first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
168
tests/generation/test_paged_attention.py
Normal file
168
tests/generation/test_paged_attention.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
from transformers.generation.configuration_utils import ContinuousBatchingConfig
|
||||
from transformers.testing_utils import Expectations, slow
|
||||
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
"A man is a walking his dog down the street, and a the turn he sees",
|
||||
"Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an",
|
||||
"A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?",
|
||||
"Please fill in the form to",
|
||||
"For safety reasons, the train is stopped in the middle of the",
|
||||
]
|
||||
|
||||
_EXPECTED_OUTPUTS = Expectations(
|
||||
{
|
||||
("cpu", None): [ # FIXME: CPU tests only pass for eager and flex. Maybe the test should be re-thought.
|
||||
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes",
|
||||
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n##",
|
||||
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer",
|
||||
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]",
|
||||
"track. The train is stopped because of a mechanical failure. The train is stopped because of a mechanical failure. The train is stopped because of a mechanical",
|
||||
# TODO: investigate why that last expectation seems incorrect
|
||||
],
|
||||
("cuda", (9, 0)): [ # A10 and H100
|
||||
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes",
|
||||
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n##",
|
||||
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer",
|
||||
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]",
|
||||
# The last prompt sits on a numerical boundary: eager/flex produce "does", sdpa/fa2 produce "will".
|
||||
# We use a tuple to accept either variant.
|
||||
(
|
||||
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train",
|
||||
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers will the train",
|
||||
),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
class TestBatchGeneration(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.2-3b-Instruct", dtype="bfloat16", device_map="cuda"
|
||||
).eval()
|
||||
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left")
|
||||
|
||||
if cls.tokenizer.pad_token is None:
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
eos_id = cls.model.config.eos_token_id
|
||||
cls.model.config.pad_token_id = eos_id[0] if isinstance(eos_id, list) else eos_id
|
||||
|
||||
cls.model.use_cache = False
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("paged|eager", 64, 128, 64),
|
||||
("paged|sdpa", 32, 256, 128),
|
||||
("paged|flash_attention_2", 16, 512, 256),
|
||||
("paged|flex_attention", 64, 128, 64),
|
||||
]
|
||||
)
|
||||
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
||||
self.model.config.attn_implementation = attn_impl
|
||||
|
||||
cb_config = ContinuousBatchingConfig(
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
max_batch_tokens=max_batch_tokens,
|
||||
)
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=30,
|
||||
top_k=0,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
|
||||
batch_inputs = list(tokenized["input_ids"])
|
||||
|
||||
batch_outputs = self.model.generate_batch(
|
||||
inputs=batch_inputs,
|
||||
generation_config=generation_config,
|
||||
continuous_batching_config=cb_config,
|
||||
)
|
||||
|
||||
expected_outputs = _EXPECTED_OUTPUTS.get_expectation()
|
||||
|
||||
for i, (output, expected_output) in enumerate(zip(batch_outputs.values(), expected_outputs)):
|
||||
generated = self.tokenizer.decode(output.generated_tokens, skip_special_tokens=False).strip()
|
||||
expected_output = (expected_output.strip(),) if isinstance(expected_output, str) else expected_output
|
||||
self.assertIn(
|
||||
generated,
|
||||
[e.strip() for e in expected_output],
|
||||
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected one of: {expected_output}\nGot: {generated}",
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("paged|eager", 64, 128, 64),
|
||||
("paged|sdpa", 32, 256, 128),
|
||||
("paged|flash_attention_2", 16, 512, 256),
|
||||
("paged|flex_attention", 64, 128, 64),
|
||||
]
|
||||
)
|
||||
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
||||
"""Test batch generation with do_sampling=True to verify sampling works correctly."""
|
||||
self.model.config.attn_implementation = attn_impl
|
||||
|
||||
cb_config = ContinuousBatchingConfig(
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
max_batch_tokens=max_batch_tokens,
|
||||
)
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=30,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
temperature=0.8,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512) # Use fewer prompts for faster test
|
||||
batch_inputs = list(tokenized["input_ids"])
|
||||
|
||||
start = time.time()
|
||||
batch_outputs = self.model.generate_batch(
|
||||
inputs=batch_inputs,
|
||||
generation_config=generation_config,
|
||||
continuous_batching_config=cb_config,
|
||||
)
|
||||
end = time.time()
|
||||
print(
|
||||
f"\n[{attn_impl}] Sampling batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
|
||||
)
|
||||
|
||||
# With sampling enabled, we can't check exact outputs, but we should verify:
|
||||
# 1. All requests completed successfully
|
||||
# 2. Generated text is non-empty
|
||||
# 3. Generated text is different from greedy (demonstrating sampling is working)
|
||||
self.assertEqual(len(batch_outputs), len(batch_inputs), f"[{attn_impl}] Not all requests completed")
|
||||
|
||||
for i, req_id in enumerate(batch_outputs):
|
||||
generated = self.tokenizer.decode(
|
||||
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
|
||||
).strip()
|
||||
self.assertTrue(
|
||||
len(generated) > 0,
|
||||
msg=f"[{attn_impl}] Empty output for request {i}",
|
||||
)
|
||||
# Check that we got at least some tokens generated
|
||||
generated_tokens = batch_outputs[req_id].generated_tokens
|
||||
self.assertGreater(
|
||||
len(generated_tokens),
|
||||
0,
|
||||
msg=f"[{attn_impl}] No tokens generated for request {i}",
|
||||
)
|
||||
Reference in New Issue
Block a user