Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
201 lines
7.4 KiB
Python
201 lines
7.4 KiB
Python
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import unittest
|
|
|
|
import pytest
|
|
|
|
from transformers import is_torch_available
|
|
from transformers.testing_utils import cleanup, require_torch, require_torch_accelerator, slow, torch_device
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import AfmoeForCausalLM, AfmoeModel, AutoTokenizer
|
|
from transformers.conversion_mapping import get_model_conversion_mapping
|
|
|
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
|
|
|
|
|
class AfmoeModelTester(CausalLMModelTester):
|
|
if is_torch_available():
|
|
base_model_class = AfmoeModel
|
|
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=4,
|
|
seq_length=12,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_token_type_ids=False,
|
|
use_labels=True,
|
|
vocab_size=64,
|
|
hidden_size=32,
|
|
intermediate_size=16,
|
|
moe_intermediate_size=16,
|
|
num_hidden_layers=2,
|
|
num_dense_layers=1,
|
|
num_attention_heads=16,
|
|
num_key_value_heads=16,
|
|
head_dim=128,
|
|
hidden_act="silu",
|
|
max_position_embeddings=128,
|
|
initializer_range=0.02,
|
|
rms_norm_eps=1e-5,
|
|
use_cache=False,
|
|
rope_theta=10000.0,
|
|
rope_parameters=None,
|
|
num_experts=4,
|
|
num_experts_per_tok=2,
|
|
num_shared_experts=2,
|
|
route_norm=True,
|
|
route_scale=1.0,
|
|
global_attn_every_n_layers=2,
|
|
sliding_window=128,
|
|
attention_dropout=0.0,
|
|
):
|
|
super().__init__(
|
|
parent=parent,
|
|
batch_size=batch_size,
|
|
seq_length=seq_length,
|
|
is_training=is_training,
|
|
use_input_mask=use_input_mask,
|
|
use_token_type_ids=use_token_type_ids,
|
|
use_labels=use_labels,
|
|
vocab_size=vocab_size,
|
|
hidden_size=hidden_size,
|
|
num_hidden_layers=num_hidden_layers,
|
|
num_attention_heads=num_attention_heads,
|
|
num_key_value_heads=num_key_value_heads,
|
|
intermediate_size=intermediate_size,
|
|
hidden_act=hidden_act,
|
|
max_position_embeddings=max_position_embeddings,
|
|
initializer_range=initializer_range,
|
|
)
|
|
self.use_cache = use_cache
|
|
self.head_dim = head_dim
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.rope_theta = rope_theta
|
|
self.moe_intermediate_size = moe_intermediate_size
|
|
self.num_dense_layers = num_dense_layers
|
|
self.num_experts = num_experts
|
|
self.num_experts_per_tok = num_experts_per_tok
|
|
self.num_shared_experts = num_shared_experts
|
|
self.route_norm = route_norm
|
|
self.route_scale = route_scale
|
|
self.global_attn_every_n_layers = global_attn_every_n_layers
|
|
self.sliding_window = sliding_window
|
|
self.attention_dropout = attention_dropout
|
|
|
|
|
|
@require_torch
|
|
class AfmoeModelTest(CausalLMModelTest, unittest.TestCase):
|
|
model_tester_class = AfmoeModelTester
|
|
all_model_classes = (AfmoeModel, AfmoeForCausalLM) if is_torch_available() else ()
|
|
pipeline_model_mapping = (
|
|
{"feature-extraction": AfmoeModel, "text-generation": AfmoeForCausalLM} if is_torch_available() else {}
|
|
)
|
|
|
|
@unittest.skip("Afmoe applies key/query norm which doesn't work with packing")
|
|
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
|
pass
|
|
|
|
@unittest.skip("Afmoe applies key/query norm which doesn't work with packing")
|
|
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
|
pass
|
|
|
|
@unittest.skip("Afmoe applies key/query norm which doesn't work with packing")
|
|
def test_model_rope_scaling_frequencies(self):
|
|
pass
|
|
|
|
@unittest.skip("Afmoe has moe, output can be different")
|
|
def test_model_outputs_equivalence(self, **kwargs):
|
|
pass
|
|
|
|
def test_router_logits_without_aux_loss(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_dense_layers = 0
|
|
config.output_router_logits = True
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
model = AfmoeForCausalLM(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
result = model(input_ids, attention_mask=attention_mask)
|
|
self.assertIsNotNone(result.router_logits)
|
|
self.assertEqual(result.router_logits[0].shape[-1], config.num_experts)
|
|
self.assertIsNone(result.aux_loss)
|
|
|
|
def test_moe_legacy_conversion_mapping_registered(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
model = AfmoeModel(config)
|
|
weight_mapping = get_model_conversion_mapping(model)
|
|
found_fused_expert_converter = any(
|
|
"mlp.experts.*.gate_proj.weight" in mapping.source_patterns
|
|
and "mlp.experts.gate_up_proj" in mapping.target_patterns
|
|
for mapping in weight_mapping
|
|
)
|
|
self.assertTrue(found_fused_expert_converter)
|
|
|
|
|
|
@require_torch_accelerator
|
|
@slow
|
|
class AfmoeIntegrationTest(unittest.TestCase):
|
|
def tearDown(self):
|
|
# See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed.
|
|
cleanup(torch_device, gc_collect=False)
|
|
|
|
@slow
|
|
@require_torch_accelerator
|
|
@pytest.mark.torch_compile_test
|
|
def test_compile_static_cache(self):
|
|
num_tokens_to_generate = 24
|
|
prompts = [
|
|
"Simply put, the theory of relativity states that ",
|
|
"My favorite all time favorite condiment is ketchup.",
|
|
]
|
|
checkpoint = "arcee-ai/trinity-nano-preview"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
|
if tokenizer.pad_token_id is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
model = AfmoeForCausalLM.from_pretrained(checkpoint, device_map=torch_device, dtype=torch.bfloat16)
|
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
|
|
|
generated_ids = model.generate(**inputs, max_new_tokens=num_tokens_to_generate, do_sample=False)
|
|
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
|
|
generated_ids = model.generate(
|
|
**inputs,
|
|
max_new_tokens=num_tokens_to_generate,
|
|
do_sample=False,
|
|
cache_implementation="static",
|
|
)
|
|
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
self.assertEqual(dynamic_text, static_text)
|
|
|
|
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
|
generated_ids = model.generate(
|
|
**inputs,
|
|
max_new_tokens=num_tokens_to_generate,
|
|
do_sample=False,
|
|
cache_implementation="static",
|
|
)
|
|
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
self.assertEqual(dynamic_text, static_compiled_text)
|