Files
transformers/tests/utils/test_configuration_utils.py
陈赣 06f1fd69a6
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
first commit
2026-06-05 16:53:03 +08:00

400 lines
18 KiB
Python

# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import sys
import tempfile
import unittest
import unittest.mock as mock
import warnings
from pathlib import Path
import httpx
from transformers import AutoConfig, BertConfig, Florence2Config, GPT2Config
from transformers.configuration_utils import PreTrainedConfig
from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, require_torch
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
config_common_kwargs = {
"return_dict": False,
"output_hidden_states": True,
"output_attentions": True,
"dtype": "float16",
"chunk_size_feed_forward": 5,
"architectures": ["BertModel"],
"id2label": {0: "label"},
"label2id": {"label": "0"},
"problem_type": "regression",
}
@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._token = TOKEN
def test_push_to_hub(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.push_to_hub(tmp_repo.repo_id, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_via_save_pretrained(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_in_organization(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.push_to_hub(tmp_repo.repo_id, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_in_organization_via_save_pretrained(self):
with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
def test_push_to_hub_dynamic_config(self):
with TemporaryHubRepo(token=self._token) as tmp_repo:
CustomConfig.register_for_auto_class()
config = CustomConfig(attribute=42)
config.push_to_hub(tmp_repo.repo_id, token=self._token)
# This has added the proper auto_map field to the config
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
new_config = AutoConfig.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
self.assertEqual(new_config.attribute, 42)
class ConfigTestUtils(unittest.TestCase):
def test_config_from_string(self):
c = GPT2Config()
# attempt to modify each of int/float/bool/str config records and verify they were updated
n_embd = c.n_embd + 1 # int
resid_pdrop = c.resid_pdrop + 1.0 # float
scale_attn_weights = not c.scale_attn_weights # bool
summary_type = c.summary_type + "foo" # str
c.update_from_string(
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
)
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
def test_config_common_kwargs_is_complete(self):
base_config = PreTrainedConfig()
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to add in config_common_kwargs above.
self.assertListEqual(
missing_keys,
[
"transformers_version",
"is_encoder_decoder",
"_name_or_path",
"_commit_hash",
"_output_attentions",
"_attn_implementation_internal",
"_experts_implementation_internal",
],
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
raise ValueError(
"The following keys are set with the default values in"
" `test_configuration_common.config_common_kwargs` pick another value for them:"
f" {', '.join(keys_with_defaults)}."
)
def test_nested_config_load_from_dict(self):
config = AutoConfig.from_pretrained(
"hf-internal-testing/tiny-random-CLIPModel", text_config={"num_hidden_layers": 2}
)
self.assertNotIsInstance(config.text_config, dict)
self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig")
def test_from_pretrained_subfolder(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
self.assertIsNotNone(config)
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
self.assertIsNotNone(config)
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = httpx.HTTPStatusError(
"failed", request=mock.Mock(), response=mock.Mock()
)
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("httpx.Client.request", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
def test_local_versioning(self):
configuration = AutoConfig.from_pretrained("google-bert/bert-base-cased")
configuration.configuration_files = ["config.4.0.0.json"]
with tempfile.TemporaryDirectory() as tmp_dir:
configuration.save_pretrained(tmp_dir)
configuration.hidden_size = 2
json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
# This should pick the new configuration file as the version of Transformers is > 4.0.0
new_configuration = AutoConfig.from_pretrained(tmp_dir)
self.assertEqual(new_configuration.hidden_size, 2)
# Will need to be adjusted if we reach v42 and this test is still here.
# Should pick the old configuration file as the version of Transformers is < 4.42.0
configuration.configuration_files = ["config.42.0.0.json"]
configuration.hidden_size = 768
configuration.save_pretrained(tmp_dir)
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
new_configuration = AutoConfig.from_pretrained(tmp_dir)
self.assertEqual(new_configuration.hidden_size, 768)
def test_repo_versioning_before(self):
# This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
repo = "hf-internal-testing/test-two-configs"
import transformers as new_transformers
# Matt: Use a context manager to ensure everything is correctly reverted and we
# don't leak state between tests
with mock.patch.object(new_transformers.configuration_utils, "__version__", "v4.0.0"):
new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained(
repo, return_unused_kwargs=True
)
self.assertEqual(new_configuration.hidden_size, 2)
# This checks `_configuration_file` ia not kept in the kwargs by mistake.
self.assertDictEqual(kwargs, {})
# Testing an older version by monkey-patching the version in the module it's used.
import transformers as old_transformers
with mock.patch.object(old_transformers.configuration_utils, "__version__", "v3.0.0"):
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
self.assertEqual(old_configuration.hidden_size, 768)
def test_saving_config_with_custom_generation_kwargs_raises_error(self):
config = BertConfig()
config.min_length = 3 # `min_length = 3` is a non-default generation kwarg
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(ValueError):
config.save_pretrained(tmp_dir)
def test_get_generation_parameters(self):
config = BertConfig()
self.assertFalse(len(config._get_generation_parameters()) > 0)
config.min_length = 3
self.assertTrue(len(config._get_generation_parameters()) > 0)
config.min_length = 0
self.assertTrue(len(config._get_generation_parameters()) > 0)
def test_loading_config_do_not_raise_future_warnings(self):
"""Regression test for https://github.com/huggingface/transformers/issues/31002."""
# Loading config should not raise a FutureWarning. It was the case before.
with warnings.catch_warnings():
warnings.simplefilter("error")
PreTrainedConfig.from_pretrained("bert-base-uncased")
def test_get_text_config(self):
"""Tests the `get_text_config` method."""
# 1. model with only text input -> returns the original config instance
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
self.assertEqual(config.get_text_config(), config)
self.assertEqual(config.get_text_config(decoder=True), config)
# 2. composite model (VLM) -> returns the text component
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-LlavaForConditionalGeneration")
self.assertEqual(config.get_text_config(), config.text_config)
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
# 3. ! corner case! : composite model whose sub-config is an old composite model (should behave as above)
config = Florence2Config()
self.assertEqual(config.get_text_config(), config.text_config)
self.assertEqual(config.get_text_config(decoder=True), config.text_config)
# 4. old composite model -> may remove components based on the `decoder` or `encoder` argument
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-bart")
self.assertEqual(config.get_text_config(), config)
# both encoder_layers and decoder_layers exist
self.assertTrue(getattr(config, "encoder_ffn_dim", None) is not None)
self.assertTrue(getattr(config, "decoder_ffn_dim", None) is not None)
decoder_config = config.get_text_config(decoder=True)
self.assertNotEqual(decoder_config, config)
self.assertEqual(decoder_config.num_hidden_layers, config.decoder_layers)
encoder_config = config.get_text_config(encoder=True)
self.assertNotEqual(encoder_config, config)
self.assertEqual(encoder_config.num_hidden_layers, config.encoder_layers)
@require_torch
def test_bc_torch_dtype(self):
import torch
config = PreTrainedConfig(dtype="bfloat16")
self.assertEqual(config.dtype, torch.bfloat16)
config = PreTrainedConfig(torch_dtype="bfloat16")
self.assertEqual(config.dtype, torch.bfloat16)
# Check that if we pass both, `dtype` is used
config = PreTrainedConfig(dtype="bfloat16", torch_dtype="float32")
self.assertEqual(config.dtype, torch.bfloat16)
with tempfile.TemporaryDirectory() as tmpdirname:
config.save_pretrained(tmpdirname)
config = PreTrainedConfig.from_pretrained(tmpdirname)
self.assertEqual(config.dtype, torch.bfloat16)
config = PreTrainedConfig.from_pretrained(tmpdirname, dtype="float32")
self.assertEqual(config.dtype, "float32")
config = PreTrainedConfig.from_pretrained(tmpdirname, torch_dtype="float32")
self.assertEqual(config.dtype, "float32")
def test_unserializable_json_is_encoded(self):
class NewConfig(PreTrainedConfig):
def __init__(
self,
inf_positive: float = float("inf"),
inf_negative: float = float("-inf"),
nan: float = float("nan"),
**kwargs,
):
self.inf_positive = inf_positive
self.inf_negative = inf_negative
self.nan = nan
super().__init__(**kwargs)
new_config = NewConfig()
# All floats should remain as floats when being accessed in the config
self.assertIsInstance(new_config.inf_positive, float)
self.assertIsInstance(new_config.inf_negative, float)
self.assertIsInstance(new_config.nan, float)
with tempfile.TemporaryDirectory() as tmpdirname:
new_config.save_pretrained(tmpdirname)
config_file = Path(tmpdirname) / "config.json"
config_contents = json.loads(config_file.read_text())
new_config_instance = NewConfig.from_pretrained(tmpdirname)
# In the serialized JSON file, the non-JSON compatible floats should be updated
self.assertDictEqual(config_contents["inf_positive"], {"__float__": "Infinity"})
self.assertDictEqual(config_contents["inf_negative"], {"__float__": "-Infinity"})
self.assertDictEqual(config_contents["nan"], {"__float__": "NaN"})
with tempfile.TemporaryDirectory() as tmpdirname:
new_config.save_pretrained(tmpdirname)
# When reloading the config, it should have correct float values
self.assertIsInstance(new_config_instance.inf_positive, float)
self.assertIsInstance(new_config_instance.inf_negative, float)
self.assertIsInstance(new_config_instance.nan, float)
class ConfigSubclassKwOnlyTest(unittest.TestCase):
"""Test that config subclasses with non-default fields following parent default fields
no longer raise TypeError (fixed by kw_only=True in __init_subclass__). Regression
test for https://github.com/huggingface/transformers/issues/XXXX."""
def test_subclass_non_default_field_after_default(self):
"""A config subclass adding a required field after parent defaults must not raise."""
class MyConfig(PreTrainedConfig):
pooling: str # no default — would fail under Python dataclass ordering rules
# Should construct without TypeError
cfg = MyConfig(pooling="mean")
self.assertEqual(cfg.pooling, "mean")
def test_subclass_multiple_non_default_fields(self):
"""Multiple non-default fields in the subclass should all work."""
class EmbedConfig(PreTrainedConfig):
dim: int
pooling: str
cfg = EmbedConfig(dim=128, pooling="cls")
self.assertEqual(cfg.dim, 128)
self.assertEqual(cfg.pooling, "cls")
def test_inherited_defaults_still_work(self):
"""Inherited fields with defaults must still be accessible."""
from transformers import BertConfig
class BertWithPooling(BertConfig):
pooling: str
cfg = BertWithPooling(pooling="mean", hidden_size=256)
self.assertEqual(cfg.pooling, "mean")
self.assertEqual(cfg.hidden_size, 256)