first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
0
tests/models/cwm/__init__.py
Normal file
0
tests/models/cwm/__init__.py
Normal file
128
tests/models/cwm/test_configuration_cwm.py
Normal file
128
tests/models/cwm/test_configuration_cwm.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
from transformers.models.cwm import CwmConfig
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
|
||||
|
||||
class CwmConfigTest(unittest.TestCase):
|
||||
def test_default_config(self):
|
||||
"""Test default CWM configuration"""
|
||||
config = CwmConfig()
|
||||
|
||||
# CWM defaults
|
||||
self.assertEqual(config.sliding_window, 8192)
|
||||
self.assertIsInstance(config.layer_types, list)
|
||||
|
||||
# Llama3 defaults
|
||||
self.assertEqual(config.vocab_size, 128256)
|
||||
self.assertIsNotNone(config.rope_parameters)
|
||||
self.assertEqual(config.rope_parameters["rope_type"], "llama3")
|
||||
|
||||
def test_custom_sliding_window_config(self):
|
||||
config = CwmConfig(sliding_window=4096)
|
||||
|
||||
self.assertEqual(config.sliding_window, 4096)
|
||||
|
||||
def test_custom_layer_types_config(self):
|
||||
layer_types = ["full_attention", "sliding_attention", "sliding_attention", "full_attention"]
|
||||
config = CwmConfig(num_hidden_layers=4, layer_types=layer_types)
|
||||
|
||||
self.assertEqual(config.layer_types, layer_types)
|
||||
self.assertEqual(len(config.layer_types), config.num_hidden_layers)
|
||||
|
||||
def test_invalid_layer_types_length(self):
|
||||
with self.assertRaises(huggingface_hub.errors.StrictDataclassClassValidationError):
|
||||
CwmConfig(
|
||||
num_hidden_layers=4,
|
||||
layer_types=["full_attention", "sliding_attention"], # Only 2 types for 4 layers
|
||||
)
|
||||
|
||||
def test_invalid_layer_type_value(self):
|
||||
with self.assertRaises(huggingface_hub.errors.StrictDataclassClassValidationError):
|
||||
CwmConfig(num_hidden_layers=2, layer_types=["full_attention", "invalid_attention"])
|
||||
|
||||
def test_automatic_layer_types_generation(self):
|
||||
# Test default pattern (every 4th layer uses full attention)
|
||||
config = CwmConfig(num_hidden_layers=8)
|
||||
|
||||
expected_types = [
|
||||
"full_attention", # layer 0: 0 % 4 == 0
|
||||
"sliding_attention", # layer 1: 1 % 4 != 0
|
||||
"sliding_attention", # layer 2: 2 % 4 != 0
|
||||
"sliding_attention", # layer 3: 3 % 4 != 0
|
||||
"full_attention", # layer 4: 4 % 4 == 0
|
||||
"sliding_attention", # layer 5: 5 % 4 != 0
|
||||
"sliding_attention", # layer 6: 6 % 4 != 0
|
||||
"sliding_attention", # layer 7: 7 % 4 != 0
|
||||
]
|
||||
|
||||
self.assertEqual(config.layer_types, expected_types)
|
||||
|
||||
def test_rope_parameters_config(self):
|
||||
custom_rope_parameters = {
|
||||
"factor": 8.0,
|
||||
"high_freq_factor": 2.0,
|
||||
"low_freq_factor": 0.5,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"rope_type": "llama3",
|
||||
"rope_theta": 1_000_000.0,
|
||||
}
|
||||
|
||||
config = CwmConfig(rope_parameters=custom_rope_parameters)
|
||||
|
||||
self.assertEqual(config.rope_parameters, custom_rope_parameters)
|
||||
|
||||
def test_config_serialization(self):
|
||||
config = CwmConfig(
|
||||
sliding_window=4096,
|
||||
layer_types=["full_attention", "sliding_attention"] * 3,
|
||||
num_hidden_layers=6,
|
||||
)
|
||||
|
||||
config_dict = config.to_dict()
|
||||
self.assertIn("sliding_window", config_dict)
|
||||
self.assertIn("layer_types", config_dict)
|
||||
|
||||
new_config = CwmConfig.from_dict(config_dict)
|
||||
self.assertEqual(new_config.sliding_window, config.sliding_window)
|
||||
self.assertEqual(new_config.layer_types, config.layer_types)
|
||||
|
||||
def test_config_inheritance_from_llama(self):
|
||||
config = CwmConfig()
|
||||
|
||||
# Llama config attributes
|
||||
self.assertTrue(hasattr(config, "hidden_size"))
|
||||
self.assertTrue(hasattr(config, "num_attention_heads"))
|
||||
self.assertTrue(hasattr(config, "num_key_value_heads"))
|
||||
self.assertTrue(hasattr(config, "intermediate_size"))
|
||||
self.assertTrue(hasattr(config, "rope_parameters"))
|
||||
self.assertTrue(hasattr(config, "attention_dropout"))
|
||||
|
||||
|
||||
@require_torch
|
||||
class CwmConfigTester(ConfigTester):
|
||||
def __init__(self, parent, config_class=None, **kwargs):
|
||||
super().__init__(parent, config_class=config_class, **kwargs)
|
||||
|
||||
def test_config(self):
|
||||
config_class = CwmConfig
|
||||
self.config_tester = ConfigTester(self, config_class=config_class)
|
||||
self.config_tester.run_common_tests()
|
||||
257
tests/models/cwm/test_modeling_cwm.py
Normal file
257
tests/models/cwm/test_modeling_cwm.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
require_deterministic_for_xpu,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.models.cwm import (
|
||||
CwmConfig,
|
||||
CwmForCausalLM,
|
||||
CwmModel,
|
||||
)
|
||||
|
||||
|
||||
class CwmModelTester(CausalLMModelTester):
|
||||
if is_torch_available():
|
||||
config_class = CwmConfig
|
||||
base_model_class = CwmModel
|
||||
causal_lm_class = CwmForCausalLM
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
|
||||
config.sliding_window = 8192
|
||||
config.rope_parameters = {
|
||||
"factor": 16.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
"rope_type": "llama3",
|
||||
"rope_theta": 1000000.0,
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@require_torch
|
||||
class CwmModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
CwmModel,
|
||||
CwmForCausalLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": CwmModel,
|
||||
"text-generation": CwmForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
model_tester_class = CwmModelTester
|
||||
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
_torch_compile_train_cls = CwmForCausalLM if is_torch_available() else None
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
class CwmIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@slow
|
||||
@require_deterministic_for_xpu
|
||||
def test_cwm_integration(self):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/cwm")
|
||||
model = CwmForCausalLM.from_pretrained("facebook/cwm", device_map="auto", dtype=torch.bfloat16)
|
||||
|
||||
self.assertIsNotNone(model.config.sliding_window)
|
||||
self.assertIsNotNone(model.config.layer_types)
|
||||
self.assertIn("full_attention", model.config.layer_types)
|
||||
self.assertIn("sliding_attention", model.config.layer_types)
|
||||
|
||||
for i, layer in enumerate(model.model.layers):
|
||||
expected_type = model.config.layer_types[i]
|
||||
self.assertEqual(layer.attention_type, expected_type)
|
||||
if expected_type == "sliding_attention":
|
||||
self.assertEqual(layer.self_attn.sliding_window, model.config.sliding_window)
|
||||
|
||||
prompt = "def quicksort(arr):"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(**inputs)
|
||||
|
||||
# fmt: off
|
||||
expected_logits = Expectations(
|
||||
{
|
||||
("cuda", None): torch.tensor(
|
||||
[0.5625, 2.9531, 9.1875, 0.5039, -0.3262, 2.2344, 3.0312, 1.5312, 0.5664, 1.5625, 2.7656, 3.4219, 2.0312, 2.1719, 1.5391, 2.5469, 2.8281, 1.8125, 1.7109, 1.3906, 1.0391, 0.1621, 0.4277, 0.1455, -0.1230, 0.8477, 2.2344, 5.2188, 1.2969, 1.5547, 0.8516, 0.7148]
|
||||
),
|
||||
("xpu", None): torch.Tensor(
|
||||
[0.5625, 2.9688, 9.1875, 0.4766, -0.3574, 2.2344, 3.0156, 1.4922, 0.5625, 1.5547, 2.7656, 3.4062, 2.0156, 2.1719, 1.5469, 2.5156, 2.8125, 1.7891, 1.7031, 1.3828, 1.0312, 0.1602, 0.4277, 0.1328, -0.1348, 0.8281, 2.2188, 5.2812, 1.2734, 1.5312, 0.8398, 0.7070]
|
||||
),
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
expected_logits = expected_logits.get_expectation().to(model.device, torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out.logits[0, -1, :32], expected_logits, atol=1e-2, rtol=1e-2)
|
||||
|
||||
self.assertEqual(out.logits.shape[1], inputs.input_ids.shape[1])
|
||||
self.assertEqual(out.logits.shape[2], model.config.vocab_size)
|
||||
self.assertFalse(torch.isnan(out.logits).any())
|
||||
self.assertFalse(torch.isinf(out.logits).any())
|
||||
|
||||
@slow
|
||||
@require_deterministic_for_xpu
|
||||
def test_cwm_sliding_window_long_sequence(self):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/cwm")
|
||||
# original `sliding_window` is `8192`, but it causes GPU OOM on A10
|
||||
model = CwmForCausalLM.from_pretrained(
|
||||
"facebook/cwm", device_map="auto", dtype=torch.bfloat16, sliding_window=4096
|
||||
)
|
||||
|
||||
sliding_window = model.config.sliding_window
|
||||
long_text = "for i in range(1000):\n print(f'iteration {i}')\n" * 270
|
||||
|
||||
inputs = tokenizer(long_text, return_tensors="pt").to(model.device)
|
||||
seq_len = inputs.input_ids.shape[1]
|
||||
|
||||
# create a sequence longer than sliding window
|
||||
self.assertGreater(
|
||||
seq_len, sliding_window, f"Test sequence length {seq_len} should be > sliding window {sliding_window}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(**inputs)
|
||||
|
||||
# fmt: off
|
||||
expected_logits = Expectations(
|
||||
{
|
||||
("cuda", None): torch.tensor(
|
||||
[5.2812, 6.4688, 12.8125, 4.6875, 5.2500, 4.2500, 6.9688, 4.9375, 2.7656, 6.5938, 4.9688, 1.1016, 5.9375, 3.7500, 3.1094, 5.5312, 6.1250, 4.7500, 4.5312, 2.8281, 4.0625, 3.3125, 3.9219, 3.3906, 3.1406, 3.6719, 3.2031, 7.0938, 4.8750, 6.0000, 2.7188, 6.2500]
|
||||
),
|
||||
("xpu", None): torch.Tensor(
|
||||
[5.2500, 6.4688, 12.8125, 4.6562, 5.2812, 4.2812, 7.0000, 4.9062, 2.7344, 6.5938, 4.9062, 1.1094, 5.9375, 3.7188, 3.0469, 5.5000, 6.0938, 4.7188, 4.5000, 2.7344, 4.0312, 3.2812, 3.8750, 3.3438, 3.1094, 3.6406, 3.2031, 7.1250, 4.8750, 6.0000, 2.7031, 6.2188]
|
||||
),
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
expected_logits = expected_logits.get_expectation().to(model.device, torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out.logits[0, -1, :32], expected_logits, atol=1e-2, rtol=1e-2)
|
||||
|
||||
logits = out.logits.to("cpu")
|
||||
|
||||
self.assertEqual(logits.shape[1], seq_len)
|
||||
self.assertEqual(logits.shape[2], model.config.vocab_size)
|
||||
self.assertFalse(torch.isnan(logits).any())
|
||||
self.assertFalse(torch.isinf(logits).any())
|
||||
|
||||
for i, layer in enumerate(model.model.layers):
|
||||
if model.config.layer_types[i] == "sliding_attention":
|
||||
self.assertEqual(layer.self_attn.sliding_window, sliding_window)
|
||||
|
||||
@slow
|
||||
def test_cwm_generation_20_tokens(self):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/cwm")
|
||||
model = CwmForCausalLM.from_pretrained("facebook/cwm", device_map="auto", dtype=torch.bfloat16)
|
||||
|
||||
system_prompt = "You are a helpful AI assistant. You always reason before responding, using the following format:\n\n<think>\nyour internal reasoning\n</think>\nyour external response"
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "Write a simple Python function to add two numbers."},
|
||||
]
|
||||
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
preserve_previous_think=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=20,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()
|
||||
generated_text = tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||
|
||||
self.assertEqual(len(output_ids), 20, "Should generate exactly 20 tokens")
|
||||
|
||||
expected_token_ids = [
|
||||
33413,
|
||||
11,
|
||||
358,
|
||||
1205,
|
||||
311,
|
||||
3350,
|
||||
264,
|
||||
13325,
|
||||
734,
|
||||
430,
|
||||
11621,
|
||||
1403,
|
||||
5219,
|
||||
13,
|
||||
6914,
|
||||
596,
|
||||
1212,
|
||||
555,
|
||||
89746,
|
||||
1268,
|
||||
]
|
||||
expected_text = "Okay, I need to write a Python function that adds two numbers. Let's start by recalling how"
|
||||
|
||||
self.assertEqual(output_ids, expected_token_ids, "Generated tokens should match ground truth")
|
||||
self.assertEqual(generated_text, expected_text, "Generated text should match ground truth")
|
||||
Reference in New Issue
Block a user