14 KiB
This model was contributed to Hugging Face Transformers on 2025-09-25.
Parakeet
Overview
Parakeet models, introduced by NVIDIA NeMo, are models that combine a Fast Conformer encoder with connectionist temporal classification (CTC), recurrent neural network transducer (RNNT) or token and duration transducer (TDT) decoder for automatic speech recognition.
Model Architecture
- Fast Conformer Encoder: A linearly scalable Conformer architecture that processes mel-spectrogram features and reduces sequence length through subsampling. This is more efficient version of the Conformer Encoder found in FastSpeech2Conformer (see [
ParakeetEncoder] for the encoder implementation and details). - ParakeetForCTC: a Fast Conformer Encoder + a CTC decoder
- CTC Decoder: Simple but effective decoder consisting of:
- 1D convolution projection from encoder hidden size to vocabulary size (for optimal NeMo compatibility).
- CTC loss computation for training.
- Greedy CTC decoding for inference.
- CTC Decoder: Simple but effective decoder consisting of:
- ParakeetForTDT: a Fast Conformer Encoder + a TDT (Token Duration Transducer) decoder
- TDT Decoder: Jointly predicts tokens and their durations, enabling efficient decoding:
- LSTM prediction network maintains language context across token predictions.
- Joint network combines encoder and decoder outputs.
- Duration head predicts how many frames to skip, enabling fast inference.
- TDT Decoder: Jointly predicts tokens and their durations, enabling efficient decoding:
The original implementation can be found in NVIDIA NeMo. Model checkpoints are to be found under the NVIDIA organization.
This model was contributed by Nithin Rao Koluguri, Eustache Le Bihan, Eric Bezzam, Maksym Lypivskyi, and Hainan Xu.
Usage
ParakeetForCTC usage
from transformers import pipeline
pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-ctc-1.1b")
out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
print(out)
# {'text': 'yesterday it was thirty five degrees in barcelona but today the temperature will go down to minus twenty degrees'}
from datasets import Audio, load_dataset
from transformers import AutoModelForCTC, AutoProcessor
model_id = "nvidia/parakeet-ctc-1.1b"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCTC.from_pretrained(model_id, device_map="auto")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:5]]
inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
outputs = model.generate(**inputs)
print(processor.decode(outputs))
ParakeetForTDT usage
Parakeet TDT transcripts include casing, and the model can also perform token timestamping.
from transformers import pipeline
pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-tdt-0.6b-v3")
out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
print(out)
# {'text': 'Yesterday it was 35 degrees in Barcelona, but today the temperature will go down to minus 20 degrees.'}
from transformers import AutoModelForTDT, AutoProcessor
from datasets import load_dataset, Audio
model_id = "nvidia/parakeet-tdt-0.6b-v3"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:5]]
inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
output = model.generate(**inputs, return_dict_in_generate=True)
print(processor.decode(output.sequences, skip_special_tokens=True))
from datasets import Audio, load_dataset
from transformers import AutoModelForTDT, AutoProcessor
model_id = "nvidia/parakeet-tdt-0.6b-v3"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForTDT.from_pretrained(model_id, dtype="auto", device_map="auto")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:1]]
inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
output = model.generate(**inputs, return_dict_in_generate=True)
decoded_output, decoded_timestamps = processor.decode(
output.sequences,
durations=output.durations,
skip_special_tokens=True,
)
print("Transcription:", decoded_output)
print("\nTimestamped tokens:", decoded_timestamps)
"""
Transcription: ['mister Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.']
Timestamped tokens: [[{'token': 'm', 'start': 0.24, 'end': 0.48}, {'token': 'ister', 'start': 0.48, 'end': 0.64}, {'token': 'Qu', 'start': 0.64, 'end': 0.88}, {'token': 'il', 'start': 0.88, 'end': 1.12}, {'token': 'ter', 'start': 1.12, 'end': 1.36}, {'token': 'is', 'start': 1.36, 'end': 1.44}, {'token': 'the', 'start': 1.44, 'end': 1.6}, {'token': 'ap', 'start': 1.6, 'end': 1.76}, {'token': 'ost', 'start': 1.76, 'end': 1.92}, {'token': 'le', 'start': 2.0, 'end': 2.16}, {'token': 'of', 'start': 2.16, 'end': 2.24}, {'token': 'the', 'start': 2.24, 'end': 2.4}, {'token': 'mid', 'start': 2.4, 'end': 2.48}, {'token': 'd', 'start': 2.48, 'end': 2.56}, {'token': 'le', 'start': 2.56, 'end': 2.64}, {'token': 'clas', 'start': 2.72, 'end': 2.88}, {'token': 's', 'start': 2.88, 'end': 3.04}, {'token': 'es', 'start': 3.04, 'end': 3.12}, {'token': ',', 'start': 3.12, 'end': 3.12}, {'token': 'and', 'start': 3.2800000000000002, 'end': 3.44}, {'token': 'we', 'start': 3.44, 'end': 3.6}, {'token': 'are', 'start': 3.6, 'end': 3.7600000000000002}, {'token': 'gl', 'start': 3.7600000000000002, 'end': 3.92}, {'token': 'ad', 'start': 3.92, 'end': 4.08}, {'token': 'to', 'start': 4.08, 'end': 4.24}, {'token': 'wel', 'start': 4.24, 'end': 4.4}, {'token': 'c', 'start': 4.4, 'end': 4.48}, {'token': 'ome', 'start': 4.48, 'end': 4.72}, {'token': 'his', 'start': 4.72, 'end': 4.96}, {'token': 'gos', 'start': 4.96, 'end': 5.12}, {'token': 'pel', 'start': 5.36, 'end': 5.6000000000000005}, {'token': '.', 'start': 5.6000000000000005, 'end': 5.6000000000000005}]]
"""
Making The Model Go Brrr
Parakeet supports full-graph compilation with CUDA graphs! This optimization is most effective when you know the maximum audio length you want to transcribe. The key idea is using static input shapes to avoid recompilation. For example, if you know your audio will be under 30 seconds, you can use the processor to pad all inputs to 30 seconds, preparing consistent input features and attention masks. See the example below!
import torch
from datasets import Audio, load_dataset
from transformers import AutoModelForCTC, AutoProcessor
processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", device_map="auto")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:5]]
# Compile the generate method with fullgraph and CUDA graphs
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
# let's define processor kwargs to pad to 30 seconds
processor_kwargs = {
"padding": "max_length",
"max_length": 30 * processor.feature_extractor.sampling_rate,
}
# Define a timing context using CUDA events
class TimerContext:
def __init__(self, name="Execution"):
self.name = name
self.start_event = None
self.end_event = None
def __enter__(self):
# Use CUDA events for more accurate GPU timing
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.start_event.record()
return self
def __exit__(self, *args):
self.end_event.record()
torch.cuda.synchronize()
elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
print(f"{self.name} time: {elapsed_time:.4f} seconds")
inputs = processor(speech_samples[0], **processor_kwargs)
inputs.to(model.device, dtype=model.dtype)
print("\n" + "="*50)
print("First generation - compiling...")
# Generate with the compiled model
with TimerContext("First generation"):
outputs = model.generate(**inputs)
print(processor.decode(outputs))
inputs = processor(speech_samples[1], **processor_kwargs)
inputs.to(model.device, dtype=model.dtype)
print("\n" + "="*50)
print("Second generation - recording CUDA graphs...")
with TimerContext("Second generation"):
outputs = model.generate(**inputs)
print(processor.decode(outputs))
inputs = processor(speech_samples[2], **processor_kwargs)
inputs.to(model.device, dtype=model.dtype)
print("\n" + "="*50)
print("Third generation - fast !!!")
with TimerContext("Third generation"):
outputs = model.generate(**inputs)
print(processor.decode(outputs))
inputs = processor(speech_samples[3], **processor_kwargs)
inputs.to(model.device, dtype=model.dtype)
print("\n" + "="*50)
print("Fourth generation - still fast !!!")
with TimerContext("Fourth generation"):
outputs = model.generate(**inputs)
print(processor.decode(outputs))
CTC Training
import torch
from datasets import Audio, load_dataset
from transformers import AutoModelForCTC, AutoProcessor
model_id = "nvidia/parakeet-ctc-1.1b"
NUM_SAMPLES = 5
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCTC.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
model.train()
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]]
text_samples = ds["text"][:NUM_SAMPLES]
# passing `text` to the processor will prepare inputs' `labels` key
inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
outputs = model(**inputs)
print("Loss:", outputs.loss.item())
outputs.loss.backward()
TDT Training
from datasets import Audio, load_dataset
import torch
from transformers import AutoModelForTDT, AutoProcessor
model_id = "nvidia/parakeet-tdt-0.6b-v3"
NUM_SAMPLES = 4
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForTDT.from_pretrained(model_id, dtype=torch.bfloat16, device_map="auto")
model.train()
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:NUM_SAMPLES]]
text_samples = ds["text"][:NUM_SAMPLES]
# passing `text` to the processor will prepare inputs' `labels` key
inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
outputs = model(**inputs)
print("Loss:", outputs.loss.item())
outputs.loss.backward()
ParakeetTokenizer
autodoc ParakeetTokenizer
ParakeetFeatureExtractor
autodoc ParakeetFeatureExtractor - call
ParakeetProcessor
autodoc ParakeetProcessor - call - decode
ParakeetEncoderConfig
autodoc ParakeetEncoderConfig
ParakeetCTCConfig
autodoc ParakeetCTCConfig
ParakeetTDTConfig
autodoc ParakeetTDTConfig
ParakeetEncoder
autodoc ParakeetEncoder
ParakeetForCTC
autodoc ParakeetForCTC
ParakeetForTDT
autodoc ParakeetForTDT