Files
transformers/tests/generation/test_continuous_batching.py
陈赣 06f1fd69a6
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
first commit
2026-06-05 16:53:03 +08:00

1667 lines
78 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2025 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import gc
import itertools
import os
import unittest
from typing import Any
from unittest.mock import patch
import torch
from parameterized import parameterized
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
CompileConfig,
ContinuousBatchingConfig,
GenerationConfig,
GenerationMixin,
StaticCache,
)
from transformers.generation.continuous_batching.cache import (
PagedAttentionCache,
PagedAttentionMemoryHandler,
SlidingAttentionCacheAllocator,
group_layers_by_attn_type,
)
from transformers.generation.continuous_batching.cache_manager import FullAttentionCacheAllocator
from transformers.generation.continuous_batching.continuous_api import OutputRouter
from transformers.generation.continuous_batching.distributed import DistributedHelper
from transformers.generation.continuous_batching.input_outputs import build_attention_mask
from transformers.generation.continuous_batching.offloading_manager import OffloadingManager
from transformers.generation.continuous_batching.requests import GenerationOutput, RequestStatus
from transformers.testing_utils import (
require_deterministic_for_xpu,
require_flash_attn,
require_flash_attn_3,
require_kernels,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
)
from transformers.utils import (
is_flash_attn_2_available,
is_kernels_available,
is_torch_xpu_available,
)
from transformers.utils.generic import is_flash_attention_requested
from ..test_tensor_parallel_mixin import _init_distributed
# Constants for tests
_DEFAULT_USER_MESSAGES = [
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
"Josh decides to try flipping a house. He buys a house for $80,000 and then puts in $50,000 in repairs. This increased the value of the house by 150%. How much profit did he make?",
"A basket contains 25 oranges among which 1 is bad, 20% are unripe, 2 are sour and the rest are good. How many oranges are good?",
] # fmt: skip
# Helper functions
def flush_memory(flush_compile: bool = True) -> None:
"""Flushes the memory of the current device and, if the flush_compile flag is True, all data related to
torch.compile."""
gc.collect()
# If needed, flush everything related to torch.compile
if flush_compile:
# Dynamo resets
torch._dynamo.reset()
torch._dynamo.reset_code_caches()
if hasattr(torch._inductor, "codecache"):
# Clear FX graph cache
if hasattr(torch._inductor.codecache, "FxGraphCache"):
torch._inductor.codecache.FxGraphCache.clear()
# Clear PyCodeCache
if hasattr(torch._inductor.codecache, "PyCodeCache"):
torch._inductor.codecache.PyCodeCache.cache_clear()
# Clear TritonFuture cache (for async compilation)
if hasattr(torch._inductor.codecache, "TritonFuture"):
if hasattr(torch._inductor.codecache.TritonFuture, "_compile_cache"):
torch._inductor.codecache.TritonFuture._compile_cache.clear()
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
torch.xpu.synchronize()
gc.collect()
def get_tokenizer_and_model(
model_id: str, attn_implementation: str, device: str, dtype: str | torch.dtype = "auto"
) -> tuple[AutoTokenizer, GenerationMixin]:
"""Returns a tokenizer and a model for the given model ID. Attributes to setup the models (attn_implementation,
dtype and device) are needed as arguments."""
# Get tokenizer, with a padding token if not present
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
if not hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token"):
tokenizer.pad_token = tokenizer.eos_token
# Load model on CPU
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=attn_implementation, torch_dtype=dtype)
model = model.to(device).eval()
return tokenizer, model
def with_flush_memory(func):
"""Decorator that ensures flush_memory is called after the test, even if it fails."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Determine flush_compile value from continuous_batching_config or generation_config
cb_config = kwargs.get("continuous_batching_config")
generation_config = kwargs.get("generation_config")
if isinstance(cb_config, ContinuousBatchingConfig):
flush_compile = (
cb_config.use_default_compile_configs
or cb_config.varlen_compile_config is not None
or cb_config.decode_compile_config is not None
)
elif isinstance(generation_config, GenerationConfig):
flush_compile = generation_config.compile_config is not None
else:
flush_compile = False
# Run the test and always flush memory
try:
return func(*args, **kwargs)
finally:
flush_memory(flush_compile=flush_compile)
return wrapper
def get_generation_inputs(
user_messages: list[str], tokenizer: AutoTokenizer, for_continuous_batching: bool = False
) -> Any:
"""Returns the tokenized inputs for batched or non-batched generation."""
chats = [[{"role": "user", "content": user_message}] for user_message in user_messages]
if for_continuous_batching:
tokenized = [tokenizer.apply_chat_template(chat, add_generation_prompt=True) for chat in chats]
input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized]
return input_ids
else:
inputs = tokenizer.apply_chat_template(
chats,
add_generation_prompt=True,
return_tensors="pt",
padding=True,
return_dict=True,
return_attention_mask=True,
)
return inputs
def regular_generate(
model: GenerationMixin,
tokenizer: AutoTokenizer,
user_messages: list[str],
**generate_kwargs,
) -> tuple[list[list[int]], list[list[float]]]:
# Run generation
inputs = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=False)
generate_outputs = model.generate(**inputs.to(model.device), return_dict_in_generate=True, **generate_kwargs)
# Keep only generated tokens
all_generated_tokens = []
num_input_tokens = inputs.input_ids.shape[1]
for i in range(len(user_messages)):
# Remove left-side input and padding tokens
generated_toks = generate_outputs.sequences[i, num_input_tokens:].tolist()
# Remove right-side padding tokens
while generated_toks[-1] == model.generation_config.pad_token_id:
generated_toks.pop()
all_generated_tokens.append(generated_toks)
# Retrieve logprobs if the scores were requested
per_prompt_logprobs = []
if generate_kwargs.get("output_scores", False):
# Loop over prompts
for i in range(len(user_messages)):
logprobs = []
tokens_for_prompt = generate_outputs.sequences[i, num_input_tokens:].tolist()
for score, token in zip(generate_outputs.scores, tokens_for_prompt):
# Scores already have logits processors applied (including temperature)
probs = torch.nn.functional.softmax(score[i], dim=-1)
logprobs.append(probs[token].log().item())
per_prompt_logprobs.append(logprobs)
# Otherwise, return an empty list
else:
per_prompt_logprobs = []
return all_generated_tokens, per_prompt_logprobs
# Class for all continuous batching tests that do not require any accelerator. Usualy those test are faster to run.
class ContinuousBatchingNoAcceleratorTest(unittest.TestCase):
@parameterized.expand(
[
(None, None, "0"),
(None, 4096, "0"),
("f", None, "0"),
("ffff", None, "0000"),
("sssss", 4096, "00000"),
("fs", 4096, "01"),
("ssfssf", 4096, "001221"),
("ssssf", 4096, "01234"),
("fffsffs", 4096, "0123456"),
]
)
def test_group_layers(
self,
layer_types_str: str | None,
sliding_window: int | None,
expected_groups: str,
) -> None:
"""Test the layer grouping algorithm of the hybrid allocator."""
# Take a config and change the layer_types attribute to the mix we want
config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B")
if layer_types_str is not None:
layer_types = [{"f": "full_attention", "s": "sliding_window"}[char] for char in layer_types_str]
else:
layer_types = None
config.num_hidden_layers = len(expected_groups)
config.layer_types = layer_types
config.sliding_window = sliding_window
expected_lg = {}
for i, group in enumerate(expected_groups):
group = int(group)
expected_lg[group] = expected_lg.get(group, []) + [i]
expected_layer_groups = [expected_lg[i] for i in sorted(expected_lg.keys())]
# Test layer groups formation
layer_groups, group_types = group_layers_by_attn_type(config)
self.assertEqual(
sorted(expected_layer_groups),
sorted(layer_groups),
f"Test failed for: {layer_types_str = }, {sliding_window = }, {expected_layer_groups = }, {layer_groups = }",
)
# If layer_types is provided, check that group_types matches the type of the all layers in each group
if layer_types is not None:
for layer_group, group_type in zip(layer_groups, group_types):
layer_types = [config.layer_types[i] for i in layer_group]
self.assertEqual(layer_types, [group_type] * len(layer_types))
# If layer_types is None, all groups should be of the same type
else:
for group_type in group_types:
sliding_window = getattr(config, "sliding_window", None)
expected_group_type = "sliding_attention" if sliding_window is not None else "full_attention"
self.assertEqual(
group_type,
expected_group_type,
f"Test failed for: {layer_types_str = }, {sliding_window = }, {group_types = }",
)
@parameterized.expand(
[
([0, 4], [0, 4], 1, ["1000", "1100", "1110", "1111"]),
([0, 4], [0, 4], 2, ["1000", "1100", "0110", "0011"]),
([0, 3], [0, 5], 1, ["11100", "11110", "11111"]),
([0, 3], [0, 5], 3, ["11100", "01110", "00111"]),
([0, 3, 6], [0, 3, 6], 1, ["100000", "110000", "111000", "000100", "000110", "000111"]),
([0, 3, 6], [0, 3, 6], 2, ["100000", "110000", "011000", "000100", "000110", "000011"]),
]
)
def test_attention_mask(
self,
cumulative_seqlens_q: list[int],
cumulative_seqlens_k: list[int],
sliding_window: int, # the sliding window size, 1 means no sliding window
str_expected_mask_lines: list[str], # the attention mask, broken down by line as a string of 0s and 1s
) -> None:
"""Tests the correctness of the attention mask used in the continuous batching API."""
# Build expected mask
minus_inf = torch.finfo(torch.float32).min
expected_mask = torch.empty((cumulative_seqlens_q[-1], cumulative_seqlens_k[-1]), dtype=torch.float32)
for i, line in enumerate(str_expected_mask_lines):
expected_mask[i, :] = torch.tensor([minus_inf if c == "0" else 0 for c in line])
# Build actual mask
actual_mask = torch.full_like(expected_mask, minus_inf) # function modifies in place
build_attention_mask(actual_mask, cumulative_seqlens_q, cumulative_seqlens_k, sliding_window)
# Check that the actual mask matches the expected mask
matches = (expected_mask == actual_mask).all()
# If it doesn't match, print the masks in a readable form and fail the test
if not matches:
str_mask = [
"".join("1" if x == 0 else "0" for x in token_attn_vector) for token_attn_vector in actual_mask
]
str_mask = "\n".join(str_mask)
str_expected_mask = "\n".join(str_expected_mask_lines)
self.fail(
f"Test failed for: {cumulative_seqlens_q = }, {cumulative_seqlens_k = }, {sliding_window = }\n"
f"Expected mask:\n{str_expected_mask}\n"
f"Actual mask:\n{str_mask}"
)
@parameterized.expand(
[
# Case 1: Only full attention groups, allocation succeeds
# needed_blocks = 2 * 1 = 2, free_blocks = 10 -> 2 <= 10 = True
(2, 0, 1, 0, 0, 10, True),
# Case 2: Only full attention groups, allocation fails
# needed_blocks = 5 * 2 = 10, free_blocks = 5 -> 10 <= 5 = False
(5, 0, 2, 0, 0, 5, False),
# Case 3: Mixed attention, sliding window not yet full
# needed_blocks = 2 * 1 + min(4 - 0, 2) * 1 = 2 + 2 = 4, free_blocks = 10 -> 4 <= 10 = True
(2, 0, 1, 1, 4, 10, True),
# Case 4: Mixed attention, sliding window partially filled
# needed_blocks = 3 * 1 + min(4 - 2, 3) * 1 = 3 + 2 = 5, free_blocks = 5 -> 5 <= 5 = True
(3, 2, 1, 1, 4, 5, True),
# Case 5: Mixed attention, sliding window already full (allocated_blocks >= max_sliding)
# blocks_left = max(4 - 5, 0) = 0, needed_blocks = 3 * 1 + 0 = 3, free_blocks = 5 -> 3 <= 5 = True
(3, 5, 1, 1, 4, 5, True),
# Case 6: Mixed attention, sliding window full, allocation fails due to full attention
# blocks_left = max(4 - 4, 0) = 0, needed_blocks = 6 * 1 + 0 = 6, free_blocks = 5 -> 6 <= 5 = False
(6, 4, 1, 1, 4, 5, False),
# Case 7: Multiple full attention groups
# needed_blocks = 3 * 2 = 6, free_blocks = 6 -> 6 <= 6 = True
(3, 0, 2, 0, 0, 6, True),
# Case 8: Multiple sliding attention groups, not full
# needed_blocks = 2 * 1 + min(4 - 1, 2) * 2 = 2 + 4 = 6, free_blocks = 6 -> 6 <= 6 = True
(2, 1, 1, 2, 4, 6, True),
# Case 9: Edge case - requesting 0 blocks always succeeds
# needed_blocks = 0, free_blocks = 0 -> 0 <= 0 = True
(0, 0, 1, 1, 4, 0, True),
# Case 10: Edge case - exactly enough blocks
# needed_blocks = 2 * 1 + min(3 - 0, 2) * 1 = 2 + 2 = 4, free_blocks = 4 -> 4 <= 4 = True
(2, 0, 1, 1, 3, 4, True),
]
)
def test_continuous_batching_will_allocation_be_successful(
self,
num_requested_blocks: int,
allocated_blocks: int,
num_full_attention_groups: int,
num_sliding_attention_groups: int,
max_sliding_window_blocks_per_request: int,
num_free_blocks: int,
expected_result: bool,
) -> None:
"""Test the will_allocation_be_successful method of PagedAttentionCache, overloading the relevant attributes of
a dummy cache."""
if torch_device is None: # this check which should always pass and helps with type checking
raise ValueError(f"This requires a torch accelerator, yet {torch_device = } and the test was not skipped.")
# Create the cache
cache = PagedAttentionCache(
config=AutoConfig.from_pretrained("HuggingFaceTB/SmolLM-1.7B", attn_implementation="sdpa"),
continuous_batching_config=ContinuousBatchingConfig(block_size=16, num_blocks=8, max_batch_tokens=8),
device=torch_device,
tp_plan={},
distributed_helper=DistributedHelper(device_mesh=None, cpu_group_timeout=300),
)
# Overload cache parameters to match test scenario
cache.num_full_attention_groups = num_full_attention_groups
cache.num_sliding_attention_groups = num_sliding_attention_groups
cache.max_sliding_window_blocks_per_request = max_sliding_window_blocks_per_request
# Overload the cache get_num_free_blocks method
cache.get_num_free_blocks = lambda: num_free_blocks
# Test the method
result = cache.will_allocation_be_successful(num_requested_blocks, allocated_blocks)
self.assertEqual(
result,
expected_result,
f"Failed for: {num_requested_blocks=}, {allocated_blocks=}, {num_full_attention_groups=}, "
f"{num_sliding_attention_groups=}, {max_sliding_window_blocks_per_request=}, {num_free_blocks=}. "
f"Expected {expected_result}, got {result}",
)
@parameterized.expand(
[
# (block_size, block_table, past_length, query_length)
# Contiguous blocks
(32, [0, 1, 2], 0, 16),
(32, [0, 1, 2], 0, 64),
(32, [0, 1, 2], 16, 16),
(32, [0, 1, 2], 31, 2),
# Non-contiguous blocks
(32, [0, 3, 6], 0, 64),
(32, [2, 5, 8], 60, 10),
# Different block sizes
(16, [0, 1, 2, 3], 14, 4),
(64, [0, 1], 60, 10),
]
)
def test_full_attention_get_indices(
self,
block_size: int,
block_table: list[int],
past_length: int,
query_length: int,
) -> None:
"""Test FullAttentionCacheAllocator.get_read_indices and get_write_indices return correct physical indices."""
def reference_indices(start: int, end: int) -> list[int]:
"""Reference implementation: converts logical indices to physical indices."""
return [block_table[i // block_size] * block_size + i % block_size for i in range(start, end)]
allocator = FullAttentionCacheAllocator(index=0, block_size=block_size, allow_block_sharing=False)
allocator.block_table["req"] = block_table
# Test read indices (from 0 to past_length + query_length)
expected_read = reference_indices(0, past_length + query_length)
self.assertEqual(allocator.get_read_indices("req", past_length, query_length), expected_read)
# Test write indices (from past_length to past_length + query_length)
expected_write = reference_indices(past_length, past_length + query_length)
self.assertEqual(allocator.get_write_indices("req", past_length, query_length), expected_write)
@slow
def test_continuous_batching_no_accelerators(self) -> None:
"""Test continuous batching generation when no accelerator is available. It uses a simulated CPU-only PyTorch
environment by mocking all acceleratoravailability checks to return False"""
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Mock all accelerator availability checks to simulate CPU-only PyTorch
with (
patch("torch.cuda.is_available", return_value=False),
patch("transformers.utils.is_torch_xpu_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=False),
):
# Verify patches work
self.assertFalse(torch.cuda.is_available())
self.assertFalse(is_torch_xpu_available())
self.assertFalse(torch.backends.mps.is_available())
tokenizer, model = get_tokenizer_and_model(model_id, "sdpa", "cpu")
user_messages = _DEFAULT_USER_MESSAGES[:1]
input_ids = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=True)
model.generation_config.max_new_tokens = 10
model.generation_config.do_sample = False
continuous_batching_config = ContinuousBatchingConfig(use_cuda_graph=False, use_async_batching=False)
# This should not crash even with all accelerators unavailable
outputs = model.generate_batch(
inputs=input_ids,
generation_config=model.generation_config,
continuous_batching_config=continuous_batching_config,
)
# Verify we got outputs
self.assertEqual(len(outputs), len(input_ids))
for output in outputs.values():
self.assertIsNotNone(output.generated_tokens)
self.assertGreater(len(output.generated_tokens), 0)
def test_output_router_deliver_to_queue(self):
"""Test that OutputRouter.deliver places outputs on the queue when no handler is registered."""
router = OutputRouter()
output = GenerationOutput(request_id="req_0", status=RequestStatus.FINISHED)
router.deliver(output)
result = router.output_queue.get_nowait()
self.assertEqual(result.request_id, "req_0")
self.assertTrue(router.output_queue.empty())
def test_output_router_deliver_to_handler(self):
"""Test that OutputRouter.deliver forwards to a registered handler instead of the queue."""
router = OutputRouter()
received = []
loop = unittest.mock.Mock()
with router._lock:
router.result_handlers["req_0"] = (lambda out: received.append(out), loop)
output = GenerationOutput(request_id="req_0", status=RequestStatus.DECODING)
router.deliver(output)
loop.call_soon_threadsafe.assert_called_once()
self.assertTrue(router.output_queue.empty())
def test_distributed_helper_no_dist(self) -> None:
"""Test that DistributedHelper falls back to a single-rank, TP-driver setup when distributed is not on."""
helper = DistributedHelper(device_mesh=None, cpu_group_timeout=300)
self.assertFalse(helper.dist_on)
self.assertEqual(helper.global_rank, 0)
self.assertEqual(helper.world_size, 1)
self.assertEqual(helper.tp_size, 1)
self.assertEqual(helper.tp_local_rank, 0)
self.assertEqual(helper.dp_rank, 0)
self.assertEqual(helper.dp_size, 1)
self.assertTrue(helper.is_tp_driver)
self.assertIsNone(helper.tp_group)
self.assertIsNone(helper.cpu_comm_group)
# Tensor and object broadcasts should be no-ops without a TP group
tensor = torch.tensor([1.0, 2.0])
self.assertTrue(torch.equal(helper.tp_broadcast_from_rank_0(tensor), tensor))
obj = {"some_request": "payload"}
self.assertIs(helper.tp_broadcast_object_from_rank_0(obj), obj)
# All-reduce-min should be a no-op without a TP group
reduce_tensor = torch.tensor([7, 3], dtype=torch.int64)
self.assertIs(helper.tp_all_reduce_min(reduce_tensor), reduce_tensor)
self.assertTrue(torch.equal(reduce_tensor, torch.tensor([7, 3], dtype=torch.int64)))
def test_distributed_helper_set_tp_seed_no_dist(self) -> None:
"""Test that set_tp_seed sets a torch seed without distributed initialized, both with and without a user seed."""
helper = DistributedHelper(device_mesh=None, cpu_group_timeout=300)
# Explicit seed: torch RNG state must be reproducible across calls
helper.set_tp_seed(seed=42, model_device=torch.device("cpu"))
first = torch.randint(0, 2**31 - 1, (4,))
helper.set_tp_seed(seed=42, model_device=torch.device("cpu"))
second = torch.randint(0, 2**31 - 1, (4,))
self.assertTrue(torch.equal(first, second))
# No seed: should not raise and should still set a torch seed
helper.set_tp_seed(seed=None, model_device=torch.device("cpu"))
def test_continuous_batching_config_disables_nccl_graph_mixing(self) -> None:
"""Test that ContinuousBatchingConfig sets NCCL_GRAPH_MIXING_SUPPORT=0 only under a distributed launch
(WORLD_SIZE > 1) and respects the disable_nccl_graph_mixing flag."""
original_nccl = os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None)
original_ws = os.environ.pop("WORLD_SIZE", None)
try:
# Single-GPU launch (no WORLD_SIZE): env var is left untouched
ContinuousBatchingConfig()
self.assertNotIn("NCCL_GRAPH_MIXING_SUPPORT", os.environ)
# Distributed launch (WORLD_SIZE > 1): env var is set to "0"
os.environ["WORLD_SIZE"] = "2"
ContinuousBatchingConfig()
self.assertEqual(os.environ.get("NCCL_GRAPH_MIXING_SUPPORT"), "0")
# Explicitly disabled flag: env var is left untouched even under a distributed launch
os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None)
ContinuousBatchingConfig(disable_nccl_graph_mixing=False)
self.assertNotIn("NCCL_GRAPH_MIXING_SUPPORT", os.environ)
# setdefault semantics: a pre-existing value is preserved
os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "1"
ContinuousBatchingConfig()
self.assertEqual(os.environ.get("NCCL_GRAPH_MIXING_SUPPORT"), "1")
finally:
if original_nccl is None:
os.environ.pop("NCCL_GRAPH_MIXING_SUPPORT", None)
else:
os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = original_nccl
if original_ws is None:
os.environ.pop("WORLD_SIZE", None)
else:
os.environ["WORLD_SIZE"] = original_ws
@require_torch_accelerator
class ContinuousBatchingWithAcceleratorTest(unittest.TestCase):
# -----------------------------------------------Parity tests----------------------------------------------- #
# Ensure continuous batching and non-continuous batching generation produce the same outputs #
# ---------------------------------------------------------------------------------------------------------- #
@require_deterministic_for_xpu
@with_flush_memory
def _test_continuous_batching_parity(
self,
model_id: str,
continuous_batching_config: ContinuousBatchingConfig,
attn_implementation: str,
max_new_tokens: int = 20,
num_repeat_prompts: int = 1,
) -> None:
"""Tests the parity between continuous batching and non-continuous batching generation."""
# Skip the test if Flash Attention is required but not available
is_fa = is_flash_attention_requested(requested_attention_implementation=attn_implementation)
if is_fa and not (is_flash_attn_2_available() or is_kernels_available()):
self.skipTest("Flash Attention is not available and neither is the kernels library. Skipping test.")
# Skip the test if cuda graph is on but the device is not CUDA
if continuous_batching_config.use_cuda_graph and torch_device != "cuda":
self.skipTest("CUDA graph is only supported on CUDA devices. Skipping test.")
# If the config turns on compile, change the generation config to use the default mode instead of
# max-autotune-no-cudagraphs which can change the kernels between generate_batch and generate
if continuous_batching_config.use_default_compile_configs:
fullgraph = not is_flash_attention_requested(requested_attention_implementation=attn_implementation)
compile_config = CompileConfig(mode="default", fullgraph=fullgraph, dynamic=True)
continuous_batching_config.varlen_compile_config = compile_config
# Eager and SDPA implementations get a precision boost to account for the fact that an attention mask is used in
# continuous batching but not in generate
dtype = "auto" if is_fa else torch.float32
# Prepare inputs
tokenizer, model = get_tokenizer_and_model(model_id, attn_implementation, torch_device, dtype)
if (
attn_implementation == "flash_attention_2"
and torch_device == "cpu"
and getattr(model.config, "sliding_window", None) is not None
and model.config.sliding_window > 0
):
self.skipTest("Flash Attention 2 with sliding window attention is not supported on CPU. Skipping test.")
user_messages = _DEFAULT_USER_MESSAGES * num_repeat_prompts
input_ids = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=True)
model.generation_config.max_new_tokens = max_new_tokens
model.generation_config.do_sample = False
# Generation with continuous batching
continuous_batching_outputs = model.generate_batch(
inputs=input_ids,
generation_config=model.generation_config,
continuous_batching_config=continuous_batching_config,
)
# Prepare non-continuous batching inputs and model
inputs = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=False)
num_input_tokens = inputs.input_ids.shape[1]
# Generation without continuous batching (reload model to avoid any state contamination)
_, model = get_tokenizer_and_model(model_id, attn_implementation, torch_device, dtype)
model.generation_config.max_new_tokens = max_new_tokens
model.generation_config.do_sample = False
model.generation_config.use_cuda_graph = continuous_batching_config.use_cuda_graph
model.generation_config.compile_config = continuous_batching_config.varlen_compile_config
# Create a static cache if compile_config is set, because regular generate requires a compileable cache
past_key_values = None
if model.generation_config.compile_config is not None:
max_cache_len = num_input_tokens + max_new_tokens
past_key_values = StaticCache(config=model.config, max_cache_len=max_cache_len)
generate_outputs = model.generate(
**inputs.to(torch_device), generation_config=model.generation_config, past_key_values=past_key_values
)
for i, user_message in enumerate(user_messages):
# Find the corresponding request in the continuous batching outputs
input_tokens = inputs.input_ids[i][inputs.attention_mask[i] == 1].tolist()
key_to_pop = None
for key, state in continuous_batching_outputs.items():
if state.prompt_ids == input_tokens:
key_to_pop = key
break
if key_to_pop is None:
self.fail(f"Request {i} not found in continuous batching outputs")
continuous_batching_output = continuous_batching_outputs.pop(key_to_pop).generated_tokens
generate_output = generate_outputs[i][num_input_tokens:].tolist()
while generate_output[-1] == model.generation_config.pad_token_id:
generate_output.pop()
if continuous_batching_output != generate_output:
decoded_continuous_batching_output = tokenizer.decode(continuous_batching_output)
decoded_generate_output = tokenizer.decode(generate_output)
msg = f"Test failed for {model_id = } {continuous_batching_config = }, {attn_implementation = }\n"
msg += f"User message : {repr(user_message)}\n"
msg += f"Continuous batching output: {repr(decoded_continuous_batching_output)}\n"
msg += f"Generate output : {repr(decoded_generate_output)}"
self.fail(msg)
@parameterized.expand(
list(
itertools.product(
[False, True],
["eager", "sdpa", "flash_attention_2"],
[False, True],
)
)
)
@slow
def test_continuous_batching_config_combinations_no_compile(
self,
allow_block_sharing: bool,
attn_implementation: str,
use_cuda_graph: bool,
) -> None:
# Compiling adds a lot of overhead, so it's better not to include here (2*3*2=12 tests because of cross product)
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
continuous_batching_config = ContinuousBatchingConfig(
allow_block_sharing=allow_block_sharing,
use_cuda_graph=use_cuda_graph,
use_default_compile_configs=False,
)
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=continuous_batching_config,
attn_implementation=attn_implementation,
)
@parameterized.expand([("eager", False), ("sdpa", False), ("sdpa", True), ("flash_attention_2", True)])
@slow
def test_continuous_batching_config_combinations_with_compile(
self,
attn_implementation: str,
use_cuda_graph: bool,
) -> None:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=use_cuda_graph,
use_default_compile_configs=True,
)
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=continuous_batching_config,
attn_implementation=attn_implementation,
)
# FIXME: Qwen2.5-0.5B-Instruct is not here because it's broken (it uses a repetition penalty logits processor)
# TODO: replace gemma2 with a tiny version of GPT-OSS? That way we can test sliding window AND attention sink
@parameterized.expand(
list(
itertools.product(
["google/gemma-2-2b-it"],
[False, True],
[False, True],
)
)
)
@slow
def test_continuous_batching_diverse_models(self, model_id: str, use_cuda_graph: bool, use_compile: bool) -> None:
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=use_cuda_graph, use_default_compile_configs=use_compile
)
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=continuous_batching_config,
attn_implementation="flash_attention_2",
)
@parameterized.expand([(True, False), (False, True)])
@require_flash_attn_3
@slow
def test_continuous_batching_tuple_cuda_graph(self, varlen_cg: bool, decode_cg: bool) -> None:
"""Tests that use_cuda_graph can be a tuple to independently control varlen and decode paths."""
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=(varlen_cg, decode_cg),
use_async_batching=False,
)
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=continuous_batching_config,
attn_implementation="flash_attention_3",
)
def test_continuous_batching_fast(self) -> None:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=False,
allow_block_sharing=False,
use_async_batching=False,
use_default_compile_configs=False,
)
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=continuous_batching_config,
attn_implementation="sdpa",
)
def test_continuous_batching_long_generate(self) -> None:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=True, allow_block_sharing=True, use_async_batching=False, use_default_compile_configs=True
)
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=continuous_batching_config,
attn_implementation="sdpa",
max_new_tokens=80,
)
@parameterized.expand([(False, False), (False, True), (True, False), (True, True)])
@slow
def test_continuous_batching_log_probs(self, use_cuda_graph: bool, use_async_batching: bool) -> None:
"""Test that log probabilities match between continuous batching and regular generate."""
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Retrieve tokenizer, model and eos_token_id (required otherwise logits will be misaligned)
tokenizer, model = get_tokenizer_and_model(model_id, "sdpa", torch_device, torch.float32)
eos_token_id = model.config.eos_token_id # type: ignore[attr-defined]
# Run CB generation
user_messages = ["What is 2+2?", "Hello world"]
input_ids = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=True)
gen_config = GenerationConfig(max_new_tokens=10, do_sample=False, eos_token_id=eos_token_id)
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=use_cuda_graph,
use_async_batching=use_async_batching,
return_logprobs=True,
)
cb_outputs = model.generate_batch(
inputs=input_ids, generation_config=gen_config, continuous_batching_config=continuous_batching_config
)
# Load fresh model for regular generate
tokenizer, model = get_tokenizer_and_model(model_id, "sdpa", torch_device, torch.float32)
# Run regular generate
regular_outputs, regular_logprobs = regular_generate(
model=model,
tokenizer=tokenizer,
user_messages=user_messages,
max_new_tokens=10,
do_sample=False,
output_scores=True,
eos_token_id=eos_token_id,
)
# Compare log_probs for each request, matching by prompt_ids
for i, cb_output in enumerate(cb_outputs.values()):
# Compare Cb and regular generate outputs
cb_output_ids = cb_output.generated_tokens
regular_output_ids = regular_outputs[i]
self.assertEqual(len(cb_output_ids), len(regular_output_ids))
self.assertEqual(cb_output_ids, regular_output_ids)
# Retrieve logprobs from CB and regular generate
cb_logprobs = cb_output.logprobs
expected_logprobs = regular_logprobs[i]
# Because of padding, we need to truncate to the same length
min_len = min(len(cb_logprobs), len(expected_logprobs))
cb_logprobs = cb_logprobs[:min_len]
expected_logprobs = expected_logprobs[:min_len]
self.assertEqual(len(cb_logprobs), len(expected_logprobs))
# Compare with tolerance for floating point differences (because of padding, tol is higher for cuda graphs)
delta = 2e-5 if use_cuda_graph else 1e-5
for j, (cb_lp, exp_lp) in enumerate(zip(cb_logprobs, expected_logprobs)):
error_msg = f"logprob mismatch at position {j} for request {i}: CB={cb_lp}, expected={exp_lp}"
self.assertAlmostEqual(cb_lp, exp_lp, delta=delta, msg=error_msg)
def test_continuous_batching_few_blocks(self) -> None:
"""This test verifies that generation works with a very small number of blocks, ie. small enough that we need to
offload a request at some point. To add more complexity, we repeat the same prompt 4 times and enable prefix
sharing."""
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=True, allow_block_sharing=True, use_async_batching=False, num_blocks=4, block_size=32
)
# Patch offload_one_request to verify it's called at least once
original_offload = OffloadingManager.offload_one_request
with patch.object(
OffloadingManager, "offload_one_request", autospec=True, side_effect=original_offload
) as mock_offload:
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=continuous_batching_config,
attn_implementation="sdpa",
max_new_tokens=30,
num_repeat_prompts=4,
)
self.assertTrue(mock_offload.called, "Offload method was not called.")
# ---------------------------------------Streaming tests--------------------------------------- #
# Ensures the requests have the right behavior with and without streaming #
# --------------------------------------------------------------------------------------------- #
def _test_streaming_or_not_request(self, with_streaming: bool, with_non_streaming: bool) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
tokenizer, model = get_tokenizer_and_model(model_id, "sdpa", torch_device)
manager = model.init_continuous_batching()
manager.logit_processor.clear()
manager.start()
user_messages = ["What is the Transformers library known for?"]
inputs = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=True)[0]
# Test with non-streaming
if with_non_streaming:
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False)
# In non-streaming mode, the total number of generated tokens is equal to the max new tokens
chunk = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk.generated_tokens), max_new_tokens)
# Test with streaming
if with_streaming:
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True)
# In streaming mode, the total number of generated tokens is incremented by 1 on each iteration
chunk_1 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_1.generated_tokens), 1)
chunk_2 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_2.generated_tokens), 2)
chunk_3 = next(manager.request_id_iter(request_id))
self.assertEqual(len(chunk_3.generated_tokens), 3)
manager.stop(block=True)
def test_streaming_request(self) -> None:
self._test_streaming_or_not_request(with_streaming=True, with_non_streaming=False)
def test_non_streaming_request(self) -> None:
self._test_streaming_or_not_request(with_streaming=False, with_non_streaming=True)
def test_streaming_and_non_streaming_requests_can_alternate(self) -> None:
self._test_streaming_or_not_request(with_streaming=True, with_non_streaming=True)
def test_register_result_handler(self) -> None:
"""Test that register_result_handler receives streaming outputs through the OutputRouter."""
import asyncio
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
max_new_tokens = 3
tokenizer, model = get_tokenizer_and_model(model_id, "sdpa", torch_device)
manager = model.init_continuous_batching()
manager.logit_processor.clear()
manager.start()
user_messages = ["What is the Transformers library known for?"]
inputs = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=True)[0]
async def collect_results():
token_counts = []
future = asyncio.get_running_loop().create_future()
def on_result(output):
token_counts.append(len(output.generated_tokens))
if output.is_finished():
future.set_result(True)
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True)
manager.register_result_handler(request_id, on_result)
await asyncio.wait_for(future, timeout=30)
return token_counts
token_counts = asyncio.run(collect_results())
# Streaming via handler: incremental token count, same as request_id_iter
self.assertEqual(token_counts, [1, 2, 3])
# Queue should be empty — everything went through the handler
self.assertTrue(manager.output_router.output_queue.empty())
manager.stop(block=True)
# -----------------------------------------Misc. tests----------------------------------------- #
# Various tests that don't fit into the other categories #
# --------------------------------------------------------------------------------------------- #
def _test_block_sharing(self, model_id: str, expected_layer_types: dict[str, int], input_msg: str) -> None:
# Use float32 for SDPA to handle precision differences from attention masks (same as parity test)
tokenizer, model = get_tokenizer_and_model(model_id, "sdpa", torch_device, dtype=torch.float32)
# Configure generation for parity: disable processors not supported by CB (like repetition_penalty)
model.generation_config.max_new_tokens = 32
model.generation_config.do_sample = False
model.generation_config.repetition_penalty = None
# Get expected output from regular generate for parity check
expected_output_tokens, _ = regular_generate(model, tokenizer, [input_msg])
cb_context_manager = model.continuous_batching_context_manager(
generation_config=model.generation_config,
continuous_batching_config=ContinuousBatchingConfig(block_size=32),
)
with cb_context_manager as manager:
# Create a request with at least 32 tokens but less than 64 so prefill only generates one complete block
inputs = get_generation_inputs([input_msg], tokenizer, for_continuous_batching=True)[0]
self.assertGreaterEqual(len(inputs), 32, f"Input length is {len(inputs)} instead of at least 32")
self.assertLess(len(inputs), 64, f"Input length is {len(inputs)} instead of less than 64")
# First request, which populates the cache w/ 2 complete blocks for each full attention layer group
request_id = manager.add_request(inputs, max_new_tokens=32)
chunk_no_reuse = next(manager.request_id_iter(request_id))
num_fa = expected_layer_types["full_attention"]
num_sw = expected_layer_types["sliding_window"]
if manager.batch_processor is None:
raise RuntimeError("Batch processor is None even after a request was added.")
hash_table = manager.batch_processor.cache._block_manager._hash_to_id
self.assertEqual(
len(hash_table),
2 * num_fa, # 2 = 1 for prefill + 1 for decode
f"There should be {2 * num_fa} blocks, 2 for each full attention layer group, but {len(hash_table) = }",
)
total_prefix_length = manager.batch_processor.cache._total_prefix_length
self.assertEqual(
total_prefix_length, 0, f"Expected total prefix length to be 0, got {total_prefix_length}"
)
# Assert the number of layer groups and their types are the expected ones
layer_groups = manager.batch_processor.cache.group_cache_managers
self.assertEqual(
len(layer_groups),
num_fa + num_sw,
f"There should be {num_fa + num_sw} layer groups, but {len(layer_groups) = }",
)
layer_group_types = {"full_attention": 0, "sliding_window": 0}
for cm in layer_groups:
if isinstance(cm, FullAttentionCacheAllocator):
layer_group_types["full_attention"] += 1
elif isinstance(cm, SlidingAttentionCacheAllocator):
layer_group_types["sliding_window"] += 1
else:
raise ValueError(f"Invalid layer group type: {type(cm)}")
self.assertEqual(
layer_group_types,
expected_layer_types,
f"The expected layer group types are\n{expected_layer_types}\nbut got\n{layer_group_types}",
)
# Second request, which should reuse the same blocks for the full attention layer groups
request_id = manager.add_request(inputs, max_new_tokens=32)
chunk_with_reuse = next(manager.request_id_iter(request_id))
# There should only still be two blocks in the hash table because of block reuse
self.assertEqual(
len(hash_table),
2 * num_fa,
f"Because of block reuse, there should still be two blocks in the hash table, but {len(hash_table) = }",
)
# Check that the whole prefill was matched if there are only full attention layers
if expected_layer_types["sliding_window"] == 0:
expected_total_prefix_length = 32
else:
expected_total_prefix_length = 0
total_prefix_length = manager.batch_processor.cache._total_prefix_length
self.assertEqual(
total_prefix_length,
expected_total_prefix_length,
f"Expected total prefix length to be {expected_total_prefix_length}, but got {total_prefix_length = }",
)
# Check the outputs were the same (block sharing should produce identical results)
self.assertEqual(chunk_no_reuse.generated_tokens, chunk_with_reuse.generated_tokens)
# Verify parity with regular generate
self.assertEqual(chunk_no_reuse.generated_tokens, expected_output_tokens[0])
def test_prefix_sharing(self) -> None:
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
num_layer_groups = {"full_attention": 1, "sliding_window": 0}
input_msg = "What is the Transformers library known for?"
return self._test_block_sharing(model_id, num_layer_groups, input_msg)
def test_block_sharing_with_hybrid_model(self) -> None:
model_id = "google/gemma-3-1b-it"
num_layer_groups = {"full_attention": 2, "sliding_window": 11}
input_msg = "I am a software engineer looking to use open source software to build a new AI agent. What is the Transformers library known for?"
return self._test_block_sharing(model_id, num_layer_groups, input_msg)
@parameterized.expand([True, False])
@require_flash_attn # otherwise the test can fail because attention bias has a very slight impact on SDPA and eager
def test_num_return_sequences(self, allow_block_sharing: bool) -> None:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer, model = get_tokenizer_and_model(model_id, "flash_attention_2", torch_device)
user_messages = _DEFAULT_USER_MESSAGES[:1]
input_ids = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=True)
model.generation_config.max_new_tokens = 30
model.generation_config.do_sample = False
# Generation with continuous batching
cb_context_manager = model.continuous_batching_context_manager(
continuous_batching_config=ContinuousBatchingConfig(allow_block_sharing=allow_block_sharing),
block=True,
timeout=5,
)
# Main loop
results = []
with cb_context_manager as manager:
manager.num_return_sequences = 2
manager.add_requests(inputs=input_ids, max_new_tokens=30)
requests_left = 2
while requests_left:
result = manager.get_result(timeout=1)
if result and result.is_finished():
results.append(result)
requests_left -= 1
else:
if not manager.is_running():
break
self.assertEqual(len(results), 2, f"Expected 2 results, but got {len(results) = }")
self.assertEqual(results[0].generated_tokens, results[1].generated_tokens)
# ----------------------------------Additional features tests---------------------------------- #
# Tests to check addtional features of CB do not change its results #
# --------------------------------------------------------------------------------------------- #
@parameterized.expand(
[
# SDPA: basic features or full features
("sdpa", False, False),
("sdpa", True, True),
# FA2: full coverage
("flash_attention_2", False, False),
("flash_attention_2", False, True),
("flash_attention_2", True, False),
("flash_attention_2", True, True),
# FA3: always turn on CUDA graphs
("flash_attention_3", True, False),
("flash_attention_3", True, True),
]
)
@slow
def test_continuous_batching_async(
self, attn_implementation: str, use_cuda_graph: bool, use_compile: bool
) -> None:
# Again, we try to not overly use_compile because it adds a lot of overhead
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=ContinuousBatchingConfig(
allow_block_sharing=True,
use_cuda_graph=use_cuda_graph,
use_async_batching=True,
use_default_compile_configs=use_compile,
),
attn_implementation=attn_implementation,
)
@parameterized.expand([(False, False), (False, True), (True, False), (True, True)])
@slow
@require_kernels
def test_flash_attn_with_kvcache_parity(self, use_cuda_graph: bool, use_async: bool) -> None:
"""Test that paged flash_attn3 (flash_attn_with_kvcache path) produces same outputs as varlen."""
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer, model = get_tokenizer_and_model(
model_id, "paged|kernels-community/flash-attn3", torch_device, torch.bfloat16
)
user_messages = _DEFAULT_USER_MESSAGES[:]
input_ids = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=True)
gen_config = GenerationConfig(do_sample=False, max_new_tokens=20)
continuous_batching_config = ContinuousBatchingConfig(
block_size=256,
num_blocks=64,
max_batch_tokens=16,
use_cuda_graph=use_cuda_graph,
use_async_batching=use_async,
)
# Generate with varlen path only
continuous_batching_config.max_blocks_per_request = 0
outputs_varlen = model.generate_batch(
inputs=input_ids, generation_config=gen_config, continuous_batching_config=continuous_batching_config
)
# Generate with flash_attn_with_kvcache path for decode
continuous_batching_config.max_blocks_per_request = 16
# This context manager ensures that the varlen path is used
og_get_block_table_key = PagedAttentionCache.get_block_table_key
with patch.object(
PagedAttentionCache, "get_block_table_key", autospec=True, side_effect=og_get_block_table_key
) as mock_get_block_table_key:
outputs_kvcache = model.generate_batch(
inputs=input_ids, generation_config=gen_config, continuous_batching_config=continuous_batching_config
)
self.assertTrue(mock_get_block_table_key.called, "get_block_table_key method was not called.")
self.assertEqual(len(outputs_varlen), len(outputs_kvcache))
for (_, out_fa2), (_, out_fa3) in zip(outputs_varlen.items(), outputs_kvcache.items()):
text_fa2 = tokenizer.decode(out_fa2.generated_tokens, skip_special_tokens=True)
text_fa3 = tokenizer.decode(out_fa3.generated_tokens, skip_special_tokens=True)
self.assertEqual(text_fa2, text_fa3, f"Mismatch:\nFA2: {text_fa2}\nFA3: {text_fa3}")
@parameterized.expand([(False, False), (False, True), (True, False), (True, True)])
@slow
def test_per_request_logits_processors(self, use_cuda_graph: bool, use_async_batching: bool) -> None:
"""Tests that per-request logits processor kwargs (temperature, top_k, top_p) work correctly in generation."""
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
max_new_tokens = 10
temperatures = [1.0, 1.0]
top_ks = [10, 50]
top_ps = [0.9, 0.99]
tokenizer, model = get_tokenizer_and_model(model_id, "flash_attention_2", torch_device)
eos_token_id = model.config.eos_token_id # type: ignore[attr-defined]
# Same prompt for both requests
user_messages = ["Write a random number:"]
input_ids = get_generation_inputs(user_messages, tokenizer, for_continuous_batching=True)[0]
# Use the context manager to add requests with different per-request kwargs
generation_config = GenerationConfig(
do_sample=True,
temperature=max(temperatures) + 1, # enables temperature warping
top_k=max(top_ks) + 1,
top_p=min(top_ps) - 0.01,
max_new_tokens=max_new_tokens,
eos_token_id=eos_token_id,
)
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=use_cuda_graph,
use_async_batching=use_async_batching,
per_request_processors=True,
return_logprobs=True,
q_padding_interval_size=16, # allows for exact comparison between CB and regular generation
)
manager = model.init_continuous_batching(
generation_config=generation_config,
continuous_batching_config=continuous_batching_config,
)
# Trick to have temperature, top-k, top-p ... without randomness: diable sampling after manager creation
manager.generation_config.do_sample = False
manager.start()
try:
# Request 0: low temperature (more deterministic)
req0_id = manager.add_request(
input_ids, max_new_tokens=max_new_tokens, temperature=temperatures[0], top_k=top_ks[0], top_p=top_ps[0]
)
# Request 1: high temperature (more random)
req1_id = manager.add_request(
input_ids, max_new_tokens=max_new_tokens, temperature=temperatures[1], top_k=top_ks[1], top_p=top_ps[1]
)
# Collect results
results = {}
while len(results) < 2:
result = manager.get_result(timeout=1)
if result is not None and result.is_finished():
results[result.request_id] = result
elif not manager.is_running():
break
finally:
manager.stop(block=True)
# Both requests should complete and have logprobs
self.assertEqual(len(results), 2, f"Expected 2 results, got {len(results)}")
self.assertGreater(len(results[req0_id].logprobs), 0)
self.assertGreater(len(results[req1_id].logprobs), 0)
# Also ensure the logprobs were not the same
self.assertNotEqual(results[req0_id].logprobs, results[req1_id].logprobs)
# Compare each request with regular generation
# Build logits processor with do_sample=True (so temperature is included), then set do_sample=False for
# deterministic generation, which is the same trick that CB uses
delta = 2e-5 if use_cuda_graph else 1e-5
for i, req_id in enumerate([req0_id, req1_id]):
tokenizer, model = get_tokenizer_and_model(model_id, "flash_attention_2", torch_device)
gen_config = GenerationConfig(
do_sample=True,
temperature=temperatures[i],
top_k=top_ks[i],
top_p=top_ps[i],
max_new_tokens=max_new_tokens,
eos_token_id=eos_token_id,
)
logits_processor = model._get_logits_processor(gen_config)
gen_config.do_sample = False
regular_generated_tokens, regular_logprobs = regular_generate(
model=model,
tokenizer=tokenizer,
user_messages=user_messages,
logits_processor=logits_processor,
max_new_tokens=max_new_tokens,
do_sample=False,
output_scores=True,
eos_token_id=eos_token_id,
)
self.assertEqual(results[req_id].generated_tokens, regular_generated_tokens[0])
for j, (cb_lp, exp_lp) in enumerate(zip(results[req_id].logprobs, regular_logprobs[0])):
error_msg = f"Request {i}: logprob mismatch at position {j}: CB={cb_lp}, expected={exp_lp}"
self.assertAlmostEqual(cb_lp, exp_lp, delta=delta, msg=error_msg)
# ---------------------------------- CPU offloading tests ---------------------------------- #
@require_torch_accelerator
def test_cpu_offloading_parity(self) -> None:
"""Test that CPU offloading produces the same results as the legacy soft-reset path, and that it is actually
called at least once. Uses a very small cache (few blocks) to force offloading."""
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=True,
allow_block_sharing=True,
use_async_batching=False,
num_blocks=4,
block_size=32,
cpu_offload_space=1.0,
)
original_offload = OffloadingManager._offload_to_cpu
with patch.object(
OffloadingManager, "_offload_to_cpu", autospec=True, side_effect=original_offload
) as mock_offload:
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=continuous_batching_config,
attn_implementation="sdpa",
max_new_tokens=30,
num_repeat_prompts=4,
)
self.assertTrue(mock_offload.called, "_offload_to_cpu was not called despite few blocks being available.")
@require_torch_accelerator
def test_cpu_offloading_disabled_when_zero(self) -> None:
"""Test that cpu_offload_space=0 produces the same output as the legacy path."""
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
continuous_batching_config = ContinuousBatchingConfig(
use_cuda_graph=True,
allow_block_sharing=True,
use_async_batching=False,
num_blocks=4,
block_size=32,
cpu_offload_space=0.0,
)
# Should work identically to the existing test_continuous_batching_few_blocks
self._test_continuous_batching_parity(
model_id=model_id,
continuous_batching_config=continuous_batching_config,
attn_implementation="sdpa",
max_new_tokens=30,
num_repeat_prompts=4,
)
@require_torch_gpu
class TestMemoryHandlerPrediction(unittest.TestCase):
"""Verifies that ``PagedAttentionMemoryHandler.compute_memory_footprint`` matches real GPU memory usage.
For each configuration we allocate tensors at the *idealized* sizes modeled by the handler (same shapes, same
dtypes, no alignment padding or extra blocks) and compare the CUDA memory delta to the handler's prediction.
"""
# (block_size, page_size, num_groups, group_size, peak_act, num_attn_masks, max_bpr, logprobs, cache_dtype, use_async_batching)
CONFIGS = [
(32, 256, 1, 22, 34048, 1, 0, False, torch.float16, False), # sdpa-like, 1 attn mask
(256, 256, 1, 22, 34048, 0, 0, False, torch.float16, False), # flash-like, no attn mask
(32, 256, 2, 14, 34048, 2, 0, False, torch.bfloat16, False), # hybrid model, 2 groups + 2 masks
(32, 128, 1, 16, 8192, 1, 8, True, torch.float16, False), # with block_table + logprobs
(32, 128, 1, 16, 8192, 1, 8, True, torch.float16, True), # with block_table + logprobs + async batching
]
NUM_BLOCKS = 4
MAX_BATCH_TOKENS = 64
@parameterized.expand(CONFIGS)
def test_memory_prediction(
self,
block_size: int,
page_size: int,
num_groups: int,
group_size: int,
peak_act: int,
num_attn_masks: int,
max_bpr: int,
logprobs: bool,
cache_dtype: torch.dtype,
use_async_batching: bool,
) -> None:
cb_config = ContinuousBatchingConfig(
max_blocks_per_request=max_bpr,
return_logprobs=logprobs,
use_async_batching=use_async_batching,
block_size=block_size,
)
handler = PagedAttentionMemoryHandler(
continuous_batching_config=cb_config,
page_size=page_size,
num_groups=num_groups,
group_size=group_size,
activation_peaks=[(0, peak_act)],
num_attention_masks=num_attn_masks,
)
N = self.NUM_BLOCKS * block_size # num_pages
M = self.MAX_BATCH_TOKENS
predicted = handler.compute_memory_footprint(self.NUM_BLOCKS, M, cache_dtype)
num_output_rows = 2 if logprobs else 1
act_dtype = handler._activation_dtype
i32 = handler._input_dtype
# -- Allocate tensors at the exact idealized sizes the handler models --
device = "cuda"
torch.cuda.empty_cache()
baseline = torch.cuda.memory_allocated(device)
k = handler.io_multiplier # 1 sync, 2 async -- scales IO tensors only
tensors = []
# kv_cache: 2 * group_size tensors of [N, page_size] (not scaled by k)
for _ in range(group_size):
tensors.append(torch.empty((N, page_size), dtype=cache_dtype, device=device))
tensors.append(torch.empty((N, page_size), dtype=cache_dtype, device=device))
# activation peak: flat tensor of peak_act * M elements (not scaled by k)
tensors.append(torch.empty(peak_act * M, dtype=act_dtype, device=device))
# IO tensors below are allocated k times (once per IO instance)
for _ in range(k):
# bulk_input: [7, M]
tensors.append(torch.empty((7, M), dtype=i32, device=device))
# output_ids: [num_output_rows, M]
tensors.append(torch.empty((num_output_rows, M), dtype=i32, device=device))
# attention_mask: [1, 1, M, N + M] per mask type
for _ in range(num_attn_masks):
tensors.append(torch.empty((1, 1, M, N + M), dtype=act_dtype, device=device))
# block_table: [num_groups, M, max_bpr] (empty when max_bpr == 0)
if max_bpr > 0:
tensors.append(torch.empty((num_groups, M, max_bpr), dtype=i32, device=device))
# write_index: [num_groups, M]
tensors.append(torch.empty((num_groups, M), dtype=torch.int64, device=device))
# read_index: [num_groups, N + M]
tensors.append(torch.empty((num_groups, N + M), dtype=torch.int64, device=device))
actual_cuda = torch.cuda.memory_allocated(device) - baseline
expected_nbytes = sum(t.nbytes for t in tensors)
num_allocations = len(tensors)
del tensors
torch.cuda.empty_cache()
# 1) Exact check: prediction must equal the sum of tensor nbytes. This validates the polynomial
# coefficients against the tensor shapes, with zero tolerance.
self.assertEqual(
predicted,
expected_nbytes,
f"Prediction ({predicted}) != sum of tensor nbytes ({expected_nbytes})",
)
# 2) GPU memory check: CUDA's caching allocator rounds each allocation up (typically to 512 bytes).
# We allow up to 512 bytes of overhead per allocation.
max_cuda_overhead = num_allocations * 512
self.assertLessEqual(
abs(actual_cuda - predicted),
max_cuda_overhead,
f"CUDA delta ({actual_cuda}) too far from prediction ({predicted}), "
f"allowed overhead = {max_cuda_overhead} ({num_allocations} allocs × 512B)",
)
# Worker functions for the TP continuous batching tests, spawned through `_init_distributed`.
def _tp_continuous_batching_worker(
rank: int,
model_id: str,
attn_implementation: str,
max_new_tokens: int,
do_sample: bool,
seed: int,
use_cuda_graph: bool,
use_async_batching: bool,
) -> None:
"""Loads `model_id` with `tp_plan="auto"`, checks three TP-specific paths in the same process: (a) direct
broadcasts via `DistributedHelper`, (b) per-rank parity of CB-generated tokens via `dist.all_gather_object`, and
(c) reproducibility across two CB runs sharing the same seed. Rank 0 owns all the assertions; the other ranks
only need to participate in the collectives."""
import torch
import torch.distributed as dist
from transformers.generation.continuous_batching.distributed import DistributedHelper
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
if not hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token"):
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, tp_plan="auto", dtype=torch.float32
).eval()
# Direct broadcast tests: only rank 0's value should propagate to every TP rank
helper = DistributedHelper(device_mesh=model._device_mesh, cpu_group_timeout=300)
received_obj = helper.tp_broadcast_object_from_rank_0({"src_rank": rank})
assert received_obj == {"src_rank": 0}, f"tp_broadcast_object: rank {rank} got {received_obj}"
sent_tensor = torch.tensor([float(rank)], device=model.device)
helper.tp_broadcast_from_rank_0(sent_tensor)
assert sent_tensor.item() == 0.0, f"tp_broadcast_from_rank_0: rank {rank} got {sent_tensor.item()}"
# CB runs: same seed twice, assert reproducibility AND cross-rank parity
user_messages = [
"A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?"
]
chats = [[{"role": "user", "content": m}] for m in user_messages]
tokenized = [tokenizer.apply_chat_template(chat, add_generation_prompt=True) for chat in chats]
input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized]
cb_config_kwargs = {"use_cuda_graph": use_cuda_graph, "use_async_batching": use_async_batching, "seed": seed}
gen_config = GenerationConfig(do_sample=do_sample, max_new_tokens=max_new_tokens)
first_outputs = model.generate_batch(
inputs=input_ids,
generation_config=gen_config,
continuous_batching_config=ContinuousBatchingConfig(**cb_config_kwargs),
)
second_outputs = model.generate_batch(
inputs=input_ids,
generation_config=gen_config,
continuous_batching_config=ContinuousBatchingConfig(**cb_config_kwargs),
)
# Cross-rank parity: every TP rank must produce the same tokens, otherwise the seed broadcast / TP collectives are
# diverging silently. Gather the first run's tokens onto all ranks and let rank 0 compare.
local_tokens = [out.generated_tokens for out in first_outputs.values()]
gathered_tokens = [None] * helper.tp_size
dist.all_gather_object(gathered_tokens, local_tokens, group=helper.tp_group)
if rank != 0:
return
assert len(first_outputs) == len(input_ids), f"Expected {len(input_ids)} CB outputs, got {len(first_outputs)}"
for i, (_, output) in enumerate(first_outputs.items()):
assert len(output.generated_tokens) > 0, f"Request {i} got no generated tokens"
for src_rank, src_tokens in enumerate(gathered_tokens):
if src_tokens != gathered_tokens[0]:
raise AssertionError(
f"TP continuous batching diverges across ranks: rank {src_rank} got {src_tokens}, rank 0 got "
f"{gathered_tokens[0]}"
)
second_tokens = [out.generated_tokens for out in second_outputs.values()]
if local_tokens != second_tokens:
raise AssertionError(
f"TP continuous batching is not reproducible across runs with the same seed\n"
f"First run : {local_tokens}\n"
f"Second run: {second_tokens}"
)
def _tp_cancellation_worker(
rank: int,
model_id: str,
attn_implementation: str,
use_cuda_graph: bool = False,
use_async_batching: bool = False,
) -> None:
"""Loads `model_id` with `DistributedConfig(tp_size=...)`, submits a long-running streaming request, and cancels it mid-flight.
The cancellation goes through the cancel-queue + `tp_broadcast_object` path: if the broadcast were broken, the
non-driver rank's scheduler would not learn about the cancellation and the test would hang or crash on the next
TP forward pass. Rank 0 owns the assertions."""
import time
import torch
cb_config = ContinuousBatchingConfig(use_cuda_graph=use_cuda_graph, use_async_batching=use_async_batching)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
if not hasattr(tokenizer, "pad_token") and hasattr(tokenizer, "eos_token"):
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, tp_plan="auto", dtype=torch.float32
).eval()
chat = [{"role": "user", "content": "Tell me a long story about a robot exploring the galaxy."}]
tokenized = tokenizer.apply_chat_template(chat, add_generation_prompt=True)
inputs = tokenized if isinstance(tokenized, list) else tokenized["input_ids"]
max_new_tokens = 200
cancel_after_n_chunks = 3
manager = model.init_continuous_batching(continuous_batching_config=cb_config)
manager.logit_processor.clear()
# Warm up synchronously so CUDA-graph capture doesn't eat the streaming-loop deadline below
manager.warmup()
manager.start()
try:
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True)
chunks_seen = 0
cancelled = False
deadline = time.time() + 60
while time.time() < deadline:
chunk = manager.get_result(request_id=request_id, timeout=2.0)
if chunk is None:
# No new chunks for 2s after cancel — cancellation took effect on every rank
break
chunks_seen += 1
if chunks_seen >= cancel_after_n_chunks and not cancelled:
manager.cancel_request(request_id)
cancelled = True
if rank == 0:
assert cancelled, "Test setup did not reach the cancel call"
assert chunks_seen < max_new_tokens, (
f"Cancellation did not stop generation early: saw {chunks_seen} chunks "
f"for max_new_tokens={max_new_tokens}"
)
finally:
manager.stop(block=True)
@require_torch_multi_accelerator
class ContinuousBatchingTensorParallelTest(unittest.TestCase):
"""Integration tests for continuous batching with tensor parallelism. Each test spawns a TP-sized process group
via `_init_distributed` (see `tests/test_tensor_parallel_mixin.py`) with the NCCL backend."""
@property
def tp_size(self) -> int:
return min(torch.cuda.device_count(), 2)
def _run_cb_worker(self, max_new_tokens: int = 20, **worker_kwargs) -> None:
"""Spawn `_tp_continuous_batching_worker` on `tp_size` NCCL processes with sensible defaults."""
defaults = {
"model_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"attn_implementation": "sdpa",
"max_new_tokens": max_new_tokens,
"do_sample": False,
"seed": 42,
"use_cuda_graph": False,
"use_async_batching": False,
}
defaults.update(worker_kwargs)
_init_distributed(tp=self.tp_size, backend="nccl")(_tp_continuous_batching_worker)(**defaults)
def test_continuous_batching_tp_fast(self) -> None:
"""Test that continuous batching with `DistributedConfig(tp_size=...)` produces non-empty, reproducible greedy outputs and
that all TP ranks agree on the generated tokens."""
self._run_cb_worker(max_new_tokens=4)
@slow
def test_continuous_batching_tp_greedy(self) -> None:
"""Test that continuous batching with `DistributedConfig(tp_size=...)` produces non-empty, reproducible greedy outputs and
that all TP ranks agree on the generated tokens."""
self._run_cb_worker()
@slow
def test_continuous_batching_tp_with_sampling(self) -> None:
"""Test that continuous batching with TP and sampling is reproducible across runs with the same seed and that
all TP ranks agree on the sampled tokens — implicitly validating the seed broadcast from rank 0."""
self._run_cb_worker(do_sample=True, seed=123)
@slow
def test_continuous_batching_tp_with_cuda_graph(self) -> None:
"""Test that continuous batching with TP and CUDA graphs is reproducible across runs and that all TP ranks
agree on the generated tokens — captured-graph collectives must stay in sync across ranks."""
self._run_cb_worker(use_cuda_graph=True)
@slow
def test_continuous_batching_tp_with_cuda_graph_and_async(self) -> None:
"""Test that continuous batching with TP, CUDA graphs, and async batching is reproducible across runs and
that all TP ranks agree on the generated tokens — the toughest combination, exercising both captured-graph
collectives and the async producer/consumer split."""
self._run_cb_worker(use_cuda_graph=True, use_async_batching=True)
@slow
def test_continuous_batching_tp_cancellation(self) -> None:
"""Test that `cancel_request` propagates across the TP group: the driver enqueues the cancellation, broadcasts
it to non-driver ranks via `tp_broadcast_object`, and generation stops well before `max_new_tokens`."""
_init_distributed(tp=self.tp_size, backend="nccl")(_tp_cancellation_worker)(
model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
attn_implementation="sdpa",
)
@slow
def test_continuous_batching_tp_cancellation_realistic(self) -> None:
"""Test that `cancel_request` propagates across the TP group: the driver enqueues the cancellation, broadcasts
it to non-driver ranks via `tp_broadcast_object`, and generation stops well before `max_new_tokens`."""
_init_distributed(tp=self.tp_size, backend="nccl")(_tp_cancellation_worker)(
model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
attn_implementation="sdpa",
use_async_batching=True,
use_cuda_graph=True,
)