first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled

This commit is contained in:
陈赣
2026-06-05 16:53:03 +08:00
commit 06f1fd69a6
6047 changed files with 1895387 additions and 0 deletions

0
tests/utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,23 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# fmt: off
from transformers.utils.import_utils import requires
@requires(backends=("random_item_that_should_not_exist",))
class A0:
def __init__(self):
pass

View File

@@ -0,0 +1,78 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# fmt: off
from transformers.utils.import_utils import requires
@requires()
class A0:
def __init__(self):
pass
@requires()
def a0():
pass
@requires(backends=("torch",))
class A1:
def __init__(self):
pass
@requires(backends=("torch",))
def a1():
pass
@requires(
backends=("torch",)
)
class A2:
def __init__(self):
pass
@requires(
backends=("torch",)
)
def a2():
pass
@requires(
backends=(
"torch",
)
)
class A3:
def __init__(self):
pass
@requires(
backends=(
"torch",
)
)
def a3():
pass
@requires(backends=())
class A4:
def __init__(self):
pass

View File

@@ -0,0 +1,92 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# fmt: off
from transformers.utils.import_utils import requires
@requires(backends=("torch>=2.5",))
class D0:
def __init__(self):
pass
@requires(backends=("torch>=2.5",))
def d0():
pass
@requires(backends=("torch>2.5",))
class D1:
def __init__(self):
pass
@requires(backends=("torch>2.5",))
def d1():
pass
@requires(backends=("torch<=2.5",))
class D2:
def __init__(self):
pass
@requires(backends=("torch<=2.5",))
def d2():
pass
@requires(backends=("torch<2.5",))
class D3:
def __init__(self):
pass
@requires(backends=("torch<2.5",))
def d3():
pass
@requires(backends=("torch==2.5",))
class D4:
def __init__(self):
pass
@requires(backends=("torch==2.5",))
def d4():
pass
@requires(backends=("torch!=2.5",))
class D5:
def __init__(self):
pass
@requires(backends=("torch!=2.5",))
def d5():
pass
@requires(backends=("torch>=2.5", "accelerate<0.20"))
class D6:
def __init__(self):
pass
@requires(backends=("torch>=2.5", "accelerate<0.20"))
def d6():
pass

View File

@@ -0,0 +1,77 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# fmt: off
from transformers.utils.import_utils import requires
@requires()
# That's a statement
class B0:
def __init__(self):
pass
@requires()
# That's a statement
def b0():
pass
@requires(backends=("torch",))
# That's a statement
class B1:
def __init__(self):
pass
@requires(backends=("torch",))
# That's a statement
def b1():
pass
@requires(backends=("torch",))
# That's a statement
class B2:
def __init__(self):
pass
@requires(backends=("torch",))
# That's a statement
def b2():
pass
@requires(
backends=(
"torch",
)
)
# That's a statement
class B3:
def __init__(self):
pass
@requires(
backends=(
"torch",
)
)
# That's a statement
def b3():
pass

View File

@@ -0,0 +1,77 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# fmt: off
from transformers.utils.import_utils import requires
@requires(backends=("torch", "torch"))
class C0:
def __init__(self):
pass
@requires(backends=("torch", "torch"))
def c0():
pass
@requires(backends=("torch", "torch"))
# That's a statement
class C1:
def __init__(self):
pass
@requires(backends=("torch", "torch"))
# That's a statement
def c1():
pass
@requires(backends=("torch", "torch"))
# That's a statement
class C2:
def __init__(self):
pass
@requires(backends=("torch", "torch"))
# That's a statement
def c2():
pass
@requires(
backends=(
"torch",
"torch"
)
)
# That's a statement
class C3:
def __init__(self):
pass
@requires(
backends=(
"torch",
"torch"
)
)
# That's a statement
def c3():
pass

View File

@@ -0,0 +1,74 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch
if is_torch_available():
import torch
from transformers.activations import gelu_new, gelu_python, get_activation
@require_torch
class TestActivations(unittest.TestCase):
def test_gelu_versions(self):
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
torch_builtin = get_activation("gelu")
torch.testing.assert_close(gelu_python(x), torch_builtin(x))
self.assertFalse(torch.allclose(gelu_python(x), gelu_new(x)))
def test_gelu_10(self):
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
torch_builtin = get_activation("gelu")
gelu10 = get_activation("gelu_10")
y_gelu = torch_builtin(x)
y_gelu_10 = gelu10(x)
clipped_mask = torch.where(y_gelu_10 < 10.0, 1, 0)
self.assertTrue(torch.max(y_gelu_10).item() == 10.0)
torch.testing.assert_close(y_gelu * clipped_mask, y_gelu_10 * clipped_mask)
def test_get_activation(self):
get_activation("gelu")
get_activation("gelu_10")
get_activation("gelu_fast")
get_activation("gelu_new")
get_activation("gelu_python")
get_activation("gelu_pytorch_tanh")
get_activation("linear")
get_activation("mish")
get_activation("quick_gelu")
get_activation("relu")
get_activation("sigmoid")
get_activation("silu")
get_activation("swish")
get_activation("tanh")
with self.assertRaises(KeyError):
get_activation("bogus")
with self.assertRaises(KeyError):
get_activation(None)
def test_activations_are_distinct_objects(self):
act1 = get_activation("gelu")
act1.a = 1
act2 = get_activation("gelu")
self.assertEqual(act1.a, 1)
with self.assertRaises(AttributeError):
_ = act2.a

View File

@@ -0,0 +1,806 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import textwrap
import unittest
from datetime import date
from pathlib import Path
from transformers.cli.add_new_model_like import ModelInfos, _add_new_model_like_internal
from transformers.testing_utils import require_torch
REPO_PATH = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
MODELS_TO_COPY = ("auto", "llama", "phi4_multimodal")
CURRENT_YEAR = date.today().year
@require_torch
class TestAddNewModelLike(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""
Create a temporary repo with the same structure as Transformers, with just 2 models.
"""
cls.tmp_dir = tempfile.TemporaryDirectory()
cls.FAKE_REPO = cls.tmp_dir.name
os.makedirs(os.path.join(cls.FAKE_REPO, "src", "transformers", "models"), exist_ok=True)
os.makedirs(os.path.join(cls.FAKE_REPO, "tests", "models"), exist_ok=True)
os.makedirs(os.path.join(cls.FAKE_REPO, "docs", "source", "en", "model_doc"), exist_ok=True)
# We need to copy the utils to run the cleanup commands
utils_src = os.path.join(REPO_PATH, "utils")
shutil.copytree(utils_src, utils_src.replace(REPO_PATH, cls.FAKE_REPO))
# Copy the __init__ files
model_init = os.path.join(REPO_PATH, "src", "transformers", "models", "__init__.py")
shutil.copy(model_init, model_init.replace(REPO_PATH, cls.FAKE_REPO))
doc_toc = os.path.join(REPO_PATH, "docs", "source", "en", "_toctree.yml")
shutil.copy(doc_toc, doc_toc.replace(REPO_PATH, cls.FAKE_REPO))
# We need the pyproject for ruff as well
pyproject = os.path.join(REPO_PATH, "pyproject.toml")
shutil.copy(pyproject, pyproject.replace(REPO_PATH, cls.FAKE_REPO))
# Copy over all the specific model files
for model in MODELS_TO_COPY:
model_src = os.path.join(REPO_PATH, "src", "transformers", "models", model)
shutil.copytree(model_src, model_src.replace(REPO_PATH, cls.FAKE_REPO))
test_src = os.path.join(REPO_PATH, "tests", "models", model)
shutil.copytree(test_src, test_src.replace(REPO_PATH, cls.FAKE_REPO))
if model != "auto":
doc_src = os.path.join(REPO_PATH, "docs", "source", "en", "model_doc", f"{model}.md")
shutil.copy(doc_src, doc_src.replace(REPO_PATH, cls.FAKE_REPO))
# For convenience
cls.MODEL_PATH = os.path.join(cls.FAKE_REPO, "src", "transformers", "models")
cls.TESTS_MODEL_PATH = os.path.join(cls.FAKE_REPO, "tests", "models")
cls.DOC_PATH = os.path.join(cls.FAKE_REPO, "docs", "source", "en")
@classmethod
def tearDownClass(cls):
cls.tmp_dir.cleanup()
def assertFileIsEqual(self, text: str, filepath: str):
with open(filepath, "r") as f:
file_text = f.read()
self.assertEqual(file_text.strip(), text.strip())
def assertInFile(self, text: str, filepath: str):
with open(filepath, "r") as f:
file_text = f.read()
self.assertTrue(text in file_text)
def test_llama_without_tokenizers(self):
# This is the structure without adding the tokenizers
filenames_to_add = (
("configuration_llama.py", True),
("modeling_llama.py", True),
("tokenization_llama.py", False),
("tokenization_llama_fast.py", False),
("image_processing_llama_pil.py", False),
("image_processing_llama.py", False),
("video_processing_llama.py", False),
("feature_extraction_llama.py", False),
("processing_llama.py", False),
)
# Run the command
_add_new_model_like_internal(
repo_path=Path(self.FAKE_REPO),
old_model_infos=ModelInfos("llama"),
new_lowercase_name="my_test",
new_model_paper_name="MyTest",
filenames_to_add=filenames_to_add,
)
# First assert that all files were created correctly
model_repo = os.path.join(self.MODEL_PATH, "my_test")
tests_repo = os.path.join(self.TESTS_MODEL_PATH, "my_test")
self.assertTrue(os.path.isfile(os.path.join(model_repo, "modular_my_test.py")))
self.assertTrue(os.path.isfile(os.path.join(model_repo, "modeling_my_test.py")))
self.assertTrue(os.path.isfile(os.path.join(model_repo, "configuration_my_test.py")))
self.assertTrue(os.path.isfile(os.path.join(model_repo, "__init__.py")))
self.assertTrue(os.path.isfile(os.path.join(self.DOC_PATH, "model_doc", "my_test.md")))
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "__init__.py")))
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_modeling_my_test.py")))
# Now assert the correct imports/auto mappings/toctree were added
self.assertInFile(
"from .my_test import *\n",
os.path.join(self.MODEL_PATH, "__init__.py"),
)
self.assertInFile(
'("my_test", "MyTestConfig"),\n',
os.path.join(self.MODEL_PATH, "auto", "auto_mappings.py"),
)
self.assertInFile(
'("my_test", "MyTestModel"),\n',
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
)
self.assertInFile(
'("my_test", "MyTestForCausalLM"),\n',
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
)
self.assertInFile(
'("my_test", "MyTestForSequenceClassification"),\n',
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
)
self.assertInFile(
'("my_test", "MyTestForQuestionAnswering"),\n',
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
)
self.assertInFile(
'("my_test", "MyTestForTokenClassification"),\n',
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
)
self.assertInFile(
"- local: model_doc/my_test\n title: MyTest\n",
os.path.join(self.DOC_PATH, "_toctree.yml"),
)
# Check some exact file creation. For model definition, only check modular as modeling/config/etc... are created
# directly from it
EXPECTED_MODULAR = textwrap.dedent(
f"""
# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..llama.configuration_llama import LlamaConfig
from ..llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaMLP,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
class MyTestConfig(LlamaConfig):
pass
class MyTestRMSNorm(LlamaRMSNorm):
pass
class MyTestRotaryEmbedding(LlamaRotaryEmbedding):
pass
class MyTestMLP(LlamaMLP):
pass
class MyTestAttention(LlamaAttention):
pass
class MyTestDecoderLayer(LlamaDecoderLayer):
pass
class MyTestPreTrainedModel(LlamaPreTrainedModel):
pass
class MyTestModel(LlamaModel):
pass
class MyTestForCausalLM(LlamaForCausalLM):
pass
class MyTestForSequenceClassification(LlamaForSequenceClassification):
pass
class MyTestForQuestionAnswering(LlamaForQuestionAnswering):
pass
class MyTestForTokenClassification(LlamaForTokenClassification):
pass
__all__ = [
"MyTestConfig",
"MyTestForCausalLM",
"MyTestModel",
"MyTestPreTrainedModel",
"MyTestForSequenceClassification",
"MyTestForQuestionAnswering",
"MyTestForTokenClassification",
]
"""
)
self.assertFileIsEqual(EXPECTED_MODULAR, os.path.join(model_repo, "modular_my_test.py"))
EXPECTED_INIT = textwrap.dedent(
f"""
# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_my_test import *
from .modeling_my_test import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
"""
)
self.assertFileIsEqual(EXPECTED_INIT, os.path.join(model_repo, "__init__.py"))
EXPECTED_DOC = textwrap.dedent(
f"""
<!--Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.
-->
# MyTest
## Overview
The MyTest model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>
The abstract from the paper is the following:
<INSERT PAPER ABSTRACT HERE>
Tips:
<INSERT TIPS ABOUT MODEL HERE>
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
## Usage examples
<INSERT SOME NICE EXAMPLES HERE>
## MyTestConfig
[[autodoc]] MyTestConfig
## MyTestForCausalLM
[[autodoc]] MyTestForCausalLM
## MyTestModel
[[autodoc]] MyTestModel
- forward
## MyTestPreTrainedModel
[[autodoc]] MyTestPreTrainedModel
- forward
## MyTestForSequenceClassification
[[autodoc]] MyTestForSequenceClassification
## MyTestForQuestionAnswering
[[autodoc]] MyTestForQuestionAnswering
## MyTestForTokenClassification
[[autodoc]] MyTestForTokenClassification
"""
)
self.assertFileIsEqual(EXPECTED_DOC, os.path.join(self.DOC_PATH, "model_doc", "my_test.md"))
def test_phi4_with_all_processors(self):
# This is the structure without adding the tokenizers
filenames_to_add = (
("configuration_phi4_multimodal.py", True),
("modeling_phi4_multimodal.py", True),
("tokenization_phi4_multimodal.py", False),
("tokenization_phi4_multimodal_fast.py", False),
("image_processing_phi4_multimodal_pil.py", False),
("image_processing_phi4_multimodal.py", True),
("video_processing_phi4_multimodal.py", False),
("feature_extraction_phi4_multimodal.py", True),
("processing_phi4_multimodal.py", True),
)
# Run the command
_add_new_model_like_internal(
repo_path=Path(self.FAKE_REPO),
old_model_infos=ModelInfos("phi4_multimodal"),
new_lowercase_name="my_test2",
new_model_paper_name="MyTest2",
filenames_to_add=filenames_to_add,
)
# First assert that all files were created correctly
model_repo = os.path.join(self.MODEL_PATH, "my_test2")
tests_repo = os.path.join(self.TESTS_MODEL_PATH, "my_test2")
self.assertTrue(os.path.isfile(os.path.join(model_repo, "modular_my_test2.py")))
self.assertTrue(os.path.isfile(os.path.join(model_repo, "modeling_my_test2.py")))
self.assertTrue(os.path.isfile(os.path.join(model_repo, "configuration_my_test2.py")))
self.assertTrue(os.path.isfile(os.path.join(model_repo, "image_processing_my_test2.py")))
self.assertTrue(os.path.isfile(os.path.join(model_repo, "feature_extraction_my_test2.py")))
self.assertTrue(os.path.isfile(os.path.join(model_repo, "processing_my_test2.py")))
self.assertTrue(os.path.isfile(os.path.join(model_repo, "__init__.py")))
self.assertTrue(os.path.isfile(os.path.join(self.DOC_PATH, "model_doc", "my_test2.md")))
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "__init__.py")))
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_modeling_my_test2.py")))
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_feature_extraction_my_test2.py")))
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_image_processing_my_test2.py")))
# Now assert the correct imports/auto mappings/toctree were added
self.assertInFile(
"from .my_test2 import *\n",
os.path.join(self.MODEL_PATH, "__init__.py"),
)
self.assertInFile(
'("my_test2", "MyTest2Config"),\n',
os.path.join(self.MODEL_PATH, "auto", "auto_mappings.py"),
)
self.assertInFile(
'("my_test2", "MyTest2Model"),\n',
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
)
self.assertInFile(
'("my_test2", "MyTest2ForCausalLM"),\n',
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
)
self.assertInFile(
'("my_test2", {"torchvision": "MyTest2ImageProcessor"}),\n',
os.path.join(self.MODEL_PATH, "auto", "auto_mappings.py"),
)
self.assertInFile(
'("my_test2", "MyTest2FeatureExtractor"),\n',
os.path.join(self.MODEL_PATH, "auto", "feature_extraction_auto.py"),
)
self.assertInFile(
'("my_test2", "MyTest2Processor"),\n',
os.path.join(self.MODEL_PATH, "auto", "processing_auto.py"),
)
self.assertInFile(
"- local: model_doc/my_test2\n title: MyTest2\n",
os.path.join(self.DOC_PATH, "_toctree.yml"),
)
# Check some exact file creation. For model definition, only check modular as modeling/config/etc... are created
# directly from it
EXPECTED_MODULAR = textwrap.dedent(
f"""
# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..phi4_multimodal.configuration_phi4_multimodal import (
Phi4MultimodalAudioConfig,
Phi4MultimodalConfig,
Phi4MultimodalVisionConfig,
)
from ..phi4_multimodal.feature_extraction_phi4_multimodal import Phi4MultimodalFeatureExtractor
from ..phi4_multimodal.image_processing_phi4_multimodal import (
Phi4MultimodalImageProcessor,
Phi4MultimodalImageProcessorKwargs,
)
from ..phi4_multimodal.modeling_phi4_multimodal import (
Phi4MultimodalAttention,
Phi4MultimodalAudioAttention,
Phi4MultimodalAudioConformerEncoderLayer,
Phi4MultimodalAudioConvModule,
Phi4MultimodalAudioDepthWiseSeparableConv1d,
Phi4MultimodalAudioEmbedding,
Phi4MultimodalAudioGluPointWiseConv,
Phi4MultimodalAudioMeanVarianceNormLayer,
Phi4MultimodalAudioMLP,
Phi4MultimodalAudioModel,
Phi4MultimodalAudioNemoConvSubsampling,
Phi4MultimodalAudioPreTrainedModel,
Phi4MultimodalAudioRelativeAttentionBias,
Phi4MultimodalDecoderLayer,
Phi4MultimodalFeatureEmbedding,
Phi4MultimodalForCausalLM,
Phi4MultimodalImageEmbedding,
Phi4MultimodalMLP,
Phi4MultimodalModel,
Phi4MultimodalPreTrainedModel,
Phi4MultimodalRMSNorm,
Phi4MultimodalRotaryEmbedding,
Phi4MultimodalVisionAttention,
Phi4MultimodalVisionEmbeddings,
Phi4MultimodalVisionEncoder,
Phi4MultimodalVisionEncoderLayer,
Phi4MultimodalVisionMLP,
Phi4MultimodalVisionModel,
Phi4MultimodalVisionMultiheadAttentionPoolingHead,
Phi4MultimodalVisionPreTrainedModel,
)
from ..phi4_multimodal.processing_phi4_multimodal import Phi4MultimodalProcessor, Phi4MultimodalProcessorKwargs
class MyTest2VisionConfig(Phi4MultimodalVisionConfig):
pass
class MyTest2AudioConfig(Phi4MultimodalAudioConfig):
pass
class MyTest2Config(Phi4MultimodalConfig):
pass
class MyTest2VisionMLP(Phi4MultimodalVisionMLP):
pass
class MyTest2VisionAttention(Phi4MultimodalVisionAttention):
pass
class MyTest2VisionEncoderLayer(Phi4MultimodalVisionEncoderLayer):
pass
class MyTest2VisionEncoder(Phi4MultimodalVisionEncoder):
pass
class MyTest2VisionPreTrainedModel(Phi4MultimodalVisionPreTrainedModel):
pass
class MyTest2VisionEmbeddings(Phi4MultimodalVisionEmbeddings):
pass
class MyTest2VisionMultiheadAttentionPoolingHead(Phi4MultimodalVisionMultiheadAttentionPoolingHead):
pass
class MyTest2VisionModel(Phi4MultimodalVisionModel):
pass
class MyTest2ImageEmbedding(Phi4MultimodalImageEmbedding):
pass
class MyTest2AudioMLP(Phi4MultimodalAudioMLP):
pass
class MyTest2AudioAttention(Phi4MultimodalAudioAttention):
pass
class MyTest2AudioDepthWiseSeparableConv1d(Phi4MultimodalAudioDepthWiseSeparableConv1d):
pass
class MyTest2AudioGluPointWiseConv(Phi4MultimodalAudioGluPointWiseConv):
pass
class MyTest2AudioConvModule(Phi4MultimodalAudioConvModule):
pass
class MyTest2AudioConformerEncoderLayer(Phi4MultimodalAudioConformerEncoderLayer):
pass
class MyTest2AudioNemoConvSubsampling(Phi4MultimodalAudioNemoConvSubsampling):
pass
class MyTest2AudioRelativeAttentionBias(Phi4MultimodalAudioRelativeAttentionBias):
pass
class MyTest2AudioMeanVarianceNormLayer(Phi4MultimodalAudioMeanVarianceNormLayer):
pass
class MyTest2AudioPreTrainedModel(Phi4MultimodalAudioPreTrainedModel):
pass
class MyTest2AudioModel(Phi4MultimodalAudioModel):
pass
class MyTest2AudioEmbedding(Phi4MultimodalAudioEmbedding):
pass
class MyTest2RMSNorm(Phi4MultimodalRMSNorm):
pass
class MyTest2MLP(Phi4MultimodalMLP):
pass
class MyTest2Attention(Phi4MultimodalAttention):
pass
class MyTest2DecoderLayer(Phi4MultimodalDecoderLayer):
pass
class MyTest2FeatureEmbedding(Phi4MultimodalFeatureEmbedding):
pass
class MyTest2PreTrainedModel(Phi4MultimodalPreTrainedModel):
pass
class MyTest2RotaryEmbedding(Phi4MultimodalRotaryEmbedding):
pass
class MyTest2Model(Phi4MultimodalModel):
pass
class MyTest2ForCausalLM(Phi4MultimodalForCausalLM):
pass
class MyTest2ImageProcessorKwargs(Phi4MultimodalImageProcessorKwargs):
pass
class MyTest2ImageProcessor(Phi4MultimodalImageProcessor):
pass
class MyTest2FeatureExtractor(Phi4MultimodalFeatureExtractor):
pass
class MyTest2ProcessorKwargs(Phi4MultimodalProcessorKwargs):
pass
class MyTest2Processor(Phi4MultimodalProcessor):
pass
__all__ = [
"MyTest2VisionConfig",
"MyTest2AudioConfig",
"MyTest2Config",
"MyTest2AudioPreTrainedModel",
"MyTest2AudioModel",
"MyTest2VisionPreTrainedModel",
"MyTest2VisionModel",
"MyTest2PreTrainedModel",
"MyTest2Model",
"MyTest2ForCausalLM",
"MyTest2ImageProcessor",
"MyTest2FeatureExtractor",
"MyTest2Processor",
]
"""
)
self.assertFileIsEqual(EXPECTED_MODULAR, os.path.join(model_repo, "modular_my_test2.py"))
EXPECTED_INIT = textwrap.dedent(
f"""
# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_my_test2 import *
from .feature_extraction_my_test2 import *
from .image_processing_my_test2 import *
from .modeling_my_test2 import *
from .processing_my_test2 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
"""
)
self.assertFileIsEqual(EXPECTED_INIT, os.path.join(model_repo, "__init__.py"))
EXPECTED_DOC = textwrap.dedent(
f"""
<!--Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.
-->
# MyTest2
## Overview
The MyTest2 model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>
The abstract from the paper is the following:
<INSERT PAPER ABSTRACT HERE>
Tips:
<INSERT TIPS ABOUT MODEL HERE>
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
## Usage examples
<INSERT SOME NICE EXAMPLES HERE>
## MyTest2VisionConfig
[[autodoc]] MyTest2VisionConfig
## MyTest2AudioConfig
[[autodoc]] MyTest2AudioConfig
## MyTest2Config
[[autodoc]] MyTest2Config
## MyTest2AudioPreTrainedModel
[[autodoc]] MyTest2AudioPreTrainedModel
- forward
## MyTest2AudioModel
[[autodoc]] MyTest2AudioModel
- forward
## MyTest2VisionPreTrainedModel
[[autodoc]] MyTest2VisionPreTrainedModel
- forward
## MyTest2VisionModel
[[autodoc]] MyTest2VisionModel
- forward
## MyTest2PreTrainedModel
[[autodoc]] MyTest2PreTrainedModel
- forward
## MyTest2Model
[[autodoc]] MyTest2Model
- forward
## MyTest2ForCausalLM
[[autodoc]] MyTest2ForCausalLM
## MyTest2ImageProcessor
[[autodoc]] MyTest2ImageProcessor
## MyTest2FeatureExtractor
[[autodoc]] MyTest2FeatureExtractor
## MyTest2Processor
[[autodoc]] MyTest2Processor
"""
)
self.assertFileIsEqual(EXPECTED_DOC, os.path.join(self.DOC_PATH, "model_doc", "my_test2.md"))

View File

@@ -0,0 +1,125 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import io
import re
import unittest
from transformers.testing_utils import require_torch
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
ANSI_RE = re.compile(r"\x1b\[[0-9;]*m")
def _normalize(s: str) -> str:
# drop ANSI (colors may be disabled on CI), normalize line endings,
# and strip trailing spaces without touching alignment inside lines
s = ANSI_RE.sub("", s)
s = s.replace("\r\n", "\n").replace("\r", "\n")
return "\n".join(line.rstrip() for line in s.split("\n")).strip()
@require_torch
class AttentionMaskVisualizerTester(unittest.TestCase):
"""Test suite for AttentionMaskVisualizer"""
def test_paligemma_multimodal_visualization(self):
"""Test AttentionMaskVisualizer with PaliGemma multimodal model"""
model_name = "hf-internal-testing/namespace_google_repo_name_paligemma-3b-pt-224"
input_text = "<img> What is in this image?"
buf = io.StringIO()
orig_print = builtins.print
def _print(*args, **kwargs):
kwargs.setdefault("file", buf)
orig_print(*args, **kwargs)
try:
builtins.print = _print
visualizer = AttentionMaskVisualizer(model_name)
visualizer(input_text)
finally:
builtins.print = orig_print
output = buf.getvalue()
expected_output = """
##########################################################################################################################################################################################################################################
## Attention visualization for \033[1mpaligemma:hf-internal-testing/namespace_google_repo_name_paligemma-3b-pt-224\033[0m PaliGemmaModel ##
##########################################################################################################################################################################################################################################
\033[92m■\033[0m: i == j (diagonal) \033[93m■\033[0m: token_type_ids
Attention Matrix
\033[93m'<image>'\033[0m: 0 \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
\033[93m'<image>'\033[0m: 1 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
\033[93m'<image>'\033[0m: 2 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
\033[93m'<image>'\033[0m: 3 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
\033[93m'<image>'\033[0m: 4 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
'<bos>' : 5 ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
'▁What' : 6 ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
'▁is' : 7 ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
'▁in' : 8 ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ |
'▁this' : 9 ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ |
'▁image' : 10 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ |
'?' : 11 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ |
'\\n' : 12 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ |
'<eos>' : 13 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m |
##########################################################################################################################################################################################################################################
""" # noqa
self.assertEqual(_normalize(output), _normalize(expected_output))
def test_llama_text_only_visualization(self):
"""Test AttentionMaskVisualizer with Llama text-only model"""
model_name = "hf-internal-testing/namespace_meta-llama_repo_name_Llama-2-7b-hf"
input_text = "Plants create energy through a process known as"
buf = io.StringIO()
orig_print = builtins.print
def _print(*args, **kwargs):
kwargs.setdefault("file", buf)
orig_print(*args, **kwargs)
try:
builtins.print = _print
visualizer = AttentionMaskVisualizer(model_name)
visualizer(input_text)
finally:
builtins.print = orig_print
output = buf.getvalue()
expected_output = """
##########################################################################################################################################################################################################
## Attention visualization for \033[1mllama:hf-internal-testing/namespace_meta-llama_repo_name_Llama-2-7b-hf\033[0m LlamaModel ##
##########################################################################################################################################################################################################
\033[92m■\033[0m: i == j (diagonal) \033[93m■\033[0m: token_type_ids
Attention Matrix
'▁Pl' : 0 \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
'ants' : 1 ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
'▁create' : 2 ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
'▁energy' : 3 ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ |
'▁through': 4 ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ |
'▁a' : 5 ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ |
'▁process': 6 ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ |
'▁known' : 7 ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ |
'▁as' : 8 ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m |
##########################################################################################################################################################################################################
""" # noqa
self.assertEqual(_normalize(output), _normalize(expected_output))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,761 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for auto_docstring decorator and check_auto_docstrings function.
"""
import importlib
import os
import statistics
import sys
import tempfile
import textwrap
import time
import unittest
from pathlib import Path
import torch
from huggingface_hub.dataclasses import strict
from transformers.configuration_utils import PretrainedConfig
from transformers.image_processing_backends import TorchvisionBackend
from transformers.image_processing_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from transformers.testing_utils import require_torch
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils.auto_docstring import (
auto_docstring,
)
from transformers.utils.import_utils import is_torch_available
if is_torch_available():
import torch
_repo_root = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(_repo_root / "utils"))
from check_docstrings import ( # noqa: E402
_build_ast_indexes,
_find_typed_dict_classes,
find_files_with_auto_docstring,
update_file_with_new_docstrings,
)
class TestCheckDocstrings(unittest.TestCase):
"""Test check_auto_docstrings static analysis tool for detecting and fixing docstring issues."""
def test_missing_args_detection_and_placeholder_generation(self):
"""Test that missing custom args are detected and placeholders generated while preserving Examples and code."""
with tempfile.TemporaryDirectory() as tmpdir:
test_file = os.path.join(tmpdir, "model.py")
original = textwrap.dedent("""
from transformers.utils.auto_docstring import auto_docstring
@auto_docstring
def forward(self, input_ids, custom_temperature: float = 1.0):
'''
Example:
```python
>>> model.forward(input_ids, custom_temperature=0.7)
```
'''
result = input_ids * custom_temperature
return result
""")
with open(test_file, "w") as f:
f.write(original)
with open(test_file, "r") as f:
content = f.read()
items = _build_ast_indexes(content)
lines = content.split("\n")
# Test detection (overwrite=False) - should detect missing arg
missing, fill, redundant = update_file_with_new_docstrings(
test_file, lines, items, content, overwrite=False
)
self.assertTrue(any("custom_temperature" in msg for msg in missing))
# Generate placeholders (overwrite=True)
update_file_with_new_docstrings(test_file, lines, items, content, overwrite=True)
with open(test_file, "r") as f:
updated = f.read()
# Verify results
self.assertIn("custom_temperature", updated)
self.assertIn("<fill_docstring>", updated) # Placeholder added
self.assertIn("input_ids", updated) # Standard arg from ModelArgs
self.assertIn("Example:", updated) # Example preserved
self.assertIn("result = input_ids * custom_temperature", updated) # Code preserved
def test_multi_item_file_processing(self):
"""Test processing files with multiple @auto_docstring decorators (class + method) in a single pass."""
with tempfile.TemporaryDirectory() as tmpdir:
test_file = os.path.join(tmpdir, "modeling.py")
original = textwrap.dedent("""
from transformers.utils.auto_docstring import auto_docstring
from transformers.modeling_utils import PreTrainedModel
@auto_docstring
class MyModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.layer = None
@auto_docstring
def forward(self, input_ids, scale_factor: float = 1.0):
'''
Example:
```python
>>> outputs = model.forward(input_ids, scale_factor=2.0)
```
'''
return self.layer(input_ids) * scale_factor
""")
with open(test_file, "w") as f:
f.write(original)
with open(test_file, "r") as f:
content = f.read()
items = _build_ast_indexes(content)
# Should find 2 decorated items
self.assertEqual(len(items), 2)
self.assertEqual(items[0].kind, "class")
self.assertEqual(items[1].kind, "function")
lines = content.split("\n")
# Detect issues
missing, fill, redundant = update_file_with_new_docstrings(
test_file, lines, items, content, overwrite=False
)
# Should detect missing scale_factor in forward method
self.assertTrue(any("scale_factor" in msg for msg in missing))
# Update file
update_file_with_new_docstrings(test_file, lines, items, content, overwrite=True)
with open(test_file, "r") as f:
updated = f.read()
# Verify updates and preservation
self.assertIn("scale_factor", updated) # Custom arg added with placeholder
self.assertIn("<fill_docstring>", updated) # Placeholder present
self.assertIn("Example:", updated) # Example preserved
self.assertIn("self.layer = None", updated) # __init__ code preserved
self.assertIn("return self.layer(input_ids) * scale_factor", updated) # forward code preserved
def test_typed_dict_field_detection(self):
"""Test that _find_typed_dict_classes correctly identifies custom fields vs standard inherited fields."""
content = textwrap.dedent("""
from typing import TypedDict
from transformers.processing_utils import ImagesKwargs
class CustomImageKwargs(ImagesKwargs, total=False):
'''
custom_mode (`str`):
Custom processing mode.
'''
# Standard field from ImagesKwargs - should be in all_fields but not fields
do_resize: bool
# Custom fields - should be in both all_fields and fields
custom_mode: str
undocumented_custom: int
""")
typed_dicts = _find_typed_dict_classes(content)
# Should find the TypedDict
self.assertEqual(len(typed_dicts), 1)
self.assertEqual(typed_dicts[0]["name"], "CustomImageKwargs")
# all_fields includes everything
self.assertIn("do_resize", typed_dicts[0]["all_fields"])
self.assertIn("custom_mode", typed_dicts[0]["all_fields"])
self.assertIn("undocumented_custom", typed_dicts[0]["all_fields"])
# fields only includes custom fields (not standard args like do_resize)
# Both documented and undocumented custom fields are included
self.assertIn("custom_mode", typed_dicts[0]["fields"])
self.assertIn("undocumented_custom", typed_dicts[0]["fields"])
self.assertNotIn("do_resize", typed_dicts[0]["fields"]) # Standard arg excluded
def test_file_discovery_finds_decorated_files(self):
"""Test that check_auto_docstrings can discover files containing @auto_docstring."""
with tempfile.TemporaryDirectory() as tmpdir:
has_decorator = os.path.join(tmpdir, "modeling.py")
no_decorator = os.path.join(tmpdir, "utils.py")
with open(has_decorator, "w") as f:
f.write("@auto_docstring\ndef forward(self): pass")
with open(no_decorator, "w") as f:
f.write("def helper(): pass")
found = find_files_with_auto_docstring([has_decorator, no_decorator])
self.assertEqual(len(found), 1)
self.assertEqual(found[0], has_decorator)
class DummyConfig(PretrainedConfig):
model_type = "dummy_test"
def __init__(self, vocab_size=1000, hidden_size=768, num_attention_heads=12, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
@auto_docstring
class DummyForTestModel(PreTrainedModel):
config_class = DummyConfig
def __init__(self, config: DummyConfig):
super().__init__(config)
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
position_ids: torch.LongTensor | None = None,
inputs_embeds: torch.FloatTensor | None = None,
temperature: float = 1.0,
custom_dict: dict[str, int | float] | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
) -> CausalLMOutputWithPast:
r"""
temperature (`float`, *optional*, defaults to 1.0):
Temperature value for scaling logits during generation.
custom_dict (`dict[str, Union[int, float]]`, *optional*):
Custom dictionary parameter with string keys and numeric values.
Example:
```python
>>> from transformers import AutoTokenizer, DummyForTestModel
>>> import torch
>>> model = DummyForTestModel.from_pretrained("dummy-model")
>>> tokenizer = AutoTokenizer.from_pretrained("dummy-model")
>>> inputs = tokenizer("Hello world", return_tensors="pt")
>>> outputs = model.forward(**inputs, temperature=0.7)
>>> logits = outputs.logits
```
"""
pass
class ComplexProcessorKwargs(ProcessingKwargs, total=False):
r"""
custom_processing_mode (`str`, *optional*, defaults to `"standard"`):
Custom processing mode for advanced text/image processing. Can be 'standard', 'enhanced', or 'experimental'.
enable_advanced_features (`bool`, *optional*, defaults to `False`):
Whether to enable advanced processing features like custom tokenization strategies.
custom_threshold (`float`, *optional*, defaults to 0.5):
Custom threshold value for filtering or processing decisions.
output_format (`str`, *optional*, defaults to `"default"`):
Output format specification. Can be 'default', 'extended', or 'minimal'.
"""
custom_processing_mode: str
enable_advanced_features: bool
custom_threshold: float
output_format: str
@auto_docstring
class DummyProcessorForTest(ProcessorMixin):
def __init__(
self,
image_processor=None,
tokenizer=None,
custom_processing_mode="standard",
enable_advanced_features=False,
custom_threshold=0.5,
output_format="default",
**kwargs,
):
r"""
custom_processing_mode (`str`, *optional*, defaults to `"standard"`):
Custom processing mode for advanced text/image processing. Can be 'standard', 'enhanced', or 'experimental'.
enable_advanced_features (`bool`, *optional*, defaults to `False`):
Whether to enable advanced processing features like custom tokenization strategies.
custom_threshold (`float`, *optional*, defaults to 0.5):
Custom threshold value for filtering or processing decisions.
output_format (`str`, *optional*, defaults to `"default"`):
Output format specification. Can be 'default', 'extended', or 'minimal'.
"""
pass
@auto_docstring
def __call__(
self,
images: ImageInput | None = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
**kwargs: Unpack[ComplexProcessorKwargs],
) -> BatchFeature:
r"""
Example:
```python
>>> from transformers import DummyProcessorForTest
>>> processor = DummyProcessorForTest.from_pretrained("dummy-processor")
>>> inputs = processor(text="Hello world", images=["image.jpg"], return_tensors="pt")
```
"""
pass
class DummyImageProcessorKwargs(ImagesKwargs, total=False):
r"""
image_grid_pinpoints (`list[list[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method.
custom_scale (`float`, *optional*, defaults to 255.0):
Custom scale factor for preprocessing pipelines.
"""
image_grid_pinpoints: list[list[int]]
custom_scale: float
@auto_docstring(
custom_intro="""
Constructs a fast DummyForTest image processor.
"""
)
class DummyForTestImageProcessorFast(TorchvisionBackend):
model_input_names = ["pixel_values"]
valid_kwargs = DummyImageProcessorKwargs
def __init__(self, **kwargs: Unpack[DummyImageProcessorKwargs]):
super().__init__(**kwargs)
@auto_docstring
def preprocess(
self,
images: ImageInput,
**kwargs: Unpack[DummyImageProcessorKwargs],
) -> BatchFeature:
r"""
Example:
```python
>>> from transformers import DummyForTestImageProcessorFast
>>> from PIL import Image
>>> import requests
>>> processor = DummyForTestImageProcessorFast.from_pretrained("dummy-processor")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor.preprocess(images=image, return_tensors="pt")
```
"""
pass
@strict
@auto_docstring(custom_intro="A minimal configuration for testing that config docstrings only show own args.")
class DummyStrictConfig(PretrainedConfig):
r"""
extra_param (`str`, *optional*, defaults to `"test"`):
A custom parameter unique to this config, not inherited from PreTrainedConfig.
"""
model_type = "dummy_strict_for_docstring_test"
hidden_size: int = 256
num_layers: int = 6
extra_param: str = "test"
@require_torch
class TestFullDocstringGeneration(unittest.TestCase):
"""
End-to-end tests for @auto_docstring runtime docstring generation.
Tests validate complete docstrings with single assertEqual assertions to ensure structure,
formatting, standard args, custom params, and TypedDict unrolling work correctly.
"""
def test_strict_config_docstring_only_documents_own_args(self):
"""Test that config docstrings only document own class annotations, not inherited PreTrainedConfig args."""
self.maxDiff = None
actual_docstring = DummyStrictConfig.__doc__
expected_docstring = """A minimal configuration for testing that config docstrings only show own args.
Args:
hidden_size (`int`, *optional*, defaults to `256`):
Dimension of the hidden representations.
num_layers (`int`, *optional*, defaults to `6`):
Number of hidden layers in the Transformer decoder.
extra_param (`str`, *optional*, defaults to `"test"`):
A custom parameter unique to this config, not inherited from PreTrainedConfig.
"""
self.assertEqual(actual_docstring, expected_docstring)
def test_dummy_model_complete_docstring(self):
self.maxDiff = None
"""Test complete class and forward method docstrings for PreTrainedModel with ModelArgs and custom parameters."""
actual_class_docstring = DummyForTestModel.__doc__
expected_class_docstring = """
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`DummyConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
self.assertEqual(actual_class_docstring, expected_class_docstring)
actual_docstring = DummyForTestModel.forward.__doc__
expected_docstring = """ The [`DummyForTestModel`] forward method, overrides the `__call__` special method.
<Tip>
Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
</Tip>
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
temperature (`float`, *optional*, defaults to 1.0):
Temperature value for scaling logits during generation.
custom_dict (`dict[str, Union[int, float]]`, *optional*):
Custom dictionary parameter with string keys and numeric values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
[`~modeling_outputs.CausalLMOutputWithPast`] or `tuple(torch.FloatTensor)`: A [`~modeling_outputs.CausalLMOutputWithPast`] or a tuple of
`torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
elements depending on the configuration ([`None`]) and inputs.
- **loss** (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) -- Language modeling loss (for next-token prediction).
- **logits** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) -- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- **past_key_values** (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`) -- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
- **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) -- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Example:
```python
>>> from transformers import AutoTokenizer, DummyForTestModel
>>> import torch
>>> model = DummyForTestModel.from_pretrained("dummy-model")
>>> tokenizer = AutoTokenizer.from_pretrained("dummy-model")
>>> inputs = tokenizer("Hello world", return_tensors="pt")
>>> outputs = model.forward(**inputs, temperature=0.7)
>>> logits = outputs.logits
```
"""
self.assertEqual(actual_docstring, expected_docstring)
def test_dummy_processor_complete_docstring(self):
self.maxDiff = None
"""Test complete class and __call__ docstrings for ProcessorMixin with complex TypedDict kwargs unrolling."""
actual_docstring = DummyProcessorForTest.__call__.__doc__
expected_docstring = """ Args:
images (`Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, list[PIL.Image.Image], list[numpy.ndarray], list[torch.Tensor]]`, *optional*):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
text (`Union[str, list[str], list[list[str]]]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If you pass a pretokenized input, set `is_split_into_words=True` to avoid ambiguity with batched inputs.
custom_processing_mode (`str`, *kwargs*, *optional*, defaults to `"standard"`):
Custom processing mode for advanced text/image processing. Can be 'standard', 'enhanced', or 'experimental'.
enable_advanced_features (`bool`, *kwargs*, *optional*, defaults to `False`):
Whether to enable advanced processing features like custom tokenization strategies.
custom_threshold (`float`, *kwargs*, *optional*, defaults to 0.5):
Custom threshold value for filtering or processing decisions.
output_format (`str`, *kwargs*, *optional*, defaults to `"default"`):
Output format specification. Can be 'default', 'extended', or 'minimal'.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
**kwargs ([`ProcessingKwargs`], *optional*):
Additional processing options for each modality (text, images, videos, audio). Model-specific parameters
are listed above; see the TypedDict class for the complete list of supported arguments.
Returns:
`~image_processing_base.BatchFeature`:
- **data** (`dict`) -- Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
- **tensor_type** (`Union[None, str, TensorType]`, *optional*) -- You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
initialization.
Example:
```python
>>> from transformers import DummyProcessorForTest
>>> processor = DummyProcessorForTest.from_pretrained("dummy-processor")
>>> inputs = processor(text="Hello world", images=["image.jpg"], return_tensors="pt")
```
"""
self.assertEqual(actual_docstring, expected_docstring)
actual_class_docstring = DummyProcessorForTest.__doc__
expected_class_docstring = """Constructs a DummyProcessorForTest which wraps a image processor and a tokenizer into a single processor.
[`DummyProcessorForTest`] offers all the functionalities of [`image_processor_class`] and [`tokenizer_class`]. See the
[`~image_processor_class`] and [`~tokenizer_class`] for more information.
Parameters:
image_processor (`image_processor_class`):
The image processor is a required input.
tokenizer (`tokenizer_class`):
The tokenizer is a required input.
custom_processing_mode (`str`, *optional*, defaults to `"standard"`):
Custom processing mode for advanced text/image processing. Can be 'standard', 'enhanced', or 'experimental'.
enable_advanced_features (`bool`, *optional*, defaults to `False`):
Whether to enable advanced processing features like custom tokenization strategies.
custom_threshold (`float`, *optional*, defaults to 0.5):
Custom threshold value for filtering or processing decisions.
output_format (`str`, *optional*, defaults to `"default"`):
Output format specification. Can be 'default', 'extended', or 'minimal'.
"""
self.assertEqual(actual_class_docstring, expected_class_docstring)
def test_dummy_image_processor_complete_docstring(self):
self.maxDiff = None
"""Test complete class and preprocess docstrings for DummyForTestImageProcessorFast with custom ImagesKwargs and custom_intro."""
actual_preprocess_docstring = DummyForTestImageProcessorFast.preprocess.__doc__
expected_preprocess_docstring = """ Args:
images (`Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, list[PIL.Image.Image], list[numpy.ndarray], list[torch.Tensor]]`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
image_grid_pinpoints (`list[list[int]]`, *kwargs*, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method.
custom_scale (`float`, *kwargs*, *optional*, defaults to 255.0):
Custom scale factor for preprocessing pipelines.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
Returns stacked tensors if set to `'pt'`, otherwise returns a list of tensors.
**kwargs ([`ImagesKwargs`], *optional*):
Additional image preprocessing options. Model-specific kwargs are listed above; see the TypedDict class
for the complete list of supported arguments.
Returns:
`~image_processing_base.BatchFeature`:
- **data** (`dict`) -- Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
- **tensor_type** (`Union[None, str, TensorType]`, *optional*) -- You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
initialization.
Example:
```python
>>> from transformers import DummyForTestImageProcessorFast
>>> from PIL import Image
>>> import requests
>>> processor = DummyForTestImageProcessorFast.from_pretrained("dummy-processor")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor.preprocess(images=image, return_tensors="pt")
```
"""
self.assertEqual(actual_preprocess_docstring, expected_preprocess_docstring)
actual_class_docstring = DummyForTestImageProcessorFast.__doc__
expected_class_docstring = """
Constructs a fast DummyForTest image processor.
Args:
image_grid_pinpoints (`list[list[int]]`, *kwargs*, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method.
custom_scale (`float`, *kwargs*, *optional*, defaults to 255.0):
Custom scale factor for preprocessing pipelines.
**kwargs ([`ImagesKwargs`], *optional*):
Additional image preprocessing options. Model-specific kwargs are listed above; see the TypedDict class
for the complete list of supported arguments.
"""
self.assertEqual(actual_class_docstring, expected_class_docstring)
# ---------------------------------------------------------------------------
# Performance tests for auto_docstring
# ---------------------------------------------------------------------------
class TestAutoDocstringPerformance:
"""
Performance tests for auto_docstring.
The decorator runs at *class-definition / import time*, so with hundreds of
models in the library the cumulative cost matters even though each individual
call looks cheap. These tests assert an upper bound to catch regressions.
"""
# Upper bound (%) of total import time that auto_docstring overhead may take.
# Relative metric; robust across CI vs local. Catches serious regressions.
AUTO_DOCSTRING_COST_PCT_UPPER_BOUND = 70.0
def test_auto_docstring_import_time_upper_bound(self):
"""
Asserts that auto_docstring overhead stays below a percentage of total
import time.
Method
------
1. Collect ``modeling_*.py``, ``image_processing_*.py``, ``processing_*.py``
under ``transformers/models``, then sample every 10th for speed.
2. Warmup: import the sampled modules once so Python's bytecode cache is hot.
3. Measure WITH auto_docstring: clear cache, re-import, median over 5 runs.
4. Measure WITHOUT auto_docstring: noop-patch, clear cache, re-import, median.
5. cost_pct = (real - noop) / real * 100; assert cost_pct < upper bound.
"""
if "transformers.utils" not in sys.modules:
importlib.import_module("transformers.utils")
_utils_module = sys.modules["transformers.utils"]
src_root = Path(__file__).resolve().parent.parent.parent / "src"
models_dir = src_root / "transformers" / "models"
all_modules: list[str] = []
for pattern in ("modeling_*.py", "image_processing_*.py", "processing_*.py"):
for f in sorted(models_dir.rglob(pattern)):
rel = f.with_suffix("").relative_to(src_root)
all_modules.append(".".join(rel.parts))
model_modules = all_modules[::10]
def _clear():
for key in [k for k in sys.modules if k.startswith("transformers.models")]:
del sys.modules[key]
def _import_all():
for mod in model_modules:
try:
importlib.import_module(mod)
except Exception:
continue
_import_all() # warmup
# With auto_docstring (real)
times_real: list[float] = []
for _ in range(5):
_clear()
t0 = time.perf_counter()
_import_all()
times_real.append(time.perf_counter() - t0)
# Without auto_docstring (noop patch)
_orig = _utils_module.auto_docstring
_noop = lambda x=None, **kw: (lambda f: f) if x is None else x # noqa: E731
times_noop: list[float] = []
for _ in range(5):
_utils_module.auto_docstring = _noop
try:
_clear()
t0 = time.perf_counter()
_import_all()
times_noop.append(time.perf_counter() - t0)
finally:
_utils_module.auto_docstring = _orig
median_real = statistics.median(times_real)
median_noop = statistics.median(times_noop)
cost_pct = (median_real - median_noop) / median_real * 100 if median_real > 0 else 0.0
print(f"Cost percentage: {cost_pct:.1f}%")
assert cost_pct < self.AUTO_DOCSTRING_COST_PCT_UPPER_BOUND, (
f"auto_docstring cost {cost_pct:.1f}% of import time exceeds upper bound "
f"{self.AUTO_DOCSTRING_COST_PCT_UPPER_BOUND}% "
f"({len(model_modules)} of {len(all_modules)} modules)"
)

View File

@@ -0,0 +1,151 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
from transformers import PreTrainedConfig
from transformers.backbone_utils import (
BackboneConfigMixin,
BackboneMixin,
)
from transformers.testing_utils import require_torch
from transformers.utils.import_utils import is_torch_available
if is_torch_available():
from transformers import PreTrainedModel
class AnyBackboneConfig(BackboneConfigMixin, PreTrainedConfig):
def __init__(
self,
stage_names: list | None = None,
out_indices: list | None = None,
out_features: list | None = None,
**kwargs,
):
self.stage_names = stage_names
self.set_output_features_output_indices(out_features=out_features, out_indices=out_indices)
super().__init__(**kwargs)
@require_torch
class AnyBackbone(BackboneMixin, PreTrainedModel): ...
class BackboneUtilsTester(unittest.TestCase):
def test_get_aligned_output_features_output_indices(self):
stage_names = ["a", "b", "c"]
# Defaults to last layer if both, `out_indices` and `out_features`, are None
config = AnyBackboneConfig(stage_names)
self.assertEqual(config.out_features, ["c"])
self.assertEqual(config.out_indices, [2])
# Out indices set to match out features
config = AnyBackboneConfig(stage_names=stage_names, out_features=["a", "c"])
self.assertEqual(config.out_features, ["a", "c"])
self.assertEqual(config.out_indices, [0, 2])
# Out features set to match out indices
config = AnyBackboneConfig(stage_names=stage_names, out_indices=[0, 2])
self.assertEqual(config.out_features, ["a", "c"])
self.assertEqual(config.out_indices, [0, 2])
# Out features selected from negative indices
config = AnyBackboneConfig(stage_names=stage_names, out_indices=[-3, -1])
self.assertEqual(config.out_features, ["a", "c"])
self.assertEqual(config.out_indices, [-3, -1])
def test_config_verify_out_features_out_indices(self):
# Stage names must be set
with pytest.raises(ValueError, match="Stage_names must be set for transformers backbones"):
AnyBackboneConfig(stage_names=None, out_features=["a", "b"], out_indices=(0, 1))
# Out features must be a list
with pytest.raises(ValueError, match="out_features must be a list got <class 'tuple'>"):
AnyBackboneConfig(stage_names=["a", "b"], out_features=("a", "b"), out_indices=[0, 1])
# Out features must be a subset of stage names
with pytest.raises(
ValueError, match=r"out_features must be a subset of stage_names: \['a'\] got \['a', 'b'\]"
):
AnyBackboneConfig(stage_names=["a"], out_features=["a", "b"], out_indices=[0, 1])
# Out features must contain no duplicates
with pytest.raises(ValueError, match=r"out_features must not contain any duplicates, got \['a', 'a'\]"):
AnyBackboneConfig(stage_names=["a"], out_features=["a", "a"], out_indices=None)
# Out indices must be a list
with pytest.raises(ValueError, match="out_indices must be a list, got <class 'int'>"):
AnyBackboneConfig(stage_names=["a", "b"], out_features=None, out_indices=0)
# Out indices must be a subset of stage names
with pytest.raises(
ValueError, match=r"out_indices must be valid indices for stage_names \['a'\], got \[0, 1\]"
):
AnyBackboneConfig(stage_names=["a"], out_features=None, out_indices=[0, 1])
# Out indices must contain no duplicates
with pytest.raises(ValueError, match=r"out_indices must not contain any duplicates, got \[0, 0\]"):
AnyBackboneConfig(stage_names=["a"], out_features=None, out_indices=[0, 0])
# Out features and out indices must be the same length
with pytest.raises(
ValueError, match="out_features and out_indices should have the same length if both are set"
):
AnyBackboneConfig(stage_names=["a", "b", "c"], out_features=["a", "b"], out_indices=[0])
# Out features should match out indices
with pytest.raises(
ValueError, match="out_features and out_indices should correspond to the same stages if both are set"
):
AnyBackboneConfig(stage_names=["a", "b", "c"], out_features=["a", "b"], out_indices=[0, 2])
# Out features and out indices should be in order
with pytest.raises(
ValueError,
match=r"out_features must be in the same order as stage_names, expected \['a', 'b'\] got \['b', 'a'\]",
):
AnyBackboneConfig(stage_names=["a", "b"], out_features=["b", "a"], out_indices=[0, 1])
with pytest.raises(
ValueError, match=r"out_indices must be in the same order as stage_names, expected \[-2, 1\] got \[1, -2\]"
):
AnyBackboneConfig(stage_names=["a", "b"], out_features=["a", "b"], out_indices=[1, -2])
# Check passes with valid inputs
AnyBackboneConfig(stage_names=["a", "b", "c", "d"], out_features=["a", "b", "d"], out_indices=[0, 1, -1])
@require_torch
def test_backbone_mixin(self):
config = AnyBackboneConfig(stage_names=["a", "b", "c"], out_features=["a", "c"], out_indices=[0, 2])
backbone = AnyBackbone(config)
backbone.config = config
# Check that the output features and indices are set correctly
self.assertEqual(backbone.out_features, ["a", "c"])
self.assertEqual(backbone.out_indices, [0, 2])
# Check out features and indices are updated correctly
backbone.out_features = ["a", "b"]
self.assertEqual(backbone.out_features, ["a", "b"])
self.assertEqual(backbone.out_indices, [0, 1])
backbone.out_indices = [-3, -1]
self.assertEqual(backbone.out_features, ["a", "c"])
self.assertEqual(backbone.out_indices, [-3, -1])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,623 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import AutoTokenizer
from transformers.testing_utils import require_jmespath
from transformers.utils.chat_parsing_utils import recursive_parse
cohere_schema = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": r"<\|START_RESPONSE\|>(.*?)(?:<\|END_RESPONSE\|>|$)"},
"thinking": {"type": "string", "x-regex": r"<\|START_THINKING\|>(.*?)(?:<\|END_THINKING\|>|$)"},
"tool_calls": {
"x-regex": r"<\|START_ACTION\|>(.*?)(?:<\|END_ACTION\|>|$)",
"x-parser": "json",
"x-parser-args": {
"transform": "[*].{type: 'function', function: {name: tool_name, arguments: parameters}}"
},
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
"type": "object",
"additionalProperties": {},
},
},
},
},
},
},
},
}
ernie_schema = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": "<response>\n(.*?)\n?</response>"},
"thinking": {"type": "string", "x-regex": r"(?:^|<think>\s*)(.*?)\s*<\/think>"},
"tool_calls": {
"x-regex-iterator": "<tool_call>(.*?)</tool_call>",
"type": "array",
"items": {
"type": "object",
"x-parser": "json",
"x-parser-args": {"transform": "{type: 'function', function: @}"},
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
"type": "object",
"additionalProperties": {},
},
},
},
},
},
},
},
}
gpt_oss_schema = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"},
"thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"},
"tool_calls": {
"x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)",
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"},
"arguments": {
"type": "object",
"x-regex": r"<\|message\|>(.*)",
"x-parser": "json",
"additionalProperties": {},
},
},
},
},
},
},
},
}
smollm_schema = {
"x-regex": r"(?:<think>\n?(?P<thinking>.+?)\n?</think>)?\s*(?:<tool_call>(?P<tool_calls>.+?)</tool_call>)?\s*(?P<content>.+?)?\s*(?:<\|im_end\|>|$)",
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string"},
"thinking": {"type": "string"},
"tool_calls": {
"x-parser": "json",
"x-parser-args": {"transform": "[{type: 'function', function: @}]"},
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
"type": "object",
"additionalProperties": {},
},
},
},
},
},
},
},
}
qwen3_schema = {
"x-regex": r"^(?:(?:<think>)?\s*(?P<thinking>.+?)\s*</think>)?\s*(?:<tool_call>(?P<tool_calls>.*?)\s*</tool_call>)?\s*(?P<content>.+?)?\s*$",
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string"},
"thinking": {"type": "string"},
"tool_calls": {
"x-regex-iterator": r"^(.*)$", # We have already extracted tool calls and there can only be one, so just make it a list
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string", "x-regex": r"<function=(\w+)>"},
"arguments": {
"type": "object",
"x-regex-key-value": r"<parameter=(?P<key>\w+)>\n(?P<value>.*?)\n</parameter>",
"additionalProperties": {
"x-parser": "json",
"x-parser-args": {"allow_non_json": True},
},
},
},
},
},
},
},
},
}
re_sub_schema = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"thinking": {"type": "string"},
"content": {"type": "string"},
"tool_calls": {
"x-regex-iterator": r"<\|tool_call>(.*?)<tool_call\|>",
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"x-regex": r"call\:(?P<name>\w+)(?P<arguments>\{.*\})",
"properties": {
"name": {
"type": "string",
},
"arguments": {
"type": "object",
"x-regex-key-value": r'(?P<key>\w+):(?P<value><\|"\|>.*?<\|"\|>|[^,}]+)',
"additionalProperties": {
"x-regex-substitutions": [[r'^<\|"\|>|<\|"\|>$', ""]],
},
},
},
},
},
},
},
},
"x-regex": r"(\<\|channel\>thought\n(?P<thinking>.*?)\<channel\|\>)?(?P<content>(?:(?!\<\|tool_call\>).)+)?(?P<tool_calls>\<\|tool_call\>.*\<tool_call\|\>)?",
}
gemma4_schema = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"thinking": {"type": "string"},
"content": {"type": "string"},
"tool_calls": {
"x-regex-iterator": r"<\|tool_call>(.*?)<tool_call\|>",
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"x-regex": r"call\:(?P<name>\w+)(?P<arguments>\{.*\})",
"properties": {
"name": {
"type": "string",
},
"arguments": {
"type": "object",
"x-parser": "gemma4-tool-call",
"additionalProperties": {},
},
},
},
},
},
},
},
"x-regex": r"(\<\|channel\>thought\n(?P<thinking>.*?)\<channel\|\>)?(?P<content>(?:(?!\<\|tool_call\>).)+)?(?P<tool_calls>\<\|tool_call\>.*\<tool_call\|\>)?",
}
prefix_items_schema = {
# Not intended to be "realistic", just checks that prefixItems can handle a heterogeneous array
"x-regex-iterator": r"<block>(.*?)<\/block>",
"type": "array",
"prefixItems": [
{"type": "string"},
{"type": "integer"},
{"type": "string"},
],
}
@require_jmespath
class ChatSchemaParserTest(unittest.TestCase):
def test_schema_save_load(self):
# Has no schema by default
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.response_schema = ernie_schema
with tempfile.TemporaryDirectory() as tmpdir:
tokenizer.save_pretrained(tmpdir)
reloaded_tokenizer = AutoTokenizer.from_pretrained(tmpdir)
self.assertEqual(reloaded_tokenizer.response_schema, ernie_schema)
def test_tokenizer_method(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
parsed_chat = recursive_parse(model_out, cohere_schema)
tokenizer.response_schema = cohere_schema
tokenizer_parsed_chat = tokenizer.parse_response(model_out)
self.assertEqual(tokenizer_parsed_chat, parsed_chat)
def test_batched_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
tokenizer.response_schema = cohere_schema
parsed_chat = tokenizer.parse_response(model_out)
self.assertEqual(tokenizer.parse_response([model_out]), [parsed_chat])
self.assertEqual(tokenizer.parse_response([model_out] * 2), [parsed_chat] * 2)
def test_token_id_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") # Need an actual tokenizer to encode
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
tokenizer.response_schema = cohere_schema
parsed_chat = tokenizer.parse_response(model_out)
tokenized_out = tokenizer(model_out).input_ids
self.assertEqual(tokenizer.parse_response(tokenized_out), parsed_chat)
self.assertEqual(tokenizer.parse_response([tokenized_out]), [parsed_chat])
self.assertEqual(tokenizer.parse_response([tokenized_out] * 2), [parsed_chat] * 2)
def test_numpy_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") # Need an actual tokenizer to encode
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
tokenizer.response_schema = cohere_schema
parsed_chat = tokenizer.parse_response(model_out)
tokenized_out = tokenizer(model_out, return_tensors="np").input_ids
self.assertEqual(tokenizer.parse_response(tokenized_out), [parsed_chat])
def test_tensor_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") # Need an actual tokenizer to encode
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
tokenizer.response_schema = cohere_schema
parsed_chat = tokenizer.parse_response(model_out)
tokenized_out = tokenizer(model_out, return_tensors="pt").input_ids
self.assertEqual(tokenizer.parse_response(tokenized_out), [parsed_chat])
def test_cohere_template(self):
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
parsed_chat = recursive_parse(model_out, cohere_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"thinking": "I should call a tool.",
"tool_calls": [
{
"type": "function",
"function": {"name": "simple_tool", "arguments": {"temperature_format": "Celsius"}},
}
],
},
)
def test_ernie_template_with_tools(self):
model_out = 'The user is asking about the weather in Paris today. Let me check the available tools. There\'s a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to "Paris". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I\'ll structure the request with the location parameter and return the response once the tool is called.\n</think>\n\n<tool_call>\n{"name": "get_current_temperature", "arguments": {"location": "Paris"}}\n</tool_call>\n</s>'
parsed_chat = recursive_parse(model_out, ernie_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"thinking": "The user is asking about the weather in Paris today. Let me check the available tools. There's a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to \"Paris\". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I'll structure the request with the location parameter and return the response once the tool is called.",
"tool_calls": [
{
"type": "function",
"function": {"name": "get_current_temperature", "arguments": {"location": "Paris"}},
}
],
},
)
def test_ernie_template_no_tools(self):
model_out = "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.\n</think>\n\n<response>\nHello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!\n</response>\n</s>"
parsed_chat = recursive_parse(model_out, ernie_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"content": "Hello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!",
"thinking": "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.",
},
)
def test_gpt_oss_template_with_tool_call(self):
model_out = '<|channel|>analysis<|message|>We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\n "location": "San Francisco, CA"\n}'
parsed_chat = recursive_parse(model_out, gpt_oss_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"thinking": 'We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.',
"tool_calls": [
{
"type": "function",
"function": {"name": "get_current_weather", "arguments": {"location": "San Francisco, CA"}},
}
],
},
)
def test_gpt_oss_template_no_tool_call(self):
model_out = "<|channel|>analysis<|message|>User asks a simple math question: 2+2 = 4. Provide answer.<|end|><|start|>assistant<|channel|>final<|message|>2"
parsed_chat = recursive_parse(model_out, gpt_oss_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"content": "2",
"thinking": "User asks a simple math question: 2+2 = 4. Provide answer.",
},
)
def test_smollm_template_thinking_and_tool_call(self):
model_out = '<think>\nOkay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.\n</think>\n\n<tool_call>{"name": "greet_user", "arguments": {"greeting": "Hello! I\'m doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!"}}</tool_call>'
parsed_chat = recursive_parse(model_out, smollm_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"thinking": 'Okay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.',
"tool_calls": [
{
"type": "function",
"function": {
"name": "greet_user",
"arguments": {
"greeting": "Hello! I'm doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!"
},
},
}
],
},
)
def test_smollm_template_tool_call_no_thinking(self):
model_out = '<tool_call>{"name": "get_weather", "arguments": {"city": "Paris"}}</tool_call>'
parsed_chat = recursive_parse(model_out, smollm_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"tool_calls": [
{"type": "function", "function": {"name": "get_weather", "arguments": {"city": "Paris"}}}
],
},
)
def test_smollm_template_thinking_no_tool_call(self):
model_out = '<think>\nOkay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.</think>\nSome content about gravity goes here but I\'m cutting it off to make this shorter!'
parsed_chat = recursive_parse(model_out, smollm_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"content": "Some content about gravity goes here but I'm cutting it off to make this shorter!",
"thinking": 'Okay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.',
},
)
def test_qwen3_tool_calls(self):
model_out = '<tool_call>\n<function=get_weather>\n<parameter=locations>\n[{"country": "France", "city": "Paris"}]\n</parameter>\n<parameter=temp_units>\ncelsius\n</parameter>\n</function>\n</tool_call>'
parsed_chat = recursive_parse(model_out, qwen3_schema)
self.assertEqual(
parsed_chat,
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": {
"locations": [{"country": "France", "city": "Paris"}],
"temp_units": "celsius",
},
},
}
],
},
)
def test_re_sub_schema(self):
"""Test that a schema doing re substitutions to enable JSON parsing works."""
model_out = '<|channel>thought\nThe user is asking for the current temperature in Paris. I should check the available tools to see if there\'s a function that can provide this information.<channel|><|tool_call>call:get_current_temperature{detail_level:0,location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>'
parsed = recursive_parse(model_out, re_sub_schema)
self.assertEqual(
parsed,
{
"role": "assistant",
"thinking": "The user is asking for the current temperature in Paris. I should check the available tools to see if there's a function that can provide this information.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_current_temperature",
"arguments": {"detail_level": "0", "location": "Paris, France", "unit": "celsius"},
},
}
],
},
)
def test_gemma4_tool_call(self):
model_out = '<|channel>thought\nThe user is asking for the current temperature in Paris. I should check the available tools to see if there\'s a function that can provide this information.<channel|><|tool_call>call:get_current_temperature{detail_level:0,location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>'
parsed = recursive_parse(model_out, gemma4_schema)
self.assertEqual(
parsed,
{
"role": "assistant",
"thinking": "The user is asking for the current temperature in Paris. I should check the available tools to see if there's a function that can provide this information.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_current_temperature",
"arguments": {"detail_level": 0, "location": "Paris, France", "unit": "celsius"},
},
}
],
},
)
def test_gemma4_complex_tool_call(self):
model_out = (
"<|channel>thought\nLet me call the tool.<channel|>"
'<|tool_call>call:foo{bool_value:true,list_value:[<|"|>foo<|"|>,<|"|>bar<|"|>],'
'null_value:null,number_value:1,string_value:<|"|>foo<|"|>,'
'struct_value:{foo:<|"|>bar<|"|>}}<tool_call|>'
)
parsed = recursive_parse(model_out, gemma4_schema)
self.assertEqual(
parsed,
{
"role": "assistant",
"thinking": "Let me call the tool.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "foo",
"arguments": {
"bool_value": True,
"list_value": ["foo", "bar"],
"null_value": None,
"number_value": 1,
"string_value": "foo",
"struct_value": {"foo": "bar"},
},
},
}
],
},
)
def test_required_fields_present(self):
"""Test that required fields pass validation when present in the output."""
schema = {
"type": "object",
"required": ["role", "content"],
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": r"<response>(.*?)</response>"},
"thinking": {"type": "string", "x-regex": r"<think>(.*?)</think>"},
},
}
model_out = "<think>Let me think.</think><response>Hello!</response>"
parsed = recursive_parse(model_out, schema)
self.assertEqual(
parsed,
{"role": "assistant", "content": "Hello!", "thinking": "Let me think."},
)
def test_required_field_missing_raises(self):
"""Test that a missing required field raises ValueError with a helpful message."""
schema = {
"type": "object",
"required": ["role", "content"],
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": r"<response>(.*?)</response>"},
"thinking": {"type": "string", "x-regex": r"<think>(.*?)</think>"},
},
}
# This output has thinking but no <response> tags, so content will be missing
model_out = "<think>Let me think about this.</think>Some plain text without response tags"
with self.assertRaises(ValueError) as cm:
recursive_parse(model_out, schema)
self.assertIn("content", str(cm.exception))
self.assertIn("missing", str(cm.exception).lower())
def test_required_not_enforced_when_absent(self):
"""Test that schemas without 'required' still silently omit missing fields."""
schema = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": r"<response>(.*?)</response>"},
"thinking": {"type": "string", "x-regex": r"<think>(.*?)</think>"},
},
}
# No <response> tags, but content is not required — should succeed
model_out = "<think>Just thinking.</think>"
parsed = recursive_parse(model_out, schema)
self.assertEqual(parsed, {"role": "assistant", "thinking": "Just thinking."})
def test_prefix_items(self):
model_out = "<block>hello</block><block>42</block><block>world</block>"
parsed = recursive_parse(model_out, prefix_items_schema)
self.assertEqual(parsed, ["hello", 42, "world"])
def test_prefix_items_wrong_length_raises(self):
model_out = "<block>hello</block><block>42</block>"
with self.assertRaises(ValueError):
recursive_parse(model_out, prefix_items_schema)
def test_prefix_items_wrong_type_raises(self):
model_out = "<block>hello</block><block>world</block><block>42</block>"
with self.assertRaises(ValueError):
recursive_parse(model_out, prefix_items_schema)
def test_type_any_passthrough(self):
"""Test that type 'any' passes content through without transformation."""
schema = {
"type": "object",
"x-regex": r"<data>(?P<value>.*?)</data>",
"properties": {
"value": {"type": "any"},
},
}
model_out = "<data>some arbitrary content 123</data>"
parsed = recursive_parse(model_out, schema)
self.assertEqual(parsed, {"value": "some arbitrary content 123"})
def test_type_any_in_additional_properties(self):
"""Test that type 'any' works in additionalProperties, matching the docs example."""
schema = {
"type": "object",
"x-parser": "json",
"additionalProperties": {"type": "any"},
}
node_content = '{"location": "San Francisco, CA", "units": "celsius"}'
parsed = recursive_parse(node_content, schema)
self.assertEqual(parsed, {"location": "San Francisco, CA", "units": "celsius"})

View File

@@ -0,0 +1,613 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Literal
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
class JsonSchemaGeneratorTest(unittest.TestCase):
def test_simple_function(self):
def fn(x: int):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_no_arguments(self):
def fn():
"""
Test function
"""
return True
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {"type": "object", "properties": {}},
}
self.assertEqual(schema["function"], expected_schema)
def test_union(self):
def fn(x: int | float):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": ["integer", "number"], "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_optional(self):
def fn(x: int | None):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input", "nullable": True}},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_default_arg(self):
def fn(x: int = 42):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}},
}
self.assertEqual(schema["function"], expected_schema)
def test_nested_list(self):
def fn(x: list[list[str | int]]):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "array",
"items": {"type": "array", "items": {"type": ["integer", "string"]}},
"description": "The input",
}
},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_multiple_arguments(self):
def fn(x: int, y: str):
"""
Test function
Args:
x: The input
y: Also the input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "integer", "description": "The input"},
"y": {"type": "string", "description": "Also the input"},
},
"required": ["x", "y"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_multiple_complex_arguments(self):
def fn(x: list[int | float], y: int | str | None = None):
"""
Test function
Args:
x: The input
y: Also the input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"},
"y": {
"type": ["integer", "string"],
"nullable": True,
"description": "Also the input",
},
},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_missing_docstring(self):
def fn(x: int):
return x
with self.assertRaises(DocstringParsingException):
get_json_schema(fn)
def test_missing_param_docstring(self):
def fn(x: int):
"""
Test function
"""
return x
with self.assertRaises(DocstringParsingException):
get_json_schema(fn)
def test_missing_type_hint(self):
def fn(x):
"""
Test function
Args:
x: The input
"""
return x
with self.assertRaises(TypeHintParsingException):
get_json_schema(fn)
def test_return_value(self):
def fn(x: int) -> int:
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
"return": {"type": "integer"},
}
self.assertEqual(schema["function"], expected_schema)
def test_return_value_docstring(self):
def fn(x: int) -> int:
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
"return": {"type": "integer", "description": "The output"},
}
self.assertEqual(schema["function"], expected_schema)
def test_tuple(self):
def fn(x: tuple[int, str]):
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "array",
"prefixItems": [{"type": "integer"}, {"type": "string"}],
"description": "The input",
}
},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_single_element_tuple_fails(self):
def fn(x: tuple[int]):
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x
# Single-element tuples should just be the type itself, or List[type] for variable-length inputs
with self.assertRaises(TypeHintParsingException):
get_json_schema(fn)
def test_ellipsis_type_fails(self):
def fn(x: tuple[int, ...]):
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x
# Variable length inputs should be specified with List[type], not Tuple[type, ...]
with self.assertRaises(TypeHintParsingException):
get_json_schema(fn)
def test_enum_extraction(self):
def fn(temperature_format: str):
"""
Test function
Args:
temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"])
Returns:
The temperature
"""
return -40.0
# Let's see if that gets correctly parsed as an enum
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"temperature_format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature format to use",
}
},
"required": ["temperature_format"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_literal(self):
def fn(
temperature_format: Literal["celsius", "fahrenheit"],
booleanish: Literal[True, False, 0, 1, "y", "n"] = False,
):
"""
Test function
Args:
temperature_format: The temperature format to use
booleanish: A value that can be regarded as boolean
Returns:
The temperature
"""
return -40.0
# Let's see if that gets correctly parsed as an enum
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"temperature_format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature format to use",
},
"booleanish": {
"type": ["boolean", "integer", "string"],
"enum": [True, False, 0, 1, "y", "n"],
"description": "A value that can be regarded as boolean",
},
},
"required": ["temperature_format"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_multiline_docstring_with_types(self):
def fn(x: int, y: int):
"""
Test function
Args:
x: The first input
y: The second input. This is a longer description
that spans multiple lines with indentation and stuff.
Returns:
God knows what
"""
pass
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "integer", "description": "The first input"},
"y": {
"type": "integer",
"description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.",
},
},
"required": ["x", "y"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_return_none(self):
def fn(x: int) -> None:
"""
Test function
Args:
x: The first input
"""
pass
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "integer", "description": "The first input"},
},
"required": ["x"],
},
"return": {"type": "null"},
}
self.assertEqual(schema["function"], expected_schema)
def test_instance_method(self):
class Tool:
def fn(self, x: int):
"""
Test function
Args:
x: The input
"""
return x
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(get_json_schema(Tool.fn)["function"], expected_schema) # unbound case
self.assertEqual(get_json_schema(Tool().fn)["function"], expected_schema) # bound case
def test_static_method(self):
class Tool:
@staticmethod
def fn(x: int):
"""
Test function
Args:
x: The input
"""
return x
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(get_json_schema(Tool.fn)["function"], expected_schema)
self.assertEqual(get_json_schema(Tool().fn)["function"], expected_schema)
def test_class_method(self):
class Tool:
@classmethod
def fn(cls, x: int):
"""
Test function
Args:
x: The input
"""
return x
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(get_json_schema(Tool.fn)["function"], expected_schema)
self.assertEqual(get_json_schema(Tool().fn)["function"], expected_schema)
def test_everything_all_at_once(self):
def fn(x: str, y: list[str | int] | None, z: tuple[str | int, str] = (42, "hello")) -> tuple[int, str]:
"""
Test function with multiple args, and docstring args that we have to strip out.
Args:
x: The first input. It's got a big multiline
description and also contains
(choices: ["a", "b", "c"])
y: The second input. It's a big list with a single-line description.
z: The third input. It's some kind of tuple with a default arg.
Returns:
The output. The return description is also a big multiline
description that spans multiple lines.
"""
pass
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function with multiple args, and docstring args that we have to strip out.",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "string",
"enum": ["a", "b", "c"],
"description": "The first input. It's got a big multiline description and also contains",
},
"y": {
"type": "array",
"items": {"type": ["integer", "string"]},
"nullable": True,
"description": "The second input. It's a big list with a single-line description.",
},
"z": {
"type": "array",
"prefixItems": [{"type": ["integer", "string"]}, {"type": "string"}],
"description": "The third input. It's some kind of tuple with a default arg.",
},
},
"required": ["x", "y"],
},
"return": {
"type": "array",
"prefixItems": [{"type": "integer"}, {"type": "string"}],
"description": "The output. The return description is also a big multiline\n description that spans multiple lines.",
},
}
self.assertEqual(schema["function"], expected_schema)

View File

@@ -0,0 +1,399 @@
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import sys
import tempfile
import unittest
import unittest.mock as mock
import warnings
from pathlib import Path
import httpx
from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config
from transformers.configuration_utils import PreTrainedConfig
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, require_torch
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
config_common_kwargs = {
"return_dict": False,
"output_hidden_states": True,
"output_attentions": True,
"dtype": "float16",
"chunk_size_feed_forward": 5,
"architectures": ["BertModel"],
"id2label": {0: "label"},
"label2id": {"label": "0"},
"problem_type": "regression",
}
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = TOKEN
def test_push_to_hub(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.push_to_hub(tmp_repo.repo_id, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_via_save_pretrained(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_in_organization(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.push_to_hub(tmp_repo.repo_id, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_in_organization_via_save_pretrained(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_dynamic_config(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
CustomConfig.register_for_auto_class()
config = CustomConfig(attribute=42)
config.push_to_hub(tmp_repo.repo_id, token=self._token)
# This has added the proper auto_map field to the config
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
new_config = AutoConfig.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
self.assertEqual(new_config.attribute, 42)
class ConfigTestUtils(unittest.TestCase):
def test_config_from_string(self):
c = GPT2Config()
# attempt to modify each of int/float/bool/str config records and verify they were updated
n_embd = c.n_embd + 1 # int
resid_pdrop = c.resid_pdrop + 1.0 # float
scale_attn_weights = not c.scale_attn_weights # bool
summary_type = c.summary_type + "foo" # str
c.update_from_string(
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
)
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
def test_config_common_kwargs_is_complete(self):
base_config = PreTrainedConfig()
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to add in config_common_kwargs above.
self.assertListEqual(
missing_keys,
[
"transformers_version",
"is_encoder_decoder",
"_name_or_path",
"_commit_hash",
"_output_attentions",
"_attn_implementation_internal",
"_experts_implementation_internal",
],
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
raise ValueError(
"The following keys are set with the default values in"
" `test_configuration_common.config_common_kwargs` pick another value for them:"
f" {', '.join(keys_with_defaults)}."
)
def test_nested_config_load_from_dict(self):
config = AutoConfig.from_pretrained(
"hf-internal-testing/tiny-random-CLIPModel", text_config={"num_hidden_layers": 2}
)
self.assertNotIsInstance(config.text_config, dict)
self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig")
def test_from_pretrained_subfolder(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
self.assertIsNotNone(config)
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
self.assertIsNotNone(config)
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
"failed", request=mock.Mock(), response=mock.Mock()
)
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
def test_local_versioning(self):
configuration = AutoConfig.from_pretrained("google-bert/bert-base-cased")
configuration.configuration_files = ["config.4.0.0.json"]
with tempfile.TemporaryDirectory() as tmp_dir:
configuration.save_pretrained(tmp_dir)
configuration.hidden_size = 2
json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
# This should pick the new configuration file as the version of Transformers is > 4.0.0
new_configuration = AutoConfig.from_pretrained(tmp_dir)
self.assertEqual(new_configuration.hidden_size, 2)
# Will need to be adjusted if we reach v42 and this test is still here.
# Should pick the old configuration file as the version of Transformers is < 4.42.0
configuration.configuration_files = ["config.42.0.0.json"]
configuration.hidden_size = 768
configuration.save_pretrained(tmp_dir)
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
new_configuration = AutoConfig.from_pretrained(tmp_dir)
self.assertEqual(new_configuration.hidden_size, 768)
def test_repo_versioning_before(self):
# This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
repo = "hf-internal-testing/test-two-configs"
import transformers as new_transformers
# Matt: Use a context manager to ensure everything is correctly reverted and we
# don't leak state between tests
with mock.patch.object(new_transformers.configuration_utils, "__version__", "v4.0.0"):
new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained(
repo, return_unused_kwargs=True
)
self.assertEqual(new_configuration.hidden_size, 2)
# This checks `_configuration_file` ia not kept in the kwargs by mistake.
self.assertDictEqual(kwargs, {})
# Testing an older version by monkey-patching the version in the module it's used.
import transformers as old_transformers
with mock.patch.object(old_transformers.configuration_utils, "__version__", "v3.0.0"):
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
self.assertEqual(old_configuration.hidden_size, 768)
def test_saving_config_with_custom_generation_kwargs_raises_error(self):
config = BertConfig()
config.min_length = 3 # `min_length = 3` is a non-default generation kwarg
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(ValueError):
config.save_pretrained(tmp_dir)
def test_get_generation_parameters(self):
config = BertConfig()
self.assertFalse(len(config._get_generation_parameters()) > 0)
config.min_length = 3
self.assertTrue(len(config._get_generation_parameters()) > 0)
config.min_length = 0
self.assertTrue(len(config._get_generation_parameters()) > 0)
def test_loading_config_do_not_raise_future_warnings(self):
"""Regression test for https://github.com/huggingface/transformers/issues/31002."""
# Loading config should not raise a FutureWarning. It was the case before.
with warnings.catch_warnings():
warnings.simplefilter("error")
PreTrainedConfig.from_pretrained("bert-base-uncased")
def test_get_text_config(self):
"""Tests the `get_text_config` method."""
# 1. model with only text input -> returns the original config instance
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
self.assertEqual(config.get_text_config(), config)
self.assertEqual(config.get_text_config(decoder=True), config)
# 2. composite model (VLM) -> returns the text component
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlavaForConditionalGeneration")
self.assertEqual(config.get_text_config(), config.text_config)
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
# 3. ! corner case! : composite model whose sub-config is an old composite model (should behave as above)
config = Florence2Config()
self.assertEqual(config.get_text_config(), config.text_config)
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
# 4. old composite model -> may remove components based on the `decoder` or `encoder` argument
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-bart")
self.assertEqual(config.get_text_config(), config)
# both encoder_layers and decoder_layers exist
self.assertTrue(getattr(config, "encoder_ffn_dim", None) is not None)
self.assertTrue(getattr(config, "decoder_ffn_dim", None) is not None)
decoder_config = config.get_text_config(decoder=True)
self.assertNotEqual(decoder_config, config)
self.assertEqual(decoder_config.num_hidden_layers, config.decoder_layers)
encoder_config = config.get_text_config(encoder=True)
self.assertNotEqual(encoder_config, config)
self.assertEqual(encoder_config.num_hidden_layers, config.encoder_layers)
@require_torch
def test_bc_torch_dtype(self):
import torch
config = PreTrainedConfig(dtype="bfloat16")
self.assertEqual(config.dtype, torch.bfloat16)
config = PreTrainedConfig(torch_dtype="bfloat16")
self.assertEqual(config.dtype, torch.bfloat16)
# Check that if we pass both, `dtype` is used
config = PreTrainedConfig(dtype="bfloat16", torch_dtype="float32")
self.assertEqual(config.dtype, torch.bfloat16)
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_pretrained(tmpdirname)
config = PreTrainedConfig.from_pretrained(tmpdirname)
self.assertEqual(config.dtype, torch.bfloat16)
config = PreTrainedConfig.from_pretrained(tmpdirname, dtype="float32")
self.assertEqual(config.dtype, "float32")
config = PreTrainedConfig.from_pretrained(tmpdirname, torch_dtype="float32")
self.assertEqual(config.dtype, "float32")
def test_unserializable_json_is_encoded(self):
class NewConfig(PreTrainedConfig):
def __init__(
self,
inf_positive: float = float("inf"),
inf_negative: float = float("-inf"),
nan: float = float("nan"),
**kwargs,
):
self.inf_positive = inf_positive
self.inf_negative = inf_negative
self.nan = nan
super().__init__(**kwargs)
new_config = NewConfig()
# All floats should remain as floats when being accessed in the config
self.assertIsInstance(new_config.inf_positive, float)
self.assertIsInstance(new_config.inf_negative, float)
self.assertIsInstance(new_config.nan, float)
with tempfile.TemporaryDirectory() as tmpdirname:
new_config.save_pretrained(tmpdirname)
config_file = Path(tmpdirname) / "config.json"
config_contents = json.loads(config_file.read_text())
new_config_instance = NewConfig.from_pretrained(tmpdirname)
# In the serialized JSON file, the non-JSON compatible floats should be updated
self.assertDictEqual(config_contents["inf_positive"], {"__float__": "Infinity"})
self.assertDictEqual(config_contents["inf_negative"], {"__float__": "-Infinity"})
self.assertDictEqual(config_contents["nan"], {"__float__": "NaN"})
with tempfile.TemporaryDirectory() as tmpdirname:
new_config.save_pretrained(tmpdirname)
# When reloading the config, it should have correct float values
self.assertIsInstance(new_config_instance.inf_positive, float)
self.assertIsInstance(new_config_instance.inf_negative, float)
self.assertIsInstance(new_config_instance.nan, float)
class ConfigSubclassKwOnlyTest(unittest.TestCase):
"""Test that config subclasses with non-default fields following parent default fields
no longer raise TypeError (fixed by kw_only=True in __init_subclass__). Regression
test for https://github.com/huggingface/transformers/issues/XXXX."""
def test_subclass_non_default_field_after_default(self):
"""A config subclass adding a required field after parent defaults must not raise."""
class MyConfig(PreTrainedConfig):
pooling: str # no default — would fail under Python dataclass ordering rules
# Should construct without TypeError
cfg = MyConfig(pooling="mean")
self.assertEqual(cfg.pooling, "mean")
def test_subclass_multiple_non_default_fields(self):
"""Multiple non-default fields in the subclass should all work."""
class EmbedConfig(PreTrainedConfig):
dim: int
pooling: str
cfg = EmbedConfig(dim=128, pooling="cls")
self.assertEqual(cfg.dim, 128)
self.assertEqual(cfg.pooling, "cls")
def test_inherited_defaults_still_work(self):
"""Inherited fields with defaults must still be accessible."""
from transformers import BertConfig
class BertWithPooling(BertConfig):
pooling: str
cfg = BertWithPooling(pooling="mean", hidden_size=256)
self.assertEqual(cfg.pooling, "mean")
self.assertEqual(cfg.hidden_size, 256)

View File

@@ -0,0 +1,39 @@
import unittest
import warnings
from dataclasses import dataclass
from transformers.convert_slow_tokenizer import SpmConverter
from transformers.testing_utils import get_tests_dir
@dataclass
class FakeOriginalTokenizer:
vocab_file: str
class ConvertSlowTokenizerTest(unittest.TestCase):
def test_spm_converter_bytefallback_warning(self):
spm_model_file_without_bytefallback = get_tests_dir("fixtures/test_sentencepiece.model")
spm_model_file_with_bytefallback = get_tests_dir("fixtures/test_sentencepiece_with_bytefallback.model")
original_tokenizer_without_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_without_bytefallback)
with warnings.catch_warnings(record=True) as w:
_ = SpmConverter(original_tokenizer_without_bytefallback)
# We are looking for if there is any `UserWarning` with
# `The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers.`
w = [x for x in w if x.category.__name__ != "DeprecationWarning"]
self.assertEqual(len(w), 0)
original_tokenizer_with_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_with_bytefallback)
with warnings.catch_warnings(record=True) as w:
_ = SpmConverter(original_tokenizer_with_bytefallback)
w = [x for x in w if x.category.__name__ != "DeprecationWarning"]
self.assertEqual(len(w), 1)
self.assertIn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers.",
str(w[0].message),
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,197 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import warnings
import pytest
from parameterized import parameterized
from transformers import __version__, is_torch_available
from transformers.testing_utils import require_torch_accelerator, torch_device
from transformers.utils.deprecation import deprecate_kwarg
if is_torch_available():
import torch
INFINITE_VERSION = "9999.0.0"
class DeprecationDecoratorTester(unittest.TestCase):
def test_rename_kwarg(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
@deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION)
def dummy_function(new_name=None, other_name=None):
return new_name, other_name
# Test keyword argument is renamed
value, other_value = dummy_function(deprecated_name="old_value")
self.assertEqual(value, "old_value")
self.assertIsNone(other_value)
# Test deprecated keyword argument not passed
value, other_value = dummy_function(new_name="new_value")
self.assertEqual(value, "new_value")
self.assertIsNone(other_value)
# Test other keyword argument
value, other_value = dummy_function(other_name="other_value")
self.assertIsNone(value)
self.assertEqual(other_value, "other_value")
# Test deprecated and new args are passed, the new one should be returned
value, other_value = dummy_function(deprecated_name="old_value", new_name="new_value")
self.assertEqual(value, "new_value")
self.assertIsNone(other_value)
def test_rename_multiple_kwargs(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
@deprecate_kwarg("deprecated_name1", new_name="new_name1", version=INFINITE_VERSION)
@deprecate_kwarg("deprecated_name2", new_name="new_name2", version=INFINITE_VERSION)
def dummy_function(new_name1=None, new_name2=None, other_name=None):
return new_name1, new_name2, other_name
# Test keyword argument is renamed
value1, value2, other_value = dummy_function(deprecated_name1="old_value1", deprecated_name2="old_value2")
self.assertEqual(value1, "old_value1")
self.assertEqual(value2, "old_value2")
self.assertIsNone(other_value)
# Test deprecated keyword argument is not passed
value1, value2, other_value = dummy_function(new_name1="new_value1", new_name2="new_value2")
self.assertEqual(value1, "new_value1")
self.assertEqual(value2, "new_value2")
self.assertIsNone(other_value)
# Test other keyword argument is passed and correctly returned
value1, value2, other_value = dummy_function(other_name="other_value")
self.assertIsNone(value1)
self.assertIsNone(value2)
self.assertEqual(other_value, "other_value")
def test_warnings(self):
# Test warning is raised for future version
@deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION)
def dummy_function(new_name=None, other_name=None):
return new_name, other_name
with self.assertWarns(FutureWarning):
dummy_function(deprecated_name="old_value")
# Test warning is not raised for past version, but arg is still renamed
@deprecate_kwarg("deprecated_name", new_name="new_name", version="0.0.0")
def dummy_function(new_name=None, other_name=None):
return new_name, other_name
with warnings.catch_warnings(record=True) as raised_warnings:
warnings.simplefilter("always")
value, other_value = dummy_function(deprecated_name="old_value")
self.assertEqual(value, "old_value")
self.assertIsNone(other_value)
self.assertEqual(len(raised_warnings), 0, f"Warning raised: {[w.message for w in raised_warnings]}")
# Test warning is raised for future version if warn_if_greater_or_equal_version is set
@deprecate_kwarg("deprecated_name", version="0.0.0", warn_if_greater_or_equal_version=True)
def dummy_function(deprecated_name=None):
return deprecated_name
with self.assertWarns(FutureWarning):
value = dummy_function(deprecated_name="deprecated_value")
self.assertEqual(value, "deprecated_value")
# Test arg is not renamed if new_name is not specified, but warning is raised
@deprecate_kwarg("deprecated_name", version=INFINITE_VERSION)
def dummy_function(deprecated_name=None):
return deprecated_name
with self.assertWarns(FutureWarning):
value = dummy_function(deprecated_name="deprecated_value")
self.assertEqual(value, "deprecated_value")
def test_raises(self):
# Test if deprecated name and new name are both passed and raise_if_both_names is set -> raise error
@deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION, raise_if_both_names=True)
def dummy_function(new_name=None, other_name=None):
return new_name, other_name
with self.assertRaises(ValueError):
dummy_function(deprecated_name="old_value", new_name="new_value")
# Test for current version == deprecation version
@deprecate_kwarg("deprecated_name", version=__version__, raise_if_greater_or_equal_version=True)
def dummy_function(deprecated_name=None):
return deprecated_name
with self.assertRaises(ValueError):
dummy_function(deprecated_name="old_value")
# Test for current version > deprecation version
@deprecate_kwarg("deprecated_name", version="0.0.0", raise_if_greater_or_equal_version=True)
def dummy_function(deprecated_name=None):
return deprecated_name
with self.assertRaises(ValueError):
dummy_function(deprecated_name="old_value")
def test_additional_message(self):
# Test additional message is added to the warning
@deprecate_kwarg("deprecated_name", version=INFINITE_VERSION, additional_message="Additional message")
def dummy_function(deprecated_name=None):
return deprecated_name
with warnings.catch_warnings(record=True) as raised_warnings:
warnings.simplefilter("always")
dummy_function(deprecated_name="old_value")
self.assertTrue("Additional message" in str(raised_warnings[0].message))
@parameterized.expand(["0.0.0", __version__, INFINITE_VERSION])
def test_warning_for_both_names(self, version):
# We should raise warning if both names are passed for any specified version
@deprecate_kwarg("deprecated_name", new_name="new_name", version=version)
def dummy_function(new_name=None, **kwargs):
return new_name
with self.assertWarns(FutureWarning):
result = dummy_function(deprecated_name="old_value", new_name="new_value")
self.assertEqual(result, "new_value")
@pytest.mark.torch_compile_test
@require_torch_accelerator
def test_compile_safe(self):
@deprecate_kwarg("deprecated_factor", new_name="new_factor", version=INFINITE_VERSION)
def dummy_function(new_factor=None, **kwargs):
return new_factor * torch.ones(1, device=torch_device)
compiled_function = torch.compile(dummy_function, fullgraph=True)
# Check that we can correctly call the compiled function with the old name, without raising errors
out = compiled_function(deprecated_factor=2)
self.assertEqual(out.item(), 2)
# Check that we can correctly call the compiled function with the new name, without raising errors
out = compiled_function(new_factor=2)
self.assertEqual(out.item(), 2)
# Check that we can correctly call the compiled function with both names, without raising errors
out = compiled_function(new_factor=2, deprecated_factor=10)
self.assertEqual(out.item(), 2)

View File

@@ -0,0 +1,110 @@
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import doctest
import logging
import os
import unittest
from pathlib import Path
import transformers
from transformers.testing_utils import require_torch, slow
logger = logging.getLogger()
@unittest.skip(reason="Temporarily disable the doc tests.")
@require_torch
@slow
class TestCodeExamples(unittest.TestCase):
def analyze_directory(
self,
directory: Path,
identifier: str | None = None,
ignore_files: list[str] | None = None,
n_identifier: str | list[str] | None = None,
only_modules: bool = True,
):
"""
Runs through the specific directory, looking for the files identified with `identifier`. Executes
the doctests in those files
Args:
directory (`Path`): Directory containing the files
identifier (`str`): Will parse files containing this
ignore_files (`List[str]`): List of files to skip
n_identifier (`str` or `List[str]`): Will not parse files containing this/these identifiers.
only_modules (`bool`): Whether to only analyze modules
"""
files = [file for file in os.listdir(directory) if os.path.isfile(os.path.join(directory, file))]
if identifier is not None:
files = [file for file in files if identifier in file]
if n_identifier is not None:
if isinstance(n_identifier, list):
for n_ in n_identifier:
files = [file for file in files if n_ not in file]
else:
files = [file for file in files if n_identifier not in file]
ignore_files = ignore_files or []
ignore_files.append("__init__.py")
files = [file for file in files if file not in ignore_files]
for file in files:
# Open all files
print("Testing", file)
if only_modules:
module_identifier = file.split(".")[0]
try:
module_identifier = getattr(transformers, module_identifier)
suite = doctest.DocTestSuite(module_identifier)
result = unittest.TextTestRunner().run(suite)
self.assertIs(len(result.failures), 0)
except AttributeError:
logger.info(f"{module_identifier} is not a module.")
else:
result = doctest.testfile(str(".." / directory / file), optionflags=doctest.ELLIPSIS)
self.assertIs(result.failed, 0)
def test_modeling_examples(self):
transformers_directory = Path("src/transformers")
files = "modeling"
ignore_files = [
"modeling_ctrl.py",
"modeling_tf_ctrl.py",
]
self.analyze_directory(transformers_directory, identifier=files, ignore_files=ignore_files)
def test_tokenization_examples(self):
transformers_directory = Path("src/transformers")
files = "tokenization"
self.analyze_directory(transformers_directory, identifier=files)
def test_configuration_examples(self):
transformers_directory = Path("src/transformers")
files = "configuration"
self.analyze_directory(transformers_directory, identifier=files)
def test_remaining_examples(self):
transformers_directory = Path("src/transformers")
n_identifiers = ["configuration", "modeling", "tokenization"]
self.analyze_directory(transformers_directory, n_identifier=n_identifiers)
def test_doc_sources(self):
doc_source_directory = Path("docs/source")
ignore_files = ["favicon.ico"]
self.analyze_directory(doc_source_directory, ignore_files=ignore_files, only_modules=False)

View File

@@ -0,0 +1,239 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
import pytest
from transformers import dynamic_module_utils
from transformers.dynamic_module_utils import get_cached_module_file, get_imports
TOP_LEVEL_IMPORT = """
import os
"""
IMPORT_IN_FUNCTION = """
def foo():
import os
return False
"""
DEEPLY_NESTED_IMPORT = """
def foo():
def bar():
if True:
import os
return False
return bar()
"""
TOP_LEVEL_TRY_IMPORT = """
import os
try:
import bar
except ImportError:
raise ValueError()
"""
TRY_IMPORT_IN_FUNCTION = """
import os
def foo():
try:
import bar
except ImportError:
raise ValueError()
"""
MULTIPLE_EXCEPTS_IMPORT = """
import os
try:
import bar
except (ImportError, AttributeError):
raise ValueError()
"""
EXCEPT_AS_IMPORT = """
import os
try:
import bar
except ImportError as e:
raise ValueError()
"""
GENERIC_EXCEPT_IMPORT = """
import os
try:
import bar
except:
raise ValueError()
"""
MULTILINE_TRY_IMPORT = """
import os
try:
import bar
import baz
except ImportError:
raise ValueError()
"""
MULTILINE_BOTH_IMPORT = """
import os
try:
import bar
import baz
except ImportError:
x = 1
raise ValueError()
"""
CASES = [
TOP_LEVEL_IMPORT,
IMPORT_IN_FUNCTION,
DEEPLY_NESTED_IMPORT,
TOP_LEVEL_TRY_IMPORT,
GENERIC_EXCEPT_IMPORT,
MULTILINE_TRY_IMPORT,
MULTILINE_BOTH_IMPORT,
MULTIPLE_EXCEPTS_IMPORT,
EXCEPT_AS_IMPORT,
TRY_IMPORT_IN_FUNCTION,
]
@pytest.mark.parametrize("case", CASES)
def test_import_parsing(tmp_path, case):
tmp_file_path = os.path.join(tmp_path, "test_file.py")
with open(tmp_file_path, "w") as _tmp_file:
_tmp_file.write(case)
parsed_imports = get_imports(tmp_file_path)
assert parsed_imports == ["os"]
def _create_local_module(module_dir: Path, module_code: str, helper_code: str | None = None):
module_dir.mkdir(parents=True, exist_ok=True)
(module_dir / "custom_model.py").write_text(module_code, encoding="utf-8")
if helper_code is not None:
(module_dir / "helper.py").write_text(helper_code, encoding="utf-8")
def test_get_cached_module_file_local_cache_key_uses_basename_and_content_hash(monkeypatch, tmp_path):
modules_cache = tmp_path / "hf_modules_cache"
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
model_dir_a = tmp_path / "pretrained_a" / "subdir"
model_dir_b = tmp_path / "pretrained_b" / "subdir"
model_dir_c = tmp_path / "pretrained_c" / "subdir"
_create_local_module(model_dir_a, 'MAGIC = "A"\n')
_create_local_module(model_dir_b, 'MAGIC = "B"\n')
_create_local_module(model_dir_c, 'MAGIC = "A"\n')
cached_module_a = get_cached_module_file(str(model_dir_a), "custom_model.py")
cached_module_b = get_cached_module_file(str(model_dir_b), "custom_model.py")
cached_module_c = get_cached_module_file(str(model_dir_c), "custom_model.py")
cached_module_path_a = Path(cached_module_a)
assert cached_module_path_a.parent.parent.name == "subdir"
assert len(cached_module_path_a.parent.name) == 16
assert cached_module_a != cached_module_b
assert cached_module_a == cached_module_c
def test_get_cached_module_file_local_cache_key_includes_relative_import_sources(monkeypatch, tmp_path):
modules_cache = tmp_path / "hf_modules_cache"
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
model_dir_a = tmp_path / "pretrained_a" / "subdir"
model_dir_b = tmp_path / "pretrained_b" / "subdir"
module_code = "from .helper import MAGIC\nVALUE = MAGIC\n"
_create_local_module(model_dir_a, module_code, 'MAGIC = "A"\n')
_create_local_module(model_dir_b, module_code, 'MAGIC = "B"\n')
cached_module_a = get_cached_module_file(str(model_dir_a), "custom_model.py")
cached_module_b = get_cached_module_file(str(model_dir_b), "custom_model.py")
cached_helper_a = modules_cache / Path(cached_module_a).parent / "helper.py"
cached_helper_b = modules_cache / Path(cached_module_b).parent / "helper.py"
assert cached_module_a != cached_module_b
assert cached_helper_a.read_text(encoding="utf-8") == 'MAGIC = "A"\n'
assert cached_helper_b.read_text(encoding="utf-8") == 'MAGIC = "B"\n'
def test_get_cached_module_file_local_copies_transitive_relative_imports(monkeypatch, tmp_path):
modules_cache = tmp_path / "hf_modules_cache"
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
model_dir = tmp_path / "pretrained" / "subdir"
model_dir.mkdir(parents=True, exist_ok=True)
# A → B → C: only A is the entry point; C is a transitive dep that must still be copied
(model_dir / "custom_model.py").write_text("from .helper import VALUE\n", encoding="utf-8")
(model_dir / "helper.py").write_text("from .base import BASE\nVALUE = BASE\n", encoding="utf-8")
(model_dir / "base.py").write_text('BASE = "transitive"\n', encoding="utf-8")
cached_module = get_cached_module_file(str(model_dir), "custom_model.py")
cache_dir = modules_cache / Path(cached_module).parent
assert (cache_dir / "helper.py").exists(), "direct import must be copied"
assert (cache_dir / "base.py").exists(), "transitive import must be copied"
def test_get_cached_module_file_local_cache_key_includes_transitive_import_sources(monkeypatch, tmp_path):
modules_cache = tmp_path / "hf_modules_cache"
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
for model_dir, base_val in [
(tmp_path / "pretrained_a" / "subdir", '"X"'),
(tmp_path / "pretrained_b" / "subdir", '"Y"'),
]:
model_dir.mkdir(parents=True, exist_ok=True)
(model_dir / "custom_model.py").write_text("from .helper import VALUE\n", encoding="utf-8")
(model_dir / "helper.py").write_text("from .base import BASE\nVALUE = BASE\n", encoding="utf-8")
(model_dir / "base.py").write_text(f"BASE = {base_val}\n", encoding="utf-8")
cached_a = get_cached_module_file(str(tmp_path / "pretrained_a" / "subdir"), "custom_model.py")
cached_b = get_cached_module_file(str(tmp_path / "pretrained_b" / "subdir"), "custom_model.py")
# Different content in transitive dep → different hash → different cache dirs
assert cached_a != cached_b
def test_get_cached_module_file_local_cache_key_keeps_hash_stable_with_different_basenames(monkeypatch, tmp_path):
modules_cache = tmp_path / "hf_modules_cache"
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
model_dir_a = tmp_path / "pretrained_a" / "alpha_subdir"
model_dir_b = tmp_path / "pretrained_b" / "beta_subdir"
_create_local_module(model_dir_a, 'MAGIC = "A"\n')
_create_local_module(model_dir_b, 'MAGIC = "A"\n')
cached_module_a = Path(get_cached_module_file(str(model_dir_a), "custom_model.py"))
cached_module_b = Path(get_cached_module_file(str(model_dir_b), "custom_model.py"))
assert cached_module_a.parent.parent.name == "alpha_subdir"
assert cached_module_b.parent.parent.name == "beta_subdir"
assert cached_module_a.parent.name == cached_module_b.parent.name

View File

@@ -0,0 +1,38 @@
import unittest
from transformers.testing_utils import Expectations
class ExpectationsTest(unittest.TestCase):
def test_expectations(self):
# We use the expectations below to make sure the right expectations are found for the right devices.
# Each value is just a unique ID.
expectations = Expectations(
{
(None, None): 1,
("cuda", 8): 2,
("cuda", 7): 3,
("rocm", 8): 4,
("rocm", None): 5,
("cpu", None): 6,
("xpu", 3): 7,
}
)
def check(expected_id, device_prop):
found_id = expectations.find_expectation(device_prop)
assert found_id == expected_id, f"Expected {expected_id} for {device_prop}, found {found_id}"
# npu has no matches so should find default expectation
check(1, ("npu", None, None))
check(7, ("xpu", 3, None))
check(2, ("cuda", 8, None))
check(3, ("cuda", 7, None))
check(4, ("rocm", 9, None))
check(4, ("rocm", None, None))
check(2, ("cuda", 2, None))
# We also test that if there is no default excpectation and no match is found, a ValueError is raised.
expectations = Expectations({("cuda", 8): 1})
with self.assertRaises(ValueError):
expectations.find_expectation(("xpu", None))

View File

@@ -0,0 +1,301 @@
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import tempfile
import unittest
import unittest.mock as mock
from pathlib import Path
import httpx
import numpy as np
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers.feature_extraction_utils import BatchFeature
from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test, require_torch
from transformers.utils import is_torch_available
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
if is_torch_available():
import torch
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures")
class BatchFeatureTester(unittest.TestCase):
"""Tests for the BatchFeature class and tensor conversion."""
def test_batch_feature_basic_access_and_no_conversion(self):
"""Test basic dict/attribute access and no conversion when tensor_type=None."""
data = {"input_values": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]}
batch = BatchFeature(data)
# Dict-style and attribute-style access
self.assertEqual(batch["input_values"], [[1, 2, 3], [4, 5, 6]])
self.assertEqual(batch.labels, [0, 1])
# No conversion without tensor_type
self.assertIsInstance(batch["input_values"], list)
@require_torch
def test_batch_feature_numpy_conversion(self):
"""Test conversion to numpy arrays from lists and existing numpy arrays."""
# From lists
batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="np")
self.assertIsInstance(batch["input_values"], np.ndarray)
self.assertEqual(batch["input_values"].shape, (2, 3))
# From numpy arrays (should remain numpy)
numpy_data = np.array([[1, 2, 3], [4, 5, 6]])
batch_arrays = BatchFeature({"input_values": numpy_data}, tensor_type="np")
np.testing.assert_array_equal(batch_arrays["input_values"], numpy_data)
# From list of numpy arrays with same shape should stack
numpy_data = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[7, 8, 9], [10, 11, 12]])]
batch_stacked = BatchFeature({"input_values": numpy_data}, tensor_type="np")
np.testing.assert_array_equal(
batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
)
# from tensor
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="np")
np.testing.assert_array_equal(batch_tensor["input_values"], tensor.numpy())
# from list of tensors with same shape should stack
tensors = [torch.tensor([[1, 2, 3], [4, 5, 6]]), torch.tensor([[7, 8, 9], [10, 11, 12]])]
batch_stacked = BatchFeature({"input_values": tensors}, tensor_type="np")
self.assertIsInstance(batch_stacked["input_values"], np.ndarray)
np.testing.assert_array_equal(
batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
)
@require_torch
def test_batch_feature_pytorch_conversion(self):
"""Test conversion to PyTorch tensors from various input types."""
# From lists
batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="pt")
self.assertIsInstance(batch["input_values"], torch.Tensor)
self.assertEqual(batch["input_values"].shape, (2, 3))
# from tensor (should be returned as-is)
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="pt")
torch.testing.assert_close(batch_tensor["input_values"], tensor)
# From numpy arrays
batch_numpy = BatchFeature({"input_values": np.array([[1, 2]])}, tensor_type="pt")
self.assertIsInstance(batch_numpy["input_values"], torch.Tensor)
# List of same-shape tensors should stack
tensors = [torch.randn(3, 10, 10) for _ in range(3)]
batch_stacked = BatchFeature({"pixel_values": tensors}, tensor_type="pt")
self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10))
# List of same-shape numpy arrays should stack
numpy_arrays = [np.random.randn(3, 10, 10) for _ in range(3)]
batch_stacked = BatchFeature({"pixel_values": numpy_arrays}, tensor_type="pt")
self.assertIsInstance(batch_stacked["pixel_values"], torch.Tensor)
self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10))
@require_torch
def test_batch_feature_error_handling(self):
"""Test clear error messages for common conversion failures."""
# Ragged tensors (different shapes)
data_ragged = {"values": [torch.randn(3, 224, 224), torch.randn(3, 448, 448)]}
with self.assertRaises(ValueError) as context:
BatchFeature(data_ragged, tensor_type="pt")
error_msg = str(context.exception)
self.assertIn("stack expects each tensor to be equal size", error_msg.lower())
self.assertIn("return_tensors=None", error_msg)
# Ragged numpy arrays (different shapes)
data_ragged = {"values": [np.random.randn(3, 224, 224), np.random.randn(3, 448, 448)]}
with self.assertRaises(ValueError) as context:
BatchFeature(data_ragged, tensor_type="np")
error_msg = str(context.exception)
self.assertIn("inhomogeneous", error_msg.lower())
self.assertIn("return_tensors=None", error_msg)
@require_torch
def test_batch_feature_auto_skip_non_array_like(self):
"""Test that non-array-like values are automatically skipped during tensor conversion."""
data = {
"values": [[1, 2]],
"metadata": {"key": "val"},
"image_path": "/path/to/image.jpg",
"tags": ["tag1", "tag2"],
"extra": None,
}
batch = BatchFeature(data, tensor_type="pt")
# values should be converted
self.assertIsInstance(batch["values"], torch.Tensor)
# Non-array-like values should remain unchanged
self.assertIsInstance(batch["metadata"], dict)
self.assertEqual(batch["metadata"], {"key": "val"})
self.assertIsInstance(batch["image_path"], str)
self.assertIsInstance(batch["tags"], list)
self.assertEqual(batch["tags"], ["tag1", "tag2"])
self.assertIsNone(batch["extra"])
@require_torch
def test_batch_feature_skip_tensor_conversion(self):
"""Test skip_tensor_conversion parameter for metadata fields."""
import torch
data = {"pixel_values": [[1, 2, 3]], "num_crops": [1, 2], "sizes": [(224, 224)]}
batch = BatchFeature(data, tensor_type="pt", skip_tensor_conversion=["num_crops", "sizes"])
# pixel_values should be converted
self.assertIsInstance(batch["pixel_values"], torch.Tensor)
# num_crops and sizes should remain as lists
self.assertIsInstance(batch["num_crops"], list)
self.assertIsInstance(batch["sizes"], list)
@require_torch
def test_batch_feature_convert_to_tensors_method(self):
"""Test convert_to_tensors method can be called after initialization."""
import torch
data = {"input_values": [[1, 2, 3]], "metadata": [1, 2]}
batch = BatchFeature(data) # No conversion initially
self.assertIsInstance(batch["input_values"], list)
# Convert with skip parameter
batch.convert_to_tensors(tensor_type="pt", skip_tensor_conversion=["metadata"])
self.assertIsInstance(batch["input_values"], torch.Tensor)
self.assertIsInstance(batch["metadata"], list)
@require_torch
def test_batch_feature_to_with_nested_tensors(self):
"""Test .to() method works recursively with nested lists and tuples of tensors."""
batch = BatchFeature(
{
"list_tensors": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])],
"nested_list": [[torch.tensor([1.0]), torch.tensor([2.0])]],
"tuple_tensors": (torch.tensor([5.0]), torch.tensor([6.0])),
}
)
batch_fp16 = batch.to(torch.float16)
# Check lists of tensors are converted
self.assertIsInstance(batch_fp16["list_tensors"], list)
self.assertEqual(batch_fp16["list_tensors"][0].dtype, torch.float16)
self.assertEqual(batch_fp16["list_tensors"][1].dtype, torch.float16)
# Check nested lists are converted
self.assertIsInstance(batch_fp16["nested_list"][0], list)
self.assertEqual(batch_fp16["nested_list"][0][0].dtype, torch.float16)
# Check tuples are preserved and converted
self.assertIsInstance(batch_fp16["tuple_tensors"], tuple)
self.assertEqual(batch_fp16["tuple_tensors"][0].dtype, torch.float16)
class FeatureExtractorUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
"failed", request=mock.Mock(), response=mock.Mock()
)
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# This check we did call the fake head request
mock_head.assert_called()
@is_staging_test
class FeatureExtractorPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = TOKEN
def test_push_to_hub(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub(tmp_repo.repo_id, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_via_save_pretrained(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_in_organization(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub(tmp_repo.repo_id, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_in_organization_via_save_pretrained(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_dynamic_feature_extractor(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub(tmp_repo.repo_id, token=self._token)
# This has added the proper auto_map field to the config
self.assertDictEqual(
feature_extractor.auto_map,
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
)
new_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")

View File

@@ -0,0 +1,102 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import importlib
import io
import unittest
import transformers
# Try to import everything from transformers to ensure every object can be loaded.
from transformers import * # noqa F406
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_torch
from transformers.utils import ContextManagers, find_labels, is_torch_available
if is_torch_available():
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
MODEL_ID = DUMMY_UNKNOWN_IDENTIFIER
# An actual model hosted on huggingface.co
REVISION_ID_DEFAULT = "main"
# Default branch name
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
# One particular commit (not the top of `main`)
REVISION_ID_INVALID = "aaaaaaa"
# This commit does not exist, so we should 404.
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
# Sha-1 of config.json on the top of `main`, for checking purposes
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
# Dummy contexts to test `ContextManagers`
@contextlib.contextmanager
def context_en():
print("Welcome!")
yield
print("Bye!")
@contextlib.contextmanager
def context_fr():
print("Bonjour!")
yield
print("Au revoir!")
class TestImportMechanisms(unittest.TestCase):
def test_module_spec_available(self):
# If the spec is missing, importlib would not be able to import the module dynamically.
assert transformers.__spec__ is not None
assert importlib.util.find_spec("transformers") is not None
class GenericUtilTests(unittest.TestCase):
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_context_managers_no_context(self, mock_stdout):
with ContextManagers([]):
print("Transformers are awesome!")
# The print statement adds a new line at the end of the output
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_context_managers_one_context(self, mock_stdout):
with ContextManagers([context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_context_managers_two_context(self, mock_stdout):
with ContextManagers([context_fr(), context_en()]):
print("Transformers are awesome!")
# The output should be wrapped with an English and French welcome and goodbye
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
@require_torch
def test_find_labels_pt(self):
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
# find_labels works regardless of the class name (it detects the framework through inheritance)
class DummyModel(BertForSequenceClassification):
pass
self.assertEqual(find_labels(DummyModel), ["labels"])

View File

@@ -0,0 +1,195 @@
# Copyright 2026 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import tempfile
import types
import unittest
from copy import deepcopy
from unittest.mock import patch
import torch.nn as nn
import transformers.conversion_mapping as conversion_mapping
import transformers.fusion_mapping as fusion_mapping
import transformers.monkey_patching as monkey_patching
from transformers import PretrainedConfig
from transformers.conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping
from transformers.core_model_loading import Conv3dToLinear, WeightConverter
from transformers.fusion_mapping import register_fusion_patches
from transformers.modeling_utils import PreTrainedModel
from transformers.monkey_patching import apply_patches, get_patch_mapping
DUMMY_TRANSFORMERS_MODULE_NAME = "transformers.test_fusion_mapping_dummy"
# `apply_patches()` scans `sys.modules` and only rewrites class attributes exposed
# from `transformers.*` modules, so this dummy class must be reachable through a
# fake `transformers` module instead of only through a local symbol.
DUMMY_TRANSFORMERS_MODULE = types.ModuleType(DUMMY_TRANSFORMERS_MODULE_NAME)
sys.modules[DUMMY_TRANSFORMERS_MODULE_NAME] = DUMMY_TRANSFORMERS_MODULE
class DummyVisionConfig(PretrainedConfig):
model_type = "dummy_fusion_vision"
base_config_key = "vision_config"
def __init__(
self,
in_channels=3,
patch_size=2,
temporal_patch_size=2,
patch_embed_stride=(2, 2, 2),
**kwargs,
):
super().__init__(**kwargs)
self.in_channels = in_channels
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.patch_embed_stride = patch_embed_stride
class DummyFusionConfig(PretrainedConfig):
model_type = "dummy_fusion"
sub_configs = {"vision_config": DummyVisionConfig}
def __init__(self, vision_config=None, **kwargs):
super().__init__(**kwargs)
if vision_config is None:
vision_config = DummyVisionConfig()
elif isinstance(vision_config, dict):
vision_config = DummyVisionConfig(**vision_config)
self.vision_config = vision_config
class DummyPatchEmbedding(nn.Module):
def __init__(self, stride=(2, 2, 2), bias=False):
super().__init__()
self.embed_dim = 8
self.proj = nn.Conv3d(3, self.embed_dim, kernel_size=(2, 2, 2), stride=stride, bias=bias)
DUMMY_PATCHABLE_CLASSES = {"DummyPatchEmbedding": DummyPatchEmbedding}
for class_name, patchable_class in DUMMY_PATCHABLE_CLASSES.items():
setattr(DUMMY_TRANSFORMERS_MODULE, class_name, patchable_class)
class DummyFusionModel(PreTrainedModel):
config_class = DummyFusionConfig
def __init__(self, config):
super().__init__(config)
# Resolve the class through the fake `transformers.*` module so monkey patching
# can replace it before instantiation.
self.patch_embed = DUMMY_TRANSFORMERS_MODULE.DummyPatchEmbedding(
stride=config.vision_config.patch_embed_stride, bias=True
)
self.post_init()
class FusionMappingTest(unittest.TestCase):
"""Covers registration, no-match, and conflict handling for fusion mapping."""
fusion_config = {"patch_embeddings": True}
def setUp(self):
self.patch_mapping_patcher = patch.object(monkey_patching, "_monkey_patch_mapping_cache", {})
self.patch_mapping_patcher.start()
self.discovery_cache_patcher = patch.object(fusion_mapping, "_FUSION_DISCOVERY_CACHE", {})
self.discovery_cache_patcher.start()
self.checkpoint_conversion_mapping_cache = deepcopy(conversion_mapping._checkpoint_conversion_mapping_cache)
def tearDown(self):
self.patch_mapping_patcher.stop()
self.discovery_cache_patcher.stop()
conversion_mapping._checkpoint_conversion_mapping_cache = deepcopy(self.checkpoint_conversion_mapping_cache)
def test_register_fusion_patches_is_effective_on_dummy_model(self):
# Registers and applies a fusion on a dummy model.
DummyFusionConfig.model_type = f"dummy_fusion_{self._testMethodName}"
config = DummyFusionConfig()
self.assertEqual(get_patch_mapping(), {})
self.assertIsNone(get_checkpoint_conversion_mapping(config.model_type))
self.assertIsInstance(DummyFusionModel(config).patch_embed.proj, nn.Conv3d)
register_fusion_patches(DummyFusionModel, config, fusion_config=self.fusion_config)
self.assertEqual(len(get_patch_mapping()), 1)
self.assertEqual(len(get_checkpoint_conversion_mapping(config.model_type)), 2)
with apply_patches():
fused_model = DummyFusionModel(config)
fused_projection = getattr(
fused_model.patch_embed, "linear_proj", getattr(fused_model.patch_embed, "proj", None)
)
self.assertIsInstance(fused_projection, nn.Linear)
def test_register_fusion_patches_skips_when_no_modules_match(self):
# Leaves registries untouched when nothing is fusable.
DummyFusionConfig.model_type = f"dummy_fusion_{self._testMethodName}"
config = DummyFusionConfig(vision_config={"patch_embed_stride": (1, 1, 1)})
register_fusion_patches(DummyFusionModel, config, fusion_config=self.fusion_config)
self.assertEqual(get_patch_mapping(), {})
self.assertIsNone(get_checkpoint_conversion_mapping(config.model_type))
def test_register_fusion_patches_raises_on_transform_conflicts(self):
# Rejects transforms that would shadow an existing source pattern.
DummyFusionConfig.model_type = f"dummy_fusion_{self._testMethodName}"
config = DummyFusionConfig()
model_type = config.model_type
# build a conflicting conversion mapping with the same source pattern but different target pattern
register_checkpoint_conversion_mapping(
model_type,
[
WeightConverter(
source_patterns=r"patch_embed\.proj\.weight$",
target_patterns=r"patch_embed\.other_linear_proj\.weight$",
operations=[Conv3dToLinear(in_channels=3, kernel_size=(2, 2, 2))],
)
],
overwrite=True,
)
with self.assertRaisesRegex(ValueError, "conflicts with an existing conversion mapping"):
register_fusion_patches(DummyFusionModel, config, fusion_config=self.fusion_config)
def test_from_pretrained_uses_serialized_fusion_config(self):
# A serialized `fusion_config` is reused on a later load.
DummyFusionConfig.model_type = f"dummy_fusion_{self._testMethodName}"
with tempfile.TemporaryDirectory() as source_dir, tempfile.TemporaryDirectory() as fused_dir:
DummyFusionModel(DummyFusionConfig()).save_pretrained(source_dir)
fused_model = DummyFusionModel.from_pretrained(source_dir, fusion_config=self.fusion_config)
fused_model.save_pretrained(fused_dir)
# Simulate a fresh process so the second load comes only from the serialized config.
monkey_patching._monkey_patch_mapping_cache.clear()
fusion_mapping._FUSION_DISCOVERY_CACHE.clear()
conversion_mapping._checkpoint_conversion_mapping_cache = deepcopy(
self.checkpoint_conversion_mapping_cache
)
reloaded_model = DummyFusionModel.from_pretrained(fused_dir)
fused_projection = getattr(
reloaded_model.patch_embed, "linear_proj", getattr(reloaded_model.patch_embed, "proj", None)
)
self.assertIsInstance(fused_projection, nn.Linear)

482
tests/utils/test_generic.py Normal file
View File

@@ -0,0 +1,482 @@
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import warnings
from dataclasses import dataclass
from unittest.mock import patch
import numpy as np
import pytest
from transformers.configuration_utils import PreTrainedConfig
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast, ModelOutput
from transformers.testing_utils import require_torch
from transformers.utils import (
can_return_tuple,
expand_dims,
filter_out_non_signature_kwargs,
flatten_dict,
is_torch_available,
reshape,
squeeze,
to_py_obj,
transpose,
)
from transformers.utils.generic import retry, split_attention_implementation
if is_torch_available():
import torch
class GenericTester(unittest.TestCase):
def test_flatten_dict(self):
input_dict = {
"task_specific_params": {
"summarization": {"length_penalty": 1.0, "max_length": 128, "min_length": 12, "num_beams": 4},
"summarization_cnn": {"length_penalty": 2.0, "max_length": 142, "min_length": 56, "num_beams": 4},
"summarization_xsum": {"length_penalty": 1.0, "max_length": 62, "min_length": 11, "num_beams": 6},
}
}
expected_dict = {
"task_specific_params.summarization.length_penalty": 1.0,
"task_specific_params.summarization.max_length": 128,
"task_specific_params.summarization.min_length": 12,
"task_specific_params.summarization.num_beams": 4,
"task_specific_params.summarization_cnn.length_penalty": 2.0,
"task_specific_params.summarization_cnn.max_length": 142,
"task_specific_params.summarization_cnn.min_length": 56,
"task_specific_params.summarization_cnn.num_beams": 4,
"task_specific_params.summarization_xsum.length_penalty": 1.0,
"task_specific_params.summarization_xsum.max_length": 62,
"task_specific_params.summarization_xsum.min_length": 11,
"task_specific_params.summarization_xsum.num_beams": 6,
}
self.assertEqual(flatten_dict(input_dict), expected_dict)
def test_transpose_numpy(self):
x = np.random.randn(3, 4)
self.assertTrue(np.allclose(transpose(x), x.transpose()))
x = np.random.randn(3, 4, 5)
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), x.transpose((1, 2, 0))))
@require_torch
def test_transpose_torch(self):
x = np.random.randn(3, 4)
t = torch.tensor(x)
self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
x = np.random.randn(3, 4, 5)
t = torch.tensor(x)
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
@require_torch
def test_reshape_torch(self):
x = np.random.randn(3, 4)
t = torch.tensor(x)
self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
x = np.random.randn(3, 4, 5)
t = torch.tensor(x)
self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
@require_torch
def test_squeeze_torch(self):
x = np.random.randn(1, 3, 4)
t = torch.tensor(x)
self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
x = np.random.randn(1, 4, 1, 5)
t = torch.tensor(x)
self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
def test_expand_dims_numpy(self):
x = np.random.randn(3, 4)
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.expand_dims(x, axis=1)))
@require_torch
def test_expand_dims_torch(self):
x = np.random.randn(3, 4)
t = torch.tensor(x)
self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
def test_to_py_obj_native(self):
self.assertTrue(to_py_obj(1) == 1)
self.assertTrue(to_py_obj([1, 2, 3]) == [1, 2, 3])
self.assertTrue(to_py_obj([((1.0, 1.1), 1.2), (2, 3)]) == [[[1.0, 1.1], 1.2], [2, 3]])
def test_to_py_obj_numpy(self):
x1 = [[1, 2, 3], [4, 5, 6]]
t1 = np.array(x1)
self.assertTrue(to_py_obj(t1) == x1)
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
t2 = np.array(x2)
self.assertTrue(to_py_obj(t2) == x2)
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
def test_split_attention_implementation(self):
self.assertEqual(split_attention_implementation(None), (False, None))
self.assertEqual(split_attention_implementation("sdpa"), (False, "sdpa"))
self.assertEqual(split_attention_implementation("paged|flash_attention_2"), (True, "flash_attention_2"))
self.assertEqual(
split_attention_implementation("paged|kernels-community/flash-attn3"),
(True, "kernels-community/flash-attn3"),
)
@require_torch
def test_to_py_obj_torch(self):
x1 = [[1, 2, 3], [4, 5, 6]]
t1 = torch.tensor(x1)
self.assertTrue(to_py_obj(t1) == x1)
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
t2 = torch.tensor(x2)
self.assertTrue(to_py_obj(t2) == x2)
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
def test_model_output_subclass(self):
# testing with “dict-like init” case
out = CausalLMOutputWithPast({"logits": torch.ones(2, 3, 4)})
self.assertNotEqual(out["logits"], None)
self.assertEqual(out.loss, None)
self.assertEqual(len(out.to_tuple()), 1)
# testing with dataclass init case
out = CausalLMOutputWithPast(logits=torch.ones(2, 3, 4))
self.assertNotEqual(out["logits"], None)
self.assertEqual(out.loss, None)
self.assertEqual(len(out.to_tuple()), 1)
# testing with updating a previously-None key after init with attribute assignment
out = CausalLMOutputWithPast(logits=torch.ones(2, 3, 4))
out.loss = torch.tensor(0.5)
self.assertEqual(out.loss, torch.tensor(0.5))
self.assertEqual(len(out.to_tuple()), 2)
# testing with updating a previously-None key after init with dictionary assignment
out = CausalLMOutputWithPast(logits=torch.ones(2, 3, 4))
out["loss"] = torch.tensor(0.5)
self.assertEqual(out.loss, torch.tensor(0.5))
self.assertEqual(len(out.to_tuple()), 2)
@require_torch
def test_register_model_output_pytree_node_skipped_during_compile(self):
# Regression test: on AMD CI (PyTorch 2.8.0+rocm), `set.__contains__` is not
# traceable by TorchDynamo. `_register_model_output_pytree_node` must return
# early when called inside a compiled context, before touching the set.
from transformers.utils.generic import _register_model_output_pytree_node
@dataclass
class DummyOutput(ModelOutput):
last_hidden_state: "torch.Tensor" = None
# Eager registration works normally
_register_model_output_pytree_node(DummyOutput)
# Simulate being inside torch.compile — must not raise
with patch("torch.compiler.is_compiling", return_value=True):
_register_model_output_pytree_node(DummyOutput)
class ValidationDecoratorTester(unittest.TestCase):
def test_cases_no_warning(self):
with warnings.catch_warnings(record=True) as raised_warnings:
warnings.simplefilter("always")
# basic test
@filter_out_non_signature_kwargs()
def func1(a):
return a
result = func1(1)
self.assertEqual(result, 1)
# include extra kwarg
@filter_out_non_signature_kwargs(extra=["extra_arg"])
def func2(a, **kwargs):
return a, kwargs
a, kwargs = func2(1)
self.assertEqual(a, 1)
self.assertEqual(kwargs, {})
a, kwargs = func2(1, extra_arg=2)
self.assertEqual(a, 1)
self.assertEqual(kwargs, {"extra_arg": 2})
# multiple extra kwargs
@filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
def func3(a, **kwargs):
return a, kwargs
a, kwargs = func3(2)
self.assertEqual(a, 2)
self.assertEqual(kwargs, {})
a, kwargs = func3(3, extra_arg2=3)
self.assertEqual(a, 3)
self.assertEqual(kwargs, {"extra_arg2": 3})
a, kwargs = func3(1, extra_arg=2, extra_arg2=3)
self.assertEqual(a, 1)
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
# Check that no warnings were raised
self.assertEqual(len(raised_warnings), 0, f"Warning raised: {[w.message for w in raised_warnings]}")
def test_cases_with_warnings(self):
@filter_out_non_signature_kwargs()
def func1(a):
return a
with self.assertWarns(UserWarning):
func1(1, extra_arg=2)
@filter_out_non_signature_kwargs(extra=["extra_arg"])
def func2(a, **kwargs):
return kwargs
with self.assertWarns(UserWarning):
kwargs = func2(1, extra_arg=2, extra_arg2=3)
self.assertEqual(kwargs, {"extra_arg": 2})
@filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
def func3(a, **kwargs):
return kwargs
with self.assertWarns(UserWarning):
kwargs = func3(1, extra_arg=2, extra_arg2=3, extra_arg3=4)
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
@require_torch
class CanReturnTupleDecoratorTester(unittest.TestCase):
def _get_model(self, config, store_config=True, raise_in_forward=False):
# Simple model class for testing can_return_tuple decorator.
class SimpleTestModel(torch.nn.Module):
def __init__(self, config):
super().__init__()
if store_config:
self.config = config
@can_return_tuple
def forward(self, x):
if raise_in_forward:
raise ValueError("Test error")
return BaseModelOutput(
last_hidden_state=x,
hidden_states=None,
attentions=None,
)
return SimpleTestModel(config)
def test_decorator_eager(self):
"""Test that the can_return_tuple decorator works with eager mode."""
# test nothing is set
config = PreTrainedConfig()
model = self._get_model(config)
inputs = torch.tensor(10)
output = model(inputs)
self.assertIsInstance(
output, BaseModelOutput, "output should be a BaseModelOutput when return_dict is not set"
)
# test all explicit cases
for config_return_dict in [True, False, None]:
for return_dict in [True, False, None]:
config = PreTrainedConfig(return_dict=config_return_dict)
model = self._get_model(config)
output = model(torch.tensor(10), return_dict=return_dict)
expected_type = (
tuple
if return_dict is False
else (tuple if config_return_dict is False and return_dict is None else BaseModelOutput)
)
if config_return_dict is None and return_dict is None:
expected_type = tuple
message = f"output should be a {expected_type.__name__} when config.return_dict={config_return_dict} and return_dict={return_dict}"
self.assertIsInstance(output, expected_type, message)
@pytest.mark.torch_compile_test
def test_decorator_compiled(self):
"""Test that the can_return_tuple decorator works with compiled mode."""
config = PreTrainedConfig()
# Output object
model = self._get_model(config)
compiled_model = torch.compile(model)
output = compiled_model(torch.tensor(10))
self.assertIsInstance(output, BaseModelOutput)
# Tuple output
model = self._get_model(config)
compiled_model = torch.compile(model)
output = compiled_model(torch.tensor(10), return_dict=False)
self.assertIsInstance(output, tuple)
@pytest.mark.torch_export_test
def test_decorator_torch_export(self):
"""Test that the can_return_tuple decorator works with torch.export."""
config = PreTrainedConfig()
model = self._get_model(config)
torch.export.export(model, args=(torch.tensor(10),))
def test_attribute_cleanup(self):
"""Test that the `_is_top_level_module` attribute is removed after the forward call."""
config = PreTrainedConfig(return_dict=False)
inputs = torch.tensor(10)
# working case
model = self._get_model(config)
output = model(inputs)
self.assertIsInstance(output, tuple)
for name, module in model.named_modules():
self.assertFalse(
hasattr(module, "_is_top_level_module"),
f"Module `{name}` should not have `_is_top_level_module` attribute",
)
# model without config
no_config_model = self._get_model(config, store_config=False)
output = no_config_model(inputs)
self.assertIsInstance(output, BaseModelOutput)
for name, module in no_config_model.named_modules():
self.assertFalse(
hasattr(module, "_is_top_level_module"),
f"Module `{name}` should not have `_is_top_level_module` attribute",
)
# model with raise in forward
model_with_raise = self._get_model(config, raise_in_forward=True)
with self.assertRaises(ValueError):
model_with_raise(inputs)
for name, module in model_with_raise.named_modules():
self.assertFalse(
hasattr(module, "_is_top_level_module"),
f"Module `{name}` should not have `_is_top_level_module` attribute",
)
class RetryTest(unittest.TestCase):
def test_succeeds_on_first_attempt(self):
"""Test that retry returns immediately when the wrapped call succeeds."""
@retry(max_retries=3, exceptions=(ValueError,))
def succeed():
return "ok"
self.assertEqual(succeed(), "ok")
@patch("transformers.utils.generic.time.sleep")
def test_retries_then_succeeds(self, mock_sleep):
"""Test that retry sleeps and eventually returns after transient failures."""
call_count = 0
@retry(max_retries=3, initial_delay=1.0, jitter=False, exceptions=(ValueError,))
def fail_twice():
nonlocal call_count
call_count += 1
if call_count < 3:
raise ValueError("transient")
return "recovered"
self.assertEqual(fail_twice(), "recovered")
self.assertEqual(call_count, 3)
self.assertEqual(mock_sleep.call_count, 2)
@patch("transformers.utils.generic.time.sleep")
def test_raises_after_max_retries(self, mock_sleep):
"""Test that retry re-raises the configured exception after exhausting retries."""
@retry(max_retries=2, initial_delay=0.1, jitter=False, exceptions=(RuntimeError,))
def always_fail():
raise RuntimeError("permanent")
with self.assertRaises(RuntimeError, msg="permanent"):
always_fail()
self.assertEqual(mock_sleep.call_count, 1)
@patch("transformers.utils.generic.time.sleep")
def test_non_matching_exception_propagates_immediately(self, mock_sleep):
"""Test that retry does not intercept exceptions outside the configured set."""
@retry(max_retries=5, exceptions=(ValueError,))
def raise_type_error():
raise TypeError("wrong type")
with self.assertRaises(TypeError):
raise_type_error()
self.assertEqual(mock_sleep.call_count, 0)
@patch("transformers.utils.generic.time.sleep")
def test_exponential_backoff(self, mock_sleep):
"""Test that retry doubles the delay between attempts when jitter is disabled."""
call_count = 0
@retry(max_retries=4, initial_delay=1.0, max_delay=10.0, jitter=False, exceptions=(ValueError,))
def fail_thrice():
nonlocal call_count
call_count += 1
if call_count < 4:
raise ValueError("retry")
return "done"
fail_thrice()
delays = [call[0][0] for call in mock_sleep.call_args_list]
self.assertEqual(delays, [1.0, 2.0, 4.0])
@patch("transformers.utils.generic.time.sleep")
def test_max_delay_cap(self, mock_sleep):
"""Test that retry caps exponential backoff at the configured maximum delay."""
call_count = 0
@retry(max_retries=5, initial_delay=8.0, max_delay=10.0, jitter=False, exceptions=(ValueError,))
def fail_four():
nonlocal call_count
call_count += 1
if call_count < 5:
raise ValueError("retry")
return "done"
fail_four()
delays = [call[0][0] for call in mock_sleep.call_args_list]
# 8.0, then min(16, 10)=10, min(20, 10)=10, min(20, 10)=10
self.assertEqual(delays, [8.0, 10.0, 10.0, 10.0])
def test_preserves_function_metadata(self):
"""Test that retry preserves the wrapped function metadata."""
@retry(exceptions=(ValueError,))
def my_func():
"""My docstring."""
pass
self.assertEqual(my_func.__name__, "my_func")
self.assertEqual(my_func.__doc__, "My docstring.")

View File

@@ -0,0 +1,492 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import sys
import tempfile
import unittest
from argparse import Namespace
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Literal, Union, get_args, get_origin
from unittest.mock import patch
import yaml
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import make_choice_type_function, string_to_bool
from transformers.testing_utils import require_torch
# Since Python 3.10, we can use the builtin `|` operator for Union types
# See PEP 604: https://peps.python.org/pep-0604
is_python_no_less_than_3_10 = sys.version_info >= (3, 10)
def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)
@dataclass
class BasicExample:
foo: int
bar: float
baz: str
flag: bool
@dataclass
class WithDefaultExample:
foo: int = 42
baz: str = field(default="toto", metadata={"help": "help message"})
@dataclass
class WithDefaultBoolExample:
foo: bool = False
baz: bool = True
opt: bool | None = None
class BasicEnum(Enum):
titi = "titi"
toto = "toto"
class MixedTypeEnum(Enum):
titi = "titi"
toto = "toto"
fourtytwo = 42
@dataclass
class EnumExample:
foo: BasicEnum = "toto"
def __post_init__(self):
self.foo = BasicEnum(self.foo)
@dataclass
class MixedTypeEnumExample:
foo: MixedTypeEnum = "toto"
def __post_init__(self):
self.foo = MixedTypeEnum(self.foo)
@dataclass
class OptionalExample:
foo: int | None = None
bar: float | None = field(default=None, metadata={"help": "help message"})
baz: str | None = None
ces: list[str] | None = list_field(default=[])
des: list[int] | None = list_field(default=[])
@dataclass
class ListExample:
foo_int: list[int] = list_field(default=[])
bar_int: list[int] = list_field(default=[1, 2, 3])
foo_str: list[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
foo_float: list[float] = list_field(default=[0.1, 0.2, 0.3])
@dataclass
class RequiredExample:
required_list: list[int] = field()
required_str: str = field()
required_enum: BasicEnum = field()
def __post_init__(self):
self.required_enum = BasicEnum(self.required_enum)
@dataclass
class StringLiteralAnnotationExample:
foo: int
required_enum: "BasicEnum" = field()
opt: "bool | None" = None
baz: "str" = field(default="toto", metadata={"help": "help message"})
foo_str: "list[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
if is_python_no_less_than_3_10:
@dataclass
class WithDefaultBoolExamplePep604:
foo: bool = False
baz: bool = True
opt: bool | None = None
@dataclass
class OptionalExamplePep604:
foo: int | None = None
bar: float | None = field(default=None, metadata={"help": "help message"})
baz: str | None = None
ces: list[str] | None = list_field(default=[])
des: list[int] | None = list_field(default=[])
class HfArgumentParserTest(unittest.TestCase):
def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser):
"""
Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances.
"""
self.assertEqual(len(a._actions), len(b._actions))
for x, y in zip(a._actions, b._actions):
xx = {k: v for k, v in vars(x).items() if k != "container"}
yy = {k: v for k, v in vars(y).items() if k != "container"}
# Choices with mixed type have custom function as "type"
# So we need to compare results directly for equality
if xx.get("choices") and yy.get("choices"):
for expected_choice in yy["choices"] + xx["choices"]:
self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice))
del xx["type"], yy["type"]
self.assertEqual(xx, yy)
def test_00_basic(self):
parser = HfArgumentParser(BasicExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--bar", type=float, required=True)
expected.add_argument("--baz", type=str, required=True)
expected.add_argument("--flag", type=string_to_bool, default=False, const=True, nargs="?")
self.argparsersEqual(parser, expected)
args = ["--foo", "1", "--baz", "quux", "--bar", "0.5"]
(example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False)
self.assertFalse(example.flag)
def test_01_with_default(self):
parser = HfArgumentParser(WithDefaultExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=42, type=int)
expected.add_argument("--baz", default="toto", type=str, help="help message")
self.argparsersEqual(parser, expected)
def test_02_with_default_bool(self):
expected = argparse.ArgumentParser()
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
# A boolean no_* argument always has to come after its "default: True" regular counter-part
# and its default must be set to False
expected.add_argument("--no_baz", "--no-baz", action="store_false", default=False, dest="baz")
expected.add_argument("--opt", type=string_to_bool, default=None)
dataclass_types = [WithDefaultBoolExample]
if is_python_no_less_than_3_10:
dataclass_types.append(WithDefaultBoolExamplePep604)
for dataclass_type in dataclass_types:
parser = HfArgumentParser(dataclass_type)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--no-baz"])
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_03_with_enum(self):
parser = HfArgumentParser(MixedTypeEnumExample)
expected = argparse.ArgumentParser()
expected.add_argument(
"--foo",
default="toto",
choices=["titi", "toto", 42],
type=make_choice_type_function(["titi", "toto", 42]),
)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args.foo, "toto")
enum_ex = parser.parse_args_into_dataclasses([])[0]
self.assertEqual(enum_ex.foo, MixedTypeEnum.toto)
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, "titi")
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
self.assertEqual(enum_ex.foo, MixedTypeEnum.titi)
args = parser.parse_args(["--foo", "42"])
self.assertEqual(args.foo, 42)
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
def test_04_with_literal(self):
@dataclass
class LiteralExample:
foo: Literal["titi", "toto", 42] = "toto"
parser = HfArgumentParser(LiteralExample)
expected = argparse.ArgumentParser()
expected.add_argument(
"--foo",
default="toto",
choices=("titi", "toto", 42),
type=make_choice_type_function(["titi", "toto", 42]),
)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args.foo, "toto")
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, "titi")
args = parser.parse_args(["--foo", "42"])
self.assertEqual(args.foo, 42)
def test_05_with_list(self):
parser = HfArgumentParser(ListExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo_int", "--foo-int", nargs="+", default=[], type=int)
expected.add_argument("--bar_int", "--bar-int", nargs="+", default=[1, 2, 3], type=int)
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
expected.add_argument("--foo_float", "--foo-float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(
args,
Namespace(foo_int=[], bar_int=[1, 2, 3], foo_str=["Hallo", "Bonjour", "Hello"], foo_float=[0.1, 0.2, 0.3]),
)
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split())
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
def test_06_with_optional(self):
expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=None, type=int)
expected.add_argument("--bar", default=None, type=float, help="help message")
expected.add_argument("--baz", default=None, type=str)
expected.add_argument("--ces", nargs="+", default=[], type=str)
expected.add_argument("--des", nargs="+", default=[], type=int)
dataclass_types = [OptionalExample]
if is_python_no_less_than_3_10:
dataclass_types.append(OptionalExamplePep604)
for dataclass_type in dataclass_types:
parser = HfArgumentParser(dataclass_type)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
def test_07_with_required(self):
parser = HfArgumentParser(RequiredExample)
expected = argparse.ArgumentParser()
expected.add_argument("--required_list", "--required-list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", "--required-str", type=str, required=True)
expected.add_argument(
"--required_enum",
"--required-enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
self.argparsersEqual(parser, expected)
def test_08_with_string_literal_annotation(self):
parser = HfArgumentParser(StringLiteralAnnotationExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", type=int, required=True)
expected.add_argument(
"--required_enum",
"--required-enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
expected.add_argument("--opt", type=string_to_bool, default=None)
expected.add_argument("--baz", default="toto", type=str, help="help message")
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
self.argparsersEqual(parser, expected)
def test_09_parse_dict(self):
parser = HfArgumentParser(BasicExample)
args_dict = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
}
parsed_args = parser.parse_dict(args_dict)[0]
args = BasicExample(**args_dict)
self.assertEqual(parsed_args, args)
def test_10_parse_dict_extra_key(self):
parser = HfArgumentParser(BasicExample)
args_dict = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
"extra": 42,
}
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
def test_11_parse_json(self):
parser = HfArgumentParser(BasicExample)
args_dict_for_json = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
}
with tempfile.TemporaryDirectory() as tmp_dir:
temp_local_path = os.path.join(tmp_dir, "temp_json")
os.mkdir(temp_local_path)
with open(temp_local_path + ".json", "w+") as f:
json.dump(args_dict_for_json, f)
parsed_args = parser.parse_json_file(Path(temp_local_path + ".json"))[0]
args = BasicExample(**args_dict_for_json)
self.assertEqual(parsed_args, args)
def test_12_parse_yaml(self):
parser = HfArgumentParser(BasicExample)
args_dict_for_yaml = {
"foo": 12,
"bar": 3.14,
"baz": "42",
"flag": True,
}
with tempfile.TemporaryDirectory() as tmp_dir:
temp_local_path = os.path.join(tmp_dir, "temp_yaml")
os.mkdir(temp_local_path)
with open(temp_local_path + ".yaml", "w+") as f:
yaml.dump(args_dict_for_yaml, f)
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".yaml"))[0]
args = BasicExample(**args_dict_for_yaml)
self.assertEqual(parsed_args, args)
def test_13_valid_dict_annotation(self):
"""
Tests to make sure that `dict` based annotations
are correctly made in the `TrainingArguments`.
If this fails, a type annotation change is
needed on a new input
"""
base_list = TrainingArguments._VALID_DICT_FIELDS.copy()
args = TrainingArguments
# First find any annotations that contain `dict`
fields = args.__dataclass_fields__
raw_dict_fields = []
optional_dict_fields = []
for field_ in fields.values():
# First verify raw dict
if field_.type is dict:
raw_dict_fields.append(field_)
# Next check for `Union` or `Optional`
elif get_origin(field_.type) == Union:
if any(arg is dict for arg in get_args(field_.type)):
optional_dict_fields.append(field_)
# First check: anything in `raw_dict_fields` is very bad
self.assertEqual(
len(raw_dict_fields),
0,
f"Found invalid raw `dict` types in the `TrainingArgument` typings, which are {raw_dict_fields}. "
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
)
# Next check raw annotations
for field_ in optional_dict_fields:
args = get_args(field_.type)
# These should be returned as `dict`, `str`, ...
# we only care about the first two
self.assertIn(
dict,
args,
f"Expected field `{field_.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, but `dict` not found. Please fix this.",
)
self.assertIn(
str,
args,
f"Expected field `{field_.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, but `str` not found. Please fix this.",
)
# Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
for field_ in optional_dict_fields:
self.assertIn(
field.name,
base_list,
f"Optional dict field `{field_.name}` is not in the base list of valid fields. Please add it to `TrainingArguments._VALID_DICT_FIELDS`",
)
@require_torch
def test_14_valid_dict_input_parsing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
args = TrainingArguments(
output_dir=tmp_dir,
accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}',
)
self.assertEqual(args.accelerator_config.split_batches, True)
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
def test_15_integration_training_args(self):
parser = HfArgumentParser(TrainingArguments)
self.assertIsNotNone(parser)
@require_torch
@patch("sys.argv", ["test.py", "--accelerator_config", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
def test_16_cli_input_parsing(self):
parser = HfArgumentParser(TrainingArguments)
training_args = parser.parse_args_into_dataclasses()[0]
self.assertEqual(training_args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)

View File

@@ -0,0 +1,207 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import unittest
import unittest.mock as mock
from pathlib import Path
from huggingface_hub import constants, hf_hub_download
from huggingface_hub.errors import HfHubHTTPError, LocalEntryNotFoundError, OfflineModeIsEnabled
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, has_file, list_repo_templates
RANDOM_BERT = "hf-internal-testing/tiny-random-bert"
TINY_BERT_PT_ONLY = "hf-internal-testing/tiny-bert-pt-only"
CACHE_DIR = os.path.join(constants.HF_HUB_CACHE, "models--hf-internal-testing--tiny-random-bert")
FULL_COMMIT_HASH = "9b8c223d42b2188cb49d29af482996f9d0f3e5a6"
GATED_REPO = "hf-internal-testing/dummy-gated-model"
README_FILE = "README.md"
class GetFromCacheTests(unittest.TestCase):
def test_cached_file(self):
archive_file = cached_file(RANDOM_BERT, CONFIG_NAME)
# Should have downloaded the file in here
self.assertTrue(os.path.isdir(CACHE_DIR))
# Cache should contain at least those three subfolders:
for subfolder in ["blobs", "refs", "snapshots"]:
self.assertTrue(os.path.isdir(os.path.join(CACHE_DIR, subfolder)))
with open(os.path.join(CACHE_DIR, "refs", "main")) as f:
main_commit = f.read()
self.assertEqual(archive_file, os.path.join(CACHE_DIR, "snapshots", main_commit, CONFIG_NAME))
self.assertTrue(os.path.isfile(archive_file))
# File is cached at the same place the second time.
new_archive_file = cached_file(RANDOM_BERT, CONFIG_NAME)
self.assertEqual(archive_file, new_archive_file)
# Using a specific revision to test the full commit hash.
archive_file = cached_file(RANDOM_BERT, CONFIG_NAME, revision="9b8c223")
self.assertEqual(archive_file, os.path.join(CACHE_DIR, "snapshots", FULL_COMMIT_HASH, CONFIG_NAME))
def test_cached_file_errors(self):
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
_ = cached_file("tiny-random-bert", CONFIG_NAME)
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
_ = cached_file(RANDOM_BERT, CONFIG_NAME, revision="aaaa")
with self.assertRaisesRegex(EnvironmentError, "does not appear to have a file named"):
_ = cached_file(RANDOM_BERT, "conf")
def test_non_existence_is_cached(self):
with self.assertRaisesRegex(EnvironmentError, "does not appear to have a file named"):
_ = cached_file(RANDOM_BERT, "conf")
with open(os.path.join(CACHE_DIR, "refs", "main")) as f:
main_commit = f.read()
self.assertTrue(os.path.isfile(os.path.join(CACHE_DIR, ".no_exist", main_commit, "conf")))
path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_missing_entries=False)
self.assertIsNone(path)
path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False)
self.assertIsNone(path)
# Under the mock environment, hf_hub_download will always raise an HTTPError
with mock.patch(
"transformers.utils.hub.hf_hub_download",
side_effect=HfHubHTTPError("failed", response=mock.Mock(status_code=404)),
) as mock_head:
path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_connection_errors=False)
self.assertIsNone(path)
# This check we did call the fake head request
mock_head.assert_called()
def test_has_file(self):
self.assertTrue(has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME))
self.assertFalse(has_file(TINY_BERT_PT_ONLY, "tf_model.h5"))
self.assertFalse(has_file(TINY_BERT_PT_ONLY, "flax_model.msgpack"))
def test_has_file_in_cache(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# Empty cache dir + offline mode => return False
assert not has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
# Populate cache dir
# TODO: only necessary for read-only cache systems; replace with a shared helper
with unittest.mock.patch.dict(os.environ, {"HF_XET_CACHE": tmp_dir}):
hf_hub_download(TINY_BERT_PT_ONLY, WEIGHTS_NAME, cache_dir=tmp_dir)
# Cache dir + offline mode => return True
assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
def test_get_file_from_repo_distant(self):
# should return None if the file does not exist
self.assertIsNone(
cached_file(
"google-bert/bert-base-cased",
"ahah.txt",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
)
# The function raises if the repository does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
cached_file(
"bert-base-case",
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
# The function raises if the revision does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
cached_file(
"google-bert/bert-base-cased",
CONFIG_NAME,
revision="ahaha",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
resolved_file = cached_file(
"google-bert/bert-base-cased",
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
# The name is the cached name which is not very easy to test, so instead we load the content.
config = json.loads(open(resolved_file).read())
self.assertEqual(config["hidden_size"], 768)
def test_get_file_from_repo_local(self):
with tempfile.TemporaryDirectory() as tmp_dir:
filename = Path(tmp_dir) / "a.txt"
filename.touch()
self.assertEqual(
cached_file(
tmp_dir,
"a.txt",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
),
str(filename),
)
self.assertIsNone(
cached_file(
tmp_dir,
"b.txt",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
)
def test_get_file_gated_repo(self):
"""Test download file from a gated repo fails with correct message when not authenticated."""
with self.assertRaisesRegex(EnvironmentError, "You are trying to access a gated repo."):
# All files except README.md are protected on a gated repo.
cached_file(GATED_REPO, "gated_file.txt", token=False)
def test_has_file_gated_repo(self):
"""Test check file existence from a gated repo fails with correct message when not authenticated."""
with self.assertRaisesRegex(EnvironmentError, "is a gated repository"):
# All files except README.md are protected on a gated repo.
has_file(GATED_REPO, "gated_file.txt", token=False)
def test_cached_files_exception_raised(self):
"""Test that unhadled exceptions, e.g. ModuleNotFoundError, is properly re-raised by cached_files when hf_hub_download fails."""
with mock.patch(
"transformers.utils.hub.hf_hub_download", side_effect=ModuleNotFoundError("No module named 'MockModule'")
):
with self.assertRaises(ModuleNotFoundError):
# The error should be re-raised by cached_files, not caught in the exception handling block
cached_file(RANDOM_BERT, "nonexistent.json")
class OfflineModeTests(unittest.TestCase):
def test_list_repo_templates_w_offline(self):
with mock.patch("transformers.utils.hub.HfApi.list_repo_tree", side_effect=OfflineModeIsEnabled()):
with mock.patch(
"transformers.utils.hub.HfApi.snapshot_download",
side_effect=LocalEntryNotFoundError("no snapshot found"),
):
self.assertEqual(list_repo_templates(RANDOM_BERT, local_files_only=False), [])

View File

@@ -0,0 +1,224 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import tempfile
import unittest
import unittest.mock as mock
from pathlib import Path
import httpx
from transformers import AutoImageProcessor, ViTImageProcessor, ViTImageProcessorFast
from transformers.image_processing_utils import get_size_dict
from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
from test_module.custom_image_processing import CustomImageProcessor # noqa E402
SAMPLE_IMAGE_PROCESSING_CONFIG_DIR = get_tests_dir("fixtures")
class ImageProcessorUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
"failed", request=mock.Mock(), response=mock.Mock()
)
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
_ = ViTImageProcessorFast.from_pretrained("hf-internal-testing/tiny-random-vit")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
_ = ViTImageProcessorFast.from_pretrained("hf-internal-testing/tiny-random-vit")
# This check we did call the fake head request
mock_head.assert_called()
def test_image_processor_from_pretrained_subfolder(self):
with self.assertRaises(OSError):
# config is in subfolder, the following should not work without specifying the subfolder
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/stable-diffusion-all-variants")
config = AutoImageProcessor.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", subfolder="feature_extractor"
)
self.assertIsNotNone(config)
@is_staging_test
class ImageProcessorPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = TOKEN
def test_push_to_hub(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
def test_push_to_hub_fast(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
image_processor = ViTImageProcessorFast.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
new_image_processor = ViTImageProcessorFast.from_pretrained(tmp_repo.repo_id)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
def test_push_to_hub_via_save_pretrained(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
def test_push_to_hub_via_save_pretrained_fast(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
image_processor = ViTImageProcessorFast.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_image_processor = ViTImageProcessorFast.from_pretrained(tmp_repo.repo_id)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
def test_push_to_hub_in_organization(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
def test_push_to_hub_in_organization_fast(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
image_processor = ViTImageProcessorFast.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
new_image_processor = ViTImageProcessorFast.from_pretrained(tmp_repo.repo_id)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
def test_push_to_hub_in_organization_via_save_pretrained(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
def test_push_to_hub_in_organization_via_save_pretrained_fast(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
image_processor = ViTImageProcessorFast.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_image_processor = ViTImageProcessorFast.from_pretrained(tmp_repo.repo_id)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
def test_push_to_hub_dynamic_image_processor(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
CustomImageProcessor.register_for_auto_class()
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
# This has added the proper auto_map field to the config
self.assertDictEqual(
image_processor.auto_map,
{"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"},
)
new_image_processor = AutoImageProcessor.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
class ImageProcessingUtilsTester(unittest.TestCase):
def test_get_size_dict(self):
# Test a dict with the wrong keys raises an error
inputs = {"wrong_key": 224}
with self.assertRaises(ValueError):
get_size_dict(inputs)
inputs = {"height": 224}
with self.assertRaises(ValueError):
get_size_dict(inputs)
inputs = {"width": 224, "shortest_edge": 224}
with self.assertRaises(ValueError):
get_size_dict(inputs)
# Test a dict with the correct keys is returned as is
inputs = {"height": 224, "width": 224}
outputs = get_size_dict(inputs)
self.assertEqual(outputs, inputs)
inputs = {"shortest_edge": 224}
outputs = get_size_dict(inputs)
self.assertEqual(outputs, {"shortest_edge": 224})
inputs = {"longest_edge": 224, "shortest_edge": 224}
outputs = get_size_dict(inputs)
self.assertEqual(outputs, {"longest_edge": 224, "shortest_edge": 224})
# Test a single int value which represents (size, size)
outputs = get_size_dict(224)
self.assertEqual(outputs, {"height": 224, "width": 224})
# Test a single int value which represents the shortest edge
outputs = get_size_dict(224, default_to_square=False)
self.assertEqual(outputs, {"shortest_edge": 224})
# Test a tuple of ints which represents (height, width)
outputs = get_size_dict((150, 200))
self.assertEqual(outputs, {"height": 150, "width": 200})
# Test a tuple of ints which represents (width, height)
outputs = get_size_dict((150, 200), height_width_order=False)
self.assertEqual(outputs, {"height": 200, "width": 150})
# Test an int representing the shortest edge and max_size which represents the longest edge
outputs = get_size_dict(224, max_size=256, default_to_square=False)
self.assertEqual(outputs, {"shortest_edge": 224, "longest_edge": 256})
# Test int with default_to_square=True and max_size fails
with self.assertRaises(ValueError):
get_size_dict(224, max_size=256, default_to_square=True)

View File

@@ -0,0 +1,899 @@
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import unittest
import httpx
import numpy as np
import pytest
from huggingface_hub.file_download import hf_hub_download
from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
from transformers import is_torch_available, is_vision_available
from transformers.image_utils import (
ChannelDimension,
get_channel_dimension_axis,
make_flat_list_of_images,
make_list_of_images,
make_nested_list_of_images,
)
from transformers.testing_utils import is_flaky, require_torch, require_vision
if is_torch_available():
import torch
if is_vision_available():
import PIL.Image
from transformers import ImageFeatureExtractionMixin
from transformers.image_utils import get_image_size, infer_channel_dimension_format, load_image
def get_image_from_hub_dataset(dataset_id: str, filename: str, revision: str | None = None) -> "PIL.Image.Image":
path = hf_hub_download(dataset_id, filename, repo_type="dataset", revision=revision)
return PIL.Image.open(path)
def get_random_image(height, width):
random_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
return PIL.Image.fromarray(random_array)
@require_vision
class ImageFeatureExtractionTester(unittest.TestCase):
def test_conversion_image_to_array(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
# Conversion with defaults (rescale + channel first)
array1 = feature_extractor.to_numpy_array(image)
self.assertTrue(array1.dtype, np.float32)
self.assertEqual(array1.shape, (3, 16, 32))
# Conversion with rescale and not channel first
array2 = feature_extractor.to_numpy_array(image, channel_first=False)
self.assertTrue(array2.dtype, np.float32)
self.assertEqual(array2.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array1, array2.transpose(2, 0, 1)))
# Conversion with no rescale and channel first
array3 = feature_extractor.to_numpy_array(image, rescale=False)
self.assertTrue(array3.dtype, np.uint8)
self.assertEqual(array3.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array1, array3.astype(np.float32) * (1 / 255.0)))
# Conversion with no rescale and not channel first
array4 = feature_extractor.to_numpy_array(image, rescale=False, channel_first=False)
self.assertTrue(array4.dtype, np.uint8)
self.assertEqual(array4.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array2, array4.astype(np.float32) * (1 / 255.0)))
def test_conversion_array_to_array(self):
feature_extractor = ImageFeatureExtractionMixin()
array = np.random.randint(0, 256, (16, 32, 3), dtype=np.uint8)
# By default, rescale (for an array of ints) and channel permute
array1 = feature_extractor.to_numpy_array(array)
self.assertTrue(array1.dtype, np.float32)
self.assertEqual(array1.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array1, array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0)))
# Same with no permute
array2 = feature_extractor.to_numpy_array(array, channel_first=False)
self.assertTrue(array2.dtype, np.float32)
self.assertEqual(array2.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array2, array.astype(np.float32) * (1 / 255.0)))
# Force rescale to False
array3 = feature_extractor.to_numpy_array(array, rescale=False)
self.assertTrue(array3.dtype, np.uint8)
self.assertEqual(array3.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array3, array.transpose(2, 0, 1)))
# Force rescale to False and no channel permute
array4 = feature_extractor.to_numpy_array(array, rescale=False, channel_first=False)
self.assertTrue(array4.dtype, np.uint8)
self.assertEqual(array4.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array4, array))
# Now test the default rescale for a float array (defaults to False)
array5 = feature_extractor.to_numpy_array(array2)
self.assertTrue(array5.dtype, np.float32)
self.assertEqual(array5.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array5, array1))
def test_make_list_of_images_pil(self):
# Test a single image is converted to a list of 1 image
pil_image = get_random_image(16, 32)
images_list = make_list_of_images(pil_image)
self.assertIsInstance(images_list, list)
self.assertEqual(len(images_list), 1)
self.assertIsInstance(images_list[0], PIL.Image.Image)
# Test a list of images is not modified
images = [get_random_image(16, 32) for _ in range(4)]
images_list = make_list_of_images(images)
self.assertIsInstance(images_list, list)
self.assertEqual(len(images_list), 4)
self.assertIsInstance(images_list[0], PIL.Image.Image)
def test_make_list_of_images_numpy(self):
# Test a single image is converted to a list of 1 image
images = np.random.randint(0, 256, (16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 1)
self.assertTrue(np.array_equal(images_list[0], images))
self.assertIsInstance(images_list, list)
# Test a batch of images is converted to a list of images
images = np.random.randint(0, 256, (4, 16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test a list of images is not modified
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test batched masks with no channel dimension are converted to a list of masks
masks = np.random.randint(0, 2, (4, 16, 32))
masks_list = make_list_of_images(masks, expected_ndims=2)
self.assertEqual(len(masks_list), 4)
self.assertTrue(np.array_equal(masks_list[0], masks[0]))
self.assertIsInstance(masks_list, list)
@require_torch
def test_make_list_of_images_torch(self):
# Test a single image is converted to a list of 1 image
images = torch.randint(0, 256, (16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 1)
self.assertTrue(np.array_equal(images_list[0], images))
self.assertIsInstance(images_list, list)
# Test a batch of images is converted to a list of images
images = torch.randint(0, 256, (4, 16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test a list of images is left unchanged
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
def test_make_flat_list_of_images_pil(self):
# Test a single image is converted to a list of 1 image
pil_image = get_random_image(16, 32)
images_list = make_flat_list_of_images(pil_image)
self.assertIsInstance(images_list, list)
self.assertEqual(len(images_list), 1)
self.assertIsInstance(images_list[0], PIL.Image.Image)
# Test a list of images is not modified
images = [get_random_image(16, 32) for _ in range(4)]
images_list = make_flat_list_of_images(images)
self.assertIsInstance(images_list, list)
self.assertEqual(len(images_list), 4)
self.assertIsInstance(images_list[0], PIL.Image.Image)
# Test a nested list of images is flattened
images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)]
images_list = make_flat_list_of_images(images)
self.assertIsInstance(images_list, list)
self.assertEqual(len(images_list), 4)
self.assertIsInstance(images_list[0], PIL.Image.Image)
def test_make_flat_list_of_images_numpy(self):
# Test a single image is converted to a list of 1 image
images = np.random.randint(0, 256, (16, 32, 3))
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 1)
self.assertTrue(np.array_equal(images_list[0], images))
self.assertIsInstance(images_list, list)
# Test a 4d array of images is changed to a list of images
images = np.random.randint(0, 256, (4, 16, 32, 3))
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertIsInstance(images_list, list)
self.assertIsInstance(images_list[0], np.ndarray)
self.assertTrue(np.array_equal(images_list[0], images[0]))
# Test a list of images is not modified
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test list of 4d array images is flattened
images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 8)
self.assertTrue(np.array_equal(images_list[0], images[0][0]))
self.assertIsInstance(images_list, list)
self.assertIsInstance(images_list[0], np.ndarray)
# Test nested list of images is flattened
images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0][0]))
self.assertIsInstance(images_list, list)
@require_torch
def test_make_flat_list_of_images_torch(self):
# Test a single image is converted to a list of 1 image
images = torch.randint(0, 256, (16, 32, 3))
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 1)
self.assertTrue(np.array_equal(images_list[0], images))
self.assertIsInstance(images_list, list)
# Test a 4d tensors of images is changed to a list of images
images = torch.randint(0, 256, (4, 16, 32, 3))
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertIsInstance(images_list, list)
self.assertIsInstance(images_list[0], torch.Tensor)
self.assertTrue(np.array_equal(images_list[0], images[0]))
# Test a list of images is not modified
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test list of 4d tensors of imagess is flattened
images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 8)
self.assertTrue(np.array_equal(images_list[0], images[0][0]))
self.assertIsInstance(images_list, list)
self.assertIsInstance(images_list[0], torch.Tensor)
# Test nested list of images is flattened
images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
images_list = make_flat_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0][0]))
self.assertIsInstance(images_list, list)
def test_make_nested_list_of_images_pil(self):
# Test a single image is converted to a nested list of 1 image
pil_image = get_random_image(16, 32)
images_list = make_nested_list_of_images(pil_image)
self.assertIsInstance(images_list[0], list)
self.assertEqual(len(images_list[0]), 1)
self.assertIsInstance(images_list[0][0], PIL.Image.Image)
# Test a list of images is converted to a nested list of images
images = [get_random_image(16, 32) for _ in range(4)]
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertEqual(len(images_list), 1)
self.assertEqual(len(images_list[0]), 4)
self.assertIsInstance(images_list[0][0], PIL.Image.Image)
# Test a nested list of images is not modified
images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)]
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertEqual(len(images_list), 2)
self.assertEqual(len(images_list[0]), 2)
self.assertIsInstance(images_list[0][0], PIL.Image.Image)
def test_make_nested_list_of_images_numpy(self):
# Test a single image is converted to a nested list of 1 image
images = np.random.randint(0, 256, (16, 32, 3))
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertEqual(len(images_list), 1)
self.assertTrue(np.array_equal(images_list[0][0], images))
# Test a 4d array of images is converted to a nested list of images
images = np.random.randint(0, 256, (4, 16, 32, 3))
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertIsInstance(images_list[0][0], np.ndarray)
self.assertEqual(len(images_list), 1)
self.assertEqual(len(images_list[0]), 4)
self.assertTrue(np.array_equal(images_list[0][0], images[0]))
# Test a list of images is converted to a nested list of images
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertEqual(len(images_list), 1)
self.assertEqual(len(images_list[0]), 4)
self.assertTrue(np.array_equal(images_list[0][0], images[0]))
# Test a nested list of images is left unchanged
images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertEqual(len(images_list), 2)
self.assertEqual(len(images_list[0]), 2)
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
# Test a list of 4d array images is converted to a nested list of images
images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertIsInstance(images_list[0][0], np.ndarray)
self.assertEqual(len(images_list), 2)
self.assertEqual(len(images_list[0]), 4)
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
@require_torch
def test_make_nested_list_of_images_torch(self):
# Test a single image is converted to a nested list of 1 image
images = torch.randint(0, 256, (16, 32, 3))
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertEqual(len(images_list[0]), 1)
self.assertTrue(np.array_equal(images_list[0][0], images))
# Test a 4d tensor of images is converted to a nested list of images
images = torch.randint(0, 256, (4, 16, 32, 3))
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertIsInstance(images_list[0][0], torch.Tensor)
self.assertEqual(len(images_list), 1)
self.assertEqual(len(images_list[0]), 4)
self.assertTrue(np.array_equal(images_list[0][0], images[0]))
# Test a list of images is converted to a nested list of images
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertEqual(len(images_list), 1)
self.assertEqual(len(images_list[0]), 4)
self.assertTrue(np.array_equal(images_list[0][0], images[0]))
# Test a nested list of images is left unchanged
images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertEqual(len(images_list), 2)
self.assertEqual(len(images_list[0]), 2)
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
# Test a list of 4d tensor images is converted to a nested list of images
images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
images_list = make_nested_list_of_images(images)
self.assertIsInstance(images_list[0], list)
self.assertIsInstance(images_list[0][0], torch.Tensor)
self.assertEqual(len(images_list), 2)
self.assertEqual(len(images_list[0]), 4)
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
@require_torch
def test_conversion_torch_to_array(self):
feature_extractor = ImageFeatureExtractionMixin()
tensor = torch.randint(0, 256, (16, 32, 3))
array = tensor.numpy()
# By default, rescale (for a tensor of ints) and channel permute
array1 = feature_extractor.to_numpy_array(array)
self.assertTrue(array1.dtype, np.float32)
self.assertEqual(array1.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array1, array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0)))
# Same with no permute
array2 = feature_extractor.to_numpy_array(array, channel_first=False)
self.assertTrue(array2.dtype, np.float32)
self.assertEqual(array2.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array2, array.astype(np.float32) * (1 / 255.0)))
# Force rescale to False
array3 = feature_extractor.to_numpy_array(array, rescale=False)
self.assertTrue(array3.dtype, np.uint8)
self.assertEqual(array3.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array3, array.transpose(2, 0, 1)))
# Force rescale to False and no channel permute
array4 = feature_extractor.to_numpy_array(array, rescale=False, channel_first=False)
self.assertTrue(array4.dtype, np.uint8)
self.assertEqual(array4.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array4, array))
# Now test the default rescale for a float tensor (defaults to False)
array5 = feature_extractor.to_numpy_array(array2)
self.assertTrue(array5.dtype, np.float32)
self.assertEqual(array5.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array5, array1))
def test_conversion_image_to_image(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
# On an image, `to_pil_image1` is a noop.
image1 = feature_extractor.to_pil_image(image)
self.assertTrue(isinstance(image, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image), np.array(image1)))
def test_conversion_array_to_image(self):
feature_extractor = ImageFeatureExtractionMixin()
array = np.random.randint(0, 256, (16, 32, 3), dtype=np.uint8)
# By default, no rescale (for an array of ints)
image1 = feature_extractor.to_pil_image(array)
self.assertTrue(isinstance(image1, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image1), array))
# If the array is channel-first, proper reordering of the channels is done.
image2 = feature_extractor.to_pil_image(array.transpose(2, 0, 1))
self.assertTrue(isinstance(image2, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image2), array))
# If the array has floating type, it's rescaled by default.
image3 = feature_extractor.to_pil_image(array.astype(np.float32) * (1 / 255.0))
self.assertTrue(isinstance(image3, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image3), array))
# You can override the default to rescale.
image4 = feature_extractor.to_pil_image(array.astype(np.float32), rescale=False)
self.assertTrue(isinstance(image4, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image4), array))
# And with floats + channel first.
image5 = feature_extractor.to_pil_image(array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0))
self.assertTrue(isinstance(image5, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image5), array))
@require_torch
def test_conversion_tensor_to_image(self):
feature_extractor = ImageFeatureExtractionMixin()
tensor = torch.randint(0, 256, (16, 32, 3))
array = tensor.numpy()
# By default, no rescale (for a tensor of ints)
image1 = feature_extractor.to_pil_image(tensor)
self.assertTrue(isinstance(image1, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image1), array))
# If the tensor is channel-first, proper reordering of the channels is done.
image2 = feature_extractor.to_pil_image(tensor.permute(2, 0, 1))
self.assertTrue(isinstance(image2, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image2), array))
# If the tensor has floating type, it's rescaled by default.
image3 = feature_extractor.to_pil_image(tensor.float() / 255.0)
self.assertTrue(isinstance(image3, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image3), array))
# You can override the default to rescale.
image4 = feature_extractor.to_pil_image(tensor.float(), rescale=False)
self.assertTrue(isinstance(image4, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image4), array))
# And with floats + channel first.
image5 = feature_extractor.to_pil_image(tensor.permute(2, 0, 1).float() * (1 / 255.0))
self.assertTrue(isinstance(image5, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image5), array))
def test_resize_image_and_array(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = np.array(image)
# Size can be an int or a tuple of ints.
resized_image = feature_extractor.resize(image, 8)
self.assertTrue(isinstance(resized_image, PIL.Image.Image))
self.assertEqual(resized_image.size, (8, 8))
resized_image1 = feature_extractor.resize(image, (8, 16))
self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
self.assertEqual(resized_image1.size, (8, 16))
# Passing an array converts it to a PIL Image.
resized_image2 = feature_extractor.resize(array, 8)
self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
self.assertEqual(resized_image2.size, (8, 8))
self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
resized_image3 = feature_extractor.resize(image, (8, 16))
self.assertTrue(isinstance(resized_image3, PIL.Image.Image))
self.assertEqual(resized_image3.size, (8, 16))
self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
def test_resize_image_and_array_non_default_to_square(self):
feature_extractor = ImageFeatureExtractionMixin()
heights_widths = [
# height, width
# square image
(28, 28),
(27, 27),
# rectangular image: h < w
(28, 34),
(29, 35),
# rectangular image: h > w
(34, 28),
(35, 29),
]
# single integer or single integer in tuple/list
sizes = [22, 27, 28, 36, [22], (27,)]
for (height, width), size in zip(heights_widths, sizes):
for max_size in (None, 37, 1000):
image = get_random_image(height, width)
array = np.array(image)
size = size[0] if isinstance(size, (list, tuple)) else size
# Size can be an int or a tuple of ints.
# If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
if height < width:
exp_w, exp_h = (int(size * width / height), size)
if max_size is not None and max_size < exp_w:
exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
elif width < height:
exp_w, exp_h = (size, int(size * height / width))
if max_size is not None and max_size < exp_h:
exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
else:
exp_w, exp_h = (size, size)
if max_size is not None and max_size < size:
exp_w, exp_h = max_size, max_size
resized_image = feature_extractor.resize(image, size=size, default_to_square=False, max_size=max_size)
self.assertTrue(isinstance(resized_image, PIL.Image.Image))
self.assertEqual(resized_image.size, (exp_w, exp_h))
# Passing an array converts it to a PIL Image.
resized_image2 = feature_extractor.resize(array, size=size, default_to_square=False, max_size=max_size)
self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
self.assertEqual(resized_image2.size, (exp_w, exp_h))
self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
@require_torch
def test_resize_tensor(self):
feature_extractor = ImageFeatureExtractionMixin()
tensor = torch.randint(0, 256, (16, 32, 3))
array = tensor.numpy()
# Size can be an int or a tuple of ints.
resized_image = feature_extractor.resize(tensor, 8)
self.assertTrue(isinstance(resized_image, PIL.Image.Image))
self.assertEqual(resized_image.size, (8, 8))
resized_image1 = feature_extractor.resize(tensor, (8, 16))
self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
self.assertEqual(resized_image1.size, (8, 16))
# Check we get the same results as with NumPy arrays.
resized_image2 = feature_extractor.resize(array, 8)
self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
resized_image3 = feature_extractor.resize(array, (8, 16))
self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
def test_normalize_image(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = np.array(image)
mean = [0.1, 0.5, 0.9]
std = [0.2, 0.4, 0.6]
# PIL Image are converted to NumPy arrays for the normalization
normalized_image = feature_extractor.normalize(image, mean, std)
self.assertTrue(isinstance(normalized_image, np.ndarray))
self.assertEqual(normalized_image.shape, (3, 16, 32))
# During the conversion rescale and channel first will be applied.
expected = array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0)
np_mean = np.array(mean).astype(np.float32)[:, None, None]
np_std = np.array(std).astype(np.float32)[:, None, None]
expected = (expected - np_mean) / np_std
self.assertTrue(np.array_equal(normalized_image, expected))
def test_normalize_array(self):
feature_extractor = ImageFeatureExtractionMixin()
array = np.random.random((16, 32, 3))
mean = [0.1, 0.5, 0.9]
std = [0.2, 0.4, 0.6]
# mean and std can be passed as lists or NumPy arrays.
expected = (array - np.array(mean)) / np.array(std)
normalized_array = feature_extractor.normalize(array, mean, std)
self.assertTrue(np.array_equal(normalized_array, expected))
normalized_array = feature_extractor.normalize(array, np.array(mean), np.array(std))
self.assertTrue(np.array_equal(normalized_array, expected))
# Normalize will detect automatically if channel first or channel last is used.
array = np.random.random((3, 16, 32))
expected = (array - np.array(mean)[:, None, None]) / np.array(std)[:, None, None]
normalized_array = feature_extractor.normalize(array, mean, std)
self.assertTrue(np.array_equal(normalized_array, expected))
normalized_array = feature_extractor.normalize(array, np.array(mean), np.array(std))
self.assertTrue(np.array_equal(normalized_array, expected))
@require_torch
def test_normalize_tensor(self):
feature_extractor = ImageFeatureExtractionMixin()
tensor = torch.rand(16, 32, 3)
mean = [0.1, 0.5, 0.9]
std = [0.2, 0.4, 0.6]
# mean and std can be passed as lists or tensors.
expected = (tensor - torch.tensor(mean)) / torch.tensor(std)
normalized_tensor = feature_extractor.normalize(tensor, mean, std)
self.assertTrue(torch.equal(normalized_tensor, expected))
normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
self.assertTrue(torch.equal(normalized_tensor, expected))
# Normalize will detect automatically if channel first or channel last is used.
tensor = torch.rand(3, 16, 32)
expected = (tensor - torch.tensor(mean)[:, None, None]) / torch.tensor(std)[:, None, None]
normalized_tensor = feature_extractor.normalize(tensor, mean, std)
self.assertTrue(torch.equal(normalized_tensor, expected))
normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
self.assertTrue(torch.equal(normalized_tensor, expected))
def test_center_crop_image(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(isinstance(cropped_image, PIL.Image.Image))
# PIL Image.size is transposed compared to NumPy or PyTorch (width first instead of height first).
expected_size = (size, size) if isinstance(size, int) else (size[1], size[0])
self.assertEqual(cropped_image.size, expected_size)
def test_center_crop_array(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = feature_extractor.to_numpy_array(image)
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
cropped_array = feature_extractor.center_crop(array, size)
self.assertTrue(isinstance(cropped_array, np.ndarray))
expected_size = (size, size) if isinstance(size, int) else size
self.assertEqual(cropped_array.shape[-2:], expected_size)
# Check result is consistent with PIL.Image.crop
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(np.array_equal(cropped_array, feature_extractor.to_numpy_array(cropped_image)))
@require_torch
def test_center_crop_tensor(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = feature_extractor.to_numpy_array(image)
tensor = torch.tensor(array)
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
cropped_tensor = feature_extractor.center_crop(tensor, size)
self.assertTrue(isinstance(cropped_tensor, torch.Tensor))
expected_size = (size, size) if isinstance(size, int) else size
self.assertEqual(cropped_tensor.shape[-2:], expected_size)
# Check result is consistent with PIL.Image.crop
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))
@require_vision
class LoadImageTester(unittest.TestCase):
def test_load_img_url(self):
img = load_image(INVOICE_URL)
img_arr = np.array(img)
self.assertEqual(img_arr.shape, (1061, 750, 3))
@is_flaky()
def test_load_img_url_timeout(self):
with self.assertRaises(httpx.ConnectTimeout):
load_image(INVOICE_URL, timeout=0.001)
def test_load_img_local(self):
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(480, 640, 3),
)
def test_load_img_base64_prefix(self):
path = hf_hub_download(
repo_id="hf-internal-testing/dummy-base64-images", filename="image_0.txt", repo_type="dataset"
)
with open(path, encoding="utf-8") as b64:
img = load_image(b64.read())
img_arr = np.array(img)
self.assertEqual(img_arr.shape, (64, 32, 3))
def test_load_img_base64(self):
path = hf_hub_download(
repo_id="hf-internal-testing/dummy-base64-images", filename="image_1.txt", repo_type="dataset"
)
with open(path, encoding="utf-8") as b64:
img = load_image(b64.read())
img_arr = np.array(img)
self.assertEqual(img_arr.shape, (64, 32, 3))
def test_load_img_base64_encoded_bytes(self):
path = hf_hub_download(
repo_id="hf-internal-testing/dummy-base64-images", filename="image_2.txt", repo_type="dataset"
)
with codecs.open(path, encoding="unicode_escape") as b64:
img = load_image(b64.read())
img_arr = np.array(img)
self.assertEqual(img_arr.shape, (256, 256, 3))
def test_load_img_rgba(self):
# we use revision="refs/pr/1" until the PR is merged
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
img = get_image_from_hub_dataset(
"hf-internal-testing/fixtures_image_utils", "0-test-lena.png", revision="refs/pr/1"
)
img = load_image(img) # img with mode RGBA
img_arr = np.array(img)
self.assertEqual(img_arr.shape, (512, 512, 3))
def test_load_img_la(self):
# we use revision="refs/pr/1" until the PR is merged
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
img = get_image_from_hub_dataset(
"hf-internal-testing/fixtures_image_utils", "1-test-parrots.png", revision="refs/pr/1"
)
img = load_image(img) # img with mode LA
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(512, 768, 3),
)
def test_load_img_l(self):
# we use revision="refs/pr/1" until the PR is merged
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
img = get_image_from_hub_dataset(
"hf-internal-testing/fixtures_image_utils", "2-test-tree.png", revision="refs/pr/1"
)
img = load_image(img) # img with mode L
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(381, 225, 3),
)
def test_load_img_exif_transpose(self):
# we use revision="refs/pr/1" until the PR is merged
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
img_without_exif_transpose = get_image_from_hub_dataset(
"hf-internal-testing/fixtures_image_utils", "3-test-cat-rotated.jpg", revision="refs/pr/1"
)
img_arr_without_exif_transpose = np.array(img_without_exif_transpose)
self.assertEqual(
img_arr_without_exif_transpose.shape,
(333, 500, 3),
)
img_with_exif_transpose = load_image(img_without_exif_transpose)
img_arr_with_exif_transpose = np.array(img_with_exif_transpose)
self.assertEqual(
img_arr_with_exif_transpose.shape,
(500, 333, 3),
)
class UtilFunctionTester(unittest.TestCase):
def test_get_image_size(self):
# Test we can infer the size and channel dimension of an image.
image = np.random.randint(0, 256, (32, 64, 3))
self.assertEqual(get_image_size(image), (32, 64))
image = np.random.randint(0, 256, (3, 32, 64))
self.assertEqual(get_image_size(image), (32, 64))
# Test the channel dimension can be overridden
image = np.random.randint(0, 256, (3, 32, 64))
self.assertEqual(get_image_size(image, channel_dim=ChannelDimension.LAST), (3, 32))
def test_infer_channel_dimension(self):
# Test we fail with invalid input
with pytest.raises(ValueError):
infer_channel_dimension_format(np.random.randint(0, 256, (10, 10)))
with pytest.raises(ValueError):
infer_channel_dimension_format(np.random.randint(0, 256, (10, 10, 10, 10, 10)))
# Test we fail if neither first not last dimension is of size 3 or 1
with pytest.raises(ValueError):
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
# But if we explicitly set one of the number of channels to 50 it works
inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
self.assertEqual(inferred_dim, ChannelDimension.LAST)
# Test we correctly identify the channel dimension
image = np.random.randint(0, 256, (3, 4, 5))
inferred_dim = infer_channel_dimension_format(image)
self.assertEqual(inferred_dim, ChannelDimension.FIRST)
image = np.random.randint(0, 256, (1, 4, 5))
inferred_dim = infer_channel_dimension_format(image)
self.assertEqual(inferred_dim, ChannelDimension.FIRST)
image = np.random.randint(0, 256, (4, 5, 3))
inferred_dim = infer_channel_dimension_format(image)
self.assertEqual(inferred_dim, ChannelDimension.LAST)
image = np.random.randint(0, 256, (4, 5, 1))
inferred_dim = infer_channel_dimension_format(image)
self.assertEqual(inferred_dim, ChannelDimension.LAST)
# We can take a batched array of images and find the dimension
image = np.random.randint(0, 256, (1, 3, 4, 5))
inferred_dim = infer_channel_dimension_format(image)
self.assertEqual(inferred_dim, ChannelDimension.FIRST)
def test_get_channel_dimension_axis(self):
# Test we correctly identify the channel dimension
image = np.random.randint(0, 256, (3, 4, 5))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 0)
image = np.random.randint(0, 256, (1, 4, 5))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 0)
image = np.random.randint(0, 256, (4, 5, 3))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 2)
image = np.random.randint(0, 256, (4, 5, 1))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 2)
# We can take a batched array of images and find the dimension
image = np.random.randint(0, 256, (1, 3, 4, 5))
inferred_axis = get_channel_dimension_axis(image)
self.assertEqual(inferred_axis, 1)

View File

@@ -0,0 +1,208 @@
import os
import unittest
from pathlib import Path
import pytest
from transformers.utils.import_utils import (
Backend,
VersionComparison,
define_import_structure,
spread_import_structure,
)
import_structures = Path(__file__).parent / "import_structures"
def fetch__all__(file_content):
"""
Returns the content of the __all__ variable in the file content.
Returns None if not defined, otherwise returns a list of strings.
"""
lines = file_content.split("\n")
for line_index in range(len(lines)):
line = lines[line_index]
if line.startswith("__all__ = "):
# __all__ is defined on a single line
if line.endswith("]"):
return [obj.strip("\"' ") for obj in line.split("=")[1].strip(" []").split(",")]
# __all__ is defined on multiple lines
else:
_all = []
for __all__line_index in range(line_index + 1, len(lines)):
if lines[__all__line_index].strip() == "]":
return _all
else:
_all.append(lines[__all__line_index].strip("\"', "))
class TestImportStructures(unittest.TestCase):
base_transformers_path = Path(__file__).parent.parent.parent
models_path = base_transformers_path / "src" / "transformers" / "models"
models_import_structure = spread_import_structure(define_import_structure(models_path))
def test_definition(self):
import_structure = define_import_structure(import_structures)
valid_frozensets: dict[frozenset | frozenset[str], dict[str, set[str]]] = {
frozenset(): {
"import_structure_raw_register": {"A0", "A4", "a0"},
"import_structure_register_with_comments": {"B0", "b0"},
},
frozenset({"random_item_that_should_not_exist"}): {"failing_export": {"A0"}},
frozenset({"torch"}): {
"import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"},
"import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"},
"import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"},
},
frozenset({"torch>=2.5"}): {"import_structure_raw_register_with_versions": {"D0", "d0"}},
frozenset({"torch>2.5"}): {"import_structure_raw_register_with_versions": {"D1", "d1"}},
frozenset({"torch<=2.5"}): {"import_structure_raw_register_with_versions": {"D2", "d2"}},
frozenset({"torch<2.5"}): {"import_structure_raw_register_with_versions": {"D3", "d3"}},
frozenset({"torch==2.5"}): {"import_structure_raw_register_with_versions": {"D4", "d4"}},
frozenset({"torch!=2.5"}): {"import_structure_raw_register_with_versions": {"D5", "d5"}},
frozenset({"torch>=2.5", "accelerate<0.20"}): {
"import_structure_raw_register_with_versions": {"D6", "d6"}
},
}
self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys()))
for _frozenset in valid_frozensets:
self.assertTrue(_frozenset in import_structure)
self.assertListEqual(
sorted(import_structure[_frozenset].keys()), sorted(valid_frozensets[_frozenset].keys())
)
for module, objects in valid_frozensets[_frozenset].items():
self.assertTrue(module in import_structure[_frozenset])
self.assertSetEqual(objects, import_structure[_frozenset][module])
def test_transformers_specific_model_import(self):
"""
This test ensures that there is equivalence between what is written down in __all__ and what is
written down with register().
It doesn't test the backends attributed to register().
"""
for architecture in os.listdir(self.models_path):
if (
os.path.isfile(self.models_path / architecture)
or architecture.startswith("_")
or architecture == "deprecated"
):
continue
with self.subTest(f"Testing arch {architecture}"):
import_structure = define_import_structure(self.models_path / architecture)
backend_agnostic_import_structure = {}
for module_object_mapping in import_structure.values():
for module, objects in module_object_mapping.items():
if module not in backend_agnostic_import_structure:
backend_agnostic_import_structure[module] = []
backend_agnostic_import_structure[module].extend(objects)
for module, objects in backend_agnostic_import_structure.items():
with open(self.models_path / architecture / f"{module}.py") as f:
content = f.read()
_all = fetch__all__(content)
if _all is None:
raise ValueError(f"{module} doesn't have __all__ defined.")
error_message = (
f"self.models_path / architecture / f'{module}.py doesn't seem to be defined correctly:\n"
f"Defined in __all__: {sorted(_all)}\nDefined with register: {sorted(objects)}"
)
self.assertListEqual(sorted(objects), sorted(_all), msg=error_message)
def test_import_spread(self):
"""
This test is specifically designed to test that varying levels of depth across import structures are
respected.
In this instance, frozensets are at respective depths of 1, 2 and 3, for example:
- models.{frozensets}
- models.albert.{frozensets}
- models.deprecated.transfo_xl.{frozensets}
"""
initial_import_structure = {
frozenset(): {"dummy_non_model": {"DummyObject"}},
"models": {
frozenset(): {"dummy_config": {"DummyConfig"}},
"albert": {
frozenset(): {"configuration_albert": {"AlbertConfig"}},
frozenset({"torch"}): {
"modeling_albert": {
"AlbertForMaskedLM",
}
},
},
"llama": {
frozenset(): {"configuration_llama": {"LlamaConfig"}},
frozenset({"torch"}): {
"modeling_llama": {
"LlamaForCausalLM",
}
},
},
"deprecated": {
"transfo_xl": {
frozenset({"torch"}): {
"modeling_transfo_xl": {
"TransfoXLModel",
}
},
frozenset(): {
"configuration_transfo_xl": {"TransfoXLConfig"},
"tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
},
},
"deta": {
frozenset({"torch"}): {
"modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"}
},
frozenset(): {"configuration_deta": {"DetaConfig"}},
frozenset({"vision"}): {"image_processing_deta": {"DetaImageProcessor"}},
},
},
},
}
ground_truth_spread_import_structure = {
frozenset(): {
"dummy_non_model": {"DummyObject"},
"models.dummy_config": {"DummyConfig"},
"models.albert.configuration_albert": {"AlbertConfig"},
"models.llama.configuration_llama": {"LlamaConfig"},
"models.deprecated.transfo_xl.configuration_transfo_xl": {"TransfoXLConfig"},
"models.deprecated.transfo_xl.tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
"models.deprecated.deta.configuration_deta": {"DetaConfig"},
},
frozenset({"torch"}): {
"models.albert.modeling_albert": {"AlbertForMaskedLM"},
"models.llama.modeling_llama": {"LlamaForCausalLM"},
"models.deprecated.transfo_xl.modeling_transfo_xl": {"TransfoXLModel"},
"models.deprecated.deta.modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"},
},
frozenset({"vision"}): {"models.deprecated.deta.image_processing_deta": {"DetaImageProcessor"}},
}
newly_spread_import_structure = spread_import_structure(initial_import_structure)
self.assertEqual(ground_truth_spread_import_structure, newly_spread_import_structure)
@pytest.mark.parametrize(
"backend,package_name,version_comparison,version",
[
pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL, "2.5"),
pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL, "0.19.1"),
],
)
def test_backend_specification(
backend: Backend, package_name: str, version_comparison: VersionComparison, version: str
):
assert backend.package_name == package_name
assert VersionComparison.from_string(backend.version_comparison) == version_comparison
assert backend.version == version

View File

@@ -0,0 +1,50 @@
import sys
from types import ModuleType
from unittest.mock import patch
from transformers.testing_utils import run_test_using_subprocess
from transformers.utils.import_utils import _is_package_available, clear_import_cache
@run_test_using_subprocess
def test_clear_import_cache():
"""Test the clear_import_cache function."""
# Save initial state
initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
assert len(initial_modules) > 0, "No transformers modules loaded before test"
# Execute clear_import_cache() function
clear_import_cache()
# Verify modules were removed
remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
assert len(remaining_modules) < len(initial_modules), "No modules were removed"
# Import and verify module exists
from transformers.models.auto import modeling_auto
assert "transformers.models.auto.modeling_auto" in sys.modules
assert modeling_auto.__name__ == "transformers.models.auto.modeling_auto"
def test_is_package_available_edge_cases():
pkg_name = "definitely_not_a_real_pkg_xyz"
namespace_shadow = ModuleType(pkg_name)
versionless_install = ModuleType(pkg_name)
versionless_install.__file__ = f"/path/to/site-packages/{pkg_name}/__init__.py"
with_version = ModuleType(pkg_name)
with_version.__version__ = "1.2.3"
cases = [
(namespace_shadow, (False, "N/A")),
(versionless_install, (True, "N/A")),
(with_version, (True, "1.2.3")),
]
for fake_module, expected in cases:
with (
patch("transformers.utils.import_utils.importlib.util.find_spec", return_value=object()),
patch("transformers.utils.import_utils.importlib.import_module", return_value=fake_module),
):
assert _is_package_available(pkg_name, return_version=True) == expected

135
tests/utils/test_logging.py Normal file
View File

@@ -0,0 +1,135 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from huggingface_hub.utils import are_progress_bars_disabled
import transformers.models.roberta.tokenization_roberta
from transformers import logging
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
class HfArgumentParserTest(unittest.TestCase):
def test_set_level(self):
logger = logging.get_logger()
# the current default level is logging.WARNING
level_origin = logging.get_verbosity()
logging.set_verbosity_error()
self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
logging.set_verbosity_warning()
self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
logging.set_verbosity_info()
self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
logging.set_verbosity_debug()
self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
# restore to the original level
logging.set_verbosity(level_origin)
def test_integration(self):
level_origin = logging.get_verbosity()
logger = logging.get_logger("transformers.models.roberta.tokenization_roberta")
msg = "Testing 1, 2, 3"
# should be able to log warnings (if default settings weren't overridden by `pytest --log-level-all`)
if level_origin <= logging.WARNING:
with CaptureLogger(logger) as cl:
logger.warning(msg)
self.assertEqual(cl.out, msg + "\n")
# this is setting the level for all of `transformers.*` loggers
logging.set_verbosity_error()
# should not be able to log warnings
with CaptureLogger(logger) as cl:
logger.warning(msg)
self.assertEqual(cl.out, "")
# should be able to log warnings again
logging.set_verbosity_warning()
with CaptureLogger(logger) as cl:
logger.warning(msg)
self.assertEqual(cl.out, msg + "\n")
# restore to the original level
logging.set_verbosity(level_origin)
@mockenv(TRANSFORMERS_VERBOSITY="error")
def test_env_override(self):
# reset for the env var to take effect, next time some logger call is made
transformers.utils.logging._reset_library_root_logger()
# this action activates the env var
_ = logging.get_logger("transformers.models.roberta.tokenization_roberta")
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
env_level = logging.log_levels[env_level_str]
current_level = logging.get_verbosity()
self.assertEqual(
env_level,
current_level,
f"TRANSFORMERS_VERBOSITY={env_level_str}/{env_level}, but internal verbosity is {current_level}",
)
# restore to the original level
os.environ["TRANSFORMERS_VERBOSITY"] = ""
transformers.utils.logging._reset_library_root_logger()
@mockenv(TRANSFORMERS_VERBOSITY="super-error")
def test_env_invalid_override(self):
# reset for the env var to take effect, next time some logger call is made
transformers.utils.logging._reset_library_root_logger()
logger = logging.logging.getLogger()
with CaptureLogger(logger) as cl:
# this action activates the env var
logging.get_logger("transformers.models.roberta.tokenization_roberta")
self.assertIn("Unknown option TRANSFORMERS_VERBOSITY=super-error", cl.out)
# no need to restore as nothing was changed
def test_advisory_warnings(self):
# testing `logger.warning_advice()`
transformers.utils.logging._reset_library_root_logger()
logger = logging.get_logger("transformers.models.roberta.tokenization_roberta")
msg = "Testing 1, 2, 3"
with mockenv_context(TRANSFORMERS_NO_ADVISORY_WARNINGS="1"):
# nothing should be logged as env var disables this method
with CaptureLogger(logger) as cl:
logger.warning_advice(msg)
self.assertEqual(cl.out, "")
with mockenv_context(TRANSFORMERS_NO_ADVISORY_WARNINGS=""):
# should log normally as TRANSFORMERS_NO_ADVISORY_WARNINGS is unset
with CaptureLogger(logger) as cl:
logger.warning_advice(msg)
self.assertEqual(cl.out, msg + "\n")
def test_set_progress_bar_enabled():
disable_progress_bar()
assert are_progress_bars_disabled()
enable_progress_bar()
assert not are_progress_bars_disabled()

View File

@@ -0,0 +1,375 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import (
cleanup,
is_torch_available,
require_torch,
torch_device,
)
if is_torch_available():
import torch
from torch.nn.attention.flex_attention import create_block_mask
from transformers import DynamicCache, LlamaConfig
from transformers.cache_utils import DynamicSlidingWindowLayer
from transformers.masking_utils import (
create_bidirectional_mask,
create_causal_mask,
create_chunked_causal_mask,
find_packed_sequence_indices,
)
# fmt: off
EXPECTED_PACKED_MASK = torch.tensor([[[
[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[False, False, False, False, True, False, False, False, False, False],
[False, False, False, False, True, True, False, False, False, False],
[False, False, False, False, False, False, True, False, False, False],
[False, False, False, False, False, False, True, True, False, False],
[False, False, False, False, False, False, True, True, True, False],
[False, False, False, False, False, False, True, True, True, True]]],
[[[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[False, False, False, False, False, False, True, False, False, False],
[False, False, False, False, False, False, True, True, False, False],
[False, False, False, False, False, False, True, True, True, False],
[False, False, False, False, False, False, True, True, True, True]
]]], dtype=torch.bool)
# fmt: on
@require_torch
class MaskTest(unittest.TestCase):
def setup(self):
cleanup(torch_device, gc_collect=True)
def tearDown(self):
cleanup(torch_device, gc_collect=True)
def test_packed_sequence_mask_sdpa(self):
config = LlamaConfig()
config._attn_implementation = "sdpa"
batch_size = 2
sequence_length = 10
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
past_key_values=None,
position_ids=position_ids,
)
self.assertTrue((causal_mask == EXPECTED_PACKED_MASK).all())
def test_packed_sequence_mask_eager(self):
config = LlamaConfig()
config._attn_implementation = "eager"
batch_size = 2
sequence_length = 10
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
past_key_values=None,
position_ids=position_ids,
)
min_dtype = torch.finfo(torch.float16).min
self.assertTrue((causal_mask == torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype)).all())
def test_packed_sequence_mask_flex_attention(self):
config = LlamaConfig()
config._attn_implementation = "flex_attention"
batch_size = 2
sequence_length = 10
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
past_key_values=None,
position_ids=position_ids,
)
def dummy_mask_mod(b, h, q, kv):
return EXPECTED_PACKED_MASK[b, h, q, kv]
EXPECTED_BLOCK_MASK = create_block_mask(dummy_mask_mod, 2, None, 10, 10, device="cpu")
# We compatre the str representations, as the BlockMask objects themselves cannot easily be compared
self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())
def test_find_packed_sequence_indices(self):
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())
def test_nonpacked_sequence_mask_skip(self):
config = LlamaConfig()
config._attn_implementation = "sdpa"
batch_size = 2
sequence_length = 10
# Non-packed sequences
position_ids = torch.arange(sequence_length)[None, :]
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
past_key_values=None,
position_ids=position_ids,
)
# packed sequence should be skipped
self.assertTrue(causal_mask is None)
create_causal_mask_compiled = torch.compile(create_causal_mask, mode="reduce-overhead")
causal_mask = create_causal_mask_compiled(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
past_key_values=None,
position_ids=position_ids,
)
# cannot be skipped under compile, should result into a triu mask
self.assertTrue(torch.equal(~torch.ones(*causal_mask.shape).triu(diagonal=1).bool(), causal_mask))
def test_chunked_mask_with_left_padding_and_large_prefill(self):
# Make sure we have an attention_chunk_size in the config
config = LlamaConfig(attention_chunk_size=3, attn_implementation="sdpa")
batch_size = 2
sequence_length = 8
pad_tokens = 4
input_ids = torch.randint(100, 200, (batch_size, sequence_length))
attention_mask = torch.tensor(
[[0 if i < pad_tokens else 1 for i in range(sequence_length)], [1] * sequence_length]
)
inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16)
positions = torch.arange(sequence_length)
position_ids = torch.empty(batch_size, sequence_length, dtype=positions.dtype)
position_ids[0, :pad_tokens] = 1
position_ids[0, pad_tokens:] = torch.arange(sequence_length - pad_tokens)
position_ids[1, :] = positions
chunked_attention_mask = create_chunked_causal_mask(
config=config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=None,
position_ids=position_ids,
)
# fmt: off
EXPECTED_CHUNKED_MASK = torch.tensor(
# Here, for the padded sequence, the chunk size should start correctly at index 4 (otherwise, with 4 padding
# tokens are chunk_size=3, the first chunk is from indices 0-2, then 3-6 if we don't account for the padding correctly)
[[[[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, False, False, False, False],
[False, False, False, False, True, False, False, False],
[False, False, False, False, True, True, False, False],
[False, False, False, False, True, True, True, False],
[False, False, False, False, False, False, False, True]]],
[[[ True, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False],
[False, False, False, True, False, False, False, False],
[False, False, False, True, True, False, False, False],
[False, False, False, True, True, True, False, False],
[False, False, False, False, False, False, True, False],
[False, False, False, False, False, False, True, True]]]],
dtype=torch.bool)
# fmt: on
self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all())
def test_chunked_mask_with_left_padding_decoding(self):
# Make sure we have an attention_chunk_size in the config
config = LlamaConfig(attention_chunk_size=4, attn_implementation="sdpa", num_hidden_layers=1)
cache = DynamicCache(config=config)
# Sanity check
self.assertEqual(len(cache), 1)
self.assertTrue(isinstance(cache.layers[0], DynamicSlidingWindowLayer))
# Fill-in the Cache (sequence length is bigger than chunk size here)
batch_size = 2
prefill_size = 8
pad_tokens = 7
fake_kv = torch.rand(batch_size, 32, prefill_size, 32)
cache.update(fake_kv, fake_kv, 0, torch.arange(prefill_size))
# Create a new input after the prefill
input_ids = torch.randint(100, 200, (batch_size, 1))
attention_mask = torch.tensor(
[[0 if i < pad_tokens else 1 for i in range(prefill_size + 1)], [1] * (prefill_size + 1)]
)
inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16)
position_ids = torch.tensor([[prefill_size - pad_tokens], [prefill_size]])
chunked_attention_mask = create_chunked_causal_mask(
config=config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=cache,
position_ids=position_ids,
)
# To understand a bit more the following expected mask, here is the full 2d mask, where the "|" characters are the chunk
# separators (where the tokens should stop seeing each other)
# [0, 0, 0, 0, 0, 0, 0, | 1, 1], -> due to left padding, the first chunk only starts after the padding tokens
# [| 1, 1, 1, 1, | 1, 1, 1, 1, | 1]]) -> easy case, each 4 tokens is a new chunk
# fmt: off
EXPECTED_CHUNKED_MASK = torch.tensor(
# Here, for the padded sequence, the chunk size should start correctly at index 7 (the first unpadded
# index), and so only indices 7 and 8 should be True
[[[[False, False, True, True]]],
# Here, for the unpadded sequence, the chunks start at index 0. Since we have 9 tokens in total, the last
# token (index 8) will only see itself (we have 2 full chunks before)
[[[False, False, False, True]]]],
dtype=torch.bool)
# fmt: on
self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all())
@staticmethod
def _run_bidirectional_mask(mask_fn, attn_implementation):
def run_mask_creation(mask_fn, config, inputs_embeds, encoder_mask, cross_mask, encoder_hidden_states):
encoder_attn_mask = mask_fn(
config=config,
inputs_embeds=inputs_embeds,
attention_mask=encoder_mask,
)
cross_attn_mask = mask_fn(
config=config,
inputs_embeds=inputs_embeds,
attention_mask=cross_mask,
encoder_hidden_states=encoder_hidden_states,
)
return encoder_attn_mask, cross_attn_mask
# We use llama but could be also bert/bart --> we only need the `_attn_implementation` here
config = LlamaConfig()
config._attn_implementation = attn_implementation
# Meta data
batch_size = 2
q_length = 10
kv_length = 5
inputs_embeds = torch.ones((batch_size, q_length, 1), device=torch_device, dtype=torch.float16)
encoder_hidden_states = torch.ones((batch_size, kv_length, 1), device=torch_device, dtype=torch.float16)
encoder_mask = torch.ones_like(inputs_embeds)[..., 0]
cross_mask = torch.ones_like(encoder_hidden_states)[..., 0]
# Case 1: Full mask
full_mask_encoder_1, full_mask_cross_1 = run_mask_creation(
mask_fn=mask_fn,
config=config,
inputs_embeds=inputs_embeds,
encoder_mask=encoder_mask,
cross_mask=cross_mask,
encoder_hidden_states=encoder_hidden_states,
)
full_mask_encoder_2, full_mask_cross_2 = run_mask_creation(
mask_fn=mask_fn,
config=config,
inputs_embeds=inputs_embeds,
encoder_mask=None,
cross_mask=None,
encoder_hidden_states=encoder_hidden_states,
)
# Case 2: Padding involved
cross_mask[:, -1] = 0
encoder_mask[:, -1] = 0
padded_mask_encoder, padded_mask_cross = run_mask_creation(
mask_fn=mask_fn,
config=config,
inputs_embeds=inputs_embeds,
encoder_mask=encoder_mask,
cross_mask=cross_mask,
encoder_hidden_states=encoder_hidden_states,
)
full_masks = (full_mask_encoder_1, full_mask_encoder_2), (full_mask_cross_1, full_mask_cross_2)
padded_masks = (padded_mask_encoder, padded_mask_cross)
return full_masks, padded_masks
def test_bidirectional_mask_cudagraphs(self):
"""
Checks whether the bidirectional mask creation is compatible with cuda graphs, i.e. we do not into any error
during this test.
"""
mask_creation_function = torch.compile(create_bidirectional_mask, mode="reduce-overhead")
self._run_bidirectional_mask(mask_fn=mask_creation_function, attn_implementation="sdpa")
def test_bidirectional_mask_skip_eager(self):
"""
Checks whether the bidirectional mask creation can skip the mask creation if we have a full mask.
"""
full_masks, padded_mask = self._run_bidirectional_mask(
mask_fn=create_bidirectional_mask, attn_implementation="eager"
)
for alternative_masks in full_masks:
self.assertTrue(alternative_masks[0] is None)
self.assertTrue(alternative_masks[1] is None)
self.assertTrue(padded_mask[0] is not None)
self.assertTrue(padded_mask[1] is not None)

View File

@@ -0,0 +1,121 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import tempfile
import unittest
from pathlib import Path
from transformers import is_torch_available
from transformers.model_debugging_utils import model_addition_debugger_context
if is_torch_available():
import torch
from torch import nn
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.embed = nn.Embedding(10, 4)
self.linear_1 = nn.Linear(4, 8)
self.linear_2 = nn.Linear(8, 2)
self.act = nn.ReLU()
def forward(self, input_ids: str):
hidden_states = self.embed(input_ids).mean(dim=1)
hidden_states = self.act(self.linear_1(hidden_states))
return self.linear_2(hidden_states)
class TestModelAdditionDebugger(unittest.TestCase):
def setUp(self):
self.model = ToyModel()
self.inputs = {"input_ids": torch.randint(0, 10, (1, 3))}
def tearDown(self):
gc.collect()
def test_debugger_outputs(self):
with tempfile.TemporaryDirectory() as tmpdir:
with model_addition_debugger_context(self.model, debug_path=str(tmpdir)):
_ = self.model.forward(**self.inputs)
base = f"{self.model.__class__.__name__}_debug_tree"
summary = Path(os.path.join(tmpdir, f"{base}_SUMMARY.json"))
full = Path(os.path.join(tmpdir, f"{base}_FULL_TENSORS.json"))
self.assertTrue(os.path.isfile(summary) and os.path.isfile(full))
data = json.loads(summary.read_text())
self.assertTrue({"module_path", "inputs", "children"} <= data.keys())
self.assertTrue(data["children"])
class ToyLayer(nn.Module):
def __init__(self, layer_index):
super().__init__()
self.layer_index = layer_index
self.layer_operation = nn.Linear(4, 4)
def forward(self, hidden_states):
return self.layer_operation(hidden_states)
class ToyModelWithLayers(nn.Module):
def __init__(self):
super().__init__()
self.input_proj = nn.Linear(4, 4)
self.layers = nn.ModuleList([ToyLayer(layer_index) for layer_index in range(6)])
self.output_proj = nn.Linear(4, 2)
def forward(self, x):
x = self.input_proj(x)
for layer in self.layers:
x = layer(x)
return self.output_proj(x)
class TestModelWithLayers(unittest.TestCase):
def setUp(self):
self.inputs = {"input_ids": torch.randint(0, 10, (1, 3))}
self.model_with_layers = ToyModelWithLayers()
self.dense_input = {"x": torch.randn(1, 4)}
def tearDown(self):
gc.collect()
def test_layer_pruning_behavior(self):
# No pruning: expect all 6 layers
with tempfile.TemporaryDirectory() as tmpdir:
with model_addition_debugger_context(self.model_with_layers, debug_path=tmpdir, do_prune_layers=False):
_ = self.model_with_layers(**self.dense_input)
summary_path = os.path.join(tmpdir, "ToyModelWithLayers_debug_tree_SUMMARY.json")
with open(summary_path) as f:
data = json.load(f)
self.assertEqual(set(data.keys()), {"module_path", "inputs", "children"})
for layer_index in range(6):
self.assertEqual(
data["children"][layer_index + 1]["module_path"],
f"ToyModelWithLayers.layers.{int(layer_index)}",
)
# Pruning: expect only 2 layers (0 and 5)
with tempfile.TemporaryDirectory() as tmpdir:
with model_addition_debugger_context(self.model_with_layers, debug_path=tmpdir, do_prune_layers=True):
_ = self.model_with_layers(**self.dense_input)
summary_path = os.path.join(tmpdir, "ToyModelWithLayers_debug_tree_SUMMARY.json")
with open(summary_path) as f:
data = json.load(f)
self.assertEqual(set(data.keys()), {"module_path", "inputs", "children"})
self.assertEqual(data["children"][1]["module_path"], "ToyModelWithLayers.layers.0")
self.assertEqual(data["children"][2]["module_path"], "ToyModelWithLayers.layers.5")

View File

@@ -0,0 +1,197 @@
# Copyright 2020 The Hugging Face Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import unittest
from dataclasses import dataclass
import pytest
from transformers import AlbertForMaskedLM
from transformers.testing_utils import require_torch
from transformers.utils import ModelOutput, is_torch_available
if is_torch_available():
import torch
@dataclass
class ModelOutputTest(ModelOutput):
a: float
b: float | None = None
c: float | None = None
class ModelOutputTester(unittest.TestCase):
def test_get_attributes(self):
x = ModelOutputTest(a=30)
self.assertEqual(x.a, 30)
self.assertIsNone(x.b)
self.assertIsNone(x.c)
with self.assertRaises(AttributeError):
_ = x.d
def test_index_with_ints_and_slices(self):
x = ModelOutputTest(a=30, b=10)
self.assertEqual(x[0], 30)
self.assertEqual(x[1], 10)
self.assertEqual(x[:2], (30, 10))
self.assertEqual(x[:], (30, 10))
x = ModelOutputTest(a=30, c=10)
self.assertEqual(x[0], 30)
self.assertEqual(x[1], 10)
self.assertEqual(x[:2], (30, 10))
self.assertEqual(x[:], (30, 10))
def test_index_with_strings(self):
x = ModelOutputTest(a=30, b=10)
self.assertEqual(x["a"], 30)
self.assertEqual(x["b"], 10)
with self.assertRaises(KeyError):
_ = x["c"]
x = ModelOutputTest(a=30, c=10)
self.assertEqual(x["a"], 30)
self.assertEqual(x["c"], 10)
with self.assertRaises(KeyError):
_ = x["b"]
def test_dict_like_properties(self):
x = ModelOutputTest(a=30)
self.assertEqual(list(x.keys()), ["a"])
self.assertEqual(list(x.values()), [30])
self.assertEqual(list(x.items()), [("a", 30)])
self.assertEqual(list(x), ["a"])
x = ModelOutputTest(a=30, b=10)
self.assertEqual(list(x.keys()), ["a", "b"])
self.assertEqual(list(x.values()), [30, 10])
self.assertEqual(list(x.items()), [("a", 30), ("b", 10)])
self.assertEqual(list(x), ["a", "b"])
x = ModelOutputTest(a=30, c=10)
self.assertEqual(list(x.keys()), ["a", "c"])
self.assertEqual(list(x.values()), [30, 10])
self.assertEqual(list(x.items()), [("a", 30), ("c", 10)])
self.assertEqual(list(x), ["a", "c"])
with self.assertRaises(Exception):
x = x.update({"d": 20})
with self.assertRaises(Exception):
del x["a"]
with self.assertRaises(Exception):
_ = x.pop("a")
with self.assertRaises(Exception):
_ = x.setdefault("d", 32)
def test_set_attributes(self):
x = ModelOutputTest(a=30)
x.a = 10
self.assertEqual(x.a, 10)
self.assertEqual(x["a"], 10)
def test_set_keys(self):
x = ModelOutputTest(a=30)
x["a"] = 10
self.assertEqual(x.a, 10)
self.assertEqual(x["a"], 10)
def test_instantiate_from_dict(self):
x = ModelOutputTest({"a": 30, "b": 10})
self.assertEqual(list(x.keys()), ["a", "b"])
self.assertEqual(x.a, 30)
self.assertEqual(x.b, 10)
def test_instantiate_from_iterator(self):
x = ModelOutputTest([("a", 30), ("b", 10)])
self.assertEqual(list(x.keys()), ["a", "b"])
self.assertEqual(x.a, 30)
self.assertEqual(x.b, 10)
with self.assertRaises(ValueError):
_ = ModelOutputTest([("a", 30), (10, 10)])
x = ModelOutputTest(a=(30, 30))
self.assertEqual(list(x.keys()), ["a"])
self.assertEqual(x.a, (30, 30))
@require_torch
def test_torch_pytree(self):
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
import torch.utils._pytree as pytree
x = ModelOutput({"a": 1.0, "c": 2.0})
self.assertFalse(pytree._is_leaf(x))
x = ModelOutputTest(a=1.0, c=2.0)
self.assertFalse(pytree._is_leaf(x))
expected_flat_outs = [1.0, 2.0]
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
self.assertEqual(expected_tree_spec, actual_tree_spec)
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)
self.assertEqual(
pytree.treespec_dumps(actual_tree_spec),
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": "[\\"a\\", \\"c\\"]", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
)
# TODO: @ydshieh
@unittest.skip(reason="CPU OOM")
@require_torch
@pytest.mark.torch_export_test
def test_export_serialization(self):
model_cls = AlbertForMaskedLM
model_config = model_cls.config_class()
model = model_cls(model_config)
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
ep = torch.export.export(model, (), input_dict)
buffer = io.BytesIO()
torch.export.save(ep, buffer)
buffer.seek(0)
loaded_ep = torch.export.load(buffer)
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
a: float
b: float | None = None
c: float | None = None
class ModelOutputSubclassTester(unittest.TestCase):
def test_direct_model_output(self):
# Check that direct usage of ModelOutput instantiates without errors
ModelOutput({"a": 1.1})
def test_subclass_no_dataclass(self):
# Check that a subclass of ModelOutput without @dataclass is invalid
# A valid subclass is inherently tested other unit tests above.
with self.assertRaises(TypeError):
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)

View File

@@ -0,0 +1,708 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
from transformers import LlamaConfig
from transformers.testing_utils import is_torch_available, require_torch, torch_device
if is_torch_available():
import torch
from transformers import ROPE_INIT_FUNCTIONS
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
@require_torch
class RopeTest(unittest.TestCase):
def test_rope_validation(self):
config = LlamaConfig()
all_rope_types = ROPE_INIT_FUNCTIONS.keys()
# The base config is always valid (default RoPE)
config.validate_rope()
# If we explicitly set the other (non-default) RoPE types with only rope_theta,
# validation should fail because required keys are missing (e.g. factor, short_factor)
for rope_type in all_rope_types:
if rope_type == "default":
continue # "default" is always valid with just rope_theta
# proportional is same as default wrt to expected keys
if rope_type == "proportional":
continue
config.rope_parameters = {"rope_type": rope_type, "rope_theta": 10000.0}
with self.assertRaises(KeyError):
config.validate_rope()
# Parameters are exclusive to their own RoPE type, and should raise an exception if incorrectly passed
valid_param_mapping = {
"factor": ["linear", "dynamic", "yarn", "longrope"],
"attention_factor": ["yarn", "longrope"],
"beta_fast": ["yarn"],
"beta_slow": ["yarn"],
"short_factor": ["longrope"],
"long_factor": ["longrope"],
}
for rope_type in all_rope_types:
if rope_type == "default":
continue # "default" only warns about unrecognised keys, never raises KeyError
# proportional is same as default wrt to expected keys
if rope_type == "proportional":
continue
for param, valid_rope_types in valid_param_mapping.items():
# Set `param` with a dummy value -- we want to test the dict key
config.rope_parameters = {"rope_type": rope_type, "rope_theta": 10000.0, param: True}
if rope_type in valid_rope_types:
continue
else:
with self.assertRaises(KeyError):
config.validate_rope()
# Any other parameters passed to RoPE will raise a warning that a particular key is not used
# But sometimes we can have model-specific RoPE kwargs and bypass warning with `ignore_keys`
config.ignore_keys_at_rope_validation = {"mrope_sections"} # e,g in Qwen2-VL
config.rope_parameters = {"rope_type": "default", "rope_theta": 10000.0, "mrope_sections": True}
config.validate_rope()
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config.ignore_keys_at_rope_validation = set()
config.validate_rope()
self.assertEqual(len(logs.output), 1)
self.assertIn("mrope_sections", logs.output[0])
# We can indicate Different RoPE params for each attention type
# We can also have only one RoPE params defined for all layer, we don't raise an error
# because it is not required to have separate RoPE per layer type
config.layer_types = ["full_attention", "sliding_attention"]
config.rope_parameters = {
"full_attention": {"rope_type": "default", "rope_theta": 10000},
"sliding_attention": {"rope_type": "linear", "rope_theta": 10000, "factor": 2.0},
}
config.validate_rope()
config.rope_parameters = config.rope_parameters["full_attention"]
config.validate_rope()
def test_yarn_original_original_max_position_embeddings_validation(self):
"""Tests that models with no/bad `original_max_position_embeddings` raise a warning"""
config = LlamaConfig()
# good rope config: has a factor AND original_max_position_embeddings -> no warnings
rope_config = {
"rope_type": "yarn",
"rope_theta": 10000.0,
"factor": 2.0,
"original_max_position_embeddings": int(config.max_position_embeddings / 2.0),
}
config.rope_parameters = rope_config
with self.assertRaises(AssertionError): # confirm that no warnings are thrown
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config.validate_rope()
# bad rope config, no `original_max_position_embeddings` -> raise error
rope_config = {
"rope_type": "yarn",
"rope_theta": 10000.0,
"factor": 2.0,
}
config.rope_parameters = rope_config
with self.assertRaises(KeyError):
config.validate_rope()
# bad rope config, bad implicit fator -> warning
rope_config = {
"rope_type": "yarn",
"rope_theta": 10000.0,
"factor": 2.0,
"original_max_position_embeddings": 1,
}
config.rope_parameters = rope_config
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config.validate_rope()
self.assertEqual(len(logs.output), 1)
self.assertIn("implicit factor", logs.output[0])
def test_convert_rope_params_to_dict_with_list_ignore_keys(self):
# Regression test for #46121: `ignore_keys_at_rope_validation` becomes a list when loaded from a config.json
# (JSON has no set type). `convert_rope_params_to_dict` used to do `list | set` and crash with
# TypeError when `partial_rotary_factor` was also set.
config = LlamaConfig(partial_rotary_factor=0.25)
config.ignore_keys_at_rope_validation = ["mrope_section", "mrope_interleaved"]
config.convert_rope_params_to_dict(partial_rotary_factor=0.25)
self.assertIsInstance(config.ignore_keys_at_rope_validation, set)
self.assertEqual(
config.ignore_keys_at_rope_validation,
{"mrope_section", "mrope_interleaved", "partial_rotary_factor"},
)
# Round-trip through from_dict to mimic the JSON-deserialized path that triggered this in production.
cfg_dict = config.to_dict()
cfg_dict["ignore_keys_at_rope_validation"] = ["mrope_section", "mrope_interleaved"]
reloaded = LlamaConfig.from_dict(cfg_dict)
reloaded.convert_rope_params_to_dict(partial_rotary_factor=0.25)
self.assertIsInstance(reloaded.ignore_keys_at_rope_validation, set)
# Also accept None (the class-level attribute can be cleared on an instance).
config_none = LlamaConfig(partial_rotary_factor=0.25)
config_none.ignore_keys_at_rope_validation = None
config_none.convert_rope_params_to_dict(partial_rotary_factor=0.25)
self.assertEqual(config_none.ignore_keys_at_rope_validation, {"partial_rotary_factor"})
def test_rope_validation_with_per_attention_type_nested_rope(self):
"""Mirrors `test_rope_validation` with `config.layer_types` set, so that
`rope_parameters` takes the per-attention-type nested shape."""
config = LlamaConfig()
all_rope_types = ROPE_INIT_FUNCTIONS.keys()
config.layer_types = ["full_attention", "sliding_attention"]
def nest(full_attention_params):
return {
"full_attention": full_attention_params,
"sliding_attention": {"rope_type": "default", "rope_theta": 10000.0},
}
# Each non-default RoPE type with only `rope_theta` should still raise
# KeyError (missing required keys) when wrapped in the nested shape.
for rope_type in all_rope_types:
if rope_type in ("default", "proportional"):
continue
config.rope_parameters = nest({"rope_type": rope_type, "rope_theta": 10000.0})
with self.assertRaises(KeyError):
config.validate_rope()
# Parameters exclusive to a RoPE type should still raise when passed to
# the wrong type while in the nested shape.
valid_param_mapping = {
"factor": ["linear", "dynamic", "yarn", "longrope"],
"attention_factor": ["yarn", "longrope"],
"beta_fast": ["yarn"],
"beta_slow": ["yarn"],
"short_factor": ["longrope"],
"long_factor": ["longrope"],
}
for rope_type in all_rope_types:
if rope_type in ("default", "proportional"):
continue
for param, valid_rope_types in valid_param_mapping.items():
config.rope_parameters = nest({"rope_type": rope_type, "rope_theta": 10000.0, param: True})
if rope_type in valid_rope_types:
continue
with self.assertRaises(KeyError):
config.validate_rope()
# A complete yarn entry under the nested shape should validate cleanly.
# Regression: previously the implicit-factor check inside the yarn
# validator dereferenced `self.rope_parameters` (the full nested dict)
# rather than its per-type `rope_parameters` argument.
config.rope_parameters = nest(
{
"rope_type": "yarn",
"rope_theta": 10000.0,
"factor": 2.0,
"original_max_position_embeddings": int(config.max_position_embeddings / 2.0),
}
)
config.validate_rope()
def test_default_rope_numerically(self):
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
# multiple RoPE strategies will fail.
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
[
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
1.3335e-02, 1.1548e-02, 1.0000e-02, 8.6596e-03, 7.4989e-03, 6.4938e-03,
5.6234e-03, 4.8697e-03, 4.2170e-03, 3.6517e-03, 3.1623e-03, 2.7384e-03,
2.3714e-03, 2.0535e-03, 1.7783e-03, 1.5399e-03, 1.3335e-03, 1.1548e-03,
1.0000e-03, 8.6596e-04, 7.4989e-04, 6.4938e-04, 5.6234e-04, 4.8697e-04,
4.2170e-04, 3.6517e-04, 3.1623e-04, 2.7384e-04, 2.3714e-04, 2.0535e-04,
1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04
], device=torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for default RoPE
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
def test_linear_rope_numerically(self):
# This is a linear scaling strategy, the **frequencies** are scaled linearly with respect to the default
# frequencies (= the inverse frequencies are scaled **inversely**)
config = LlamaConfig()
default_rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
default_inv_freq, _ = default_rope_fn(config=config, device=torch_device)
rope_fn = ROPE_INIT_FUNCTIONS["linear"]
for factor in (2.0, 10.0, 20.0):
config.rope_parameters = {"rope_type": "linear", "rope_theta": 10000.0, "factor": factor}
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for linear RoPE
torch.testing.assert_close(inv_freq, default_inv_freq / factor)
def test_dynamic_rope_numerically(self):
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
[
1.0000e+00, 8.0931e-01, 6.5498e-01, 5.3008e-01, 4.2900e-01, 3.4720e-01,
2.8099e-01, 2.2741e-01, 1.8404e-01, 1.4895e-01, 1.2055e-01, 9.7558e-02,
7.8955e-02, 6.3899e-02, 5.1714e-02, 4.1853e-02, 3.3872e-02, 2.7413e-02,
2.2185e-02, 1.7955e-02, 1.4531e-02, 1.1760e-02, 9.5176e-03, 7.7027e-03,
6.2339e-03, 5.0451e-03, 4.0831e-03, 3.3045e-03, 2.6744e-03, 2.1644e-03,
1.7517e-03, 1.4176e-03, 1.1473e-03, 9.2852e-04, 7.5146e-04, 6.0817e-04,
4.9220e-04, 3.9834e-04, 3.2238e-04, 2.6091e-04, 2.1115e-04, 1.7089e-04,
1.3830e-04, 1.1193e-04, 9.0585e-05, 7.3312e-05, 5.9332e-05, 4.8018e-05,
3.8861e-05, 3.1451e-05, 2.5453e-05, 2.0600e-05, 1.6672e-05, 1.3492e-05,
1.0920e-05, 8.8374e-06, 7.1522e-06, 5.7883e-06, 4.6845e-06, 3.7912e-06,
3.0683e-06, 2.4832e-06, 2.0097e-06, 1.6265e-06
], device=torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
# Check 1: this is a dynamic scaling strategy, it will not scale unless we provide `seq_len` larger than the
# model's original training sequence length
rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
for factor in (2.0, 10.0, 20.0):
config.rope_parameters = {"rope_type": "dynamic", "rope_theta": 10000.0, "factor": factor}
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for dynamic RoPE
torch.testing.assert_close(inv_freq, default_inv_freq)
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1)
torch.testing.assert_close(inv_freq, default_inv_freq)
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=torch.tensor(1, dtype=torch.int64))
torch.testing.assert_close(inv_freq, default_inv_freq)
# Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
# will scale up (i.e., the inverse frequencies will scale down).
factor = 10.0
config.rope_parameters = {"rope_type": "dynamic", "rope_theta": 10000.0, "factor": factor}
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=16384)
with self.assertRaises(AssertionError): # It is NOT a linear factor
torch.testing.assert_close(inv_freq, default_inv_freq / factor)
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
def test_yarn_rope_numerically(self):
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
[
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.3479e-02,
6.9590e-02, 5.7925e-02, 4.8136e-02, 3.9931e-02, 3.3061e-02, 2.7315e-02,
2.2515e-02, 1.8512e-02, 1.5177e-02, 1.2403e-02, 1.0101e-02, 8.1924e-03,
6.6143e-03, 5.3120e-03, 4.2400e-03, 3.3599e-03, 2.6396e-03, 2.0520e-03,
1.5746e-03, 1.1882e-03, 8.7713e-04, 6.2810e-04, 4.3007e-04, 2.7384e-04,
2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
], device=torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
# `0.1 * math.log(factor) + 1.0`
rope_fn = ROPE_INIT_FUNCTIONS["yarn"]
for factor in (2.0, 10.0, 20.0):
config.rope_parameters = {"rope_type": "yarn", "rope_theta": 10000.0, "factor": factor}
_, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 0.1 * math.log(factor) + 1.0)
config.rope_parameters = {
"rope_type": "yarn",
"rope_theta": 10000.0,
"factor": factor,
"attention_factor": 0.5,
}
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
self.assertEqual(attention_scale, 0.5)
# Check 2: based on `beta_fast` and `beta_slow`, the frequencies will be scaled between 1 and `factor`.
# Increasing `beta_fast` will make RoPE more interpolative (apply scaling), and the other way around.
# `beta_slow` behaves the opposite way. Remember: `beta_fast` > `beta_slow`
# (note: adds a margin to the test for numerical stability)
factor = 10.0
margin = 1e-8
config.rope_parameters = {
"rope_type": "yarn",
"rope_theta": 10000.0,
"factor": factor,
"beta_fast": 32,
"beta_slow": 1,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
is_bounded_by_factor = [
((default_inv_freq[idx] / factor) - margin) <= yarn_inv_freq_value <= (default_inv_freq[idx] + margin)
for idx, yarn_inv_freq_value in enumerate(inv_freq)
]
self.assertTrue(all(is_bounded_by_factor))
# super high beta_fast = interpolation (i.e. scaling) in all but the first inverse frequency. The last ~20
# values (empirically checked for `beta_fast` = 1000) should be very small to linear scaling
config.rope_parameters = {
"rope_type": "yarn",
"rope_theta": 10000.0,
"factor": factor,
"beta_fast": 1000,
"beta_slow": 1,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
is_interpolating = [
yarn_inv_freq_value < (default_inv_freq[idx] + margin) for idx, yarn_inv_freq_value in enumerate(inv_freq)
]
self.assertFalse(is_interpolating[0])
self.assertTrue(all(is_interpolating[1:]))
torch.testing.assert_close(inv_freq[-20:], default_inv_freq[-20:] / factor)
# Check 3: numerical snapshot to avoid regressions
config.rope_parameters = {
"rope_type": "yarn",
"rope_theta": 10000.0,
"factor": factor,
"beta_fast": 32,
"beta_slow": 1,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
def test_longrope_rope_numerically(self):
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on the seq_len
dim = config.hidden_size // config.num_attention_heads
short_factor = [2.0] * (dim // 2) # scaling applied when seq_len <= max_position_embeddings
long_factor = torch.ones(dim // 2).cumsum(0).tolist() # scaling applied when seq_len > max_position_embeddings
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
# `math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))`
rope_fn = ROPE_INIT_FUNCTIONS["longrope"]
for factor in (2.0, 10.0, 20.0):
config.rope_parameters = {
"rope_type": "longrope",
"rope_theta": 10000.0,
"factor": factor,
"short_factor": short_factor,
"long_factor": long_factor,
}
_, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(
attention_scale, math.sqrt(1 + math.log(factor) / math.log(config.max_position_embeddings))
)
config.rope_parameters = {
"rope_type": "longrope",
"rope_theta": 10000.0,
"factor": factor,
"short_factor": short_factor,
"long_factor": long_factor,
"attention_factor": 0.5,
}
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
self.assertEqual(attention_scale, 0.5)
config.rope_parameters = {
"rope_type": "longrope",
"rope_theta": 10000.0,
"factor": factor,
"short_factor": short_factor,
"long_factor": long_factor,
}
self.assertEqual(config.rope_parameters.get("attention_factor"), None)
# Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised.
config.standardize_rope_params()
config.validate_rope()
# Check 2: seq_len == 0 -> short factor is applied to the default frequencies
config.rope_parameters = {
"rope_type": "longrope",
"rope_theta": 10000.0,
"factor": 1.0,
"short_factor": short_factor,
"long_factor": long_factor,
}
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=0)
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(short_factor).to(torch_device))
# Check 3: seq_len > max_position_embeddings -> long factor is applied to the default frequencies
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=config.max_position_embeddings + 1)
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(long_factor).to(torch_device))
def test_llama3_rope_numerically(self):
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
[
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
1.3335e-02, 1.0730e-02, 7.7785e-03, 5.6009e-03, 3.9991e-03, 2.8248e-03,
1.9675e-03, 1.3449e-03, 8.9549e-04, 5.7363e-04, 3.4539e-04, 2.7384e-04,
2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
], device=torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
# Check 1: `attention_factor` is always 1
rope_fn = ROPE_INIT_FUNCTIONS["llama3"]
for factor in (2.0, 10.0, 20.0):
config.rope_parameters = {
"rope_type": "llama3",
"rope_theta": 10000.0,
"factor": factor,
"original_max_position_embeddings": 2048,
"low_freq_factor": 1,
"high_freq_factor": 4,
}
_, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 1.0)
# Check 2: based on `low_freq_factor` and `high_freq_factor`, the frequencies will be scaled between 1 and
# `factor` (similar to yarn). Low frequencies get scaled by `factor`, high frequencies see no change, medium
# frequencies are scaled by a value in between. Changing `low_freq_factor` and `high_freq_factor` changes what
# is considered low, medium, and high frequencies.
factor = 10.0
config.rope_parameters = {
"rope_type": "llama3",
"rope_theta": 10000.0,
"factor": factor,
"original_max_position_embeddings": 2048,
"low_freq_factor": 1,
"high_freq_factor": 4,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
is_bounded_by_factor = [
(default_inv_freq[idx] / factor) <= llama3_inv_freq_value <= default_inv_freq[idx]
for idx, llama3_inv_freq_value in enumerate(inv_freq)
]
self.assertTrue(all(is_bounded_by_factor))
# if we change `high_freq_factor` to a very high value, none is considered high-frequency -> ALL values will be
# scaled
config.rope_parameters = config.rope_parameters = {
"rope_type": "llama3",
"rope_theta": 10000.0,
"factor": factor,
"original_max_position_embeddings": 2048,
"low_freq_factor": 1,
"high_freq_factor": 1000,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
is_scaled = [yarn_inv_freq_value < default_inv_freq[idx] for idx, yarn_inv_freq_value in enumerate(inv_freq)]
self.assertTrue(all(is_scaled))
# Check 3: numerical snapshot to avoid regressions
config.rope_parameters = {
"rope_type": "llama3",
"rope_theta": 10000.0,
"factor": factor,
"original_max_position_embeddings": 2048,
"low_freq_factor": 1,
"high_freq_factor": 4,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
def test_proportional_rope_numerically(self):
# fmt: off
EXPECTED_INV_FREQ = torch.tensor(
[
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00
], device=torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config = LlamaConfig()
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
self.assertEqual(config.hidden_size, 4096)
self.assertEqual(config.num_attention_heads, 32)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
head_dim = config.hidden_size // config.num_attention_heads # 128
rope_fn = ROPE_INIT_FUNCTIONS["proportional"]
default_rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
# Check 1: `attention_factor` is always 1.0, regardless of parameters
for partial_rotary_factor in (1.0, 0.5, 0.25):
config.rope_parameters = {
"rope_type": "proportional",
"rope_theta": 10000.0,
"partial_rotary_factor": partial_rotary_factor,
}
_, attention_scale = rope_fn(config=config, device=torch_device)
self.assertEqual(attention_scale, 1.0)
# Check 2: output shape is always head_dim // 2, regardless of partial_rotary_factor
for partial_rotary_factor in (1.0, 0.5, 0.25):
config.rope_parameters = {
"rope_type": "proportional",
"rope_theta": 10000.0,
"partial_rotary_factor": partial_rotary_factor,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
self.assertEqual(inv_freq.shape[0], head_dim // 2)
# Check 3: zero-padding behavior — when partial_rotary_factor < 1.0, the last (head_dim // 2 - rope_angles)
# entries must be exactly zero, and the first rope_angles entries must be non-zero
for partial_rotary_factor, expected_rope_angles in ((0.5, 32), (0.25, 16)):
config.rope_parameters = {
"rope_type": "proportional",
"rope_theta": 10000.0,
"partial_rotary_factor": partial_rotary_factor,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
# First rope_angles entries should be non-zero (rotated frequencies)
self.assertTrue(torch.all(inv_freq[:expected_rope_angles] != 0))
# Remaining entries should be exactly zero (NoPE angles)
expected_nope_angles = head_dim // 2 - expected_rope_angles
torch.testing.assert_close(
inv_freq[expected_rope_angles:],
torch.zeros(expected_nope_angles, device=torch_device),
)
# When partial_rotary_factor = 1.0, no entries should be zero
config.rope_parameters = {
"rope_type": "proportional",
"rope_theta": 10000.0,
"partial_rotary_factor": 1.0,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
self.assertTrue(torch.all(inv_freq != 0))
# Check 4: factor scaling equivalences with default and linear RoPE
# 4a: With partial_rotary_factor=1.0 and factor=1.0, proportional RoPE == default RoPE
config.rope_parameters = {
"rope_type": "proportional",
"rope_theta": 10000.0,
"partial_rotary_factor": 1.0,
"factor": 1.0,
}
inv_freq_prop, _ = rope_fn(config=config, device=torch_device)
config.rope_parameters = {"rope_type": "default", "rope_theta": 10000.0}
default_inv_freq, _ = default_rope_fn(config=config, device=torch_device)
torch.testing.assert_close(inv_freq_prop, default_inv_freq)
# 4b: With partial_rotary_factor=1.0 and factor=2.0, proportional RoPE == linear RoPE
linear_rope_fn = ROPE_INIT_FUNCTIONS["linear"]
for factor in (2.0, 10.0):
config.rope_parameters = {
"rope_type": "proportional",
"rope_theta": 10000.0,
"partial_rotary_factor": 1.0,
"factor": factor,
}
inv_freq_prop, _ = rope_fn(config=config, device=torch_device)
config.rope_parameters = {"rope_type": "linear", "rope_theta": 10000.0, "factor": factor}
inv_freq_linear, _ = linear_rope_fn(config=config, device=torch_device)
torch.testing.assert_close(inv_freq_prop, inv_freq_linear)
# 4c: With partial_rotary_factor=0.5 and factor=2.0, the non-zero portion should be the rotated subspace
# frequencies divided by factor
config.rope_parameters = {
"rope_type": "proportional",
"rope_theta": 10000.0,
"partial_rotary_factor": 0.5,
"factor": 2.0,
}
inv_freq_scaled, _ = rope_fn(config=config, device=torch_device)
config.rope_parameters = {
"rope_type": "proportional",
"rope_theta": 10000.0,
"partial_rotary_factor": 0.5,
"factor": 1.0,
}
inv_freq_unscaled, _ = rope_fn(config=config, device=torch_device)
torch.testing.assert_close(inv_freq_scaled, inv_freq_unscaled / 2.0)
# Check 5: numerical snapshot to avoid regressions (partial_rotary_factor=0.25, factor=1.0)
config.rope_parameters = {
"rope_type": "proportional",
"rope_theta": 10000.0,
"partial_rotary_factor": 0.25,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
import unittest
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
import httpx
from transformers.utils.network_logging import (
_clear_network_debug_report,
_disable_network_debug_report,
_enable_network_debug_report,
_format_network_debug_report,
_get_network_debug_report,
)
class _SlowHandler(BaseHTTPRequestHandler):
def do_GET(self):
time.sleep(0.01)
response = b"ok"
self.send_response(200)
self.send_header("Content-Length", str(len(response)))
self.end_headers()
self.wfile.write(response)
def log_message(self, format, *args):
return
class NetworkLoggingTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._server = ThreadingHTTPServer(("127.0.0.1", 0), _SlowHandler)
cls._thread = threading.Thread(target=cls._server.serve_forever, daemon=True)
cls._thread.start()
cls._base_url = f"http://127.0.0.1:{cls._server.server_port}"
@classmethod
def tearDownClass(cls):
cls._server.shutdown()
cls._thread.join()
cls._server.server_close()
def tearDown(self):
_disable_network_debug_report()
def test_network_debug_report_records_httpx_requests(self):
_enable_network_debug_report()
_clear_network_debug_report()
response = httpx.get(f"{self._base_url}/slow")
self.assertEqual(response.text, "ok")
report = _get_network_debug_report()
matching_requests = [request for request in report["requests"] if request["url"].endswith("/slow")]
self.assertEqual(len(matching_requests), 1)
request = matching_requests[0]
self.assertEqual(request["method"], "GET")
self.assertEqual(request["status_code"], 200)
self.assertEqual(request["path"], "/slow")
self.assertGreater(request["total_ms"], 0)
self.assertIn("receive_response_headers", request["phases_ms"])
summary = _format_network_debug_report()
self.assertIn("Network debug report", summary)
self.assertIn("Slowest requests:", summary)
self.assertIn("/slow", summary)

220
tests/utils/test_offline.py Normal file
View File

@@ -0,0 +1,220 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import sys
import unittest
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
from transformers.testing_utils import TestCasePlus, require_torch
class OfflineTests(TestCasePlus):
@require_torch
@unittest.skip("This test is failing on main") # TODO matt/ydshieh, this test needs to be fixed
def test_offline_mode(self):
# this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before
# `transformers` is loaded, and it's too late for inside pytest - so we are changing it
# while running an external program
# python one-liner segments
# this must be loaded before socket.socket is monkey-patched
load = """
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
"""
run = """
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipe = pipeline(task="fill-mask", model=mname)
print("success")
"""
mock = """
import socket
def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet")
socket.socket = offline_socket
"""
# Force fetching the files so that we can use the cache
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipeline(task="fill-mask", model=mname)
# baseline - just load from_pretrained with normal network
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
stdout, _ = self._execute_with_env(load, run, mock, TRANSFORMERS_OFFLINE="1")
self.assertIn("success", stdout)
@require_torch
def test_offline_mode_no_internet(self):
# python one-liner segments
# this must be loaded before socket.socket is monkey-patched
load = """
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
"""
run = """
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipe = pipeline(task="fill-mask", model=mname)
print("success")
"""
mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet")
socket.socket = offline_socket
"""
# Force fetching the files so that we can use the cache
mname = "hf-internal-testing/tiny-random-bert"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
BertTokenizer.from_pretrained(mname)
pipeline(task="fill-mask", model=mname)
# baseline - just load from_pretrained with normal network
# should succeed
stdout, _ = self._execute_with_env(load, run, mock)
self.assertIn("success", stdout)
@require_torch
def test_offline_mode_sharded_checkpoint(self):
# this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before
# `transformers` is loaded, and it's too late for inside pytest - so we are changing it
# while running an external program
# python one-liner segments
# this must be loaded before socket.socket is monkey-patched
load = """
from transformers import BertConfig, BertModel, BertTokenizer
"""
run = """
mname = "hf-internal-testing/tiny-random-bert-sharded"
BertConfig.from_pretrained(mname)
BertModel.from_pretrained(mname)
print("success")
"""
mock = """
import socket
def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled")
socket.socket = offline_socket
"""
# baseline - just load from_pretrained with normal network
# should succeed
stdout, _ = self._execute_with_env(load, run)
self.assertIn("success", stdout)
# next emulate no network
# Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
# self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="0")
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
stdout, _ = self._execute_with_env(load, mock, run, TRANSFORMERS_OFFLINE="1")
self.assertIn("success", stdout)
@require_torch
def test_offline_mode_pipeline_exception(self):
load = """
from transformers import pipeline
"""
run = """
mname = "hf-internal-testing/tiny-random-bert"
pipe = pipeline(model=mname)
"""
mock = """
import socket
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
socket.socket = offline_socket
"""
_, stderr = self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="1")
self.assertIn(
"You cannot infer task automatically within `pipeline` when using offline mode",
stderr.replace("\n", ""),
)
@require_torch
def test_offline_model_dynamic_model(self):
load = """
from transformers import AutoModel
"""
run = """
mname = "hf-internal-testing/test_dynamic_model"
AutoModel.from_pretrained(mname, trust_remote_code=True)
print("success")
"""
# baseline - just load from_pretrained with normal network
# should succeed
stdout, _ = self._execute_with_env(load, run)
self.assertIn("success", stdout)
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1")
self.assertIn("success", stdout)
def test_is_offline_mode(self):
"""
Test `is_offline_mode` helper (should respect both HF_HUB_OFFLINE and legacy TRANSFORMERS_OFFLINE env vars)
"""
load = "from huggingface_hub import is_offline_mode"
run = "print(is_offline_mode())"
stdout, _ = self._execute_with_env(load, run)
self.assertIn("False", stdout)
stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1")
self.assertIn("True", stdout)
stdout, _ = self._execute_with_env(load, run, HF_HUB_OFFLINE="1")
self.assertIn("True", stdout)
def _execute_with_env(self, *commands: tuple[str, ...], should_fail: bool = False, **env) -> tuple[str, str]:
"""Execute Python code with a given environment and return the stdout/stderr as strings.
If `should_fail=True`, the command is expected to fail. Otherwise, it should succeed.
Environment variables can be passed as keyword arguments.
"""
# Build command
cmd = [sys.executable, "-c", "\n".join(commands)]
# Configure env
new_env = self.get_env()
new_env.update(env)
# Run command
result = subprocess.run(cmd, env=new_env, check=False, capture_output=True)
# Check execution
if should_fail:
self.assertNotEqual(result.returncode, 0, result.stderr)
else:
self.assertEqual(result.returncode, 0, result.stderr)
# Return output
return result.stdout.decode(), result.stderr.decode()

View File

@@ -0,0 +1,124 @@
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
#
# this test validates that we can stack skip decorators in groups and whether
# they work correctly with other decorators
#
# since the decorators have already built their decision params (like checking
# env[], we can't mock the env and test each of the combinations), so ideally
# the following 4 should be run. But since we have different CI jobs running
# different configs, all combinations should get covered
#
# RUN_SLOW=1 pytest -rA tests/test_skip_decorators.py
# RUN_SLOW=1 CUDA_VISIBLE_DEVICES="" pytest -rA tests/test_skip_decorators.py
# RUN_SLOW=0 pytest -rA tests/test_skip_decorators.py
# RUN_SLOW=0 CUDA_VISIBLE_DEVICES="" pytest -rA tests/test_skip_decorators.py
import os
import unittest
import pytest
from parameterized import parameterized
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
# skipping in unittest tests
params = [(1,)]
# test that we can stack our skip decorators with 3rd party decorators
def check_slow():
run_slow = bool(os.getenv("RUN_SLOW", "0"))
if run_slow:
assert True
else:
assert False, "should have been skipped"
# test that we can stack our skip decorators
def check_slow_torch_cuda():
run_slow = bool(os.getenv("RUN_SLOW", "0"))
if run_slow and torch_device == "cuda":
assert True
else:
assert False, "should have been skipped"
def check_slow_torch_accelerator():
run_slow = bool(os.getenv("RUN_SLOW", "0"))
assert run_slow and torch_device in ["cuda", "xpu"], "should have been skipped"
@require_torch
class SkipTester(unittest.TestCase):
@slow
@require_torch_accelerator
def test_2_skips_slow_first(self):
check_slow_torch_accelerator()
@require_torch_accelerator
@slow
def test_2_skips_slow_last(self):
check_slow_torch_accelerator()
# The combination of any skip decorator, followed by parameterized fails to skip the tests
# 1. @slow manages to correctly skip `test_param_slow_first`
# 2. but then `parameterized` creates new tests, with a unique name for each parameter groups.
# It has no idea that they are to be skipped and so they all run, ignoring @slow
# Therefore skip decorators must come after `parameterized`
#
# @slow
# @parameterized.expand(params)
# def test_param_slow_first(self, param=None):
# check_slow()
# This works as expected:
# 1. `parameterized` creates new tests with unique names
# 2. each of them gets an opportunity to be skipped
@parameterized.expand(params)
@slow
def test_param_slow_last(self, param=None):
check_slow()
# skipping in non-unittest tests
# no problem at all here
@slow
@require_torch_accelerator
def test_pytest_2_skips_slow_first():
check_slow_torch_accelerator()
@require_torch_accelerator
@slow
def test_pytest_2_skips_slow_last():
check_slow_torch_accelerator()
@slow
@pytest.mark.parametrize("param", [1])
def test_pytest_param_slow_first(param):
check_slow()
@pytest.mark.parametrize("param", [1])
@slow
def test_pytest_param_slow_last(param):
check_slow()

View File

@@ -0,0 +1,308 @@
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tempfile
import unittest
import unittest.mock as mock
from pathlib import Path
import httpx
from transformers import AutoTokenizer, BertTokenizer, BertTokenizerFast, GPT2TokenizerFast, is_tokenizers_available
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, require_tokenizers
from transformers.tokenization_python import ExtensionsTrie, Trie
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
from test_module.custom_tokenization import CustomTokenizer # noqa E402
if is_tokenizers_available():
from test_module.custom_tokenization_fast import CustomTokenizerFast
class TokenizerUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
"failed", request=mock.Mock(), response=mock.Mock()
)
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
@require_tokenizers
def test_cached_files_are_used_when_internet_is_down_missing_files(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
"failed", request=mock.Mock(), response=mock.Mock()
)
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
_ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
_ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
# This check we did call the fake head request
mock_head.assert_called()
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
@classmethod
def setUpClass(cls):
cls._token = TOKEN
def test_push_to_hub(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
def test_push_to_hub_chat_templates(self):
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.chat_template = "test template"
with TemporaryHubRepo(token=self._token) as tmp_repo:
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
with TemporaryHubRepo(token=self._token) as tmp_repo:
tokenizer.chat_template = {"default": "a", "secondary": "b"}
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
def test_push_to_hub_via_save_pretrained(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
# Push to hub via save_pretrained
tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
def test_push_to_hub_in_organization(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
# Push to hub via save_pretrained
tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
@require_tokenizers
def test_push_to_hub_dynamic_tokenizer(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
CustomTokenizer.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)
# No fast custom tokenizer
tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
tokenizer = AutoTokenizer.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
@require_tokenizers
def test_push_to_hub_dynamic_tokenizer_with_both_slow_and_fast_classes(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
CustomTokenizer.register_for_auto_class()
# Fast and slow custom tokenizer
CustomTokenizerFast.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
bert_tokenizer.save_pretrained(tmp_dir)
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
tokenizer = AutoTokenizer.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
tokenizer = AutoTokenizer.from_pretrained(tmp_repo.repo_id, use_fast=False, trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
@require_tokenizers
class TokenizersBackendTest(unittest.TestCase):
def test_clean_up_tokenization_spaces(self):
tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
# GPT-2 is a BPE tokenizer — clean_up_tokenization is skipped because it
# was designed for WordPiece and is destructive for BPE (strips legitimate
# spaces before punctuation).
# Use text with spaces before punctuation that cleanup would strip if applied.
text = "x != y"
token_ids = tokenizer.encode(text)
decoded_no_cleanup = tokenizer.decode(token_ids, clean_up_tokenization_spaces=False)
self.assertEqual(decoded_no_cleanup, text)
# With BPE guard, cleanup=True also preserves the text
decoded_with_cleanup = tokenizer.decode(token_ids, clean_up_tokenization_spaces=True)
self.assertEqual(decoded_with_cleanup, text)
class TrieTest(unittest.TestCase):
def test_trie(self):
trie = Trie()
trie.add("Hello 友達")
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"": {"": {"": 1}}}}}}}}})
trie.add("Hello")
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"": {"": {"": 1}}}}}}}}})
def test_trie_split(self):
trie = Trie()
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
trie.add("[CLS]")
trie.add("extra_id_1")
trie.add("extra_id_100")
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
def test_trie_single(self):
trie = Trie()
trie.add("A")
self.assertEqual(trie.split("ABC"), ["A", "BC"])
self.assertEqual(trie.split("BCA"), ["BC", "A"])
def test_trie_final(self):
trie = Trie()
trie.add("TOKEN]")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
def test_trie_subtokens(self):
trie = Trie()
trie.add("A")
trie.add("P")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
def test_trie_suffix_tokens(self):
trie = Trie()
trie.add("AB")
trie.add("B")
trie.add("C")
self.assertEqual(trie.split("ABC"), ["AB", "C"])
def test_trie_skip(self):
trie = Trie()
trie.add("ABC")
trie.add("B")
trie.add("CD")
self.assertEqual(trie.split("ABCD"), ["ABC", "D"])
def test_cut_text_hardening(self):
# Even if the offsets are wrong, we necessarily output correct string
# parts.
trie = Trie()
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
self.assertEqual(parts, ["AB", "C"])
class ExtensionsTrieTest(unittest.TestCase):
def test_extensions(self):
# Test searching by prefix
trie = ExtensionsTrie()
trie.add("foo")
trie.add("food")
trie.add("foodie")
trie.add("helium")
self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
self.assertEqual(trie.extensions("helium"), ["helium"])
def test_empty_prefix(self):
trie = ExtensionsTrie()
# Test searching with an empty prefix returns all values
trie.add("hello")
trie.add("bye")
self.assertEqual(trie.extensions(""), ["hello", "bye"])
def test_no_extension_match(self):
trie = ExtensionsTrie()
# Test searching for a prefix that doesn't match any key
values = trie.extensions("unknown")
self.assertEqual(len(values), 0)
def test_update_value(self):
trie = ExtensionsTrie()
# Test updating the value of an existing key
trie.add("hi")
trie.add("hi")
self.assertEqual(trie.extensions("hi"), ["hi"])

View File

@@ -0,0 +1,97 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
import sys
from transformers.testing_utils import TestCasePlus
from transformers.utils.versions import require_version, require_version_core
numpy_ver = importlib.metadata.version("numpy")
python_ver = ".".join([str(x) for x in sys.version_info[:3]])
class DependencyVersionCheckTest(TestCasePlus):
def test_core(self):
# lt + different version strings
require_version_core("numpy<1000.4.5")
require_version_core("numpy<1000.4")
require_version_core("numpy<1000")
# le
require_version_core("numpy<=1000.4.5")
require_version_core(f"numpy<={numpy_ver}")
# eq
require_version_core(f"numpy=={numpy_ver}")
# ne
require_version_core("numpy!=1000.4.5")
# ge
require_version_core("numpy>=1.0")
require_version_core("numpy>=1.0.0")
require_version_core(f"numpy>={numpy_ver}")
# gt
require_version_core("numpy>1.0.0")
# mix
require_version_core("numpy>1.0.0,<1000")
# requirement w/o version
require_version_core("numpy")
# unmet requirements due to version conflict
for req in ["numpy==1.0.0", "numpy>=1000.0.0", f"numpy<{numpy_ver}"]:
try:
require_version_core(req)
except ImportError as e:
self.assertIn(f"{req} is required", str(e))
self.assertIn("but found", str(e))
# unmet requirements due to missing module
for req in ["numpipypie>1", "numpipypie2"]:
try:
require_version_core(req)
except importlib.metadata.PackageNotFoundError as e:
self.assertIn(f"The '{req}' distribution was not found and is required by this application", str(e))
self.assertIn("Try: `pip install transformers -U`", str(e))
# bogus requirements formats:
# 1. whole thing
for req in ["numpy??1.0.0", "numpy1.0.0"]:
try:
require_version_core(req)
except ValueError as e:
self.assertIn("requirement needs to be in the pip package format", str(e))
# 2. only operators
for req in ["numpy=1.0.0", "numpy == 1.00", "numpy<>1.0.0", "numpy><1.00", "numpy>>1.0.0"]:
try:
require_version_core(req)
except ValueError as e:
self.assertIn("need one of ", str(e))
def test_python(self):
# matching requirement
require_version("python>=3.9.0")
# not matching requirements
for req in ["python>9.9.9", "python<3.0.0"]:
try:
require_version_core(req)
except ImportError as e:
self.assertIn(f"{req} is required", str(e))
self.assertIn(f"but found python=={python_ver}", str(e))

View File

@@ -0,0 +1,365 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import is_torch_available, is_vision_available
from transformers.image_processing_utils import get_size_dict
from transformers.image_utils import SizeDict
from transformers.processing_utils import VideosKwargs
from transformers.testing_utils import (
require_av,
require_cv2,
require_decord,
require_torch,
require_torchcodec,
require_torchvision,
require_vision,
)
from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos
if is_torch_available():
import torch
if is_vision_available():
import PIL
from transformers import BaseVideoProcessor
from transformers.video_utils import VideoMetadata, load_video
def get_random_video(height, width, num_frames=8, return_torch=False):
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
video = np.array([random_frame] * num_frames)
if return_torch:
# move channel first
return torch.from_numpy(video).permute(0, 3, 1, 2)
return video
@require_vision
@require_torchvision
class BaseVideoProcessorTester(unittest.TestCase):
"""
Tests that the `transforms` can be applied to a 4-dim array directly, i.e. to a whole video.
"""
def test_make_batched_videos_pil(self):
# Test a single image is converted to a list of 1 video with 1 frame
video = get_random_video(16, 32)
pil_image = PIL.Image.fromarray(video[0])
videos_list = make_batched_videos(pil_image)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0][0], np.array(pil_image)))
# Test a list of videos is converted to a list of 1 video
video = get_random_video(16, 32)
pil_video = [PIL.Image.fromarray(frame) for frame in video]
videos_list = make_batched_videos(pil_video)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
# Test a nested list of videos is not modified
video = get_random_video(16, 32)
pil_video = [PIL.Image.fromarray(frame) for frame in video]
videos = [pil_video, pil_video]
videos_list = make_batched_videos(videos)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
def test_make_batched_videos_numpy(self):
# Test a single image is converted to a list of 1 video with 1 frame
video = get_random_video(16, 32)[0]
videos_list = make_batched_videos(video)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0][0], video))
# Test a 4d array of videos is converted to a list of 1 video
video = get_random_video(16, 32)
videos_list = make_batched_videos(video)
self.assertIsInstance(videos_list, list)
self.assertTrue(len(videos_list), 1)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
# Test a 5d array of batch videos is converted to a list of videos
video = video[None, ...].repeat(4, 0)
videos_list = make_batched_videos(video)
self.assertIsInstance(videos_list, list)
self.assertTrue(len(videos_list), 4)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video[0]))
# Test a list of videos is converted to a list of videos
video = get_random_video(16, 32)
videos = [video, video]
videos_list = make_batched_videos(videos)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
@require_torch
def test_make_batched_videos_torch(self):
# Test a single image is converted to a list of 1 video with 1 frame
video = get_random_video(16, 32)[0]
torch_video = torch.from_numpy(video)
videos_list = make_batched_videos(torch_video)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], torch.Tensor)
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0][0], video))
# Test a 4d array of videos is converted to a list of 1 video
video = get_random_video(16, 32)
torch_video = torch.from_numpy(video)
videos_list = make_batched_videos(torch_video)
self.assertIsInstance(videos_list, list)
self.assertTrue(len(videos_list), 1)
self.assertIsInstance(videos_list[0], torch.Tensor)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
# Test a 5d array of batch videos is converted to a list of videos
torch_video = torch_video[None, ...].repeat(4, 1, 1, 1, 1)
videos_list = make_batched_videos(torch_video)
self.assertIsInstance(videos_list, list)
self.assertTrue(len(videos_list), 4)
self.assertIsInstance(videos_list[0], torch.Tensor)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
# Test a list of videos is converted to a list of videos
video = get_random_video(16, 32)
torch_video = torch.from_numpy(video)
videos = [torch_video, torch_video]
videos_list = make_batched_videos(videos)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], torch.Tensor)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
def test_resize(self):
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
video = get_random_video(16, 32, return_torch=True)
# Size can be an int or a tuple of ints.
size_dict = SizeDict(**get_size_dict((8, 8), param_name="size"))
resized_video = video_processor.resize(video, size=size_dict)
self.assertIsInstance(resized_video, torch.Tensor)
self.assertEqual(resized_video.shape, (8, 3, 8, 8))
def test_normalize(self):
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
array = torch.randn(4, 3, 16, 32)
mean = [0.1, 0.5, 0.9]
std = [0.2, 0.4, 0.6]
# mean and std can be passed as lists or NumPy arrays.
expected = (array - torch.tensor(mean)[:, None, None]) / torch.tensor(std)[:, None, None]
normalized_array = video_processor.normalize(array, mean, std)
torch.testing.assert_close(normalized_array, expected)
def test_center_crop(self):
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
video = get_random_video(16, 32, return_torch=True)
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
size_dict = SizeDict(**get_size_dict(size, default_to_square=True, param_name="crop_size"))
cropped_video = video_processor.center_crop(video, size_dict)
self.assertIsInstance(cropped_video, torch.Tensor)
expected_size = (size, size) if isinstance(size, int) else size
self.assertEqual(cropped_video.shape, (8, 3, *expected_size))
def test_convert_to_rgb(self):
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
video = get_random_video(20, 20, return_torch=True)
rgb_video = video_processor.convert_to_rgb(video[:, :1])
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
def test_group_and_reorder_videos(self):
"""Tests that videos can be grouped by frame size and number of frames"""
video_1 = get_random_video(20, 20, num_frames=3, return_torch=True)
video_2 = get_random_video(20, 20, num_frames=5, return_torch=True)
# Group two videos of same size but different number of frames
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group two videos of different size but same number of frames
video_3 = get_random_video(15, 20, num_frames=3, return_torch=True)
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group all three videos where some have same size or same frame count
# But since none have frames and sizes identical, we'll have 3 groups
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3])
self.assertEqual(len(grouped_videos), 3)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 3)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group if we had some videos with identical shapes
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group if we had all videos with identical shapes
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1])
self.assertEqual(len(grouped_videos), 1)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 1)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
@require_vision
@require_av
class LoadVideoTester(unittest.TestCase):
def test_load_video_url(self):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
)
self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
def test_load_video_local(self):
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
video, _ = load_video(video_file_path)
self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
# FIXME: @raushan, yt-dlp downloading works for for some reason it cannot redirect to out buffer?
# @requires_yt_dlp
# def test_load_video_youtube(self):
# video = load_video("https://www.youtube.com/watch?v=QC8iQqtG0hg")
# self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
@require_decord
@require_torchvision
@require_torchcodec
@require_cv2
def test_load_video_backend_url(self):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
backend="decord",
)
self.assertEqual(video.shape, (243, 360, 640, 3))
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
backend="torchcodec",
)
self.assertEqual(video.shape, (243, 360, 640, 3))
# Can't use certain backends with url
with self.assertRaises(ValueError):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
backend="opencv",
)
with self.assertRaises(ValueError):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
backend="torchvision",
)
@require_decord
@require_torchvision
@require_torchcodec
@require_cv2
def test_load_video_backend_local(self):
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
video, metadata = load_video(video_file_path, backend="decord")
self.assertEqual(video.shape, (243, 360, 640, 3))
self.assertIsInstance(metadata, VideoMetadata)
video, metadata = load_video(video_file_path, backend="opencv")
self.assertEqual(video.shape, (243, 360, 640, 3))
self.assertIsInstance(metadata, VideoMetadata)
video, metadata = load_video(video_file_path, backend="torchvision")
self.assertEqual(video.shape, (243, 360, 640, 3))
self.assertIsInstance(metadata, VideoMetadata)
video, metadata = load_video(video_file_path, backend="torchcodec")
self.assertEqual(video.shape, (243, 360, 640, 3))
self.assertIsInstance(metadata, VideoMetadata)
def test_load_video_num_frames(self):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
num_frames=16,
)
self.assertEqual(video.shape, (16, 360, 640, 3))
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
num_frames=22,
)
self.assertEqual(video.shape, (22, 360, 640, 3))
def test_load_video_fps(self):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", fps=1
)
self.assertEqual(video.shape, (9, 360, 640, 3))
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", fps=2
)
self.assertEqual(video.shape, (19, 360, 640, 3))
# `num_frames` is mutually exclusive with `video_fps`
with self.assertRaises(ValueError):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
fps=1,
num_frames=10,
)

File diff suppressed because it is too large Load Diff