first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
0
tests/utils/__init__.py
Normal file
0
tests/utils/__init__.py
Normal file
23
tests/utils/import_structures/failing_export.py
Normal file
23
tests/utils/import_structures/failing_export.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import requires
|
||||
|
||||
|
||||
@requires(backends=("random_item_that_should_not_exist",))
|
||||
class A0:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import requires
|
||||
|
||||
|
||||
@requires()
|
||||
class A0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires()
|
||||
def a0():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch",))
|
||||
class A1:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch",))
|
||||
def a1():
|
||||
pass
|
||||
|
||||
|
||||
@requires(
|
||||
backends=("torch",)
|
||||
)
|
||||
class A2:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(
|
||||
backends=("torch",)
|
||||
)
|
||||
def a2():
|
||||
pass
|
||||
|
||||
|
||||
@requires(
|
||||
backends=(
|
||||
"torch",
|
||||
)
|
||||
)
|
||||
class A3:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(
|
||||
backends=(
|
||||
"torch",
|
||||
)
|
||||
)
|
||||
def a3():
|
||||
pass
|
||||
|
||||
@requires(backends=())
|
||||
class A4:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -0,0 +1,92 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import requires
|
||||
|
||||
|
||||
@requires(backends=("torch>=2.5",))
|
||||
class D0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch>=2.5",))
|
||||
def d0():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch>2.5",))
|
||||
class D1:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch>2.5",))
|
||||
def d1():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch<=2.5",))
|
||||
class D2:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch<=2.5",))
|
||||
def d2():
|
||||
pass
|
||||
|
||||
@requires(backends=("torch<2.5",))
|
||||
class D3:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch<2.5",))
|
||||
def d3():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch==2.5",))
|
||||
class D4:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch==2.5",))
|
||||
def d4():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch!=2.5",))
|
||||
class D5:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch!=2.5",))
|
||||
def d5():
|
||||
pass
|
||||
|
||||
@requires(backends=("torch>=2.5", "accelerate<0.20"))
|
||||
class D6:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch>=2.5", "accelerate<0.20"))
|
||||
def d6():
|
||||
pass
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import requires
|
||||
|
||||
|
||||
@requires()
|
||||
# That's a statement
|
||||
class B0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires()
|
||||
# That's a statement
|
||||
def b0():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch",))
|
||||
# That's a statement
|
||||
class B1:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch",))
|
||||
# That's a statement
|
||||
def b1():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch",))
|
||||
# That's a statement
|
||||
class B2:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch",))
|
||||
# That's a statement
|
||||
def b2():
|
||||
pass
|
||||
|
||||
|
||||
@requires(
|
||||
backends=(
|
||||
"torch",
|
||||
)
|
||||
)
|
||||
# That's a statement
|
||||
class B3:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(
|
||||
backends=(
|
||||
"torch",
|
||||
)
|
||||
)
|
||||
# That's a statement
|
||||
def b3():
|
||||
pass
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# fmt: off
|
||||
|
||||
from transformers.utils.import_utils import requires
|
||||
|
||||
|
||||
@requires(backends=("torch", "torch"))
|
||||
class C0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch", "torch"))
|
||||
def c0():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch", "torch"))
|
||||
# That's a statement
|
||||
class C1:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch", "torch"))
|
||||
# That's a statement
|
||||
def c1():
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch", "torch"))
|
||||
# That's a statement
|
||||
class C2:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(backends=("torch", "torch"))
|
||||
# That's a statement
|
||||
def c2():
|
||||
pass
|
||||
|
||||
|
||||
@requires(
|
||||
backends=(
|
||||
"torch",
|
||||
"torch"
|
||||
)
|
||||
)
|
||||
# That's a statement
|
||||
class C3:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@requires(
|
||||
backends=(
|
||||
"torch",
|
||||
"torch"
|
||||
)
|
||||
)
|
||||
# That's a statement
|
||||
def c3():
|
||||
pass
|
||||
74
tests/utils/test_activations.py
Normal file
74
tests/utils/test_activations.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.activations import gelu_new, gelu_python, get_activation
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestActivations(unittest.TestCase):
|
||||
def test_gelu_versions(self):
|
||||
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
|
||||
torch_builtin = get_activation("gelu")
|
||||
torch.testing.assert_close(gelu_python(x), torch_builtin(x))
|
||||
self.assertFalse(torch.allclose(gelu_python(x), gelu_new(x)))
|
||||
|
||||
def test_gelu_10(self):
|
||||
x = torch.tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
|
||||
torch_builtin = get_activation("gelu")
|
||||
gelu10 = get_activation("gelu_10")
|
||||
|
||||
y_gelu = torch_builtin(x)
|
||||
y_gelu_10 = gelu10(x)
|
||||
|
||||
clipped_mask = torch.where(y_gelu_10 < 10.0, 1, 0)
|
||||
|
||||
self.assertTrue(torch.max(y_gelu_10).item() == 10.0)
|
||||
torch.testing.assert_close(y_gelu * clipped_mask, y_gelu_10 * clipped_mask)
|
||||
|
||||
def test_get_activation(self):
|
||||
get_activation("gelu")
|
||||
get_activation("gelu_10")
|
||||
get_activation("gelu_fast")
|
||||
get_activation("gelu_new")
|
||||
get_activation("gelu_python")
|
||||
get_activation("gelu_pytorch_tanh")
|
||||
get_activation("linear")
|
||||
get_activation("mish")
|
||||
get_activation("quick_gelu")
|
||||
get_activation("relu")
|
||||
get_activation("sigmoid")
|
||||
get_activation("silu")
|
||||
get_activation("swish")
|
||||
get_activation("tanh")
|
||||
with self.assertRaises(KeyError):
|
||||
get_activation("bogus")
|
||||
with self.assertRaises(KeyError):
|
||||
get_activation(None)
|
||||
|
||||
def test_activations_are_distinct_objects(self):
|
||||
act1 = get_activation("gelu")
|
||||
act1.a = 1
|
||||
act2 = get_activation("gelu")
|
||||
self.assertEqual(act1.a, 1)
|
||||
with self.assertRaises(AttributeError):
|
||||
_ = act2.a
|
||||
806
tests/utils/test_add_new_model_like.py
Normal file
806
tests/utils/test_add_new_model_like.py
Normal file
@@ -0,0 +1,806 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import textwrap
|
||||
import unittest
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.cli.add_new_model_like import ModelInfos, _add_new_model_like_internal
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
REPO_PATH = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
MODELS_TO_COPY = ("auto", "llama", "phi4_multimodal")
|
||||
CURRENT_YEAR = date.today().year
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestAddNewModelLike(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Create a temporary repo with the same structure as Transformers, with just 2 models.
|
||||
"""
|
||||
cls.tmp_dir = tempfile.TemporaryDirectory()
|
||||
cls.FAKE_REPO = cls.tmp_dir.name
|
||||
os.makedirs(os.path.join(cls.FAKE_REPO, "src", "transformers", "models"), exist_ok=True)
|
||||
os.makedirs(os.path.join(cls.FAKE_REPO, "tests", "models"), exist_ok=True)
|
||||
os.makedirs(os.path.join(cls.FAKE_REPO, "docs", "source", "en", "model_doc"), exist_ok=True)
|
||||
|
||||
# We need to copy the utils to run the cleanup commands
|
||||
utils_src = os.path.join(REPO_PATH, "utils")
|
||||
shutil.copytree(utils_src, utils_src.replace(REPO_PATH, cls.FAKE_REPO))
|
||||
# Copy the __init__ files
|
||||
model_init = os.path.join(REPO_PATH, "src", "transformers", "models", "__init__.py")
|
||||
shutil.copy(model_init, model_init.replace(REPO_PATH, cls.FAKE_REPO))
|
||||
doc_toc = os.path.join(REPO_PATH, "docs", "source", "en", "_toctree.yml")
|
||||
shutil.copy(doc_toc, doc_toc.replace(REPO_PATH, cls.FAKE_REPO))
|
||||
# We need the pyproject for ruff as well
|
||||
pyproject = os.path.join(REPO_PATH, "pyproject.toml")
|
||||
shutil.copy(pyproject, pyproject.replace(REPO_PATH, cls.FAKE_REPO))
|
||||
# Copy over all the specific model files
|
||||
for model in MODELS_TO_COPY:
|
||||
model_src = os.path.join(REPO_PATH, "src", "transformers", "models", model)
|
||||
shutil.copytree(model_src, model_src.replace(REPO_PATH, cls.FAKE_REPO))
|
||||
|
||||
test_src = os.path.join(REPO_PATH, "tests", "models", model)
|
||||
shutil.copytree(test_src, test_src.replace(REPO_PATH, cls.FAKE_REPO))
|
||||
|
||||
if model != "auto":
|
||||
doc_src = os.path.join(REPO_PATH, "docs", "source", "en", "model_doc", f"{model}.md")
|
||||
shutil.copy(doc_src, doc_src.replace(REPO_PATH, cls.FAKE_REPO))
|
||||
|
||||
# For convenience
|
||||
cls.MODEL_PATH = os.path.join(cls.FAKE_REPO, "src", "transformers", "models")
|
||||
cls.TESTS_MODEL_PATH = os.path.join(cls.FAKE_REPO, "tests", "models")
|
||||
cls.DOC_PATH = os.path.join(cls.FAKE_REPO, "docs", "source", "en")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.tmp_dir.cleanup()
|
||||
|
||||
def assertFileIsEqual(self, text: str, filepath: str):
|
||||
with open(filepath, "r") as f:
|
||||
file_text = f.read()
|
||||
self.assertEqual(file_text.strip(), text.strip())
|
||||
|
||||
def assertInFile(self, text: str, filepath: str):
|
||||
with open(filepath, "r") as f:
|
||||
file_text = f.read()
|
||||
self.assertTrue(text in file_text)
|
||||
|
||||
def test_llama_without_tokenizers(self):
|
||||
# This is the structure without adding the tokenizers
|
||||
filenames_to_add = (
|
||||
("configuration_llama.py", True),
|
||||
("modeling_llama.py", True),
|
||||
("tokenization_llama.py", False),
|
||||
("tokenization_llama_fast.py", False),
|
||||
("image_processing_llama_pil.py", False),
|
||||
("image_processing_llama.py", False),
|
||||
("video_processing_llama.py", False),
|
||||
("feature_extraction_llama.py", False),
|
||||
("processing_llama.py", False),
|
||||
)
|
||||
# Run the command
|
||||
_add_new_model_like_internal(
|
||||
repo_path=Path(self.FAKE_REPO),
|
||||
old_model_infos=ModelInfos("llama"),
|
||||
new_lowercase_name="my_test",
|
||||
new_model_paper_name="MyTest",
|
||||
filenames_to_add=filenames_to_add,
|
||||
)
|
||||
|
||||
# First assert that all files were created correctly
|
||||
model_repo = os.path.join(self.MODEL_PATH, "my_test")
|
||||
tests_repo = os.path.join(self.TESTS_MODEL_PATH, "my_test")
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "modular_my_test.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "modeling_my_test.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "configuration_my_test.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "__init__.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(self.DOC_PATH, "model_doc", "my_test.md")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "__init__.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_modeling_my_test.py")))
|
||||
|
||||
# Now assert the correct imports/auto mappings/toctree were added
|
||||
self.assertInFile(
|
||||
"from .my_test import *\n",
|
||||
os.path.join(self.MODEL_PATH, "__init__.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test", "MyTestConfig"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "auto_mappings.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test", "MyTestModel"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test", "MyTestForCausalLM"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test", "MyTestForSequenceClassification"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test", "MyTestForQuestionAnswering"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test", "MyTestForTokenClassification"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
"- local: model_doc/my_test\n title: MyTest\n",
|
||||
os.path.join(self.DOC_PATH, "_toctree.yml"),
|
||||
)
|
||||
|
||||
# Check some exact file creation. For model definition, only check modular as modeling/config/etc... are created
|
||||
# directly from it
|
||||
EXPECTED_MODULAR = textwrap.dedent(
|
||||
f"""
|
||||
# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..llama.configuration_llama import LlamaConfig
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class MyTestConfig(LlamaConfig):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestRMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestMLP(LlamaMLP):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestAttention(LlamaAttention):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestDecoderLayer(LlamaDecoderLayer):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestPreTrainedModel(LlamaPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestModel(LlamaModel):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestForQuestionAnswering(LlamaForQuestionAnswering):
|
||||
pass
|
||||
|
||||
|
||||
class MyTestForTokenClassification(LlamaForTokenClassification):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MyTestConfig",
|
||||
"MyTestForCausalLM",
|
||||
"MyTestModel",
|
||||
"MyTestPreTrainedModel",
|
||||
"MyTestForSequenceClassification",
|
||||
"MyTestForQuestionAnswering",
|
||||
"MyTestForTokenClassification",
|
||||
]
|
||||
"""
|
||||
)
|
||||
self.assertFileIsEqual(EXPECTED_MODULAR, os.path.join(model_repo, "modular_my_test.py"))
|
||||
|
||||
EXPECTED_INIT = textwrap.dedent(
|
||||
f"""
|
||||
# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_my_test import *
|
||||
from .modeling_my_test import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
|
||||
"""
|
||||
)
|
||||
self.assertFileIsEqual(EXPECTED_INIT, os.path.join(model_repo, "__init__.py"))
|
||||
|
||||
EXPECTED_DOC = textwrap.dedent(
|
||||
f"""
|
||||
<!--Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
|
||||
# MyTest
|
||||
|
||||
## Overview
|
||||
|
||||
The MyTest model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
||||
<INSERT SHORT SUMMARY HERE>
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
<INSERT PAPER ABSTRACT HERE>
|
||||
|
||||
Tips:
|
||||
|
||||
<INSERT TIPS ABOUT MODEL HERE>
|
||||
|
||||
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
||||
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
||||
|
||||
## Usage examples
|
||||
|
||||
<INSERT SOME NICE EXAMPLES HERE>
|
||||
|
||||
## MyTestConfig
|
||||
|
||||
[[autodoc]] MyTestConfig
|
||||
|
||||
## MyTestForCausalLM
|
||||
|
||||
[[autodoc]] MyTestForCausalLM
|
||||
|
||||
## MyTestModel
|
||||
|
||||
[[autodoc]] MyTestModel
|
||||
- forward
|
||||
|
||||
## MyTestPreTrainedModel
|
||||
|
||||
[[autodoc]] MyTestPreTrainedModel
|
||||
- forward
|
||||
|
||||
## MyTestForSequenceClassification
|
||||
|
||||
[[autodoc]] MyTestForSequenceClassification
|
||||
|
||||
## MyTestForQuestionAnswering
|
||||
|
||||
[[autodoc]] MyTestForQuestionAnswering
|
||||
|
||||
## MyTestForTokenClassification
|
||||
|
||||
[[autodoc]] MyTestForTokenClassification
|
||||
"""
|
||||
)
|
||||
self.assertFileIsEqual(EXPECTED_DOC, os.path.join(self.DOC_PATH, "model_doc", "my_test.md"))
|
||||
|
||||
def test_phi4_with_all_processors(self):
|
||||
# This is the structure without adding the tokenizers
|
||||
filenames_to_add = (
|
||||
("configuration_phi4_multimodal.py", True),
|
||||
("modeling_phi4_multimodal.py", True),
|
||||
("tokenization_phi4_multimodal.py", False),
|
||||
("tokenization_phi4_multimodal_fast.py", False),
|
||||
("image_processing_phi4_multimodal_pil.py", False),
|
||||
("image_processing_phi4_multimodal.py", True),
|
||||
("video_processing_phi4_multimodal.py", False),
|
||||
("feature_extraction_phi4_multimodal.py", True),
|
||||
("processing_phi4_multimodal.py", True),
|
||||
)
|
||||
# Run the command
|
||||
_add_new_model_like_internal(
|
||||
repo_path=Path(self.FAKE_REPO),
|
||||
old_model_infos=ModelInfos("phi4_multimodal"),
|
||||
new_lowercase_name="my_test2",
|
||||
new_model_paper_name="MyTest2",
|
||||
filenames_to_add=filenames_to_add,
|
||||
)
|
||||
|
||||
# First assert that all files were created correctly
|
||||
model_repo = os.path.join(self.MODEL_PATH, "my_test2")
|
||||
tests_repo = os.path.join(self.TESTS_MODEL_PATH, "my_test2")
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "modular_my_test2.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "modeling_my_test2.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "configuration_my_test2.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "image_processing_my_test2.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "feature_extraction_my_test2.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "processing_my_test2.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(model_repo, "__init__.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(self.DOC_PATH, "model_doc", "my_test2.md")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "__init__.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_modeling_my_test2.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_feature_extraction_my_test2.py")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tests_repo, "test_image_processing_my_test2.py")))
|
||||
|
||||
# Now assert the correct imports/auto mappings/toctree were added
|
||||
self.assertInFile(
|
||||
"from .my_test2 import *\n",
|
||||
os.path.join(self.MODEL_PATH, "__init__.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test2", "MyTest2Config"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "auto_mappings.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test2", "MyTest2Model"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test2", "MyTest2ForCausalLM"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "modeling_auto.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test2", {"torchvision": "MyTest2ImageProcessor"}),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "auto_mappings.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test2", "MyTest2FeatureExtractor"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "feature_extraction_auto.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
'("my_test2", "MyTest2Processor"),\n',
|
||||
os.path.join(self.MODEL_PATH, "auto", "processing_auto.py"),
|
||||
)
|
||||
self.assertInFile(
|
||||
"- local: model_doc/my_test2\n title: MyTest2\n",
|
||||
os.path.join(self.DOC_PATH, "_toctree.yml"),
|
||||
)
|
||||
|
||||
# Check some exact file creation. For model definition, only check modular as modeling/config/etc... are created
|
||||
# directly from it
|
||||
EXPECTED_MODULAR = textwrap.dedent(
|
||||
f"""
|
||||
# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..phi4_multimodal.configuration_phi4_multimodal import (
|
||||
Phi4MultimodalAudioConfig,
|
||||
Phi4MultimodalConfig,
|
||||
Phi4MultimodalVisionConfig,
|
||||
)
|
||||
from ..phi4_multimodal.feature_extraction_phi4_multimodal import Phi4MultimodalFeatureExtractor
|
||||
from ..phi4_multimodal.image_processing_phi4_multimodal import (
|
||||
Phi4MultimodalImageProcessor,
|
||||
Phi4MultimodalImageProcessorKwargs,
|
||||
)
|
||||
from ..phi4_multimodal.modeling_phi4_multimodal import (
|
||||
Phi4MultimodalAttention,
|
||||
Phi4MultimodalAudioAttention,
|
||||
Phi4MultimodalAudioConformerEncoderLayer,
|
||||
Phi4MultimodalAudioConvModule,
|
||||
Phi4MultimodalAudioDepthWiseSeparableConv1d,
|
||||
Phi4MultimodalAudioEmbedding,
|
||||
Phi4MultimodalAudioGluPointWiseConv,
|
||||
Phi4MultimodalAudioMeanVarianceNormLayer,
|
||||
Phi4MultimodalAudioMLP,
|
||||
Phi4MultimodalAudioModel,
|
||||
Phi4MultimodalAudioNemoConvSubsampling,
|
||||
Phi4MultimodalAudioPreTrainedModel,
|
||||
Phi4MultimodalAudioRelativeAttentionBias,
|
||||
Phi4MultimodalDecoderLayer,
|
||||
Phi4MultimodalFeatureEmbedding,
|
||||
Phi4MultimodalForCausalLM,
|
||||
Phi4MultimodalImageEmbedding,
|
||||
Phi4MultimodalMLP,
|
||||
Phi4MultimodalModel,
|
||||
Phi4MultimodalPreTrainedModel,
|
||||
Phi4MultimodalRMSNorm,
|
||||
Phi4MultimodalRotaryEmbedding,
|
||||
Phi4MultimodalVisionAttention,
|
||||
Phi4MultimodalVisionEmbeddings,
|
||||
Phi4MultimodalVisionEncoder,
|
||||
Phi4MultimodalVisionEncoderLayer,
|
||||
Phi4MultimodalVisionMLP,
|
||||
Phi4MultimodalVisionModel,
|
||||
Phi4MultimodalVisionMultiheadAttentionPoolingHead,
|
||||
Phi4MultimodalVisionPreTrainedModel,
|
||||
)
|
||||
from ..phi4_multimodal.processing_phi4_multimodal import Phi4MultimodalProcessor, Phi4MultimodalProcessorKwargs
|
||||
|
||||
|
||||
class MyTest2VisionConfig(Phi4MultimodalVisionConfig):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioConfig(Phi4MultimodalAudioConfig):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2Config(Phi4MultimodalConfig):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2VisionMLP(Phi4MultimodalVisionMLP):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2VisionAttention(Phi4MultimodalVisionAttention):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2VisionEncoderLayer(Phi4MultimodalVisionEncoderLayer):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2VisionEncoder(Phi4MultimodalVisionEncoder):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2VisionPreTrainedModel(Phi4MultimodalVisionPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2VisionEmbeddings(Phi4MultimodalVisionEmbeddings):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2VisionMultiheadAttentionPoolingHead(Phi4MultimodalVisionMultiheadAttentionPoolingHead):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2VisionModel(Phi4MultimodalVisionModel):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2ImageEmbedding(Phi4MultimodalImageEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioMLP(Phi4MultimodalAudioMLP):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioAttention(Phi4MultimodalAudioAttention):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioDepthWiseSeparableConv1d(Phi4MultimodalAudioDepthWiseSeparableConv1d):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioGluPointWiseConv(Phi4MultimodalAudioGluPointWiseConv):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioConvModule(Phi4MultimodalAudioConvModule):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioConformerEncoderLayer(Phi4MultimodalAudioConformerEncoderLayer):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioNemoConvSubsampling(Phi4MultimodalAudioNemoConvSubsampling):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioRelativeAttentionBias(Phi4MultimodalAudioRelativeAttentionBias):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioMeanVarianceNormLayer(Phi4MultimodalAudioMeanVarianceNormLayer):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioPreTrainedModel(Phi4MultimodalAudioPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioModel(Phi4MultimodalAudioModel):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2AudioEmbedding(Phi4MultimodalAudioEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2RMSNorm(Phi4MultimodalRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2MLP(Phi4MultimodalMLP):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2Attention(Phi4MultimodalAttention):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2DecoderLayer(Phi4MultimodalDecoderLayer):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2FeatureEmbedding(Phi4MultimodalFeatureEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2PreTrainedModel(Phi4MultimodalPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2RotaryEmbedding(Phi4MultimodalRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2Model(Phi4MultimodalModel):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2ForCausalLM(Phi4MultimodalForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2ImageProcessorKwargs(Phi4MultimodalImageProcessorKwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2ImageProcessor(Phi4MultimodalImageProcessor):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2FeatureExtractor(Phi4MultimodalFeatureExtractor):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2ProcessorKwargs(Phi4MultimodalProcessorKwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MyTest2Processor(Phi4MultimodalProcessor):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MyTest2VisionConfig",
|
||||
"MyTest2AudioConfig",
|
||||
"MyTest2Config",
|
||||
"MyTest2AudioPreTrainedModel",
|
||||
"MyTest2AudioModel",
|
||||
"MyTest2VisionPreTrainedModel",
|
||||
"MyTest2VisionModel",
|
||||
"MyTest2PreTrainedModel",
|
||||
"MyTest2Model",
|
||||
"MyTest2ForCausalLM",
|
||||
"MyTest2ImageProcessor",
|
||||
"MyTest2FeatureExtractor",
|
||||
"MyTest2Processor",
|
||||
]
|
||||
"""
|
||||
)
|
||||
self.assertFileIsEqual(EXPECTED_MODULAR, os.path.join(model_repo, "modular_my_test2.py"))
|
||||
|
||||
EXPECTED_INIT = textwrap.dedent(
|
||||
f"""
|
||||
# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_my_test2 import *
|
||||
from .feature_extraction_my_test2 import *
|
||||
from .image_processing_my_test2 import *
|
||||
from .modeling_my_test2 import *
|
||||
from .processing_my_test2 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||
"""
|
||||
)
|
||||
self.assertFileIsEqual(EXPECTED_INIT, os.path.join(model_repo, "__init__.py"))
|
||||
|
||||
EXPECTED_DOC = textwrap.dedent(
|
||||
f"""
|
||||
<!--Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
|
||||
# MyTest2
|
||||
|
||||
## Overview
|
||||
|
||||
The MyTest2 model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
||||
<INSERT SHORT SUMMARY HERE>
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
<INSERT PAPER ABSTRACT HERE>
|
||||
|
||||
Tips:
|
||||
|
||||
<INSERT TIPS ABOUT MODEL HERE>
|
||||
|
||||
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
||||
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
||||
|
||||
## Usage examples
|
||||
|
||||
<INSERT SOME NICE EXAMPLES HERE>
|
||||
|
||||
## MyTest2VisionConfig
|
||||
|
||||
[[autodoc]] MyTest2VisionConfig
|
||||
|
||||
## MyTest2AudioConfig
|
||||
|
||||
[[autodoc]] MyTest2AudioConfig
|
||||
|
||||
## MyTest2Config
|
||||
|
||||
[[autodoc]] MyTest2Config
|
||||
|
||||
## MyTest2AudioPreTrainedModel
|
||||
|
||||
[[autodoc]] MyTest2AudioPreTrainedModel
|
||||
- forward
|
||||
|
||||
## MyTest2AudioModel
|
||||
|
||||
[[autodoc]] MyTest2AudioModel
|
||||
- forward
|
||||
|
||||
## MyTest2VisionPreTrainedModel
|
||||
|
||||
[[autodoc]] MyTest2VisionPreTrainedModel
|
||||
- forward
|
||||
|
||||
## MyTest2VisionModel
|
||||
|
||||
[[autodoc]] MyTest2VisionModel
|
||||
- forward
|
||||
|
||||
## MyTest2PreTrainedModel
|
||||
|
||||
[[autodoc]] MyTest2PreTrainedModel
|
||||
- forward
|
||||
|
||||
## MyTest2Model
|
||||
|
||||
[[autodoc]] MyTest2Model
|
||||
- forward
|
||||
|
||||
## MyTest2ForCausalLM
|
||||
|
||||
[[autodoc]] MyTest2ForCausalLM
|
||||
|
||||
## MyTest2ImageProcessor
|
||||
|
||||
[[autodoc]] MyTest2ImageProcessor
|
||||
|
||||
## MyTest2FeatureExtractor
|
||||
|
||||
[[autodoc]] MyTest2FeatureExtractor
|
||||
|
||||
## MyTest2Processor
|
||||
|
||||
[[autodoc]] MyTest2Processor
|
||||
"""
|
||||
)
|
||||
self.assertFileIsEqual(EXPECTED_DOC, os.path.join(self.DOC_PATH, "model_doc", "my_test2.md"))
|
||||
125
tests/utils/test_attention_visualizer.py
Normal file
125
tests/utils/test_attention_visualizer.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import io
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
|
||||
|
||||
|
||||
ANSI_RE = re.compile(r"\x1b\[[0-9;]*m")
|
||||
|
||||
|
||||
def _normalize(s: str) -> str:
|
||||
# drop ANSI (colors may be disabled on CI), normalize line endings,
|
||||
# and strip trailing spaces without touching alignment inside lines
|
||||
s = ANSI_RE.sub("", s)
|
||||
s = s.replace("\r\n", "\n").replace("\r", "\n")
|
||||
return "\n".join(line.rstrip() for line in s.split("\n")).strip()
|
||||
|
||||
|
||||
@require_torch
|
||||
class AttentionMaskVisualizerTester(unittest.TestCase):
|
||||
"""Test suite for AttentionMaskVisualizer"""
|
||||
|
||||
def test_paligemma_multimodal_visualization(self):
|
||||
"""Test AttentionMaskVisualizer with PaliGemma multimodal model"""
|
||||
model_name = "hf-internal-testing/namespace_google_repo_name_paligemma-3b-pt-224"
|
||||
input_text = "<img> What is in this image?"
|
||||
|
||||
buf = io.StringIO()
|
||||
orig_print = builtins.print
|
||||
|
||||
def _print(*args, **kwargs):
|
||||
kwargs.setdefault("file", buf)
|
||||
orig_print(*args, **kwargs)
|
||||
|
||||
try:
|
||||
builtins.print = _print
|
||||
visualizer = AttentionMaskVisualizer(model_name)
|
||||
visualizer(input_text)
|
||||
finally:
|
||||
builtins.print = orig_print
|
||||
output = buf.getvalue()
|
||||
|
||||
expected_output = """
|
||||
##########################################################################################################################################################################################################################################
|
||||
## Attention visualization for \033[1mpaligemma:hf-internal-testing/namespace_google_repo_name_paligemma-3b-pt-224\033[0m PaliGemmaModel ##
|
||||
##########################################################################################################################################################################################################################################
|
||||
\033[92m■\033[0m: i == j (diagonal) \033[93m■\033[0m: token_type_ids
|
||||
Attention Matrix
|
||||
|
||||
|
||||
\033[93m'<image>'\033[0m: 0 \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
\033[93m'<image>'\033[0m: 1 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
\033[93m'<image>'\033[0m: 2 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
\033[93m'<image>'\033[0m: 3 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
\033[93m'<image>'\033[0m: 4 \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m \033[93m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'<bos>' : 5 ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁What' : 6 ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁is' : 7 ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁in' : 8 ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁this' : 9 ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁image' : 10 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ |
|
||||
'?' : 11 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ |
|
||||
'\\n' : 12 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ |
|
||||
'<eos>' : 13 ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m |
|
||||
##########################################################################################################################################################################################################################################
|
||||
""" # noqa
|
||||
|
||||
self.assertEqual(_normalize(output), _normalize(expected_output))
|
||||
|
||||
def test_llama_text_only_visualization(self):
|
||||
"""Test AttentionMaskVisualizer with Llama text-only model"""
|
||||
model_name = "hf-internal-testing/namespace_meta-llama_repo_name_Llama-2-7b-hf"
|
||||
input_text = "Plants create energy through a process known as"
|
||||
|
||||
buf = io.StringIO()
|
||||
orig_print = builtins.print
|
||||
|
||||
def _print(*args, **kwargs):
|
||||
kwargs.setdefault("file", buf)
|
||||
orig_print(*args, **kwargs)
|
||||
|
||||
try:
|
||||
builtins.print = _print
|
||||
visualizer = AttentionMaskVisualizer(model_name)
|
||||
visualizer(input_text)
|
||||
finally:
|
||||
builtins.print = orig_print
|
||||
output = buf.getvalue()
|
||||
|
||||
expected_output = """
|
||||
##########################################################################################################################################################################################################
|
||||
## Attention visualization for \033[1mllama:hf-internal-testing/namespace_meta-llama_repo_name_Llama-2-7b-hf\033[0m LlamaModel ##
|
||||
##########################################################################################################################################################################################################
|
||||
\033[92m■\033[0m: i == j (diagonal) \033[93m■\033[0m: token_type_ids
|
||||
Attention Matrix
|
||||
|
||||
'▁Pl' : 0 \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'ants' : 1 ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁create' : 2 ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁energy' : 3 ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁through': 4 ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ ⬚ |
|
||||
'▁a' : 5 ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ ⬚ |
|
||||
'▁process': 6 ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ ⬚ |
|
||||
'▁known' : 7 ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m ⬚ |
|
||||
'▁as' : 8 ■ ■ ■ ■ ■ ■ ■ ■ \033[92m■\033[0m |
|
||||
##########################################################################################################################################################################################################
|
||||
""" # noqa
|
||||
|
||||
self.assertEqual(_normalize(output), _normalize(expected_output))
|
||||
1751
tests/utils/test_audio_utils.py
Normal file
1751
tests/utils/test_audio_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
761
tests/utils/test_auto_docstring.py
Normal file
761
tests/utils/test_auto_docstring.py
Normal file
@@ -0,0 +1,761 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Tests for auto_docstring decorator and check_auto_docstrings function.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import statistics
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub.dataclasses import strict
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.image_processing_backends import TorchvisionBackend
|
||||
from transformers.image_processing_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils.auto_docstring import (
|
||||
auto_docstring,
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
sys.path.insert(0, str(_repo_root / "utils"))
|
||||
|
||||
from check_docstrings import ( # noqa: E402
|
||||
_build_ast_indexes,
|
||||
_find_typed_dict_classes,
|
||||
find_files_with_auto_docstring,
|
||||
update_file_with_new_docstrings,
|
||||
)
|
||||
|
||||
|
||||
class TestCheckDocstrings(unittest.TestCase):
|
||||
"""Test check_auto_docstrings static analysis tool for detecting and fixing docstring issues."""
|
||||
|
||||
def test_missing_args_detection_and_placeholder_generation(self):
|
||||
"""Test that missing custom args are detected and placeholders generated while preserving Examples and code."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = os.path.join(tmpdir, "model.py")
|
||||
|
||||
original = textwrap.dedent("""
|
||||
from transformers.utils.auto_docstring import auto_docstring
|
||||
|
||||
@auto_docstring
|
||||
def forward(self, input_ids, custom_temperature: float = 1.0):
|
||||
'''
|
||||
Example:
|
||||
```python
|
||||
>>> model.forward(input_ids, custom_temperature=0.7)
|
||||
```
|
||||
'''
|
||||
result = input_ids * custom_temperature
|
||||
return result
|
||||
""")
|
||||
|
||||
with open(test_file, "w") as f:
|
||||
f.write(original)
|
||||
|
||||
with open(test_file, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
items = _build_ast_indexes(content)
|
||||
lines = content.split("\n")
|
||||
|
||||
# Test detection (overwrite=False) - should detect missing arg
|
||||
missing, fill, redundant = update_file_with_new_docstrings(
|
||||
test_file, lines, items, content, overwrite=False
|
||||
)
|
||||
self.assertTrue(any("custom_temperature" in msg for msg in missing))
|
||||
|
||||
# Generate placeholders (overwrite=True)
|
||||
update_file_with_new_docstrings(test_file, lines, items, content, overwrite=True)
|
||||
|
||||
with open(test_file, "r") as f:
|
||||
updated = f.read()
|
||||
|
||||
# Verify results
|
||||
self.assertIn("custom_temperature", updated)
|
||||
self.assertIn("<fill_docstring>", updated) # Placeholder added
|
||||
self.assertIn("input_ids", updated) # Standard arg from ModelArgs
|
||||
self.assertIn("Example:", updated) # Example preserved
|
||||
self.assertIn("result = input_ids * custom_temperature", updated) # Code preserved
|
||||
|
||||
def test_multi_item_file_processing(self):
|
||||
"""Test processing files with multiple @auto_docstring decorators (class + method) in a single pass."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_file = os.path.join(tmpdir, "modeling.py")
|
||||
|
||||
original = textwrap.dedent("""
|
||||
from transformers.utils.auto_docstring import auto_docstring
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
@auto_docstring
|
||||
class MyModel(PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.layer = None
|
||||
|
||||
@auto_docstring
|
||||
def forward(self, input_ids, scale_factor: float = 1.0):
|
||||
'''
|
||||
Example:
|
||||
```python
|
||||
>>> outputs = model.forward(input_ids, scale_factor=2.0)
|
||||
```
|
||||
'''
|
||||
return self.layer(input_ids) * scale_factor
|
||||
""")
|
||||
|
||||
with open(test_file, "w") as f:
|
||||
f.write(original)
|
||||
|
||||
with open(test_file, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
items = _build_ast_indexes(content)
|
||||
|
||||
# Should find 2 decorated items
|
||||
self.assertEqual(len(items), 2)
|
||||
self.assertEqual(items[0].kind, "class")
|
||||
self.assertEqual(items[1].kind, "function")
|
||||
|
||||
lines = content.split("\n")
|
||||
|
||||
# Detect issues
|
||||
missing, fill, redundant = update_file_with_new_docstrings(
|
||||
test_file, lines, items, content, overwrite=False
|
||||
)
|
||||
|
||||
# Should detect missing scale_factor in forward method
|
||||
self.assertTrue(any("scale_factor" in msg for msg in missing))
|
||||
|
||||
# Update file
|
||||
update_file_with_new_docstrings(test_file, lines, items, content, overwrite=True)
|
||||
|
||||
with open(test_file, "r") as f:
|
||||
updated = f.read()
|
||||
|
||||
# Verify updates and preservation
|
||||
self.assertIn("scale_factor", updated) # Custom arg added with placeholder
|
||||
self.assertIn("<fill_docstring>", updated) # Placeholder present
|
||||
self.assertIn("Example:", updated) # Example preserved
|
||||
self.assertIn("self.layer = None", updated) # __init__ code preserved
|
||||
self.assertIn("return self.layer(input_ids) * scale_factor", updated) # forward code preserved
|
||||
|
||||
def test_typed_dict_field_detection(self):
|
||||
"""Test that _find_typed_dict_classes correctly identifies custom fields vs standard inherited fields."""
|
||||
content = textwrap.dedent("""
|
||||
from typing import TypedDict
|
||||
from transformers.processing_utils import ImagesKwargs
|
||||
|
||||
class CustomImageKwargs(ImagesKwargs, total=False):
|
||||
'''
|
||||
custom_mode (`str`):
|
||||
Custom processing mode.
|
||||
'''
|
||||
# Standard field from ImagesKwargs - should be in all_fields but not fields
|
||||
do_resize: bool
|
||||
# Custom fields - should be in both all_fields and fields
|
||||
custom_mode: str
|
||||
undocumented_custom: int
|
||||
""")
|
||||
|
||||
typed_dicts = _find_typed_dict_classes(content)
|
||||
|
||||
# Should find the TypedDict
|
||||
self.assertEqual(len(typed_dicts), 1)
|
||||
self.assertEqual(typed_dicts[0]["name"], "CustomImageKwargs")
|
||||
|
||||
# all_fields includes everything
|
||||
self.assertIn("do_resize", typed_dicts[0]["all_fields"])
|
||||
self.assertIn("custom_mode", typed_dicts[0]["all_fields"])
|
||||
self.assertIn("undocumented_custom", typed_dicts[0]["all_fields"])
|
||||
|
||||
# fields only includes custom fields (not standard args like do_resize)
|
||||
# Both documented and undocumented custom fields are included
|
||||
self.assertIn("custom_mode", typed_dicts[0]["fields"])
|
||||
self.assertIn("undocumented_custom", typed_dicts[0]["fields"])
|
||||
self.assertNotIn("do_resize", typed_dicts[0]["fields"]) # Standard arg excluded
|
||||
|
||||
def test_file_discovery_finds_decorated_files(self):
|
||||
"""Test that check_auto_docstrings can discover files containing @auto_docstring."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
has_decorator = os.path.join(tmpdir, "modeling.py")
|
||||
no_decorator = os.path.join(tmpdir, "utils.py")
|
||||
|
||||
with open(has_decorator, "w") as f:
|
||||
f.write("@auto_docstring\ndef forward(self): pass")
|
||||
|
||||
with open(no_decorator, "w") as f:
|
||||
f.write("def helper(): pass")
|
||||
|
||||
found = find_files_with_auto_docstring([has_decorator, no_decorator])
|
||||
|
||||
self.assertEqual(len(found), 1)
|
||||
self.assertEqual(found[0], has_decorator)
|
||||
|
||||
|
||||
class DummyConfig(PretrainedConfig):
|
||||
model_type = "dummy_test"
|
||||
|
||||
def __init__(self, vocab_size=1000, hidden_size=768, num_attention_heads=12, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DummyForTestModel(PreTrainedModel):
|
||||
config_class = DummyConfig
|
||||
|
||||
def __init__(self, config: DummyConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
temperature: float = 1.0,
|
||||
custom_dict: dict[str, int | float] | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
temperature (`float`, *optional*, defaults to 1.0):
|
||||
Temperature value for scaling logits during generation.
|
||||
custom_dict (`dict[str, Union[int, float]]`, *optional*):
|
||||
Custom dictionary parameter with string keys and numeric values.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, DummyForTestModel
|
||||
>>> import torch
|
||||
|
||||
>>> model = DummyForTestModel.from_pretrained("dummy-model")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("dummy-model")
|
||||
>>> inputs = tokenizer("Hello world", return_tensors="pt")
|
||||
>>> outputs = model.forward(**inputs, temperature=0.7)
|
||||
>>> logits = outputs.logits
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ComplexProcessorKwargs(ProcessingKwargs, total=False):
|
||||
r"""
|
||||
custom_processing_mode (`str`, *optional*, defaults to `"standard"`):
|
||||
Custom processing mode for advanced text/image processing. Can be 'standard', 'enhanced', or 'experimental'.
|
||||
enable_advanced_features (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable advanced processing features like custom tokenization strategies.
|
||||
custom_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Custom threshold value for filtering or processing decisions.
|
||||
output_format (`str`, *optional*, defaults to `"default"`):
|
||||
Output format specification. Can be 'default', 'extended', or 'minimal'.
|
||||
"""
|
||||
|
||||
custom_processing_mode: str
|
||||
enable_advanced_features: bool
|
||||
custom_threshold: float
|
||||
output_format: str
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DummyProcessorForTest(ProcessorMixin):
|
||||
def __init__(
|
||||
self,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
custom_processing_mode="standard",
|
||||
enable_advanced_features=False,
|
||||
custom_threshold=0.5,
|
||||
output_format="default",
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
custom_processing_mode (`str`, *optional*, defaults to `"standard"`):
|
||||
Custom processing mode for advanced text/image processing. Can be 'standard', 'enhanced', or 'experimental'.
|
||||
enable_advanced_features (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable advanced processing features like custom tokenization strategies.
|
||||
custom_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Custom threshold value for filtering or processing decisions.
|
||||
output_format (`str`, *optional*, defaults to `"default"`):
|
||||
Output format specification. Can be 'default', 'extended', or 'minimal'.
|
||||
"""
|
||||
pass
|
||||
|
||||
@auto_docstring
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput | None = None,
|
||||
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
|
||||
**kwargs: Unpack[ComplexProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import DummyProcessorForTest
|
||||
>>> processor = DummyProcessorForTest.from_pretrained("dummy-processor")
|
||||
>>> inputs = processor(text="Hello world", images=["image.jpg"], return_tensors="pt")
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DummyImageProcessorKwargs(ImagesKwargs, total=False):
|
||||
r"""
|
||||
image_grid_pinpoints (`list[list[int]]`, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method.
|
||||
custom_scale (`float`, *optional*, defaults to 255.0):
|
||||
Custom scale factor for preprocessing pipelines.
|
||||
"""
|
||||
|
||||
image_grid_pinpoints: list[list[int]]
|
||||
custom_scale: float
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Constructs a fast DummyForTest image processor.
|
||||
"""
|
||||
)
|
||||
class DummyForTestImageProcessorFast(TorchvisionBackend):
|
||||
model_input_names = ["pixel_values"]
|
||||
valid_kwargs = DummyImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[DummyImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
**kwargs: Unpack[DummyImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import DummyForTestImageProcessorFast
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> processor = DummyForTestImageProcessorFast.from_pretrained("dummy-processor")
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> inputs = processor.preprocess(images=image, return_tensors="pt")
|
||||
```
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@strict
|
||||
@auto_docstring(custom_intro="A minimal configuration for testing that config docstrings only show own args.")
|
||||
class DummyStrictConfig(PretrainedConfig):
|
||||
r"""
|
||||
extra_param (`str`, *optional*, defaults to `"test"`):
|
||||
A custom parameter unique to this config, not inherited from PreTrainedConfig.
|
||||
"""
|
||||
|
||||
model_type = "dummy_strict_for_docstring_test"
|
||||
|
||||
hidden_size: int = 256
|
||||
num_layers: int = 6
|
||||
extra_param: str = "test"
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestFullDocstringGeneration(unittest.TestCase):
|
||||
"""
|
||||
End-to-end tests for @auto_docstring runtime docstring generation.
|
||||
|
||||
Tests validate complete docstrings with single assertEqual assertions to ensure structure,
|
||||
formatting, standard args, custom params, and TypedDict unrolling work correctly.
|
||||
"""
|
||||
|
||||
def test_strict_config_docstring_only_documents_own_args(self):
|
||||
"""Test that config docstrings only document own class annotations, not inherited PreTrainedConfig args."""
|
||||
self.maxDiff = None
|
||||
actual_docstring = DummyStrictConfig.__doc__
|
||||
expected_docstring = """A minimal configuration for testing that config docstrings only show own args.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to `256`):
|
||||
Dimension of the hidden representations.
|
||||
num_layers (`int`, *optional*, defaults to `6`):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
extra_param (`str`, *optional*, defaults to `"test"`):
|
||||
A custom parameter unique to this config, not inherited from PreTrainedConfig.
|
||||
"""
|
||||
self.assertEqual(actual_docstring, expected_docstring)
|
||||
|
||||
def test_dummy_model_complete_docstring(self):
|
||||
self.maxDiff = None
|
||||
"""Test complete class and forward method docstrings for PreTrainedModel with ModelArgs and custom parameters."""
|
||||
actual_class_docstring = DummyForTestModel.__doc__
|
||||
expected_class_docstring = """
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`DummyConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
self.assertEqual(actual_class_docstring, expected_class_docstring)
|
||||
|
||||
actual_docstring = DummyForTestModel.forward.__doc__
|
||||
|
||||
expected_docstring = """ The [`DummyForTestModel`] forward method, overrides the `__call__` special method.
|
||||
|
||||
<Tip>
|
||||
|
||||
Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
|
||||
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
|
||||
the latter silently ignores them.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
temperature (`float`, *optional*, defaults to 1.0):
|
||||
Temperature value for scaling logits during generation.
|
||||
custom_dict (`dict[str, Union[int, float]]`, *optional*):
|
||||
Custom dictionary parameter with string keys and numeric values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~modeling_outputs.CausalLMOutputWithPast`] or `tuple(torch.FloatTensor)`: A [`~modeling_outputs.CausalLMOutputWithPast`] or a tuple of
|
||||
`torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
|
||||
elements depending on the configuration ([`None`]) and inputs.
|
||||
|
||||
- **loss** (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) -- Language modeling loss (for next-token prediction).
|
||||
- **logits** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) -- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
- **past_key_values** (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`) -- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
- **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) -- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
- **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, DummyForTestModel
|
||||
>>> import torch
|
||||
|
||||
>>> model = DummyForTestModel.from_pretrained("dummy-model")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("dummy-model")
|
||||
>>> inputs = tokenizer("Hello world", return_tensors="pt")
|
||||
>>> outputs = model.forward(**inputs, temperature=0.7)
|
||||
>>> logits = outputs.logits
|
||||
```
|
||||
"""
|
||||
|
||||
self.assertEqual(actual_docstring, expected_docstring)
|
||||
|
||||
def test_dummy_processor_complete_docstring(self):
|
||||
self.maxDiff = None
|
||||
"""Test complete class and __call__ docstrings for ProcessorMixin with complex TypedDict kwargs unrolling."""
|
||||
|
||||
actual_docstring = DummyProcessorForTest.__call__.__doc__
|
||||
|
||||
expected_docstring = """ Args:
|
||||
images (`Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, list[PIL.Image.Image], list[numpy.ndarray], list[torch.Tensor]]`, *optional*):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
text (`Union[str, list[str], list[list[str]]]`, *optional*):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If you pass a pretokenized input, set `is_split_into_words=True` to avoid ambiguity with batched inputs.
|
||||
custom_processing_mode (`str`, *kwargs*, *optional*, defaults to `"standard"`):
|
||||
Custom processing mode for advanced text/image processing. Can be 'standard', 'enhanced', or 'experimental'.
|
||||
enable_advanced_features (`bool`, *kwargs*, *optional*, defaults to `False`):
|
||||
Whether to enable advanced processing features like custom tokenization strategies.
|
||||
custom_threshold (`float`, *kwargs*, *optional*, defaults to 0.5):
|
||||
Custom threshold value for filtering or processing decisions.
|
||||
output_format (`str`, *kwargs*, *optional*, defaults to `"default"`):
|
||||
Output format specification. Can be 'default', 'extended', or 'minimal'.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
**kwargs ([`ProcessingKwargs`], *optional*):
|
||||
Additional processing options for each modality (text, images, videos, audio). Model-specific parameters
|
||||
are listed above; see the TypedDict class for the complete list of supported arguments.
|
||||
|
||||
Returns:
|
||||
`~image_processing_base.BatchFeature`:
|
||||
- **data** (`dict`) -- Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
|
||||
- **tensor_type** (`Union[None, str, TensorType]`, *optional*) -- You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
|
||||
initialization.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import DummyProcessorForTest
|
||||
>>> processor = DummyProcessorForTest.from_pretrained("dummy-processor")
|
||||
>>> inputs = processor(text="Hello world", images=["image.jpg"], return_tensors="pt")
|
||||
```
|
||||
"""
|
||||
|
||||
self.assertEqual(actual_docstring, expected_docstring)
|
||||
|
||||
actual_class_docstring = DummyProcessorForTest.__doc__
|
||||
|
||||
expected_class_docstring = """Constructs a DummyProcessorForTest which wraps a image processor and a tokenizer into a single processor.
|
||||
|
||||
[`DummyProcessorForTest`] offers all the functionalities of [`image_processor_class`] and [`tokenizer_class`]. See the
|
||||
[`~image_processor_class`] and [`~tokenizer_class`] for more information.
|
||||
Parameters:
|
||||
image_processor (`image_processor_class`):
|
||||
The image processor is a required input.
|
||||
tokenizer (`tokenizer_class`):
|
||||
The tokenizer is a required input.
|
||||
custom_processing_mode (`str`, *optional*, defaults to `"standard"`):
|
||||
Custom processing mode for advanced text/image processing. Can be 'standard', 'enhanced', or 'experimental'.
|
||||
enable_advanced_features (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable advanced processing features like custom tokenization strategies.
|
||||
custom_threshold (`float`, *optional*, defaults to 0.5):
|
||||
Custom threshold value for filtering or processing decisions.
|
||||
output_format (`str`, *optional*, defaults to `"default"`):
|
||||
Output format specification. Can be 'default', 'extended', or 'minimal'.
|
||||
"""
|
||||
|
||||
self.assertEqual(actual_class_docstring, expected_class_docstring)
|
||||
|
||||
def test_dummy_image_processor_complete_docstring(self):
|
||||
self.maxDiff = None
|
||||
"""Test complete class and preprocess docstrings for DummyForTestImageProcessorFast with custom ImagesKwargs and custom_intro."""
|
||||
|
||||
actual_preprocess_docstring = DummyForTestImageProcessorFast.preprocess.__doc__
|
||||
|
||||
expected_preprocess_docstring = """ Args:
|
||||
images (`Union[PIL.Image.Image, numpy.ndarray, torch.Tensor, list[PIL.Image.Image], list[numpy.ndarray], list[torch.Tensor]]`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
image_grid_pinpoints (`list[list[int]]`, *kwargs*, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method.
|
||||
custom_scale (`float`, *kwargs*, *optional*, defaults to 255.0):
|
||||
Custom scale factor for preprocessing pipelines.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
Returns stacked tensors if set to `'pt'`, otherwise returns a list of tensors.
|
||||
**kwargs ([`ImagesKwargs`], *optional*):
|
||||
Additional image preprocessing options. Model-specific kwargs are listed above; see the TypedDict class
|
||||
for the complete list of supported arguments.
|
||||
|
||||
Returns:
|
||||
`~image_processing_base.BatchFeature`:
|
||||
- **data** (`dict`) -- Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
|
||||
- **tensor_type** (`Union[None, str, TensorType]`, *optional*) -- You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
|
||||
initialization.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import DummyForTestImageProcessorFast
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> processor = DummyForTestImageProcessorFast.from_pretrained("dummy-processor")
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> inputs = processor.preprocess(images=image, return_tensors="pt")
|
||||
```
|
||||
"""
|
||||
|
||||
self.assertEqual(actual_preprocess_docstring, expected_preprocess_docstring)
|
||||
|
||||
actual_class_docstring = DummyForTestImageProcessorFast.__doc__
|
||||
|
||||
expected_class_docstring = """
|
||||
Constructs a fast DummyForTest image processor.
|
||||
|
||||
Args:
|
||||
image_grid_pinpoints (`list[list[int]]`, *kwargs*, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method.
|
||||
custom_scale (`float`, *kwargs*, *optional*, defaults to 255.0):
|
||||
Custom scale factor for preprocessing pipelines.
|
||||
**kwargs ([`ImagesKwargs`], *optional*):
|
||||
Additional image preprocessing options. Model-specific kwargs are listed above; see the TypedDict class
|
||||
for the complete list of supported arguments.
|
||||
"""
|
||||
|
||||
self.assertEqual(actual_class_docstring, expected_class_docstring)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Performance tests for auto_docstring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAutoDocstringPerformance:
|
||||
"""
|
||||
Performance tests for auto_docstring.
|
||||
|
||||
The decorator runs at *class-definition / import time*, so with hundreds of
|
||||
models in the library the cumulative cost matters even though each individual
|
||||
call looks cheap. These tests assert an upper bound to catch regressions.
|
||||
"""
|
||||
|
||||
# Upper bound (%) of total import time that auto_docstring overhead may take.
|
||||
# Relative metric; robust across CI vs local. Catches serious regressions.
|
||||
AUTO_DOCSTRING_COST_PCT_UPPER_BOUND = 70.0
|
||||
|
||||
def test_auto_docstring_import_time_upper_bound(self):
|
||||
"""
|
||||
Asserts that auto_docstring overhead stays below a percentage of total
|
||||
import time.
|
||||
|
||||
Method
|
||||
------
|
||||
1. Collect ``modeling_*.py``, ``image_processing_*.py``, ``processing_*.py``
|
||||
under ``transformers/models``, then sample every 10th for speed.
|
||||
2. Warmup: import the sampled modules once so Python's bytecode cache is hot.
|
||||
3. Measure WITH auto_docstring: clear cache, re-import, median over 5 runs.
|
||||
4. Measure WITHOUT auto_docstring: noop-patch, clear cache, re-import, median.
|
||||
5. cost_pct = (real - noop) / real * 100; assert cost_pct < upper bound.
|
||||
"""
|
||||
if "transformers.utils" not in sys.modules:
|
||||
importlib.import_module("transformers.utils")
|
||||
_utils_module = sys.modules["transformers.utils"]
|
||||
|
||||
src_root = Path(__file__).resolve().parent.parent.parent / "src"
|
||||
models_dir = src_root / "transformers" / "models"
|
||||
all_modules: list[str] = []
|
||||
for pattern in ("modeling_*.py", "image_processing_*.py", "processing_*.py"):
|
||||
for f in sorted(models_dir.rglob(pattern)):
|
||||
rel = f.with_suffix("").relative_to(src_root)
|
||||
all_modules.append(".".join(rel.parts))
|
||||
model_modules = all_modules[::10]
|
||||
|
||||
def _clear():
|
||||
for key in [k for k in sys.modules if k.startswith("transformers.models")]:
|
||||
del sys.modules[key]
|
||||
|
||||
def _import_all():
|
||||
for mod in model_modules:
|
||||
try:
|
||||
importlib.import_module(mod)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
_import_all() # warmup
|
||||
|
||||
# With auto_docstring (real)
|
||||
times_real: list[float] = []
|
||||
for _ in range(5):
|
||||
_clear()
|
||||
t0 = time.perf_counter()
|
||||
_import_all()
|
||||
times_real.append(time.perf_counter() - t0)
|
||||
|
||||
# Without auto_docstring (noop patch)
|
||||
_orig = _utils_module.auto_docstring
|
||||
_noop = lambda x=None, **kw: (lambda f: f) if x is None else x # noqa: E731
|
||||
times_noop: list[float] = []
|
||||
for _ in range(5):
|
||||
_utils_module.auto_docstring = _noop
|
||||
try:
|
||||
_clear()
|
||||
t0 = time.perf_counter()
|
||||
_import_all()
|
||||
times_noop.append(time.perf_counter() - t0)
|
||||
finally:
|
||||
_utils_module.auto_docstring = _orig
|
||||
|
||||
median_real = statistics.median(times_real)
|
||||
median_noop = statistics.median(times_noop)
|
||||
cost_pct = (median_real - median_noop) / median_real * 100 if median_real > 0 else 0.0
|
||||
print(f"Cost percentage: {cost_pct:.1f}%")
|
||||
assert cost_pct < self.AUTO_DOCSTRING_COST_PCT_UPPER_BOUND, (
|
||||
f"auto_docstring cost {cost_pct:.1f}% of import time exceeds upper bound "
|
||||
f"{self.AUTO_DOCSTRING_COST_PCT_UPPER_BOUND}% "
|
||||
f"({len(model_modules)} of {len(all_modules)} modules)"
|
||||
)
|
||||
151
tests/utils/test_backbone_utils.py
Normal file
151
tests/utils/test_backbone_utils.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import PreTrainedConfig
|
||||
from transformers.backbone_utils import (
|
||||
BackboneConfigMixin,
|
||||
BackboneMixin,
|
||||
)
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
class AnyBackboneConfig(BackboneConfigMixin, PreTrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
stage_names: list | None = None,
|
||||
out_indices: list | None = None,
|
||||
out_features: list | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.stage_names = stage_names
|
||||
self.set_output_features_output_indices(out_features=out_features, out_indices=out_indices)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@require_torch
|
||||
class AnyBackbone(BackboneMixin, PreTrainedModel): ...
|
||||
|
||||
|
||||
class BackboneUtilsTester(unittest.TestCase):
|
||||
def test_get_aligned_output_features_output_indices(self):
|
||||
stage_names = ["a", "b", "c"]
|
||||
|
||||
# Defaults to last layer if both, `out_indices` and `out_features`, are None
|
||||
config = AnyBackboneConfig(stage_names)
|
||||
self.assertEqual(config.out_features, ["c"])
|
||||
self.assertEqual(config.out_indices, [2])
|
||||
|
||||
# Out indices set to match out features
|
||||
config = AnyBackboneConfig(stage_names=stage_names, out_features=["a", "c"])
|
||||
self.assertEqual(config.out_features, ["a", "c"])
|
||||
self.assertEqual(config.out_indices, [0, 2])
|
||||
|
||||
# Out features set to match out indices
|
||||
config = AnyBackboneConfig(stage_names=stage_names, out_indices=[0, 2])
|
||||
self.assertEqual(config.out_features, ["a", "c"])
|
||||
self.assertEqual(config.out_indices, [0, 2])
|
||||
|
||||
# Out features selected from negative indices
|
||||
config = AnyBackboneConfig(stage_names=stage_names, out_indices=[-3, -1])
|
||||
self.assertEqual(config.out_features, ["a", "c"])
|
||||
self.assertEqual(config.out_indices, [-3, -1])
|
||||
|
||||
def test_config_verify_out_features_out_indices(self):
|
||||
# Stage names must be set
|
||||
with pytest.raises(ValueError, match="Stage_names must be set for transformers backbones"):
|
||||
AnyBackboneConfig(stage_names=None, out_features=["a", "b"], out_indices=(0, 1))
|
||||
|
||||
# Out features must be a list
|
||||
with pytest.raises(ValueError, match="out_features must be a list got <class 'tuple'>"):
|
||||
AnyBackboneConfig(stage_names=["a", "b"], out_features=("a", "b"), out_indices=[0, 1])
|
||||
|
||||
# Out features must be a subset of stage names
|
||||
with pytest.raises(
|
||||
ValueError, match=r"out_features must be a subset of stage_names: \['a'\] got \['a', 'b'\]"
|
||||
):
|
||||
AnyBackboneConfig(stage_names=["a"], out_features=["a", "b"], out_indices=[0, 1])
|
||||
|
||||
# Out features must contain no duplicates
|
||||
with pytest.raises(ValueError, match=r"out_features must not contain any duplicates, got \['a', 'a'\]"):
|
||||
AnyBackboneConfig(stage_names=["a"], out_features=["a", "a"], out_indices=None)
|
||||
|
||||
# Out indices must be a list
|
||||
with pytest.raises(ValueError, match="out_indices must be a list, got <class 'int'>"):
|
||||
AnyBackboneConfig(stage_names=["a", "b"], out_features=None, out_indices=0)
|
||||
|
||||
# Out indices must be a subset of stage names
|
||||
with pytest.raises(
|
||||
ValueError, match=r"out_indices must be valid indices for stage_names \['a'\], got \[0, 1\]"
|
||||
):
|
||||
AnyBackboneConfig(stage_names=["a"], out_features=None, out_indices=[0, 1])
|
||||
|
||||
# Out indices must contain no duplicates
|
||||
with pytest.raises(ValueError, match=r"out_indices must not contain any duplicates, got \[0, 0\]"):
|
||||
AnyBackboneConfig(stage_names=["a"], out_features=None, out_indices=[0, 0])
|
||||
|
||||
# Out features and out indices must be the same length
|
||||
with pytest.raises(
|
||||
ValueError, match="out_features and out_indices should have the same length if both are set"
|
||||
):
|
||||
AnyBackboneConfig(stage_names=["a", "b", "c"], out_features=["a", "b"], out_indices=[0])
|
||||
|
||||
# Out features should match out indices
|
||||
with pytest.raises(
|
||||
ValueError, match="out_features and out_indices should correspond to the same stages if both are set"
|
||||
):
|
||||
AnyBackboneConfig(stage_names=["a", "b", "c"], out_features=["a", "b"], out_indices=[0, 2])
|
||||
|
||||
# Out features and out indices should be in order
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r"out_features must be in the same order as stage_names, expected \['a', 'b'\] got \['b', 'a'\]",
|
||||
):
|
||||
AnyBackboneConfig(stage_names=["a", "b"], out_features=["b", "a"], out_indices=[0, 1])
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r"out_indices must be in the same order as stage_names, expected \[-2, 1\] got \[1, -2\]"
|
||||
):
|
||||
AnyBackboneConfig(stage_names=["a", "b"], out_features=["a", "b"], out_indices=[1, -2])
|
||||
|
||||
# Check passes with valid inputs
|
||||
AnyBackboneConfig(stage_names=["a", "b", "c", "d"], out_features=["a", "b", "d"], out_indices=[0, 1, -1])
|
||||
|
||||
@require_torch
|
||||
def test_backbone_mixin(self):
|
||||
config = AnyBackboneConfig(stage_names=["a", "b", "c"], out_features=["a", "c"], out_indices=[0, 2])
|
||||
backbone = AnyBackbone(config)
|
||||
backbone.config = config
|
||||
|
||||
# Check that the output features and indices are set correctly
|
||||
self.assertEqual(backbone.out_features, ["a", "c"])
|
||||
self.assertEqual(backbone.out_indices, [0, 2])
|
||||
|
||||
# Check out features and indices are updated correctly
|
||||
backbone.out_features = ["a", "b"]
|
||||
self.assertEqual(backbone.out_features, ["a", "b"])
|
||||
self.assertEqual(backbone.out_indices, [0, 1])
|
||||
|
||||
backbone.out_indices = [-3, -1]
|
||||
self.assertEqual(backbone.out_features, ["a", "c"])
|
||||
self.assertEqual(backbone.out_indices, [-3, -1])
|
||||
1235
tests/utils/test_cache_utils.py
Normal file
1235
tests/utils/test_cache_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
623
tests/utils/test_chat_parsing_utils.py
Normal file
623
tests/utils/test_chat_parsing_utils.py
Normal file
@@ -0,0 +1,623 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.testing_utils import require_jmespath
|
||||
from transformers.utils.chat_parsing_utils import recursive_parse
|
||||
|
||||
|
||||
cohere_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string", "x-regex": r"<\|START_RESPONSE\|>(.*?)(?:<\|END_RESPONSE\|>|$)"},
|
||||
"thinking": {"type": "string", "x-regex": r"<\|START_THINKING\|>(.*?)(?:<\|END_THINKING\|>|$)"},
|
||||
"tool_calls": {
|
||||
"x-regex": r"<\|START_ACTION\|>(.*?)(?:<\|END_ACTION\|>|$)",
|
||||
"x-parser": "json",
|
||||
"x-parser-args": {
|
||||
"transform": "[*].{type: 'function', function: {name: tool_name, arguments: parameters}}"
|
||||
},
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"additionalProperties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ernie_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string", "x-regex": "<response>\n(.*?)\n?</response>"},
|
||||
"thinking": {"type": "string", "x-regex": r"(?:^|<think>\s*)(.*?)\s*<\/think>"},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": "<tool_call>(.*?)</tool_call>",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"x-parser": "json",
|
||||
"x-parser-args": {"transform": "{type: 'function', function: @}"},
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"additionalProperties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
gpt_oss_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"},
|
||||
"thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"x-regex": r"<\|message\|>(.*)",
|
||||
"x-parser": "json",
|
||||
"additionalProperties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
smollm_schema = {
|
||||
"x-regex": r"(?:<think>\n?(?P<thinking>.+?)\n?</think>)?\s*(?:<tool_call>(?P<tool_calls>.+?)</tool_call>)?\s*(?P<content>.+?)?\s*(?:<\|im_end\|>|$)",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string"},
|
||||
"thinking": {"type": "string"},
|
||||
"tool_calls": {
|
||||
"x-parser": "json",
|
||||
"x-parser-args": {"transform": "[{type: 'function', function: @}]"},
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"additionalProperties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
qwen3_schema = {
|
||||
"x-regex": r"^(?:(?:<think>)?\s*(?P<thinking>.+?)\s*</think>)?\s*(?:<tool_call>(?P<tool_calls>.*?)\s*</tool_call>)?\s*(?P<content>.+?)?\s*$",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string"},
|
||||
"thinking": {"type": "string"},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": r"^(.*)$", # We have already extracted tool calls and there can only be one, so just make it a list
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "x-regex": r"<function=(\w+)>"},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"x-regex-key-value": r"<parameter=(?P<key>\w+)>\n(?P<value>.*?)\n</parameter>",
|
||||
"additionalProperties": {
|
||||
"x-parser": "json",
|
||||
"x-parser-args": {"allow_non_json": True},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
re_sub_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"thinking": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": r"<\|tool_call>(.*?)<tool_call\|>",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"x-regex": r"call\:(?P<name>\w+)(?P<arguments>\{.*\})",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"x-regex-key-value": r'(?P<key>\w+):(?P<value><\|"\|>.*?<\|"\|>|[^,}]+)',
|
||||
"additionalProperties": {
|
||||
"x-regex-substitutions": [[r'^<\|"\|>|<\|"\|>$', ""]],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"x-regex": r"(\<\|channel\>thought\n(?P<thinking>.*?)\<channel\|\>)?(?P<content>(?:(?!\<\|tool_call\>).)+)?(?P<tool_calls>\<\|tool_call\>.*\<tool_call\|\>)?",
|
||||
}
|
||||
|
||||
gemma4_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"thinking": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
"tool_calls": {
|
||||
"x-regex-iterator": r"<\|tool_call>(.*?)<tool_call\|>",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"const": "function"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"x-regex": r"call\:(?P<name>\w+)(?P<arguments>\{.*\})",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
},
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"x-parser": "gemma4-tool-call",
|
||||
"additionalProperties": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"x-regex": r"(\<\|channel\>thought\n(?P<thinking>.*?)\<channel\|\>)?(?P<content>(?:(?!\<\|tool_call\>).)+)?(?P<tool_calls>\<\|tool_call\>.*\<tool_call\|\>)?",
|
||||
}
|
||||
|
||||
prefix_items_schema = {
|
||||
# Not intended to be "realistic", just checks that prefixItems can handle a heterogeneous array
|
||||
"x-regex-iterator": r"<block>(.*?)<\/block>",
|
||||
"type": "array",
|
||||
"prefixItems": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
{"type": "string"},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@require_jmespath
|
||||
class ChatSchemaParserTest(unittest.TestCase):
|
||||
def test_schema_save_load(self):
|
||||
# Has no schema by default
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.response_schema = ernie_schema
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tokenizer.save_pretrained(tmpdir)
|
||||
reloaded_tokenizer = AutoTokenizer.from_pretrained(tmpdir)
|
||||
self.assertEqual(reloaded_tokenizer.response_schema, ernie_schema)
|
||||
|
||||
def test_tokenizer_method(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
|
||||
parsed_chat = recursive_parse(model_out, cohere_schema)
|
||||
tokenizer.response_schema = cohere_schema
|
||||
tokenizer_parsed_chat = tokenizer.parse_response(model_out)
|
||||
self.assertEqual(tokenizer_parsed_chat, parsed_chat)
|
||||
|
||||
def test_batched_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
|
||||
tokenizer.response_schema = cohere_schema
|
||||
parsed_chat = tokenizer.parse_response(model_out)
|
||||
self.assertEqual(tokenizer.parse_response([model_out]), [parsed_chat])
|
||||
self.assertEqual(tokenizer.parse_response([model_out] * 2), [parsed_chat] * 2)
|
||||
|
||||
def test_token_id_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") # Need an actual tokenizer to encode
|
||||
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
|
||||
tokenizer.response_schema = cohere_schema
|
||||
parsed_chat = tokenizer.parse_response(model_out)
|
||||
tokenized_out = tokenizer(model_out).input_ids
|
||||
self.assertEqual(tokenizer.parse_response(tokenized_out), parsed_chat)
|
||||
self.assertEqual(tokenizer.parse_response([tokenized_out]), [parsed_chat])
|
||||
self.assertEqual(tokenizer.parse_response([tokenized_out] * 2), [parsed_chat] * 2)
|
||||
|
||||
def test_numpy_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") # Need an actual tokenizer to encode
|
||||
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
|
||||
tokenizer.response_schema = cohere_schema
|
||||
parsed_chat = tokenizer.parse_response(model_out)
|
||||
tokenized_out = tokenizer(model_out, return_tensors="np").input_ids
|
||||
self.assertEqual(tokenizer.parse_response(tokenized_out), [parsed_chat])
|
||||
|
||||
def test_tensor_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") # Need an actual tokenizer to encode
|
||||
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
|
||||
tokenizer.response_schema = cohere_schema
|
||||
parsed_chat = tokenizer.parse_response(model_out)
|
||||
tokenized_out = tokenizer(model_out, return_tensors="pt").input_ids
|
||||
self.assertEqual(tokenizer.parse_response(tokenized_out), [parsed_chat])
|
||||
|
||||
def test_cohere_template(self):
|
||||
model_out = '<|START_THINKING|>I should call a tool.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "simple_tool", "parameters": {"temperature_format": "Celsius"}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>'
|
||||
parsed_chat = recursive_parse(model_out, cohere_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": "I should call a tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "simple_tool", "arguments": {"temperature_format": "Celsius"}},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_ernie_template_with_tools(self):
|
||||
model_out = 'The user is asking about the weather in Paris today. Let me check the available tools. There\'s a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to "Paris". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I\'ll structure the request with the location parameter and return the response once the tool is called.\n</think>\n\n<tool_call>\n{"name": "get_current_temperature", "arguments": {"location": "Paris"}}\n</tool_call>\n</s>'
|
||||
parsed_chat = recursive_parse(model_out, ernie_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": "The user is asking about the weather in Paris today. Let me check the available tools. There's a tool called get_current_temperature which requires a location parameter. Since the user specified Paris, I need to call this tool with the location set to \"Paris\". I should make sure the argument is correctly formatted as a string. No other tools are available, so this is the right one to use. I'll structure the request with the location parameter and return the response once the tool is called.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "get_current_temperature", "arguments": {"location": "Paris"}},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_ernie_template_no_tools(self):
|
||||
model_out = "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.\n</think>\n\n<response>\nHello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!\n</response>\n</s>"
|
||||
parsed_chat = recursive_parse(model_out, ernie_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! I'm doing well, thank you for asking. How about you? Is there something specific you'd like help with today? I'm here to assist you with any questions or problems you have!",
|
||||
"thinking": "The user just greeted me with \"Hi! How are you?\" I need to respond in a friendly and helpful manner. Let me start by acknowledging their greeting. I should ask them how they're doing to engage in conversation.\n\nFirst, I'll say hello back and then ask how they're feeling. It's important to show genuine interest. Maybe mention that I'm here to help with anything they need. Keep the tone warm and positive. Let me make sure the response is concise but friendly. Alright, that should work.",
|
||||
},
|
||||
)
|
||||
|
||||
def test_gpt_oss_template_with_tool_call(self):
|
||||
model_out = '<|channel|>analysis<|message|>We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\n "location": "San Francisco, CA"\n}'
|
||||
parsed_chat = recursive_parse(model_out, gpt_oss_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": 'We need to respond in riddles. The user asks: "What is the weather like in SF?" We need to get the location of the user? The user explicitly asks about SF (San Francisco). So we need to get the current weather in San Francisco, CA. We need to call get_current_weather function. The developer instruction says "Always respond in riddles". So the final answer should be in a riddle form. But we need to call function to get weather data. So we should call get_current_weather with location "San Francisco, CA". Possibly specify format "celsius" (default). Let\'s do that.\n\nWe will call function get_current_weather.',
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "get_current_weather", "arguments": {"location": "San Francisco, CA"}},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_gpt_oss_template_no_tool_call(self):
|
||||
model_out = "<|channel|>analysis<|message|>User asks a simple math question: 2+2 = 4. Provide answer.<|end|><|start|>assistant<|channel|>final<|message|>2"
|
||||
parsed_chat = recursive_parse(model_out, gpt_oss_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "2",
|
||||
"thinking": "User asks a simple math question: 2+2 = 4. Provide answer.",
|
||||
},
|
||||
)
|
||||
|
||||
def test_smollm_template_thinking_and_tool_call(self):
|
||||
model_out = '<think>\nOkay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.\n</think>\n\n<tool_call>{"name": "greet_user", "arguments": {"greeting": "Hello! I\'m doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!"}}</tool_call>'
|
||||
parsed_chat = recursive_parse(model_out, smollm_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": 'Okay, the user said, "Hello! How are you?" I need to respond appropriately. Since this is the first message, I should greet them back and ask how I can assist. I should keep it friendly and open-ended. Let me make sure the response is welcoming and encourages them to share what they need help with. I\'ll avoid any technical jargon and keep it simple. Let me check for any typos and ensure the tone is positive.',
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "greet_user",
|
||||
"arguments": {
|
||||
"greeting": "Hello! I'm doing well, thanks for asking. How can I assist you today? Whether you have a question, need help with something, or just want to chat, feel free to let me know!"
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_smollm_template_tool_call_no_thinking(self):
|
||||
model_out = '<tool_call>{"name": "get_weather", "arguments": {"city": "Paris"}}</tool_call>'
|
||||
parsed_chat = recursive_parse(model_out, smollm_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"type": "function", "function": {"name": "get_weather", "arguments": {"city": "Paris"}}}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_smollm_template_thinking_no_tool_call(self):
|
||||
model_out = '<think>\nOkay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.</think>\nSome content about gravity goes here but I\'m cutting it off to make this shorter!'
|
||||
parsed_chat = recursive_parse(model_out, smollm_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Some content about gravity goes here but I'm cutting it off to make this shorter!",
|
||||
"thinking": 'Okay, the user asked, "Hey! Can you tell me about gravity?" Let me start by breaking down what they might be looking for. They probably want a basic understanding of gravity, maybe for a school project or just personal curiosity. I should explain what gravity is, how it works, and maybe some examples.',
|
||||
},
|
||||
)
|
||||
|
||||
def test_qwen3_tool_calls(self):
|
||||
model_out = '<tool_call>\n<function=get_weather>\n<parameter=locations>\n[{"country": "France", "city": "Paris"}]\n</parameter>\n<parameter=temp_units>\ncelsius\n</parameter>\n</function>\n</tool_call>'
|
||||
parsed_chat = recursive_parse(model_out, qwen3_schema)
|
||||
self.assertEqual(
|
||||
parsed_chat,
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"locations": [{"country": "France", "city": "Paris"}],
|
||||
"temp_units": "celsius",
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_re_sub_schema(self):
|
||||
"""Test that a schema doing re substitutions to enable JSON parsing works."""
|
||||
model_out = '<|channel>thought\nThe user is asking for the current temperature in Paris. I should check the available tools to see if there\'s a function that can provide this information.<channel|><|tool_call>call:get_current_temperature{detail_level:0,location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>'
|
||||
parsed = recursive_parse(model_out, re_sub_schema)
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": "The user is asking for the current temperature in Paris. I should check the available tools to see if there's a function that can provide this information.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_temperature",
|
||||
"arguments": {"detail_level": "0", "location": "Paris, France", "unit": "celsius"},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_gemma4_tool_call(self):
|
||||
model_out = '<|channel>thought\nThe user is asking for the current temperature in Paris. I should check the available tools to see if there\'s a function that can provide this information.<channel|><|tool_call>call:get_current_temperature{detail_level:0,location:<|"|>Paris, France<|"|>,unit:<|"|>celsius<|"|>}<tool_call|><|tool_response>'
|
||||
parsed = recursive_parse(model_out, gemma4_schema)
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": "The user is asking for the current temperature in Paris. I should check the available tools to see if there's a function that can provide this information.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_temperature",
|
||||
"arguments": {"detail_level": 0, "location": "Paris, France", "unit": "celsius"},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_gemma4_complex_tool_call(self):
|
||||
model_out = (
|
||||
"<|channel>thought\nLet me call the tool.<channel|>"
|
||||
'<|tool_call>call:foo{bool_value:true,list_value:[<|"|>foo<|"|>,<|"|>bar<|"|>],'
|
||||
'null_value:null,number_value:1,string_value:<|"|>foo<|"|>,'
|
||||
'struct_value:{foo:<|"|>bar<|"|>}}<tool_call|>'
|
||||
)
|
||||
parsed = recursive_parse(model_out, gemma4_schema)
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
{
|
||||
"role": "assistant",
|
||||
"thinking": "Let me call the tool.",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "foo",
|
||||
"arguments": {
|
||||
"bool_value": True,
|
||||
"list_value": ["foo", "bar"],
|
||||
"null_value": None,
|
||||
"number_value": 1,
|
||||
"string_value": "foo",
|
||||
"struct_value": {"foo": "bar"},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_required_fields_present(self):
|
||||
"""Test that required fields pass validation when present in the output."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["role", "content"],
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string", "x-regex": r"<response>(.*?)</response>"},
|
||||
"thinking": {"type": "string", "x-regex": r"<think>(.*?)</think>"},
|
||||
},
|
||||
}
|
||||
model_out = "<think>Let me think.</think><response>Hello!</response>"
|
||||
parsed = recursive_parse(model_out, schema)
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
{"role": "assistant", "content": "Hello!", "thinking": "Let me think."},
|
||||
)
|
||||
|
||||
def test_required_field_missing_raises(self):
|
||||
"""Test that a missing required field raises ValueError with a helpful message."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["role", "content"],
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string", "x-regex": r"<response>(.*?)</response>"},
|
||||
"thinking": {"type": "string", "x-regex": r"<think>(.*?)</think>"},
|
||||
},
|
||||
}
|
||||
# This output has thinking but no <response> tags, so content will be missing
|
||||
model_out = "<think>Let me think about this.</think>Some plain text without response tags"
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
recursive_parse(model_out, schema)
|
||||
self.assertIn("content", str(cm.exception))
|
||||
self.assertIn("missing", str(cm.exception).lower())
|
||||
|
||||
def test_required_not_enforced_when_absent(self):
|
||||
"""Test that schemas without 'required' still silently omit missing fields."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {"const": "assistant"},
|
||||
"content": {"type": "string", "x-regex": r"<response>(.*?)</response>"},
|
||||
"thinking": {"type": "string", "x-regex": r"<think>(.*?)</think>"},
|
||||
},
|
||||
}
|
||||
# No <response> tags, but content is not required — should succeed
|
||||
model_out = "<think>Just thinking.</think>"
|
||||
parsed = recursive_parse(model_out, schema)
|
||||
self.assertEqual(parsed, {"role": "assistant", "thinking": "Just thinking."})
|
||||
|
||||
def test_prefix_items(self):
|
||||
model_out = "<block>hello</block><block>42</block><block>world</block>"
|
||||
parsed = recursive_parse(model_out, prefix_items_schema)
|
||||
self.assertEqual(parsed, ["hello", 42, "world"])
|
||||
|
||||
def test_prefix_items_wrong_length_raises(self):
|
||||
model_out = "<block>hello</block><block>42</block>"
|
||||
with self.assertRaises(ValueError):
|
||||
recursive_parse(model_out, prefix_items_schema)
|
||||
|
||||
def test_prefix_items_wrong_type_raises(self):
|
||||
model_out = "<block>hello</block><block>world</block><block>42</block>"
|
||||
with self.assertRaises(ValueError):
|
||||
recursive_parse(model_out, prefix_items_schema)
|
||||
|
||||
def test_type_any_passthrough(self):
|
||||
"""Test that type 'any' passes content through without transformation."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"x-regex": r"<data>(?P<value>.*?)</data>",
|
||||
"properties": {
|
||||
"value": {"type": "any"},
|
||||
},
|
||||
}
|
||||
model_out = "<data>some arbitrary content 123</data>"
|
||||
parsed = recursive_parse(model_out, schema)
|
||||
self.assertEqual(parsed, {"value": "some arbitrary content 123"})
|
||||
|
||||
def test_type_any_in_additional_properties(self):
|
||||
"""Test that type 'any' works in additionalProperties, matching the docs example."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"x-parser": "json",
|
||||
"additionalProperties": {"type": "any"},
|
||||
}
|
||||
node_content = '{"location": "San Francisco, CA", "units": "celsius"}'
|
||||
parsed = recursive_parse(node_content, schema)
|
||||
self.assertEqual(parsed, {"location": "San Francisco, CA", "units": "celsius"})
|
||||
613
tests/utils/test_chat_template_utils.py
Normal file
613
tests/utils/test_chat_template_utils.py
Normal file
@@ -0,0 +1,613 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from typing import Literal
|
||||
|
||||
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
||||
|
||||
|
||||
class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
def test_simple_function(self):
|
||||
def fn(x: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_no_arguments(self):
|
||||
def fn():
|
||||
"""
|
||||
Test function
|
||||
"""
|
||||
return True
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_union(self):
|
||||
def fn(x: int | float):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": ["integer", "number"], "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_optional(self):
|
||||
def fn(x: int | None):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input", "nullable": True}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_default_arg(self):
|
||||
def fn(x: int = 42):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_nested_list(self):
|
||||
def fn(x: list[list[str | int]]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "array",
|
||||
"items": {"type": "array", "items": {"type": ["integer", "string"]}},
|
||||
"description": "The input",
|
||||
}
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiple_arguments(self):
|
||||
def fn(x: int, y: str):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
y: Also the input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "integer", "description": "The input"},
|
||||
"y": {"type": "string", "description": "Also the input"},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiple_complex_arguments(self):
|
||||
def fn(x: list[int | float], y: int | str | None = None):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
y: Also the input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"},
|
||||
"y": {
|
||||
"type": ["integer", "string"],
|
||||
"nullable": True,
|
||||
"description": "Also the input",
|
||||
},
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_missing_docstring(self):
|
||||
def fn(x: int):
|
||||
return x
|
||||
|
||||
with self.assertRaises(DocstringParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_missing_param_docstring(self):
|
||||
def fn(x: int):
|
||||
"""
|
||||
Test function
|
||||
"""
|
||||
return x
|
||||
|
||||
with self.assertRaises(DocstringParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_missing_type_hint(self):
|
||||
def fn(x):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
with self.assertRaises(TypeHintParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_return_value(self):
|
||||
def fn(x: int) -> int:
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
"return": {"type": "integer"},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_return_value_docstring(self):
|
||||
def fn(x: int) -> int:
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
"return": {"type": "integer", "description": "The output"},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_tuple(self):
|
||||
def fn(x: tuple[int, str]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "array",
|
||||
"prefixItems": [{"type": "integer"}, {"type": "string"}],
|
||||
"description": "The input",
|
||||
}
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_single_element_tuple_fails(self):
|
||||
def fn(x: tuple[int]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
# Single-element tuples should just be the type itself, or List[type] for variable-length inputs
|
||||
with self.assertRaises(TypeHintParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_ellipsis_type_fails(self):
|
||||
def fn(x: tuple[int, ...]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
# Variable length inputs should be specified with List[type], not Tuple[type, ...]
|
||||
with self.assertRaises(TypeHintParsingException):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_enum_extraction(self):
|
||||
def fn(temperature_format: str):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"])
|
||||
|
||||
|
||||
Returns:
|
||||
The temperature
|
||||
"""
|
||||
return -40.0
|
||||
|
||||
# Let's see if that gets correctly parsed as an enum
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature_format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature format to use",
|
||||
}
|
||||
},
|
||||
"required": ["temperature_format"],
|
||||
},
|
||||
}
|
||||
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_literal(self):
|
||||
def fn(
|
||||
temperature_format: Literal["celsius", "fahrenheit"],
|
||||
booleanish: Literal[True, False, 0, 1, "y", "n"] = False,
|
||||
):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
temperature_format: The temperature format to use
|
||||
booleanish: A value that can be regarded as boolean
|
||||
|
||||
|
||||
Returns:
|
||||
The temperature
|
||||
"""
|
||||
return -40.0
|
||||
|
||||
# Let's see if that gets correctly parsed as an enum
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature_format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature format to use",
|
||||
},
|
||||
"booleanish": {
|
||||
"type": ["boolean", "integer", "string"],
|
||||
"enum": [True, False, 0, 1, "y", "n"],
|
||||
"description": "A value that can be regarded as boolean",
|
||||
},
|
||||
},
|
||||
"required": ["temperature_format"],
|
||||
},
|
||||
}
|
||||
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiline_docstring_with_types(self):
|
||||
def fn(x: int, y: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The first input
|
||||
|
||||
y: The second input. This is a longer description
|
||||
that spans multiple lines with indentation and stuff.
|
||||
|
||||
Returns:
|
||||
God knows what
|
||||
"""
|
||||
pass
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "integer", "description": "The first input"},
|
||||
"y": {
|
||||
"type": "integer",
|
||||
"description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.",
|
||||
},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
}
|
||||
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_return_none(self):
|
||||
def fn(x: int) -> None:
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The first input
|
||||
"""
|
||||
pass
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "integer", "description": "The first input"},
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
"return": {"type": "null"},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_instance_method(self):
|
||||
class Tool:
|
||||
def fn(self, x: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(get_json_schema(Tool.fn)["function"], expected_schema) # unbound case
|
||||
self.assertEqual(get_json_schema(Tool().fn)["function"], expected_schema) # bound case
|
||||
|
||||
def test_static_method(self):
|
||||
class Tool:
|
||||
@staticmethod
|
||||
def fn(x: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(get_json_schema(Tool.fn)["function"], expected_schema)
|
||||
self.assertEqual(get_json_schema(Tool().fn)["function"], expected_schema)
|
||||
|
||||
def test_class_method(self):
|
||||
class Tool:
|
||||
@classmethod
|
||||
def fn(cls, x: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(get_json_schema(Tool.fn)["function"], expected_schema)
|
||||
self.assertEqual(get_json_schema(Tool().fn)["function"], expected_schema)
|
||||
|
||||
def test_everything_all_at_once(self):
|
||||
def fn(x: str, y: list[str | int] | None, z: tuple[str | int, str] = (42, "hello")) -> tuple[int, str]:
|
||||
"""
|
||||
Test function with multiple args, and docstring args that we have to strip out.
|
||||
|
||||
Args:
|
||||
x: The first input. It's got a big multiline
|
||||
description and also contains
|
||||
(choices: ["a", "b", "c"])
|
||||
|
||||
y: The second input. It's a big list with a single-line description.
|
||||
|
||||
z: The third input. It's some kind of tuple with a default arg.
|
||||
|
||||
Returns:
|
||||
The output. The return description is also a big multiline
|
||||
description that spans multiple lines.
|
||||
"""
|
||||
pass
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function with multiple args, and docstring args that we have to strip out.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b", "c"],
|
||||
"description": "The first input. It's got a big multiline description and also contains",
|
||||
},
|
||||
"y": {
|
||||
"type": "array",
|
||||
"items": {"type": ["integer", "string"]},
|
||||
"nullable": True,
|
||||
"description": "The second input. It's a big list with a single-line description.",
|
||||
},
|
||||
"z": {
|
||||
"type": "array",
|
||||
"prefixItems": [{"type": ["integer", "string"]}, {"type": "string"}],
|
||||
"description": "The third input. It's some kind of tuple with a default arg.",
|
||||
},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
"return": {
|
||||
"type": "array",
|
||||
"prefixItems": [{"type": "integer"}, {"type": "string"}],
|
||||
"description": "The output. The return description is also a big multiline\n description that spans multiple lines.",
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
399
tests/utils/test_configuration_utils.py
Normal file
399
tests/utils/test_configuration_utils.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config
|
||||
from transformers.configuration_utils import PreTrainedConfig
|
||||
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, require_torch
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_configuration import CustomConfig # noqa E402
|
||||
|
||||
|
||||
config_common_kwargs = {
|
||||
"return_dict": False,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"dtype": "float16",
|
||||
"chunk_size_feed_forward": 5,
|
||||
"architectures": ["BertModel"],
|
||||
"id2label": {0: "label"},
|
||||
"label2id": {"label": "0"},
|
||||
"problem_type": "regression",
|
||||
}
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
config.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
config.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_in_organization_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_dynamic_config(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
CustomConfig.register_for_auto_class()
|
||||
config = CustomConfig(attribute=42)
|
||||
|
||||
config.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
|
||||
|
||||
new_config = AutoConfig.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
|
||||
self.assertEqual(new_config.attribute, 42)
|
||||
|
||||
|
||||
class ConfigTestUtils(unittest.TestCase):
|
||||
def test_config_from_string(self):
|
||||
c = GPT2Config()
|
||||
|
||||
# attempt to modify each of int/float/bool/str config records and verify they were updated
|
||||
n_embd = c.n_embd + 1 # int
|
||||
resid_pdrop = c.resid_pdrop + 1.0 # float
|
||||
scale_attn_weights = not c.scale_attn_weights # bool
|
||||
summary_type = c.summary_type + "foo" # str
|
||||
c.update_from_string(
|
||||
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
|
||||
)
|
||||
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
|
||||
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
|
||||
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
|
||||
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
|
||||
|
||||
def test_config_common_kwargs_is_complete(self):
|
||||
base_config = PreTrainedConfig()
|
||||
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
||||
# If this part of the test fails, you have arguments to add in config_common_kwargs above.
|
||||
self.assertListEqual(
|
||||
missing_keys,
|
||||
[
|
||||
"transformers_version",
|
||||
"is_encoder_decoder",
|
||||
"_name_or_path",
|
||||
"_commit_hash",
|
||||
"_output_attentions",
|
||||
"_attn_implementation_internal",
|
||||
"_experts_implementation_internal",
|
||||
],
|
||||
)
|
||||
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
||||
if len(keys_with_defaults) > 0:
|
||||
raise ValueError(
|
||||
"The following keys are set with the default values in"
|
||||
" `test_configuration_common.config_common_kwargs` pick another value for them:"
|
||||
f" {', '.join(keys_with_defaults)}."
|
||||
)
|
||||
|
||||
def test_nested_config_load_from_dict(self):
|
||||
config = AutoConfig.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-CLIPModel", text_config={"num_hidden_layers": 2}
|
||||
)
|
||||
self.assertNotIsInstance(config.text_config, dict)
|
||||
self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig")
|
||||
|
||||
def test_from_pretrained_subfolder(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
|
||||
self.assertIsNotNone(config)
|
||||
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
|
||||
self.assertIsNotNone(config)
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"failed", request=mock.Mock(), response=mock.Mock()
|
||||
)
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
|
||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_local_versioning(self):
|
||||
configuration = AutoConfig.from_pretrained("google-bert/bert-base-cased")
|
||||
configuration.configuration_files = ["config.4.0.0.json"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
configuration.save_pretrained(tmp_dir)
|
||||
configuration.hidden_size = 2
|
||||
json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
|
||||
|
||||
# This should pick the new configuration file as the version of Transformers is > 4.0.0
|
||||
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_configuration.hidden_size, 2)
|
||||
|
||||
# Will need to be adjusted if we reach v42 and this test is still here.
|
||||
# Should pick the old configuration file as the version of Transformers is < 4.42.0
|
||||
configuration.configuration_files = ["config.42.0.0.json"]
|
||||
configuration.hidden_size = 768
|
||||
configuration.save_pretrained(tmp_dir)
|
||||
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
|
||||
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_configuration.hidden_size, 768)
|
||||
|
||||
def test_repo_versioning_before(self):
|
||||
# This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
|
||||
repo = "hf-internal-testing/test-two-configs"
|
||||
|
||||
import transformers as new_transformers
|
||||
|
||||
# Matt: Use a context manager to ensure everything is correctly reverted and we
|
||||
# don't leak state between tests
|
||||
with mock.patch.object(new_transformers.configuration_utils, "__version__", "v4.0.0"):
|
||||
new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained(
|
||||
repo, return_unused_kwargs=True
|
||||
)
|
||||
self.assertEqual(new_configuration.hidden_size, 2)
|
||||
# This checks `_configuration_file` ia not kept in the kwargs by mistake.
|
||||
self.assertDictEqual(kwargs, {})
|
||||
|
||||
# Testing an older version by monkey-patching the version in the module it's used.
|
||||
import transformers as old_transformers
|
||||
|
||||
with mock.patch.object(old_transformers.configuration_utils, "__version__", "v3.0.0"):
|
||||
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||
self.assertEqual(old_configuration.hidden_size, 768)
|
||||
|
||||
def test_saving_config_with_custom_generation_kwargs_raises_error(self):
|
||||
config = BertConfig()
|
||||
config.min_length = 3 # `min_length = 3` is a non-default generation kwarg
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertRaises(ValueError):
|
||||
config.save_pretrained(tmp_dir)
|
||||
|
||||
def test_get_generation_parameters(self):
|
||||
config = BertConfig()
|
||||
self.assertFalse(len(config._get_generation_parameters()) > 0)
|
||||
config.min_length = 3
|
||||
self.assertTrue(len(config._get_generation_parameters()) > 0)
|
||||
config.min_length = 0
|
||||
self.assertTrue(len(config._get_generation_parameters()) > 0)
|
||||
|
||||
def test_loading_config_do_not_raise_future_warnings(self):
|
||||
"""Regression test for https://github.com/huggingface/transformers/issues/31002."""
|
||||
# Loading config should not raise a FutureWarning. It was the case before.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
PreTrainedConfig.from_pretrained("bert-base-uncased")
|
||||
|
||||
def test_get_text_config(self):
|
||||
"""Tests the `get_text_config` method."""
|
||||
# 1. model with only text input -> returns the original config instance
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
|
||||
self.assertEqual(config.get_text_config(), config)
|
||||
self.assertEqual(config.get_text_config(decoder=True), config)
|
||||
|
||||
# 2. composite model (VLM) -> returns the text component
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlavaForConditionalGeneration")
|
||||
self.assertEqual(config.get_text_config(), config.text_config)
|
||||
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
|
||||
|
||||
# 3. ! corner case! : composite model whose sub-config is an old composite model (should behave as above)
|
||||
config = Florence2Config()
|
||||
self.assertEqual(config.get_text_config(), config.text_config)
|
||||
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
|
||||
|
||||
# 4. old composite model -> may remove components based on the `decoder` or `encoder` argument
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
self.assertEqual(config.get_text_config(), config)
|
||||
# both encoder_layers and decoder_layers exist
|
||||
self.assertTrue(getattr(config, "encoder_ffn_dim", None) is not None)
|
||||
self.assertTrue(getattr(config, "decoder_ffn_dim", None) is not None)
|
||||
|
||||
decoder_config = config.get_text_config(decoder=True)
|
||||
self.assertNotEqual(decoder_config, config)
|
||||
self.assertEqual(decoder_config.num_hidden_layers, config.decoder_layers)
|
||||
|
||||
encoder_config = config.get_text_config(encoder=True)
|
||||
self.assertNotEqual(encoder_config, config)
|
||||
self.assertEqual(encoder_config.num_hidden_layers, config.encoder_layers)
|
||||
|
||||
@require_torch
|
||||
def test_bc_torch_dtype(self):
|
||||
import torch
|
||||
|
||||
config = PreTrainedConfig(dtype="bfloat16")
|
||||
self.assertEqual(config.dtype, torch.bfloat16)
|
||||
|
||||
config = PreTrainedConfig(torch_dtype="bfloat16")
|
||||
self.assertEqual(config.dtype, torch.bfloat16)
|
||||
|
||||
# Check that if we pass both, `dtype` is used
|
||||
config = PreTrainedConfig(dtype="bfloat16", torch_dtype="float32")
|
||||
self.assertEqual(config.dtype, torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
config.save_pretrained(tmpdirname)
|
||||
|
||||
config = PreTrainedConfig.from_pretrained(tmpdirname)
|
||||
self.assertEqual(config.dtype, torch.bfloat16)
|
||||
|
||||
config = PreTrainedConfig.from_pretrained(tmpdirname, dtype="float32")
|
||||
self.assertEqual(config.dtype, "float32")
|
||||
|
||||
config = PreTrainedConfig.from_pretrained(tmpdirname, torch_dtype="float32")
|
||||
self.assertEqual(config.dtype, "float32")
|
||||
|
||||
def test_unserializable_json_is_encoded(self):
|
||||
class NewConfig(PreTrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
inf_positive: float = float("inf"),
|
||||
inf_negative: float = float("-inf"),
|
||||
nan: float = float("nan"),
|
||||
**kwargs,
|
||||
):
|
||||
self.inf_positive = inf_positive
|
||||
self.inf_negative = inf_negative
|
||||
self.nan = nan
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
new_config = NewConfig()
|
||||
|
||||
# All floats should remain as floats when being accessed in the config
|
||||
self.assertIsInstance(new_config.inf_positive, float)
|
||||
self.assertIsInstance(new_config.inf_negative, float)
|
||||
self.assertIsInstance(new_config.nan, float)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
new_config.save_pretrained(tmpdirname)
|
||||
config_file = Path(tmpdirname) / "config.json"
|
||||
config_contents = json.loads(config_file.read_text())
|
||||
new_config_instance = NewConfig.from_pretrained(tmpdirname)
|
||||
|
||||
# In the serialized JSON file, the non-JSON compatible floats should be updated
|
||||
self.assertDictEqual(config_contents["inf_positive"], {"__float__": "Infinity"})
|
||||
self.assertDictEqual(config_contents["inf_negative"], {"__float__": "-Infinity"})
|
||||
self.assertDictEqual(config_contents["nan"], {"__float__": "NaN"})
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
new_config.save_pretrained(tmpdirname)
|
||||
|
||||
# When reloading the config, it should have correct float values
|
||||
self.assertIsInstance(new_config_instance.inf_positive, float)
|
||||
self.assertIsInstance(new_config_instance.inf_negative, float)
|
||||
self.assertIsInstance(new_config_instance.nan, float)
|
||||
|
||||
|
||||
class ConfigSubclassKwOnlyTest(unittest.TestCase):
|
||||
"""Test that config subclasses with non-default fields following parent default fields
|
||||
no longer raise TypeError (fixed by kw_only=True in __init_subclass__). Regression
|
||||
test for https://github.com/huggingface/transformers/issues/XXXX."""
|
||||
|
||||
def test_subclass_non_default_field_after_default(self):
|
||||
"""A config subclass adding a required field after parent defaults must not raise."""
|
||||
|
||||
class MyConfig(PreTrainedConfig):
|
||||
pooling: str # no default — would fail under Python dataclass ordering rules
|
||||
|
||||
# Should construct without TypeError
|
||||
cfg = MyConfig(pooling="mean")
|
||||
self.assertEqual(cfg.pooling, "mean")
|
||||
|
||||
def test_subclass_multiple_non_default_fields(self):
|
||||
"""Multiple non-default fields in the subclass should all work."""
|
||||
|
||||
class EmbedConfig(PreTrainedConfig):
|
||||
dim: int
|
||||
pooling: str
|
||||
|
||||
cfg = EmbedConfig(dim=128, pooling="cls")
|
||||
self.assertEqual(cfg.dim, 128)
|
||||
self.assertEqual(cfg.pooling, "cls")
|
||||
|
||||
def test_inherited_defaults_still_work(self):
|
||||
"""Inherited fields with defaults must still be accessible."""
|
||||
from transformers import BertConfig
|
||||
|
||||
class BertWithPooling(BertConfig):
|
||||
pooling: str
|
||||
|
||||
cfg = BertWithPooling(pooling="mean", hidden_size=256)
|
||||
self.assertEqual(cfg.pooling, "mean")
|
||||
self.assertEqual(cfg.hidden_size, 256)
|
||||
39
tests/utils/test_convert_slow_tokenizer.py
Normal file
39
tests/utils/test_convert_slow_tokenizer.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import unittest
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
|
||||
from transformers.convert_slow_tokenizer import SpmConverter
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeOriginalTokenizer:
|
||||
vocab_file: str
|
||||
|
||||
|
||||
class ConvertSlowTokenizerTest(unittest.TestCase):
|
||||
def test_spm_converter_bytefallback_warning(self):
|
||||
spm_model_file_without_bytefallback = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
spm_model_file_with_bytefallback = get_tests_dir("fixtures/test_sentencepiece_with_bytefallback.model")
|
||||
|
||||
original_tokenizer_without_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_without_bytefallback)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
_ = SpmConverter(original_tokenizer_without_bytefallback)
|
||||
# We are looking for if there is any `UserWarning` with
|
||||
# `The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers.`
|
||||
w = [x for x in w if x.category.__name__ != "DeprecationWarning"]
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
original_tokenizer_with_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_with_bytefallback)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
_ = SpmConverter(original_tokenizer_with_bytefallback)
|
||||
w = [x for x in w if x.category.__name__ != "DeprecationWarning"]
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
self.assertIn(
|
||||
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||
" which is not implemented in the fast tokenizers.",
|
||||
str(w[0].message),
|
||||
)
|
||||
1385
tests/utils/test_core_model_loading.py
Normal file
1385
tests/utils/test_core_model_loading.py
Normal file
File diff suppressed because it is too large
Load Diff
197
tests/utils/test_deprecation.py
Normal file
197
tests/utils/test_deprecation.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import __version__, is_torch_available
|
||||
from transformers.testing_utils import require_torch_accelerator, torch_device
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
INFINITE_VERSION = "9999.0.0"
|
||||
|
||||
|
||||
class DeprecationDecoratorTester(unittest.TestCase):
|
||||
def test_rename_kwarg(self):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
@deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION)
|
||||
def dummy_function(new_name=None, other_name=None):
|
||||
return new_name, other_name
|
||||
|
||||
# Test keyword argument is renamed
|
||||
value, other_value = dummy_function(deprecated_name="old_value")
|
||||
self.assertEqual(value, "old_value")
|
||||
self.assertIsNone(other_value)
|
||||
|
||||
# Test deprecated keyword argument not passed
|
||||
value, other_value = dummy_function(new_name="new_value")
|
||||
self.assertEqual(value, "new_value")
|
||||
self.assertIsNone(other_value)
|
||||
|
||||
# Test other keyword argument
|
||||
value, other_value = dummy_function(other_name="other_value")
|
||||
self.assertIsNone(value)
|
||||
self.assertEqual(other_value, "other_value")
|
||||
|
||||
# Test deprecated and new args are passed, the new one should be returned
|
||||
value, other_value = dummy_function(deprecated_name="old_value", new_name="new_value")
|
||||
self.assertEqual(value, "new_value")
|
||||
self.assertIsNone(other_value)
|
||||
|
||||
def test_rename_multiple_kwargs(self):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
@deprecate_kwarg("deprecated_name1", new_name="new_name1", version=INFINITE_VERSION)
|
||||
@deprecate_kwarg("deprecated_name2", new_name="new_name2", version=INFINITE_VERSION)
|
||||
def dummy_function(new_name1=None, new_name2=None, other_name=None):
|
||||
return new_name1, new_name2, other_name
|
||||
|
||||
# Test keyword argument is renamed
|
||||
value1, value2, other_value = dummy_function(deprecated_name1="old_value1", deprecated_name2="old_value2")
|
||||
self.assertEqual(value1, "old_value1")
|
||||
self.assertEqual(value2, "old_value2")
|
||||
self.assertIsNone(other_value)
|
||||
|
||||
# Test deprecated keyword argument is not passed
|
||||
value1, value2, other_value = dummy_function(new_name1="new_value1", new_name2="new_value2")
|
||||
self.assertEqual(value1, "new_value1")
|
||||
self.assertEqual(value2, "new_value2")
|
||||
self.assertIsNone(other_value)
|
||||
|
||||
# Test other keyword argument is passed and correctly returned
|
||||
value1, value2, other_value = dummy_function(other_name="other_value")
|
||||
self.assertIsNone(value1)
|
||||
self.assertIsNone(value2)
|
||||
self.assertEqual(other_value, "other_value")
|
||||
|
||||
def test_warnings(self):
|
||||
# Test warning is raised for future version
|
||||
@deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION)
|
||||
def dummy_function(new_name=None, other_name=None):
|
||||
return new_name, other_name
|
||||
|
||||
with self.assertWarns(FutureWarning):
|
||||
dummy_function(deprecated_name="old_value")
|
||||
|
||||
# Test warning is not raised for past version, but arg is still renamed
|
||||
@deprecate_kwarg("deprecated_name", new_name="new_name", version="0.0.0")
|
||||
def dummy_function(new_name=None, other_name=None):
|
||||
return new_name, other_name
|
||||
|
||||
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
value, other_value = dummy_function(deprecated_name="old_value")
|
||||
|
||||
self.assertEqual(value, "old_value")
|
||||
self.assertIsNone(other_value)
|
||||
self.assertEqual(len(raised_warnings), 0, f"Warning raised: {[w.message for w in raised_warnings]}")
|
||||
|
||||
# Test warning is raised for future version if warn_if_greater_or_equal_version is set
|
||||
@deprecate_kwarg("deprecated_name", version="0.0.0", warn_if_greater_or_equal_version=True)
|
||||
def dummy_function(deprecated_name=None):
|
||||
return deprecated_name
|
||||
|
||||
with self.assertWarns(FutureWarning):
|
||||
value = dummy_function(deprecated_name="deprecated_value")
|
||||
self.assertEqual(value, "deprecated_value")
|
||||
|
||||
# Test arg is not renamed if new_name is not specified, but warning is raised
|
||||
@deprecate_kwarg("deprecated_name", version=INFINITE_VERSION)
|
||||
def dummy_function(deprecated_name=None):
|
||||
return deprecated_name
|
||||
|
||||
with self.assertWarns(FutureWarning):
|
||||
value = dummy_function(deprecated_name="deprecated_value")
|
||||
self.assertEqual(value, "deprecated_value")
|
||||
|
||||
def test_raises(self):
|
||||
# Test if deprecated name and new name are both passed and raise_if_both_names is set -> raise error
|
||||
@deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION, raise_if_both_names=True)
|
||||
def dummy_function(new_name=None, other_name=None):
|
||||
return new_name, other_name
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
dummy_function(deprecated_name="old_value", new_name="new_value")
|
||||
|
||||
# Test for current version == deprecation version
|
||||
@deprecate_kwarg("deprecated_name", version=__version__, raise_if_greater_or_equal_version=True)
|
||||
def dummy_function(deprecated_name=None):
|
||||
return deprecated_name
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
dummy_function(deprecated_name="old_value")
|
||||
|
||||
# Test for current version > deprecation version
|
||||
@deprecate_kwarg("deprecated_name", version="0.0.0", raise_if_greater_or_equal_version=True)
|
||||
def dummy_function(deprecated_name=None):
|
||||
return deprecated_name
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
dummy_function(deprecated_name="old_value")
|
||||
|
||||
def test_additional_message(self):
|
||||
# Test additional message is added to the warning
|
||||
@deprecate_kwarg("deprecated_name", version=INFINITE_VERSION, additional_message="Additional message")
|
||||
def dummy_function(deprecated_name=None):
|
||||
return deprecated_name
|
||||
|
||||
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||
warnings.simplefilter("always")
|
||||
dummy_function(deprecated_name="old_value")
|
||||
|
||||
self.assertTrue("Additional message" in str(raised_warnings[0].message))
|
||||
|
||||
@parameterized.expand(["0.0.0", __version__, INFINITE_VERSION])
|
||||
def test_warning_for_both_names(self, version):
|
||||
# We should raise warning if both names are passed for any specified version
|
||||
@deprecate_kwarg("deprecated_name", new_name="new_name", version=version)
|
||||
def dummy_function(new_name=None, **kwargs):
|
||||
return new_name
|
||||
|
||||
with self.assertWarns(FutureWarning):
|
||||
result = dummy_function(deprecated_name="old_value", new_name="new_value")
|
||||
self.assertEqual(result, "new_value")
|
||||
|
||||
@pytest.mark.torch_compile_test
|
||||
@require_torch_accelerator
|
||||
def test_compile_safe(self):
|
||||
@deprecate_kwarg("deprecated_factor", new_name="new_factor", version=INFINITE_VERSION)
|
||||
def dummy_function(new_factor=None, **kwargs):
|
||||
return new_factor * torch.ones(1, device=torch_device)
|
||||
|
||||
compiled_function = torch.compile(dummy_function, fullgraph=True)
|
||||
|
||||
# Check that we can correctly call the compiled function with the old name, without raising errors
|
||||
out = compiled_function(deprecated_factor=2)
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
# Check that we can correctly call the compiled function with the new name, without raising errors
|
||||
out = compiled_function(new_factor=2)
|
||||
self.assertEqual(out.item(), 2)
|
||||
|
||||
# Check that we can correctly call the compiled function with both names, without raising errors
|
||||
out = compiled_function(new_factor=2, deprecated_factor=10)
|
||||
self.assertEqual(out.item(), 2)
|
||||
110
tests/utils/test_doc_samples.py
Normal file
110
tests/utils/test_doc_samples.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import doctest
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import transformers
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@unittest.skip(reason="Temporarily disable the doc tests.")
|
||||
@require_torch
|
||||
@slow
|
||||
class TestCodeExamples(unittest.TestCase):
|
||||
def analyze_directory(
|
||||
self,
|
||||
directory: Path,
|
||||
identifier: str | None = None,
|
||||
ignore_files: list[str] | None = None,
|
||||
n_identifier: str | list[str] | None = None,
|
||||
only_modules: bool = True,
|
||||
):
|
||||
"""
|
||||
Runs through the specific directory, looking for the files identified with `identifier`. Executes
|
||||
the doctests in those files
|
||||
|
||||
Args:
|
||||
directory (`Path`): Directory containing the files
|
||||
identifier (`str`): Will parse files containing this
|
||||
ignore_files (`List[str]`): List of files to skip
|
||||
n_identifier (`str` or `List[str]`): Will not parse files containing this/these identifiers.
|
||||
only_modules (`bool`): Whether to only analyze modules
|
||||
"""
|
||||
files = [file for file in os.listdir(directory) if os.path.isfile(os.path.join(directory, file))]
|
||||
|
||||
if identifier is not None:
|
||||
files = [file for file in files if identifier in file]
|
||||
|
||||
if n_identifier is not None:
|
||||
if isinstance(n_identifier, list):
|
||||
for n_ in n_identifier:
|
||||
files = [file for file in files if n_ not in file]
|
||||
else:
|
||||
files = [file for file in files if n_identifier not in file]
|
||||
|
||||
ignore_files = ignore_files or []
|
||||
ignore_files.append("__init__.py")
|
||||
files = [file for file in files if file not in ignore_files]
|
||||
|
||||
for file in files:
|
||||
# Open all files
|
||||
print("Testing", file)
|
||||
|
||||
if only_modules:
|
||||
module_identifier = file.split(".")[0]
|
||||
try:
|
||||
module_identifier = getattr(transformers, module_identifier)
|
||||
suite = doctest.DocTestSuite(module_identifier)
|
||||
result = unittest.TextTestRunner().run(suite)
|
||||
self.assertIs(len(result.failures), 0)
|
||||
except AttributeError:
|
||||
logger.info(f"{module_identifier} is not a module.")
|
||||
else:
|
||||
result = doctest.testfile(str(".." / directory / file), optionflags=doctest.ELLIPSIS)
|
||||
self.assertIs(result.failed, 0)
|
||||
|
||||
def test_modeling_examples(self):
|
||||
transformers_directory = Path("src/transformers")
|
||||
files = "modeling"
|
||||
ignore_files = [
|
||||
"modeling_ctrl.py",
|
||||
"modeling_tf_ctrl.py",
|
||||
]
|
||||
self.analyze_directory(transformers_directory, identifier=files, ignore_files=ignore_files)
|
||||
|
||||
def test_tokenization_examples(self):
|
||||
transformers_directory = Path("src/transformers")
|
||||
files = "tokenization"
|
||||
self.analyze_directory(transformers_directory, identifier=files)
|
||||
|
||||
def test_configuration_examples(self):
|
||||
transformers_directory = Path("src/transformers")
|
||||
files = "configuration"
|
||||
self.analyze_directory(transformers_directory, identifier=files)
|
||||
|
||||
def test_remaining_examples(self):
|
||||
transformers_directory = Path("src/transformers")
|
||||
n_identifiers = ["configuration", "modeling", "tokenization"]
|
||||
self.analyze_directory(transformers_directory, n_identifier=n_identifiers)
|
||||
|
||||
def test_doc_sources(self):
|
||||
doc_source_directory = Path("docs/source")
|
||||
ignore_files = ["favicon.ico"]
|
||||
self.analyze_directory(doc_source_directory, ignore_files=ignore_files, only_modules=False)
|
||||
239
tests/utils/test_dynamic_module_utils.py
Normal file
239
tests/utils/test_dynamic_module_utils.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import dynamic_module_utils
|
||||
from transformers.dynamic_module_utils import get_cached_module_file, get_imports
|
||||
|
||||
|
||||
TOP_LEVEL_IMPORT = """
|
||||
import os
|
||||
"""
|
||||
|
||||
IMPORT_IN_FUNCTION = """
|
||||
def foo():
|
||||
import os
|
||||
return False
|
||||
"""
|
||||
|
||||
DEEPLY_NESTED_IMPORT = """
|
||||
def foo():
|
||||
def bar():
|
||||
if True:
|
||||
import os
|
||||
return False
|
||||
return bar()
|
||||
"""
|
||||
|
||||
TOP_LEVEL_TRY_IMPORT = """
|
||||
import os
|
||||
|
||||
try:
|
||||
import bar
|
||||
except ImportError:
|
||||
raise ValueError()
|
||||
"""
|
||||
|
||||
TRY_IMPORT_IN_FUNCTION = """
|
||||
import os
|
||||
|
||||
def foo():
|
||||
try:
|
||||
import bar
|
||||
except ImportError:
|
||||
raise ValueError()
|
||||
"""
|
||||
|
||||
MULTIPLE_EXCEPTS_IMPORT = """
|
||||
import os
|
||||
|
||||
try:
|
||||
import bar
|
||||
except (ImportError, AttributeError):
|
||||
raise ValueError()
|
||||
"""
|
||||
|
||||
EXCEPT_AS_IMPORT = """
|
||||
import os
|
||||
|
||||
try:
|
||||
import bar
|
||||
except ImportError as e:
|
||||
raise ValueError()
|
||||
"""
|
||||
|
||||
GENERIC_EXCEPT_IMPORT = """
|
||||
import os
|
||||
|
||||
try:
|
||||
import bar
|
||||
except:
|
||||
raise ValueError()
|
||||
"""
|
||||
|
||||
MULTILINE_TRY_IMPORT = """
|
||||
import os
|
||||
|
||||
try:
|
||||
import bar
|
||||
import baz
|
||||
except ImportError:
|
||||
raise ValueError()
|
||||
"""
|
||||
|
||||
MULTILINE_BOTH_IMPORT = """
|
||||
import os
|
||||
|
||||
try:
|
||||
import bar
|
||||
import baz
|
||||
except ImportError:
|
||||
x = 1
|
||||
raise ValueError()
|
||||
"""
|
||||
|
||||
CASES = [
|
||||
TOP_LEVEL_IMPORT,
|
||||
IMPORT_IN_FUNCTION,
|
||||
DEEPLY_NESTED_IMPORT,
|
||||
TOP_LEVEL_TRY_IMPORT,
|
||||
GENERIC_EXCEPT_IMPORT,
|
||||
MULTILINE_TRY_IMPORT,
|
||||
MULTILINE_BOTH_IMPORT,
|
||||
MULTIPLE_EXCEPTS_IMPORT,
|
||||
EXCEPT_AS_IMPORT,
|
||||
TRY_IMPORT_IN_FUNCTION,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", CASES)
|
||||
def test_import_parsing(tmp_path, case):
|
||||
tmp_file_path = os.path.join(tmp_path, "test_file.py")
|
||||
with open(tmp_file_path, "w") as _tmp_file:
|
||||
_tmp_file.write(case)
|
||||
|
||||
parsed_imports = get_imports(tmp_file_path)
|
||||
assert parsed_imports == ["os"]
|
||||
|
||||
|
||||
def _create_local_module(module_dir: Path, module_code: str, helper_code: str | None = None):
|
||||
module_dir.mkdir(parents=True, exist_ok=True)
|
||||
(module_dir / "custom_model.py").write_text(module_code, encoding="utf-8")
|
||||
if helper_code is not None:
|
||||
(module_dir / "helper.py").write_text(helper_code, encoding="utf-8")
|
||||
|
||||
|
||||
def test_get_cached_module_file_local_cache_key_uses_basename_and_content_hash(monkeypatch, tmp_path):
|
||||
modules_cache = tmp_path / "hf_modules_cache"
|
||||
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
|
||||
|
||||
model_dir_a = tmp_path / "pretrained_a" / "subdir"
|
||||
model_dir_b = tmp_path / "pretrained_b" / "subdir"
|
||||
model_dir_c = tmp_path / "pretrained_c" / "subdir"
|
||||
|
||||
_create_local_module(model_dir_a, 'MAGIC = "A"\n')
|
||||
_create_local_module(model_dir_b, 'MAGIC = "B"\n')
|
||||
_create_local_module(model_dir_c, 'MAGIC = "A"\n')
|
||||
|
||||
cached_module_a = get_cached_module_file(str(model_dir_a), "custom_model.py")
|
||||
cached_module_b = get_cached_module_file(str(model_dir_b), "custom_model.py")
|
||||
cached_module_c = get_cached_module_file(str(model_dir_c), "custom_model.py")
|
||||
|
||||
cached_module_path_a = Path(cached_module_a)
|
||||
assert cached_module_path_a.parent.parent.name == "subdir"
|
||||
assert len(cached_module_path_a.parent.name) == 16
|
||||
assert cached_module_a != cached_module_b
|
||||
assert cached_module_a == cached_module_c
|
||||
|
||||
|
||||
def test_get_cached_module_file_local_cache_key_includes_relative_import_sources(monkeypatch, tmp_path):
|
||||
modules_cache = tmp_path / "hf_modules_cache"
|
||||
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
|
||||
|
||||
model_dir_a = tmp_path / "pretrained_a" / "subdir"
|
||||
model_dir_b = tmp_path / "pretrained_b" / "subdir"
|
||||
|
||||
module_code = "from .helper import MAGIC\nVALUE = MAGIC\n"
|
||||
_create_local_module(model_dir_a, module_code, 'MAGIC = "A"\n')
|
||||
_create_local_module(model_dir_b, module_code, 'MAGIC = "B"\n')
|
||||
|
||||
cached_module_a = get_cached_module_file(str(model_dir_a), "custom_model.py")
|
||||
cached_module_b = get_cached_module_file(str(model_dir_b), "custom_model.py")
|
||||
|
||||
cached_helper_a = modules_cache / Path(cached_module_a).parent / "helper.py"
|
||||
cached_helper_b = modules_cache / Path(cached_module_b).parent / "helper.py"
|
||||
|
||||
assert cached_module_a != cached_module_b
|
||||
assert cached_helper_a.read_text(encoding="utf-8") == 'MAGIC = "A"\n'
|
||||
assert cached_helper_b.read_text(encoding="utf-8") == 'MAGIC = "B"\n'
|
||||
|
||||
|
||||
def test_get_cached_module_file_local_copies_transitive_relative_imports(monkeypatch, tmp_path):
|
||||
modules_cache = tmp_path / "hf_modules_cache"
|
||||
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
|
||||
|
||||
model_dir = tmp_path / "pretrained" / "subdir"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
# A → B → C: only A is the entry point; C is a transitive dep that must still be copied
|
||||
(model_dir / "custom_model.py").write_text("from .helper import VALUE\n", encoding="utf-8")
|
||||
(model_dir / "helper.py").write_text("from .base import BASE\nVALUE = BASE\n", encoding="utf-8")
|
||||
(model_dir / "base.py").write_text('BASE = "transitive"\n', encoding="utf-8")
|
||||
|
||||
cached_module = get_cached_module_file(str(model_dir), "custom_model.py")
|
||||
cache_dir = modules_cache / Path(cached_module).parent
|
||||
|
||||
assert (cache_dir / "helper.py").exists(), "direct import must be copied"
|
||||
assert (cache_dir / "base.py").exists(), "transitive import must be copied"
|
||||
|
||||
|
||||
def test_get_cached_module_file_local_cache_key_includes_transitive_import_sources(monkeypatch, tmp_path):
|
||||
modules_cache = tmp_path / "hf_modules_cache"
|
||||
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
|
||||
|
||||
for model_dir, base_val in [
|
||||
(tmp_path / "pretrained_a" / "subdir", '"X"'),
|
||||
(tmp_path / "pretrained_b" / "subdir", '"Y"'),
|
||||
]:
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
(model_dir / "custom_model.py").write_text("from .helper import VALUE\n", encoding="utf-8")
|
||||
(model_dir / "helper.py").write_text("from .base import BASE\nVALUE = BASE\n", encoding="utf-8")
|
||||
(model_dir / "base.py").write_text(f"BASE = {base_val}\n", encoding="utf-8")
|
||||
|
||||
cached_a = get_cached_module_file(str(tmp_path / "pretrained_a" / "subdir"), "custom_model.py")
|
||||
cached_b = get_cached_module_file(str(tmp_path / "pretrained_b" / "subdir"), "custom_model.py")
|
||||
|
||||
# Different content in transitive dep → different hash → different cache dirs
|
||||
assert cached_a != cached_b
|
||||
|
||||
|
||||
def test_get_cached_module_file_local_cache_key_keeps_hash_stable_with_different_basenames(monkeypatch, tmp_path):
|
||||
modules_cache = tmp_path / "hf_modules_cache"
|
||||
monkeypatch.setattr(dynamic_module_utils, "HF_MODULES_CACHE", str(modules_cache))
|
||||
|
||||
model_dir_a = tmp_path / "pretrained_a" / "alpha_subdir"
|
||||
model_dir_b = tmp_path / "pretrained_b" / "beta_subdir"
|
||||
|
||||
_create_local_module(model_dir_a, 'MAGIC = "A"\n')
|
||||
_create_local_module(model_dir_b, 'MAGIC = "A"\n')
|
||||
|
||||
cached_module_a = Path(get_cached_module_file(str(model_dir_a), "custom_model.py"))
|
||||
cached_module_b = Path(get_cached_module_file(str(model_dir_b), "custom_model.py"))
|
||||
|
||||
assert cached_module_a.parent.parent.name == "alpha_subdir"
|
||||
assert cached_module_b.parent.parent.name == "beta_subdir"
|
||||
assert cached_module_a.parent.name == cached_module_b.parent.name
|
||||
38
tests/utils/test_expectations.py
Normal file
38
tests/utils/test_expectations.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import Expectations
|
||||
|
||||
|
||||
class ExpectationsTest(unittest.TestCase):
|
||||
def test_expectations(self):
|
||||
# We use the expectations below to make sure the right expectations are found for the right devices.
|
||||
# Each value is just a unique ID.
|
||||
expectations = Expectations(
|
||||
{
|
||||
(None, None): 1,
|
||||
("cuda", 8): 2,
|
||||
("cuda", 7): 3,
|
||||
("rocm", 8): 4,
|
||||
("rocm", None): 5,
|
||||
("cpu", None): 6,
|
||||
("xpu", 3): 7,
|
||||
}
|
||||
)
|
||||
|
||||
def check(expected_id, device_prop):
|
||||
found_id = expectations.find_expectation(device_prop)
|
||||
assert found_id == expected_id, f"Expected {expected_id} for {device_prop}, found {found_id}"
|
||||
|
||||
# npu has no matches so should find default expectation
|
||||
check(1, ("npu", None, None))
|
||||
check(7, ("xpu", 3, None))
|
||||
check(2, ("cuda", 8, None))
|
||||
check(3, ("cuda", 7, None))
|
||||
check(4, ("rocm", 9, None))
|
||||
check(4, ("rocm", None, None))
|
||||
check(2, ("cuda", 2, None))
|
||||
|
||||
# We also test that if there is no default excpectation and no match is found, a ValueError is raised.
|
||||
expectations = Expectations({("cuda", 8): 1})
|
||||
with self.assertRaises(ValueError):
|
||||
expectations.find_expectation(("xpu", None))
|
||||
301
tests/utils/test_feature_extraction_utils.py
Normal file
301
tests/utils/test_feature_extraction_utils.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test, require_torch
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures")
|
||||
|
||||
|
||||
class BatchFeatureTester(unittest.TestCase):
|
||||
"""Tests for the BatchFeature class and tensor conversion."""
|
||||
|
||||
def test_batch_feature_basic_access_and_no_conversion(self):
|
||||
"""Test basic dict/attribute access and no conversion when tensor_type=None."""
|
||||
data = {"input_values": [[1, 2, 3], [4, 5, 6]], "labels": [0, 1]}
|
||||
batch = BatchFeature(data)
|
||||
|
||||
# Dict-style and attribute-style access
|
||||
self.assertEqual(batch["input_values"], [[1, 2, 3], [4, 5, 6]])
|
||||
self.assertEqual(batch.labels, [0, 1])
|
||||
|
||||
# No conversion without tensor_type
|
||||
self.assertIsInstance(batch["input_values"], list)
|
||||
|
||||
@require_torch
|
||||
def test_batch_feature_numpy_conversion(self):
|
||||
"""Test conversion to numpy arrays from lists and existing numpy arrays."""
|
||||
# From lists
|
||||
batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="np")
|
||||
self.assertIsInstance(batch["input_values"], np.ndarray)
|
||||
self.assertEqual(batch["input_values"].shape, (2, 3))
|
||||
|
||||
# From numpy arrays (should remain numpy)
|
||||
numpy_data = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
batch_arrays = BatchFeature({"input_values": numpy_data}, tensor_type="np")
|
||||
np.testing.assert_array_equal(batch_arrays["input_values"], numpy_data)
|
||||
|
||||
# From list of numpy arrays with same shape should stack
|
||||
numpy_data = [np.array([[1, 2, 3], [4, 5, 6]]), np.array([[7, 8, 9], [10, 11, 12]])]
|
||||
batch_stacked = BatchFeature({"input_values": numpy_data}, tensor_type="np")
|
||||
np.testing.assert_array_equal(
|
||||
batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
|
||||
)
|
||||
|
||||
# from tensor
|
||||
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="np")
|
||||
np.testing.assert_array_equal(batch_tensor["input_values"], tensor.numpy())
|
||||
|
||||
# from list of tensors with same shape should stack
|
||||
tensors = [torch.tensor([[1, 2, 3], [4, 5, 6]]), torch.tensor([[7, 8, 9], [10, 11, 12]])]
|
||||
batch_stacked = BatchFeature({"input_values": tensors}, tensor_type="np")
|
||||
self.assertIsInstance(batch_stacked["input_values"], np.ndarray)
|
||||
np.testing.assert_array_equal(
|
||||
batch_stacked["input_values"], np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_batch_feature_pytorch_conversion(self):
|
||||
"""Test conversion to PyTorch tensors from various input types."""
|
||||
# From lists
|
||||
batch = BatchFeature({"input_values": [[1, 2, 3], [4, 5, 6]]}, tensor_type="pt")
|
||||
self.assertIsInstance(batch["input_values"], torch.Tensor)
|
||||
self.assertEqual(batch["input_values"].shape, (2, 3))
|
||||
|
||||
# from tensor (should be returned as-is)
|
||||
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
batch_tensor = BatchFeature({"input_values": tensor}, tensor_type="pt")
|
||||
torch.testing.assert_close(batch_tensor["input_values"], tensor)
|
||||
|
||||
# From numpy arrays
|
||||
batch_numpy = BatchFeature({"input_values": np.array([[1, 2]])}, tensor_type="pt")
|
||||
self.assertIsInstance(batch_numpy["input_values"], torch.Tensor)
|
||||
|
||||
# List of same-shape tensors should stack
|
||||
tensors = [torch.randn(3, 10, 10) for _ in range(3)]
|
||||
batch_stacked = BatchFeature({"pixel_values": tensors}, tensor_type="pt")
|
||||
self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10))
|
||||
|
||||
# List of same-shape numpy arrays should stack
|
||||
numpy_arrays = [np.random.randn(3, 10, 10) for _ in range(3)]
|
||||
batch_stacked = BatchFeature({"pixel_values": numpy_arrays}, tensor_type="pt")
|
||||
self.assertIsInstance(batch_stacked["pixel_values"], torch.Tensor)
|
||||
self.assertEqual(batch_stacked["pixel_values"].shape, (3, 3, 10, 10))
|
||||
|
||||
@require_torch
|
||||
def test_batch_feature_error_handling(self):
|
||||
"""Test clear error messages for common conversion failures."""
|
||||
# Ragged tensors (different shapes)
|
||||
data_ragged = {"values": [torch.randn(3, 224, 224), torch.randn(3, 448, 448)]}
|
||||
with self.assertRaises(ValueError) as context:
|
||||
BatchFeature(data_ragged, tensor_type="pt")
|
||||
error_msg = str(context.exception)
|
||||
self.assertIn("stack expects each tensor to be equal size", error_msg.lower())
|
||||
self.assertIn("return_tensors=None", error_msg)
|
||||
|
||||
# Ragged numpy arrays (different shapes)
|
||||
data_ragged = {"values": [np.random.randn(3, 224, 224), np.random.randn(3, 448, 448)]}
|
||||
with self.assertRaises(ValueError) as context:
|
||||
BatchFeature(data_ragged, tensor_type="np")
|
||||
error_msg = str(context.exception)
|
||||
self.assertIn("inhomogeneous", error_msg.lower())
|
||||
self.assertIn("return_tensors=None", error_msg)
|
||||
|
||||
@require_torch
|
||||
def test_batch_feature_auto_skip_non_array_like(self):
|
||||
"""Test that non-array-like values are automatically skipped during tensor conversion."""
|
||||
data = {
|
||||
"values": [[1, 2]],
|
||||
"metadata": {"key": "val"},
|
||||
"image_path": "/path/to/image.jpg",
|
||||
"tags": ["tag1", "tag2"],
|
||||
"extra": None,
|
||||
}
|
||||
batch = BatchFeature(data, tensor_type="pt")
|
||||
|
||||
# values should be converted
|
||||
self.assertIsInstance(batch["values"], torch.Tensor)
|
||||
|
||||
# Non-array-like values should remain unchanged
|
||||
self.assertIsInstance(batch["metadata"], dict)
|
||||
self.assertEqual(batch["metadata"], {"key": "val"})
|
||||
self.assertIsInstance(batch["image_path"], str)
|
||||
self.assertIsInstance(batch["tags"], list)
|
||||
self.assertEqual(batch["tags"], ["tag1", "tag2"])
|
||||
self.assertIsNone(batch["extra"])
|
||||
|
||||
@require_torch
|
||||
def test_batch_feature_skip_tensor_conversion(self):
|
||||
"""Test skip_tensor_conversion parameter for metadata fields."""
|
||||
import torch
|
||||
|
||||
data = {"pixel_values": [[1, 2, 3]], "num_crops": [1, 2], "sizes": [(224, 224)]}
|
||||
batch = BatchFeature(data, tensor_type="pt", skip_tensor_conversion=["num_crops", "sizes"])
|
||||
|
||||
# pixel_values should be converted
|
||||
self.assertIsInstance(batch["pixel_values"], torch.Tensor)
|
||||
# num_crops and sizes should remain as lists
|
||||
self.assertIsInstance(batch["num_crops"], list)
|
||||
self.assertIsInstance(batch["sizes"], list)
|
||||
|
||||
@require_torch
|
||||
def test_batch_feature_convert_to_tensors_method(self):
|
||||
"""Test convert_to_tensors method can be called after initialization."""
|
||||
import torch
|
||||
|
||||
data = {"input_values": [[1, 2, 3]], "metadata": [1, 2]}
|
||||
batch = BatchFeature(data) # No conversion initially
|
||||
self.assertIsInstance(batch["input_values"], list)
|
||||
|
||||
# Convert with skip parameter
|
||||
batch.convert_to_tensors(tensor_type="pt", skip_tensor_conversion=["metadata"])
|
||||
self.assertIsInstance(batch["input_values"], torch.Tensor)
|
||||
self.assertIsInstance(batch["metadata"], list)
|
||||
|
||||
@require_torch
|
||||
def test_batch_feature_to_with_nested_tensors(self):
|
||||
"""Test .to() method works recursively with nested lists and tuples of tensors."""
|
||||
batch = BatchFeature(
|
||||
{
|
||||
"list_tensors": [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])],
|
||||
"nested_list": [[torch.tensor([1.0]), torch.tensor([2.0])]],
|
||||
"tuple_tensors": (torch.tensor([5.0]), torch.tensor([6.0])),
|
||||
}
|
||||
)
|
||||
|
||||
batch_fp16 = batch.to(torch.float16)
|
||||
|
||||
# Check lists of tensors are converted
|
||||
self.assertIsInstance(batch_fp16["list_tensors"], list)
|
||||
self.assertEqual(batch_fp16["list_tensors"][0].dtype, torch.float16)
|
||||
self.assertEqual(batch_fp16["list_tensors"][1].dtype, torch.float16)
|
||||
|
||||
# Check nested lists are converted
|
||||
self.assertIsInstance(batch_fp16["nested_list"][0], list)
|
||||
self.assertEqual(batch_fp16["nested_list"][0][0].dtype, torch.float16)
|
||||
|
||||
# Check tuples are preserved and converted
|
||||
self.assertIsInstance(batch_fp16["tuple_tensors"], tuple)
|
||||
self.assertEqual(batch_fp16["tuple_tensors"][0].dtype, torch.float16)
|
||||
|
||||
|
||||
class FeatureExtractorUtilTester(unittest.TestCase):
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"failed", request=mock.Mock(), response=mock.Mock()
|
||||
)
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
|
||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class FeatureExtractorPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
feature_extractor.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
feature_extractor.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_in_organization_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_dynamic_feature_extractor(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
CustomFeatureExtractor.register_for_auto_class()
|
||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
|
||||
feature_extractor.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(
|
||||
feature_extractor.auto_map,
|
||||
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
|
||||
)
|
||||
|
||||
new_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
|
||||
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
|
||||
102
tests/utils/test_file_utils.py
Normal file
102
tests/utils/test_file_utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
# Try to import everything from transformers to ensure every object can be loaded.
|
||||
from transformers import * # noqa F406
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_torch
|
||||
from transformers.utils import ContextManagers, find_labels, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
|
||||
|
||||
|
||||
MODEL_ID = DUMMY_UNKNOWN_IDENTIFIER
|
||||
# An actual model hosted on huggingface.co
|
||||
|
||||
REVISION_ID_DEFAULT = "main"
|
||||
# Default branch name
|
||||
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
|
||||
# One particular commit (not the top of `main`)
|
||||
REVISION_ID_INVALID = "aaaaaaa"
|
||||
# This commit does not exist, so we should 404.
|
||||
|
||||
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
|
||||
# Sha-1 of config.json on the top of `main`, for checking purposes
|
||||
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
|
||||
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
|
||||
|
||||
|
||||
# Dummy contexts to test `ContextManagers`
|
||||
@contextlib.contextmanager
|
||||
def context_en():
|
||||
print("Welcome!")
|
||||
yield
|
||||
print("Bye!")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def context_fr():
|
||||
print("Bonjour!")
|
||||
yield
|
||||
print("Au revoir!")
|
||||
|
||||
|
||||
class TestImportMechanisms(unittest.TestCase):
|
||||
def test_module_spec_available(self):
|
||||
# If the spec is missing, importlib would not be able to import the module dynamically.
|
||||
assert transformers.__spec__ is not None
|
||||
assert importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
class GenericUtilTests(unittest.TestCase):
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_context_managers_no_context(self, mock_stdout):
|
||||
with ContextManagers([]):
|
||||
print("Transformers are awesome!")
|
||||
# The print statement adds a new line at the end of the output
|
||||
self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_context_managers_one_context(self, mock_stdout):
|
||||
with ContextManagers([context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
|
||||
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
def test_context_managers_two_context(self, mock_stdout):
|
||||
with ContextManagers([context_fr(), context_en()]):
|
||||
print("Transformers are awesome!")
|
||||
# The output should be wrapped with an English and French welcome and goodbye
|
||||
self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
|
||||
|
||||
@require_torch
|
||||
def test_find_labels_pt(self):
|
||||
self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
|
||||
self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
|
||||
self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
|
||||
|
||||
# find_labels works regardless of the class name (it detects the framework through inheritance)
|
||||
class DummyModel(BertForSequenceClassification):
|
||||
pass
|
||||
|
||||
self.assertEqual(find_labels(DummyModel), ["labels"])
|
||||
195
tests/utils/test_fusion_mapping.py
Normal file
195
tests/utils/test_fusion_mapping.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
import transformers.conversion_mapping as conversion_mapping
|
||||
import transformers.fusion_mapping as fusion_mapping
|
||||
import transformers.monkey_patching as monkey_patching
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.conversion_mapping import get_checkpoint_conversion_mapping, register_checkpoint_conversion_mapping
|
||||
from transformers.core_model_loading import Conv3dToLinear, WeightConverter
|
||||
from transformers.fusion_mapping import register_fusion_patches
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.monkey_patching import apply_patches, get_patch_mapping
|
||||
|
||||
|
||||
DUMMY_TRANSFORMERS_MODULE_NAME = "transformers.test_fusion_mapping_dummy"
|
||||
# `apply_patches()` scans `sys.modules` and only rewrites class attributes exposed
|
||||
# from `transformers.*` modules, so this dummy class must be reachable through a
|
||||
# fake `transformers` module instead of only through a local symbol.
|
||||
DUMMY_TRANSFORMERS_MODULE = types.ModuleType(DUMMY_TRANSFORMERS_MODULE_NAME)
|
||||
sys.modules[DUMMY_TRANSFORMERS_MODULE_NAME] = DUMMY_TRANSFORMERS_MODULE
|
||||
|
||||
|
||||
class DummyVisionConfig(PretrainedConfig):
|
||||
model_type = "dummy_fusion_vision"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
patch_size=2,
|
||||
temporal_patch_size=2,
|
||||
patch_embed_stride=(2, 2, 2),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.patch_embed_stride = patch_embed_stride
|
||||
|
||||
|
||||
class DummyFusionConfig(PretrainedConfig):
|
||||
model_type = "dummy_fusion"
|
||||
sub_configs = {"vision_config": DummyVisionConfig}
|
||||
|
||||
def __init__(self, vision_config=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if vision_config is None:
|
||||
vision_config = DummyVisionConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = DummyVisionConfig(**vision_config)
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
|
||||
class DummyPatchEmbedding(nn.Module):
|
||||
def __init__(self, stride=(2, 2, 2), bias=False):
|
||||
super().__init__()
|
||||
self.embed_dim = 8
|
||||
self.proj = nn.Conv3d(3, self.embed_dim, kernel_size=(2, 2, 2), stride=stride, bias=bias)
|
||||
|
||||
|
||||
DUMMY_PATCHABLE_CLASSES = {"DummyPatchEmbedding": DummyPatchEmbedding}
|
||||
|
||||
for class_name, patchable_class in DUMMY_PATCHABLE_CLASSES.items():
|
||||
setattr(DUMMY_TRANSFORMERS_MODULE, class_name, patchable_class)
|
||||
|
||||
|
||||
class DummyFusionModel(PreTrainedModel):
|
||||
config_class = DummyFusionConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# Resolve the class through the fake `transformers.*` module so monkey patching
|
||||
# can replace it before instantiation.
|
||||
self.patch_embed = DUMMY_TRANSFORMERS_MODULE.DummyPatchEmbedding(
|
||||
stride=config.vision_config.patch_embed_stride, bias=True
|
||||
)
|
||||
self.post_init()
|
||||
|
||||
|
||||
class FusionMappingTest(unittest.TestCase):
|
||||
"""Covers registration, no-match, and conflict handling for fusion mapping."""
|
||||
|
||||
fusion_config = {"patch_embeddings": True}
|
||||
|
||||
def setUp(self):
|
||||
self.patch_mapping_patcher = patch.object(monkey_patching, "_monkey_patch_mapping_cache", {})
|
||||
self.patch_mapping_patcher.start()
|
||||
self.discovery_cache_patcher = patch.object(fusion_mapping, "_FUSION_DISCOVERY_CACHE", {})
|
||||
self.discovery_cache_patcher.start()
|
||||
self.checkpoint_conversion_mapping_cache = deepcopy(conversion_mapping._checkpoint_conversion_mapping_cache)
|
||||
|
||||
def tearDown(self):
|
||||
self.patch_mapping_patcher.stop()
|
||||
self.discovery_cache_patcher.stop()
|
||||
conversion_mapping._checkpoint_conversion_mapping_cache = deepcopy(self.checkpoint_conversion_mapping_cache)
|
||||
|
||||
def test_register_fusion_patches_is_effective_on_dummy_model(self):
|
||||
# Registers and applies a fusion on a dummy model.
|
||||
DummyFusionConfig.model_type = f"dummy_fusion_{self._testMethodName}"
|
||||
config = DummyFusionConfig()
|
||||
|
||||
self.assertEqual(get_patch_mapping(), {})
|
||||
self.assertIsNone(get_checkpoint_conversion_mapping(config.model_type))
|
||||
self.assertIsInstance(DummyFusionModel(config).patch_embed.proj, nn.Conv3d)
|
||||
|
||||
register_fusion_patches(DummyFusionModel, config, fusion_config=self.fusion_config)
|
||||
|
||||
self.assertEqual(len(get_patch_mapping()), 1)
|
||||
self.assertEqual(len(get_checkpoint_conversion_mapping(config.model_type)), 2)
|
||||
|
||||
with apply_patches():
|
||||
fused_model = DummyFusionModel(config)
|
||||
|
||||
fused_projection = getattr(
|
||||
fused_model.patch_embed, "linear_proj", getattr(fused_model.patch_embed, "proj", None)
|
||||
)
|
||||
self.assertIsInstance(fused_projection, nn.Linear)
|
||||
|
||||
def test_register_fusion_patches_skips_when_no_modules_match(self):
|
||||
# Leaves registries untouched when nothing is fusable.
|
||||
DummyFusionConfig.model_type = f"dummy_fusion_{self._testMethodName}"
|
||||
config = DummyFusionConfig(vision_config={"patch_embed_stride": (1, 1, 1)})
|
||||
|
||||
register_fusion_patches(DummyFusionModel, config, fusion_config=self.fusion_config)
|
||||
|
||||
self.assertEqual(get_patch_mapping(), {})
|
||||
self.assertIsNone(get_checkpoint_conversion_mapping(config.model_type))
|
||||
|
||||
def test_register_fusion_patches_raises_on_transform_conflicts(self):
|
||||
# Rejects transforms that would shadow an existing source pattern.
|
||||
DummyFusionConfig.model_type = f"dummy_fusion_{self._testMethodName}"
|
||||
config = DummyFusionConfig()
|
||||
model_type = config.model_type
|
||||
|
||||
# build a conflicting conversion mapping with the same source pattern but different target pattern
|
||||
register_checkpoint_conversion_mapping(
|
||||
model_type,
|
||||
[
|
||||
WeightConverter(
|
||||
source_patterns=r"patch_embed\.proj\.weight$",
|
||||
target_patterns=r"patch_embed\.other_linear_proj\.weight$",
|
||||
operations=[Conv3dToLinear(in_channels=3, kernel_size=(2, 2, 2))],
|
||||
)
|
||||
],
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "conflicts with an existing conversion mapping"):
|
||||
register_fusion_patches(DummyFusionModel, config, fusion_config=self.fusion_config)
|
||||
|
||||
def test_from_pretrained_uses_serialized_fusion_config(self):
|
||||
# A serialized `fusion_config` is reused on a later load.
|
||||
DummyFusionConfig.model_type = f"dummy_fusion_{self._testMethodName}"
|
||||
|
||||
with tempfile.TemporaryDirectory() as source_dir, tempfile.TemporaryDirectory() as fused_dir:
|
||||
DummyFusionModel(DummyFusionConfig()).save_pretrained(source_dir)
|
||||
|
||||
fused_model = DummyFusionModel.from_pretrained(source_dir, fusion_config=self.fusion_config)
|
||||
fused_model.save_pretrained(fused_dir)
|
||||
|
||||
# Simulate a fresh process so the second load comes only from the serialized config.
|
||||
monkey_patching._monkey_patch_mapping_cache.clear()
|
||||
fusion_mapping._FUSION_DISCOVERY_CACHE.clear()
|
||||
conversion_mapping._checkpoint_conversion_mapping_cache = deepcopy(
|
||||
self.checkpoint_conversion_mapping_cache
|
||||
)
|
||||
|
||||
reloaded_model = DummyFusionModel.from_pretrained(fused_dir)
|
||||
|
||||
fused_projection = getattr(
|
||||
reloaded_model.patch_embed, "linear_proj", getattr(reloaded_model.patch_embed, "proj", None)
|
||||
)
|
||||
self.assertIsInstance(fused_projection, nn.Linear)
|
||||
482
tests/utils/test_generic.py
Normal file
482
tests/utils/test_generic.py
Normal file
@@ -0,0 +1,482 @@
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers.configuration_utils import PreTrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast, ModelOutput
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils import (
|
||||
can_return_tuple,
|
||||
expand_dims,
|
||||
filter_out_non_signature_kwargs,
|
||||
flatten_dict,
|
||||
is_torch_available,
|
||||
reshape,
|
||||
squeeze,
|
||||
to_py_obj,
|
||||
transpose,
|
||||
)
|
||||
from transformers.utils.generic import retry, split_attention_implementation
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class GenericTester(unittest.TestCase):
|
||||
def test_flatten_dict(self):
|
||||
input_dict = {
|
||||
"task_specific_params": {
|
||||
"summarization": {"length_penalty": 1.0, "max_length": 128, "min_length": 12, "num_beams": 4},
|
||||
"summarization_cnn": {"length_penalty": 2.0, "max_length": 142, "min_length": 56, "num_beams": 4},
|
||||
"summarization_xsum": {"length_penalty": 1.0, "max_length": 62, "min_length": 11, "num_beams": 6},
|
||||
}
|
||||
}
|
||||
expected_dict = {
|
||||
"task_specific_params.summarization.length_penalty": 1.0,
|
||||
"task_specific_params.summarization.max_length": 128,
|
||||
"task_specific_params.summarization.min_length": 12,
|
||||
"task_specific_params.summarization.num_beams": 4,
|
||||
"task_specific_params.summarization_cnn.length_penalty": 2.0,
|
||||
"task_specific_params.summarization_cnn.max_length": 142,
|
||||
"task_specific_params.summarization_cnn.min_length": 56,
|
||||
"task_specific_params.summarization_cnn.num_beams": 4,
|
||||
"task_specific_params.summarization_xsum.length_penalty": 1.0,
|
||||
"task_specific_params.summarization_xsum.max_length": 62,
|
||||
"task_specific_params.summarization_xsum.min_length": 11,
|
||||
"task_specific_params.summarization_xsum.num_beams": 6,
|
||||
}
|
||||
|
||||
self.assertEqual(flatten_dict(input_dict), expected_dict)
|
||||
|
||||
def test_transpose_numpy(self):
|
||||
x = np.random.randn(3, 4)
|
||||
self.assertTrue(np.allclose(transpose(x), x.transpose()))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), x.transpose((1, 2, 0))))
|
||||
|
||||
@require_torch
|
||||
def test_transpose_torch(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
|
||||
|
||||
@require_torch
|
||||
def test_reshape_torch(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
|
||||
|
||||
x = np.random.randn(3, 4, 5)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
|
||||
|
||||
@require_torch
|
||||
def test_squeeze_torch(self):
|
||||
x = np.random.randn(1, 3, 4)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
|
||||
|
||||
x = np.random.randn(1, 4, 1, 5)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
|
||||
|
||||
def test_expand_dims_numpy(self):
|
||||
x = np.random.randn(3, 4)
|
||||
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.expand_dims(x, axis=1)))
|
||||
|
||||
@require_torch
|
||||
def test_expand_dims_torch(self):
|
||||
x = np.random.randn(3, 4)
|
||||
t = torch.tensor(x)
|
||||
self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
|
||||
|
||||
def test_to_py_obj_native(self):
|
||||
self.assertTrue(to_py_obj(1) == 1)
|
||||
self.assertTrue(to_py_obj([1, 2, 3]) == [1, 2, 3])
|
||||
self.assertTrue(to_py_obj([((1.0, 1.1), 1.2), (2, 3)]) == [[[1.0, 1.1], 1.2], [2, 3]])
|
||||
|
||||
def test_to_py_obj_numpy(self):
|
||||
x1 = [[1, 2, 3], [4, 5, 6]]
|
||||
t1 = np.array(x1)
|
||||
self.assertTrue(to_py_obj(t1) == x1)
|
||||
|
||||
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||
t2 = np.array(x2)
|
||||
self.assertTrue(to_py_obj(t2) == x2)
|
||||
|
||||
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
|
||||
|
||||
def test_split_attention_implementation(self):
|
||||
self.assertEqual(split_attention_implementation(None), (False, None))
|
||||
self.assertEqual(split_attention_implementation("sdpa"), (False, "sdpa"))
|
||||
self.assertEqual(split_attention_implementation("paged|flash_attention_2"), (True, "flash_attention_2"))
|
||||
self.assertEqual(
|
||||
split_attention_implementation("paged|kernels-community/flash-attn3"),
|
||||
(True, "kernels-community/flash-attn3"),
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_to_py_obj_torch(self):
|
||||
x1 = [[1, 2, 3], [4, 5, 6]]
|
||||
t1 = torch.tensor(x1)
|
||||
self.assertTrue(to_py_obj(t1) == x1)
|
||||
|
||||
x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
|
||||
t2 = torch.tensor(x2)
|
||||
self.assertTrue(to_py_obj(t2) == x2)
|
||||
|
||||
self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
|
||||
|
||||
def test_model_output_subclass(self):
|
||||
# testing with “dict-like init” case
|
||||
out = CausalLMOutputWithPast({"logits": torch.ones(2, 3, 4)})
|
||||
self.assertNotEqual(out["logits"], None)
|
||||
self.assertEqual(out.loss, None)
|
||||
self.assertEqual(len(out.to_tuple()), 1)
|
||||
|
||||
# testing with dataclass init case
|
||||
out = CausalLMOutputWithPast(logits=torch.ones(2, 3, 4))
|
||||
self.assertNotEqual(out["logits"], None)
|
||||
self.assertEqual(out.loss, None)
|
||||
self.assertEqual(len(out.to_tuple()), 1)
|
||||
|
||||
# testing with updating a previously-None key after init with attribute assignment
|
||||
out = CausalLMOutputWithPast(logits=torch.ones(2, 3, 4))
|
||||
out.loss = torch.tensor(0.5)
|
||||
self.assertEqual(out.loss, torch.tensor(0.5))
|
||||
self.assertEqual(len(out.to_tuple()), 2)
|
||||
|
||||
# testing with updating a previously-None key after init with dictionary assignment
|
||||
out = CausalLMOutputWithPast(logits=torch.ones(2, 3, 4))
|
||||
out["loss"] = torch.tensor(0.5)
|
||||
self.assertEqual(out.loss, torch.tensor(0.5))
|
||||
self.assertEqual(len(out.to_tuple()), 2)
|
||||
|
||||
@require_torch
|
||||
def test_register_model_output_pytree_node_skipped_during_compile(self):
|
||||
# Regression test: on AMD CI (PyTorch 2.8.0+rocm), `set.__contains__` is not
|
||||
# traceable by TorchDynamo. `_register_model_output_pytree_node` must return
|
||||
# early when called inside a compiled context, before touching the set.
|
||||
from transformers.utils.generic import _register_model_output_pytree_node
|
||||
|
||||
@dataclass
|
||||
class DummyOutput(ModelOutput):
|
||||
last_hidden_state: "torch.Tensor" = None
|
||||
|
||||
# Eager registration works normally
|
||||
_register_model_output_pytree_node(DummyOutput)
|
||||
|
||||
# Simulate being inside torch.compile — must not raise
|
||||
with patch("torch.compiler.is_compiling", return_value=True):
|
||||
_register_model_output_pytree_node(DummyOutput)
|
||||
|
||||
|
||||
class ValidationDecoratorTester(unittest.TestCase):
|
||||
def test_cases_no_warning(self):
|
||||
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
# basic test
|
||||
@filter_out_non_signature_kwargs()
|
||||
def func1(a):
|
||||
return a
|
||||
|
||||
result = func1(1)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
# include extra kwarg
|
||||
@filter_out_non_signature_kwargs(extra=["extra_arg"])
|
||||
def func2(a, **kwargs):
|
||||
return a, kwargs
|
||||
|
||||
a, kwargs = func2(1)
|
||||
self.assertEqual(a, 1)
|
||||
self.assertEqual(kwargs, {})
|
||||
|
||||
a, kwargs = func2(1, extra_arg=2)
|
||||
self.assertEqual(a, 1)
|
||||
self.assertEqual(kwargs, {"extra_arg": 2})
|
||||
|
||||
# multiple extra kwargs
|
||||
@filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
|
||||
def func3(a, **kwargs):
|
||||
return a, kwargs
|
||||
|
||||
a, kwargs = func3(2)
|
||||
self.assertEqual(a, 2)
|
||||
self.assertEqual(kwargs, {})
|
||||
|
||||
a, kwargs = func3(3, extra_arg2=3)
|
||||
self.assertEqual(a, 3)
|
||||
self.assertEqual(kwargs, {"extra_arg2": 3})
|
||||
|
||||
a, kwargs = func3(1, extra_arg=2, extra_arg2=3)
|
||||
self.assertEqual(a, 1)
|
||||
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
|
||||
|
||||
# Check that no warnings were raised
|
||||
self.assertEqual(len(raised_warnings), 0, f"Warning raised: {[w.message for w in raised_warnings]}")
|
||||
|
||||
def test_cases_with_warnings(self):
|
||||
@filter_out_non_signature_kwargs()
|
||||
def func1(a):
|
||||
return a
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
func1(1, extra_arg=2)
|
||||
|
||||
@filter_out_non_signature_kwargs(extra=["extra_arg"])
|
||||
def func2(a, **kwargs):
|
||||
return kwargs
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
kwargs = func2(1, extra_arg=2, extra_arg2=3)
|
||||
self.assertEqual(kwargs, {"extra_arg": 2})
|
||||
|
||||
@filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
|
||||
def func3(a, **kwargs):
|
||||
return kwargs
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
kwargs = func3(1, extra_arg=2, extra_arg2=3, extra_arg3=4)
|
||||
self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
|
||||
|
||||
|
||||
@require_torch
|
||||
class CanReturnTupleDecoratorTester(unittest.TestCase):
|
||||
def _get_model(self, config, store_config=True, raise_in_forward=False):
|
||||
# Simple model class for testing can_return_tuple decorator.
|
||||
class SimpleTestModel(torch.nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if store_config:
|
||||
self.config = config
|
||||
|
||||
@can_return_tuple
|
||||
def forward(self, x):
|
||||
if raise_in_forward:
|
||||
raise ValueError("Test error")
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=x,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
return SimpleTestModel(config)
|
||||
|
||||
def test_decorator_eager(self):
|
||||
"""Test that the can_return_tuple decorator works with eager mode."""
|
||||
|
||||
# test nothing is set
|
||||
config = PreTrainedConfig()
|
||||
model = self._get_model(config)
|
||||
inputs = torch.tensor(10)
|
||||
output = model(inputs)
|
||||
self.assertIsInstance(
|
||||
output, BaseModelOutput, "output should be a BaseModelOutput when return_dict is not set"
|
||||
)
|
||||
|
||||
# test all explicit cases
|
||||
for config_return_dict in [True, False, None]:
|
||||
for return_dict in [True, False, None]:
|
||||
config = PreTrainedConfig(return_dict=config_return_dict)
|
||||
model = self._get_model(config)
|
||||
output = model(torch.tensor(10), return_dict=return_dict)
|
||||
|
||||
expected_type = (
|
||||
tuple
|
||||
if return_dict is False
|
||||
else (tuple if config_return_dict is False and return_dict is None else BaseModelOutput)
|
||||
)
|
||||
if config_return_dict is None and return_dict is None:
|
||||
expected_type = tuple
|
||||
message = f"output should be a {expected_type.__name__} when config.return_dict={config_return_dict} and return_dict={return_dict}"
|
||||
self.assertIsInstance(output, expected_type, message)
|
||||
|
||||
@pytest.mark.torch_compile_test
|
||||
def test_decorator_compiled(self):
|
||||
"""Test that the can_return_tuple decorator works with compiled mode."""
|
||||
config = PreTrainedConfig()
|
||||
|
||||
# Output object
|
||||
model = self._get_model(config)
|
||||
compiled_model = torch.compile(model)
|
||||
output = compiled_model(torch.tensor(10))
|
||||
self.assertIsInstance(output, BaseModelOutput)
|
||||
|
||||
# Tuple output
|
||||
model = self._get_model(config)
|
||||
compiled_model = torch.compile(model)
|
||||
output = compiled_model(torch.tensor(10), return_dict=False)
|
||||
self.assertIsInstance(output, tuple)
|
||||
|
||||
@pytest.mark.torch_export_test
|
||||
def test_decorator_torch_export(self):
|
||||
"""Test that the can_return_tuple decorator works with torch.export."""
|
||||
config = PreTrainedConfig()
|
||||
model = self._get_model(config)
|
||||
torch.export.export(model, args=(torch.tensor(10),))
|
||||
|
||||
def test_attribute_cleanup(self):
|
||||
"""Test that the `_is_top_level_module` attribute is removed after the forward call."""
|
||||
|
||||
config = PreTrainedConfig(return_dict=False)
|
||||
inputs = torch.tensor(10)
|
||||
|
||||
# working case
|
||||
model = self._get_model(config)
|
||||
output = model(inputs)
|
||||
|
||||
self.assertIsInstance(output, tuple)
|
||||
for name, module in model.named_modules():
|
||||
self.assertFalse(
|
||||
hasattr(module, "_is_top_level_module"),
|
||||
f"Module `{name}` should not have `_is_top_level_module` attribute",
|
||||
)
|
||||
|
||||
# model without config
|
||||
no_config_model = self._get_model(config, store_config=False)
|
||||
output = no_config_model(inputs)
|
||||
|
||||
self.assertIsInstance(output, BaseModelOutput)
|
||||
for name, module in no_config_model.named_modules():
|
||||
self.assertFalse(
|
||||
hasattr(module, "_is_top_level_module"),
|
||||
f"Module `{name}` should not have `_is_top_level_module` attribute",
|
||||
)
|
||||
|
||||
# model with raise in forward
|
||||
model_with_raise = self._get_model(config, raise_in_forward=True)
|
||||
with self.assertRaises(ValueError):
|
||||
model_with_raise(inputs)
|
||||
|
||||
for name, module in model_with_raise.named_modules():
|
||||
self.assertFalse(
|
||||
hasattr(module, "_is_top_level_module"),
|
||||
f"Module `{name}` should not have `_is_top_level_module` attribute",
|
||||
)
|
||||
|
||||
|
||||
class RetryTest(unittest.TestCase):
|
||||
def test_succeeds_on_first_attempt(self):
|
||||
"""Test that retry returns immediately when the wrapped call succeeds."""
|
||||
|
||||
@retry(max_retries=3, exceptions=(ValueError,))
|
||||
def succeed():
|
||||
return "ok"
|
||||
|
||||
self.assertEqual(succeed(), "ok")
|
||||
|
||||
@patch("transformers.utils.generic.time.sleep")
|
||||
def test_retries_then_succeeds(self, mock_sleep):
|
||||
"""Test that retry sleeps and eventually returns after transient failures."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
@retry(max_retries=3, initial_delay=1.0, jitter=False, exceptions=(ValueError,))
|
||||
def fail_twice():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise ValueError("transient")
|
||||
return "recovered"
|
||||
|
||||
self.assertEqual(fail_twice(), "recovered")
|
||||
self.assertEqual(call_count, 3)
|
||||
self.assertEqual(mock_sleep.call_count, 2)
|
||||
|
||||
@patch("transformers.utils.generic.time.sleep")
|
||||
def test_raises_after_max_retries(self, mock_sleep):
|
||||
"""Test that retry re-raises the configured exception after exhausting retries."""
|
||||
|
||||
@retry(max_retries=2, initial_delay=0.1, jitter=False, exceptions=(RuntimeError,))
|
||||
def always_fail():
|
||||
raise RuntimeError("permanent")
|
||||
|
||||
with self.assertRaises(RuntimeError, msg="permanent"):
|
||||
always_fail()
|
||||
self.assertEqual(mock_sleep.call_count, 1)
|
||||
|
||||
@patch("transformers.utils.generic.time.sleep")
|
||||
def test_non_matching_exception_propagates_immediately(self, mock_sleep):
|
||||
"""Test that retry does not intercept exceptions outside the configured set."""
|
||||
|
||||
@retry(max_retries=5, exceptions=(ValueError,))
|
||||
def raise_type_error():
|
||||
raise TypeError("wrong type")
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
raise_type_error()
|
||||
self.assertEqual(mock_sleep.call_count, 0)
|
||||
|
||||
@patch("transformers.utils.generic.time.sleep")
|
||||
def test_exponential_backoff(self, mock_sleep):
|
||||
"""Test that retry doubles the delay between attempts when jitter is disabled."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
@retry(max_retries=4, initial_delay=1.0, max_delay=10.0, jitter=False, exceptions=(ValueError,))
|
||||
def fail_thrice():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 4:
|
||||
raise ValueError("retry")
|
||||
return "done"
|
||||
|
||||
fail_thrice()
|
||||
delays = [call[0][0] for call in mock_sleep.call_args_list]
|
||||
self.assertEqual(delays, [1.0, 2.0, 4.0])
|
||||
|
||||
@patch("transformers.utils.generic.time.sleep")
|
||||
def test_max_delay_cap(self, mock_sleep):
|
||||
"""Test that retry caps exponential backoff at the configured maximum delay."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
@retry(max_retries=5, initial_delay=8.0, max_delay=10.0, jitter=False, exceptions=(ValueError,))
|
||||
def fail_four():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 5:
|
||||
raise ValueError("retry")
|
||||
return "done"
|
||||
|
||||
fail_four()
|
||||
delays = [call[0][0] for call in mock_sleep.call_args_list]
|
||||
# 8.0, then min(16, 10)=10, min(20, 10)=10, min(20, 10)=10
|
||||
self.assertEqual(delays, [8.0, 10.0, 10.0, 10.0])
|
||||
|
||||
def test_preserves_function_metadata(self):
|
||||
"""Test that retry preserves the wrapped function metadata."""
|
||||
|
||||
@retry(exceptions=(ValueError,))
|
||||
def my_func():
|
||||
"""My docstring."""
|
||||
pass
|
||||
|
||||
self.assertEqual(my_func.__name__, "my_func")
|
||||
self.assertEqual(my_func.__doc__, "My docstring.")
|
||||
492
tests/utils/test_hf_argparser.py
Normal file
492
tests/utils/test_hf_argparser.py
Normal file
@@ -0,0 +1,492 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union, get_args, get_origin
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from transformers.hf_argparser import make_choice_type_function, string_to_bool
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
# Since Python 3.10, we can use the builtin `|` operator for Union types
|
||||
# See PEP 604: https://peps.python.org/pep-0604
|
||||
is_python_no_less_than_3_10 = sys.version_info >= (3, 10)
|
||||
|
||||
|
||||
def list_field(default=None, metadata=None):
|
||||
return field(default_factory=lambda: default, metadata=metadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicExample:
|
||||
foo: int
|
||||
bar: float
|
||||
baz: str
|
||||
flag: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class WithDefaultExample:
|
||||
foo: int = 42
|
||||
baz: str = field(default="toto", metadata={"help": "help message"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class WithDefaultBoolExample:
|
||||
foo: bool = False
|
||||
baz: bool = True
|
||||
opt: bool | None = None
|
||||
|
||||
|
||||
class BasicEnum(Enum):
|
||||
titi = "titi"
|
||||
toto = "toto"
|
||||
|
||||
|
||||
class MixedTypeEnum(Enum):
|
||||
titi = "titi"
|
||||
toto = "toto"
|
||||
fourtytwo = 42
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumExample:
|
||||
foo: BasicEnum = "toto"
|
||||
|
||||
def __post_init__(self):
|
||||
self.foo = BasicEnum(self.foo)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixedTypeEnumExample:
|
||||
foo: MixedTypeEnum = "toto"
|
||||
|
||||
def __post_init__(self):
|
||||
self.foo = MixedTypeEnum(self.foo)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptionalExample:
|
||||
foo: int | None = None
|
||||
bar: float | None = field(default=None, metadata={"help": "help message"})
|
||||
baz: str | None = None
|
||||
ces: list[str] | None = list_field(default=[])
|
||||
des: list[int] | None = list_field(default=[])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListExample:
|
||||
foo_int: list[int] = list_field(default=[])
|
||||
bar_int: list[int] = list_field(default=[1, 2, 3])
|
||||
foo_str: list[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
foo_float: list[float] = list_field(default=[0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequiredExample:
|
||||
required_list: list[int] = field()
|
||||
required_str: str = field()
|
||||
required_enum: BasicEnum = field()
|
||||
|
||||
def __post_init__(self):
|
||||
self.required_enum = BasicEnum(self.required_enum)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StringLiteralAnnotationExample:
|
||||
foo: int
|
||||
required_enum: "BasicEnum" = field()
|
||||
opt: "bool | None" = None
|
||||
baz: "str" = field(default="toto", metadata={"help": "help message"})
|
||||
foo_str: "list[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
|
||||
|
||||
|
||||
if is_python_no_less_than_3_10:
|
||||
|
||||
@dataclass
|
||||
class WithDefaultBoolExamplePep604:
|
||||
foo: bool = False
|
||||
baz: bool = True
|
||||
opt: bool | None = None
|
||||
|
||||
@dataclass
|
||||
class OptionalExamplePep604:
|
||||
foo: int | None = None
|
||||
bar: float | None = field(default=None, metadata={"help": "help message"})
|
||||
baz: str | None = None
|
||||
ces: list[str] | None = list_field(default=[])
|
||||
des: list[int] | None = list_field(default=[])
|
||||
|
||||
|
||||
class HfArgumentParserTest(unittest.TestCase):
|
||||
def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser):
|
||||
"""
|
||||
Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances.
|
||||
"""
|
||||
self.assertEqual(len(a._actions), len(b._actions))
|
||||
for x, y in zip(a._actions, b._actions):
|
||||
xx = {k: v for k, v in vars(x).items() if k != "container"}
|
||||
yy = {k: v for k, v in vars(y).items() if k != "container"}
|
||||
|
||||
# Choices with mixed type have custom function as "type"
|
||||
# So we need to compare results directly for equality
|
||||
if xx.get("choices") and yy.get("choices"):
|
||||
for expected_choice in yy["choices"] + xx["choices"]:
|
||||
self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice))
|
||||
del xx["type"], yy["type"]
|
||||
|
||||
self.assertEqual(xx, yy)
|
||||
|
||||
def test_00_basic(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", type=int, required=True)
|
||||
expected.add_argument("--bar", type=float, required=True)
|
||||
expected.add_argument("--baz", type=str, required=True)
|
||||
expected.add_argument("--flag", type=string_to_bool, default=False, const=True, nargs="?")
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = ["--foo", "1", "--baz", "quux", "--bar", "0.5"]
|
||||
(example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False)
|
||||
self.assertFalse(example.flag)
|
||||
|
||||
def test_01_with_default(self):
|
||||
parser = HfArgumentParser(WithDefaultExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", default=42, type=int)
|
||||
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_02_with_default_bool(self):
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
|
||||
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
|
||||
# A boolean no_* argument always has to come after its "default: True" regular counter-part
|
||||
# and its default must be set to False
|
||||
expected.add_argument("--no_baz", "--no-baz", action="store_false", default=False, dest="baz")
|
||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||
|
||||
dataclass_types = [WithDefaultBoolExample]
|
||||
if is_python_no_less_than_3_10:
|
||||
dataclass_types.append(WithDefaultBoolExamplePep604)
|
||||
|
||||
for dataclass_type in dataclass_types:
|
||||
parser = HfArgumentParser(dataclass_type)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "--no_baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "--no-baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "--baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
|
||||
|
||||
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
|
||||
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
|
||||
|
||||
def test_03_with_enum(self):
|
||||
parser = HfArgumentParser(MixedTypeEnumExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument(
|
||||
"--foo",
|
||||
default="toto",
|
||||
choices=["titi", "toto", 42],
|
||||
type=make_choice_type_function(["titi", "toto", 42]),
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args.foo, "toto")
|
||||
enum_ex = parser.parse_args_into_dataclasses([])[0]
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.toto)
|
||||
|
||||
args = parser.parse_args(["--foo", "titi"])
|
||||
self.assertEqual(args.foo, "titi")
|
||||
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.titi)
|
||||
|
||||
args = parser.parse_args(["--foo", "42"])
|
||||
self.assertEqual(args.foo, 42)
|
||||
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
|
||||
|
||||
def test_04_with_literal(self):
|
||||
@dataclass
|
||||
class LiteralExample:
|
||||
foo: Literal["titi", "toto", 42] = "toto"
|
||||
|
||||
parser = HfArgumentParser(LiteralExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument(
|
||||
"--foo",
|
||||
default="toto",
|
||||
choices=("titi", "toto", 42),
|
||||
type=make_choice_type_function(["titi", "toto", 42]),
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args.foo, "toto")
|
||||
|
||||
args = parser.parse_args(["--foo", "titi"])
|
||||
self.assertEqual(args.foo, "titi")
|
||||
|
||||
args = parser.parse_args(["--foo", "42"])
|
||||
self.assertEqual(args.foo, 42)
|
||||
|
||||
def test_05_with_list(self):
|
||||
parser = HfArgumentParser(ListExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo_int", "--foo-int", nargs="+", default=[], type=int)
|
||||
expected.add_argument("--bar_int", "--bar-int", nargs="+", default=[1, 2, 3], type=int)
|
||||
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
expected.add_argument("--foo_float", "--foo-float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
|
||||
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(
|
||||
args,
|
||||
Namespace(foo_int=[], bar_int=[1, 2, 3], foo_str=["Hallo", "Bonjour", "Hello"], foo_float=[0.1, 0.2, 0.3]),
|
||||
)
|
||||
|
||||
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
|
||||
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
||||
|
||||
args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split())
|
||||
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
||||
|
||||
def test_06_with_optional(self):
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", default=None, type=int)
|
||||
expected.add_argument("--bar", default=None, type=float, help="help message")
|
||||
expected.add_argument("--baz", default=None, type=str)
|
||||
expected.add_argument("--ces", nargs="+", default=[], type=str)
|
||||
expected.add_argument("--des", nargs="+", default=[], type=int)
|
||||
|
||||
dataclass_types = [OptionalExample]
|
||||
if is_python_no_less_than_3_10:
|
||||
dataclass_types.append(OptionalExamplePep604)
|
||||
|
||||
for dataclass_type in dataclass_types:
|
||||
parser = HfArgumentParser(dataclass_type)
|
||||
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))
|
||||
|
||||
args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
|
||||
self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
|
||||
|
||||
def test_07_with_required(self):
|
||||
parser = HfArgumentParser(RequiredExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--required_list", "--required-list", nargs="+", type=int, required=True)
|
||||
expected.add_argument("--required_str", "--required-str", type=str, required=True)
|
||||
expected.add_argument(
|
||||
"--required_enum",
|
||||
"--required-enum",
|
||||
type=make_choice_type_function(["titi", "toto"]),
|
||||
choices=["titi", "toto"],
|
||||
required=True,
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_08_with_string_literal_annotation(self):
|
||||
parser = HfArgumentParser(StringLiteralAnnotationExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", type=int, required=True)
|
||||
expected.add_argument(
|
||||
"--required_enum",
|
||||
"--required-enum",
|
||||
type=make_choice_type_function(["titi", "toto"]),
|
||||
choices=["titi", "toto"],
|
||||
required=True,
|
||||
)
|
||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
||||
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_09_parse_dict(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict = {
|
||||
"foo": 12,
|
||||
"bar": 3.14,
|
||||
"baz": "42",
|
||||
"flag": True,
|
||||
}
|
||||
|
||||
parsed_args = parser.parse_dict(args_dict)[0]
|
||||
args = BasicExample(**args_dict)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_10_parse_dict_extra_key(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict = {
|
||||
"foo": 12,
|
||||
"bar": 3.14,
|
||||
"baz": "42",
|
||||
"flag": True,
|
||||
"extra": 42,
|
||||
}
|
||||
|
||||
self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
|
||||
|
||||
def test_11_parse_json(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict_for_json = {
|
||||
"foo": 12,
|
||||
"bar": 3.14,
|
||||
"baz": "42",
|
||||
"flag": True,
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
temp_local_path = os.path.join(tmp_dir, "temp_json")
|
||||
os.mkdir(temp_local_path)
|
||||
with open(temp_local_path + ".json", "w+") as f:
|
||||
json.dump(args_dict_for_json, f)
|
||||
parsed_args = parser.parse_json_file(Path(temp_local_path + ".json"))[0]
|
||||
|
||||
args = BasicExample(**args_dict_for_json)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_12_parse_yaml(self):
|
||||
parser = HfArgumentParser(BasicExample)
|
||||
|
||||
args_dict_for_yaml = {
|
||||
"foo": 12,
|
||||
"bar": 3.14,
|
||||
"baz": "42",
|
||||
"flag": True,
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
temp_local_path = os.path.join(tmp_dir, "temp_yaml")
|
||||
os.mkdir(temp_local_path)
|
||||
with open(temp_local_path + ".yaml", "w+") as f:
|
||||
yaml.dump(args_dict_for_yaml, f)
|
||||
parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".yaml"))[0]
|
||||
args = BasicExample(**args_dict_for_yaml)
|
||||
self.assertEqual(parsed_args, args)
|
||||
|
||||
def test_13_valid_dict_annotation(self):
|
||||
"""
|
||||
Tests to make sure that `dict` based annotations
|
||||
are correctly made in the `TrainingArguments`.
|
||||
|
||||
If this fails, a type annotation change is
|
||||
needed on a new input
|
||||
"""
|
||||
base_list = TrainingArguments._VALID_DICT_FIELDS.copy()
|
||||
args = TrainingArguments
|
||||
|
||||
# First find any annotations that contain `dict`
|
||||
fields = args.__dataclass_fields__
|
||||
|
||||
raw_dict_fields = []
|
||||
optional_dict_fields = []
|
||||
|
||||
for field_ in fields.values():
|
||||
# First verify raw dict
|
||||
if field_.type is dict:
|
||||
raw_dict_fields.append(field_)
|
||||
# Next check for `Union` or `Optional`
|
||||
elif get_origin(field_.type) == Union:
|
||||
if any(arg is dict for arg in get_args(field_.type)):
|
||||
optional_dict_fields.append(field_)
|
||||
|
||||
# First check: anything in `raw_dict_fields` is very bad
|
||||
self.assertEqual(
|
||||
len(raw_dict_fields),
|
||||
0,
|
||||
f"Found invalid raw `dict` types in the `TrainingArgument` typings, which are {raw_dict_fields}. "
|
||||
"This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
|
||||
)
|
||||
|
||||
# Next check raw annotations
|
||||
for field_ in optional_dict_fields:
|
||||
args = get_args(field_.type)
|
||||
# These should be returned as `dict`, `str`, ...
|
||||
# we only care about the first two
|
||||
self.assertIn(
|
||||
dict,
|
||||
args,
|
||||
f"Expected field `{field_.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, but `dict` not found. Please fix this.",
|
||||
)
|
||||
self.assertIn(
|
||||
str,
|
||||
args,
|
||||
f"Expected field `{field_.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, but `str` not found. Please fix this.",
|
||||
)
|
||||
|
||||
# Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
|
||||
for field_ in optional_dict_fields:
|
||||
self.assertIn(
|
||||
field.name,
|
||||
base_list,
|
||||
f"Optional dict field `{field_.name}` is not in the base list of valid fields. Please add it to `TrainingArguments._VALID_DICT_FIELDS`",
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_14_valid_dict_input_parsing(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}',
|
||||
)
|
||||
self.assertEqual(args.accelerator_config.split_batches, True)
|
||||
self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
|
||||
|
||||
def test_15_integration_training_args(self):
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
self.assertIsNotNone(parser)
|
||||
|
||||
@require_torch
|
||||
@patch("sys.argv", ["test.py", "--accelerator_config", '{"gradient_accumulation_kwargs": {"num_steps": 2}}'])
|
||||
def test_16_cli_input_parsing(self):
|
||||
parser = HfArgumentParser(TrainingArguments)
|
||||
training_args = parser.parse_args_into_dataclasses()[0]
|
||||
self.assertEqual(training_args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
|
||||
207
tests/utils/test_hub_utils.py
Normal file
207
tests/utils/test_hub_utils.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import constants, hf_hub_download
|
||||
from huggingface_hub.errors import HfHubHTTPError, LocalEntryNotFoundError, OfflineModeIsEnabled
|
||||
|
||||
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, has_file, list_repo_templates
|
||||
|
||||
|
||||
RANDOM_BERT = "hf-internal-testing/tiny-random-bert"
|
||||
TINY_BERT_PT_ONLY = "hf-internal-testing/tiny-bert-pt-only"
|
||||
CACHE_DIR = os.path.join(constants.HF_HUB_CACHE, "models--hf-internal-testing--tiny-random-bert")
|
||||
FULL_COMMIT_HASH = "9b8c223d42b2188cb49d29af482996f9d0f3e5a6"
|
||||
|
||||
GATED_REPO = "hf-internal-testing/dummy-gated-model"
|
||||
README_FILE = "README.md"
|
||||
|
||||
|
||||
class GetFromCacheTests(unittest.TestCase):
|
||||
def test_cached_file(self):
|
||||
archive_file = cached_file(RANDOM_BERT, CONFIG_NAME)
|
||||
# Should have downloaded the file in here
|
||||
self.assertTrue(os.path.isdir(CACHE_DIR))
|
||||
# Cache should contain at least those three subfolders:
|
||||
for subfolder in ["blobs", "refs", "snapshots"]:
|
||||
self.assertTrue(os.path.isdir(os.path.join(CACHE_DIR, subfolder)))
|
||||
with open(os.path.join(CACHE_DIR, "refs", "main")) as f:
|
||||
main_commit = f.read()
|
||||
self.assertEqual(archive_file, os.path.join(CACHE_DIR, "snapshots", main_commit, CONFIG_NAME))
|
||||
self.assertTrue(os.path.isfile(archive_file))
|
||||
|
||||
# File is cached at the same place the second time.
|
||||
new_archive_file = cached_file(RANDOM_BERT, CONFIG_NAME)
|
||||
self.assertEqual(archive_file, new_archive_file)
|
||||
|
||||
# Using a specific revision to test the full commit hash.
|
||||
archive_file = cached_file(RANDOM_BERT, CONFIG_NAME, revision="9b8c223")
|
||||
self.assertEqual(archive_file, os.path.join(CACHE_DIR, "snapshots", FULL_COMMIT_HASH, CONFIG_NAME))
|
||||
|
||||
def test_cached_file_errors(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
|
||||
_ = cached_file("tiny-random-bert", CONFIG_NAME)
|
||||
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
|
||||
_ = cached_file(RANDOM_BERT, CONFIG_NAME, revision="aaaa")
|
||||
|
||||
with self.assertRaisesRegex(EnvironmentError, "does not appear to have a file named"):
|
||||
_ = cached_file(RANDOM_BERT, "conf")
|
||||
|
||||
def test_non_existence_is_cached(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "does not appear to have a file named"):
|
||||
_ = cached_file(RANDOM_BERT, "conf")
|
||||
|
||||
with open(os.path.join(CACHE_DIR, "refs", "main")) as f:
|
||||
main_commit = f.read()
|
||||
self.assertTrue(os.path.isfile(os.path.join(CACHE_DIR, ".no_exist", main_commit, "conf")))
|
||||
|
||||
path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_missing_entries=False)
|
||||
self.assertIsNone(path)
|
||||
|
||||
path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False)
|
||||
self.assertIsNone(path)
|
||||
|
||||
# Under the mock environment, hf_hub_download will always raise an HTTPError
|
||||
with mock.patch(
|
||||
"transformers.utils.hub.hf_hub_download",
|
||||
side_effect=HfHubHTTPError("failed", response=mock.Mock(status_code=404)),
|
||||
) as mock_head:
|
||||
path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_connection_errors=False)
|
||||
self.assertIsNone(path)
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_has_file(self):
|
||||
self.assertTrue(has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME))
|
||||
self.assertFalse(has_file(TINY_BERT_PT_ONLY, "tf_model.h5"))
|
||||
self.assertFalse(has_file(TINY_BERT_PT_ONLY, "flax_model.msgpack"))
|
||||
|
||||
def test_has_file_in_cache(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Empty cache dir + offline mode => return False
|
||||
assert not has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
|
||||
|
||||
# Populate cache dir
|
||||
# TODO: only necessary for read-only cache systems; replace with a shared helper
|
||||
with unittest.mock.patch.dict(os.environ, {"HF_XET_CACHE": tmp_dir}):
|
||||
hf_hub_download(TINY_BERT_PT_ONLY, WEIGHTS_NAME, cache_dir=tmp_dir)
|
||||
|
||||
# Cache dir + offline mode => return True
|
||||
assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
|
||||
|
||||
def test_get_file_from_repo_distant(self):
|
||||
# should return None if the file does not exist
|
||||
self.assertIsNone(
|
||||
cached_file(
|
||||
"google-bert/bert-base-cased",
|
||||
"ahah.txt",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
)
|
||||
|
||||
# The function raises if the repository does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
|
||||
cached_file(
|
||||
"bert-base-case",
|
||||
CONFIG_NAME,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
# The function raises if the revision does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
|
||||
cached_file(
|
||||
"google-bert/bert-base-cased",
|
||||
CONFIG_NAME,
|
||||
revision="ahaha",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
resolved_file = cached_file(
|
||||
"google-bert/bert-base-cased",
|
||||
CONFIG_NAME,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
# The name is the cached name which is not very easy to test, so instead we load the content.
|
||||
config = json.loads(open(resolved_file).read())
|
||||
self.assertEqual(config["hidden_size"], 768)
|
||||
|
||||
def test_get_file_from_repo_local(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
filename = Path(tmp_dir) / "a.txt"
|
||||
filename.touch()
|
||||
self.assertEqual(
|
||||
cached_file(
|
||||
tmp_dir,
|
||||
"a.txt",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
),
|
||||
str(filename),
|
||||
)
|
||||
|
||||
self.assertIsNone(
|
||||
cached_file(
|
||||
tmp_dir,
|
||||
"b.txt",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_file_gated_repo(self):
|
||||
"""Test download file from a gated repo fails with correct message when not authenticated."""
|
||||
with self.assertRaisesRegex(EnvironmentError, "You are trying to access a gated repo."):
|
||||
# All files except README.md are protected on a gated repo.
|
||||
cached_file(GATED_REPO, "gated_file.txt", token=False)
|
||||
|
||||
def test_has_file_gated_repo(self):
|
||||
"""Test check file existence from a gated repo fails with correct message when not authenticated."""
|
||||
with self.assertRaisesRegex(EnvironmentError, "is a gated repository"):
|
||||
# All files except README.md are protected on a gated repo.
|
||||
has_file(GATED_REPO, "gated_file.txt", token=False)
|
||||
|
||||
def test_cached_files_exception_raised(self):
|
||||
"""Test that unhadled exceptions, e.g. ModuleNotFoundError, is properly re-raised by cached_files when hf_hub_download fails."""
|
||||
with mock.patch(
|
||||
"transformers.utils.hub.hf_hub_download", side_effect=ModuleNotFoundError("No module named 'MockModule'")
|
||||
):
|
||||
with self.assertRaises(ModuleNotFoundError):
|
||||
# The error should be re-raised by cached_files, not caught in the exception handling block
|
||||
cached_file(RANDOM_BERT, "nonexistent.json")
|
||||
|
||||
|
||||
class OfflineModeTests(unittest.TestCase):
|
||||
def test_list_repo_templates_w_offline(self):
|
||||
with mock.patch("transformers.utils.hub.HfApi.list_repo_tree", side_effect=OfflineModeIsEnabled()):
|
||||
with mock.patch(
|
||||
"transformers.utils.hub.HfApi.snapshot_download",
|
||||
side_effect=LocalEntryNotFoundError("no snapshot found"),
|
||||
):
|
||||
self.assertEqual(list_repo_templates(RANDOM_BERT, local_files_only=False), [])
|
||||
224
tests/utils/test_image_processing_utils.py
Normal file
224
tests/utils/test_image_processing_utils.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from transformers import AutoImageProcessor, ViTImageProcessor, ViTImageProcessorFast
|
||||
from transformers.image_processing_utils import get_size_dict
|
||||
from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_image_processing import CustomImageProcessor # noqa E402
|
||||
|
||||
|
||||
SAMPLE_IMAGE_PROCESSING_CONFIG_DIR = get_tests_dir("fixtures")
|
||||
|
||||
|
||||
class ImageProcessorUtilTester(unittest.TestCase):
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"failed", request=mock.Mock(), response=mock.Mock()
|
||||
)
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
|
||||
_ = ViTImageProcessorFast.from_pretrained("hf-internal-testing/tiny-random-vit")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
|
||||
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
|
||||
_ = ViTImageProcessorFast.from_pretrained("hf-internal-testing/tiny-random-vit")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_image_processor_from_pretrained_subfolder(self):
|
||||
with self.assertRaises(OSError):
|
||||
# config is in subfolder, the following should not work without specifying the subfolder
|
||||
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/stable-diffusion-all-variants")
|
||||
|
||||
config = AutoImageProcessor.from_pretrained(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", subfolder="feature_extractor"
|
||||
)
|
||||
|
||||
self.assertIsNotNone(config)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ImageProcessorPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_fast(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
image_processor = ViTImageProcessorFast.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessorFast.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_via_save_pretrained_fast(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
image_processor = ViTImageProcessorFast.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessorFast.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_in_organization_fast(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
image_processor = ViTImageProcessorFast.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessorFast.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_in_organization_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_in_organization_via_save_pretrained_fast(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
image_processor = ViTImageProcessorFast.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessorFast.from_pretrained(tmp_repo.repo_id)
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_dynamic_image_processor(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
CustomImageProcessor.register_for_auto_class()
|
||||
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
|
||||
image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(
|
||||
image_processor.auto_map,
|
||||
{"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"},
|
||||
)
|
||||
|
||||
new_image_processor = AutoImageProcessor.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
|
||||
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
|
||||
|
||||
|
||||
class ImageProcessingUtilsTester(unittest.TestCase):
|
||||
def test_get_size_dict(self):
|
||||
# Test a dict with the wrong keys raises an error
|
||||
inputs = {"wrong_key": 224}
|
||||
with self.assertRaises(ValueError):
|
||||
get_size_dict(inputs)
|
||||
|
||||
inputs = {"height": 224}
|
||||
with self.assertRaises(ValueError):
|
||||
get_size_dict(inputs)
|
||||
|
||||
inputs = {"width": 224, "shortest_edge": 224}
|
||||
with self.assertRaises(ValueError):
|
||||
get_size_dict(inputs)
|
||||
|
||||
# Test a dict with the correct keys is returned as is
|
||||
inputs = {"height": 224, "width": 224}
|
||||
outputs = get_size_dict(inputs)
|
||||
self.assertEqual(outputs, inputs)
|
||||
|
||||
inputs = {"shortest_edge": 224}
|
||||
outputs = get_size_dict(inputs)
|
||||
self.assertEqual(outputs, {"shortest_edge": 224})
|
||||
|
||||
inputs = {"longest_edge": 224, "shortest_edge": 224}
|
||||
outputs = get_size_dict(inputs)
|
||||
self.assertEqual(outputs, {"longest_edge": 224, "shortest_edge": 224})
|
||||
|
||||
# Test a single int value which represents (size, size)
|
||||
outputs = get_size_dict(224)
|
||||
self.assertEqual(outputs, {"height": 224, "width": 224})
|
||||
|
||||
# Test a single int value which represents the shortest edge
|
||||
outputs = get_size_dict(224, default_to_square=False)
|
||||
self.assertEqual(outputs, {"shortest_edge": 224})
|
||||
|
||||
# Test a tuple of ints which represents (height, width)
|
||||
outputs = get_size_dict((150, 200))
|
||||
self.assertEqual(outputs, {"height": 150, "width": 200})
|
||||
|
||||
# Test a tuple of ints which represents (width, height)
|
||||
outputs = get_size_dict((150, 200), height_width_order=False)
|
||||
self.assertEqual(outputs, {"height": 200, "width": 150})
|
||||
|
||||
# Test an int representing the shortest edge and max_size which represents the longest edge
|
||||
outputs = get_size_dict(224, max_size=256, default_to_square=False)
|
||||
self.assertEqual(outputs, {"shortest_edge": 224, "longest_edge": 256})
|
||||
|
||||
# Test int with default_to_square=True and max_size fails
|
||||
with self.assertRaises(ValueError):
|
||||
get_size_dict(224, max_size=256, default_to_square=True)
|
||||
899
tests/utils/test_image_utils.py
Normal file
899
tests/utils/test_image_utils.py
Normal file
@@ -0,0 +1,899 @@
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import codecs
|
||||
import unittest
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
|
||||
from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.image_utils import (
|
||||
ChannelDimension,
|
||||
get_channel_dimension_axis,
|
||||
make_flat_list_of_images,
|
||||
make_list_of_images,
|
||||
make_nested_list_of_images,
|
||||
)
|
||||
from transformers.testing_utils import is_flaky, require_torch, require_vision
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
import PIL.Image
|
||||
|
||||
from transformers import ImageFeatureExtractionMixin
|
||||
from transformers.image_utils import get_image_size, infer_channel_dimension_format, load_image
|
||||
|
||||
|
||||
def get_image_from_hub_dataset(dataset_id: str, filename: str, revision: str | None = None) -> "PIL.Image.Image":
|
||||
path = hf_hub_download(dataset_id, filename, repo_type="dataset", revision=revision)
|
||||
return PIL.Image.open(path)
|
||||
|
||||
|
||||
def get_random_image(height, width):
|
||||
random_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
return PIL.Image.fromarray(random_array)
|
||||
|
||||
|
||||
@require_vision
|
||||
class ImageFeatureExtractionTester(unittest.TestCase):
|
||||
def test_conversion_image_to_array(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
image = get_random_image(16, 32)
|
||||
|
||||
# Conversion with defaults (rescale + channel first)
|
||||
array1 = feature_extractor.to_numpy_array(image)
|
||||
self.assertTrue(array1.dtype, np.float32)
|
||||
self.assertEqual(array1.shape, (3, 16, 32))
|
||||
|
||||
# Conversion with rescale and not channel first
|
||||
array2 = feature_extractor.to_numpy_array(image, channel_first=False)
|
||||
self.assertTrue(array2.dtype, np.float32)
|
||||
self.assertEqual(array2.shape, (16, 32, 3))
|
||||
self.assertTrue(np.array_equal(array1, array2.transpose(2, 0, 1)))
|
||||
|
||||
# Conversion with no rescale and channel first
|
||||
array3 = feature_extractor.to_numpy_array(image, rescale=False)
|
||||
self.assertTrue(array3.dtype, np.uint8)
|
||||
self.assertEqual(array3.shape, (3, 16, 32))
|
||||
self.assertTrue(np.array_equal(array1, array3.astype(np.float32) * (1 / 255.0)))
|
||||
|
||||
# Conversion with no rescale and not channel first
|
||||
array4 = feature_extractor.to_numpy_array(image, rescale=False, channel_first=False)
|
||||
self.assertTrue(array4.dtype, np.uint8)
|
||||
self.assertEqual(array4.shape, (16, 32, 3))
|
||||
self.assertTrue(np.array_equal(array2, array4.astype(np.float32) * (1 / 255.0)))
|
||||
|
||||
def test_conversion_array_to_array(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
array = np.random.randint(0, 256, (16, 32, 3), dtype=np.uint8)
|
||||
|
||||
# By default, rescale (for an array of ints) and channel permute
|
||||
array1 = feature_extractor.to_numpy_array(array)
|
||||
self.assertTrue(array1.dtype, np.float32)
|
||||
self.assertEqual(array1.shape, (3, 16, 32))
|
||||
self.assertTrue(np.array_equal(array1, array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0)))
|
||||
|
||||
# Same with no permute
|
||||
array2 = feature_extractor.to_numpy_array(array, channel_first=False)
|
||||
self.assertTrue(array2.dtype, np.float32)
|
||||
self.assertEqual(array2.shape, (16, 32, 3))
|
||||
self.assertTrue(np.array_equal(array2, array.astype(np.float32) * (1 / 255.0)))
|
||||
|
||||
# Force rescale to False
|
||||
array3 = feature_extractor.to_numpy_array(array, rescale=False)
|
||||
self.assertTrue(array3.dtype, np.uint8)
|
||||
self.assertEqual(array3.shape, (3, 16, 32))
|
||||
self.assertTrue(np.array_equal(array3, array.transpose(2, 0, 1)))
|
||||
|
||||
# Force rescale to False and no channel permute
|
||||
array4 = feature_extractor.to_numpy_array(array, rescale=False, channel_first=False)
|
||||
self.assertTrue(array4.dtype, np.uint8)
|
||||
self.assertEqual(array4.shape, (16, 32, 3))
|
||||
self.assertTrue(np.array_equal(array4, array))
|
||||
|
||||
# Now test the default rescale for a float array (defaults to False)
|
||||
array5 = feature_extractor.to_numpy_array(array2)
|
||||
self.assertTrue(array5.dtype, np.float32)
|
||||
self.assertEqual(array5.shape, (3, 16, 32))
|
||||
self.assertTrue(np.array_equal(array5, array1))
|
||||
|
||||
def test_make_list_of_images_pil(self):
|
||||
# Test a single image is converted to a list of 1 image
|
||||
pil_image = get_random_image(16, 32)
|
||||
images_list = make_list_of_images(pil_image)
|
||||
self.assertIsInstance(images_list, list)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertIsInstance(images_list[0], PIL.Image.Image)
|
||||
|
||||
# Test a list of images is not modified
|
||||
images = [get_random_image(16, 32) for _ in range(4)]
|
||||
images_list = make_list_of_images(images)
|
||||
self.assertIsInstance(images_list, list)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertIsInstance(images_list[0], PIL.Image.Image)
|
||||
|
||||
def test_make_list_of_images_numpy(self):
|
||||
# Test a single image is converted to a list of 1 image
|
||||
images = np.random.randint(0, 256, (16, 32, 3))
|
||||
images_list = make_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertTrue(np.array_equal(images_list[0], images))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
# Test a batch of images is converted to a list of images
|
||||
images = np.random.randint(0, 256, (4, 16, 32, 3))
|
||||
images_list = make_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
# Test a list of images is not modified
|
||||
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
|
||||
images_list = make_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
# Test batched masks with no channel dimension are converted to a list of masks
|
||||
masks = np.random.randint(0, 2, (4, 16, 32))
|
||||
masks_list = make_list_of_images(masks, expected_ndims=2)
|
||||
self.assertEqual(len(masks_list), 4)
|
||||
self.assertTrue(np.array_equal(masks_list[0], masks[0]))
|
||||
self.assertIsInstance(masks_list, list)
|
||||
|
||||
@require_torch
|
||||
def test_make_list_of_images_torch(self):
|
||||
# Test a single image is converted to a list of 1 image
|
||||
images = torch.randint(0, 256, (16, 32, 3))
|
||||
images_list = make_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertTrue(np.array_equal(images_list[0], images))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
# Test a batch of images is converted to a list of images
|
||||
images = torch.randint(0, 256, (4, 16, 32, 3))
|
||||
images_list = make_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
# Test a list of images is left unchanged
|
||||
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
|
||||
images_list = make_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
def test_make_flat_list_of_images_pil(self):
|
||||
# Test a single image is converted to a list of 1 image
|
||||
pil_image = get_random_image(16, 32)
|
||||
images_list = make_flat_list_of_images(pil_image)
|
||||
self.assertIsInstance(images_list, list)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertIsInstance(images_list[0], PIL.Image.Image)
|
||||
|
||||
# Test a list of images is not modified
|
||||
images = [get_random_image(16, 32) for _ in range(4)]
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertIsInstance(images_list, list)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertIsInstance(images_list[0], PIL.Image.Image)
|
||||
|
||||
# Test a nested list of images is flattened
|
||||
images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)]
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertIsInstance(images_list, list)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertIsInstance(images_list[0], PIL.Image.Image)
|
||||
|
||||
def test_make_flat_list_of_images_numpy(self):
|
||||
# Test a single image is converted to a list of 1 image
|
||||
images = np.random.randint(0, 256, (16, 32, 3))
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertTrue(np.array_equal(images_list[0], images))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
# Test a 4d array of images is changed to a list of images
|
||||
images = np.random.randint(0, 256, (4, 16, 32, 3))
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertIsInstance(images_list, list)
|
||||
self.assertIsInstance(images_list[0], np.ndarray)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0]))
|
||||
|
||||
# Test a list of images is not modified
|
||||
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
# Test list of 4d array images is flattened
|
||||
images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 8)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0][0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
self.assertIsInstance(images_list[0], np.ndarray)
|
||||
|
||||
# Test nested list of images is flattened
|
||||
images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0][0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
@require_torch
|
||||
def test_make_flat_list_of_images_torch(self):
|
||||
# Test a single image is converted to a list of 1 image
|
||||
images = torch.randint(0, 256, (16, 32, 3))
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertTrue(np.array_equal(images_list[0], images))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
# Test a 4d tensors of images is changed to a list of images
|
||||
images = torch.randint(0, 256, (4, 16, 32, 3))
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertIsInstance(images_list, list)
|
||||
self.assertIsInstance(images_list[0], torch.Tensor)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0]))
|
||||
|
||||
# Test a list of images is not modified
|
||||
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
# Test list of 4d tensors of imagess is flattened
|
||||
images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 8)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0][0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
self.assertIsInstance(images_list[0], torch.Tensor)
|
||||
|
||||
# Test nested list of images is flattened
|
||||
images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
|
||||
images_list = make_flat_list_of_images(images)
|
||||
self.assertEqual(len(images_list), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0], images[0][0]))
|
||||
self.assertIsInstance(images_list, list)
|
||||
|
||||
def test_make_nested_list_of_images_pil(self):
|
||||
# Test a single image is converted to a nested list of 1 image
|
||||
pil_image = get_random_image(16, 32)
|
||||
images_list = make_nested_list_of_images(pil_image)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertEqual(len(images_list[0]), 1)
|
||||
self.assertIsInstance(images_list[0][0], PIL.Image.Image)
|
||||
|
||||
# Test a list of images is converted to a nested list of images
|
||||
images = [get_random_image(16, 32) for _ in range(4)]
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertEqual(len(images_list[0]), 4)
|
||||
self.assertIsInstance(images_list[0][0], PIL.Image.Image)
|
||||
|
||||
# Test a nested list of images is not modified
|
||||
images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)]
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertEqual(len(images_list), 2)
|
||||
self.assertEqual(len(images_list[0]), 2)
|
||||
self.assertIsInstance(images_list[0][0], PIL.Image.Image)
|
||||
|
||||
def test_make_nested_list_of_images_numpy(self):
|
||||
# Test a single image is converted to a nested list of 1 image
|
||||
images = np.random.randint(0, 256, (16, 32, 3))
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images))
|
||||
|
||||
# Test a 4d array of images is converted to a nested list of images
|
||||
images = np.random.randint(0, 256, (4, 16, 32, 3))
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertIsInstance(images_list[0][0], np.ndarray)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertEqual(len(images_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images[0]))
|
||||
|
||||
# Test a list of images is converted to a nested list of images
|
||||
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertEqual(len(images_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images[0]))
|
||||
|
||||
# Test a nested list of images is left unchanged
|
||||
images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertEqual(len(images_list), 2)
|
||||
self.assertEqual(len(images_list[0]), 2)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
|
||||
|
||||
# Test a list of 4d array images is converted to a nested list of images
|
||||
images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertIsInstance(images_list[0][0], np.ndarray)
|
||||
self.assertEqual(len(images_list), 2)
|
||||
self.assertEqual(len(images_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
|
||||
|
||||
@require_torch
|
||||
def test_make_nested_list_of_images_torch(self):
|
||||
# Test a single image is converted to a nested list of 1 image
|
||||
images = torch.randint(0, 256, (16, 32, 3))
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertEqual(len(images_list[0]), 1)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images))
|
||||
|
||||
# Test a 4d tensor of images is converted to a nested list of images
|
||||
images = torch.randint(0, 256, (4, 16, 32, 3))
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertIsInstance(images_list[0][0], torch.Tensor)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertEqual(len(images_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images[0]))
|
||||
|
||||
# Test a list of images is converted to a nested list of images
|
||||
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertEqual(len(images_list), 1)
|
||||
self.assertEqual(len(images_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images[0]))
|
||||
|
||||
# Test a nested list of images is left unchanged
|
||||
images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertEqual(len(images_list), 2)
|
||||
self.assertEqual(len(images_list[0]), 2)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
|
||||
|
||||
# Test a list of 4d tensor images is converted to a nested list of images
|
||||
images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
|
||||
images_list = make_nested_list_of_images(images)
|
||||
self.assertIsInstance(images_list[0], list)
|
||||
self.assertIsInstance(images_list[0][0], torch.Tensor)
|
||||
self.assertEqual(len(images_list), 2)
|
||||
self.assertEqual(len(images_list[0]), 4)
|
||||
self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
|
||||
|
||||
@require_torch
|
||||
def test_conversion_torch_to_array(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
tensor = torch.randint(0, 256, (16, 32, 3))
|
||||
array = tensor.numpy()
|
||||
|
||||
# By default, rescale (for a tensor of ints) and channel permute
|
||||
array1 = feature_extractor.to_numpy_array(array)
|
||||
self.assertTrue(array1.dtype, np.float32)
|
||||
self.assertEqual(array1.shape, (3, 16, 32))
|
||||
self.assertTrue(np.array_equal(array1, array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0)))
|
||||
|
||||
# Same with no permute
|
||||
array2 = feature_extractor.to_numpy_array(array, channel_first=False)
|
||||
self.assertTrue(array2.dtype, np.float32)
|
||||
self.assertEqual(array2.shape, (16, 32, 3))
|
||||
self.assertTrue(np.array_equal(array2, array.astype(np.float32) * (1 / 255.0)))
|
||||
|
||||
# Force rescale to False
|
||||
array3 = feature_extractor.to_numpy_array(array, rescale=False)
|
||||
self.assertTrue(array3.dtype, np.uint8)
|
||||
self.assertEqual(array3.shape, (3, 16, 32))
|
||||
self.assertTrue(np.array_equal(array3, array.transpose(2, 0, 1)))
|
||||
|
||||
# Force rescale to False and no channel permute
|
||||
array4 = feature_extractor.to_numpy_array(array, rescale=False, channel_first=False)
|
||||
self.assertTrue(array4.dtype, np.uint8)
|
||||
self.assertEqual(array4.shape, (16, 32, 3))
|
||||
self.assertTrue(np.array_equal(array4, array))
|
||||
|
||||
# Now test the default rescale for a float tensor (defaults to False)
|
||||
array5 = feature_extractor.to_numpy_array(array2)
|
||||
self.assertTrue(array5.dtype, np.float32)
|
||||
self.assertEqual(array5.shape, (3, 16, 32))
|
||||
self.assertTrue(np.array_equal(array5, array1))
|
||||
|
||||
def test_conversion_image_to_image(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
image = get_random_image(16, 32)
|
||||
|
||||
# On an image, `to_pil_image1` is a noop.
|
||||
image1 = feature_extractor.to_pil_image(image)
|
||||
self.assertTrue(isinstance(image, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image), np.array(image1)))
|
||||
|
||||
def test_conversion_array_to_image(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
array = np.random.randint(0, 256, (16, 32, 3), dtype=np.uint8)
|
||||
|
||||
# By default, no rescale (for an array of ints)
|
||||
image1 = feature_extractor.to_pil_image(array)
|
||||
self.assertTrue(isinstance(image1, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image1), array))
|
||||
|
||||
# If the array is channel-first, proper reordering of the channels is done.
|
||||
image2 = feature_extractor.to_pil_image(array.transpose(2, 0, 1))
|
||||
self.assertTrue(isinstance(image2, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image2), array))
|
||||
|
||||
# If the array has floating type, it's rescaled by default.
|
||||
image3 = feature_extractor.to_pil_image(array.astype(np.float32) * (1 / 255.0))
|
||||
self.assertTrue(isinstance(image3, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image3), array))
|
||||
|
||||
# You can override the default to rescale.
|
||||
image4 = feature_extractor.to_pil_image(array.astype(np.float32), rescale=False)
|
||||
self.assertTrue(isinstance(image4, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image4), array))
|
||||
|
||||
# And with floats + channel first.
|
||||
image5 = feature_extractor.to_pil_image(array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0))
|
||||
self.assertTrue(isinstance(image5, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image5), array))
|
||||
|
||||
@require_torch
|
||||
def test_conversion_tensor_to_image(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
tensor = torch.randint(0, 256, (16, 32, 3))
|
||||
array = tensor.numpy()
|
||||
|
||||
# By default, no rescale (for a tensor of ints)
|
||||
image1 = feature_extractor.to_pil_image(tensor)
|
||||
self.assertTrue(isinstance(image1, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image1), array))
|
||||
|
||||
# If the tensor is channel-first, proper reordering of the channels is done.
|
||||
image2 = feature_extractor.to_pil_image(tensor.permute(2, 0, 1))
|
||||
self.assertTrue(isinstance(image2, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image2), array))
|
||||
|
||||
# If the tensor has floating type, it's rescaled by default.
|
||||
image3 = feature_extractor.to_pil_image(tensor.float() / 255.0)
|
||||
self.assertTrue(isinstance(image3, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image3), array))
|
||||
|
||||
# You can override the default to rescale.
|
||||
image4 = feature_extractor.to_pil_image(tensor.float(), rescale=False)
|
||||
self.assertTrue(isinstance(image4, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image4), array))
|
||||
|
||||
# And with floats + channel first.
|
||||
image5 = feature_extractor.to_pil_image(tensor.permute(2, 0, 1).float() * (1 / 255.0))
|
||||
self.assertTrue(isinstance(image5, PIL.Image.Image))
|
||||
self.assertTrue(np.array_equal(np.array(image5), array))
|
||||
|
||||
def test_resize_image_and_array(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
image = get_random_image(16, 32)
|
||||
array = np.array(image)
|
||||
|
||||
# Size can be an int or a tuple of ints.
|
||||
resized_image = feature_extractor.resize(image, 8)
|
||||
self.assertTrue(isinstance(resized_image, PIL.Image.Image))
|
||||
self.assertEqual(resized_image.size, (8, 8))
|
||||
|
||||
resized_image1 = feature_extractor.resize(image, (8, 16))
|
||||
self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
|
||||
self.assertEqual(resized_image1.size, (8, 16))
|
||||
|
||||
# Passing an array converts it to a PIL Image.
|
||||
resized_image2 = feature_extractor.resize(array, 8)
|
||||
self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
|
||||
self.assertEqual(resized_image2.size, (8, 8))
|
||||
self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
|
||||
|
||||
resized_image3 = feature_extractor.resize(image, (8, 16))
|
||||
self.assertTrue(isinstance(resized_image3, PIL.Image.Image))
|
||||
self.assertEqual(resized_image3.size, (8, 16))
|
||||
self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
|
||||
|
||||
def test_resize_image_and_array_non_default_to_square(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
|
||||
heights_widths = [
|
||||
# height, width
|
||||
# square image
|
||||
(28, 28),
|
||||
(27, 27),
|
||||
# rectangular image: h < w
|
||||
(28, 34),
|
||||
(29, 35),
|
||||
# rectangular image: h > w
|
||||
(34, 28),
|
||||
(35, 29),
|
||||
]
|
||||
|
||||
# single integer or single integer in tuple/list
|
||||
sizes = [22, 27, 28, 36, [22], (27,)]
|
||||
|
||||
for (height, width), size in zip(heights_widths, sizes):
|
||||
for max_size in (None, 37, 1000):
|
||||
image = get_random_image(height, width)
|
||||
array = np.array(image)
|
||||
|
||||
size = size[0] if isinstance(size, (list, tuple)) else size
|
||||
# Size can be an int or a tuple of ints.
|
||||
# If size is an int, smaller edge of the image will be matched to this number.
|
||||
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
||||
if height < width:
|
||||
exp_w, exp_h = (int(size * width / height), size)
|
||||
if max_size is not None and max_size < exp_w:
|
||||
exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
|
||||
elif width < height:
|
||||
exp_w, exp_h = (size, int(size * height / width))
|
||||
if max_size is not None and max_size < exp_h:
|
||||
exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
|
||||
else:
|
||||
exp_w, exp_h = (size, size)
|
||||
if max_size is not None and max_size < size:
|
||||
exp_w, exp_h = max_size, max_size
|
||||
|
||||
resized_image = feature_extractor.resize(image, size=size, default_to_square=False, max_size=max_size)
|
||||
self.assertTrue(isinstance(resized_image, PIL.Image.Image))
|
||||
self.assertEqual(resized_image.size, (exp_w, exp_h))
|
||||
|
||||
# Passing an array converts it to a PIL Image.
|
||||
resized_image2 = feature_extractor.resize(array, size=size, default_to_square=False, max_size=max_size)
|
||||
self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
|
||||
self.assertEqual(resized_image2.size, (exp_w, exp_h))
|
||||
self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
|
||||
|
||||
@require_torch
|
||||
def test_resize_tensor(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
tensor = torch.randint(0, 256, (16, 32, 3))
|
||||
array = tensor.numpy()
|
||||
|
||||
# Size can be an int or a tuple of ints.
|
||||
resized_image = feature_extractor.resize(tensor, 8)
|
||||
self.assertTrue(isinstance(resized_image, PIL.Image.Image))
|
||||
self.assertEqual(resized_image.size, (8, 8))
|
||||
|
||||
resized_image1 = feature_extractor.resize(tensor, (8, 16))
|
||||
self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
|
||||
self.assertEqual(resized_image1.size, (8, 16))
|
||||
|
||||
# Check we get the same results as with NumPy arrays.
|
||||
resized_image2 = feature_extractor.resize(array, 8)
|
||||
self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
|
||||
|
||||
resized_image3 = feature_extractor.resize(array, (8, 16))
|
||||
self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
|
||||
|
||||
def test_normalize_image(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
image = get_random_image(16, 32)
|
||||
array = np.array(image)
|
||||
mean = [0.1, 0.5, 0.9]
|
||||
std = [0.2, 0.4, 0.6]
|
||||
|
||||
# PIL Image are converted to NumPy arrays for the normalization
|
||||
normalized_image = feature_extractor.normalize(image, mean, std)
|
||||
self.assertTrue(isinstance(normalized_image, np.ndarray))
|
||||
self.assertEqual(normalized_image.shape, (3, 16, 32))
|
||||
|
||||
# During the conversion rescale and channel first will be applied.
|
||||
expected = array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0)
|
||||
np_mean = np.array(mean).astype(np.float32)[:, None, None]
|
||||
np_std = np.array(std).astype(np.float32)[:, None, None]
|
||||
expected = (expected - np_mean) / np_std
|
||||
self.assertTrue(np.array_equal(normalized_image, expected))
|
||||
|
||||
def test_normalize_array(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
array = np.random.random((16, 32, 3))
|
||||
mean = [0.1, 0.5, 0.9]
|
||||
std = [0.2, 0.4, 0.6]
|
||||
|
||||
# mean and std can be passed as lists or NumPy arrays.
|
||||
expected = (array - np.array(mean)) / np.array(std)
|
||||
normalized_array = feature_extractor.normalize(array, mean, std)
|
||||
self.assertTrue(np.array_equal(normalized_array, expected))
|
||||
|
||||
normalized_array = feature_extractor.normalize(array, np.array(mean), np.array(std))
|
||||
self.assertTrue(np.array_equal(normalized_array, expected))
|
||||
|
||||
# Normalize will detect automatically if channel first or channel last is used.
|
||||
array = np.random.random((3, 16, 32))
|
||||
expected = (array - np.array(mean)[:, None, None]) / np.array(std)[:, None, None]
|
||||
normalized_array = feature_extractor.normalize(array, mean, std)
|
||||
self.assertTrue(np.array_equal(normalized_array, expected))
|
||||
|
||||
normalized_array = feature_extractor.normalize(array, np.array(mean), np.array(std))
|
||||
self.assertTrue(np.array_equal(normalized_array, expected))
|
||||
|
||||
@require_torch
|
||||
def test_normalize_tensor(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
tensor = torch.rand(16, 32, 3)
|
||||
mean = [0.1, 0.5, 0.9]
|
||||
std = [0.2, 0.4, 0.6]
|
||||
|
||||
# mean and std can be passed as lists or tensors.
|
||||
expected = (tensor - torch.tensor(mean)) / torch.tensor(std)
|
||||
normalized_tensor = feature_extractor.normalize(tensor, mean, std)
|
||||
self.assertTrue(torch.equal(normalized_tensor, expected))
|
||||
|
||||
normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
|
||||
self.assertTrue(torch.equal(normalized_tensor, expected))
|
||||
|
||||
# Normalize will detect automatically if channel first or channel last is used.
|
||||
tensor = torch.rand(3, 16, 32)
|
||||
expected = (tensor - torch.tensor(mean)[:, None, None]) / torch.tensor(std)[:, None, None]
|
||||
normalized_tensor = feature_extractor.normalize(tensor, mean, std)
|
||||
self.assertTrue(torch.equal(normalized_tensor, expected))
|
||||
|
||||
normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
|
||||
self.assertTrue(torch.equal(normalized_tensor, expected))
|
||||
|
||||
def test_center_crop_image(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
image = get_random_image(16, 32)
|
||||
|
||||
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
|
||||
crop_sizes = [8, (8, 64), 20, (32, 64)]
|
||||
for size in crop_sizes:
|
||||
cropped_image = feature_extractor.center_crop(image, size)
|
||||
self.assertTrue(isinstance(cropped_image, PIL.Image.Image))
|
||||
|
||||
# PIL Image.size is transposed compared to NumPy or PyTorch (width first instead of height first).
|
||||
expected_size = (size, size) if isinstance(size, int) else (size[1], size[0])
|
||||
self.assertEqual(cropped_image.size, expected_size)
|
||||
|
||||
def test_center_crop_array(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
image = get_random_image(16, 32)
|
||||
array = feature_extractor.to_numpy_array(image)
|
||||
|
||||
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
|
||||
crop_sizes = [8, (8, 64), 20, (32, 64)]
|
||||
for size in crop_sizes:
|
||||
cropped_array = feature_extractor.center_crop(array, size)
|
||||
self.assertTrue(isinstance(cropped_array, np.ndarray))
|
||||
|
||||
expected_size = (size, size) if isinstance(size, int) else size
|
||||
self.assertEqual(cropped_array.shape[-2:], expected_size)
|
||||
|
||||
# Check result is consistent with PIL.Image.crop
|
||||
cropped_image = feature_extractor.center_crop(image, size)
|
||||
self.assertTrue(np.array_equal(cropped_array, feature_extractor.to_numpy_array(cropped_image)))
|
||||
|
||||
@require_torch
|
||||
def test_center_crop_tensor(self):
|
||||
feature_extractor = ImageFeatureExtractionMixin()
|
||||
image = get_random_image(16, 32)
|
||||
array = feature_extractor.to_numpy_array(image)
|
||||
tensor = torch.tensor(array)
|
||||
|
||||
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
|
||||
crop_sizes = [8, (8, 64), 20, (32, 64)]
|
||||
for size in crop_sizes:
|
||||
cropped_tensor = feature_extractor.center_crop(tensor, size)
|
||||
self.assertTrue(isinstance(cropped_tensor, torch.Tensor))
|
||||
|
||||
expected_size = (size, size) if isinstance(size, int) else size
|
||||
self.assertEqual(cropped_tensor.shape[-2:], expected_size)
|
||||
|
||||
# Check result is consistent with PIL.Image.crop
|
||||
cropped_image = feature_extractor.center_crop(image, size)
|
||||
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))
|
||||
|
||||
|
||||
@require_vision
|
||||
class LoadImageTester(unittest.TestCase):
|
||||
def test_load_img_url(self):
|
||||
img = load_image(INVOICE_URL)
|
||||
img_arr = np.array(img)
|
||||
|
||||
self.assertEqual(img_arr.shape, (1061, 750, 3))
|
||||
|
||||
@is_flaky()
|
||||
def test_load_img_url_timeout(self):
|
||||
with self.assertRaises(httpx.ConnectTimeout):
|
||||
load_image(INVOICE_URL, timeout=0.001)
|
||||
|
||||
def test_load_img_local(self):
|
||||
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
img_arr = np.array(img)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr.shape,
|
||||
(480, 640, 3),
|
||||
)
|
||||
|
||||
def test_load_img_base64_prefix(self):
|
||||
path = hf_hub_download(
|
||||
repo_id="hf-internal-testing/dummy-base64-images", filename="image_0.txt", repo_type="dataset"
|
||||
)
|
||||
with open(path, encoding="utf-8") as b64:
|
||||
img = load_image(b64.read())
|
||||
img_arr = np.array(img)
|
||||
self.assertEqual(img_arr.shape, (64, 32, 3))
|
||||
|
||||
def test_load_img_base64(self):
|
||||
path = hf_hub_download(
|
||||
repo_id="hf-internal-testing/dummy-base64-images", filename="image_1.txt", repo_type="dataset"
|
||||
)
|
||||
with open(path, encoding="utf-8") as b64:
|
||||
img = load_image(b64.read())
|
||||
img_arr = np.array(img)
|
||||
self.assertEqual(img_arr.shape, (64, 32, 3))
|
||||
|
||||
def test_load_img_base64_encoded_bytes(self):
|
||||
path = hf_hub_download(
|
||||
repo_id="hf-internal-testing/dummy-base64-images", filename="image_2.txt", repo_type="dataset"
|
||||
)
|
||||
with codecs.open(path, encoding="unicode_escape") as b64:
|
||||
img = load_image(b64.read())
|
||||
img_arr = np.array(img)
|
||||
self.assertEqual(img_arr.shape, (256, 256, 3))
|
||||
|
||||
def test_load_img_rgba(self):
|
||||
# we use revision="refs/pr/1" until the PR is merged
|
||||
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
|
||||
img = get_image_from_hub_dataset(
|
||||
"hf-internal-testing/fixtures_image_utils", "0-test-lena.png", revision="refs/pr/1"
|
||||
)
|
||||
|
||||
img = load_image(img) # img with mode RGBA
|
||||
img_arr = np.array(img)
|
||||
self.assertEqual(img_arr.shape, (512, 512, 3))
|
||||
|
||||
def test_load_img_la(self):
|
||||
# we use revision="refs/pr/1" until the PR is merged
|
||||
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
|
||||
img = get_image_from_hub_dataset(
|
||||
"hf-internal-testing/fixtures_image_utils", "1-test-parrots.png", revision="refs/pr/1"
|
||||
)
|
||||
|
||||
img = load_image(img) # img with mode LA
|
||||
img_arr = np.array(img)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr.shape,
|
||||
(512, 768, 3),
|
||||
)
|
||||
|
||||
def test_load_img_l(self):
|
||||
# we use revision="refs/pr/1" until the PR is merged
|
||||
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
|
||||
img = get_image_from_hub_dataset(
|
||||
"hf-internal-testing/fixtures_image_utils", "2-test-tree.png", revision="refs/pr/1"
|
||||
)
|
||||
|
||||
img = load_image(img) # img with mode L
|
||||
img_arr = np.array(img)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr.shape,
|
||||
(381, 225, 3),
|
||||
)
|
||||
|
||||
def test_load_img_exif_transpose(self):
|
||||
# we use revision="refs/pr/1" until the PR is merged
|
||||
# https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
|
||||
|
||||
img_without_exif_transpose = get_image_from_hub_dataset(
|
||||
"hf-internal-testing/fixtures_image_utils", "3-test-cat-rotated.jpg", revision="refs/pr/1"
|
||||
)
|
||||
img_arr_without_exif_transpose = np.array(img_without_exif_transpose)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr_without_exif_transpose.shape,
|
||||
(333, 500, 3),
|
||||
)
|
||||
|
||||
img_with_exif_transpose = load_image(img_without_exif_transpose)
|
||||
img_arr_with_exif_transpose = np.array(img_with_exif_transpose)
|
||||
|
||||
self.assertEqual(
|
||||
img_arr_with_exif_transpose.shape,
|
||||
(500, 333, 3),
|
||||
)
|
||||
|
||||
|
||||
class UtilFunctionTester(unittest.TestCase):
|
||||
def test_get_image_size(self):
|
||||
# Test we can infer the size and channel dimension of an image.
|
||||
image = np.random.randint(0, 256, (32, 64, 3))
|
||||
self.assertEqual(get_image_size(image), (32, 64))
|
||||
|
||||
image = np.random.randint(0, 256, (3, 32, 64))
|
||||
self.assertEqual(get_image_size(image), (32, 64))
|
||||
|
||||
# Test the channel dimension can be overridden
|
||||
image = np.random.randint(0, 256, (3, 32, 64))
|
||||
self.assertEqual(get_image_size(image, channel_dim=ChannelDimension.LAST), (3, 32))
|
||||
|
||||
def test_infer_channel_dimension(self):
|
||||
# Test we fail with invalid input
|
||||
with pytest.raises(ValueError):
|
||||
infer_channel_dimension_format(np.random.randint(0, 256, (10, 10)))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
infer_channel_dimension_format(np.random.randint(0, 256, (10, 10, 10, 10, 10)))
|
||||
|
||||
# Test we fail if neither first not last dimension is of size 3 or 1
|
||||
with pytest.raises(ValueError):
|
||||
infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
|
||||
|
||||
# But if we explicitly set one of the number of channels to 50 it works
|
||||
inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
|
||||
self.assertEqual(inferred_dim, ChannelDimension.LAST)
|
||||
|
||||
# Test we correctly identify the channel dimension
|
||||
image = np.random.randint(0, 256, (3, 4, 5))
|
||||
inferred_dim = infer_channel_dimension_format(image)
|
||||
self.assertEqual(inferred_dim, ChannelDimension.FIRST)
|
||||
|
||||
image = np.random.randint(0, 256, (1, 4, 5))
|
||||
inferred_dim = infer_channel_dimension_format(image)
|
||||
self.assertEqual(inferred_dim, ChannelDimension.FIRST)
|
||||
|
||||
image = np.random.randint(0, 256, (4, 5, 3))
|
||||
inferred_dim = infer_channel_dimension_format(image)
|
||||
self.assertEqual(inferred_dim, ChannelDimension.LAST)
|
||||
|
||||
image = np.random.randint(0, 256, (4, 5, 1))
|
||||
inferred_dim = infer_channel_dimension_format(image)
|
||||
self.assertEqual(inferred_dim, ChannelDimension.LAST)
|
||||
|
||||
# We can take a batched array of images and find the dimension
|
||||
image = np.random.randint(0, 256, (1, 3, 4, 5))
|
||||
inferred_dim = infer_channel_dimension_format(image)
|
||||
self.assertEqual(inferred_dim, ChannelDimension.FIRST)
|
||||
|
||||
def test_get_channel_dimension_axis(self):
|
||||
# Test we correctly identify the channel dimension
|
||||
image = np.random.randint(0, 256, (3, 4, 5))
|
||||
inferred_axis = get_channel_dimension_axis(image)
|
||||
self.assertEqual(inferred_axis, 0)
|
||||
|
||||
image = np.random.randint(0, 256, (1, 4, 5))
|
||||
inferred_axis = get_channel_dimension_axis(image)
|
||||
self.assertEqual(inferred_axis, 0)
|
||||
|
||||
image = np.random.randint(0, 256, (4, 5, 3))
|
||||
inferred_axis = get_channel_dimension_axis(image)
|
||||
self.assertEqual(inferred_axis, 2)
|
||||
|
||||
image = np.random.randint(0, 256, (4, 5, 1))
|
||||
inferred_axis = get_channel_dimension_axis(image)
|
||||
self.assertEqual(inferred_axis, 2)
|
||||
|
||||
# We can take a batched array of images and find the dimension
|
||||
image = np.random.randint(0, 256, (1, 3, 4, 5))
|
||||
inferred_axis = get_channel_dimension_axis(image)
|
||||
self.assertEqual(inferred_axis, 1)
|
||||
208
tests/utils/test_import_structure.py
Normal file
208
tests/utils/test_import_structure.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.utils.import_utils import (
|
||||
Backend,
|
||||
VersionComparison,
|
||||
define_import_structure,
|
||||
spread_import_structure,
|
||||
)
|
||||
|
||||
|
||||
import_structures = Path(__file__).parent / "import_structures"
|
||||
|
||||
|
||||
def fetch__all__(file_content):
|
||||
"""
|
||||
Returns the content of the __all__ variable in the file content.
|
||||
Returns None if not defined, otherwise returns a list of strings.
|
||||
"""
|
||||
lines = file_content.split("\n")
|
||||
for line_index in range(len(lines)):
|
||||
line = lines[line_index]
|
||||
if line.startswith("__all__ = "):
|
||||
# __all__ is defined on a single line
|
||||
if line.endswith("]"):
|
||||
return [obj.strip("\"' ") for obj in line.split("=")[1].strip(" []").split(",")]
|
||||
|
||||
# __all__ is defined on multiple lines
|
||||
else:
|
||||
_all = []
|
||||
for __all__line_index in range(line_index + 1, len(lines)):
|
||||
if lines[__all__line_index].strip() == "]":
|
||||
return _all
|
||||
else:
|
||||
_all.append(lines[__all__line_index].strip("\"', "))
|
||||
|
||||
|
||||
class TestImportStructures(unittest.TestCase):
|
||||
base_transformers_path = Path(__file__).parent.parent.parent
|
||||
models_path = base_transformers_path / "src" / "transformers" / "models"
|
||||
models_import_structure = spread_import_structure(define_import_structure(models_path))
|
||||
|
||||
def test_definition(self):
|
||||
import_structure = define_import_structure(import_structures)
|
||||
valid_frozensets: dict[frozenset | frozenset[str], dict[str, set[str]]] = {
|
||||
frozenset(): {
|
||||
"import_structure_raw_register": {"A0", "A4", "a0"},
|
||||
"import_structure_register_with_comments": {"B0", "b0"},
|
||||
},
|
||||
frozenset({"random_item_that_should_not_exist"}): {"failing_export": {"A0"}},
|
||||
frozenset({"torch"}): {
|
||||
"import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"},
|
||||
"import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"},
|
||||
"import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"},
|
||||
},
|
||||
frozenset({"torch>=2.5"}): {"import_structure_raw_register_with_versions": {"D0", "d0"}},
|
||||
frozenset({"torch>2.5"}): {"import_structure_raw_register_with_versions": {"D1", "d1"}},
|
||||
frozenset({"torch<=2.5"}): {"import_structure_raw_register_with_versions": {"D2", "d2"}},
|
||||
frozenset({"torch<2.5"}): {"import_structure_raw_register_with_versions": {"D3", "d3"}},
|
||||
frozenset({"torch==2.5"}): {"import_structure_raw_register_with_versions": {"D4", "d4"}},
|
||||
frozenset({"torch!=2.5"}): {"import_structure_raw_register_with_versions": {"D5", "d5"}},
|
||||
frozenset({"torch>=2.5", "accelerate<0.20"}): {
|
||||
"import_structure_raw_register_with_versions": {"D6", "d6"}
|
||||
},
|
||||
}
|
||||
|
||||
self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys()))
|
||||
for _frozenset in valid_frozensets:
|
||||
self.assertTrue(_frozenset in import_structure)
|
||||
self.assertListEqual(
|
||||
sorted(import_structure[_frozenset].keys()), sorted(valid_frozensets[_frozenset].keys())
|
||||
)
|
||||
for module, objects in valid_frozensets[_frozenset].items():
|
||||
self.assertTrue(module in import_structure[_frozenset])
|
||||
self.assertSetEqual(objects, import_structure[_frozenset][module])
|
||||
|
||||
def test_transformers_specific_model_import(self):
|
||||
"""
|
||||
This test ensures that there is equivalence between what is written down in __all__ and what is
|
||||
written down with register().
|
||||
|
||||
It doesn't test the backends attributed to register().
|
||||
"""
|
||||
for architecture in os.listdir(self.models_path):
|
||||
if (
|
||||
os.path.isfile(self.models_path / architecture)
|
||||
or architecture.startswith("_")
|
||||
or architecture == "deprecated"
|
||||
):
|
||||
continue
|
||||
|
||||
with self.subTest(f"Testing arch {architecture}"):
|
||||
import_structure = define_import_structure(self.models_path / architecture)
|
||||
backend_agnostic_import_structure = {}
|
||||
for module_object_mapping in import_structure.values():
|
||||
for module, objects in module_object_mapping.items():
|
||||
if module not in backend_agnostic_import_structure:
|
||||
backend_agnostic_import_structure[module] = []
|
||||
|
||||
backend_agnostic_import_structure[module].extend(objects)
|
||||
|
||||
for module, objects in backend_agnostic_import_structure.items():
|
||||
with open(self.models_path / architecture / f"{module}.py") as f:
|
||||
content = f.read()
|
||||
_all = fetch__all__(content)
|
||||
|
||||
if _all is None:
|
||||
raise ValueError(f"{module} doesn't have __all__ defined.")
|
||||
|
||||
error_message = (
|
||||
f"self.models_path / architecture / f'{module}.py doesn't seem to be defined correctly:\n"
|
||||
f"Defined in __all__: {sorted(_all)}\nDefined with register: {sorted(objects)}"
|
||||
)
|
||||
self.assertListEqual(sorted(objects), sorted(_all), msg=error_message)
|
||||
|
||||
def test_import_spread(self):
|
||||
"""
|
||||
This test is specifically designed to test that varying levels of depth across import structures are
|
||||
respected.
|
||||
|
||||
In this instance, frozensets are at respective depths of 1, 2 and 3, for example:
|
||||
- models.{frozensets}
|
||||
- models.albert.{frozensets}
|
||||
- models.deprecated.transfo_xl.{frozensets}
|
||||
"""
|
||||
initial_import_structure = {
|
||||
frozenset(): {"dummy_non_model": {"DummyObject"}},
|
||||
"models": {
|
||||
frozenset(): {"dummy_config": {"DummyConfig"}},
|
||||
"albert": {
|
||||
frozenset(): {"configuration_albert": {"AlbertConfig"}},
|
||||
frozenset({"torch"}): {
|
||||
"modeling_albert": {
|
||||
"AlbertForMaskedLM",
|
||||
}
|
||||
},
|
||||
},
|
||||
"llama": {
|
||||
frozenset(): {"configuration_llama": {"LlamaConfig"}},
|
||||
frozenset({"torch"}): {
|
||||
"modeling_llama": {
|
||||
"LlamaForCausalLM",
|
||||
}
|
||||
},
|
||||
},
|
||||
"deprecated": {
|
||||
"transfo_xl": {
|
||||
frozenset({"torch"}): {
|
||||
"modeling_transfo_xl": {
|
||||
"TransfoXLModel",
|
||||
}
|
||||
},
|
||||
frozenset(): {
|
||||
"configuration_transfo_xl": {"TransfoXLConfig"},
|
||||
"tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
|
||||
},
|
||||
},
|
||||
"deta": {
|
||||
frozenset({"torch"}): {
|
||||
"modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"}
|
||||
},
|
||||
frozenset(): {"configuration_deta": {"DetaConfig"}},
|
||||
frozenset({"vision"}): {"image_processing_deta": {"DetaImageProcessor"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ground_truth_spread_import_structure = {
|
||||
frozenset(): {
|
||||
"dummy_non_model": {"DummyObject"},
|
||||
"models.dummy_config": {"DummyConfig"},
|
||||
"models.albert.configuration_albert": {"AlbertConfig"},
|
||||
"models.llama.configuration_llama": {"LlamaConfig"},
|
||||
"models.deprecated.transfo_xl.configuration_transfo_xl": {"TransfoXLConfig"},
|
||||
"models.deprecated.transfo_xl.tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
|
||||
"models.deprecated.deta.configuration_deta": {"DetaConfig"},
|
||||
},
|
||||
frozenset({"torch"}): {
|
||||
"models.albert.modeling_albert": {"AlbertForMaskedLM"},
|
||||
"models.llama.modeling_llama": {"LlamaForCausalLM"},
|
||||
"models.deprecated.transfo_xl.modeling_transfo_xl": {"TransfoXLModel"},
|
||||
"models.deprecated.deta.modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"},
|
||||
},
|
||||
frozenset({"vision"}): {"models.deprecated.deta.image_processing_deta": {"DetaImageProcessor"}},
|
||||
}
|
||||
|
||||
newly_spread_import_structure = spread_import_structure(initial_import_structure)
|
||||
|
||||
self.assertEqual(ground_truth_spread_import_structure, newly_spread_import_structure)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend,package_name,version_comparison,version",
|
||||
[
|
||||
pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL, "2.5"),
|
||||
pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL, "0.19.1"),
|
||||
],
|
||||
)
|
||||
def test_backend_specification(
|
||||
backend: Backend, package_name: str, version_comparison: VersionComparison, version: str
|
||||
):
|
||||
assert backend.package_name == package_name
|
||||
assert VersionComparison.from_string(backend.version_comparison) == version_comparison
|
||||
assert backend.version == version
|
||||
50
tests/utils/test_import_utils.py
Normal file
50
tests/utils/test_import_utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.testing_utils import run_test_using_subprocess
|
||||
from transformers.utils.import_utils import _is_package_available, clear_import_cache
|
||||
|
||||
|
||||
@run_test_using_subprocess
|
||||
def test_clear_import_cache():
|
||||
"""Test the clear_import_cache function."""
|
||||
|
||||
# Save initial state
|
||||
initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
|
||||
assert len(initial_modules) > 0, "No transformers modules loaded before test"
|
||||
|
||||
# Execute clear_import_cache() function
|
||||
clear_import_cache()
|
||||
|
||||
# Verify modules were removed
|
||||
remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
|
||||
assert len(remaining_modules) < len(initial_modules), "No modules were removed"
|
||||
|
||||
# Import and verify module exists
|
||||
from transformers.models.auto import modeling_auto
|
||||
|
||||
assert "transformers.models.auto.modeling_auto" in sys.modules
|
||||
assert modeling_auto.__name__ == "transformers.models.auto.modeling_auto"
|
||||
|
||||
|
||||
def test_is_package_available_edge_cases():
|
||||
pkg_name = "definitely_not_a_real_pkg_xyz"
|
||||
|
||||
namespace_shadow = ModuleType(pkg_name)
|
||||
versionless_install = ModuleType(pkg_name)
|
||||
versionless_install.__file__ = f"/path/to/site-packages/{pkg_name}/__init__.py"
|
||||
with_version = ModuleType(pkg_name)
|
||||
with_version.__version__ = "1.2.3"
|
||||
|
||||
cases = [
|
||||
(namespace_shadow, (False, "N/A")),
|
||||
(versionless_install, (True, "N/A")),
|
||||
(with_version, (True, "1.2.3")),
|
||||
]
|
||||
for fake_module, expected in cases:
|
||||
with (
|
||||
patch("transformers.utils.import_utils.importlib.util.find_spec", return_value=object()),
|
||||
patch("transformers.utils.import_utils.importlib.import_module", return_value=fake_module),
|
||||
):
|
||||
assert _is_package_available(pkg_name, return_version=True) == expected
|
||||
135
tests/utils/test_logging.py
Normal file
135
tests/utils/test_logging.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from huggingface_hub.utils import are_progress_bars_disabled
|
||||
|
||||
import transformers.models.roberta.tokenization_roberta
|
||||
from transformers import logging
|
||||
from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
|
||||
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
|
||||
|
||||
|
||||
class HfArgumentParserTest(unittest.TestCase):
|
||||
def test_set_level(self):
|
||||
logger = logging.get_logger()
|
||||
|
||||
# the current default level is logging.WARNING
|
||||
level_origin = logging.get_verbosity()
|
||||
|
||||
logging.set_verbosity_error()
|
||||
self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
|
||||
|
||||
logging.set_verbosity_warning()
|
||||
self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
|
||||
|
||||
logging.set_verbosity_info()
|
||||
self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
|
||||
|
||||
logging.set_verbosity_debug()
|
||||
self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
|
||||
|
||||
# restore to the original level
|
||||
logging.set_verbosity(level_origin)
|
||||
|
||||
def test_integration(self):
|
||||
level_origin = logging.get_verbosity()
|
||||
|
||||
logger = logging.get_logger("transformers.models.roberta.tokenization_roberta")
|
||||
msg = "Testing 1, 2, 3"
|
||||
|
||||
# should be able to log warnings (if default settings weren't overridden by `pytest --log-level-all`)
|
||||
if level_origin <= logging.WARNING:
|
||||
with CaptureLogger(logger) as cl:
|
||||
logger.warning(msg)
|
||||
self.assertEqual(cl.out, msg + "\n")
|
||||
|
||||
# this is setting the level for all of `transformers.*` loggers
|
||||
logging.set_verbosity_error()
|
||||
|
||||
# should not be able to log warnings
|
||||
with CaptureLogger(logger) as cl:
|
||||
logger.warning(msg)
|
||||
self.assertEqual(cl.out, "")
|
||||
|
||||
# should be able to log warnings again
|
||||
logging.set_verbosity_warning()
|
||||
with CaptureLogger(logger) as cl:
|
||||
logger.warning(msg)
|
||||
self.assertEqual(cl.out, msg + "\n")
|
||||
|
||||
# restore to the original level
|
||||
logging.set_verbosity(level_origin)
|
||||
|
||||
@mockenv(TRANSFORMERS_VERBOSITY="error")
|
||||
def test_env_override(self):
|
||||
# reset for the env var to take effect, next time some logger call is made
|
||||
transformers.utils.logging._reset_library_root_logger()
|
||||
# this action activates the env var
|
||||
_ = logging.get_logger("transformers.models.roberta.tokenization_roberta")
|
||||
|
||||
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
|
||||
env_level = logging.log_levels[env_level_str]
|
||||
|
||||
current_level = logging.get_verbosity()
|
||||
self.assertEqual(
|
||||
env_level,
|
||||
current_level,
|
||||
f"TRANSFORMERS_VERBOSITY={env_level_str}/{env_level}, but internal verbosity is {current_level}",
|
||||
)
|
||||
|
||||
# restore to the original level
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = ""
|
||||
transformers.utils.logging._reset_library_root_logger()
|
||||
|
||||
@mockenv(TRANSFORMERS_VERBOSITY="super-error")
|
||||
def test_env_invalid_override(self):
|
||||
# reset for the env var to take effect, next time some logger call is made
|
||||
transformers.utils.logging._reset_library_root_logger()
|
||||
logger = logging.logging.getLogger()
|
||||
with CaptureLogger(logger) as cl:
|
||||
# this action activates the env var
|
||||
logging.get_logger("transformers.models.roberta.tokenization_roberta")
|
||||
self.assertIn("Unknown option TRANSFORMERS_VERBOSITY=super-error", cl.out)
|
||||
|
||||
# no need to restore as nothing was changed
|
||||
|
||||
def test_advisory_warnings(self):
|
||||
# testing `logger.warning_advice()`
|
||||
transformers.utils.logging._reset_library_root_logger()
|
||||
|
||||
logger = logging.get_logger("transformers.models.roberta.tokenization_roberta")
|
||||
msg = "Testing 1, 2, 3"
|
||||
|
||||
with mockenv_context(TRANSFORMERS_NO_ADVISORY_WARNINGS="1"):
|
||||
# nothing should be logged as env var disables this method
|
||||
with CaptureLogger(logger) as cl:
|
||||
logger.warning_advice(msg)
|
||||
self.assertEqual(cl.out, "")
|
||||
|
||||
with mockenv_context(TRANSFORMERS_NO_ADVISORY_WARNINGS=""):
|
||||
# should log normally as TRANSFORMERS_NO_ADVISORY_WARNINGS is unset
|
||||
with CaptureLogger(logger) as cl:
|
||||
logger.warning_advice(msg)
|
||||
self.assertEqual(cl.out, msg + "\n")
|
||||
|
||||
|
||||
def test_set_progress_bar_enabled():
|
||||
disable_progress_bar()
|
||||
assert are_progress_bars_disabled()
|
||||
|
||||
enable_progress_bar()
|
||||
assert not are_progress_bars_disabled()
|
||||
375
tests/utils/test_masking_utils.py
Normal file
375
tests/utils/test_masking_utils.py
Normal file
@@ -0,0 +1,375 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
is_torch_available,
|
||||
require_torch,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
|
||||
from transformers import DynamicCache, LlamaConfig
|
||||
from transformers.cache_utils import DynamicSlidingWindowLayer
|
||||
from transformers.masking_utils import (
|
||||
create_bidirectional_mask,
|
||||
create_causal_mask,
|
||||
create_chunked_causal_mask,
|
||||
find_packed_sequence_indices,
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_PACKED_MASK = torch.tensor([[[
|
||||
[ True, False, False, False, False, False, False, False, False, False],
|
||||
[ True, True, False, False, False, False, False, False, False, False],
|
||||
[ True, True, True, False, False, False, False, False, False, False],
|
||||
[ True, True, True, True, False, False, False, False, False, False],
|
||||
[False, False, False, False, True, False, False, False, False, False],
|
||||
[False, False, False, False, True, True, False, False, False, False],
|
||||
[False, False, False, False, False, False, True, False, False, False],
|
||||
[False, False, False, False, False, False, True, True, False, False],
|
||||
[False, False, False, False, False, False, True, True, True, False],
|
||||
[False, False, False, False, False, False, True, True, True, True]]],
|
||||
|
||||
|
||||
[[[ True, False, False, False, False, False, False, False, False, False],
|
||||
[ True, True, False, False, False, False, False, False, False, False],
|
||||
[ True, True, True, False, False, False, False, False, False, False],
|
||||
[ True, True, True, True, False, False, False, False, False, False],
|
||||
[ True, True, True, True, True, False, False, False, False, False],
|
||||
[ True, True, True, True, True, True, False, False, False, False],
|
||||
[False, False, False, False, False, False, True, False, False, False],
|
||||
[False, False, False, False, False, False, True, True, False, False],
|
||||
[False, False, False, False, False, False, True, True, True, False],
|
||||
[False, False, False, False, False, False, True, True, True, True]
|
||||
]]], dtype=torch.bool)
|
||||
# fmt: on
|
||||
|
||||
|
||||
@require_torch
|
||||
class MaskTest(unittest.TestCase):
|
||||
def setup(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_packed_sequence_mask_sdpa(self):
|
||||
config = LlamaConfig()
|
||||
config._attn_implementation = "sdpa"
|
||||
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
|
||||
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
|
||||
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=config,
|
||||
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
|
||||
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
self.assertTrue((causal_mask == EXPECTED_PACKED_MASK).all())
|
||||
|
||||
def test_packed_sequence_mask_eager(self):
|
||||
config = LlamaConfig()
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
|
||||
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
|
||||
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=config,
|
||||
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
|
||||
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
min_dtype = torch.finfo(torch.float16).min
|
||||
self.assertTrue((causal_mask == torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype)).all())
|
||||
|
||||
def test_packed_sequence_mask_flex_attention(self):
|
||||
config = LlamaConfig()
|
||||
config._attn_implementation = "flex_attention"
|
||||
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
|
||||
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
|
||||
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=config,
|
||||
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
|
||||
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
def dummy_mask_mod(b, h, q, kv):
|
||||
return EXPECTED_PACKED_MASK[b, h, q, kv]
|
||||
|
||||
EXPECTED_BLOCK_MASK = create_block_mask(dummy_mask_mod, 2, None, 10, 10, device="cpu")
|
||||
|
||||
# We compatre the str representations, as the BlockMask objects themselves cannot easily be compared
|
||||
self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())
|
||||
|
||||
def test_find_packed_sequence_indices(self):
|
||||
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
|
||||
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
|
||||
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())
|
||||
|
||||
def test_nonpacked_sequence_mask_skip(self):
|
||||
config = LlamaConfig()
|
||||
config._attn_implementation = "sdpa"
|
||||
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
|
||||
# Non-packed sequences
|
||||
position_ids = torch.arange(sequence_length)[None, :]
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=config,
|
||||
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
|
||||
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
# packed sequence should be skipped
|
||||
self.assertTrue(causal_mask is None)
|
||||
|
||||
create_causal_mask_compiled = torch.compile(create_causal_mask, mode="reduce-overhead")
|
||||
causal_mask = create_causal_mask_compiled(
|
||||
config=config,
|
||||
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
|
||||
inputs_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
# cannot be skipped under compile, should result into a triu mask
|
||||
self.assertTrue(torch.equal(~torch.ones(*causal_mask.shape).triu(diagonal=1).bool(), causal_mask))
|
||||
|
||||
def test_chunked_mask_with_left_padding_and_large_prefill(self):
|
||||
# Make sure we have an attention_chunk_size in the config
|
||||
config = LlamaConfig(attention_chunk_size=3, attn_implementation="sdpa")
|
||||
|
||||
batch_size = 2
|
||||
sequence_length = 8
|
||||
pad_tokens = 4
|
||||
|
||||
input_ids = torch.randint(100, 200, (batch_size, sequence_length))
|
||||
attention_mask = torch.tensor(
|
||||
[[0 if i < pad_tokens else 1 for i in range(sequence_length)], [1] * sequence_length]
|
||||
)
|
||||
inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16)
|
||||
positions = torch.arange(sequence_length)
|
||||
position_ids = torch.empty(batch_size, sequence_length, dtype=positions.dtype)
|
||||
position_ids[0, :pad_tokens] = 1
|
||||
position_ids[0, pad_tokens:] = torch.arange(sequence_length - pad_tokens)
|
||||
position_ids[1, :] = positions
|
||||
|
||||
chunked_attention_mask = create_chunked_causal_mask(
|
||||
config=config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=None,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_CHUNKED_MASK = torch.tensor(
|
||||
# Here, for the padded sequence, the chunk size should start correctly at index 4 (otherwise, with 4 padding
|
||||
# tokens are chunk_size=3, the first chunk is from indices 0-2, then 3-6 if we don't account for the padding correctly)
|
||||
[[[[False, False, False, False, False, False, False, False],
|
||||
[False, False, False, False, False, False, False, False],
|
||||
[False, False, False, False, False, False, False, False],
|
||||
[False, False, False, False, False, False, False, False],
|
||||
[False, False, False, False, True, False, False, False],
|
||||
[False, False, False, False, True, True, False, False],
|
||||
[False, False, False, False, True, True, True, False],
|
||||
[False, False, False, False, False, False, False, True]]],
|
||||
|
||||
|
||||
[[[ True, False, False, False, False, False, False, False],
|
||||
[ True, True, False, False, False, False, False, False],
|
||||
[ True, True, True, False, False, False, False, False],
|
||||
[False, False, False, True, False, False, False, False],
|
||||
[False, False, False, True, True, False, False, False],
|
||||
[False, False, False, True, True, True, False, False],
|
||||
[False, False, False, False, False, False, True, False],
|
||||
[False, False, False, False, False, False, True, True]]]],
|
||||
dtype=torch.bool)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all())
|
||||
|
||||
def test_chunked_mask_with_left_padding_decoding(self):
|
||||
# Make sure we have an attention_chunk_size in the config
|
||||
config = LlamaConfig(attention_chunk_size=4, attn_implementation="sdpa", num_hidden_layers=1)
|
||||
|
||||
cache = DynamicCache(config=config)
|
||||
# Sanity check
|
||||
self.assertEqual(len(cache), 1)
|
||||
self.assertTrue(isinstance(cache.layers[0], DynamicSlidingWindowLayer))
|
||||
|
||||
# Fill-in the Cache (sequence length is bigger than chunk size here)
|
||||
batch_size = 2
|
||||
prefill_size = 8
|
||||
pad_tokens = 7
|
||||
fake_kv = torch.rand(batch_size, 32, prefill_size, 32)
|
||||
cache.update(fake_kv, fake_kv, 0, torch.arange(prefill_size))
|
||||
|
||||
# Create a new input after the prefill
|
||||
input_ids = torch.randint(100, 200, (batch_size, 1))
|
||||
attention_mask = torch.tensor(
|
||||
[[0 if i < pad_tokens else 1 for i in range(prefill_size + 1)], [1] * (prefill_size + 1)]
|
||||
)
|
||||
inputs_embeds = torch.empty_like(input_ids, dtype=torch.float16)
|
||||
position_ids = torch.tensor([[prefill_size - pad_tokens], [prefill_size]])
|
||||
|
||||
chunked_attention_mask = create_chunked_causal_mask(
|
||||
config=config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=cache,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# To understand a bit more the following expected mask, here is the full 2d mask, where the "|" characters are the chunk
|
||||
# separators (where the tokens should stop seeing each other)
|
||||
# [0, 0, 0, 0, 0, 0, 0, | 1, 1], -> due to left padding, the first chunk only starts after the padding tokens
|
||||
# [| 1, 1, 1, 1, | 1, 1, 1, 1, | 1]]) -> easy case, each 4 tokens is a new chunk
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_CHUNKED_MASK = torch.tensor(
|
||||
# Here, for the padded sequence, the chunk size should start correctly at index 7 (the first unpadded
|
||||
# index), and so only indices 7 and 8 should be True
|
||||
[[[[False, False, True, True]]],
|
||||
|
||||
# Here, for the unpadded sequence, the chunks start at index 0. Since we have 9 tokens in total, the last
|
||||
# token (index 8) will only see itself (we have 2 full chunks before)
|
||||
[[[False, False, False, True]]]],
|
||||
dtype=torch.bool)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue((chunked_attention_mask == EXPECTED_CHUNKED_MASK).all())
|
||||
|
||||
@staticmethod
|
||||
def _run_bidirectional_mask(mask_fn, attn_implementation):
|
||||
def run_mask_creation(mask_fn, config, inputs_embeds, encoder_mask, cross_mask, encoder_hidden_states):
|
||||
encoder_attn_mask = mask_fn(
|
||||
config=config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=encoder_mask,
|
||||
)
|
||||
cross_attn_mask = mask_fn(
|
||||
config=config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=cross_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
return encoder_attn_mask, cross_attn_mask
|
||||
|
||||
# We use llama but could be also bert/bart --> we only need the `_attn_implementation` here
|
||||
config = LlamaConfig()
|
||||
config._attn_implementation = attn_implementation
|
||||
|
||||
# Meta data
|
||||
batch_size = 2
|
||||
q_length = 10
|
||||
kv_length = 5
|
||||
|
||||
inputs_embeds = torch.ones((batch_size, q_length, 1), device=torch_device, dtype=torch.float16)
|
||||
encoder_hidden_states = torch.ones((batch_size, kv_length, 1), device=torch_device, dtype=torch.float16)
|
||||
|
||||
encoder_mask = torch.ones_like(inputs_embeds)[..., 0]
|
||||
cross_mask = torch.ones_like(encoder_hidden_states)[..., 0]
|
||||
|
||||
# Case 1: Full mask
|
||||
full_mask_encoder_1, full_mask_cross_1 = run_mask_creation(
|
||||
mask_fn=mask_fn,
|
||||
config=config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_mask=encoder_mask,
|
||||
cross_mask=cross_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
full_mask_encoder_2, full_mask_cross_2 = run_mask_creation(
|
||||
mask_fn=mask_fn,
|
||||
config=config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_mask=None,
|
||||
cross_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
# Case 2: Padding involved
|
||||
cross_mask[:, -1] = 0
|
||||
encoder_mask[:, -1] = 0
|
||||
|
||||
padded_mask_encoder, padded_mask_cross = run_mask_creation(
|
||||
mask_fn=mask_fn,
|
||||
config=config,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_mask=encoder_mask,
|
||||
cross_mask=cross_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
full_masks = (full_mask_encoder_1, full_mask_encoder_2), (full_mask_cross_1, full_mask_cross_2)
|
||||
padded_masks = (padded_mask_encoder, padded_mask_cross)
|
||||
return full_masks, padded_masks
|
||||
|
||||
def test_bidirectional_mask_cudagraphs(self):
|
||||
"""
|
||||
Checks whether the bidirectional mask creation is compatible with cuda graphs, i.e. we do not into any error
|
||||
during this test.
|
||||
"""
|
||||
mask_creation_function = torch.compile(create_bidirectional_mask, mode="reduce-overhead")
|
||||
self._run_bidirectional_mask(mask_fn=mask_creation_function, attn_implementation="sdpa")
|
||||
|
||||
def test_bidirectional_mask_skip_eager(self):
|
||||
"""
|
||||
Checks whether the bidirectional mask creation can skip the mask creation if we have a full mask.
|
||||
"""
|
||||
full_masks, padded_mask = self._run_bidirectional_mask(
|
||||
mask_fn=create_bidirectional_mask, attn_implementation="eager"
|
||||
)
|
||||
|
||||
for alternative_masks in full_masks:
|
||||
self.assertTrue(alternative_masks[0] is None)
|
||||
self.assertTrue(alternative_masks[1] is None)
|
||||
|
||||
self.assertTrue(padded_mask[0] is not None)
|
||||
self.assertTrue(padded_mask[1] is not None)
|
||||
121
tests/utils/test_model_debugging_utils.py
Normal file
121
tests/utils/test_model_debugging_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.model_debugging_utils import model_addition_debugger_context
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class ToyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(10, 4)
|
||||
self.linear_1 = nn.Linear(4, 8)
|
||||
self.linear_2 = nn.Linear(8, 2)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
def forward(self, input_ids: str):
|
||||
hidden_states = self.embed(input_ids).mean(dim=1)
|
||||
hidden_states = self.act(self.linear_1(hidden_states))
|
||||
return self.linear_2(hidden_states)
|
||||
|
||||
class TestModelAdditionDebugger(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model = ToyModel()
|
||||
self.inputs = {"input_ids": torch.randint(0, 10, (1, 3))}
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
|
||||
def test_debugger_outputs(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with model_addition_debugger_context(self.model, debug_path=str(tmpdir)):
|
||||
_ = self.model.forward(**self.inputs)
|
||||
|
||||
base = f"{self.model.__class__.__name__}_debug_tree"
|
||||
summary = Path(os.path.join(tmpdir, f"{base}_SUMMARY.json"))
|
||||
full = Path(os.path.join(tmpdir, f"{base}_FULL_TENSORS.json"))
|
||||
self.assertTrue(os.path.isfile(summary) and os.path.isfile(full))
|
||||
data = json.loads(summary.read_text())
|
||||
self.assertTrue({"module_path", "inputs", "children"} <= data.keys())
|
||||
self.assertTrue(data["children"])
|
||||
|
||||
class ToyLayer(nn.Module):
|
||||
def __init__(self, layer_index):
|
||||
super().__init__()
|
||||
self.layer_index = layer_index
|
||||
self.layer_operation = nn.Linear(4, 4)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.layer_operation(hidden_states)
|
||||
|
||||
class ToyModelWithLayers(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_proj = nn.Linear(4, 4)
|
||||
self.layers = nn.ModuleList([ToyLayer(layer_index) for layer_index in range(6)])
|
||||
self.output_proj = nn.Linear(4, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_proj(x)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return self.output_proj(x)
|
||||
|
||||
class TestModelWithLayers(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.inputs = {"input_ids": torch.randint(0, 10, (1, 3))}
|
||||
self.model_with_layers = ToyModelWithLayers()
|
||||
self.dense_input = {"x": torch.randn(1, 4)}
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
|
||||
def test_layer_pruning_behavior(self):
|
||||
# No pruning: expect all 6 layers
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with model_addition_debugger_context(self.model_with_layers, debug_path=tmpdir, do_prune_layers=False):
|
||||
_ = self.model_with_layers(**self.dense_input)
|
||||
|
||||
summary_path = os.path.join(tmpdir, "ToyModelWithLayers_debug_tree_SUMMARY.json")
|
||||
with open(summary_path) as f:
|
||||
data = json.load(f)
|
||||
self.assertEqual(set(data.keys()), {"module_path", "inputs", "children"})
|
||||
for layer_index in range(6):
|
||||
self.assertEqual(
|
||||
data["children"][layer_index + 1]["module_path"],
|
||||
f"ToyModelWithLayers.layers.{int(layer_index)}",
|
||||
)
|
||||
|
||||
# Pruning: expect only 2 layers (0 and 5)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with model_addition_debugger_context(self.model_with_layers, debug_path=tmpdir, do_prune_layers=True):
|
||||
_ = self.model_with_layers(**self.dense_input)
|
||||
|
||||
summary_path = os.path.join(tmpdir, "ToyModelWithLayers_debug_tree_SUMMARY.json")
|
||||
with open(summary_path) as f:
|
||||
data = json.load(f)
|
||||
self.assertEqual(set(data.keys()), {"module_path", "inputs", "children"})
|
||||
self.assertEqual(data["children"][1]["module_path"], "ToyModelWithLayers.layers.0")
|
||||
self.assertEqual(data["children"][2]["module_path"], "ToyModelWithLayers.layers.5")
|
||||
197
tests/utils/test_model_output.py
Normal file
197
tests/utils/test_model_output.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright 2020 The Hugging Face Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import AlbertForMaskedLM
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils import ModelOutput, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutputTest(ModelOutput):
|
||||
a: float
|
||||
b: float | None = None
|
||||
c: float | None = None
|
||||
|
||||
|
||||
class ModelOutputTester(unittest.TestCase):
|
||||
def test_get_attributes(self):
|
||||
x = ModelOutputTest(a=30)
|
||||
self.assertEqual(x.a, 30)
|
||||
self.assertIsNone(x.b)
|
||||
self.assertIsNone(x.c)
|
||||
with self.assertRaises(AttributeError):
|
||||
_ = x.d
|
||||
|
||||
def test_index_with_ints_and_slices(self):
|
||||
x = ModelOutputTest(a=30, b=10)
|
||||
self.assertEqual(x[0], 30)
|
||||
self.assertEqual(x[1], 10)
|
||||
self.assertEqual(x[:2], (30, 10))
|
||||
self.assertEqual(x[:], (30, 10))
|
||||
|
||||
x = ModelOutputTest(a=30, c=10)
|
||||
self.assertEqual(x[0], 30)
|
||||
self.assertEqual(x[1], 10)
|
||||
self.assertEqual(x[:2], (30, 10))
|
||||
self.assertEqual(x[:], (30, 10))
|
||||
|
||||
def test_index_with_strings(self):
|
||||
x = ModelOutputTest(a=30, b=10)
|
||||
self.assertEqual(x["a"], 30)
|
||||
self.assertEqual(x["b"], 10)
|
||||
with self.assertRaises(KeyError):
|
||||
_ = x["c"]
|
||||
|
||||
x = ModelOutputTest(a=30, c=10)
|
||||
self.assertEqual(x["a"], 30)
|
||||
self.assertEqual(x["c"], 10)
|
||||
with self.assertRaises(KeyError):
|
||||
_ = x["b"]
|
||||
|
||||
def test_dict_like_properties(self):
|
||||
x = ModelOutputTest(a=30)
|
||||
self.assertEqual(list(x.keys()), ["a"])
|
||||
self.assertEqual(list(x.values()), [30])
|
||||
self.assertEqual(list(x.items()), [("a", 30)])
|
||||
self.assertEqual(list(x), ["a"])
|
||||
|
||||
x = ModelOutputTest(a=30, b=10)
|
||||
self.assertEqual(list(x.keys()), ["a", "b"])
|
||||
self.assertEqual(list(x.values()), [30, 10])
|
||||
self.assertEqual(list(x.items()), [("a", 30), ("b", 10)])
|
||||
self.assertEqual(list(x), ["a", "b"])
|
||||
|
||||
x = ModelOutputTest(a=30, c=10)
|
||||
self.assertEqual(list(x.keys()), ["a", "c"])
|
||||
self.assertEqual(list(x.values()), [30, 10])
|
||||
self.assertEqual(list(x.items()), [("a", 30), ("c", 10)])
|
||||
self.assertEqual(list(x), ["a", "c"])
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
x = x.update({"d": 20})
|
||||
with self.assertRaises(Exception):
|
||||
del x["a"]
|
||||
with self.assertRaises(Exception):
|
||||
_ = x.pop("a")
|
||||
with self.assertRaises(Exception):
|
||||
_ = x.setdefault("d", 32)
|
||||
|
||||
def test_set_attributes(self):
|
||||
x = ModelOutputTest(a=30)
|
||||
x.a = 10
|
||||
self.assertEqual(x.a, 10)
|
||||
self.assertEqual(x["a"], 10)
|
||||
|
||||
def test_set_keys(self):
|
||||
x = ModelOutputTest(a=30)
|
||||
x["a"] = 10
|
||||
self.assertEqual(x.a, 10)
|
||||
self.assertEqual(x["a"], 10)
|
||||
|
||||
def test_instantiate_from_dict(self):
|
||||
x = ModelOutputTest({"a": 30, "b": 10})
|
||||
self.assertEqual(list(x.keys()), ["a", "b"])
|
||||
self.assertEqual(x.a, 30)
|
||||
self.assertEqual(x.b, 10)
|
||||
|
||||
def test_instantiate_from_iterator(self):
|
||||
x = ModelOutputTest([("a", 30), ("b", 10)])
|
||||
self.assertEqual(list(x.keys()), ["a", "b"])
|
||||
self.assertEqual(x.a, 30)
|
||||
self.assertEqual(x.b, 10)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = ModelOutputTest([("a", 30), (10, 10)])
|
||||
|
||||
x = ModelOutputTest(a=(30, 30))
|
||||
self.assertEqual(list(x.keys()), ["a"])
|
||||
self.assertEqual(x.a, (30, 30))
|
||||
|
||||
@require_torch
|
||||
def test_torch_pytree(self):
|
||||
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
|
||||
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
x = ModelOutput({"a": 1.0, "c": 2.0})
|
||||
self.assertFalse(pytree._is_leaf(x))
|
||||
|
||||
x = ModelOutputTest(a=1.0, c=2.0)
|
||||
self.assertFalse(pytree._is_leaf(x))
|
||||
|
||||
expected_flat_outs = [1.0, 2.0]
|
||||
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
|
||||
|
||||
actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
|
||||
self.assertEqual(expected_flat_outs, actual_flat_outs)
|
||||
self.assertEqual(expected_tree_spec, actual_tree_spec)
|
||||
|
||||
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
|
||||
self.assertEqual(x, unflattened_x)
|
||||
|
||||
self.assertEqual(
|
||||
pytree.treespec_dumps(actual_tree_spec),
|
||||
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": "[\\"a\\", \\"c\\"]", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
|
||||
)
|
||||
|
||||
# TODO: @ydshieh
|
||||
@unittest.skip(reason="CPU OOM")
|
||||
@require_torch
|
||||
@pytest.mark.torch_export_test
|
||||
def test_export_serialization(self):
|
||||
model_cls = AlbertForMaskedLM
|
||||
model_config = model_cls.config_class()
|
||||
model = model_cls(model_config)
|
||||
|
||||
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
||||
|
||||
ep = torch.export.export(model, (), input_dict)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.export.save(ep, buffer)
|
||||
buffer.seek(0)
|
||||
loaded_ep = torch.export.load(buffer)
|
||||
|
||||
input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
|
||||
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
|
||||
|
||||
|
||||
class ModelOutputTestNoDataclass(ModelOutput):
|
||||
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
|
||||
|
||||
a: float
|
||||
b: float | None = None
|
||||
c: float | None = None
|
||||
|
||||
|
||||
class ModelOutputSubclassTester(unittest.TestCase):
|
||||
def test_direct_model_output(self):
|
||||
# Check that direct usage of ModelOutput instantiates without errors
|
||||
ModelOutput({"a": 1.1})
|
||||
|
||||
def test_subclass_no_dataclass(self):
|
||||
# Check that a subclass of ModelOutput without @dataclass is invalid
|
||||
# A valid subclass is inherently tested other unit tests above.
|
||||
with self.assertRaises(TypeError):
|
||||
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)
|
||||
708
tests/utils/test_modeling_rope_utils.py
Normal file
708
tests/utils/test_modeling_rope_utils.py
Normal file
@@ -0,0 +1,708 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from transformers import LlamaConfig
|
||||
from transformers.testing_utils import is_torch_available, require_torch, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import ROPE_INIT_FUNCTIONS
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
||||
|
||||
|
||||
@require_torch
|
||||
class RopeTest(unittest.TestCase):
|
||||
def test_rope_validation(self):
|
||||
config = LlamaConfig()
|
||||
all_rope_types = ROPE_INIT_FUNCTIONS.keys()
|
||||
|
||||
# The base config is always valid (default RoPE)
|
||||
config.validate_rope()
|
||||
|
||||
# If we explicitly set the other (non-default) RoPE types with only rope_theta,
|
||||
# validation should fail because required keys are missing (e.g. factor, short_factor)
|
||||
for rope_type in all_rope_types:
|
||||
if rope_type == "default":
|
||||
continue # "default" is always valid with just rope_theta
|
||||
# proportional is same as default wrt to expected keys
|
||||
if rope_type == "proportional":
|
||||
continue
|
||||
config.rope_parameters = {"rope_type": rope_type, "rope_theta": 10000.0}
|
||||
with self.assertRaises(KeyError):
|
||||
config.validate_rope()
|
||||
|
||||
# Parameters are exclusive to their own RoPE type, and should raise an exception if incorrectly passed
|
||||
valid_param_mapping = {
|
||||
"factor": ["linear", "dynamic", "yarn", "longrope"],
|
||||
"attention_factor": ["yarn", "longrope"],
|
||||
"beta_fast": ["yarn"],
|
||||
"beta_slow": ["yarn"],
|
||||
"short_factor": ["longrope"],
|
||||
"long_factor": ["longrope"],
|
||||
}
|
||||
for rope_type in all_rope_types:
|
||||
if rope_type == "default":
|
||||
continue # "default" only warns about unrecognised keys, never raises KeyError
|
||||
# proportional is same as default wrt to expected keys
|
||||
if rope_type == "proportional":
|
||||
continue
|
||||
for param, valid_rope_types in valid_param_mapping.items():
|
||||
# Set `param` with a dummy value -- we want to test the dict key
|
||||
config.rope_parameters = {"rope_type": rope_type, "rope_theta": 10000.0, param: True}
|
||||
if rope_type in valid_rope_types:
|
||||
continue
|
||||
else:
|
||||
with self.assertRaises(KeyError):
|
||||
config.validate_rope()
|
||||
|
||||
# Any other parameters passed to RoPE will raise a warning that a particular key is not used
|
||||
# But sometimes we can have model-specific RoPE kwargs and bypass warning with `ignore_keys`
|
||||
config.ignore_keys_at_rope_validation = {"mrope_sections"} # e,g in Qwen2-VL
|
||||
config.rope_parameters = {"rope_type": "default", "rope_theta": 10000.0, "mrope_sections": True}
|
||||
config.validate_rope()
|
||||
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config.ignore_keys_at_rope_validation = set()
|
||||
config.validate_rope()
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("mrope_sections", logs.output[0])
|
||||
|
||||
# We can indicate Different RoPE params for each attention type
|
||||
# We can also have only one RoPE params defined for all layer, we don't raise an error
|
||||
# because it is not required to have separate RoPE per layer type
|
||||
config.layer_types = ["full_attention", "sliding_attention"]
|
||||
config.rope_parameters = {
|
||||
"full_attention": {"rope_type": "default", "rope_theta": 10000},
|
||||
"sliding_attention": {"rope_type": "linear", "rope_theta": 10000, "factor": 2.0},
|
||||
}
|
||||
config.validate_rope()
|
||||
|
||||
config.rope_parameters = config.rope_parameters["full_attention"]
|
||||
config.validate_rope()
|
||||
|
||||
def test_yarn_original_original_max_position_embeddings_validation(self):
|
||||
"""Tests that models with no/bad `original_max_position_embeddings` raise a warning"""
|
||||
config = LlamaConfig()
|
||||
|
||||
# good rope config: has a factor AND original_max_position_embeddings -> no warnings
|
||||
rope_config = {
|
||||
"rope_type": "yarn",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": 2.0,
|
||||
"original_max_position_embeddings": int(config.max_position_embeddings / 2.0),
|
||||
}
|
||||
config.rope_parameters = rope_config
|
||||
with self.assertRaises(AssertionError): # confirm that no warnings are thrown
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config.validate_rope()
|
||||
|
||||
# bad rope config, no `original_max_position_embeddings` -> raise error
|
||||
rope_config = {
|
||||
"rope_type": "yarn",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": 2.0,
|
||||
}
|
||||
config.rope_parameters = rope_config
|
||||
with self.assertRaises(KeyError):
|
||||
config.validate_rope()
|
||||
|
||||
# bad rope config, bad implicit fator -> warning
|
||||
rope_config = {
|
||||
"rope_type": "yarn",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": 2.0,
|
||||
"original_max_position_embeddings": 1,
|
||||
}
|
||||
config.rope_parameters = rope_config
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config.validate_rope()
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("implicit factor", logs.output[0])
|
||||
|
||||
def test_convert_rope_params_to_dict_with_list_ignore_keys(self):
|
||||
# Regression test for #46121: `ignore_keys_at_rope_validation` becomes a list when loaded from a config.json
|
||||
# (JSON has no set type). `convert_rope_params_to_dict` used to do `list | set` and crash with
|
||||
# TypeError when `partial_rotary_factor` was also set.
|
||||
config = LlamaConfig(partial_rotary_factor=0.25)
|
||||
config.ignore_keys_at_rope_validation = ["mrope_section", "mrope_interleaved"]
|
||||
|
||||
config.convert_rope_params_to_dict(partial_rotary_factor=0.25)
|
||||
|
||||
self.assertIsInstance(config.ignore_keys_at_rope_validation, set)
|
||||
self.assertEqual(
|
||||
config.ignore_keys_at_rope_validation,
|
||||
{"mrope_section", "mrope_interleaved", "partial_rotary_factor"},
|
||||
)
|
||||
|
||||
# Round-trip through from_dict to mimic the JSON-deserialized path that triggered this in production.
|
||||
cfg_dict = config.to_dict()
|
||||
cfg_dict["ignore_keys_at_rope_validation"] = ["mrope_section", "mrope_interleaved"]
|
||||
reloaded = LlamaConfig.from_dict(cfg_dict)
|
||||
reloaded.convert_rope_params_to_dict(partial_rotary_factor=0.25)
|
||||
self.assertIsInstance(reloaded.ignore_keys_at_rope_validation, set)
|
||||
|
||||
# Also accept None (the class-level attribute can be cleared on an instance).
|
||||
config_none = LlamaConfig(partial_rotary_factor=0.25)
|
||||
config_none.ignore_keys_at_rope_validation = None
|
||||
config_none.convert_rope_params_to_dict(partial_rotary_factor=0.25)
|
||||
self.assertEqual(config_none.ignore_keys_at_rope_validation, {"partial_rotary_factor"})
|
||||
|
||||
def test_rope_validation_with_per_attention_type_nested_rope(self):
|
||||
"""Mirrors `test_rope_validation` with `config.layer_types` set, so that
|
||||
`rope_parameters` takes the per-attention-type nested shape."""
|
||||
config = LlamaConfig()
|
||||
all_rope_types = ROPE_INIT_FUNCTIONS.keys()
|
||||
config.layer_types = ["full_attention", "sliding_attention"]
|
||||
|
||||
def nest(full_attention_params):
|
||||
return {
|
||||
"full_attention": full_attention_params,
|
||||
"sliding_attention": {"rope_type": "default", "rope_theta": 10000.0},
|
||||
}
|
||||
|
||||
# Each non-default RoPE type with only `rope_theta` should still raise
|
||||
# KeyError (missing required keys) when wrapped in the nested shape.
|
||||
for rope_type in all_rope_types:
|
||||
if rope_type in ("default", "proportional"):
|
||||
continue
|
||||
config.rope_parameters = nest({"rope_type": rope_type, "rope_theta": 10000.0})
|
||||
with self.assertRaises(KeyError):
|
||||
config.validate_rope()
|
||||
|
||||
# Parameters exclusive to a RoPE type should still raise when passed to
|
||||
# the wrong type while in the nested shape.
|
||||
valid_param_mapping = {
|
||||
"factor": ["linear", "dynamic", "yarn", "longrope"],
|
||||
"attention_factor": ["yarn", "longrope"],
|
||||
"beta_fast": ["yarn"],
|
||||
"beta_slow": ["yarn"],
|
||||
"short_factor": ["longrope"],
|
||||
"long_factor": ["longrope"],
|
||||
}
|
||||
for rope_type in all_rope_types:
|
||||
if rope_type in ("default", "proportional"):
|
||||
continue
|
||||
for param, valid_rope_types in valid_param_mapping.items():
|
||||
config.rope_parameters = nest({"rope_type": rope_type, "rope_theta": 10000.0, param: True})
|
||||
if rope_type in valid_rope_types:
|
||||
continue
|
||||
with self.assertRaises(KeyError):
|
||||
config.validate_rope()
|
||||
|
||||
# A complete yarn entry under the nested shape should validate cleanly.
|
||||
# Regression: previously the implicit-factor check inside the yarn
|
||||
# validator dereferenced `self.rope_parameters` (the full nested dict)
|
||||
# rather than its per-type `rope_parameters` argument.
|
||||
config.rope_parameters = nest(
|
||||
{
|
||||
"rope_type": "yarn",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": 2.0,
|
||||
"original_max_position_embeddings": int(config.max_position_embeddings / 2.0),
|
||||
}
|
||||
)
|
||||
config.validate_rope()
|
||||
|
||||
def test_default_rope_numerically(self):
|
||||
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
|
||||
# multiple RoPE strategies will fail.
|
||||
# fmt: off
|
||||
EXPECTED_INV_FREQ = torch.tensor(
|
||||
[
|
||||
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
|
||||
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
|
||||
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
|
||||
7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
|
||||
3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
|
||||
1.3335e-02, 1.1548e-02, 1.0000e-02, 8.6596e-03, 7.4989e-03, 6.4938e-03,
|
||||
5.6234e-03, 4.8697e-03, 4.2170e-03, 3.6517e-03, 3.1623e-03, 2.7384e-03,
|
||||
2.3714e-03, 2.0535e-03, 1.7783e-03, 1.5399e-03, 1.3335e-03, 1.1548e-03,
|
||||
1.0000e-03, 8.6596e-04, 7.4989e-04, 6.4938e-04, 5.6234e-04, 4.8697e-04,
|
||||
4.2170e-04, 3.6517e-04, 3.1623e-04, 2.7384e-04, 2.3714e-04, 2.0535e-04,
|
||||
1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04
|
||||
], device=torch_device
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
|
||||
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
|
||||
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for default RoPE
|
||||
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|
||||
|
||||
def test_linear_rope_numerically(self):
|
||||
# This is a linear scaling strategy, the **frequencies** are scaled linearly with respect to the default
|
||||
# frequencies (= the inverse frequencies are scaled **inversely**)
|
||||
config = LlamaConfig()
|
||||
default_rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
|
||||
default_inv_freq, _ = default_rope_fn(config=config, device=torch_device)
|
||||
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["linear"]
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_parameters = {"rope_type": "linear", "rope_theta": 10000.0, "factor": factor}
|
||||
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for linear RoPE
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / factor)
|
||||
|
||||
def test_dynamic_rope_numerically(self):
|
||||
# fmt: off
|
||||
EXPECTED_INV_FREQ = torch.tensor(
|
||||
[
|
||||
1.0000e+00, 8.0931e-01, 6.5498e-01, 5.3008e-01, 4.2900e-01, 3.4720e-01,
|
||||
2.8099e-01, 2.2741e-01, 1.8404e-01, 1.4895e-01, 1.2055e-01, 9.7558e-02,
|
||||
7.8955e-02, 6.3899e-02, 5.1714e-02, 4.1853e-02, 3.3872e-02, 2.7413e-02,
|
||||
2.2185e-02, 1.7955e-02, 1.4531e-02, 1.1760e-02, 9.5176e-03, 7.7027e-03,
|
||||
6.2339e-03, 5.0451e-03, 4.0831e-03, 3.3045e-03, 2.6744e-03, 2.1644e-03,
|
||||
1.7517e-03, 1.4176e-03, 1.1473e-03, 9.2852e-04, 7.5146e-04, 6.0817e-04,
|
||||
4.9220e-04, 3.9834e-04, 3.2238e-04, 2.6091e-04, 2.1115e-04, 1.7089e-04,
|
||||
1.3830e-04, 1.1193e-04, 9.0585e-05, 7.3312e-05, 5.9332e-05, 4.8018e-05,
|
||||
3.8861e-05, 3.1451e-05, 2.5453e-05, 2.0600e-05, 1.6672e-05, 1.3492e-05,
|
||||
1.0920e-05, 8.8374e-06, 7.1522e-06, 5.7883e-06, 4.6845e-06, 3.7912e-06,
|
||||
3.0683e-06, 2.4832e-06, 2.0097e-06, 1.6265e-06
|
||||
], device=torch_device
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
|
||||
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
|
||||
# Check 1: this is a dynamic scaling strategy, it will not scale unless we provide `seq_len` larger than the
|
||||
# model's original training sequence length
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_parameters = {"rope_type": "dynamic", "rope_theta": 10000.0, "factor": factor}
|
||||
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for dynamic RoPE
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq)
|
||||
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1)
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq)
|
||||
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=torch.tensor(1, dtype=torch.int64))
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq)
|
||||
|
||||
# Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
|
||||
# will scale up (i.e., the inverse frequencies will scale down).
|
||||
factor = 10.0
|
||||
config.rope_parameters = {"rope_type": "dynamic", "rope_theta": 10000.0, "factor": factor}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=16384)
|
||||
with self.assertRaises(AssertionError): # It is NOT a linear factor
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / factor)
|
||||
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|
||||
|
||||
def test_yarn_rope_numerically(self):
|
||||
# fmt: off
|
||||
EXPECTED_INV_FREQ = torch.tensor(
|
||||
[
|
||||
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
|
||||
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
|
||||
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.3479e-02,
|
||||
6.9590e-02, 5.7925e-02, 4.8136e-02, 3.9931e-02, 3.3061e-02, 2.7315e-02,
|
||||
2.2515e-02, 1.8512e-02, 1.5177e-02, 1.2403e-02, 1.0101e-02, 8.1924e-03,
|
||||
6.6143e-03, 5.3120e-03, 4.2400e-03, 3.3599e-03, 2.6396e-03, 2.0520e-03,
|
||||
1.5746e-03, 1.1882e-03, 8.7713e-04, 6.2810e-04, 4.3007e-04, 2.7384e-04,
|
||||
2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
|
||||
1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
|
||||
4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
|
||||
1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
|
||||
], device=torch_device
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
|
||||
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
|
||||
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
|
||||
# `0.1 * math.log(factor) + 1.0`
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["yarn"]
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_parameters = {"rope_type": "yarn", "rope_theta": 10000.0, "factor": factor}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, 0.1 * math.log(factor) + 1.0)
|
||||
|
||||
config.rope_parameters = {
|
||||
"rope_type": "yarn",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"attention_factor": 0.5,
|
||||
}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
|
||||
self.assertEqual(attention_scale, 0.5)
|
||||
|
||||
# Check 2: based on `beta_fast` and `beta_slow`, the frequencies will be scaled between 1 and `factor`.
|
||||
# Increasing `beta_fast` will make RoPE more interpolative (apply scaling), and the other way around.
|
||||
# `beta_slow` behaves the opposite way. Remember: `beta_fast` > `beta_slow`
|
||||
# (note: adds a margin to the test for numerical stability)
|
||||
factor = 10.0
|
||||
margin = 1e-8
|
||||
config.rope_parameters = {
|
||||
"rope_type": "yarn",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
is_bounded_by_factor = [
|
||||
((default_inv_freq[idx] / factor) - margin) <= yarn_inv_freq_value <= (default_inv_freq[idx] + margin)
|
||||
for idx, yarn_inv_freq_value in enumerate(inv_freq)
|
||||
]
|
||||
self.assertTrue(all(is_bounded_by_factor))
|
||||
|
||||
# super high beta_fast = interpolation (i.e. scaling) in all but the first inverse frequency. The last ~20
|
||||
# values (empirically checked for `beta_fast` = 1000) should be very small to linear scaling
|
||||
config.rope_parameters = {
|
||||
"rope_type": "yarn",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"beta_fast": 1000,
|
||||
"beta_slow": 1,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
is_interpolating = [
|
||||
yarn_inv_freq_value < (default_inv_freq[idx] + margin) for idx, yarn_inv_freq_value in enumerate(inv_freq)
|
||||
]
|
||||
self.assertFalse(is_interpolating[0])
|
||||
self.assertTrue(all(is_interpolating[1:]))
|
||||
torch.testing.assert_close(inv_freq[-20:], default_inv_freq[-20:] / factor)
|
||||
|
||||
# Check 3: numerical snapshot to avoid regressions
|
||||
config.rope_parameters = {
|
||||
"rope_type": "yarn",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|
||||
|
||||
def test_longrope_rope_numerically(self):
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on the seq_len
|
||||
dim = config.hidden_size // config.num_attention_heads
|
||||
short_factor = [2.0] * (dim // 2) # scaling applied when seq_len <= max_position_embeddings
|
||||
long_factor = torch.ones(dim // 2).cumsum(0).tolist() # scaling applied when seq_len > max_position_embeddings
|
||||
|
||||
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
|
||||
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
|
||||
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
|
||||
# `math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))`
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["longrope"]
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_parameters = {
|
||||
"rope_type": "longrope",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(
|
||||
attention_scale, math.sqrt(1 + math.log(factor) / math.log(config.max_position_embeddings))
|
||||
)
|
||||
|
||||
config.rope_parameters = {
|
||||
"rope_type": "longrope",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
"attention_factor": 0.5,
|
||||
}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
|
||||
self.assertEqual(attention_scale, 0.5)
|
||||
|
||||
config.rope_parameters = {
|
||||
"rope_type": "longrope",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
}
|
||||
self.assertEqual(config.rope_parameters.get("attention_factor"), None)
|
||||
# Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised.
|
||||
config.standardize_rope_params()
|
||||
config.validate_rope()
|
||||
|
||||
# Check 2: seq_len == 0 -> short factor is applied to the default frequencies
|
||||
config.rope_parameters = {
|
||||
"rope_type": "longrope",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": 1.0,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=0)
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(short_factor).to(torch_device))
|
||||
|
||||
# Check 3: seq_len > max_position_embeddings -> long factor is applied to the default frequencies
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=config.max_position_embeddings + 1)
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(long_factor).to(torch_device))
|
||||
|
||||
def test_llama3_rope_numerically(self):
|
||||
# fmt: off
|
||||
EXPECTED_INV_FREQ = torch.tensor(
|
||||
[
|
||||
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
|
||||
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
|
||||
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
|
||||
7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
|
||||
3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
|
||||
1.3335e-02, 1.0730e-02, 7.7785e-03, 5.6009e-03, 3.9991e-03, 2.8248e-03,
|
||||
1.9675e-03, 1.3449e-03, 8.9549e-04, 5.7363e-04, 3.4539e-04, 2.7384e-04,
|
||||
2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
|
||||
1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
|
||||
4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
|
||||
1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
|
||||
], device=torch_device
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
|
||||
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
|
||||
# Check 1: `attention_factor` is always 1
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["llama3"]
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_parameters = {
|
||||
"rope_type": "llama3",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"low_freq_factor": 1,
|
||||
"high_freq_factor": 4,
|
||||
}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, 1.0)
|
||||
|
||||
# Check 2: based on `low_freq_factor` and `high_freq_factor`, the frequencies will be scaled between 1 and
|
||||
# `factor` (similar to yarn). Low frequencies get scaled by `factor`, high frequencies see no change, medium
|
||||
# frequencies are scaled by a value in between. Changing `low_freq_factor` and `high_freq_factor` changes what
|
||||
# is considered low, medium, and high frequencies.
|
||||
factor = 10.0
|
||||
config.rope_parameters = {
|
||||
"rope_type": "llama3",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"low_freq_factor": 1,
|
||||
"high_freq_factor": 4,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
is_bounded_by_factor = [
|
||||
(default_inv_freq[idx] / factor) <= llama3_inv_freq_value <= default_inv_freq[idx]
|
||||
for idx, llama3_inv_freq_value in enumerate(inv_freq)
|
||||
]
|
||||
self.assertTrue(all(is_bounded_by_factor))
|
||||
|
||||
# if we change `high_freq_factor` to a very high value, none is considered high-frequency -> ALL values will be
|
||||
# scaled
|
||||
config.rope_parameters = config.rope_parameters = {
|
||||
"rope_type": "llama3",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"low_freq_factor": 1,
|
||||
"high_freq_factor": 1000,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
is_scaled = [yarn_inv_freq_value < default_inv_freq[idx] for idx, yarn_inv_freq_value in enumerate(inv_freq)]
|
||||
self.assertTrue(all(is_scaled))
|
||||
|
||||
# Check 3: numerical snapshot to avoid regressions
|
||||
config.rope_parameters = {
|
||||
"rope_type": "llama3",
|
||||
"rope_theta": 10000.0,
|
||||
"factor": factor,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"low_freq_factor": 1,
|
||||
"high_freq_factor": 4,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|
||||
|
||||
def test_proportional_rope_numerically(self):
|
||||
# fmt: off
|
||||
EXPECTED_INV_FREQ = torch.tensor(
|
||||
[
|
||||
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
|
||||
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
|
||||
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 0.0000e+00, 0.0000e+00,
|
||||
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
||||
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
||||
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
||||
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
||||
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
||||
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
||||
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
|
||||
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00
|
||||
], device=torch_device
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_parameters, {"rope_type": "default", "rope_theta": 10000.0})
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
head_dim = config.hidden_size // config.num_attention_heads # 128
|
||||
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["proportional"]
|
||||
default_rope_fn = LlamaRotaryEmbedding.compute_default_rope_parameters
|
||||
|
||||
# Check 1: `attention_factor` is always 1.0, regardless of parameters
|
||||
for partial_rotary_factor in (1.0, 0.5, 0.25):
|
||||
config.rope_parameters = {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 10000.0,
|
||||
"partial_rotary_factor": partial_rotary_factor,
|
||||
}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, 1.0)
|
||||
|
||||
# Check 2: output shape is always head_dim // 2, regardless of partial_rotary_factor
|
||||
for partial_rotary_factor in (1.0, 0.5, 0.25):
|
||||
config.rope_parameters = {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 10000.0,
|
||||
"partial_rotary_factor": partial_rotary_factor,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(inv_freq.shape[0], head_dim // 2)
|
||||
|
||||
# Check 3: zero-padding behavior — when partial_rotary_factor < 1.0, the last (head_dim // 2 - rope_angles)
|
||||
# entries must be exactly zero, and the first rope_angles entries must be non-zero
|
||||
for partial_rotary_factor, expected_rope_angles in ((0.5, 32), (0.25, 16)):
|
||||
config.rope_parameters = {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 10000.0,
|
||||
"partial_rotary_factor": partial_rotary_factor,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
|
||||
# First rope_angles entries should be non-zero (rotated frequencies)
|
||||
self.assertTrue(torch.all(inv_freq[:expected_rope_angles] != 0))
|
||||
# Remaining entries should be exactly zero (NoPE angles)
|
||||
expected_nope_angles = head_dim // 2 - expected_rope_angles
|
||||
torch.testing.assert_close(
|
||||
inv_freq[expected_rope_angles:],
|
||||
torch.zeros(expected_nope_angles, device=torch_device),
|
||||
)
|
||||
|
||||
# When partial_rotary_factor = 1.0, no entries should be zero
|
||||
config.rope_parameters = {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 10000.0,
|
||||
"partial_rotary_factor": 1.0,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
self.assertTrue(torch.all(inv_freq != 0))
|
||||
|
||||
# Check 4: factor scaling equivalences with default and linear RoPE
|
||||
# 4a: With partial_rotary_factor=1.0 and factor=1.0, proportional RoPE == default RoPE
|
||||
config.rope_parameters = {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 10000.0,
|
||||
"partial_rotary_factor": 1.0,
|
||||
"factor": 1.0,
|
||||
}
|
||||
inv_freq_prop, _ = rope_fn(config=config, device=torch_device)
|
||||
config.rope_parameters = {"rope_type": "default", "rope_theta": 10000.0}
|
||||
default_inv_freq, _ = default_rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq_prop, default_inv_freq)
|
||||
|
||||
# 4b: With partial_rotary_factor=1.0 and factor=2.0, proportional RoPE == linear RoPE
|
||||
linear_rope_fn = ROPE_INIT_FUNCTIONS["linear"]
|
||||
for factor in (2.0, 10.0):
|
||||
config.rope_parameters = {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 10000.0,
|
||||
"partial_rotary_factor": 1.0,
|
||||
"factor": factor,
|
||||
}
|
||||
inv_freq_prop, _ = rope_fn(config=config, device=torch_device)
|
||||
config.rope_parameters = {"rope_type": "linear", "rope_theta": 10000.0, "factor": factor}
|
||||
inv_freq_linear, _ = linear_rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq_prop, inv_freq_linear)
|
||||
|
||||
# 4c: With partial_rotary_factor=0.5 and factor=2.0, the non-zero portion should be the rotated subspace
|
||||
# frequencies divided by factor
|
||||
config.rope_parameters = {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 10000.0,
|
||||
"partial_rotary_factor": 0.5,
|
||||
"factor": 2.0,
|
||||
}
|
||||
inv_freq_scaled, _ = rope_fn(config=config, device=torch_device)
|
||||
config.rope_parameters = {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 10000.0,
|
||||
"partial_rotary_factor": 0.5,
|
||||
"factor": 1.0,
|
||||
}
|
||||
inv_freq_unscaled, _ = rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq_scaled, inv_freq_unscaled / 2.0)
|
||||
|
||||
# Check 5: numerical snapshot to avoid regressions (partial_rotary_factor=0.25, factor=1.0)
|
||||
config.rope_parameters = {
|
||||
"rope_type": "proportional",
|
||||
"rope_theta": 10000.0,
|
||||
"partial_rotary_factor": 0.25,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|
||||
3569
tests/utils/test_modeling_utils.py
Normal file
3569
tests/utils/test_modeling_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
82
tests/utils/test_network_logging.py
Normal file
82
tests/utils/test_network_logging.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
|
||||
import httpx
|
||||
|
||||
from transformers.utils.network_logging import (
|
||||
_clear_network_debug_report,
|
||||
_disable_network_debug_report,
|
||||
_enable_network_debug_report,
|
||||
_format_network_debug_report,
|
||||
_get_network_debug_report,
|
||||
)
|
||||
|
||||
|
||||
class _SlowHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
time.sleep(0.01)
|
||||
response = b"ok"
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Length", str(len(response)))
|
||||
self.end_headers()
|
||||
self.wfile.write(response)
|
||||
|
||||
def log_message(self, format, *args):
|
||||
return
|
||||
|
||||
|
||||
class NetworkLoggingTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._server = ThreadingHTTPServer(("127.0.0.1", 0), _SlowHandler)
|
||||
cls._thread = threading.Thread(target=cls._server.serve_forever, daemon=True)
|
||||
cls._thread.start()
|
||||
cls._base_url = f"http://127.0.0.1:{cls._server.server_port}"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._server.shutdown()
|
||||
cls._thread.join()
|
||||
cls._server.server_close()
|
||||
|
||||
def tearDown(self):
|
||||
_disable_network_debug_report()
|
||||
|
||||
def test_network_debug_report_records_httpx_requests(self):
|
||||
_enable_network_debug_report()
|
||||
_clear_network_debug_report()
|
||||
|
||||
response = httpx.get(f"{self._base_url}/slow")
|
||||
self.assertEqual(response.text, "ok")
|
||||
|
||||
report = _get_network_debug_report()
|
||||
matching_requests = [request for request in report["requests"] if request["url"].endswith("/slow")]
|
||||
self.assertEqual(len(matching_requests), 1)
|
||||
|
||||
request = matching_requests[0]
|
||||
self.assertEqual(request["method"], "GET")
|
||||
self.assertEqual(request["status_code"], 200)
|
||||
self.assertEqual(request["path"], "/slow")
|
||||
self.assertGreater(request["total_ms"], 0)
|
||||
self.assertIn("receive_response_headers", request["phases_ms"])
|
||||
|
||||
summary = _format_network_debug_report()
|
||||
self.assertIn("Network debug report", summary)
|
||||
self.assertIn("Slowest requests:", summary)
|
||||
self.assertIn("/slow", summary)
|
||||
220
tests/utils/test_offline.py
Normal file
220
tests/utils/test_offline.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
|
||||
from transformers.testing_utils import TestCasePlus, require_torch
|
||||
|
||||
|
||||
class OfflineTests(TestCasePlus):
|
||||
@require_torch
|
||||
@unittest.skip("This test is failing on main") # TODO matt/ydshieh, this test needs to be fixed
|
||||
def test_offline_mode(self):
|
||||
# this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before
|
||||
# `transformers` is loaded, and it's too late for inside pytest - so we are changing it
|
||||
# while running an external program
|
||||
|
||||
# python one-liner segments
|
||||
|
||||
# this must be loaded before socket.socket is monkey-patched
|
||||
load = """
|
||||
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
|
||||
"""
|
||||
|
||||
run = """
|
||||
mname = "hf-internal-testing/tiny-random-bert"
|
||||
BertConfig.from_pretrained(mname)
|
||||
BertModel.from_pretrained(mname)
|
||||
BertTokenizer.from_pretrained(mname)
|
||||
pipe = pipeline(task="fill-mask", model=mname)
|
||||
print("success")
|
||||
"""
|
||||
|
||||
mock = """
|
||||
import socket
|
||||
def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet")
|
||||
socket.socket = offline_socket
|
||||
"""
|
||||
|
||||
# Force fetching the files so that we can use the cache
|
||||
mname = "hf-internal-testing/tiny-random-bert"
|
||||
BertConfig.from_pretrained(mname)
|
||||
BertModel.from_pretrained(mname)
|
||||
BertTokenizer.from_pretrained(mname)
|
||||
pipeline(task="fill-mask", model=mname)
|
||||
|
||||
# baseline - just load from_pretrained with normal network
|
||||
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
|
||||
stdout, _ = self._execute_with_env(load, run, mock, TRANSFORMERS_OFFLINE="1")
|
||||
self.assertIn("success", stdout)
|
||||
|
||||
@require_torch
|
||||
def test_offline_mode_no_internet(self):
|
||||
# python one-liner segments
|
||||
# this must be loaded before socket.socket is monkey-patched
|
||||
load = """
|
||||
from transformers import BertConfig, BertModel, BertTokenizer, pipeline
|
||||
"""
|
||||
|
||||
run = """
|
||||
mname = "hf-internal-testing/tiny-random-bert"
|
||||
BertConfig.from_pretrained(mname)
|
||||
BertModel.from_pretrained(mname)
|
||||
BertTokenizer.from_pretrained(mname)
|
||||
pipe = pipeline(task="fill-mask", model=mname)
|
||||
print("success")
|
||||
"""
|
||||
|
||||
mock = """
|
||||
import socket
|
||||
def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet")
|
||||
socket.socket = offline_socket
|
||||
"""
|
||||
|
||||
# Force fetching the files so that we can use the cache
|
||||
mname = "hf-internal-testing/tiny-random-bert"
|
||||
BertConfig.from_pretrained(mname)
|
||||
BertModel.from_pretrained(mname)
|
||||
BertTokenizer.from_pretrained(mname)
|
||||
pipeline(task="fill-mask", model=mname)
|
||||
|
||||
# baseline - just load from_pretrained with normal network
|
||||
# should succeed
|
||||
stdout, _ = self._execute_with_env(load, run, mock)
|
||||
self.assertIn("success", stdout)
|
||||
|
||||
@require_torch
|
||||
def test_offline_mode_sharded_checkpoint(self):
|
||||
# this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before
|
||||
# `transformers` is loaded, and it's too late for inside pytest - so we are changing it
|
||||
# while running an external program
|
||||
|
||||
# python one-liner segments
|
||||
|
||||
# this must be loaded before socket.socket is monkey-patched
|
||||
load = """
|
||||
from transformers import BertConfig, BertModel, BertTokenizer
|
||||
"""
|
||||
|
||||
run = """
|
||||
mname = "hf-internal-testing/tiny-random-bert-sharded"
|
||||
BertConfig.from_pretrained(mname)
|
||||
BertModel.from_pretrained(mname)
|
||||
print("success")
|
||||
"""
|
||||
|
||||
mock = """
|
||||
import socket
|
||||
def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled")
|
||||
socket.socket = offline_socket
|
||||
"""
|
||||
|
||||
# baseline - just load from_pretrained with normal network
|
||||
# should succeed
|
||||
stdout, _ = self._execute_with_env(load, run)
|
||||
self.assertIn("success", stdout)
|
||||
|
||||
# next emulate no network
|
||||
# Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
|
||||
# self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="0")
|
||||
|
||||
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
|
||||
stdout, _ = self._execute_with_env(load, mock, run, TRANSFORMERS_OFFLINE="1")
|
||||
self.assertIn("success", stdout)
|
||||
|
||||
@require_torch
|
||||
def test_offline_mode_pipeline_exception(self):
|
||||
load = """
|
||||
from transformers import pipeline
|
||||
"""
|
||||
run = """
|
||||
mname = "hf-internal-testing/tiny-random-bert"
|
||||
pipe = pipeline(model=mname)
|
||||
"""
|
||||
|
||||
mock = """
|
||||
import socket
|
||||
def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
|
||||
socket.socket = offline_socket
|
||||
"""
|
||||
|
||||
_, stderr = self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="1")
|
||||
self.assertIn(
|
||||
"You cannot infer task automatically within `pipeline` when using offline mode",
|
||||
stderr.replace("\n", ""),
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_offline_model_dynamic_model(self):
|
||||
load = """
|
||||
from transformers import AutoModel
|
||||
"""
|
||||
run = """
|
||||
mname = "hf-internal-testing/test_dynamic_model"
|
||||
AutoModel.from_pretrained(mname, trust_remote_code=True)
|
||||
print("success")
|
||||
"""
|
||||
|
||||
# baseline - just load from_pretrained with normal network
|
||||
# should succeed
|
||||
stdout, _ = self._execute_with_env(load, run)
|
||||
self.assertIn("success", stdout)
|
||||
|
||||
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
|
||||
stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1")
|
||||
self.assertIn("success", stdout)
|
||||
|
||||
def test_is_offline_mode(self):
|
||||
"""
|
||||
Test `is_offline_mode` helper (should respect both HF_HUB_OFFLINE and legacy TRANSFORMERS_OFFLINE env vars)
|
||||
"""
|
||||
load = "from huggingface_hub import is_offline_mode"
|
||||
run = "print(is_offline_mode())"
|
||||
|
||||
stdout, _ = self._execute_with_env(load, run)
|
||||
self.assertIn("False", stdout)
|
||||
|
||||
stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1")
|
||||
self.assertIn("True", stdout)
|
||||
|
||||
stdout, _ = self._execute_with_env(load, run, HF_HUB_OFFLINE="1")
|
||||
self.assertIn("True", stdout)
|
||||
|
||||
def _execute_with_env(self, *commands: tuple[str, ...], should_fail: bool = False, **env) -> tuple[str, str]:
|
||||
"""Execute Python code with a given environment and return the stdout/stderr as strings.
|
||||
|
||||
If `should_fail=True`, the command is expected to fail. Otherwise, it should succeed.
|
||||
Environment variables can be passed as keyword arguments.
|
||||
"""
|
||||
# Build command
|
||||
cmd = [sys.executable, "-c", "\n".join(commands)]
|
||||
|
||||
# Configure env
|
||||
new_env = self.get_env()
|
||||
new_env.update(env)
|
||||
|
||||
# Run command
|
||||
result = subprocess.run(cmd, env=new_env, check=False, capture_output=True)
|
||||
|
||||
# Check execution
|
||||
if should_fail:
|
||||
self.assertNotEqual(result.returncode, 0, result.stderr)
|
||||
else:
|
||||
self.assertEqual(result.returncode, 0, result.stderr)
|
||||
|
||||
# Return output
|
||||
return result.stdout.decode(), result.stderr.decode()
|
||||
124
tests/utils/test_skip_decorators.py
Normal file
124
tests/utils/test_skip_decorators.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright 2019-present, the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
#
|
||||
#
|
||||
# this test validates that we can stack skip decorators in groups and whether
|
||||
# they work correctly with other decorators
|
||||
#
|
||||
# since the decorators have already built their decision params (like checking
|
||||
# env[], we can't mock the env and test each of the combinations), so ideally
|
||||
# the following 4 should be run. But since we have different CI jobs running
|
||||
# different configs, all combinations should get covered
|
||||
#
|
||||
# RUN_SLOW=1 pytest -rA tests/test_skip_decorators.py
|
||||
# RUN_SLOW=1 CUDA_VISIBLE_DEVICES="" pytest -rA tests/test_skip_decorators.py
|
||||
# RUN_SLOW=0 pytest -rA tests/test_skip_decorators.py
|
||||
# RUN_SLOW=0 CUDA_VISIBLE_DEVICES="" pytest -rA tests/test_skip_decorators.py
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
|
||||
# skipping in unittest tests
|
||||
|
||||
params = [(1,)]
|
||||
|
||||
|
||||
# test that we can stack our skip decorators with 3rd party decorators
|
||||
def check_slow():
|
||||
run_slow = bool(os.getenv("RUN_SLOW", "0"))
|
||||
if run_slow:
|
||||
assert True
|
||||
else:
|
||||
assert False, "should have been skipped"
|
||||
|
||||
|
||||
# test that we can stack our skip decorators
|
||||
def check_slow_torch_cuda():
|
||||
run_slow = bool(os.getenv("RUN_SLOW", "0"))
|
||||
if run_slow and torch_device == "cuda":
|
||||
assert True
|
||||
else:
|
||||
assert False, "should have been skipped"
|
||||
|
||||
|
||||
def check_slow_torch_accelerator():
|
||||
run_slow = bool(os.getenv("RUN_SLOW", "0"))
|
||||
assert run_slow and torch_device in ["cuda", "xpu"], "should have been skipped"
|
||||
|
||||
|
||||
@require_torch
|
||||
class SkipTester(unittest.TestCase):
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
def test_2_skips_slow_first(self):
|
||||
check_slow_torch_accelerator()
|
||||
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
def test_2_skips_slow_last(self):
|
||||
check_slow_torch_accelerator()
|
||||
|
||||
# The combination of any skip decorator, followed by parameterized fails to skip the tests
|
||||
# 1. @slow manages to correctly skip `test_param_slow_first`
|
||||
# 2. but then `parameterized` creates new tests, with a unique name for each parameter groups.
|
||||
# It has no idea that they are to be skipped and so they all run, ignoring @slow
|
||||
# Therefore skip decorators must come after `parameterized`
|
||||
#
|
||||
# @slow
|
||||
# @parameterized.expand(params)
|
||||
# def test_param_slow_first(self, param=None):
|
||||
# check_slow()
|
||||
|
||||
# This works as expected:
|
||||
# 1. `parameterized` creates new tests with unique names
|
||||
# 2. each of them gets an opportunity to be skipped
|
||||
@parameterized.expand(params)
|
||||
@slow
|
||||
def test_param_slow_last(self, param=None):
|
||||
check_slow()
|
||||
|
||||
|
||||
# skipping in non-unittest tests
|
||||
# no problem at all here
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
def test_pytest_2_skips_slow_first():
|
||||
check_slow_torch_accelerator()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
def test_pytest_2_skips_slow_last():
|
||||
check_slow_torch_accelerator()
|
||||
|
||||
|
||||
@slow
|
||||
@pytest.mark.parametrize("param", [1])
|
||||
def test_pytest_param_slow_first(param):
|
||||
check_slow()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("param", [1])
|
||||
@slow
|
||||
def test_pytest_param_slow_last(param):
|
||||
check_slow()
|
||||
308
tests/utils/test_tokenization_utils.py
Normal file
308
tests/utils/test_tokenization_utils.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from transformers import AutoTokenizer, BertTokenizer, BertTokenizerFast, GPT2TokenizerFast, is_tokenizers_available
|
||||
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, require_tokenizers
|
||||
from transformers.tokenization_python import ExtensionsTrie, Trie
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_tokenization import CustomTokenizer # noqa E402
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
from test_module.custom_tokenization_fast import CustomTokenizerFast
|
||||
|
||||
|
||||
class TokenizerUtilTester(unittest.TestCase):
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"failed", request=mock.Mock(), response=mock.Mock()
|
||||
)
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
|
||||
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
|
||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
@require_tokenizers
|
||||
def test_cached_files_are_used_when_internet_is_down_missing_files(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"failed", request=mock.Mock(), response=mock.Mock()
|
||||
)
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
|
||||
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
|
||||
_ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
def test_push_to_hub_chat_templates(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
tokenizer.chat_template = "test template"
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
|
||||
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
|
||||
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
tokenizer.chat_template = {"default": "a", "secondary": "b"}
|
||||
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
|
||||
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
|
||||
|
||||
def test_push_to_hub_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
def test_push_to_hub_in_organization_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
|
||||
|
||||
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
@require_tokenizers
|
||||
def test_push_to_hub_dynamic_tokenizer(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
CustomTokenizer.register_for_auto_class()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = CustomTokenizer(vocab_file)
|
||||
|
||||
# No fast custom tokenizer
|
||||
tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
|
||||
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
|
||||
|
||||
@require_tokenizers
|
||||
def test_push_to_hub_dynamic_tokenizer_with_both_slow_and_fast_classes(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
CustomTokenizer.register_for_auto_class()
|
||||
|
||||
# Fast and slow custom tokenizer
|
||||
CustomTokenizerFast.register_for_auto_class()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
|
||||
bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
|
||||
bert_tokenizer.save_pretrained(tmp_dir)
|
||||
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
|
||||
|
||||
tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
|
||||
tokenizer = AutoTokenizer.from_pretrained(tmp_repo.repo_id, use_fast=False, trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
class TokenizersBackendTest(unittest.TestCase):
|
||||
def test_clean_up_tokenization_spaces(self):
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
||||
|
||||
# GPT-2 is a BPE tokenizer — clean_up_tokenization is skipped because it
|
||||
# was designed for WordPiece and is destructive for BPE (strips legitimate
|
||||
# spaces before punctuation).
|
||||
# Use text with spaces before punctuation that cleanup would strip if applied.
|
||||
text = "x != y"
|
||||
token_ids = tokenizer.encode(text)
|
||||
|
||||
decoded_no_cleanup = tokenizer.decode(token_ids, clean_up_tokenization_spaces=False)
|
||||
self.assertEqual(decoded_no_cleanup, text)
|
||||
|
||||
# With BPE guard, cleanup=True also preserves the text
|
||||
decoded_with_cleanup = tokenizer.decode(token_ids, clean_up_tokenization_spaces=True)
|
||||
self.assertEqual(decoded_with_cleanup, text)
|
||||
|
||||
|
||||
class TrieTest(unittest.TestCase):
|
||||
def test_trie(self):
|
||||
trie = Trie()
|
||||
trie.add("Hello 友達")
|
||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
|
||||
trie.add("Hello")
|
||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
|
||||
|
||||
def test_trie_split(self):
|
||||
trie = Trie()
|
||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
|
||||
trie.add("[CLS]")
|
||||
trie.add("extra_id_1")
|
||||
trie.add("extra_id_100")
|
||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
||||
|
||||
def test_trie_single(self):
|
||||
trie = Trie()
|
||||
trie.add("A")
|
||||
self.assertEqual(trie.split("ABC"), ["A", "BC"])
|
||||
self.assertEqual(trie.split("BCA"), ["BC", "A"])
|
||||
|
||||
def test_trie_final(self):
|
||||
trie = Trie()
|
||||
trie.add("TOKEN]")
|
||||
trie.add("[SPECIAL_TOKEN]")
|
||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||
|
||||
def test_trie_subtokens(self):
|
||||
trie = Trie()
|
||||
trie.add("A")
|
||||
trie.add("P")
|
||||
trie.add("[SPECIAL_TOKEN]")
|
||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||
|
||||
def test_trie_suffix_tokens(self):
|
||||
trie = Trie()
|
||||
trie.add("AB")
|
||||
trie.add("B")
|
||||
trie.add("C")
|
||||
self.assertEqual(trie.split("ABC"), ["AB", "C"])
|
||||
|
||||
def test_trie_skip(self):
|
||||
trie = Trie()
|
||||
trie.add("ABC")
|
||||
trie.add("B")
|
||||
trie.add("CD")
|
||||
self.assertEqual(trie.split("ABCD"), ["ABC", "D"])
|
||||
|
||||
def test_cut_text_hardening(self):
|
||||
# Even if the offsets are wrong, we necessarily output correct string
|
||||
# parts.
|
||||
trie = Trie()
|
||||
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
|
||||
self.assertEqual(parts, ["AB", "C"])
|
||||
|
||||
|
||||
class ExtensionsTrieTest(unittest.TestCase):
|
||||
def test_extensions(self):
|
||||
# Test searching by prefix
|
||||
trie = ExtensionsTrie()
|
||||
trie.add("foo")
|
||||
trie.add("food")
|
||||
trie.add("foodie")
|
||||
trie.add("helium")
|
||||
self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
|
||||
self.assertEqual(trie.extensions("helium"), ["helium"])
|
||||
|
||||
def test_empty_prefix(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching with an empty prefix returns all values
|
||||
trie.add("hello")
|
||||
trie.add("bye")
|
||||
self.assertEqual(trie.extensions(""), ["hello", "bye"])
|
||||
|
||||
def test_no_extension_match(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching for a prefix that doesn't match any key
|
||||
values = trie.extensions("unknown")
|
||||
|
||||
self.assertEqual(len(values), 0)
|
||||
|
||||
def test_update_value(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test updating the value of an existing key
|
||||
trie.add("hi")
|
||||
trie.add("hi")
|
||||
self.assertEqual(trie.extensions("hi"), ["hi"])
|
||||
97
tests/utils/test_versions_utils.py
Normal file
97
tests/utils/test_versions_utils.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.metadata
|
||||
import sys
|
||||
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
from transformers.utils.versions import require_version, require_version_core
|
||||
|
||||
|
||||
numpy_ver = importlib.metadata.version("numpy")
|
||||
python_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
||||
|
||||
|
||||
class DependencyVersionCheckTest(TestCasePlus):
|
||||
def test_core(self):
|
||||
# lt + different version strings
|
||||
require_version_core("numpy<1000.4.5")
|
||||
require_version_core("numpy<1000.4")
|
||||
require_version_core("numpy<1000")
|
||||
|
||||
# le
|
||||
require_version_core("numpy<=1000.4.5")
|
||||
require_version_core(f"numpy<={numpy_ver}")
|
||||
|
||||
# eq
|
||||
require_version_core(f"numpy=={numpy_ver}")
|
||||
|
||||
# ne
|
||||
require_version_core("numpy!=1000.4.5")
|
||||
|
||||
# ge
|
||||
require_version_core("numpy>=1.0")
|
||||
require_version_core("numpy>=1.0.0")
|
||||
require_version_core(f"numpy>={numpy_ver}")
|
||||
|
||||
# gt
|
||||
require_version_core("numpy>1.0.0")
|
||||
|
||||
# mix
|
||||
require_version_core("numpy>1.0.0,<1000")
|
||||
|
||||
# requirement w/o version
|
||||
require_version_core("numpy")
|
||||
|
||||
# unmet requirements due to version conflict
|
||||
for req in ["numpy==1.0.0", "numpy>=1000.0.0", f"numpy<{numpy_ver}"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except ImportError as e:
|
||||
self.assertIn(f"{req} is required", str(e))
|
||||
self.assertIn("but found", str(e))
|
||||
|
||||
# unmet requirements due to missing module
|
||||
for req in ["numpipypie>1", "numpipypie2"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except importlib.metadata.PackageNotFoundError as e:
|
||||
self.assertIn(f"The '{req}' distribution was not found and is required by this application", str(e))
|
||||
self.assertIn("Try: `pip install transformers -U`", str(e))
|
||||
|
||||
# bogus requirements formats:
|
||||
# 1. whole thing
|
||||
for req in ["numpy??1.0.0", "numpy1.0.0"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except ValueError as e:
|
||||
self.assertIn("requirement needs to be in the pip package format", str(e))
|
||||
# 2. only operators
|
||||
for req in ["numpy=1.0.0", "numpy == 1.00", "numpy<>1.0.0", "numpy><1.00", "numpy>>1.0.0"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except ValueError as e:
|
||||
self.assertIn("need one of ", str(e))
|
||||
|
||||
def test_python(self):
|
||||
# matching requirement
|
||||
require_version("python>=3.9.0")
|
||||
|
||||
# not matching requirements
|
||||
for req in ["python>9.9.9", "python<3.0.0"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except ImportError as e:
|
||||
self.assertIn(f"{req} is required", str(e))
|
||||
self.assertIn(f"but found python=={python_ver}", str(e))
|
||||
365
tests/utils/test_video_utils.py
Normal file
365
tests/utils/test_video_utils.py
Normal file
@@ -0,0 +1,365 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.image_processing_utils import get_size_dict
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.processing_utils import VideosKwargs
|
||||
from transformers.testing_utils import (
|
||||
require_av,
|
||||
require_cv2,
|
||||
require_decord,
|
||||
require_torch,
|
||||
require_torchcodec,
|
||||
require_torchvision,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
from transformers import BaseVideoProcessor
|
||||
from transformers.video_utils import VideoMetadata, load_video
|
||||
|
||||
|
||||
def get_random_video(height, width, num_frames=8, return_torch=False):
|
||||
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
video = np.array([random_frame] * num_frames)
|
||||
if return_torch:
|
||||
# move channel first
|
||||
return torch.from_numpy(video).permute(0, 3, 1, 2)
|
||||
return video
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torchvision
|
||||
class BaseVideoProcessorTester(unittest.TestCase):
|
||||
"""
|
||||
Tests that the `transforms` can be applied to a 4-dim array directly, i.e. to a whole video.
|
||||
"""
|
||||
|
||||
def test_make_batched_videos_pil(self):
|
||||
# Test a single image is converted to a list of 1 video with 1 frame
|
||||
video = get_random_video(16, 32)
|
||||
pil_image = PIL.Image.fromarray(video[0])
|
||||
videos_list = make_batched_videos(pil_image)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], np.array(pil_image)))
|
||||
|
||||
# Test a list of videos is converted to a list of 1 video
|
||||
video = get_random_video(16, 32)
|
||||
pil_video = [PIL.Image.fromarray(frame) for frame in video]
|
||||
videos_list = make_batched_videos(pil_video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
# Test a nested list of videos is not modified
|
||||
video = get_random_video(16, 32)
|
||||
pil_video = [PIL.Image.fromarray(frame) for frame in video]
|
||||
videos = [pil_video, pil_video]
|
||||
videos_list = make_batched_videos(videos)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
def test_make_batched_videos_numpy(self):
|
||||
# Test a single image is converted to a list of 1 video with 1 frame
|
||||
video = get_random_video(16, 32)[0]
|
||||
videos_list = make_batched_videos(video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], video))
|
||||
|
||||
# Test a 4d array of videos is converted to a list of 1 video
|
||||
video = get_random_video(16, 32)
|
||||
videos_list = make_batched_videos(video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertTrue(len(videos_list), 1)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
# Test a 5d array of batch videos is converted to a list of videos
|
||||
video = video[None, ...].repeat(4, 0)
|
||||
videos_list = make_batched_videos(video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertTrue(len(videos_list), 4)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video[0]))
|
||||
|
||||
# Test a list of videos is converted to a list of videos
|
||||
video = get_random_video(16, 32)
|
||||
videos = [video, video]
|
||||
videos_list = make_batched_videos(videos)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], np.ndarray)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
@require_torch
|
||||
def test_make_batched_videos_torch(self):
|
||||
# Test a single image is converted to a list of 1 video with 1 frame
|
||||
video = get_random_video(16, 32)[0]
|
||||
torch_video = torch.from_numpy(video)
|
||||
videos_list = make_batched_videos(torch_video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], torch.Tensor)
|
||||
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0][0], video))
|
||||
|
||||
# Test a 4d array of videos is converted to a list of 1 video
|
||||
video = get_random_video(16, 32)
|
||||
torch_video = torch.from_numpy(video)
|
||||
videos_list = make_batched_videos(torch_video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertTrue(len(videos_list), 1)
|
||||
self.assertIsInstance(videos_list[0], torch.Tensor)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
# Test a 5d array of batch videos is converted to a list of videos
|
||||
torch_video = torch_video[None, ...].repeat(4, 1, 1, 1, 1)
|
||||
videos_list = make_batched_videos(torch_video)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertTrue(len(videos_list), 4)
|
||||
self.assertIsInstance(videos_list[0], torch.Tensor)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
# Test a list of videos is converted to a list of videos
|
||||
video = get_random_video(16, 32)
|
||||
torch_video = torch.from_numpy(video)
|
||||
videos = [torch_video, torch_video]
|
||||
videos_list = make_batched_videos(videos)
|
||||
self.assertIsInstance(videos_list, list)
|
||||
self.assertIsInstance(videos_list[0], torch.Tensor)
|
||||
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
|
||||
self.assertTrue(np.array_equal(videos_list[0], video))
|
||||
|
||||
def test_resize(self):
|
||||
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
|
||||
video = get_random_video(16, 32, return_torch=True)
|
||||
|
||||
# Size can be an int or a tuple of ints.
|
||||
size_dict = SizeDict(**get_size_dict((8, 8), param_name="size"))
|
||||
resized_video = video_processor.resize(video, size=size_dict)
|
||||
self.assertIsInstance(resized_video, torch.Tensor)
|
||||
self.assertEqual(resized_video.shape, (8, 3, 8, 8))
|
||||
|
||||
def test_normalize(self):
|
||||
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
|
||||
array = torch.randn(4, 3, 16, 32)
|
||||
mean = [0.1, 0.5, 0.9]
|
||||
std = [0.2, 0.4, 0.6]
|
||||
|
||||
# mean and std can be passed as lists or NumPy arrays.
|
||||
expected = (array - torch.tensor(mean)[:, None, None]) / torch.tensor(std)[:, None, None]
|
||||
normalized_array = video_processor.normalize(array, mean, std)
|
||||
torch.testing.assert_close(normalized_array, expected)
|
||||
|
||||
def test_center_crop(self):
|
||||
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
|
||||
video = get_random_video(16, 32, return_torch=True)
|
||||
|
||||
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
|
||||
crop_sizes = [8, (8, 64), 20, (32, 64)]
|
||||
for size in crop_sizes:
|
||||
size_dict = SizeDict(**get_size_dict(size, default_to_square=True, param_name="crop_size"))
|
||||
cropped_video = video_processor.center_crop(video, size_dict)
|
||||
self.assertIsInstance(cropped_video, torch.Tensor)
|
||||
|
||||
expected_size = (size, size) if isinstance(size, int) else size
|
||||
self.assertEqual(cropped_video.shape, (8, 3, *expected_size))
|
||||
|
||||
def test_convert_to_rgb(self):
|
||||
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
|
||||
video = get_random_video(20, 20, return_torch=True)
|
||||
|
||||
rgb_video = video_processor.convert_to_rgb(video[:, :1])
|
||||
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
|
||||
|
||||
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
|
||||
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
|
||||
|
||||
def test_group_and_reorder_videos(self):
|
||||
"""Tests that videos can be grouped by frame size and number of frames"""
|
||||
video_1 = get_random_video(20, 20, num_frames=3, return_torch=True)
|
||||
video_2 = get_random_video(20, 20, num_frames=5, return_torch=True)
|
||||
|
||||
# Group two videos of same size but different number of frames
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2])
|
||||
self.assertEqual(len(grouped_videos), 2)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 2)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
# Group two videos of different size but same number of frames
|
||||
video_3 = get_random_video(15, 20, num_frames=3, return_torch=True)
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3])
|
||||
self.assertEqual(len(grouped_videos), 2)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 2)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
# Group all three videos where some have same size or same frame count
|
||||
# But since none have frames and sizes identical, we'll have 3 groups
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3])
|
||||
self.assertEqual(len(grouped_videos), 3)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 3)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
# Group if we had some videos with identical shapes
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3])
|
||||
self.assertEqual(len(grouped_videos), 2)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 2)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
# Group if we had all videos with identical shapes
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1])
|
||||
self.assertEqual(len(grouped_videos), 1)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 1)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_av
|
||||
class LoadVideoTester(unittest.TestCase):
|
||||
def test_load_video_url(self):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
)
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
|
||||
|
||||
def test_load_video_local(self):
|
||||
video_file_path = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||
)
|
||||
video, _ = load_video(video_file_path)
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
|
||||
|
||||
# FIXME: @raushan, yt-dlp downloading works for for some reason it cannot redirect to out buffer?
|
||||
# @requires_yt_dlp
|
||||
# def test_load_video_youtube(self):
|
||||
# video = load_video("https://www.youtube.com/watch?v=QC8iQqtG0hg")
|
||||
# self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
|
||||
|
||||
@require_decord
|
||||
@require_torchvision
|
||||
@require_torchcodec
|
||||
@require_cv2
|
||||
def test_load_video_backend_url(self):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
backend="decord",
|
||||
)
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
backend="torchcodec",
|
||||
)
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
|
||||
# Can't use certain backends with url
|
||||
with self.assertRaises(ValueError):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
backend="opencv",
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
backend="torchvision",
|
||||
)
|
||||
|
||||
@require_decord
|
||||
@require_torchvision
|
||||
@require_torchcodec
|
||||
@require_cv2
|
||||
def test_load_video_backend_local(self):
|
||||
video_file_path = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
|
||||
)
|
||||
video, metadata = load_video(video_file_path, backend="decord")
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
self.assertIsInstance(metadata, VideoMetadata)
|
||||
|
||||
video, metadata = load_video(video_file_path, backend="opencv")
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
self.assertIsInstance(metadata, VideoMetadata)
|
||||
|
||||
video, metadata = load_video(video_file_path, backend="torchvision")
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
self.assertIsInstance(metadata, VideoMetadata)
|
||||
|
||||
video, metadata = load_video(video_file_path, backend="torchcodec")
|
||||
self.assertEqual(video.shape, (243, 360, 640, 3))
|
||||
self.assertIsInstance(metadata, VideoMetadata)
|
||||
|
||||
def test_load_video_num_frames(self):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
num_frames=16,
|
||||
)
|
||||
self.assertEqual(video.shape, (16, 360, 640, 3))
|
||||
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
num_frames=22,
|
||||
)
|
||||
self.assertEqual(video.shape, (22, 360, 640, 3))
|
||||
|
||||
def test_load_video_fps(self):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", fps=1
|
||||
)
|
||||
self.assertEqual(video.shape, (9, 360, 640, 3))
|
||||
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", fps=2
|
||||
)
|
||||
self.assertEqual(video.shape, (19, 360, 640, 3))
|
||||
|
||||
# `num_frames` is mutually exclusive with `video_fps`
|
||||
with self.assertRaises(ValueError):
|
||||
video, _ = load_video(
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
fps=1,
|
||||
num_frames=10,
|
||||
)
|
||||
7234
tests/utils/tiny_model_summary.json
Normal file
7234
tests/utils/tiny_model_summary.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user