# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Testing suite for the PyTorch Parakeet model.""" import json import tempfile import unittest from pathlib import Path from transformers import is_datasets_available, is_torch_available from transformers.testing_utils import cleanup, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask if is_datasets_available(): from datasets import Audio, load_dataset if is_torch_available(): import torch from transformers import ( AutoProcessor, ParakeetCTCConfig, ParakeetEncoder, ParakeetEncoderConfig, ParakeetForCTC, ParakeetForTDT, ParakeetTDTConfig, ) from transformers.loss.loss_tdt import tdt_loss @require_torch class TDTLossTest(unittest.TestCase): """Test tdt_loss against reference values generated by NeMo's TDTLossPytorch. reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-generate_tdt_loss_fixtures-py """ FIXTURE_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_tdt_loss.json" @classmethod def setUpClass(cls): with open(cls.FIXTURE_PATH) as f: cls.fixture = json.load(f) def _make_inputs(self): torch.manual_seed(self.fixture["seed"]) batch_size = self.fixture["batch_size"] max_t = self.fixture["max_t"] max_u = self.fixture["max_u"] vocab_size = self.fixture["vocab_size"] num_durations = len(self.fixture["durations"]) blank_token_id = vocab_size combined_logits = torch.randn(batch_size, max_t, max_u + 1, vocab_size + 1 + num_durations) targets = torch.randint(0, vocab_size, (batch_size, max_u)) logit_lengths = torch.tensor(self.fixture["logit_lengths"]) target_lengths = torch.tensor(self.fixture["target_lengths"]) return { "token_logits": combined_logits[..., : vocab_size + 1], "duration_logits": combined_logits[..., vocab_size + 1 :], "targets": targets, "logit_lengths": logit_lengths, "target_lengths": target_lengths, "blank_token_id": blank_token_id, "durations": self.fixture["durations"], } def test_tdt_loss_sum(self): inputs = self._make_inputs() loss = tdt_loss(**inputs, reduction="sum") expected = torch.tensor(self.fixture["expected_loss_sum"]) torch.testing.assert_close(loss, expected) def test_tdt_loss_mean(self): inputs = self._make_inputs() loss = tdt_loss(**inputs, reduction="mean") expected = torch.tensor(self.fixture["expected_loss_mean"]) torch.testing.assert_close(loss, expected) def test_tdt_loss_none(self): inputs = self._make_inputs() losses = tdt_loss(**inputs, reduction="none") expected = torch.tensor(self.fixture["expected_loss_none"]) torch.testing.assert_close(losses, expected) def test_tdt_loss_with_sigma(self): inputs = self._make_inputs() loss_no_sigma = tdt_loss(**inputs, sigma=0.0, reduction="mean") loss_with_sigma = tdt_loss(**inputs, sigma=0.05, reduction="mean") self.assertFalse(torch.allclose(loss_no_sigma, loss_with_sigma)) self.assertGreater(loss_with_sigma.item(), loss_no_sigma.item()) expected = torch.tensor(self.fixture["expected_loss_mean_sigma_0p05"]) torch.testing.assert_close(loss_with_sigma, expected) def test_tdt_loss_gradient_flows(self): inputs = self._make_inputs() inputs["token_logits"] = inputs["token_logits"].requires_grad_(True) inputs["duration_logits"] = inputs["duration_logits"].requires_grad_(True) loss = tdt_loss(**inputs, reduction="mean") loss.backward() self.assertIsNotNone(inputs["token_logits"].grad) self.assertIsNotNone(inputs["duration_logits"].grad) self.assertFalse(torch.all(inputs["token_logits"].grad == 0)) self.assertFalse(torch.all(inputs["duration_logits"].grad == 0)) class ParakeetEncoderModelTester: def __init__( self, parent, batch_size=13, seq_length=1024, is_training=True, hidden_size=64, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256, hidden_act="silu", dropout=0.0, # so gradient checkpointing doesn't fail conv_kernel_size=9, subsampling_factor=8, subsampling_conv_channels=32, attention_bias=True, num_mel_bins=80, scale_input=True, ): # testing suite parameters self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.num_mel_bins = num_mel_bins self.is_training = is_training # config parameters self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.dropout = dropout self.conv_kernel_size = conv_kernel_size self.subsampling_factor = subsampling_factor self.subsampling_conv_channels = subsampling_conv_channels self.attention_bias = attention_bias self.num_mel_bins = num_mel_bins self.scale_input = scale_input # Calculate output sequence length after subsampling self.output_seq_length = seq_length // subsampling_factor self.encoder_seq_length = self.output_seq_length self.key_length = self.output_seq_length def prepare_config_and_inputs(self): input_features = floats_tensor([self.batch_size, self.seq_length, self.num_mel_bins]) attention_mask = random_attention_mask([self.batch_size, self.seq_length]) config = self.get_config() return config, input_features, attention_mask def get_config(self): return ParakeetEncoderConfig( hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, dropout=self.dropout, dropout_positions=self.dropout, layerdrop=self.dropout, activation_dropout=self.dropout, attention_dropout=self.dropout, conv_kernel_size=self.conv_kernel_size, subsampling_factor=self.subsampling_factor, subsampling_conv_channels=self.subsampling_conv_channels, attention_bias=self.attention_bias, num_mel_bins=self.num_mel_bins, scale_input=self.scale_input, ) def create_and_check_model(self, config, input_features, attention_mask): model = ParakeetEncoder(config=config) model.to(torch_device) model.eval() with torch.no_grad(): result = model(input_features, attention_mask=attention_mask) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, config.hidden_size) ) def prepare_config_and_inputs_for_common(self): config, input_features, attention_mask = self.prepare_config_and_inputs() inputs_dict = { "input_features": input_features, "attention_mask": attention_mask, } return config, inputs_dict def check_ctc_loss(self, config, input_values, *args): model = ParakeetForCTC(config=config) model.to(torch_device) # make sure that dropout is disabled model.eval() input_values = input_values[:3] attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size) # pad input for i in range(len(input_lengths)): input_values[i, input_lengths[i] :] = 0.0 attention_mask[i, input_lengths[i] :] = 0 model.config.ctc_loss_reduction = "sum" sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() model.config.ctc_loss_reduction = "mean" mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item() self.parent.assertTrue(isinstance(sum_loss, float)) self.parent.assertTrue(isinstance(mean_loss, float)) @require_torch class ParakeetEncoderModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ParakeetEncoder,) if is_torch_available() else () test_resize_embeddings = False @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup") def test_sdpa_can_dispatch_on_flash(self): pass def setUp(self): self.model_tester = ParakeetEncoderModelTester(self) self.config_tester = ConfigTester(self, config_class=ParakeetEncoderConfig, has_text_modality=False) def test_config(self): self.config_tester.run_common_tests() def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) @unittest.skip(reason="ParakeetEncoder does not use inputs_embeds") def test_model_get_set_embeddings(self): pass class ParakeetForCTCModelTester: def __init__(self, parent, encoder_kwargs=None, is_training=True, vocab_size=128, pad_token_id=0): if encoder_kwargs is None: encoder_kwargs = {} self.parent = parent self.encoder_model_tester = ParakeetEncoderModelTester(parent, **encoder_kwargs) self.is_training = is_training self.batch_size = self.encoder_model_tester.batch_size self.output_seq_length = self.encoder_model_tester.output_seq_length self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers self.seq_length = vocab_size self.hidden_size = self.encoder_model_tester.hidden_size self.vocab_size = vocab_size self.pad_token_id = pad_token_id def prepare_config_and_inputs(self): _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs() config = self.get_config() return config, input_features, attention_mask def get_config(self): return ParakeetCTCConfig( encoder_config=self.encoder_model_tester.get_config(), vocab_size=self.vocab_size, pad_token_id=self.pad_token_id, ) def create_and_check_model(self, config, input_features, attention_mask): model = ParakeetForCTC(config=config) model.to(torch_device) model.eval() with torch.no_grad(): result = model(input_features, attention_mask=attention_mask) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size)) def prepare_config_and_inputs_for_common(self): config, input_features, attention_mask = self.prepare_config_and_inputs() inputs_dict = { "input_features": input_features, "attention_mask": attention_mask, } return config, inputs_dict def test_ctc_loss_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.encoder_model_tester.check_ctc_loss(*config_and_inputs) @require_torch class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ParakeetForCTC,) if is_torch_available() else () all_generative_model_classes = () # ParakeetForCTC has a custom genereate method pipeline_model_mapping = ( { "feature-extraction": ParakeetEncoder, "automatic-speech-recognition": ParakeetForCTC, } if is_torch_available() else {} ) test_attention_outputs = False test_resize_embeddings = False _is_composite = True @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup") def test_sdpa_can_dispatch_on_flash(self): pass def setUp(self): self.model_tester = ParakeetForCTCModelTester(self) self.config_tester = ConfigTester(self, config_class=ParakeetCTCConfig) def test_config(self): self.config_tester.run_common_tests() def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) @unittest.skip(reason="ParakeetEncoder does not use inputs_embeds") def test_model_get_set_embeddings(self): pass # Original function assumes vision+text model, so overwrite since Parakeet is audio+text # Below is modified from `tests/models/granite_speech/test_modeling_granite_speech.py` def test_sdpa_can_dispatch_composite_models(self): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") if not self._is_composite: self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") @require_torch class ParakeetForCTCIntegrationTest(unittest.TestCase): _dataset = None @classmethod def setUp(cls): cls.checkpoint_name = "nvidia/parakeet-ctc-1.1b" cls.dtype = torch.bfloat16 cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name) def tearDown(self): cleanup(torch_device, gc_collect=True) @classmethod def _load_dataset(cls): if cls._dataset is None: cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") cls._dataset = cls._dataset.cast_column( "audio", Audio(sampling_rate=cls.processor.feature_extractor.sampling_rate) ) def _load_datasamples(self, num_samples): self._load_dataset() ds = self._dataset speech_samples = ds.sort("id")[:num_samples]["audio"] return [x["array"] for x in speech_samples] @slow def test_1b_model_integration(self): """ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py """ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(1) model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples) inputs.to(model.device, dtype=self.dtype) predicted_ids = model.generate(**inputs) torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.decode(predicted_ids, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @slow def test_1b_model_integration_batched(self): """ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batched-py """ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"]) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(5) model = ParakeetForCTC.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples) inputs.to(model.device, dtype=self.dtype) predicted_ids = model.generate(**inputs) torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS) predicted_transcripts = self.processor.decode(predicted_ids, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) class ParakeetForTDTModelTester: def __init__( self, parent, encoder_kwargs=None, is_training=True, vocab_size=129, decoder_hidden_size=32, num_decoder_layers=1, durations=[0, 1, 2, 3, 4], hidden_act="relu", max_symbols_per_step=5, pad_token_id=2, ): if encoder_kwargs is None: encoder_kwargs = {} self.parent = parent self.encoder_model_tester = ParakeetEncoderModelTester(parent, **encoder_kwargs) self.is_training = is_training self.batch_size = self.encoder_model_tester.batch_size self.output_seq_length = self.encoder_model_tester.output_seq_length self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers self.hidden_size = self.encoder_model_tester.hidden_size self.seq_length = self.encoder_model_tester.output_seq_length self.encoder_seq_length = self.encoder_model_tester.output_seq_length self.vocab_size = vocab_size self.decoder_hidden_size = decoder_hidden_size self.num_decoder_layers = num_decoder_layers self.durations = durations self.hidden_act = hidden_act self.max_symbols_per_step = max_symbols_per_step self.pad_token_id = pad_token_id self.blank_token_id = vocab_size - 1 def prepare_config_and_inputs(self): _, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs() config = self.get_config() return config, input_features, attention_mask def get_config(self): return ParakeetTDTConfig( vocab_size=self.vocab_size, decoder_hidden_size=self.decoder_hidden_size, num_decoder_layers=self.num_decoder_layers, durations=self.durations, hidden_act=self.hidden_act, max_symbols_per_step=self.max_symbols_per_step, encoder_config=self.encoder_model_tester.get_config().to_dict(), pad_token_id=self.pad_token_id, blank_token_id=self.blank_token_id, ) def create_and_check_model(self, config, inputs_dict): model = ParakeetForTDT(config=config) model.to(torch_device) model.eval() with torch.no_grad(): result = model(**inputs_dict) # Check encoder last hidden state self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.encoder_model_tester.hidden_size), ) def prepare_config_and_inputs_for_common(self): config, input_features, attention_mask = self.prepare_config_and_inputs() decoder_input_ids = ids_tensor([self.batch_size, 1], self.vocab_size) inputs_dict = { "input_features": input_features, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, } return config, inputs_dict @require_torch class ParakeetForTDTModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (ParakeetForTDT,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": ParakeetEncoder, "automatic-speech-recognition": ParakeetForTDT, } if is_torch_available() else {} ) test_attention_outputs = False test_resize_embeddings = False test_torch_exportable = False _is_composite = True @unittest.skip(reason="No available flash-SDPA kernels for Parakeet test shapes on this setup") def test_sdpa_can_dispatch_on_flash(self): pass def setUp(self): self.model_tester = ParakeetForTDTModelTester(self) self.config_tester = ConfigTester(self, config_class=ParakeetTDTConfig) def test_config(self): self.config_tester.run_common_tests() def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.create_and_check_model(*config_and_inputs) @unittest.skip(reason="ParakeetForTDT does not use inputs_embeds") def test_model_get_set_embeddings(self): pass @unittest.skip( reason="ParakeetForTDT is a transducer, not a standard encoder-decoder: no separate text config to set" ) def test_attn_implementation_composite_models(self): pass @unittest.skip( reason="ParakeetForTDT is a transducer with an LSTM prediction network; " "it does not expose encoder_hidden_states in the standard encoder-decoder sense" ) def test_hidden_states_output(self): pass @unittest.skip( reason="ParakeetForTDT is a transducer with an LSTM prediction network; " "it does not expose encoder_hidden_states in the standard encoder-decoder sense" ) def test_retain_grad_hidden_states_attentions(self): pass @unittest.skip( reason="ParakeetForTDT has a custom generate() that is not fully compatible with GenerationTesterMixin" ) def test_generation_tester_mixin_inheritance(self): pass @unittest.skip(reason="ParakeetForTDT is a flat composite model without a separate base_model sub-module") def test_model_base_model_prefix(self): pass @unittest.skip(reason="ParakeetForTDT decoder is an LSTM prediction network without attention") def test_flex_attention_with_grads(self): pass # Original function assumes vision+text model, so overwrite since Parakeet is audio+text def test_sdpa_can_dispatch_composite_models(self): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") if not self._is_composite: self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") for model_class in self.all_model_classes: config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_sdpa = model_class.from_pretrained(tmpdirname) model_sdpa = model_sdpa.eval().to(torch_device) model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") for name, submodule in model_eager.named_modules(): class_name = submodule.__class__.__name__ if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") @require_torch class ParakeetForTDTIntegrationTest(unittest.TestCase): _dataset = None @classmethod def setUp(cls): cls.checkpoint_name = "nvidia/parakeet-tdt-0.6b-v3" cls.dtype = torch.bfloat16 cls.processor = AutoProcessor.from_pretrained(cls.checkpoint_name) def tearDown(self): cleanup(torch_device, gc_collect=True) @classmethod def _load_dataset(cls): if cls._dataset is None: cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") cls._dataset = cls._dataset.cast_column( "audio", Audio(sampling_rate=cls.processor.feature_extractor.sampling_rate) ) def _load_datasamples(self, num_samples): self._load_dataset() ds = self._dataset speech_samples = ds.sort("id")[:num_samples]["audio"] return [x["array"] for x in speech_samples] @slow def test_tdt_model_integration(self): """ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single_tdt-py """ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single_tdt.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @slow def test_tdt_model_integration_batched(self): """ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batch_tdt-py """ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch_tdt.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=self.dtype, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=self.dtype) output = model.generate(**inputs, return_dict_in_generate=True) predicted_transcripts = self.processor.decode(output.sequences, skip_special_tokens=True) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) @slow def test_tdt_model_integration_timestamps(self): """ reproducer: https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batch_tdt_timestamps-py """ RESULTS_PATH = ( Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch_tdt_timestamp.json" ) with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"] EXPECTED_START_TIMESTAMPS = raw_data["start_timestamps"] EXPECTED_END_TIMESTAMPS = raw_data["end_timestamps"] # Use larger precision for testing token durations and timestamps samples = self._load_datasamples(len(EXPECTED_TRANSCRIPTIONS)) model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto") inputs = self.processor(samples, sampling_rate=self.processor.feature_extractor.sampling_rate) inputs.to(model.device, dtype=model.dtype) output = model.generate(**inputs, return_dict_in_generate=True) predicted_transcripts, predicted_timestamps = self.processor.decode( output.sequences, durations=output.durations, skip_special_tokens=True, ) self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS) # Check timestamps and durations self.assertIsNotNone(output.durations, "durations should be returned") predicted_start_times = [[entry["start"] for entry in el] for el in predicted_timestamps] predicted_end_times = [[entry["end"] for entry in el] for el in predicted_timestamps] torch.testing.assert_close(predicted_start_times, EXPECTED_START_TIMESTAMPS) torch.testing.assert_close(predicted_end_times, EXPECTED_END_TIMESTAMPS) @slow def test_tdt_model_integration_loss(self): """ Verify that ParakeetForTDT loss matches NeMo's TDT loss (sigma=0). reproducer: https://gist.github.com/883ea42bf7d8ce2af42f3055627476a7 """ RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_loss_tdt.json" with open(RESULTS_PATH, "r") as f: raw_data = json.load(f) EXPECTED_MEAN_LOSS = torch.tensor(raw_data["expected_mean_loss"]) num_samples = raw_data["num_samples"] samples = self._load_datasamples(num_samples) transcripts = self._dataset.sort("id")[:num_samples]["text"] transcripts = [t.lower() for t in transcripts] # Use float32 for loss precision model = ParakeetForTDT.from_pretrained(self.checkpoint_name, dtype=torch.float32, device_map="auto") inputs = self.processor( audio=samples, text=transcripts, sampling_rate=self.processor.feature_extractor.sampling_rate, ) inputs.to(model.device) # Forward in eval mode — check loss matches NeMo model.eval() with torch.no_grad(): outputs = model(**inputs) self.assertIsNotNone(outputs.loss, "Loss must be computed when labels are provided") self.assertEqual(outputs.logits.dim(), 4, "Training logits must be 4D (B, T, U+1, V+D)") torch.testing.assert_close(outputs.loss.cpu(), EXPECTED_MEAN_LOSS, rtol=1e-3, atol=1e-3) # Backward — verify gradients flow del outputs torch.cuda.empty_cache() model.train() model.zero_grad() outputs = model(**inputs) outputs.loss.backward() n_with_grad = sum(1 for p in model.parameters() if p.grad is not None) self.assertGreater(n_with_grad, 0, "No gradients after backward")