# Copyright 2025 Meituan and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Testing suite for the PyTorch LongcatFlash model.""" import copy import tempfile import unittest from pytest import mark from transformers import LongcatFlashConfig, is_torch_available from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, require_large_cpu_ram, require_torch, require_torch_accelerator, slow, torch_device, ) from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester from ...test_modeling_common import ids_tensor if is_torch_available(): import torch from transformers import AutoTokenizer, Cache, LongcatFlashForCausalLM, LongcatFlashModel class LongcatFlashModelTester(CausalLMModelTester): if is_torch_available(): base_model_class = LongcatFlashModel def __init__( self, parent, batch_size=2, seq_length=7, is_training=True, use_input_mask=True, use_labels=True, vocab_size=99, hidden_size=144, ffn_hidden_size=288, expert_ffn_hidden_size=48, num_layers=1, # We have `self.num_hidden_layers = 2 * num_layers` in the body. See `LongcatFlashConfig`. num_attention_heads=8, num_key_value_heads=8, kv_lora_rank=16, q_lora_rank=48, qk_rope_head_dim=4, v_head_dim=8, qk_nope_head_dim=8, head_dim=4, n_routed_experts=4, zero_expert_num=2, moe_topk=2, routed_scaling_factor=1.0, hidden_act="silu", max_position_embeddings=128, initializer_range=0.02, rms_norm_eps=1e-6, bos_token_id=1, eos_token_id=2, pad_token_id=3, type_sequence_label_size=2, num_labels=3, num_choices=4, ): super().__init__(parent) self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training self.use_input_mask = use_input_mask self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.ffn_hidden_size = ffn_hidden_size self.expert_ffn_hidden_size = expert_ffn_hidden_size self.num_layers = num_layers self.num_hidden_layers = 2 * num_layers # for compatibility self.expected_num_hidden_layers = 2 # embedding + 2 layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.kv_lora_rank = kv_lora_rank self.q_lora_rank = q_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.qk_nope_head_dim = qk_nope_head_dim self.head_dim = head_dim self.n_routed_experts = n_routed_experts self.zero_expert_num = zero_expert_num self.moe_topk = moe_topk self.routed_scaling_factor = routed_scaling_factor self.hidden_act = hidden_act self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.type_sequence_label_size = type_sequence_label_size self.num_labels = num_labels self.num_choices = num_choices def get_config(self): return LongcatFlashConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size, expert_ffn_hidden_size=self.expert_ffn_hidden_size, num_layers=self.num_layers, num_attention_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads, kv_lora_rank=self.kv_lora_rank, q_lora_rank=self.q_lora_rank, qk_rope_head_dim=self.qk_rope_head_dim, v_head_dim=self.v_head_dim, qk_nope_head_dim=self.qk_nope_head_dim, head_dim=self.head_dim, n_routed_experts=self.n_routed_experts, zero_expert_num=self.zero_expert_num, moe_topk=self.moe_topk, routed_scaling_factor=self.routed_scaling_factor, hidden_act=self.hidden_act, max_position_embeddings=self.max_position_embeddings, initializer_range=self.initializer_range, rms_norm_eps=self.rms_norm_eps, pad_token_id=self.pad_token_id, ) def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = LongcatFlashModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) def create_and_check_for_causal_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states=None, encoder_attention_mask=None, ): model = LongcatFlashForCausalLM(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None if self.use_input_mask: input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None sequence_labels = None token_labels = None choice_labels = None if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) config = self.get_config() return ( config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, ) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = config_and_inputs inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} return config, inputs_dict @require_torch class LongcatFlashModelTest(CausalLMModelTest, unittest.TestCase): model_split_percents = [0.5, 0.8] model_tester_class = LongcatFlashModelTester @unittest.skip("LongcatFlash buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass @unittest.skip("LongcatFlash buffers include complex numbers, which breaks this test") def test_save_load_fast_init_to_base(self): pass def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): self.assertIsInstance(past_key_values, Cache) k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim v_embed_dim = config.v_head_dim expected_key_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim) expected_value_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim) for layer_idx in range(config.num_hidden_layers): self.assertEqual(past_key_values.layers[layer_idx].keys.shape, expected_key_shape) self.assertEqual(past_key_values.layers[layer_idx].values.shape, expected_value_shape) @unittest.skip("LongcatFlash router uses weight.type() directly in forward which prevents offloading") def test_cpu_offload(self): pass @unittest.skip("LongcatFlash router uses weight.type() directly in forward which prevents offloading") def test_disk_offload_bin(self): pass @unittest.skip("LongcatFlash router uses weight.type() directly in forward which prevents offloading") def test_disk_offload_safetensors(self): pass @unittest.skip("Most probably because of the MOE, the moe and router does not ignore padding tokens") def test_eager_padding_matches_padding_free_with_position_ids(self): pass @unittest.skip(reason="SDPA can't dispatch on flash due to unsupported head dims") def test_sdpa_can_dispatch_on_flash(self): pass @staticmethod def _prepare_config_headdim(config, requested_dim): # there's specific head dims due to lora compressions in longcat config = copy.deepcopy(config) config.attention_dropout = 0 if requested_dim > config.qk_rope_head_dim: config.qk_rope_head_dim = requested_dim config.qk_nope_head_dim = max(config.qk_nope_head_dim, requested_dim) config.v_head_dim = max(config.v_head_dim, requested_dim) config.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim config.head_dim = requested_dim config.q_lora_rank = max(config.q_lora_rank, requested_dim * 4) config.kv_lora_rank = max(config.kv_lora_rank, requested_dim * 2) config.hidden_size = max(config.hidden_size, config.num_attention_heads * requested_dim) return config @require_flash_attn @require_torch_accelerator @require_bitsandbytes @mark.flash_attn_test @slow def test_flash_attn_2_fp32_ln(self): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_generative_model_classes: # TODO: this test should run on all classes instead if not model_class._supports_flash_attn: self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) dummy_input = inputs_dict[model.main_input_name] dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) batch_size = dummy_attention_mask.shape[0] is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size # To avoid errors with padding_side=="right" if is_padding_right: dummy_attention_mask = torch.ones_like(dummy_input) model = model_class.from_pretrained( tmpdirname, dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto", # small change to ensure device placement ) # no upcasting at all if model.config.is_encoder_decoder: dummy_decoder_input_ids = inputs_dict["decoder_input_ids"] dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"] _ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids) # with attention mask _ = model( dummy_input, attention_mask=dummy_attention_mask, decoder_input_ids=dummy_decoder_input_ids, decoder_attention_mask=dummy_decoder_attention_mask, ) else: _ = model(dummy_input) # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) @slow class LongcatFlashIntegrationTest(unittest.TestCase): short_model_id = "hf-internal-testing/LongCat-ShortCat" # This is a cut-down model that matches part of the early logits of the larger one # Only a couple experts + layers # But if it fails, it means the larger model might have issues as well model_id = "meituan-longcat/LongCat-Flash-Chat" @slow def test_shortcat_generation(self): self.model = LongcatFlashForCausalLM.from_pretrained( self.short_model_id, device_map="auto", dtype=torch.bfloat16, ) self.model.generation_config.bos_token_id = 1 self.model.generation_config.pad_token_id = 3 self.model.generation_config.eos_token_id = 2 self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) chat = [{"role": "user", "content": "Paris is..."}] inputs = self.tokenizer.apply_chat_template( chat, tokenize=True, add_generation_prompt=True, return_tensors="pt" ).to(self.model.device) with torch.no_grad(): outputs = self.model.generate(inputs["input_ids"], max_new_tokens=10, do_sample=False) response = self.tokenizer.batch_decode(outputs, skip_special_tokens=False)[0] expected_output = "[Round 0] USER:Paris is... ASSISTANT: dig年车龄juanaheast稍achaotingupebarebones" self.assertEqual(response, expected_output) @slow @require_large_cpu_ram def test_longcat_generation_cpu(self): # takes absolutely forever and a lot RAM, but allows to test the output in the CI model = LongcatFlashForCausalLM.from_pretrained(self.model_id, device_map="auto", dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(self.model_id) chat = [{"role": "user", "content": "Paris is..."}] inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt") with torch.no_grad(): outputs = model.generate(inputs["input_ids"], max_new_tokens=3, do_sample=False) response = tokenizer.batch_decode(outputs, skip_special_tokens=False)[0] expected_output = "[Round 0] USER:Paris is... ASSISTANT:Paris is..." self.assertEqual(response, expected_output)