first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
0
tests/generation/__init__.py
Normal file
0
tests/generation/__init__.py
Normal file
338
tests/generation/test_candidate_generator.py
Normal file
338
tests/generation/test_candidate_generator.py
Normal file
@@ -0,0 +1,338 @@
|
||||
import gc
|
||||
import unittest
|
||||
import weakref
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline
|
||||
from transformers.generation.candidate_generator import (
|
||||
AssistantToTargetTranslator,
|
||||
AssistantVocabTranslatorCache,
|
||||
UniversalSpeculativeDecodingGenerator,
|
||||
)
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestAssistantToTargetTranslator(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create mock tokenizers with predefined vocabularies
|
||||
self.target_tokenizer = MagicMock()
|
||||
self.assistant_tokenizer = MagicMock()
|
||||
self.assistant_model = MagicMock(device=torch_device)
|
||||
|
||||
# Define mock vocabularies for the tokenizers
|
||||
self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3}
|
||||
self.assistant_vocab = {"hello": 0, "world": 1, "foo": 2, "baz": 4}
|
||||
|
||||
self.target_tokenizer.get_vocab.return_value = self.target_vocab
|
||||
self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab
|
||||
self.target_vocab_size = 6
|
||||
|
||||
# Instantiate the class under test
|
||||
self.translator = AssistantToTargetTranslator(
|
||||
target_tokenizer=self.target_tokenizer,
|
||||
assistant_tokenizer=self.assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
|
||||
def test_get_assistant_to_target_input_ids(self):
|
||||
"""Test the mapping from assistant tokens to target tokens."""
|
||||
expected_mapping = [0, 1, 2, self.translator.SUPPRESS_TOKEN_ID, self.translator.SUPPRESS_TOKEN_ID]
|
||||
actual_mapping = self.translator._assistant_to_target_input_ids.tolist()
|
||||
self.assertEqual(actual_mapping, expected_mapping)
|
||||
|
||||
def test_get_suppress_input_ids(self):
|
||||
"""Test the suppression of assistant input IDs not present in the target vocabulary."""
|
||||
expected_suppress_ids = [3, 4]
|
||||
actual_suppress_ids = self.translator._get_suppress_input_ids().tolist()
|
||||
self.assertEqual(actual_suppress_ids, expected_suppress_ids)
|
||||
|
||||
def test_get_target_ids(self):
|
||||
"""Test the translation of assistant candidate IDs to target candidate IDs."""
|
||||
assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo' in assistant tokenizer
|
||||
target_input_ids = torch.LongTensor([[0, 1, 2]]).to(
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo' in target tokenizer
|
||||
assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to(
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo baz' in assistant tokenizer
|
||||
|
||||
expected_target_ids = torch.LongTensor(
|
||||
[[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]]
|
||||
).to(
|
||||
self.assistant_model.device
|
||||
) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab)
|
||||
|
||||
actual_target_ids = self.translator.get_target_ids(
|
||||
assistant_input_ids, target_input_ids, assistant_candidate_ids
|
||||
)
|
||||
self.assertTrue(torch.equal(actual_target_ids, expected_target_ids))
|
||||
|
||||
def test_get_target_logits(self):
|
||||
"""Test the conversion of assistant logits to target logits."""
|
||||
# Assistant logits for IDs 0, 1, 2
|
||||
assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to(
|
||||
self.assistant_model.device
|
||||
) # Shape (1, 1, 5)
|
||||
|
||||
# Expected target logits (target_vocab_size = 4)
|
||||
expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to(
|
||||
self.assistant_model.device
|
||||
)
|
||||
expected_target_logits[0, 0, 0] = 0.1 # 'hello'
|
||||
expected_target_logits[0, 0, 1] = 0.2 # 'world'
|
||||
expected_target_logits[0, 0, 2] = 0.3 # 'foo'
|
||||
# The 'bar' token in target vocab remains at -inf
|
||||
|
||||
actual_target_logits = self.translator.get_target_logits(assistant_logits)
|
||||
self.assertTrue(torch.equal(actual_target_logits, expected_target_logits))
|
||||
|
||||
|
||||
class MockTokenizer:
|
||||
"""A simple mock tokenizer class that supports weak references."""
|
||||
|
||||
def __init__(self, vocab=None):
|
||||
self._vocab = vocab or {}
|
||||
|
||||
def get_vocab(self):
|
||||
return self._vocab
|
||||
|
||||
def __call__(self, text, add_special_tokens=True):
|
||||
# Mock implementation of the __call__ method
|
||||
tokens = text.split()
|
||||
input_ids = [self._vocab.get(token, 0) for token in tokens]
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestAssistantVocabTranslatorCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Clear the cache before each test
|
||||
AssistantVocabTranslatorCache._cache.clear()
|
||||
# Create mock tokenizers with different vocabularies
|
||||
self.target_tokenizer = MockTokenizer({"hello": 0, "world": 1})
|
||||
self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2})
|
||||
self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3})
|
||||
self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5})
|
||||
self.assistant_model = MagicMock(device=torch_device)
|
||||
|
||||
self.target_vocab_size = 6
|
||||
|
||||
def test_same_instance_for_same_tokenizers(self):
|
||||
"""Test that the same translator is returned for the same tokenizers."""
|
||||
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer,
|
||||
self.assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer,
|
||||
self.assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
self.assertIs(translator1, translator2, "Translators should be cached and identical")
|
||||
|
||||
def test_different_instances_for_different_tokenizers(self):
|
||||
"""Test that different tokenizers produce different translators."""
|
||||
translator1 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.target_tokenizer,
|
||||
self.assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
translator2 = AssistantVocabTranslatorCache.get_translator(
|
||||
self.other_target_tokenizer,
|
||||
self.other_assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers")
|
||||
|
||||
def test_cache_with_weakref_key(self):
|
||||
"""Ensure that the cache uses weak references as keys."""
|
||||
initial_cache_size = len(AssistantVocabTranslatorCache._cache)
|
||||
target_tokenizer = MockTokenizer({"hello": 0})
|
||||
assistant_tokenizer = MockTokenizer({"hello": 0})
|
||||
|
||||
# Store translator in a local variable to avoid it being kept alive
|
||||
translator = AssistantVocabTranslatorCache.get_translator(
|
||||
target_tokenizer,
|
||||
assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
|
||||
|
||||
# Delete all strong references
|
||||
del target_tokenizer
|
||||
del assistant_tokenizer
|
||||
del translator
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Call cleanup to remove dead entries
|
||||
AssistantVocabTranslatorCache.cleanup()
|
||||
|
||||
# The cache size remains increased due to strong references
|
||||
self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1)
|
||||
|
||||
def test_weakref_cache_cleanup(self):
|
||||
"""Test that the cache cleans up translators when tokenizers are garbage collected."""
|
||||
|
||||
def create_translator():
|
||||
target_tokenizer = MockTokenizer({"hello": 0})
|
||||
assistant_tokenizer = MockTokenizer({"hello": 0})
|
||||
translator = AssistantVocabTranslatorCache.get_translator(
|
||||
target_tokenizer,
|
||||
assistant_tokenizer,
|
||||
target_vocab_size=self.target_vocab_size,
|
||||
assistant_model=self.assistant_model,
|
||||
assistant_prune_lm_head=False,
|
||||
)
|
||||
# Create weak references before returning
|
||||
refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer))
|
||||
# Remove strong references inside the function
|
||||
del target_tokenizer
|
||||
del assistant_tokenizer
|
||||
del translator
|
||||
return refs
|
||||
|
||||
translator_ref, target_ref, assistant_ref = create_translator()
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Call cleanup to remove dead entries
|
||||
AssistantVocabTranslatorCache.cleanup()
|
||||
|
||||
# The tokenizers and translator are not garbage collected due to strong references
|
||||
self.assertIsNotNone(target_ref(), "Target tokenizer should still be alive due to strong references")
|
||||
self.assertIsNotNone(assistant_ref(), "Assistant tokenizer should still be alive due to strong references")
|
||||
self.assertIsNotNone(translator_ref(), "Translator should still be alive due to strong references")
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestUniversalSpeculativeDecoding(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.target_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
cls.assistant_name = "hf-internal-testing/tiny-random-PhiForCausalLM"
|
||||
|
||||
def setUp(self):
|
||||
self.target_tokenizer = AutoTokenizer.from_pretrained(self.target_name)
|
||||
self.target_config = AutoConfig.from_pretrained(self.target_name)
|
||||
self.assistant_model = AutoModelForCausalLM.from_pretrained(self.assistant_name).to(torch_device)
|
||||
self.assistant_tokenizer = AutoTokenizer.from_pretrained(self.assistant_name)
|
||||
self.generation_config = GenerationConfig(max_length=20, min_length=0)
|
||||
|
||||
# Ensure required tokens exist
|
||||
if self.target_tokenizer.pad_token_id is None:
|
||||
self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id
|
||||
if self.target_tokenizer.bos_token_id is None:
|
||||
self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id
|
||||
if self.assistant_tokenizer.pad_token_id is None:
|
||||
self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id
|
||||
if self.assistant_tokenizer.bos_token_id is None:
|
||||
self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id
|
||||
|
||||
self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
|
||||
self.model_kwargs = {
|
||||
"attention_mask": torch.ones_like(self.input_ids).to(torch_device),
|
||||
}
|
||||
atm_translator = AssistantVocabTranslatorCache.get_translator(
|
||||
target_tokenizer=self.target_tokenizer,
|
||||
assistant_tokenizer=self.assistant_tokenizer,
|
||||
assistant_model=self.assistant_model,
|
||||
target_vocab_size=self.target_config.vocab_size,
|
||||
)
|
||||
self.generator = UniversalSpeculativeDecodingGenerator(
|
||||
input_ids=self.input_ids,
|
||||
assistant_model=self.assistant_model,
|
||||
target_tokenizer=self.target_tokenizer,
|
||||
assistant_tokenizer=self.assistant_tokenizer,
|
||||
generation_config=self.generation_config,
|
||||
model_kwargs=self.model_kwargs,
|
||||
atm_translator=atm_translator,
|
||||
)
|
||||
|
||||
def test_basic_generation(self):
|
||||
"""Test basic speculative decoding works"""
|
||||
input_text = "The quick brown fox"
|
||||
input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt")
|
||||
self.generator.input_ids = input_ids
|
||||
candidates, scores = self.generator.get_candidates(input_ids)
|
||||
|
||||
self.assertIsNotNone(candidates)
|
||||
self.assertIsNotNone(scores)
|
||||
self.assertTrue(torch.is_tensor(candidates))
|
||||
self.assertTrue(torch.is_tensor(scores))
|
||||
|
||||
def test_mismatched_vocabularies(self):
|
||||
"""Test handling of mismatched vocabularies between models"""
|
||||
# Create input with tokens present in main but not assistant vocab
|
||||
# Find a token that is not in the assistant tokenizer but in
|
||||
# the main tokenizer.
|
||||
missing_token = next(
|
||||
token
|
||||
for token in self.target_tokenizer.get_vocab()
|
||||
if token not in self.assistant_tokenizer.get_vocab()
|
||||
and token not in self.target_tokenizer.all_special_tokens
|
||||
and "reserved_" not in token
|
||||
)
|
||||
input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]])
|
||||
self.generator.input_ids = input_ids
|
||||
candidates, _ = self.generator.get_candidates(input_ids)
|
||||
self.assertIsNotNone(candidates)
|
||||
|
||||
def test_speculation_depth(self):
|
||||
"""Test different speculation depths"""
|
||||
input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt")
|
||||
self.generator.input_ids = input_ids
|
||||
|
||||
for depth in [1, 8, 17]:
|
||||
self.generator.num_assistant_tokens = depth
|
||||
candidates, _ = self.generator.get_candidates(input_ids)
|
||||
self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth)
|
||||
|
||||
def test_device_consistency(self):
|
||||
"""Test handling of inputs on different devices"""
|
||||
input_ids = torch.tensor([[1, 2, 3]]).to(torch_device)
|
||||
self.generator.input_ids = input_ids
|
||||
candidates, _ = self.generator.get_candidates(input_ids)
|
||||
self.assertEqual(candidates.device, input_ids.device)
|
||||
|
||||
def test_usd_vs_vanilla_sampling(cls):
|
||||
"""Test that USD matches vanilla sampling with temperature set to nearly 0"""
|
||||
prompt = "Test text"
|
||||
|
||||
pipe_vanilla = pipeline(
|
||||
"text-generation",
|
||||
model=cls.target_name,
|
||||
)
|
||||
pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False)
|
||||
vanilla_text = pipe_vanilla_output[0]["generated_text"]
|
||||
|
||||
pipe_usd = pipeline(
|
||||
"text-generation",
|
||||
model=cls.target_name,
|
||||
assistant_model=cls.assistant_name,
|
||||
)
|
||||
pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) # Nearly 0 temperature
|
||||
usd_text = pipe_usd_output[0]["generated_text"]
|
||||
|
||||
# Assert that the outputs match
|
||||
cls.assertEqual(usd_text, vanilla_text)
|
||||
843
tests/generation/test_configuration_utils.py
Normal file
843
tests/generation/test_configuration_utils.py
Normal file
@@ -0,0 +1,843 @@
|
||||
# Copyright 2022 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from huggingface_hub import create_pull_request
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, GenerationConfig, WatermarkingConfig, is_torch_available
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
GenerationMode,
|
||||
MinLengthLogitsProcessor,
|
||||
MinNewTokensLengthLogitsProcessor,
|
||||
MinPLogitsWarper,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
CaptureLogger,
|
||||
LoggingLevel,
|
||||
TemporaryHubRepo,
|
||||
is_staging_test,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
class GenerationConfigTest(unittest.TestCase):
|
||||
@parameterized.expand([(None,), ("foo.json",)])
|
||||
def test_save_load_config(self, config_name):
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
bad_words_ids=[[1, 2, 3], [4, 5]],
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, config_name=config_name)
|
||||
loaded_config = GenerationConfig.from_pretrained(tmp_dir, config_name=config_name)
|
||||
|
||||
# Checks parameters that were specified
|
||||
self.assertEqual(loaded_config.do_sample, True)
|
||||
self.assertEqual(loaded_config.temperature, 0.7)
|
||||
self.assertEqual(loaded_config.length_penalty, 1.0)
|
||||
self.assertEqual(loaded_config.bad_words_ids, [[1, 2, 3], [4, 5]])
|
||||
|
||||
# Checks parameters that were not specified (defaults)
|
||||
self.assertEqual(loaded_config.top_k, None)
|
||||
self.assertEqual(loaded_config.max_length, None)
|
||||
self.assertEqual(loaded_config.max_time, None)
|
||||
|
||||
def test_from_model_config(self):
|
||||
model_config = AutoConfig.from_pretrained("openai-community/gpt2")
|
||||
generation_config_from_model = GenerationConfig.from_model_config(model_config)
|
||||
default_generation_config = GenerationConfig()
|
||||
|
||||
# The generation config has loaded a few non-default parameters from the model config
|
||||
self.assertNotEqual(generation_config_from_model, default_generation_config)
|
||||
|
||||
# One of those parameters is eos_token_id -- check if it matches
|
||||
self.assertNotEqual(generation_config_from_model.eos_token_id, default_generation_config.eos_token_id)
|
||||
self.assertEqual(generation_config_from_model.eos_token_id, model_config.eos_token_id)
|
||||
|
||||
def test_update(self):
|
||||
generation_config = GenerationConfig()
|
||||
update_kwargs = {
|
||||
"max_new_tokens": 1024,
|
||||
"foo": "bar",
|
||||
}
|
||||
update_kwargs_copy = copy.deepcopy(update_kwargs)
|
||||
unused_kwargs = generation_config.update(**update_kwargs)
|
||||
|
||||
# update_kwargs was not modified (no side effects)
|
||||
self.assertEqual(update_kwargs, update_kwargs_copy)
|
||||
|
||||
# update_kwargs was used to update the config on valid attributes
|
||||
self.assertEqual(generation_config.max_new_tokens, 1024)
|
||||
|
||||
# `.update()` returns a dictionary of unused kwargs
|
||||
self.assertEqual(unused_kwargs, {"foo": "bar"})
|
||||
|
||||
def test_kwarg_init(self):
|
||||
"""Tests that we can overwrite attributes at `from_pretrained` time."""
|
||||
default_config = GenerationConfig()
|
||||
self.assertEqual(default_config.temperature, None)
|
||||
self.assertEqual(default_config.do_sample, None)
|
||||
self.assertEqual(default_config.num_beams, None)
|
||||
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
bad_words_ids=[[1, 2, 3], [4, 5]],
|
||||
)
|
||||
self.assertEqual(config.temperature, 0.7)
|
||||
self.assertEqual(config.do_sample, True)
|
||||
self.assertEqual(config.num_beams, None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir)
|
||||
loaded_config = GenerationConfig.from_pretrained(tmp_dir, temperature=1.0)
|
||||
|
||||
self.assertEqual(loaded_config.temperature, 1.0)
|
||||
self.assertEqual(loaded_config.do_sample, True)
|
||||
self.assertEqual(loaded_config.num_beams, None) # default value
|
||||
|
||||
def test_validate(self):
|
||||
"""
|
||||
Tests that the `validate` method is working as expected. Note that `validate` is called at initialization time
|
||||
"""
|
||||
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
|
||||
|
||||
# A correct configuration will not throw any warning
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig()
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Inconsequent but technically wrong configuration will throw a warning (e.g. requesting an extra output
|
||||
# without `return_dict_in_generate=True`). May be escalated to an error in the future.
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(return_dict_in_generate=False, output_scores=True)
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Explicitly setting a sampling flag alongside `do_sample=False` still warns: this is a user-level mistake.
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
# But a value inherited from a model's default config (i.e. not in this update's kwargs) does NOT warn: in
|
||||
# the real world, `generate(do_sample=False)` on a model whose `generation_config.json` has `temperature=0.6`
|
||||
# would otherwise log a useless warning.
|
||||
logger.warning_once.cache_clear()
|
||||
base_config = GenerationConfig(do_sample=True, temperature=0.6) # mimics a model's default config
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
base_config.update(do_sample=False)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Inverse provenance case: `do_sample=False` inherited from a model's config (so not user-set this call), user only
|
||||
# sets a sampling flag. The conflict SHOULD produce noise because the user may think that it's non-greedy by default
|
||||
logger.warning_once.cache_clear()
|
||||
greedy_hub_config = GenerationConfig(do_sample=False) # mimics a model's default config forcing greedy
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
greedy_hub_config.update(top_p=0.8)
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Updating only `temperature` (do_sample was pre-existing, i.e. "from the hub") does warn
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
generation_config_bad_temperature.update(temperature=0.9)
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
# But setting both in the same `update()` call DOES warn.
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
generation_config_bad_temperature.update(do_sample=False, temperature=0.9)
|
||||
self.assertNotEqual(len(captured_logs.out), 0)
|
||||
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
# OK - None means it is unset, nothing to warn about
|
||||
generation_config_bad_temperature.update(temperature=None)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# Impossible sets of parameters will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(do_sample=False, num_beams=1, num_return_sequences=2)
|
||||
|
||||
# Passing `generate()`-only flags to `validate` will raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
GenerationConfig(logits_processor="foo")
|
||||
|
||||
# Model-specific parameters will NOT raise an exception or a warning
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(foo="bar")
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
|
||||
# By default we throw a short warning. However, we log with INFO level the details.
|
||||
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details.
|
||||
logger.warning_once.cache_clear()
|
||||
with LoggingLevel(logging.WARNING):
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertNotIn("0.5", captured_logs.out)
|
||||
self.assertTrue(len(captured_logs.out) < 150) # short log
|
||||
self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
|
||||
|
||||
# INFO level: we share the full deets
|
||||
logger.warning_once.cache_clear()
|
||||
logger.info_once.cache_clear()
|
||||
with LoggingLevel(logging.INFO):
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
GenerationConfig(do_sample=False, temperature=0.5)
|
||||
self.assertIn("0.5", captured_logs.out)
|
||||
self.assertTrue(len(captured_logs.out) > 400) # long log
|
||||
self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
|
||||
|
||||
# Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning.
|
||||
generation_config = GenerationConfig()
|
||||
generation_config.temperature = 0.5
|
||||
generation_config.do_sample = False
|
||||
with self.assertRaises(ValueError):
|
||||
generation_config.validate(strict=True)
|
||||
|
||||
def test_validate_sampling_flag_provenance(self):
|
||||
"""
|
||||
Dedicated coverage for the provenance-aware warning rule on sampling-only flags:
|
||||
we only warn when BOTH `do_sample=False` AND a conflicting sampling flag (e.g. `top_p`, `temperature`)
|
||||
were explicitly provided by the caller in the same context, or none of the 2 were directly provided, or only
|
||||
the sampling flag is provided along do_sample=False already existing.
|
||||
"""
|
||||
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
|
||||
|
||||
def _warn_count(fn):
|
||||
logger.warning_once.cache_clear()
|
||||
with CaptureLogger(logger) as captured:
|
||||
fn()
|
||||
return len(captured.out)
|
||||
|
||||
# 1. Hub config sets `temperature`, user does only `generate(do_sample=False)` -> NO warning.
|
||||
# (Emulates: model whose `generation_config.json` carries `do_sample=True, temperature=0.6`, user
|
||||
# explicitly asks for greedy decoding.)
|
||||
def case_hub_temp_user_do_sample_only():
|
||||
cfg = GenerationConfig(do_sample=True, temperature=0.6) # stands in for the hub default
|
||||
cfg.update(do_sample=False)
|
||||
|
||||
self.assertEqual(_warn_count(case_hub_temp_user_do_sample_only), 0)
|
||||
|
||||
# 2. User explicitly sets BOTH `do_sample=False` and `top_p=0.8` in the same call -> WARN.
|
||||
self.assertNotEqual(_warn_count(lambda: GenerationConfig(do_sample=False, top_p=0.8)), 0)
|
||||
|
||||
# 3. User explicitly sets only `do_sample=False` (no sampling flag) -> NO warning, even though
|
||||
# attribute defaults (like `top_k=50`) may be present.
|
||||
self.assertEqual(_warn_count(lambda: GenerationConfig(do_sample=False)), 0)
|
||||
|
||||
# 4. Hub config forces greedy (`do_sample=False`), user sets only `top_p=0.8` -> warnings:
|
||||
# do_sample` was inherited, but clashes with user-expressed intent, so flagging their `top_p`
|
||||
def case_hub_greedy_user_top_p():
|
||||
cfg = GenerationConfig(do_sample=False) # stands in for the hub default
|
||||
cfg.update(top_p=0.8)
|
||||
|
||||
self.assertNotEqual(_warn_count(case_hub_greedy_user_top_p), 0)
|
||||
|
||||
# 5. User sets `do_sample=False` and `temperature=0.5` via a single `update()` call -> WARN.
|
||||
def case_update_both_sides():
|
||||
cfg = GenerationConfig()
|
||||
cfg.update(do_sample=False, temperature=0.5)
|
||||
|
||||
self.assertNotEqual(_warn_count(case_update_both_sides), 0)
|
||||
|
||||
# 6. Same idea for beam flags: user only asks for `num_beams=1`, hub default has `length_penalty=0.8`
|
||||
# -> NO warning.
|
||||
def case_hub_length_penalty_user_num_beams_only():
|
||||
cfg = GenerationConfig(num_beams=4, length_penalty=0.8) # stands in for the hub default
|
||||
cfg.update(num_beams=1)
|
||||
|
||||
self.assertEqual(_warn_count(case_hub_length_penalty_user_num_beams_only), 0)
|
||||
|
||||
# 7. User sets BOTH `num_beams=1` and `length_penalty=0.8` explicitly -> WARN.
|
||||
self.assertNotEqual(_warn_count(lambda: GenerationConfig(num_beams=1, length_penalty=0.8)), 0)
|
||||
|
||||
def test_refuse_to_save(self):
|
||||
"""Tests that we refuse to save a generation config that fails validation."""
|
||||
|
||||
# setting the temperature alone is invalid, as we also need to set do_sample to True -> throws a warning that
|
||||
# is caught, doesn't save, and raises an exception
|
||||
config = GenerationConfig()
|
||||
config.temperature = 0.5
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(ValueError) as exc:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||
self.assertTrue("`temperature` is set to `0.5`" in str(exc.exception))
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||
|
||||
# greedy decoding throws an exception if we try to return multiple sequences -> throws an exception that is
|
||||
# caught, doesn't save, and raises a warning
|
||||
config = GenerationConfig()
|
||||
config.num_return_sequences = 2
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(ValueError) as exc:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertTrue("Fix these issues to save the configuration." in str(exc.exception))
|
||||
self.assertTrue(
|
||||
"Greedy methods (do_sample != True) without beam search do not support `num_return_sequences` different than 1"
|
||||
in str(exc.exception)
|
||||
)
|
||||
self.assertTrue(len(os.listdir(tmp_dir)) == 0)
|
||||
|
||||
# Final check: no logs at warning level/warnings/exceptions thrown if it is correct, and file is saved.
|
||||
config = GenerationConfig()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Catch warnings
|
||||
with warnings.catch_warnings(record=True) as captured_warnings:
|
||||
# Catch logs (up to WARNING level, the default level)
|
||||
with LoggingLevel(logging.WARNING):
|
||||
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
|
||||
with CaptureLogger(logger) as captured_logs:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertEqual(len(captured_warnings), 0)
|
||||
self.assertEqual(len(captured_logs.out), 0)
|
||||
self.assertEqual(len(os.listdir(tmp_dir)), 1)
|
||||
|
||||
def test_generation_mode(self):
|
||||
"""Tests that the `get_generation_mode` method is working as expected."""
|
||||
config = GenerationConfig()
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.GREEDY_SEARCH)
|
||||
|
||||
config = GenerationConfig(do_sample=True)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.SAMPLE)
|
||||
|
||||
config = GenerationConfig(num_beams=2)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.BEAM_SEARCH)
|
||||
|
||||
# TODO joao, manuel: remove this in v4.62.0
|
||||
config = GenerationConfig(top_k=10, do_sample=False, penalty_alpha=0.6)
|
||||
self.assertEqual(config.get_generation_mode(), GenerationMode.CONTRASTIVE_SEARCH)
|
||||
|
||||
config = GenerationConfig()
|
||||
self.assertEqual(config.get_generation_mode(assistant_model="foo"), GenerationMode.ASSISTED_GENERATION)
|
||||
|
||||
def test_static_cache_without_cache_config(self):
|
||||
"""Regression test for #35026 -- static cache should work without a cache config."""
|
||||
config = GenerationConfig(cache_implementation="static")
|
||||
self.assertEqual(config.cache_implementation, "static")
|
||||
self.assertEqual(config.cache_config, None)
|
||||
|
||||
|
||||
class GenerationConfigSerializationTest(unittest.TestCase):
|
||||
def test_serialize_generation_sequence_bias(self):
|
||||
"""Tests that GenerationConfig is serialized and SequenceBiasLogitsProcessor is initialized with sequence_bias parameter"""
|
||||
generation_config = GenerationConfig()
|
||||
sequence_bias = [[[45, 67], -0.6], [[89], 1.2]]
|
||||
generation_config.sequence_bias = sequence_bias
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.sequence_bias, sequence_bias)
|
||||
|
||||
expected_sequence_bias = {(45, 67): -0.6, (89,): 1.2}
|
||||
bias_logits_processor = SequenceBiasLogitsProcessor(new_config.sequence_bias)
|
||||
self.assertDictEqual(bias_logits_processor.sequence_bias, expected_sequence_bias)
|
||||
|
||||
def test_serialize_generation_min_length_eos_token(self):
|
||||
"""Tests that GenerationConfig is serialized and MinLengthLogitsProcessor is initialized with min_length and eos_token_id"""
|
||||
eos_token_id = 0
|
||||
min_length = 10
|
||||
|
||||
generation_config = GenerationConfig(min_length=min_length, eos_token_id=eos_token_id)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.min_length, min_length)
|
||||
self.assertEqual(new_config.eos_token_id, eos_token_id)
|
||||
|
||||
min_dist_processor = MinLengthLogitsProcessor(
|
||||
min_length=new_config.min_length, eos_token_id=new_config.eos_token_id
|
||||
)
|
||||
self.assertEqual(min_dist_processor.min_length, min_length)
|
||||
self.assertEqual(min_dist_processor.eos_token_id, eos_token_id)
|
||||
|
||||
def test_serialize_generation_min_new_tokens(self):
|
||||
"""Tests that GenerationConfig is serialized and MinNewTokensLengthLogitsProcessor is initialized with min_new_tokens"""
|
||||
eos_token_id = 0
|
||||
min_new_tokens = 5
|
||||
prompt_length_to_skip = 2
|
||||
|
||||
generation_config = GenerationConfig(min_new_tokens=min_new_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.min_new_tokens, min_new_tokens)
|
||||
|
||||
min_new_tokens_processor = MinNewTokensLengthLogitsProcessor(
|
||||
prompt_length_to_skip=prompt_length_to_skip,
|
||||
min_new_tokens=new_config.min_new_tokens,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
self.assertEqual(min_new_tokens_processor.min_new_tokens, min_new_tokens)
|
||||
|
||||
def test_serialize_generation_temperature(self):
|
||||
"""Tests that GenerationConfig is serialized and TemperatureLogitsWarper is initialized with temperature"""
|
||||
temperature = 2.0
|
||||
|
||||
generation_config = GenerationConfig(temperature=temperature, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.temperature, temperature)
|
||||
|
||||
temperature_logits_warper = TemperatureLogitsWarper(temperature=new_config.temperature)
|
||||
self.assertEqual(temperature_logits_warper.temperature, temperature)
|
||||
|
||||
def test_serialize_generation_repetition_penalty(self):
|
||||
"""Tests that GenerationConfig is serialized and RepetitionPenaltyLogitsProcessor is initialized with repetition_penalty"""
|
||||
penalty = 2.0
|
||||
|
||||
generation_config = GenerationConfig(repetition_penalty=penalty)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.repetition_penalty, penalty)
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=new_config.repetition_penalty)
|
||||
self.assertEqual(rep_penalty_proc.penalty, penalty)
|
||||
|
||||
def test_serialize_generation_encoder_repetition_penalty(self):
|
||||
"""Tests that GenerationConfig is serialized and EncoderRepetitionPenaltyLogitsProcessor is initialized with penalty and input_ids"""
|
||||
penalty = 2.0
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
|
||||
generation_config = GenerationConfig(encoder_repetition_penalty=penalty)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.encoder_repetition_penalty, penalty)
|
||||
|
||||
rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(
|
||||
penalty=new_config.encoder_repetition_penalty, encoder_input_ids=input_ids
|
||||
)
|
||||
self.assertEqual(rep_penalty_proc.penalty, 1 / penalty)
|
||||
torch.testing.assert_close(rep_penalty_proc.encoder_input_ids, input_ids)
|
||||
|
||||
def test_serialize_generation_top_p(self):
|
||||
"""Tests that GenerationConfig is serialized and TopPLogitsWarper is initialized with top_p"""
|
||||
top_p = 0.8
|
||||
|
||||
generation_config = GenerationConfig(top_p=top_p, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.top_p, top_p)
|
||||
|
||||
rep_penalty_proc = TopPLogitsWarper(top_p=new_config.top_p)
|
||||
self.assertEqual(rep_penalty_proc.top_p, top_p)
|
||||
|
||||
def test_serialize_generation_top_k(self):
|
||||
"""Tests that GenerationConfig is serialized and TopKLogitsWarper is initialized with top_k"""
|
||||
top_k = 2
|
||||
|
||||
generation_config = GenerationConfig(top_k=top_k, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.top_k, top_k)
|
||||
|
||||
top_k_logits_wrap = TopKLogitsWarper(top_k=new_config.top_k)
|
||||
self.assertEqual(top_k_logits_wrap.top_k, top_k)
|
||||
|
||||
def test_serialize_generation_min_p(self):
|
||||
"""Tests that GenerationConfig is serialized and MinPLogitsWarper is initialized with min_p"""
|
||||
min_p = 0.8
|
||||
|
||||
generation_config = GenerationConfig(min_p=min_p, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.min_p, min_p)
|
||||
|
||||
min_k_logits_wrap = MinPLogitsWarper(min_p=new_config.min_p)
|
||||
self.assertEqual(min_k_logits_wrap.min_p, min_p)
|
||||
|
||||
def test_serialize_generation_typical_p(self):
|
||||
"""Tests that GenerationConfig is serialized and TypicalLogitsWarper is initialized with mass"""
|
||||
mass = 0.8
|
||||
|
||||
generation_config = GenerationConfig(typical_p=mass, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.typical_p, mass)
|
||||
|
||||
typical_p_logits_wrap = TypicalLogitsWarper(mass=new_config.typical_p)
|
||||
self.assertEqual(typical_p_logits_wrap.mass, mass)
|
||||
|
||||
def test_serialize_generation_epsilon_cutoff(self):
|
||||
"""Tests that GenerationConfig is serialized and EpsilonLogitsWarper is initialized with epsilon"""
|
||||
epsilon = 0.8
|
||||
|
||||
generation_config = GenerationConfig(epsilon_cutoff=epsilon, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.epsilon_cutoff, epsilon)
|
||||
|
||||
epsilon_logits_wrap = EpsilonLogitsWarper(epsilon=new_config.epsilon_cutoff)
|
||||
self.assertEqual(epsilon_logits_wrap.epsilon, epsilon)
|
||||
|
||||
def test_serialize_generation_eta_cutoff(self):
|
||||
"""Tests that GenerationConfig is serialized and EtaLogitsWarper is initialized with epsilon"""
|
||||
epsilon = 0.8
|
||||
|
||||
generation_config = GenerationConfig(eta_cutoff=epsilon, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.eta_cutoff, epsilon)
|
||||
|
||||
eta_logits_wrap = EtaLogitsWarper(epsilon=new_config.eta_cutoff)
|
||||
self.assertEqual(eta_logits_wrap.epsilon, epsilon)
|
||||
|
||||
def test_serialize_generation_ngram_size(self):
|
||||
"""Tests that GenerationConfig is serialized and NoRepeatNGramLogitsProcessor is initialized with ngram_size"""
|
||||
ngram_size = 2
|
||||
|
||||
generation_config = GenerationConfig(no_repeat_ngram_size=ngram_size, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.no_repeat_ngram_size, ngram_size)
|
||||
|
||||
no_repeat_ngram_proc = NoRepeatNGramLogitsProcessor(ngram_size=new_config.no_repeat_ngram_size)
|
||||
self.assertEqual(no_repeat_ngram_proc.ngram_size, ngram_size)
|
||||
|
||||
def test_serialize_generation_encoder_ngram_size(self):
|
||||
"""Tests that GenerationConfig is serialized and EncoderNoRepeatNGramLogitsProcessor is initialized with ngram_size"""
|
||||
ngram_size = 2
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
|
||||
generation_config = GenerationConfig(encoder_no_repeat_ngram_size=ngram_size, do_sample=True)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.encoder_no_repeat_ngram_size, ngram_size)
|
||||
|
||||
encoder_no_repeat_ngram_proc = EncoderNoRepeatNGramLogitsProcessor(
|
||||
encoder_ngram_size=new_config.encoder_no_repeat_ngram_size, encoder_input_ids=input_ids
|
||||
)
|
||||
self.assertEqual(encoder_no_repeat_ngram_proc.ngram_size, ngram_size)
|
||||
|
||||
def test_serialize_generation_bad_words_ids(self):
|
||||
"""Tests that GenerationConfig is serialized and NoBadWordsLogitsProcessor is initialized with bad_words_ids"""
|
||||
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
|
||||
|
||||
generation_config = GenerationConfig(bad_words_ids=bad_word_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.bad_words_ids, bad_word_tokens)
|
||||
|
||||
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=new_config.bad_words_ids)
|
||||
self.assertSequenceEqual(no_bad_words_dist_proc.bad_word_ids, bad_word_tokens)
|
||||
|
||||
def test_serialize_generation_num_beams(self):
|
||||
"""Tests that GenerationConfig is serialized and PrefixConstrainedLogitsProcessor is initialized with num_beams"""
|
||||
num_beams = 1
|
||||
|
||||
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
|
||||
return [[0, 1], [2, 3]][batch_id]
|
||||
|
||||
generation_config = GenerationConfig(num_beams=num_beams)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.num_beams, num_beams)
|
||||
|
||||
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(
|
||||
prefix_allowed_tokens_fn, num_beams=new_config.num_beams
|
||||
)
|
||||
self.assertEqual(prefix_constrained_logits_proc._num_beams, num_beams)
|
||||
|
||||
def test_serialize_generation_bos_token_id(self):
|
||||
"""Tests that GenerationConfig is serialized and ForcedBOSTokenLogitsProcessor is initialized with bos_token_id"""
|
||||
bos_token_id = 0
|
||||
|
||||
generation_config = GenerationConfig(bos_token_id=bos_token_id)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.bos_token_id, bos_token_id)
|
||||
|
||||
logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=new_config.bos_token_id)
|
||||
self.assertEqual(logits_processor.bos_token_id, bos_token_id)
|
||||
|
||||
def test_serialize_generation_eos_token_id(self):
|
||||
"""Tests that GenerationConfig is serialized and ForcedEOSTokenLogitsProcessor is initialized with eos_token_id"""
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
generation_config = GenerationConfig(eos_token_id=eos_token_id)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.eos_token_id, eos_token_id)
|
||||
|
||||
logits_processor = ForcedEOSTokenLogitsProcessor(
|
||||
max_length=max_length, eos_token_id=new_config.eos_token_id, device=torch_device
|
||||
)
|
||||
self.assertEqual(logits_processor.eos_token_id, eos_token_id)
|
||||
|
||||
def test_serialize_generation_exponential_decay_length_penalty(self):
|
||||
"""Tests that GenerationConfig is serialized and ExponentialDecayLengthPenalty is initialized with regulation_start and regulation_factor"""
|
||||
eos_token_id = 0
|
||||
penalty_start = 5
|
||||
penalty_factor = 1.1
|
||||
input_ids_seq_length = 10
|
||||
exponential_decay_length_penalty = (penalty_start, penalty_factor)
|
||||
|
||||
generation_config = GenerationConfig(exponential_decay_length_penalty=exponential_decay_length_penalty)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.exponential_decay_length_penalty, [penalty_start, penalty_factor])
|
||||
|
||||
exponential_decay_processor = ExponentialDecayLengthPenalty(
|
||||
exponential_decay_length_penalty=new_config.exponential_decay_length_penalty,
|
||||
eos_token_id=eos_token_id,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
)
|
||||
self.assertEqual(
|
||||
exponential_decay_processor.regulation_start, exponential_decay_length_penalty[0] + input_ids_seq_length
|
||||
)
|
||||
self.assertEqual(exponential_decay_processor.regulation_factor, exponential_decay_length_penalty[1])
|
||||
|
||||
def test_serialize_generation_begin_suppress_tokens(self):
|
||||
"""Tests that GenerationConfig is serialized and SuppressTokensAtBeginLogitsProcessor is initialized with begin_suppress_token and begin_index"""
|
||||
|
||||
begin_suppress_tokens = [220, 50256]
|
||||
begin_index = 0
|
||||
generation_config = GenerationConfig(begin_suppress_tokens=begin_suppress_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.begin_suppress_tokens, begin_suppress_tokens)
|
||||
|
||||
suppress_processor = SuppressTokensAtBeginLogitsProcessor(
|
||||
begin_suppress_tokens=new_config.begin_suppress_tokens, begin_index=begin_index
|
||||
)
|
||||
self.assertSequenceEqual(suppress_processor.begin_suppress_tokens, begin_suppress_tokens)
|
||||
self.assertEqual(suppress_processor.begin_index, begin_index)
|
||||
|
||||
def test_serialize_generation_suppress_tokens(self):
|
||||
"""Tests that GenerationConfig is serialized and SuppressTokensLogitsProcessor is initialized with suppress_token"""
|
||||
suppress_tokens = [220, 50256]
|
||||
|
||||
generation_config = GenerationConfig(suppress_tokens=suppress_tokens)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertSequenceEqual(new_config.suppress_tokens, suppress_tokens)
|
||||
|
||||
suppress_processor = SuppressTokensLogitsProcessor(suppress_tokens=new_config.suppress_tokens)
|
||||
self.assertSequenceEqual(suppress_processor.suppress_tokens, suppress_tokens)
|
||||
|
||||
def test_serialize_generation_guidance_scale(self):
|
||||
"""Tests that GenerationConfig is serialized and ClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale"""
|
||||
guidance_scale = 2.0
|
||||
generation_config = GenerationConfig(guidance_scale=guidance_scale)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.guidance_scale, guidance_scale)
|
||||
|
||||
classifier_processor = ClassifierFreeGuidanceLogitsProcessor(guidance_scale=new_config.guidance_scale)
|
||||
self.assertEqual(classifier_processor.guidance_scale, guidance_scale)
|
||||
|
||||
def test_serialize_generation_guidance_scale_unbatched(self):
|
||||
"""Tests that GenerationConfig is serialized and UnbatchedClassifierFreeGuidanceLogitsProcessor is initialized with guidance_scale"""
|
||||
guidance_scale = 2.0
|
||||
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
|
||||
generation_config = GenerationConfig(guidance_scale=guidance_scale)
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.guidance_scale, guidance_scale)
|
||||
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(new_config.guidance_scale, {}, input_ids)
|
||||
self.assertEqual(cfg.guidance_scale, guidance_scale)
|
||||
|
||||
def test_serialize_generation_watermarking_config(self):
|
||||
"""Tests that GenerationConfig is serialized and WatermarkLogitsProcessor is initialized with WatermarkingConfig parameters"""
|
||||
|
||||
vocab_size = 20
|
||||
bias = 2.0
|
||||
greenlist_ratio = 0.5
|
||||
hashing_key = 10
|
||||
seeding_scheme = "lefthash"
|
||||
context_width = 10
|
||||
watermarking_config = WatermarkingConfig(
|
||||
bias=bias,
|
||||
greenlist_ratio=greenlist_ratio,
|
||||
hashing_key=hashing_key,
|
||||
seeding_scheme=seeding_scheme,
|
||||
context_width=context_width,
|
||||
)
|
||||
generation_config = GenerationConfig(watermarking_config=watermarking_config)
|
||||
|
||||
with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
|
||||
generation_config.save_pretrained(tmp_dir)
|
||||
new_config = GenerationConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_config.watermarking_config.bias, bias)
|
||||
self.assertEqual(new_config.watermarking_config.greenlist_ratio, greenlist_ratio)
|
||||
self.assertEqual(new_config.watermarking_config.hashing_key, hashing_key)
|
||||
self.assertEqual(new_config.watermarking_config.seeding_scheme, seeding_scheme)
|
||||
self.assertEqual(new_config.watermarking_config.context_width, context_width)
|
||||
|
||||
watermark = WatermarkLogitsProcessor(
|
||||
vocab_size=vocab_size,
|
||||
device=torch_device,
|
||||
greenlist_ratio=new_config.watermarking_config.greenlist_ratio,
|
||||
bias=new_config.watermarking_config.bias,
|
||||
hashing_key=new_config.watermarking_config.hashing_key,
|
||||
seeding_scheme=new_config.watermarking_config.seeding_scheme,
|
||||
context_width=new_config.watermarking_config.context_width,
|
||||
)
|
||||
self.assertEqual(watermark.bias, bias)
|
||||
self.assertEqual(watermark.greenlist_size, int(vocab_size * greenlist_ratio))
|
||||
self.assertEqual(watermark.hash_key, hashing_key)
|
||||
self.assertEqual(watermark.seeding_scheme, seeding_scheme)
|
||||
self.assertEqual(watermark.context_width, context_width)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
config.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
config.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_in_organization_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_on_pr_revision(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
# create a PR
|
||||
pr = create_pull_request(repo_id=tmp_repo.repo_id, title="Test PR", token=self._token)
|
||||
revision = f"refs/pr/{pr.num}"
|
||||
|
||||
# push to PR ref
|
||||
config = GenerationConfig(
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
length_penalty=1.0,
|
||||
)
|
||||
config.push_to_hub(tmp_repo.repo_id, token=self._token, revision=revision)
|
||||
|
||||
# load from PR ref
|
||||
new_config = GenerationConfig.from_pretrained(tmp_repo.repo_id, revision=revision)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
1666
tests/generation/test_continuous_batching.py
Normal file
1666
tests/generation/test_continuous_batching.py
Normal file
File diff suppressed because it is too large
Load Diff
176
tests/generation/test_flash_attention_parity.py
Normal file
176
tests/generation/test_flash_attention_parity.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright 2025 Eduard Durech, SGLang, and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Usage:
|
||||
# RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py
|
||||
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_all_flash_attn, require_torch_gpu, slow
|
||||
|
||||
|
||||
class FlashAttentionParityTest(unittest.TestCase):
|
||||
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
|
||||
def _lcs(self, X, Y):
|
||||
m = len(X)
|
||||
n = len(Y)
|
||||
L = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
|
||||
for i in range(m + 1):
|
||||
for j in range(n + 1):
|
||||
if i == 0 or j == 0:
|
||||
L[i][j] = 0
|
||||
elif X[i - 1] == Y[j - 1]:
|
||||
L[i][j] = L[i - 1][j - 1] + 1
|
||||
else:
|
||||
L[i][j] = max(L[i - 1][j], L[i][j - 1])
|
||||
|
||||
return L[m][n]
|
||||
|
||||
# From https://github.com/sgl-project/sglang/blob/main/python/sglang/test/test_utils.py
|
||||
def _calculate_rouge_l(self, output_strs_list1, output_strs_list2):
|
||||
rouge_l_scores = []
|
||||
|
||||
for s1, s2 in zip(output_strs_list1, output_strs_list2):
|
||||
lcs_len = self._lcs(s1, s2)
|
||||
precision = lcs_len / len(s1) if len(s1) > 0 else 0
|
||||
recall = lcs_len / len(s2) if len(s2) > 0 else 0
|
||||
if precision + recall > 0:
|
||||
fmeasure = (2 * precision * recall) / (precision + recall)
|
||||
else:
|
||||
fmeasure = 0.0
|
||||
rouge_l_scores.append(fmeasure)
|
||||
|
||||
return rouge_l_scores
|
||||
|
||||
def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5):
|
||||
for _ in range(n_warmup):
|
||||
model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_time = torch.cuda.Event(enable_timing=True)
|
||||
end_time = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_time.record()
|
||||
for _ in range(n_runs):
|
||||
model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
end_time.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start_time.elapsed_time(end_time) / n_runs
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_all_flash_attn
|
||||
@pytest.mark.all_flash_attn_test
|
||||
def test_flash_attention_parity(self):
|
||||
flash_attn_versions = [2, 3, 4]
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
prompt = ["The ETH AI Center is", "What is life?"]
|
||||
|
||||
# 1. Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
# 2. Generate with both models
|
||||
inputs = tokenizer(prompt, padding=True, padding_side="left", return_tensors="pt").to("cuda")
|
||||
|
||||
logits = {}
|
||||
logprobs = {}
|
||||
outputs = defaultdict(list)
|
||||
with torch.no_grad():
|
||||
|
||||
def generate(model, version, outputs, logits, logprobs):
|
||||
model.set_attn_implementation(f"flash_attention_{version}")
|
||||
output = model.generate(
|
||||
**inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True
|
||||
)
|
||||
logit = torch.stack(output.scores)
|
||||
logprob = torch.nn.functional.log_softmax(logit, dim=-1)
|
||||
|
||||
for i in range(len(prompt)):
|
||||
outputs[version].append(tokenizer.decode(output.sequences[i], skip_special_tokens=True))
|
||||
logits[version] = logit
|
||||
logprobs[version] = logprob
|
||||
|
||||
for version in flash_attn_versions:
|
||||
generate(model, version, outputs, logits, logprobs)
|
||||
|
||||
# 3. Correctness check
|
||||
# 3a. Logits
|
||||
# FA2 as base to compare against
|
||||
logits_1 = logits[2]
|
||||
logprobs_1 = logprobs[2]
|
||||
max_logprob_diffs = []
|
||||
for version in range(1, len(flash_attn_versions)):
|
||||
logits_x = logits[flash_attn_versions[version]]
|
||||
logprobs_x = logprobs[flash_attn_versions[version]]
|
||||
max_logprob_diffs.append(torch.max(torch.abs(logprobs_1 - logprobs_x)).item())
|
||||
|
||||
# Only 80% need to pass the tolerance (big model with several steps)
|
||||
atol, fraction = 4e-2, 0.8
|
||||
logits_ok = (torch.abs(logits_1 - logits_x) <= atol).float().mean().item()
|
||||
assert logits_ok >= fraction, (
|
||||
f"FA{flash_attn_versions[version]} logits pass fraction {logits_ok:.6f} < {fraction:.6f}"
|
||||
)
|
||||
|
||||
# 3b. Generated text
|
||||
# FA2 as base to compare against
|
||||
texts_1 = outputs[2]
|
||||
rouge_scores = []
|
||||
for version in range(1, len(flash_attn_versions)):
|
||||
fa_version = flash_attn_versions[version]
|
||||
texts_x = outputs[fa_version]
|
||||
rouge_score = self._calculate_rouge_l(texts_1, texts_x)
|
||||
for idx, score in enumerate(rouge_score):
|
||||
assert score > 0.99, (
|
||||
f"Generated texts at prompt {idx} do not match (ROUGE-L: {score}) comparing FA2 vs FA{fa_version}"
|
||||
)
|
||||
rouge_scores.append(self._calculate_rouge_l(texts_1, texts_x))
|
||||
|
||||
# 4. Performance check
|
||||
times = []
|
||||
with torch.no_grad():
|
||||
for version in flash_attn_versions:
|
||||
model.set_attn_implementation(f"flash_attention_{version}")
|
||||
times.append(self._benchmark_generation(model, inputs))
|
||||
|
||||
# Summary
|
||||
print(f"\n--- Flash Attention Parity Test on {model_id} ---")
|
||||
print(f"Prompts: '{prompt}'")
|
||||
print("\nGenerated texts:")
|
||||
for version in flash_attn_versions:
|
||||
print(f" With FA{version}: {outputs[version]}")
|
||||
print("\nROUGE-L scores:")
|
||||
for idx, version in enumerate(range(1, len(flash_attn_versions))):
|
||||
print(f" Between FA2 and FA{flash_attn_versions[version]}: {rouge_scores[idx]}")
|
||||
print("\nMax absolute difference in logprobs:")
|
||||
for idx, version in enumerate(range(1, len(flash_attn_versions))):
|
||||
print(f" Between FA2 and FA{flash_attn_versions[version]}: {max_logprob_diffs[idx]:.5e}")
|
||||
print("\nLatency:")
|
||||
for idx, version in enumerate(flash_attn_versions):
|
||||
print(f" With FA{version}: {times[idx]}")
|
||||
print("---")
|
||||
1418
tests/generation/test_logits_process.py
Normal file
1418
tests/generation/test_logits_process.py
Normal file
File diff suppressed because it is too large
Load Diff
168
tests/generation/test_paged_attention.py
Normal file
168
tests/generation/test_paged_attention.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
from transformers.generation.configuration_utils import ContinuousBatchingConfig
|
||||
from transformers.testing_utils import Expectations, slow
|
||||
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
"A man is a walking his dog down the street, and a the turn he sees",
|
||||
"Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an",
|
||||
"A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?",
|
||||
"Please fill in the form to",
|
||||
"For safety reasons, the train is stopped in the middle of the",
|
||||
]
|
||||
|
||||
_EXPECTED_OUTPUTS = Expectations(
|
||||
{
|
||||
("cpu", None): [ # FIXME: CPU tests only pass for eager and flex. Maybe the test should be re-thought.
|
||||
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes",
|
||||
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n##",
|
||||
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer",
|
||||
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]",
|
||||
"track. The train is stopped because of a mechanical failure. The train is stopped because of a mechanical failure. The train is stopped because of a mechanical",
|
||||
# TODO: investigate why that last expectation seems incorrect
|
||||
],
|
||||
("cuda", (9, 0)): [ # A10 and H100
|
||||
"a woman standing on the sidewalk, looking at him. He is immediately drawn to her and feels a strong attraction. He walks up to her and strikes",
|
||||
"orange.\n\n## Step 1: Identify the key characteristics of the fruit\nThe fruit is described as being orange in color and round in shape.\n\n##",
|
||||
"This riddle is a classic example of a lateral thinking puzzle, which requires the test-taker to think creatively and consider multiple possibilities. The answer",
|
||||
"get in touch with us. We will respond to your message as soon as possible.\n\n[Your Name]\n[Your Email]\n[Your Phone Number]",
|
||||
# The last prompt sits on a numerical boundary: eager/flex produce "does", sdpa/fa2 produce "will".
|
||||
# We use a tuple to accept either variant.
|
||||
(
|
||||
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers does the train",
|
||||
"track. The train is stopped for 30 minutes. The train is moving at a speed of 60 km/h. How many kilometers will the train",
|
||||
),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
class TestBatchGeneration(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.2-3b-Instruct", dtype="bfloat16", device_map="cuda"
|
||||
).eval()
|
||||
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left")
|
||||
|
||||
if cls.tokenizer.pad_token is None:
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
eos_id = cls.model.config.eos_token_id
|
||||
cls.model.config.pad_token_id = eos_id[0] if isinstance(eos_id, list) else eos_id
|
||||
|
||||
cls.model.use_cache = False
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("paged|eager", 64, 128, 64),
|
||||
("paged|sdpa", 32, 256, 128),
|
||||
("paged|flash_attention_2", 16, 512, 256),
|
||||
("paged|flex_attention", 64, 128, 64),
|
||||
]
|
||||
)
|
||||
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
||||
self.model.config.attn_implementation = attn_impl
|
||||
|
||||
cb_config = ContinuousBatchingConfig(
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
max_batch_tokens=max_batch_tokens,
|
||||
)
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=30,
|
||||
top_k=0,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
|
||||
batch_inputs = list(tokenized["input_ids"])
|
||||
|
||||
batch_outputs = self.model.generate_batch(
|
||||
inputs=batch_inputs,
|
||||
generation_config=generation_config,
|
||||
continuous_batching_config=cb_config,
|
||||
)
|
||||
|
||||
expected_outputs = _EXPECTED_OUTPUTS.get_expectation()
|
||||
|
||||
for i, (output, expected_output) in enumerate(zip(batch_outputs.values(), expected_outputs)):
|
||||
generated = self.tokenizer.decode(output.generated_tokens, skip_special_tokens=False).strip()
|
||||
expected_output = (expected_output.strip(),) if isinstance(expected_output, str) else expected_output
|
||||
self.assertIn(
|
||||
generated,
|
||||
[e.strip() for e in expected_output],
|
||||
msg=f"[{attn_impl}] Mismatch in request {i}:\nExpected one of: {expected_output}\nGot: {generated}",
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("paged|eager", 64, 128, 64),
|
||||
("paged|sdpa", 32, 256, 128),
|
||||
("paged|flash_attention_2", 16, 512, 256),
|
||||
("paged|flex_attention", 64, 128, 64),
|
||||
]
|
||||
)
|
||||
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):
|
||||
"""Test batch generation with do_sampling=True to verify sampling works correctly."""
|
||||
self.model.config.attn_implementation = attn_impl
|
||||
|
||||
cb_config = ContinuousBatchingConfig(
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
max_batch_tokens=max_batch_tokens,
|
||||
)
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=30,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.9,
|
||||
temperature=0.8,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
tokenized = self.tokenizer(_TEST_PROMPTS, truncation=True, max_length=512) # Use fewer prompts for faster test
|
||||
batch_inputs = list(tokenized["input_ids"])
|
||||
|
||||
start = time.time()
|
||||
batch_outputs = self.model.generate_batch(
|
||||
inputs=batch_inputs,
|
||||
generation_config=generation_config,
|
||||
continuous_batching_config=cb_config,
|
||||
)
|
||||
end = time.time()
|
||||
print(
|
||||
f"\n[{attn_impl}] Sampling batch took {end - start:.2f}s with config: blocks={num_blocks}, block_size={block_size}, max_batch_tokens={max_batch_tokens}"
|
||||
)
|
||||
|
||||
# With sampling enabled, we can't check exact outputs, but we should verify:
|
||||
# 1. All requests completed successfully
|
||||
# 2. Generated text is non-empty
|
||||
# 3. Generated text is different from greedy (demonstrating sampling is working)
|
||||
self.assertEqual(len(batch_outputs), len(batch_inputs), f"[{attn_impl}] Not all requests completed")
|
||||
|
||||
for i, req_id in enumerate(batch_outputs):
|
||||
generated = self.tokenizer.decode(
|
||||
batch_outputs[req_id].generated_tokens, skip_special_tokens=False
|
||||
).strip()
|
||||
self.assertTrue(
|
||||
len(generated) > 0,
|
||||
msg=f"[{attn_impl}] Empty output for request {i}",
|
||||
)
|
||||
# Check that we got at least some tokens generated
|
||||
generated_tokens = batch_outputs[req_id].generated_tokens
|
||||
self.assertGreater(
|
||||
len(generated_tokens),
|
||||
0,
|
||||
msg=f"[{attn_impl}] No tokens generated for request {i}",
|
||||
)
|
||||
289
tests/generation/test_stopping_criteria.py
Normal file
289
tests/generation/test_stopping_criteria.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from ..test_modeling_common import ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation import (
|
||||
ConfidenceCriteria,
|
||||
EosTokenCriteria,
|
||||
MaxLengthCriteria,
|
||||
MaxTimeCriteria,
|
||||
StoppingCriteriaList,
|
||||
StopStringCriteria,
|
||||
validate_stopping_criteria,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
def _get_tensors(self, length):
|
||||
batch_size = 3
|
||||
vocab_size = 250
|
||||
|
||||
input_ids = ids_tensor((batch_size, length), vocab_size)
|
||||
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
|
||||
return input_ids, scores
|
||||
|
||||
def test_list_criteria(self):
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
|
||||
criteria = StoppingCriteriaList(
|
||||
[
|
||||
MaxLengthCriteria(max_length=10),
|
||||
MaxTimeCriteria(max_time=0.1),
|
||||
]
|
||||
)
|
||||
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_max_length_criteria(self):
|
||||
criteria = MaxLengthCriteria(max_length=10)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_max_time_criteria(self):
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
|
||||
criteria = MaxTimeCriteria(max_time=0.1)
|
||||
self.assertFalse(all(criteria(input_ids, scores)))
|
||||
|
||||
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
def test_eos_token_criteria(self):
|
||||
criteria = EosTokenCriteria(eos_token_id=0)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:, -1] = 0
|
||||
self.assertTrue(all(criteria(input_ids, scores)))
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:2, -1] = 0
|
||||
input_ids[2, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [True, True, False])
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
input_ids[:, -1] = 1
|
||||
self.assertListEqual(criteria(input_ids, scores).tolist(), [False, False, False])
|
||||
|
||||
def test_confidence_criteria(self):
|
||||
criteria = ConfidenceCriteria(assistant_confidence_threshold=0.5)
|
||||
|
||||
vocab_size = 250
|
||||
length = 5
|
||||
|
||||
input_ids = ids_tensor((1, length), vocab_size)
|
||||
scores = (torch.randn((1, vocab_size)),)
|
||||
|
||||
# Simulate high confidence by setting the probability of the last token to be high
|
||||
scores[0][0, input_ids[0, -1]] = 10.0 # Logits before softmax
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
# Simulate low confidence by setting the probability of the last token to be low
|
||||
scores[0][0, input_ids[0, -1]] = -10.0 # Logits before softmax
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_validate_stopping_criteria(self):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
|
||||
|
||||
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
|
||||
|
||||
self.assertEqual(len(stopping_criteria), 1)
|
||||
|
||||
def test_stop_string_criteria(self):
|
||||
true_strings = [
|
||||
"<|im_start|><|im_end|>",
|
||||
"<|im_start|><|im_end|<|im_end|>",
|
||||
">><|im_start|>>stop",
|
||||
"stop",
|
||||
"e nd",
|
||||
]
|
||||
false_strings = [
|
||||
"<|im_start|><|im_end|",
|
||||
"<|im_start|><|im_end|<|im_end|",
|
||||
"<|im_end|><|im_start|>",
|
||||
"<|im_end|<>stop<|im_end|",
|
||||
"end",
|
||||
"en d",
|
||||
"eNd",
|
||||
"<|im_end|",
|
||||
"|im_end|>",
|
||||
"s",
|
||||
]
|
||||
stop_strings = ["<|im_end|>", "stop", "e nd"]
|
||||
|
||||
# Use a tokenizer that won't actually have special tokens for these
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer.padding_side = "left"
|
||||
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
|
||||
scores = None
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
|
||||
for i in range(len(true_strings)):
|
||||
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
|
||||
for i in range(len(false_strings)):
|
||||
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
|
||||
|
||||
# Now try it with a tokenizer where those are actually special tokens
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer.padding_side = "left"
|
||||
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
|
||||
for i in range(len(true_strings)):
|
||||
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
|
||||
for i in range(len(false_strings)):
|
||||
self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores))
|
||||
|
||||
def test_stop_string_criteria_vocab_size_mismatch(self):
|
||||
"""Test that StopStringCriteria handles tokens above len(tokenizer) correctly."""
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
# Create input_ids with tokens above len(tokenizer)
|
||||
input_ids = torch.tensor([[len(tokenizer) + 1024, 1, 2]], device=torch_device)
|
||||
scores = None
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["test"])
|
||||
|
||||
# This should not raise an error and should return False since no stop string is matched
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
def test_stop_string_matching_positions(self):
|
||||
stop_string = "stop"
|
||||
token_list = ["last", "top", "topper", "s", "p"]
|
||||
token_indices = list(range(len(token_list)))
|
||||
all_token_valid_positions, all_token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions(
|
||||
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
|
||||
)
|
||||
valid_positions = {
|
||||
token_list[idx]: positions for idx, positions in all_token_valid_positions[stop_string].items()
|
||||
}
|
||||
end_overlaps = {token_list[idx]: overlaps for idx, overlaps in all_token_end_overlaps[stop_string].items()}
|
||||
self.assertEqual(valid_positions, {"s": [3], "last": [2]})
|
||||
self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]})
|
||||
|
||||
def test_stop_string_embedding_vecs(self):
|
||||
stop_string = "stop"
|
||||
token_list = ["last", "top", "topper", "s", "p"]
|
||||
token_indices = list(range(len(token_list)))
|
||||
embedding_vec, max_valid_positions, max_valid_end_lens = StopStringCriteria._stop_string_create_embedding_vec(
|
||||
token_list=token_list, token_indices=token_indices, stop_strings=[stop_string]
|
||||
)
|
||||
|
||||
# Positions inside the stop string where the token matches (excluding end overlaps)
|
||||
valid_positions = embedding_vec[:, 0].tolist()
|
||||
self.assertEqual(valid_positions, [2, -1, -1, 3, -1, -1])
|
||||
|
||||
# Overlap lengths between end of stop string and start of token
|
||||
end_overlaps = embedding_vec[:, 1].tolist()
|
||||
self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1, -1])
|
||||
|
||||
# Length of each token
|
||||
token_lengths = embedding_vec[:-1, 2].tolist()
|
||||
self.assertEqual(token_lengths, [len(token) for token in token_list])
|
||||
|
||||
def test_single_letter_stop_string(self):
|
||||
true_strings = ["a", "baa", "abc"] # "abc" is a single token
|
||||
false_strings = ["abbbbbbb", "b"] # "abbbbbbb" is split into multiple tokens
|
||||
stop_strings = ["a"]
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", add_prefix_space=False)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
|
||||
scores = None
|
||||
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings)
|
||||
for input_ids in true_input_ids["input_ids"]:
|
||||
self.assertTrue(criteria(input_ids.unsqueeze(0), scores))
|
||||
for input_ids in false_input_ids["input_ids"]:
|
||||
self.assertFalse(criteria(input_ids.unsqueeze(0), scores))
|
||||
|
||||
def test_criteria_per_row(self):
|
||||
text = "They completed the challenging puzzle, revealing the hidden image at the end"
|
||||
stop_strings = ["end"]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False)
|
||||
|
||||
scores = None
|
||||
criteria = StoppingCriteriaList(
|
||||
[
|
||||
MaxLengthCriteria(max_length=20),
|
||||
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
|
||||
]
|
||||
)
|
||||
|
||||
# trigger stopping when at least one criteria is satisfied, one value per batch
|
||||
self.assertTrue(criteria(inputs["input_ids"], scores))
|
||||
|
||||
# return False when neither is satisfied
|
||||
self.assertFalse(criteria(inputs["input_ids"][:, :-1], scores))
|
||||
|
||||
def test_criteria_per_row_batched(self):
|
||||
text = [
|
||||
"They completed the challenging puzzle, revealing the hidden image at the end",
|
||||
"Today a dragon flew over France",
|
||||
"The aroma of freshly baked pizza filled the kitchen",
|
||||
]
|
||||
stop_strings = ["end"]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
tokenizer.padding_side = "left"
|
||||
inputs = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False)
|
||||
|
||||
scores = None
|
||||
criteria = StoppingCriteriaList(
|
||||
[
|
||||
MaxLengthCriteria(max_length=20),
|
||||
StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings),
|
||||
]
|
||||
)
|
||||
|
||||
# trigger stopping when at least one criteria is satisfied
|
||||
self.assertListEqual(criteria(inputs["input_ids"], scores).tolist(), [True, False, False])
|
||||
|
||||
# False when neither is satisfied
|
||||
self.assertListEqual(criteria(inputs["input_ids"][:, :-1], scores).tolist(), [False, False, False])
|
||||
174
tests/generation/test_streamers.py
Normal file
174
tests/generation/test_streamers.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright 2023 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from queue import Empty
|
||||
from threading import Thread
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import (
|
||||
AsyncTextIteratorStreamer,
|
||||
AutoTokenizer,
|
||||
TextIteratorStreamer,
|
||||
TextStreamer,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import CaptureStdout, require_torch, torch_device
|
||||
from transformers.utils.logging import _get_library_root_logger
|
||||
|
||||
from ..test_modeling_common import ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
@require_torch
|
||||
class StreamerTester(unittest.TestCase):
|
||||
def test_text_streamer_matches_non_streaming(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer)
|
||||
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
|
||||
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
|
||||
streamer_text = cs.out[:-1]
|
||||
|
||||
self.assertEqual(streamer_text, greedy_text)
|
||||
|
||||
def test_iterator_streamer_matches_non_streaming(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
streamer_text = ""
|
||||
for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
||||
self.assertEqual(streamer_text, greedy_text)
|
||||
|
||||
def test_text_streamer_skip_prompt(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
new_greedy_ids = greedy_ids[:, input_ids.shape[1] :]
|
||||
new_greedy_text = tokenizer.decode(new_greedy_ids[0])
|
||||
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
||||
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
|
||||
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
|
||||
streamer_text = cs.out[:-1]
|
||||
|
||||
self.assertEqual(streamer_text, new_greedy_text)
|
||||
|
||||
def test_text_streamer_decode_kwargs(self):
|
||||
# Tests that we can pass `decode_kwargs` to the streamer to control how the tokens are decoded. Must be tested
|
||||
# with actual models -- the dummy models' tokenizers are not aligned with their models, and
|
||||
# `skip_special_tokens=True` has no effect on them
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id
|
||||
|
||||
root = _get_library_root_logger()
|
||||
with patch.object(root, "propagate", False):
|
||||
with CaptureStdout() as cs:
|
||||
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
|
||||
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)
|
||||
|
||||
# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
|
||||
# re-tokenized, must only contain one token
|
||||
streamer_text = cs.out[:-1] # Remove the final "\n"
|
||||
streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt")
|
||||
self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1))
|
||||
|
||||
def test_iterator_streamer_timeout(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
streamer = TextIteratorStreamer(tokenizer, timeout=0.001)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# The streamer will timeout after 0.001 seconds, so an exception will be raised
|
||||
with self.assertRaises(Empty):
|
||||
streamer_text = ""
|
||||
for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
||||
|
||||
@require_torch
|
||||
@pytest.mark.asyncio(loop_scope="class")
|
||||
class AsyncStreamerTester(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_async_iterator_streamer_matches_non_streaming(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
greedy_text = tokenizer.decode(greedy_ids[0])
|
||||
|
||||
streamer = AsyncTextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
streamer_text = ""
|
||||
async for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
|
||||
self.assertEqual(streamer_text, greedy_text)
|
||||
|
||||
async def test_async_iterator_streamer_timeout(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
model.config.eos_token_id = -1
|
||||
|
||||
input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
|
||||
streamer = AsyncTextIteratorStreamer(tokenizer, timeout=0.001)
|
||||
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# The streamer will timeout after 0.001 seconds, so TimeoutError will be raised
|
||||
with self.assertRaises(TimeoutError):
|
||||
streamer_text = ""
|
||||
async for new_text in streamer:
|
||||
streamer_text += new_text
|
||||
5098
tests/generation/test_utils.py
Normal file
5098
tests/generation/test_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user