Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
424 lines
18 KiB
Python
424 lines
18 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch Blt model."""
|
|
|
|
import unittest
|
|
|
|
import pytest
|
|
from parameterized import parameterized
|
|
|
|
from transformers import AutoTokenizer, is_torch_available
|
|
from transformers.testing_utils import (
|
|
Expectations,
|
|
cleanup,
|
|
require_torch,
|
|
require_torch_accelerator,
|
|
require_torch_bf16,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
|
from ...test_modeling_common import (
|
|
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
|
_test_eager_matches_sdpa_inference,
|
|
)
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import BltConfig, BltForCausalLM, BltModel
|
|
|
|
|
|
class BltModelTester(CausalLMModelTester):
|
|
if is_torch_available():
|
|
base_model_class = BltModel
|
|
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
ignore_index=-100,
|
|
seq_length=7,
|
|
is_training=True,
|
|
):
|
|
super().__init__(parent)
|
|
self.parent = parent
|
|
self.ignore_index = ignore_index
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.batch_size = 3
|
|
|
|
# Common parameters for all configs
|
|
self.hidden_size = 16
|
|
self.num_hidden_layers = 1
|
|
self.num_attention_heads = 2
|
|
self.num_key_value_heads = 2
|
|
self.intermediate_size = 32
|
|
self.hidden_act = "silu"
|
|
self.max_position_embeddings = 32
|
|
self.vocab_size = 32
|
|
self.rope_theta = 500000.0
|
|
self.rope_parameters = {"rope_type": "default"}
|
|
self.rms_norm_eps = 1e-5
|
|
self.dropout = 0.0
|
|
self.encoder_hash_byte_group_size = [2, 3]
|
|
self.encoder_hash_byte_group_vocab = 64
|
|
self.encoder_hash_byte_group_nb_functions = 1
|
|
# Common parameters for all configs
|
|
self.patcher_config = {
|
|
"hidden_size": self.hidden_size,
|
|
"num_hidden_layers": self.num_hidden_layers,
|
|
"num_attention_heads": self.num_attention_heads,
|
|
"num_key_value_heads": self.num_key_value_heads,
|
|
"intermediate_size": self.intermediate_size,
|
|
"max_position_embeddings": self.max_position_embeddings,
|
|
"rope_theta": self.rope_theta,
|
|
"rope_parameters": self.rope_parameters,
|
|
"hidden_act": self.hidden_act,
|
|
"rms_norm_eps": self.rms_norm_eps,
|
|
"dropout": self.dropout,
|
|
}
|
|
|
|
self.encoder_config = {
|
|
"hidden_size": self.hidden_size,
|
|
"num_hidden_layers": self.num_hidden_layers,
|
|
"num_attention_heads": self.num_attention_heads,
|
|
"num_key_value_heads": self.num_key_value_heads,
|
|
"intermediate_size": self.intermediate_size,
|
|
"max_position_embeddings": self.max_position_embeddings,
|
|
"rope_theta": self.rope_theta,
|
|
"rope_parameters": self.rope_parameters,
|
|
"hidden_act": self.hidden_act,
|
|
"rms_norm_eps": self.rms_norm_eps,
|
|
"dropout": self.dropout,
|
|
}
|
|
|
|
self.decoder_config = {
|
|
"vocab_size": self.vocab_size,
|
|
"hidden_size": self.hidden_size,
|
|
"hidden_size_global": self.hidden_size * 2, # Must match global transformer output size
|
|
"num_hidden_layers": self.num_hidden_layers,
|
|
"num_attention_heads": self.num_attention_heads,
|
|
"num_key_value_heads": self.num_key_value_heads,
|
|
"intermediate_size": self.intermediate_size,
|
|
"max_position_embeddings": self.max_position_embeddings,
|
|
"rope_theta": self.rope_theta,
|
|
"rope_parameters": self.rope_parameters,
|
|
"hidden_act": self.hidden_act,
|
|
"rms_norm_eps": self.rms_norm_eps,
|
|
"dropout": self.dropout,
|
|
}
|
|
|
|
self.global_config = {
|
|
"hidden_size": self.hidden_size * 2, # Double the hidden size for global transformer
|
|
"num_hidden_layers": self.num_hidden_layers,
|
|
"num_attention_heads": self.num_attention_heads,
|
|
"num_key_value_heads": self.num_key_value_heads,
|
|
"intermediate_size": self.intermediate_size,
|
|
"max_position_embeddings": self.max_position_embeddings,
|
|
"rope_theta": self.rope_theta,
|
|
"rope_parameters": self.rope_parameters,
|
|
"hidden_act": self.hidden_act,
|
|
"rms_norm_eps": self.rms_norm_eps,
|
|
"dropout": self.dropout,
|
|
}
|
|
|
|
self.num_hidden_layers = self.encoder_config["num_hidden_layers"]
|
|
|
|
def get_config(self):
|
|
config = BltConfig(
|
|
vocab_size=self.vocab_size,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
patch_in_forward=False, # Disable patching for tests
|
|
patch_size=4,
|
|
patching_mode="entropy",
|
|
patching_threshold=1.335442066192627,
|
|
patching_batch_size=1,
|
|
max_patch_length=None,
|
|
cross_attn_k=2,
|
|
encoder_hash_byte_group_size=self.encoder_hash_byte_group_size,
|
|
encoder_hash_byte_group_vocab=self.encoder_hash_byte_group_vocab,
|
|
encoder_hash_byte_group_nb_functions=self.encoder_hash_byte_group_nb_functions,
|
|
patcher_config=self.patcher_config,
|
|
encoder_config=self.encoder_config,
|
|
decoder_config=self.decoder_config,
|
|
global_config=self.global_config,
|
|
rope_parameters=self.rope_parameters,
|
|
tie_word_embeddings=False,
|
|
)
|
|
|
|
config.num_attention_heads = config.decoder_config.num_attention_heads
|
|
config.num_hidden_layers = config.encoder_config.num_hidden_layers
|
|
config.hidden_size = config.decoder_config.hidden_size
|
|
|
|
return config
|
|
|
|
|
|
@require_torch
|
|
class BltModelTest(CausalLMModelTest, unittest.TestCase):
|
|
model_tester_class = BltModelTester
|
|
|
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
|
# This is because we are hitting edge cases with the causal_mask buffer
|
|
model_split_percents = [0.5, 0.7, 0.8]
|
|
|
|
# used in `test_torch_compile_for_training`
|
|
_torch_compile_train_cls = BltForCausalLM if is_torch_available() else None
|
|
|
|
@pytest.mark.generate
|
|
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
|
@unittest.skip(
|
|
"Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs"
|
|
)
|
|
def test_generate_from_inputs_embeds(self, _, num_beams):
|
|
pass
|
|
|
|
@pytest.mark.generate
|
|
def test_generate_with_quant_cache(self):
|
|
self.skipTest("BLT uses EncoderDecoderCache internally and does not support quantized cache")
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(
|
|
"Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs"
|
|
)
|
|
def test_inputs_embeds_matches_input_ids(self):
|
|
pass
|
|
|
|
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
|
def test_eager_matches_sdpa_inference(
|
|
self,
|
|
name,
|
|
torch_dtype,
|
|
padding_side,
|
|
use_attention_mask,
|
|
output_attentions,
|
|
enable_kernels,
|
|
):
|
|
"We need to relax a bit the `atols` for fp32 here due to the altup projections"
|
|
atols = {
|
|
("cpu", False, torch.float32): 2e-2, # this was relaxed
|
|
("cpu", False, torch.float16): 5e-3,
|
|
("cpu", False, torch.bfloat16): 1e-2,
|
|
("cpu", True, torch.float32): 2e-2, # this was relaxed
|
|
("cpu", True, torch.float16): 5e-3,
|
|
("cpu", True, torch.bfloat16): 1e-2,
|
|
("cuda", False, torch.float32): 2e-2, # this was relaxed
|
|
("cuda", False, torch.bfloat16): 1e-2,
|
|
("cuda", False, torch.float16): 5e-3,
|
|
("cuda", True, torch.float32): 2e-2, # this was relaxed
|
|
("cuda", True, torch.bfloat16): 1e-2,
|
|
("cuda", True, torch.float16): 5e-3,
|
|
}
|
|
_test_eager_matches_sdpa_inference(
|
|
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols
|
|
)
|
|
|
|
@require_torch_accelerator
|
|
@slow
|
|
def test_sdpa_can_dispatch_on_flash(self):
|
|
self.skipTest("BLT always has an attention_mask input")
|
|
|
|
|
|
@require_torch_accelerator
|
|
class BltIntegrationTest(unittest.TestCase):
|
|
def setup(self):
|
|
cleanup(torch_device, gc_collect=True)
|
|
|
|
def tearDown(self):
|
|
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
|
|
# some memory allocated in the cache, which means some object is not being released properly. This causes some
|
|
# unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
|
|
# Investigate the root cause.
|
|
cleanup(torch_device, gc_collect=True)
|
|
|
|
@slow
|
|
def test_model(self):
|
|
NUM_TOKENS_TO_GENERATE = 200
|
|
EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s"
|
|
|
|
prompt = "my name is"
|
|
|
|
model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
generated_ids = model.generate(
|
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
|
|
)
|
|
|
|
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
self.assertEqual(output_text, EXPECTED_TEXT)
|
|
|
|
@slow
|
|
def test_model_logits(self):
|
|
# fmt: off
|
|
EXPECTED_OUTPUT = Expectations(
|
|
{
|
|
(None, None): torch.tensor(
|
|
[
|
|
[-10.5000, -10.6875, -6.2500, -10.5625, -10.3125, -9.1875, -8.5000, -8.5625, -9.1875, -9.6250, -9.3750, -8.5000, -9.1250, -3.3906, 2.9688, -10.3125, -6.4688, -6.0312, -9.7500, -9.1875, -8.8125, -9.8750, -9.8125, -9.5000, -9.8125, -9.5000, -9.0625, -9.8125, -9.5000, -9.3750],
|
|
[-13.2500, -13.1250, -5.6875, -13.1875, -13.3750, -8.6875, -6.9688, -6.9375, -10.0625, -10.3125, -9.8125, -7.7188, -8.8125, -5.2188, -3.5000, -12.4375, -9.0625, -6.6250, -10.3125, -9.1875, -10.6250, -11.4375, -11.1250, -10.8750, -10.5000, -10.8750, -11.0000, -11.3125, -10.5000, -9.8750],
|
|
]
|
|
),
|
|
("xpu", None): torch.tensor(
|
|
[
|
|
[-10.4375, -10.6875, -6.1875, -10.5000, -10.3125, -9.1250, -8.4375, -8.6250, -9.1875, -9.5625, -9.3125, -8.4375, -9.0625, -3.4375, 2.9531, -10.2500, -6.4062, -6.0000, -9.6875, -9.1875, -8.8125, -9.8125, -9.7500, -9.4375, -9.7500, -9.4375, -9.0000, -9.8125, -9.4375, -9.3125],
|
|
[-13.3125, -13.2500, -5.5938, -13.3125, -13.5000, -8.7500, -7.0625, -7.0312, -10.1875, -10.3750, -9.9375, -7.8438, -8.8750, -5.3438, -3.5938, -12.5625, -9.2500, -6.8125, -10.3750, -9.3125, -10.6875, -11.5625, -11.3125, -11.0000, -10.6250, -10.9375, -11.0625, -11.3750, -10.5625, -10.0000],
|
|
]
|
|
),
|
|
}
|
|
).get_expectation()
|
|
EXPECTED_OUTPUT = EXPECTED_OUTPUT.to(torch_device)
|
|
# fmt: on
|
|
|
|
input_ids = [1, 42, 21, 12, 43, 23, 1, 4]
|
|
|
|
model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf", attn_implementation="sdpa", device_map="auto")
|
|
|
|
with torch.no_grad():
|
|
output = model(torch.tensor([input_ids]).to(torch_device))[0]
|
|
|
|
torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30].to(torch_device), rtol=1e-3, atol=1e-3)
|
|
|
|
@slow
|
|
@require_torch_bf16
|
|
def test_model_bf16(self):
|
|
"""Test Blt model with bfloat16 precision."""
|
|
NUM_TOKENS_TO_GENERATE = 200
|
|
# fmt: off
|
|
EXPECTED_TEXT = Expectations(
|
|
{
|
|
(None, None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s",
|
|
("xpu", None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s",
|
|
}
|
|
)
|
|
# fmt: on
|
|
|
|
prompt = "my name is"
|
|
|
|
model = BltForCausalLM.from_pretrained(
|
|
"itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
generated_ids = model.generate(
|
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
|
|
)
|
|
|
|
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
self.assertEqual(output_text, EXPECTED_TEXT.get_expectation())
|
|
|
|
@slow
|
|
@require_torch_bf16
|
|
def test_model_logits_bf16(self):
|
|
"""Test Blt model logits with bfloat16 precision."""
|
|
|
|
# fmt: off
|
|
EXPECTED_OUTPUT = Expectations(
|
|
{
|
|
(None, None): torch.tensor(
|
|
[
|
|
[-10.5000, -10.6875, -6.2500, -10.5625, -10.3125, -9.1875, -8.5000, -8.5625, -9.1875, -9.6250, -9.3750, -8.5000, -9.1250, -3.3906, 2.9688, -10.3125, -6.4688, -6.0312, -9.7500, -9.1875, -8.8125, -9.8750, -9.8125, -9.5000, -9.8125, -9.5000, -9.0625, -9.8125, -9.5000, -9.3750],
|
|
[-13.2500, -13.1250, -5.6875, -13.1875, -13.3750, -8.6875, -6.9688, -6.9375, -10.0625, -10.3125, -9.8125, -7.7188, -8.8125, -5.2188, -3.5000, -12.4375, -9.0625, -6.6250, -10.3125, -9.1875, -10.6250, -11.4375, -11.1250, -10.8750, -10.5000, -10.8750, -11.0000, -11.3125, -10.5000, -9.8750],
|
|
]
|
|
),
|
|
("xpu", None): torch.tensor(
|
|
[
|
|
[-10.4375, -10.6875, -6.1875, -10.5000, -10.3125, -9.1250, -8.4375, -8.6250, -9.1875, -9.5625, -9.3125, -8.4375, -9.0625, -3.4375, 2.9531, -10.2500, -6.4062, -6.0000, -9.6875, -9.1875, -8.8125, -9.8125, -9.7500, -9.4375, -9.7500, -9.4375, -9.0000, -9.8125, -9.4375, -9.3125],
|
|
[-13.3125, -13.2500, -5.5938, -13.3125, -13.5000, -8.7500, -7.0625, -7.0312, -10.1875, -10.3750, -9.9375, -7.8438, -8.8750, -5.3438, -3.5938, -12.5625, -9.2500, -6.8125, -10.3750, -9.3125, -10.6875, -11.5625, -11.3125, -11.0000, -10.6250, -10.9375, -11.0625, -11.3750, -10.5625, -10.0000],
|
|
]
|
|
),
|
|
}
|
|
).get_expectation()
|
|
EXPECTED_OUTPUT = EXPECTED_OUTPUT.to(torch_device)
|
|
# fmt: on
|
|
|
|
input_ids = [1, 42, 21, 12, 43, 23, 1, 4]
|
|
|
|
model = BltForCausalLM.from_pretrained(
|
|
"itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16
|
|
)
|
|
|
|
with torch.no_grad():
|
|
output = model(torch.tensor([input_ids]).to(torch_device))[0]
|
|
|
|
torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30].to(torch_device), rtol=1e-3, atol=1e-3)
|
|
|
|
@slow
|
|
def test_model_eager(self):
|
|
"""Test Blt model with bfloat16 precision using eager attention implementation."""
|
|
NUM_TOKENS_TO_GENERATE = 200
|
|
# fmt: off
|
|
EXPECTED_TEXT = Expectations(
|
|
{
|
|
(None, None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s",
|
|
("xpu", None): "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m",
|
|
}
|
|
)
|
|
# fmt: on
|
|
|
|
prompt = "my name is"
|
|
|
|
model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf", device_map="auto", attn_implementation="eager")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
generated_ids = model.generate(
|
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
|
|
)
|
|
|
|
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
self.assertEqual(output_text, EXPECTED_TEXT.get_expectation())
|
|
|
|
@slow
|
|
@require_torch_bf16
|
|
def test_model_bf16_static_cache(self):
|
|
"""Test Blt model with bfloat16 precision and static cache."""
|
|
NUM_TOKENS_TO_GENERATE = 200
|
|
# fmt: off
|
|
EXPECTED_TEXT = Expectations(
|
|
{
|
|
(None, None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s",
|
|
("xpu", None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s",
|
|
}
|
|
)
|
|
# fmt: on
|
|
|
|
prompt = "my name is"
|
|
|
|
model = BltForCausalLM.from_pretrained(
|
|
"itazap/blt-1b-hf", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16
|
|
)
|
|
|
|
model.generation_config.cache_implementation = "static"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
generated_ids = model.generate(
|
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, use_cache=False
|
|
)
|
|
|
|
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
self.assertEqual(output_text, EXPECTED_TEXT.get_expectation())
|