Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
169 lines
7.9 KiB
Python
169 lines
7.9 KiB
Python
import time
|
|
import unittest
|
|
|
|
from parameterized import parameterized
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
|
from transformers.generation.configuration_utils import ContinuousBatchingConfig
|
|
from transformers.testing_utils import Expectations, slow
|
|
|
|
|
|
_TEST_PROMPTS = [
|
|
"A man is a walking his dog down the street, and a the turn he sees",
|
|
"Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an",
|
|
"A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?",
|
|
"Please fill in the form to",
|
|
"For safety reasons, the train is stopped in the middle of the",
|
|
]
|
|
|
|
_EXPECTED_OUTPUTS = Expectations(
|
|
{
|
|
("cpu", None): [ # FIXME: CPU tests only pass for eager and flex. Maybe the test should be re-thought.
|
|
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes",
|
|
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n##",
|
|
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer",
|
|
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]",
|
|
"track. The train is stopped because of a mechanical failure. The train is stopped because of a mechanical failure. The train is stopped because of a mechanical",
|
|
# TODO: investigate why that last expectation seems incorrect
|
|
],
|
|
("cuda", (9, 0)): [ # A10 and H100
|
|
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes",
|
|
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n##",
|
|
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer",
|
|
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]",
|
|
# The last prompt sits on a numerical boundary: eager/flex produce "does", sdpa/fa2 produce "will".
|
|
# We use a tuple to accept either variant.
|
|
(
|
|
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train",
|
|
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers will the train",
|
|
),
|
|
],
|
|
}
|
|
)
|
|
|
|
|
|
@slow
|
|
class TestBatchGeneration(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = AutoModelForCausalLM.from_pretrained(
|
|
"meta-llama/Llama-3.2-3b-Instruct", dtype="bfloat16", device_map="cuda"
|
|
).eval()
|
|
|
|
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left")
|
|
|
|
if cls.tokenizer.pad_token is None:
|
|
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
|
eos_id = cls.model.config.eos_token_id
|
|
cls.model.config.pad_token_id = eos_id[0] if isinstance(eos_id, list) else eos_id
|
|
|
|
cls.model.use_cache = False
|
|
|
|
@parameterized.expand(
|
|
[
|
|
("paged|eager", 64, 128, 64),
|
|
("paged|sdpa", 32, 256, 128),
|
|
("paged|flash_attention_2", 16, 512, 256),
|
|
("paged|flex_attention", 64, 128, 64),
|
|
]
|
|
)
|
|
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
|
self.model.config.attn_implementation = attn_impl
|
|
|
|
cb_config = ContinuousBatchingConfig(
|
|
num_blocks=num_blocks,
|
|
block_size=block_size,
|
|
max_batch_tokens=max_batch_tokens,
|
|
)
|
|
generation_config = GenerationConfig(
|
|
max_new_tokens=30,
|
|
top_k=0,
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
use_cache=False,
|
|
)
|
|
|
|
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
|
|
batch_inputs = list(tokenized["input_ids"])
|
|
|
|
batch_outputs = self.model.generate_batch(
|
|
inputs=batch_inputs,
|
|
generation_config=generation_config,
|
|
continuous_batching_config=cb_config,
|
|
)
|
|
|
|
expected_outputs = _EXPECTED_OUTPUTS.get_expectation()
|
|
|
|
for i, (output, expected_output) in enumerate(zip(batch_outputs.values(), expected_outputs)):
|
|
generated = self.tokenizer.decode(output.generated_tokens, skip_special_tokens=False).strip()
|
|
expected_output = (expected_output.strip(),) if isinstance(expected_output, str) else expected_output
|
|
self.assertIn(
|
|
generated,
|
|
[e.strip() for e in expected_output],
|
|
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected one of: {expected_output}\nGot: {generated}",
|
|
)
|
|
|
|
@parameterized.expand(
|
|
[
|
|
("paged|eager", 64, 128, 64),
|
|
("paged|sdpa", 32, 256, 128),
|
|
("paged|flash_attention_2", 16, 512, 256),
|
|
("paged|flex_attention", 64, 128, 64),
|
|
]
|
|
)
|
|
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
|
"""Test batch generation with do_sampling=True to verify sampling works correctly."""
|
|
self.model.config.attn_implementation = attn_impl
|
|
|
|
cb_config = ContinuousBatchingConfig(
|
|
num_blocks=num_blocks,
|
|
block_size=block_size,
|
|
max_batch_tokens=max_batch_tokens,
|
|
)
|
|
generation_config = GenerationConfig(
|
|
max_new_tokens=30,
|
|
do_sample=True,
|
|
top_k=50,
|
|
top_p=0.9,
|
|
temperature=0.8,
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
use_cache=False,
|
|
)
|
|
|
|
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512) # Use fewer prompts for faster test
|
|
batch_inputs = list(tokenized["input_ids"])
|
|
|
|
start = time.time()
|
|
batch_outputs = self.model.generate_batch(
|
|
inputs=batch_inputs,
|
|
generation_config=generation_config,
|
|
continuous_batching_config=cb_config,
|
|
)
|
|
end = time.time()
|
|
print(
|
|
f"\n[{attn_impl}] Sampling batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
|
|
)
|
|
|
|
# With sampling enabled, we can't check exact outputs, but we should verify:
|
|
# 1. All requests completed successfully
|
|
# 2. Generated text is non-empty
|
|
# 3. Generated text is different from greedy (demonstrating sampling is working)
|
|
self.assertEqual(len(batch_outputs), len(batch_inputs), f"[{attn_impl}] Not all requests completed")
|
|
|
|
for i, req_id in enumerate(batch_outputs):
|
|
generated = self.tokenizer.decode(
|
|
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
|
|
).strip()
|
|
self.assertTrue(
|
|
len(generated) > 0,
|
|
msg=f"[{attn_impl}] Empty output for request {i}",
|
|
)
|
|
# Check that we got at least some tokens generated
|
|
generated_tokens = batch_outputs[req_id].generated_tokens
|
|
self.assertGreater(
|
|
len(generated_tokens),
|
|
0,
|
|
msg=f"[{attn_impl}] No tokens generated for request {i}",
|
|
)
|