first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled

This commit is contained in:
陈赣
2026-06-05 16:53:03 +08:00
commit 06f1fd69a6
6047 changed files with 1895387 additions and 0 deletions

View File

@@ -0,0 +1,75 @@
import os
import sys
import unittest
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(ROOT_DIR, "utils"))
import create_dependency_mapping # noqa: E402
# This is equivalent to `all` in the current library state (as of 09/01/2025)
MODEL_ROOT = os.path.join("src", "transformers", "models")
FILES_TO_PARSE = [
os.path.join(MODEL_ROOT, "starcoder2", "modular_starcoder2.py"),
os.path.join(MODEL_ROOT, "gemma", "modular_gemma.py"),
os.path.join(MODEL_ROOT, "olmo2", "modular_olmo2.py"),
os.path.join(MODEL_ROOT, "diffllama", "modular_diffllama.py"),
os.path.join(MODEL_ROOT, "granite", "modular_granite.py"),
os.path.join(MODEL_ROOT, "gemma2", "modular_gemma2.py"),
os.path.join(MODEL_ROOT, "mixtral", "modular_mixtral.py"),
os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"),
os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"),
os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"),
os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3.py"),
os.path.join(MODEL_ROOT, "llava_next_video", "modular_llava_next_video.py"),
os.path.join(MODEL_ROOT, "cohere2", "modular_cohere2.py"),
os.path.join(MODEL_ROOT, "modernbert", "modular_modernbert.py"),
os.path.join(MODEL_ROOT, "colpali", "modular_colpali.py"),
os.path.join(MODEL_ROOT, "deformable_detr", "modular_deformable_detr.py"),
os.path.join(MODEL_ROOT, "aria", "modular_aria.py"),
os.path.join(MODEL_ROOT, "ijepa", "modular_ijepa.py"),
os.path.join(MODEL_ROOT, "bamba", "modular_bamba.py"),
os.path.join(MODEL_ROOT, "dinov2_with_registers", "modular_dinov2_with_registers.py"),
os.path.join(MODEL_ROOT, "instructblipvideo", "modular_instructblipvideo.py"),
os.path.join(MODEL_ROOT, "glm", "modular_glm.py"),
os.path.join(MODEL_ROOT, "phi", "modular_phi.py"),
os.path.join(MODEL_ROOT, "mistral", "modular_mistral.py"),
os.path.join(MODEL_ROOT, "phi3", "modular_phi3.py"),
os.path.join(MODEL_ROOT, "cohere", "modular_cohere.py"),
os.path.join(MODEL_ROOT, "glm4", "modular_glm4.py"),
os.path.join(MODEL_ROOT, "seed_oss", "modular_seed_oss.py"),
]
def appear_after(model1: str, model2: str, priority_list: list[list[str]]) -> bool:
"""Return True if `model1` appear after `model2` in `priority_list`."""
model1_index, model2_index = None, None
for i, level in enumerate(priority_list):
if model1 in level:
model1_index = i
if model2 in level:
model2_index = i
if model1_index is None or model2_index is None:
raise ValueError(f"Model {model1} or {model2} not found in {priority_list}")
return model1_index > model2_index
class ConversionOrderTest(unittest.TestCase):
def test_conversion_order(self):
# Find the order
priority_list, _ = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
# Extract just the model names (list of lists)
model_priority_list = [[file.split("/")[-2] for file in level] for level in priority_list]
# These are based on what the current library order should be (as of 09/01/2025)
self.assertTrue(appear_after("mixtral", "mistral", model_priority_list))
self.assertTrue(appear_after("gemma2", "gemma", model_priority_list))
self.assertTrue(appear_after("starcoder2", "mistral", model_priority_list))
self.assertTrue(appear_after("olmo2", "olmo", model_priority_list))
self.assertTrue(appear_after("diffllama", "mistral", model_priority_list))
self.assertTrue(appear_after("cohere2", "gemma2", model_priority_list))
self.assertTrue(appear_after("cohere2", "cohere", model_priority_list))
self.assertTrue(appear_after("phi3", "mistral", model_priority_list))
self.assertTrue(appear_after("glm4", "glm", model_priority_list))

View File

@@ -0,0 +1,129 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tempfile
import textwrap
import unittest
from contextlib import contextmanager
from pathlib import Path
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
utils_path = os.path.join(git_repo_path, "utils")
if utils_path not in sys.path:
sys.path.append(utils_path)
import check_auto # noqa: E402
@contextmanager
def cwd(path: Path):
old = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(old)
def _write_config(root: Path, module_name: str, classes: list[tuple[str, str]]) -> None:
"""Create `src/transformers/models/<module>/configuration_<module>.py` with the given classes.
`classes` is a list of (class_name, model_type) pairs, all subclassing PreTrainedConfig.
"""
module_dir = root / "src" / "transformers" / "models" / module_name
module_dir.mkdir(parents=True, exist_ok=True)
body = "from transformers import PreTrainedConfig\n\n"
for cls_name, model_type in classes:
body += textwrap.dedent(
f'''
class {cls_name}(PreTrainedConfig):
model_type = "{model_type}"
'''
)
(module_dir / f"configuration_{module_name}.py").write_text(body, encoding="utf-8")
class BuildConfigMappingNamesTest(unittest.TestCase):
"""Tests for the natural-match tie-break in `check_auto.build_config_mapping_names`.
A natural match is one where a config's `model_type` equals its module directory name
(e.g. `DetrConfig` with `model_type = "detr"` inside `models/detr/`). When two classes
share a `model_type`, the natural one must always win regardless of filesystem ordering.
"""
def test_single_natural_match(self):
"""Baseline: one config in its eponymous module → no special mapping."""
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
_write_config(root, "detr", [("DetrConfig", "detr")])
with cwd(root):
model_type_map, special_mappings = check_auto.build_config_mapping_names()
self.assertEqual(model_type_map, {"detr": "DetrConfig"})
self.assertEqual(special_mappings, {})
def test_natural_wins_when_encountered_first(self):
"""detr (natural) is alphabetically before maskformer (non-natural for model_type=detr).
This is the order modern Linux filesystems produce. The natural match must be kept
and the alias must not appear in the canonical mapping.
"""
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
_write_config(root, "detr", [("DetrConfig", "detr")])
_write_config(
root,
"maskformer",
[("MaskFormerConfig", "maskformer"), ("MaskFormerDetrConfig", "detr")],
)
with cwd(root):
model_type_map, special_mappings = check_auto.build_config_mapping_names()
self.assertEqual(model_type_map["detr"], "DetrConfig")
self.assertEqual(model_type_map["maskformer"], "MaskFormerConfig")
self.assertNotIn("detr", special_mappings)
def test_natural_wins_when_encountered_second(self):
"""The non-natural alias is alphabetically *before* the natural module.
Without the prefer-natural logic this is the case that breaks: the alias would be
recorded first and then never overwritten. The fix must still pick the natural class
and clear the now-stale special mapping.
"""
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
# `aaa_alias` sorts before `foo`, so its non-natural class is processed first.
_write_config(root, "aaa_alias", [("AaaConfig", "aaa_alias"), ("FooAliasConfig", "foo")])
_write_config(root, "foo", [("FooConfig", "foo")])
with cwd(root):
model_type_map, special_mappings = check_auto.build_config_mapping_names()
self.assertEqual(model_type_map["foo"], "FooConfig")
self.assertNotIn("foo", special_mappings, "stale alias entry must be cleared")
# The alias module's own natural entry is still recorded.
self.assertEqual(model_type_map["aaa_alias"], "AaaConfig")
def test_non_natural_only_records_special_mapping(self):
"""If a model_type has no natural match, the alias is the canonical entry."""
with tempfile.TemporaryDirectory() as tmp:
root = Path(tmp)
_write_config(root, "wrapper", [("WrapperConfig", "wrapper"), ("InnerConfig", "inner")])
with cwd(root):
model_type_map, special_mappings = check_auto.build_config_mapping_names()
self.assertEqual(model_type_map["inner"], "InnerConfig")
self.assertEqual(special_mappings["inner"], "wrapper")

View File

@@ -0,0 +1,450 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import sys
import tempfile
import unittest
from contextlib import contextmanager
from pathlib import Path
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))
import check_copies # noqa: E402
from check_copies import convert_to_localized_md, find_code_in_transformers, is_copy_consistent # noqa: E402
# This is the reference code that will be used in the tests.
# If BertLMPredictionHead is changed in modeling_bert.py, this code needs to be manually updated.
REFERENCE_CODE = """ def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
"""
MOCK_BERT_CODE = """from ...modeling_utils import PreTrainedModel
def bert_function(x):
return x
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
class BertModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__()
self.bert = BertEncoder(config)
@add_docstring(BERT_DOCSTRING)
def forward(self, x):
return self.bert(x)
"""
MOCK_BERT_COPY_CODE = """from ...modeling_utils import PreTrainedModel
# Copied from transformers.models.bert.modeling_bert.bert_function
def bert_copy_function(x):
return x
# Copied from transformers.models.bert.modeling_bert.BertAttention
class BertCopyAttention(nn.Module):
def __init__(self, config):
super().__init__()
# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->BertCopy all-casing
class BertCopyModel(BertCopyPreTrainedModel):
def __init__(self, config):
super().__init__()
self.bertcopy = BertCopyEncoder(config)
@add_docstring(BERTCOPY_DOCSTRING)
def forward(self, x):
return self.bertcopy(x)
"""
MOCK_DUMMY_BERT_CODE_MATCH = """
class BertDummyModel:
attr_1 = 1
attr_2 = 2
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def existing_common(self, c):
return 4
def existing_diff_to_be_ignored(self, c):
return 9
"""
MOCK_DUMMY_ROBERTA_CODE_MATCH = """
# Copied from transformers.models.dummy_bert_match.modeling_dummy_bert_match.BertDummyModel with BertDummy->RobertaBertDummy
class RobertaBertDummyModel:
attr_1 = 1
attr_2 = 2
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Ignore copy
def only_in_roberta_to_be_ignored(self, c):
return 3
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def existing_common(self, c):
return 4
# Ignore copy
def existing_diff_to_be_ignored(self, c):
return 6
"""
MOCK_DUMMY_BERT_CODE_NO_MATCH = """
class BertDummyModel:
attr_1 = 1
attr_2 = 2
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def only_in_bert(self, c):
return 7
def existing_common(self, c):
return 4
def existing_diff_not_ignored(self, c):
return 8
def existing_diff_to_be_ignored(self, c):
return 9
"""
MOCK_DUMMY_ROBERTA_CODE_NO_MATCH = """
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
class RobertaBertDummyModel:
attr_1 = 1
attr_2 = 3
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Ignore copy
def only_in_roberta_to_be_ignored(self, c):
return 3
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def only_in_roberta_not_ignored(self, c):
return 2
def existing_common(self, c):
return 4
def existing_diff_not_ignored(self, c):
return 5
# Ignore copy
def existing_diff_to_be_ignored(self, c):
return 6
"""
EXPECTED_REPLACED_CODE = """
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
class RobertaBertDummyModel:
attr_1 = 1
attr_2 = 2
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def only_in_bert(self, c):
return 7
def existing_common(self, c):
return 4
def existing_diff_not_ignored(self, c):
return 8
# Ignore copy
def existing_diff_to_be_ignored(self, c):
return 6
# Ignore copy
def only_in_roberta_to_be_ignored(self, c):
return 3
"""
def replace_in_file(filename, old, new):
with open(filename, encoding="utf-8") as f:
content = f.read()
content = content.replace(old, new)
with open(filename, "w", encoding="utf-8", newline="\n") as f:
f.write(content)
def create_tmp_repo(tmp_dir):
"""
Creates a mock repository in a temporary folder for testing.
"""
tmp_dir = Path(tmp_dir)
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True)
model_dir = tmp_dir / "src" / "transformers" / "models"
model_dir.mkdir(parents=True, exist_ok=True)
models = {
"bert": MOCK_BERT_CODE,
"bertcopy": MOCK_BERT_COPY_CODE,
"dummy_bert_match": MOCK_DUMMY_BERT_CODE_MATCH,
"dummy_roberta_match": MOCK_DUMMY_ROBERTA_CODE_MATCH,
"dummy_bert_no_match": MOCK_DUMMY_BERT_CODE_NO_MATCH,
"dummy_roberta_no_match": MOCK_DUMMY_ROBERTA_CODE_NO_MATCH,
}
for model, code in models.items():
model_subdir = model_dir / model
model_subdir.mkdir(exist_ok=True)
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8", newline="\n") as f:
f.write(code)
@contextmanager
def patch_transformer_repo_path(new_folder):
"""
Temporarily patches the variables defines in `check_copies` to use a different location for the repo.
"""
old_repo_path = check_copies.REPO_PATH
old_doc_path = check_copies.PATH_TO_DOCS
old_transformer_path = check_copies.TRANSFORMERS_PATH
repo_path = Path(new_folder).resolve()
check_copies.REPO_PATH = str(repo_path)
check_copies.PATH_TO_DOCS = str(repo_path / "docs" / "source" / "en")
check_copies.TRANSFORMERS_PATH = str(repo_path / "src" / "transformers")
try:
yield
finally:
check_copies.REPO_PATH = old_repo_path
check_copies.PATH_TO_DOCS = old_doc_path
check_copies.TRANSFORMERS_PATH = old_transformer_path
class CopyCheckTester(unittest.TestCase):
def test_find_code_in_transformers(self):
with tempfile.TemporaryDirectory() as tmp_folder:
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
code = find_code_in_transformers("models.bert.modeling_bert.BertAttention")
reference_code = (
"class BertAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n"
)
self.assertEqual(code, reference_code)
def test_is_copy_consistent(self):
path_to_check = ["src", "transformers", "models", "bertcopy", "modeling_bertcopy.py"]
with tempfile.TemporaryDirectory() as tmp_folder:
# Base check
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [])
# Base check with an inconsistency
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
replace_in_file(file_to_check, "self.bertcopy(x)", "self.bert(x)")
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
_ = is_copy_consistent(file_to_check, overwrite=True)
with open(file_to_check, encoding="utf-8") as f:
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
def test_is_copy_consistent_with_ignored_match(self):
path_to_check = ["src", "transformers", "models", "dummy_roberta_match", "modeling_dummy_roberta_match.py"]
with tempfile.TemporaryDirectory() as tmp_folder:
# Base check
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [])
def test_is_copy_consistent_with_ignored_no_match(self):
path_to_check = [
"src",
"transformers",
"models",
"dummy_roberta_no_match",
"modeling_dummy_roberta_no_match.py",
]
with tempfile.TemporaryDirectory() as tmp_folder:
# Base check with an inconsistency
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
diffs = is_copy_consistent(file_to_check)
# line 6: `attr_2 = 3` in `MOCK_DUMMY_ROBERTA_CODE_NO_MATCH`.
# (which has a leading `\n`.)
self.assertEqual(
diffs, [["models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel", 6]]
)
_ = is_copy_consistent(file_to_check, overwrite=True)
with open(file_to_check, encoding="utf-8") as f:
self.assertEqual(f.read(), EXPECTED_REPLACED_CODE)
def test_convert_to_localized_md(self):
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
md_list = (
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
" Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
" Self-supervised Learning of Language Representations](https://huggingface.co/papers/1909.11942), by Zhenzhong"
" Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.\n1."
" **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (from HuggingFace),"
" released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
" lighter](https://huggingface.co/papers/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same"
" method has been applied to compress GPT2 into"
" [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
" [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
" Multilingual BERT into"
" [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
" version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)**"
" (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders"
" as discriminators rather than generators](https://huggingface.co/papers/2003.10555) by Kevin Clark, Minh-Thang"
" Luong, Quoc V. Le, Christopher D. Manning."
)
localized_md_list = (
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (来自 Google Research and the"
" Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
" Language Representations](https://huggingface.co/papers/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
)
converted_md_list_sample = (
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (来自 Google Research and the"
" Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
" Language Representations](https://huggingface.co/papers/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n1."
" **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (来自 HuggingFace) 伴随论文"
" [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
" lighter](https://huggingface.co/papers/1910.01108) 由 Victor Sanh, Lysandre Debut and Thomas Wolf 发布。 The same"
" method has been applied to compress GPT2 into"
" [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
" [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
" Multilingual BERT into"
" [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
" version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (来自"
" Google Research/Stanford University) 伴随论文 [ELECTRA: Pre-training text encoders as discriminators rather"
" than generators](https://huggingface.co/papers/2003.10555) 由 Kevin Clark, Minh-Thang Luong, Quoc V. Le,"
" Christopher D. Manning 发布。\n"
)
num_models_equal, converted_md_list = convert_to_localized_md(
md_list, localized_md_list, localized_readme["format_model_list"]
)
self.assertFalse(num_models_equal)
self.assertEqual(converted_md_list, converted_md_list_sample)
num_models_equal, converted_md_list = convert_to_localized_md(
md_list, converted_md_list, localized_readme["format_model_list"]
)
# Check whether the number of models is equal to README.md after conversion.
self.assertTrue(num_models_equal)
link_changed_md_list = (
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
" Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
" Self-supervised Learning of Language Representations](https://huggingface.co/papers/1909.11942), by Zhenzhong"
" Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut."
)
link_unchanged_md_list = (
"1. **[ALBERT](https://huggingface.co/transformers/main/model_doc/albert.html)** (来自 Google Research and"
" the Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
" Language Representations](https://huggingface.co/papers/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
)
converted_md_list_sample = (
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (来自 Google Research and the"
" Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
" Language Representations](https://huggingface.co/papers/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
)
num_models_equal, converted_md_list = convert_to_localized_md(
link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
)
# Check if the model link is synchronized.
self.assertEqual(converted_md_list, converted_md_list_sample)

View File

@@ -0,0 +1,347 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import inspect
import os
import sys
import tempfile
import textwrap
import unittest
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))
from check_docstrings import ( # noqa: E402
_build_ast_indexes,
_find_typed_dict_classes,
_get_auto_docstring_names,
get_default_description,
has_auto_docstring_decorator,
replace_default_in_arg_description,
)
class CheckDostringsTested(unittest.TestCase):
def test_replace_default_in_arg_description(self):
# Standard docstring with default.
desc_with_default = "`float`, *optional*, defaults to 2.0"
self.assertEqual(
replace_default_in_arg_description(desc_with_default, 2.0), "`float`, *optional*, defaults to 2.0"
)
self.assertEqual(
replace_default_in_arg_description(desc_with_default, 1.0), "`float`, *optional*, defaults to 1.0"
)
self.assertEqual(replace_default_in_arg_description(desc_with_default, inspect._empty), "`float`")
# Standard docstring with default but optional is not using the stars.
desc_with_default_typo = "`float`, `optional`, defaults to 2.0"
self.assertEqual(
replace_default_in_arg_description(desc_with_default_typo, 2.0), "`float`, *optional*, defaults to 2.0"
)
self.assertEqual(
replace_default_in_arg_description(desc_with_default_typo, 1.0), "`float`, *optional*, defaults to 1.0"
)
# If the default is None we do not erase the value in the docstring.
self.assertEqual(
replace_default_in_arg_description(desc_with_default, None), "`float`, *optional*, defaults to 2.0"
)
# If the default is None (and set as such in the docstring), we do not include it.
desc_with_default = "`float`, *optional*, defaults to None"
self.assertEqual(replace_default_in_arg_description(desc_with_default, None), "`float`, *optional*")
desc_with_default = "`float`, *optional*, defaults to `None`"
self.assertEqual(replace_default_in_arg_description(desc_with_default, None), "`float`, *optional*")
# Operations are not replaced, but put in backtiks.
desc_with_default = "`float`, *optional*, defaults to 1/255"
self.assertEqual(
replace_default_in_arg_description(desc_with_default, 1 / 255), "`float`, *optional*, defaults to `1/255`"
)
desc_with_default = "`float`, *optional*, defaults to `1/255`"
self.assertEqual(
replace_default_in_arg_description(desc_with_default, 1 / 255), "`float`, *optional*, defaults to `1/255`"
)
desc_with_optional = "`float`, *optional*"
self.assertEqual(
replace_default_in_arg_description(desc_with_optional, 2.0), "`float`, *optional*, defaults to 2.0"
)
self.assertEqual(
replace_default_in_arg_description(desc_with_optional, 1.0), "`float`, *optional*, defaults to 1.0"
)
self.assertEqual(replace_default_in_arg_description(desc_with_optional, None), "`float`, *optional*")
self.assertEqual(replace_default_in_arg_description(desc_with_optional, inspect._empty), "`float`")
desc_with_no_optional = "`float`"
self.assertEqual(
replace_default_in_arg_description(desc_with_no_optional, 2.0), "`float`, *optional*, defaults to 2.0"
)
self.assertEqual(
replace_default_in_arg_description(desc_with_no_optional, 1.0), "`float`, *optional*, defaults to 1.0"
)
self.assertEqual(replace_default_in_arg_description(desc_with_no_optional, None), "`float`, *optional*")
self.assertEqual(replace_default_in_arg_description(desc_with_no_optional, inspect._empty), "`float`")
def test_get_default_description(self):
# Fake function to have arguments to test.
def _fake_function(a, b: int, c=1, d: float = 2.0, e: str = "blob"):
pass
params = inspect.signature(_fake_function).parameters
assert get_default_description(params["a"]) == "`<fill_type>`"
assert get_default_description(params["b"]) == "`int`"
assert get_default_description(params["c"]) == "`<fill_type>`, *optional*, defaults to 1"
assert get_default_description(params["d"]) == "`float`, *optional*, defaults to 2.0"
assert get_default_description(params["e"]) == '`str`, *optional*, defaults to `"blob"`'
class TestGetAutoDocstringNames(unittest.TestCase):
"""Tests for _get_auto_docstring_names and has_auto_docstring_decorator."""
def setUp(self):
self.cache = {}
def _write_temp(self, source):
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(source)
self.addCleanup(os.unlink, f.name)
return f.name
def test_detects_simple_decorator(self):
"""Test that a class decorated with @auto_docstring is detected."""
path = self._write_temp(
textwrap.dedent("""\
from transformers import auto_docstring
@auto_docstring
class Foo:
pass
""")
)
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, {"Foo"})
def test_detects_decorator_with_call(self):
"""Test that a class decorated with @auto_docstring(args) (called form) is detected."""
path = self._write_temp(
textwrap.dedent("""\
@auto_docstring(custom_args='x')
class Bar:
pass
""")
)
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, {"Bar"})
def test_ignores_other_decorators(self):
"""Test that classes with non-auto_docstring decorators are not detected."""
path = self._write_temp(
textwrap.dedent("""\
@dataclass
class Baz:
pass
""")
)
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, set())
def test_multiple_classes(self):
"""Test that only decorated classes and functions are returned when multiple definitions exist."""
path = self._write_temp(
textwrap.dedent("""\
@auto_docstring
class A:
pass
class B:
pass
@auto_docstring()
def func_c():
pass
""")
)
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, {"A", "func_c"})
def test_caching(self):
"""Test that repeated calls for the same file return the cached (identical) result object."""
path = self._write_temp(
textwrap.dedent("""\
@auto_docstring
class X:
pass
""")
)
result1 = _get_auto_docstring_names(path, cache=self.cache)
result2 = _get_auto_docstring_names(path, cache=self.cache)
self.assertIs(result1, result2)
def test_syntax_error_returns_empty(self):
"""Test that a file with a syntax error returns an empty set instead of raising."""
path = self._write_temp("def broken(\n")
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, set())
def test_has_auto_docstring_decorator_uses_cache(self):
"""Test that has_auto_docstring_decorator looks up names from the pre-populated cache."""
from unittest.mock import patch
path = self._write_temp(
textwrap.dedent("""\
@auto_docstring
class Cached:
pass
""")
)
self.cache[path] = {"Cached"}
# Create classes whose __name__ matches/doesn't match the cache
Cached = type("Cached", (), {})
Other = type("Other", (), {})
with patch.object(inspect, "getfile", return_value=path):
self.assertTrue(has_auto_docstring_decorator(Cached, cache=self.cache))
self.assertFalse(has_auto_docstring_decorator(Other, cache=self.cache))
class TestBuildAstIndexes(unittest.TestCase):
"""Tests for _build_ast_indexes with pre-parsed tree."""
def test_finds_decorated_items(self):
"""Test that _build_ast_indexes finds a decorated class and extracts its __init__ args."""
source = textwrap.dedent("""\
@auto_docstring
class MyModel:
def __init__(self, hidden_size=768):
self.hidden_size = hidden_size
""")
items = _build_ast_indexes(source)
self.assertEqual(len(items), 1)
self.assertEqual(items[0].name, "MyModel")
self.assertEqual(items[0].kind, "class")
self.assertIn("hidden_size", items[0].args)
def test_shared_tree(self):
"""Test that passing a pre-parsed AST tree produces the same results as letting the function parse internally."""
source = textwrap.dedent("""\
@auto_docstring
class A:
pass
""")
tree = ast.parse(source)
items_with_tree = _build_ast_indexes(source, tree=tree)
items_without = _build_ast_indexes(source)
self.assertEqual(len(items_with_tree), len(items_without))
self.assertEqual(items_with_tree[0].name, items_without[0].name)
def test_no_decorated_items(self):
"""Test that a class without the auto_docstring decorator is not indexed."""
source = textwrap.dedent("""\
class Plain:
pass
""")
items = _build_ast_indexes(source)
self.assertEqual(items, [])
def test_function_decorated(self):
"""Test that a decorated function is indexed with its arguments."""
source = textwrap.dedent("""\
@auto_docstring
def my_func(x, y=10):
pass
""")
items = _build_ast_indexes(source)
self.assertEqual(len(items), 1)
self.assertEqual(items[0].name, "my_func")
self.assertEqual(items[0].kind, "function")
self.assertIn("x", items[0].args)
self.assertIn("y", items[0].args)
def test_custom_args_from_variable(self):
"""Test that custom_args passed as a module-level variable are resolved to their string value."""
source = textwrap.dedent("""\
MY_ARGS = "custom param docs"
@auto_docstring(custom_args=MY_ARGS)
class WithCustom:
def __init__(self):
pass
""")
items = _build_ast_indexes(source)
self.assertEqual(len(items), 1)
self.assertEqual(items[0].custom_args_text, "custom param docs")
class TestFindTypedDictClasses(unittest.TestCase):
"""Tests for _find_typed_dict_classes with pre-parsed tree."""
def test_finds_typed_dict(self):
"""Test that a TypedDict subclass is found and its public fields are extracted."""
source = textwrap.dedent("""\
from typing import TypedDict
class MyKwargs(TypedDict):
field_a: str
field_b: int
""")
result = _find_typed_dict_classes(source)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["name"], "MyKwargs")
self.assertIn("field_a", result[0]["all_fields"])
self.assertIn("field_b", result[0]["all_fields"])
def test_shared_tree(self):
"""Test that passing a pre-parsed AST tree produces the same results as internal parsing."""
source = textwrap.dedent("""\
class MyKwargs(TypedDict):
x: int
""")
tree = ast.parse(source)
r1 = _find_typed_dict_classes(source, tree=tree)
r2 = _find_typed_dict_classes(source)
self.assertEqual(len(r1), len(r2))
self.assertEqual(r1[0]["name"], r2[0]["name"])
def test_skips_standard_kwargs(self):
"""Test that well-known kwargs TypedDicts (e.g. TextKwargs) are excluded from results."""
source = textwrap.dedent("""\
class TextKwargs(TypedDict):
field: str
""")
result = _find_typed_dict_classes(source)
self.assertEqual(result, [])
def test_no_typed_dicts(self):
"""Test that source with no TypedDict subclasses returns an empty list."""
source = textwrap.dedent("""\
class Regular:
pass
""")
result = _find_typed_dict_classes(source)
self.assertEqual(result, [])
def test_skips_private_fields(self):
"""Test that fields starting with an underscore are excluded from the extracted TypedDict fields."""
source = textwrap.dedent("""\
class MyKwargs(TypedDict):
public: int
_private: str
""")
result = _find_typed_dict_classes(source)
self.assertEqual(len(result), 1)
self.assertIn("public", result[0]["all_fields"])
self.assertNotIn("_private", result[0]["all_fields"])

View File

@@ -0,0 +1,74 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
from unittest.mock import patch
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
utils_path = os.path.join(git_repo_path, "utils")
if utils_path not in sys.path:
sys.path.append(utils_path)
import check_modular_conversion # noqa: E402
class ConverterChangedInDiffTest(unittest.TestCase):
"""Regression guard for PR #45492: changes to the converter alone must force a full check."""
def _patch_modified(self, files):
return patch.object(check_modular_conversion, "_get_modified_files", return_value=files)
def test_returns_true_when_modular_model_converter_changed(self):
with self._patch_modified(
[
"utils/modular_model_converter.py",
"src/transformers/models/llava_onevision/modular_llava_onevision.py",
]
):
self.assertTrue(check_modular_conversion.converter_changed_in_diff())
def test_returns_true_when_create_dependency_mapping_changed(self):
with self._patch_modified(["utils/create_dependency_mapping.py"]):
self.assertTrue(check_modular_conversion.converter_changed_in_diff())
def test_returns_false_for_model_only_diff(self):
with self._patch_modified(
[
"src/transformers/models/llama/modular_llama.py",
"src/transformers/models/llama/modeling_llama.py",
]
):
self.assertFalse(check_modular_conversion.converter_changed_in_diff())
def test_returns_false_for_unrelated_utils_change(self):
with self._patch_modified(["utils/check_modular_conversion.py", "utils/check_copies.py"]):
self.assertFalse(check_modular_conversion.converter_changed_in_diff())
def test_converter_files_set_includes_expected_entries(self):
# Keep the allow-list grounded: if either file is renamed/removed, this test fails loudly
# so the detection logic is updated alongside the rename.
self.assertIn("utils/modular_model_converter.py", check_modular_conversion.CONVERTER_FILES)
self.assertIn("utils/create_dependency_mapping.py", check_modular_conversion.CONVERTER_FILES)
for rel_path in check_modular_conversion.CONVERTER_FILES:
self.assertTrue(
os.path.exists(os.path.join(git_repo_path, rel_path)),
f"{rel_path} listed in CONVERTER_FILES but does not exist on disk",
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,151 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tempfile
import unittest
from contextlib import contextmanager
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
utils_path = os.path.join(git_repo_path, "utils")
if utils_path not in sys.path:
sys.path.append(utils_path)
import check_repo # noqa: E402
class RecordingNamespace:
"""Record directory listings and attribute access for cache tests."""
def __init__(self, mapping):
self._mapping = mapping
self.dir_calls = 0
self.getattr_calls = []
def __dir__(self):
self.dir_calls += 1
return list(self._mapping.keys()) + ["__doc__"]
def __getattr__(self, name):
self.getattr_calls.append(name)
try:
return self._mapping[name]
except KeyError as error:
raise AttributeError(name) from error
@contextmanager
def patch_transformers_path(path: Path):
"""Temporarily point `check_repo` at a temporary transformers source tree."""
old_path = check_repo.PATH_TO_TRANSFORMERS
check_repo.PATH_TO_TRANSFORMERS = str(path)
try:
yield
finally:
check_repo.PATH_TO_TRANSFORMERS = old_path
class CheckRepoTest(unittest.TestCase):
def setUp(self):
"""Reset the `get_model_modules` cache before each test."""
check_repo.get_model_modules.cache_clear()
self.addCleanup(check_repo.get_model_modules.cache_clear)
def _write_modeling_file(self, root: Path, model_name: str, content: str) -> None:
"""Create a temporary `modeling_*.py` file used by `check_models_have_kwargs`."""
model_dir = root / "src" / "transformers" / "models" / model_name
model_dir.mkdir(parents=True, exist_ok=True)
(model_dir / f"modeling_{model_name}.py").write_text(content, encoding="utf-8")
def test_get_model_modules_is_cached(self):
"""Repeated calls should reuse the cached module list instead of traversing models twice."""
alpha_modeling = object()
alpha_module = RecordingNamespace({"modeling_alpha": alpha_modeling, "configuration_alpha": object()})
fake_models = RecordingNamespace({"alpha": alpha_module, "deprecated_alpha": object()})
fake_transformers = SimpleNamespace(models=fake_models)
with patch.object(check_repo, "transformers", fake_transformers):
first = check_repo.get_model_modules()
second = check_repo.get_model_modules()
self.assertIs(first, second)
self.assertEqual(first, [alpha_modeling])
self.assertEqual(fake_models.dir_calls, 1)
self.assertEqual(fake_models.getattr_calls, ["alpha"])
self.assertEqual(alpha_module.dir_calls, 1)
self.assertEqual(alpha_module.getattr_calls, ["modeling_alpha"])
def test_check_models_have_kwargs_ignores_nested_classes(self):
"""Nested helper classes should not trigger missing-`**kwargs` failures."""
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
self._write_modeling_file(
root,
"foo",
"""
class PreTrainedModel:
pass
class FooModel(PreTrainedModel):
def forward(self, hidden_states, **kwargs):
return hidden_states
class HelperContainer:
class NestedModel(PreTrainedModel):
def forward(self, hidden_states):
return hidden_states
""".strip()
+ "\n",
)
with patch_transformers_path(root / "src" / "transformers"):
check_repo.check_models_have_kwargs()
def test_check_models_have_kwargs_still_checks_top_level_models(self):
"""Top-level model classes should still fail when `forward()` omits `**kwargs`."""
with tempfile.TemporaryDirectory() as tmpdir:
root = Path(tmpdir)
self._write_modeling_file(
root,
"foo",
"""
class PreTrainedModel:
pass
class FooModel(PreTrainedModel):
def forward(self, hidden_states):
return hidden_states
def make_nested_model():
class NestedModel(PreTrainedModel):
def forward(self, hidden_states, **kwargs):
return hidden_states
return NestedModel
""".strip()
+ "\n",
)
with patch_transformers_path(root / "src" / "transformers"):
with self.assertRaisesRegex(Exception, "FooModel"):
check_repo.check_models_have_kwargs()

View File

@@ -0,0 +1,198 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import sys
import tempfile
import unittest
from contextlib import ExitStack, contextmanager
from pathlib import Path
from unittest.mock import patch
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
utils_path = os.path.join(git_repo_path, "utils")
if utils_path not in sys.path:
sys.path.append(utils_path)
import checkers # noqa: E402
@contextmanager
def patch_checkers_paths(repo_root: Path):
cache_path = repo_root / "utils" / ".checkers_cache.json"
with ExitStack() as stack:
stack.enter_context(patch.object(checkers, "REPO_ROOT", repo_root))
stack.enter_context(patch.object(checkers, "CACHE_PATH", cache_path))
stack.enter_context(patch.object(checkers, "CHECKERS", {"demo": ("Demo checker", "fake_checker.py", [], [])}))
stack.enter_context(patch.object(checkers, "CHECKER_CACHE_GLOBS", {"demo": ["tracked/**/*.txt"]}))
yield cache_path
class CheckersCacheTest(unittest.TestCase):
class _TTYStringIO(io.StringIO):
def isatty(self) -> bool:
return True
def _create_fake_repo(self, tmpdir: str) -> Path:
"""Create a minimal repo layout for exercising checker cache inputs."""
repo_root = Path(tmpdir)
(repo_root / "tracked").mkdir()
(repo_root / "tracked" / "input.txt").write_text("tracked\n", encoding="utf-8")
(repo_root / "utils").mkdir()
(repo_root / "utils" / "fake_checker.py").write_text("# fake checker\n", encoding="utf-8")
return repo_root
def _run_main(self, *args: str, stdout=None) -> tuple[int | None, str]:
"""Run `checkers.main()` with patched argv/stdout and return the exit code and captured output."""
stdout = io.StringIO() if stdout is None else stdout
with (
patch.object(sys, "argv", ["checkers.py", *args]),
patch.object(sys, "stdout", new=stdout),
):
exit_code = None
try:
checkers.main()
except SystemExit as e:
exit_code = e.code
return exit_code, stdout.getvalue()
def test_checker_cache_detects_checker_script_changes(self):
"""Cache entries should become stale when the checker implementation file changes."""
with tempfile.TemporaryDirectory() as tmpdir:
repo_root = self._create_fake_repo(tmpdir)
with patch_checkers_paths(repo_root) as cache_path:
cache = checkers.CheckerCache(path=cache_path)
self.assertFalse(cache.is_current("demo"))
cache.update("demo")
self.assertTrue(cache.is_current("demo"))
(repo_root / "utils" / "fake_checker.py").write_text("# fake checker changed\n", encoding="utf-8")
self.assertFalse(cache.is_current("demo"))
def test_main_skips_cached_runs(self):
"""Main should reuse cached results for repeated runs."""
with tempfile.TemporaryDirectory() as tmpdir:
repo_root = self._create_fake_repo(tmpdir)
with (
patch_checkers_paths(repo_root),
patch.object(
checkers,
"run_checker",
return_value=(0, "first run"),
) as run_checker,
):
exit_code, _ = self._run_main("demo")
self.assertIsNone(exit_code)
self.assertEqual(run_checker.call_count, 1)
exit_code, output = self._run_main("demo")
self.assertIsNone(exit_code)
self.assertEqual(run_checker.call_count, 1)
self.assertIn("(cached)", output)
def test_main_reruns_with_no_cache(self):
"""Main should rerun when `--no-cache` is passed."""
with tempfile.TemporaryDirectory() as tmpdir:
repo_root = self._create_fake_repo(tmpdir)
with (
patch_checkers_paths(repo_root),
patch.object(
checkers,
"run_checker",
side_effect=[(0, "first run"), (0, "forced rerun")],
) as run_checker,
):
exit_code, _ = self._run_main("demo")
self.assertIsNone(exit_code)
self.assertEqual(run_checker.call_count, 1)
exit_code, _ = self._run_main("demo", "--no-cache")
self.assertIsNone(exit_code)
self.assertEqual(run_checker.call_count, 2)
def test_main_prints_full_output_on_failure_without_tty(self):
"""Local non-TTY failures should print the full checker output instead of a cropped tail."""
with tempfile.TemporaryDirectory() as tmpdir:
repo_root = self._create_fake_repo(tmpdir)
output = "\n".join(f"line {i}" for i in range(12)) + "\n"
with (
patch.dict(os.environ, {"GITHUB_ACTIONS": "false", "CIRCLECI": "false"}),
patch_checkers_paths(repo_root),
patch.object(checkers, "run_checker", return_value=(1, output)),
):
exit_code, stdout = self._run_main("demo", "--keep-going")
self.assertEqual(exit_code, 1)
self.assertIn("line 0", stdout)
self.assertIn("line 11", stdout)
def test_main_prints_full_output_on_failure_with_tty(self):
"""TTY failures should print the full checker output without reprinting the cropped window tail."""
with tempfile.TemporaryDirectory() as tmpdir:
repo_root = self._create_fake_repo(tmpdir)
output = "\n".join(f"line {i}" for i in range(12)) + "\n"
class FakeSlidingWindow:
def __init__(self, label, max_lines=10):
self.label = label
self.max_lines = max_lines
def add_line(self, line):
pass
def finish(self, success, elapsed=None, show_lines=True):
print(f"window finished: {self.label} ({success}, {show_lines})")
with (
patch.dict(os.environ, {"GITHUB_ACTIONS": "false", "CIRCLECI": "false"}),
patch_checkers_paths(repo_root),
patch.object(checkers, "run_checker", return_value=(1, output)),
patch.object(checkers, "SlidingWindow", FakeSlidingWindow),
):
exit_code, stdout = self._run_main("demo", "--keep-going", stdout=self._TTYStringIO())
self.assertEqual(exit_code, 1)
self.assertIn("window finished: Demo checker (False, False)", stdout)
self.assertIn("line 0", stdout)
self.assertIn("line 11", stdout)
def test_main_prints_failure_suffix_in_ci(self):
"""CI failures should still print any extra captured output that was not streamed live."""
with tempfile.TemporaryDirectory() as tmpdir:
repo_root = self._create_fake_repo(tmpdir)
streamed_output = "line 0\nline 1\n"
failure_suffix = "summary line\n"
def run_checker(name, fix=False, line_callback=None):
self.assertEqual(name, "demo")
self.assertFalse(fix)
self.assertIsNotNone(line_callback)
for line in streamed_output.splitlines(keepends=True):
line_callback(line)
return 1, streamed_output + failure_suffix
with (
patch.dict(os.environ, {"GITHUB_ACTIONS": "true", "CIRCLECI": "false"}),
patch_checkers_paths(repo_root),
patch.object(checkers, "run_checker", side_effect=run_checker),
):
exit_code, stdout = self._run_main("demo", "--keep-going")
self.assertEqual(exit_code, 1)
self.assertIn("line 0", stdout)
self.assertIn("line 1", stdout)
self.assertIn("summary line", stdout)

View File

@@ -0,0 +1,108 @@
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))
import get_test_info # noqa: E402
from get_test_info import ( # noqa: E402
get_model_to_test_mapping,
get_model_to_tester_mapping,
get_test_to_tester_mapping,
)
BERT_TEST_FILE = os.path.join("tests", "models", "bert", "test_modeling_bert.py")
BLIP_TEST_FILE = os.path.join("tests", "models", "blip", "test_modeling_blip.py")
class GetTestInfoTester(unittest.TestCase):
def test_get_test_to_tester_mapping(self):
bert_test_tester_mapping = get_test_to_tester_mapping(BERT_TEST_FILE)
blip_test_tester_mapping = get_test_to_tester_mapping(BLIP_TEST_FILE)
EXPECTED_BERT_MAPPING = {"BertModelTest": "BertModelTester"}
EXPECTED_BLIP_MAPPING = {
"BlipModelTest": "BlipModelTester",
"BlipTextImageModelTest": "BlipTextImageModelsModelTester",
"BlipTextModelTest": "BlipTextModelTester",
"BlipTextRetrievalModelTest": "BlipTextRetrievalModelTester",
"BlipVQAModelTest": "BlipVQAModelTester",
"BlipVisionModelTest": "BlipVisionModelTester",
}
self.assertEqual(get_test_info.to_json(bert_test_tester_mapping), EXPECTED_BERT_MAPPING)
self.assertEqual(get_test_info.to_json(blip_test_tester_mapping), EXPECTED_BLIP_MAPPING)
def test_get_model_to_test_mapping(self):
bert_model_test_mapping = get_model_to_test_mapping(BERT_TEST_FILE)
blip_model_test_mapping = get_model_to_test_mapping(BLIP_TEST_FILE)
EXPECTED_BERT_MAPPING = {
"BertForMaskedLM": ["BertModelTest"],
"BertForMultipleChoice": ["BertModelTest"],
"BertForNextSentencePrediction": ["BertModelTest"],
"BertForPreTraining": ["BertModelTest"],
"BertForQuestionAnswering": ["BertModelTest"],
"BertForSequenceClassification": ["BertModelTest"],
"BertForTokenClassification": ["BertModelTest"],
"BertLMHeadModel": ["BertModelTest"],
"BertModel": ["BertModelTest"],
}
EXPECTED_BLIP_MAPPING = {
"BlipForConditionalGeneration": ["BlipTextImageModelTest"],
"BlipForImageTextRetrieval": ["BlipTextRetrievalModelTest"],
"BlipForQuestionAnswering": ["BlipVQAModelTest"],
"BlipModel": ["BlipModelTest"],
"BlipTextModel": ["BlipTextModelTest"],
"BlipVisionModel": ["BlipVisionModelTest"],
}
self.assertEqual(get_test_info.to_json(bert_model_test_mapping), EXPECTED_BERT_MAPPING)
self.assertEqual(get_test_info.to_json(blip_model_test_mapping), EXPECTED_BLIP_MAPPING)
def test_get_model_to_tester_mapping(self):
bert_model_tester_mapping = get_model_to_tester_mapping(BERT_TEST_FILE)
blip_model_tester_mapping = get_model_to_tester_mapping(BLIP_TEST_FILE)
EXPECTED_BERT_MAPPING = {
"BertForMaskedLM": ["BertModelTester"],
"BertForMultipleChoice": ["BertModelTester"],
"BertForNextSentencePrediction": ["BertModelTester"],
"BertForPreTraining": ["BertModelTester"],
"BertForQuestionAnswering": ["BertModelTester"],
"BertForSequenceClassification": ["BertModelTester"],
"BertForTokenClassification": ["BertModelTester"],
"BertLMHeadModel": ["BertModelTester"],
"BertModel": ["BertModelTester"],
}
EXPECTED_BLIP_MAPPING = {
"BlipForConditionalGeneration": ["BlipTextImageModelsModelTester"],
"BlipForImageTextRetrieval": ["BlipTextRetrievalModelTester"],
"BlipForQuestionAnswering": ["BlipVQAModelTester"],
"BlipModel": ["BlipModelTester"],
"BlipTextModel": ["BlipTextModelTester"],
"BlipVisionModel": ["BlipVisionModelTester"],
}
self.assertEqual(get_test_info.to_json(bert_model_tester_mapping), EXPECTED_BERT_MAPPING)
self.assertEqual(get_test_info.to_json(blip_model_tester_mapping), EXPECTED_BLIP_MAPPING)

View File

@@ -0,0 +1,806 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import sys
import tempfile
import unittest
from contextlib import ExitStack, contextmanager
from pathlib import Path
from unittest.mock import patch
from git import Repo
from transformers.testing_utils import CaptureStdout
REPO_PATH = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(REPO_PATH, "utils"))
import tests_fetcher # noqa: E402
from tests_fetcher import ( # noqa: E402
checkout_commit,
clean_code,
create_reverse_dependency_map,
create_reverse_dependency_tree,
create_test_list_from_filter,
diff_is_docstring_only,
extract_imports,
get_all_tests,
get_diff,
get_module_dependencies,
get_repo_utils_tests,
get_tree_starting_at,
infer_tests_to_run,
init_test_examples_dependencies,
parse_commit_message,
print_tree_deps_of,
should_run_repo_utils_tests,
)
BERT_MODELING_FILE = "src/transformers/models/bert/modeling_bert.py"
BERT_MODEL_FILE = """from ...modeling_utils import PreTrainedModel
from ...utils import is_torch_available
from .configuration_bert import BertConfig
class BertModel:
'''
This is the docstring.
'''
This is the code
"""
BERT_MODEL_FILE_NEW_DOCSTRING = """from ...modeling_utils import PreTrainedModel
from ...utils import is_torch_available
from .configuration_bert import BertConfig
class BertModel:
'''
This is the docstring. It has been updated.
'''
This is the code
"""
BERT_MODEL_FILE_NEW_CODE = """from ...modeling_utils import PreTrainedModel
from ...utils import is_torch_available
from .configuration_bert import BertConfig
class BertModel:
'''
This is the docstring.
'''
This is the code. It has been updated
"""
def create_tmp_repo(tmp_dir, models=None):
"""
Creates a repository in a temporary directory mimicking the structure of Transformers. Uses the list of models
provided (which defaults to just `["bert"]`).
"""
tmp_dir = Path(tmp_dir)
if tmp_dir.exists():
shutil.rmtree(tmp_dir)
tmp_dir.mkdir(exist_ok=True)
repo = Repo.init(tmp_dir)
if models is None:
models = ["bert"]
class_names = [model[0].upper() + model[1:] for model in models]
transformers_dir = tmp_dir / "src" / "transformers"
transformers_dir.mkdir(parents=True, exist_ok=True)
with open(transformers_dir / "__init__.py", "w") as f:
init_lines = ["from .utils import cached_file, is_torch_available"]
init_lines.extend(
[f"from .models.{model} import {cls}Config, {cls}Model" for model, cls in zip(models, class_names)]
)
f.write("\n".join(init_lines) + "\n")
with open(transformers_dir / "configuration_utils.py", "w") as f:
f.write("from .utils import cached_file\n\ncode")
with open(transformers_dir / "modeling_utils.py", "w") as f:
f.write("from .utils import cached_file\n\ncode")
utils_dir = tmp_dir / "src" / "transformers" / "utils"
utils_dir.mkdir(exist_ok=True)
with open(utils_dir / "__init__.py", "w") as f:
f.write("from .hub import cached_file\nfrom .imports import is_torch_available\n")
with open(utils_dir / "hub.py", "w") as f:
f.write("import huggingface_hub\n\ncode")
with open(utils_dir / "imports.py", "w") as f:
f.write("code")
model_dir = tmp_dir / "src" / "transformers" / "models"
model_dir.mkdir(parents=True, exist_ok=True)
with open(model_dir / "__init__.py", "w") as f:
f.write("\n".join([f"import {model}" for model in models]))
for model, cls in zip(models, class_names):
model_dir = tmp_dir / "src" / "transformers" / "models" / model
model_dir.mkdir(parents=True, exist_ok=True)
with open(model_dir / "__init__.py", "w") as f:
f.write(f"from .configuration_{model} import {cls}Config\nfrom .modeling_{model} import {cls}Model\n")
with open(model_dir / f"configuration_{model}.py", "w") as f:
f.write("from ...configuration_utils import PreTrainedConfig\ncode")
with open(model_dir / f"modeling_{model}.py", "w") as f:
modeling_code = BERT_MODEL_FILE.replace("bert", model).replace("Bert", cls)
f.write(modeling_code)
test_dir = tmp_dir / "tests"
test_dir.mkdir(exist_ok=True)
with open(test_dir / "test_modeling_common.py", "w") as f:
f.write("from transformers.modeling_utils import PreTrainedModel\ncode")
for model, cls in zip(models, class_names):
test_model_dir = test_dir / "models" / model
test_model_dir.mkdir(parents=True, exist_ok=True)
(test_model_dir / "__init__.py").touch()
with open(test_model_dir / f"test_modeling_{model}.py", "w") as f:
f.write(
f"from transformers import {cls}Config, {cls}Model\nfrom ...test_modeling_common import ModelTesterMixin\n\ncode"
)
example_dir = tmp_dir / "examples"
example_dir.mkdir(exist_ok=True)
framework_dir = example_dir / "pytorch"
framework_dir.mkdir(exist_ok=True)
with open(framework_dir / "test_pytorch_examples.py", "w") as f:
f.write("""test_args = "run_glue.py"\n""")
glue_dir = framework_dir / "text-classification"
glue_dir.mkdir(exist_ok=True)
with open(glue_dir / "run_glue.py", "w") as f:
f.write("from transformers import BertModel\n\ncode")
repo.index.add(["examples", "src", "tests"])
repo.index.commit("Initial commit")
if "main" not in repo.heads:
repo.create_head("main")
repo.head.reference = repo.refs.main
if "master" in repo.heads:
repo.delete_head("master")
return repo
@contextmanager
def patch_transformer_repo_path(new_folder):
"""
Temporarily patches the variables defines in `tests_fetcher` to use a different location for the repo.
"""
old_repo_path = tests_fetcher.PATH_TO_REPO
tests_fetcher.PATH_TO_REPO = Path(new_folder).resolve()
tests_fetcher.PATH_TO_EXAMPLES = tests_fetcher.PATH_TO_REPO / "examples"
tests_fetcher.PATH_TO_TRANSFORMERS = tests_fetcher.PATH_TO_REPO / "src/transformers"
tests_fetcher.PATH_TO_TESTS = tests_fetcher.PATH_TO_REPO / "tests"
try:
yield
finally:
tests_fetcher.PATH_TO_REPO = old_repo_path
tests_fetcher.PATH_TO_EXAMPLES = tests_fetcher.PATH_TO_REPO / "examples"
tests_fetcher.PATH_TO_TRANSFORMERS = tests_fetcher.PATH_TO_REPO / "src/transformers"
tests_fetcher.PATH_TO_TESTS = tests_fetcher.PATH_TO_REPO / "tests"
def commit_changes(filenames, contents, repo, commit_message="Commit"):
"""
Commit new `contents` to `filenames` inside a given `repo`.
"""
if not isinstance(filenames, list):
filenames = [filenames]
if not isinstance(contents, list):
contents = [contents]
folder = Path(repo.working_dir)
for filename, content in zip(filenames, contents):
with open(folder / filename, "w") as f:
f.write(content)
repo.index.add(filenames)
commit = repo.index.commit(commit_message)
return commit.hexsha
class TestFetcherTester(unittest.TestCase):
def test_checkout_commit(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
repo = create_tmp_repo(tmp_folder)
initial_sha = repo.head.commit.hexsha
new_sha = commit_changes(BERT_MODELING_FILE, BERT_MODEL_FILE_NEW_DOCSTRING, repo)
assert repo.head.commit.hexsha == new_sha
with checkout_commit(repo, initial_sha):
assert repo.head.commit.hexsha == initial_sha
with open(tmp_folder / BERT_MODELING_FILE) as f:
assert f.read() == BERT_MODEL_FILE
assert repo.head.commit.hexsha == new_sha
with open(tmp_folder / BERT_MODELING_FILE) as f:
assert f.read() == BERT_MODEL_FILE_NEW_DOCSTRING
def test_clean_code(self):
# Clean code removes all strings in triple quotes
assert clean_code('"""\nDocstring\n"""\ncode\n"""Long string"""\ncode\n') == "code\ncode"
assert clean_code("'''\nDocstring\n'''\ncode\n'''Long string'''\ncode\n'''") == "code\ncode"
# Clean code removes all comments
assert clean_code("code\n# Comment\ncode") == "code\ncode"
assert clean_code("code # inline comment\ncode") == "code \ncode"
def test_get_all_tests(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
assert get_all_tests() == ["tests/models/bert", "tests/test_modeling_common.py"]
def test_get_all_tests_on_full_repo(self):
all_tests = get_all_tests()
assert "tests/models/albert" in all_tests
assert "tests/models/bert" in all_tests
assert "tests/repo_utils" in all_tests
assert "tests/test_pipeline_mixin.py" in all_tests
assert "tests/models" not in all_tests
assert "tests/__pycache__" not in all_tests
assert "tests/models/albert/test_modeling_albert.py" not in all_tests
assert "tests/repo_utils/test_tests_fetcher.py" not in all_tests
def test_get_repo_utils_tests_on_full_repo(self):
repo_utils_tests = get_repo_utils_tests()
assert "tests/repo_utils/test_tests_fetcher.py" in repo_utils_tests
def test_should_run_repo_utils_tests(self):
assert should_run_repo_utils_tests(["utils/check_modeling_structure.py"])
assert not should_run_repo_utils_tests(["src/transformers/modeling_utils.py"])
def test_create_test_list_from_filter_routes_repo_utils_tests(self):
with tempfile.TemporaryDirectory() as tmp_folder:
create_test_list_from_filter(
[
"tests/models/bert/test_modeling_bert.py",
"tests/repo_utils/test_tests_fetcher.py",
],
out_path=tmp_folder,
)
with open(Path(tmp_folder) / "tests_repo_utils_test_list.txt", encoding="utf-8") as f:
repo_utils_tests = f.read().splitlines()
assert repo_utils_tests == [
"tests/repo_utils/test_tests_fetcher.py",
]
def test_create_test_list_from_filter_does_not_create_hub_job(self):
with tempfile.TemporaryDirectory() as tmp_folder:
create_test_list_from_filter(["tests/models/bert/test_modeling_bert.py"], out_path=tmp_folder)
assert (Path(tmp_folder) / "tests_torch_test_list.txt").exists()
assert not (Path(tmp_folder) / "tests_hub_test_list.txt").exists()
def test_infer_tests_to_run_adds_repo_utils_for_utils_changes(self):
with ExitStack() as stack:
stack.enter_context(patch.object(tests_fetcher, "commit_flags", {"test_all": False}, create=True))
stack.enter_context(
patch.object(
tests_fetcher, "get_modified_python_files", return_value=["utils/check_modeling_structure.py"]
)
)
stack.enter_context(patch.object(tests_fetcher, "create_reverse_dependency_map", return_value={}))
stack.enter_context(
patch.object(tests_fetcher, "get_impacted_files_from_tiny_model_summary", return_value=[])
)
mock_create_test_list = stack.enter_context(patch.object(tests_fetcher, "create_test_list_from_filter"))
stack.enter_context(patch.object(tests_fetcher, "get_doctest_files", return_value=[]))
infer_tests_to_run("unused.txt", diff_with_last_commit=True)
test_files_to_run = mock_create_test_list.call_args.args[0]
assert "tests/repo_utils/test_tests_fetcher.py" in test_files_to_run
def test_diff_is_docstring_only(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
repo = create_tmp_repo(tmp_folder)
branching_point = repo.refs.main.commit
bert_file = BERT_MODELING_FILE
commit_changes(bert_file, BERT_MODEL_FILE_NEW_DOCSTRING, repo)
assert diff_is_docstring_only(repo, branching_point, bert_file)
commit_changes(bert_file, BERT_MODEL_FILE_NEW_CODE, repo)
assert not diff_is_docstring_only(repo, branching_point, bert_file)
def test_get_diff_ignores_docstring_only_changes(self):
"""Files whose diff is only in docstrings/comments should be excluded from get_diff results."""
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
repo = create_tmp_repo(tmp_folder)
branching_commit = repo.head.commit
# Docstring-only change: should NOT appear in diff
commit_changes(BERT_MODELING_FILE, BERT_MODEL_FILE_NEW_DOCSTRING, repo)
diff = get_diff(repo, repo.head.commit, [branching_commit])
assert BERT_MODELING_FILE not in diff
# Real code change: should appear in diff
commit_changes(BERT_MODELING_FILE, BERT_MODEL_FILE_NEW_CODE, repo)
diff = get_diff(repo, repo.head.commit, [branching_commit])
assert BERT_MODELING_FILE in diff
def test_extract_imports_relative(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
create_tmp_repo(tmp_folder)
expected_bert_imports = [
("src/transformers/modeling_utils.py", ["PreTrainedModel"]),
("src/transformers/utils/__init__.py", ["is_torch_available"]),
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
]
expected_utils_imports = [
("src/transformers/utils/hub.py", ["cached_file"]),
("src/transformers/utils/imports.py", ["is_torch_available"]),
]
with patch_transformer_repo_path(tmp_folder):
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
assert extract_imports("src/transformers/utils/__init__.py") == expected_utils_imports
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
f.write(
"from ...utils import cached_file, is_torch_available\nfrom .configuration_bert import BertConfig\n"
)
expected_bert_imports = [
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
]
with patch_transformer_repo_path(tmp_folder):
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
# Test with multi-line imports
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
f.write(
"from ...utils import (\n cached_file,\n is_torch_available\n)\nfrom .configuration_bert import BertConfig\n"
)
expected_bert_imports = [
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
]
with patch_transformer_repo_path(tmp_folder):
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
def test_extract_imports_absolute(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
create_tmp_repo(tmp_folder)
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
f.write(
"from transformers.utils import cached_file, is_torch_available\nfrom transformers.models.bert.configuration_bert import BertConfig\n"
)
expected_bert_imports = [
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
]
with patch_transformer_repo_path(tmp_folder):
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
# Test with multi-line imports
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
f.write(
"from transformers.utils import (\n cached_file,\n is_torch_available\n)\nfrom transformers.models.bert.configuration_bert import BertConfig\n"
)
expected_bert_imports = [
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
]
with patch_transformer_repo_path(tmp_folder):
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
# Test with base imports
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
f.write(
"from transformers.utils import (\n cached_file,\n is_torch_available\n)\nfrom transformers import BertConfig\n"
)
expected_bert_imports = [
("src/transformers/__init__.py", ["BertConfig"]),
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
]
with patch_transformer_repo_path(tmp_folder):
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
def test_get_module_dependencies(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
create_tmp_repo(tmp_folder)
expected_bert_dependencies = [
"src/transformers/modeling_utils.py",
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/utils/imports.py",
]
with patch_transformer_repo_path(tmp_folder):
assert get_module_dependencies(BERT_MODELING_FILE) == expected_bert_dependencies
expected_test_bert_dependencies = [
"tests/test_modeling_common.py",
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/models/bert/modeling_bert.py",
]
with patch_transformer_repo_path(tmp_folder):
assert (
get_module_dependencies("tests/models/bert/test_modeling_bert.py")
== expected_test_bert_dependencies
)
# Test with a submodule
(tmp_folder / "src/transformers/utils/logging.py").touch()
with open(tmp_folder / BERT_MODELING_FILE, "a") as f:
f.write("from ...utils import logging\n")
expected_bert_dependencies = [
"src/transformers/modeling_utils.py",
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/utils/logging.py",
"src/transformers/utils/imports.py",
]
with patch_transformer_repo_path(tmp_folder):
assert get_module_dependencies(BERT_MODELING_FILE) == expected_bert_dependencies
# Test with an object non-imported in the init
create_tmp_repo(tmp_folder)
with open(tmp_folder / BERT_MODELING_FILE, "a") as f:
f.write("from ...utils import CONSTANT\n")
expected_bert_dependencies = [
"src/transformers/modeling_utils.py",
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/utils/__init__.py",
"src/transformers/utils/imports.py",
]
with patch_transformer_repo_path(tmp_folder):
assert get_module_dependencies(BERT_MODELING_FILE) == expected_bert_dependencies
# Test with an example
create_tmp_repo(tmp_folder)
expected_example_dependencies = ["src/transformers/models/bert/modeling_bert.py"]
with patch_transformer_repo_path(tmp_folder):
assert (
get_module_dependencies("examples/pytorch/text-classification/run_glue.py")
== expected_example_dependencies
)
def test_create_reverse_dependency_tree(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
tree = create_reverse_dependency_tree()
init_edges = [
"src/transformers/utils/hub.py",
"src/transformers/utils/imports.py",
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/models/bert/modeling_bert.py",
]
assert {f for f, g in tree if g == "src/transformers/__init__.py"} == set(init_edges)
bert_edges = [
"src/transformers/modeling_utils.py",
"src/transformers/utils/imports.py",
"src/transformers/models/bert/configuration_bert.py",
]
assert {f for f, g in tree if g == "src/transformers/models/bert/modeling_bert.py"} == set(bert_edges)
test_bert_edges = [
"tests/test_modeling_common.py",
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/models/bert/modeling_bert.py",
]
assert {f for f, g in tree if g == "tests/models/bert/test_modeling_bert.py"} == set(test_bert_edges)
def test_get_tree_starting_at(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
edges = create_reverse_dependency_tree()
bert_tree = get_tree_starting_at("src/transformers/models/bert/modeling_bert.py", edges)
config_utils_tree = get_tree_starting_at("src/transformers/configuration_utils.py", edges)
expected_bert_tree = [
"src/transformers/models/bert/modeling_bert.py",
[("src/transformers/models/bert/modeling_bert.py", "tests/models/bert/test_modeling_bert.py")],
]
assert bert_tree == expected_bert_tree
expected_config_tree = [
"src/transformers/configuration_utils.py",
[("src/transformers/configuration_utils.py", "src/transformers/models/bert/configuration_bert.py")],
[
("src/transformers/models/bert/configuration_bert.py", "tests/models/bert/test_modeling_bert.py"),
(
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/models/bert/modeling_bert.py",
),
],
]
# Order of the edges is random
assert [set(v) for v in config_utils_tree] == [set(v) for v in expected_config_tree]
def test_print_tree_deps_of(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
create_tmp_repo(tmp_folder)
# There are two possible outputs since the order of the last two lines is non-deterministic.
expected_std_out = """src/transformers/models/bert/modeling_bert.py
tests/models/bert/test_modeling_bert.py
src/transformers/configuration_utils.py
src/transformers/models/bert/configuration_bert.py
src/transformers/models/bert/modeling_bert.py
tests/models/bert/test_modeling_bert.py"""
expected_std_out_2 = """src/transformers/models/bert/modeling_bert.py
tests/models/bert/test_modeling_bert.py
src/transformers/configuration_utils.py
src/transformers/models/bert/configuration_bert.py
tests/models/bert/test_modeling_bert.py
src/transformers/models/bert/modeling_bert.py"""
with patch_transformer_repo_path(tmp_folder), CaptureStdout() as cs:
print_tree_deps_of("src/transformers/models/bert/modeling_bert.py")
print_tree_deps_of("src/transformers/configuration_utils.py")
assert cs.out.strip() in [expected_std_out, expected_std_out_2]
def test_init_test_examples_dependencies(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder).resolve()
create_tmp_repo(tmp_folder)
expected_example_deps = {
"examples/pytorch/test_pytorch_examples.py": [
"examples/pytorch/text-classification/run_glue.py",
"examples/pytorch/test_pytorch_examples.py",
],
}
expected_examples = {
"examples/pytorch/test_pytorch_examples.py",
"examples/pytorch/text-classification/run_glue.py",
}
with patch_transformer_repo_path(tmp_folder):
example_deps, all_examples = init_test_examples_dependencies()
assert example_deps == expected_example_deps
assert {str(f.relative_to(tmp_folder)) for f in all_examples} == expected_examples
def test_create_reverse_dependency_map(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
reverse_map = create_reverse_dependency_map()
# impact of BERT modeling file (note that we stop at the inits and don't go down further)
expected_bert_deps = {
"src/transformers/__init__.py",
"src/transformers/models/bert/__init__.py",
"tests/models/bert/test_modeling_bert.py",
"examples/pytorch/test_pytorch_examples.py",
"examples/pytorch/text-classification/run_glue.py",
}
assert set(reverse_map["src/transformers/models/bert/modeling_bert.py"]) == expected_bert_deps
# init gets the direct deps (and their recursive deps)
expected_init_deps = {
"src/transformers/utils/__init__.py",
"src/transformers/utils/hub.py",
"src/transformers/utils/imports.py",
"src/transformers/models/bert/__init__.py",
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/models/bert/modeling_bert.py",
"src/transformers/configuration_utils.py",
"src/transformers/modeling_utils.py",
"tests/test_modeling_common.py",
"tests/models/bert/test_modeling_bert.py",
"examples/pytorch/test_pytorch_examples.py",
"examples/pytorch/text-classification/run_glue.py",
}
assert set(reverse_map["src/transformers/__init__.py"]) == expected_init_deps
expected_init_deps = {
"src/transformers/__init__.py",
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/models/bert/modeling_bert.py",
"tests/models/bert/test_modeling_bert.py",
"examples/pytorch/test_pytorch_examples.py",
"examples/pytorch/text-classification/run_glue.py",
}
assert set(reverse_map["src/transformers/models/bert/__init__.py"]) == expected_init_deps
# Test that with more models init of bert only gets deps to bert.
create_tmp_repo(tmp_folder, models=["bert", "gpt2"])
with patch_transformer_repo_path(tmp_folder):
reverse_map = create_reverse_dependency_map()
# init gets the direct deps (and their recursive deps)
expected_init_deps = {
"src/transformers/__init__.py",
"src/transformers/models/bert/configuration_bert.py",
"src/transformers/models/bert/modeling_bert.py",
"tests/models/bert/test_modeling_bert.py",
"examples/pytorch/test_pytorch_examples.py",
"examples/pytorch/text-classification/run_glue.py",
}
assert set(reverse_map["src/transformers/models/bert/__init__.py"]) == expected_init_deps
@unittest.skip("Broken for now TODO @ArthurZucker")
def test_infer_tests_to_run(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
models = ["bert", "gpt2"] + [f"bert{i}" for i in range(10)]
repo = create_tmp_repo(tmp_folder, models=models)
commit_changes("src/transformers/models/bert/modeling_bert.py", BERT_MODEL_FILE_NEW_CODE, repo)
example_tests = {
"examples/pytorch/test_pytorch_examples.py",
}
with patch_transformer_repo_path(tmp_folder):
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
with open(tmp_folder / "test-output.txt") as f:
tests_to_run = f.read()
with open(tmp_folder / "examples_test_list.txt") as f:
example_tests_to_run = f.read()
assert tests_to_run == "tests/models/bert/test_modeling_bert.py"
assert set(example_tests_to_run.split(" ")) == example_tests
# Fake a new model addition
repo = create_tmp_repo(tmp_folder, models=models)
branch = repo.create_head("new_model")
branch.checkout()
with open(tmp_folder / "src/transformers/__init__.py", "a") as f:
f.write("from .models.t5 import T5Config, T5Model\n")
model_dir = tmp_folder / "src/transformers/models/t5"
model_dir.mkdir(exist_ok=True)
with open(model_dir / "__init__.py", "w") as f:
f.write("from .configuration_t5 import T5Config\nfrom .modeling_t5 import T5Model\n")
with open(model_dir / "configuration_t5.py", "w") as f:
f.write("from ...configuration_utils import PreTrainedConfig\ncode")
with open(model_dir / "modeling_t5.py", "w") as f:
modeling_code = BERT_MODEL_FILE.replace("bert", "t5").replace("Bert", "T5")
f.write(modeling_code)
test_dir = tmp_folder / "tests/models/t5"
test_dir.mkdir(exist_ok=True)
(test_dir / "__init__.py").touch()
with open(test_dir / "test_modeling_t5.py", "w") as f:
f.write(
"from transformers import T5Config, T5Model\nfrom ...test_modeling_common import ModelTesterMixin\n\ncode"
)
repo.index.add(["src", "tests"])
repo.index.commit("Add T5 model")
with patch_transformer_repo_path(tmp_folder):
infer_tests_to_run(tmp_folder / "test-output.txt")
with open(tmp_folder / "test-output.txt") as f:
tests_to_run = f.read()
with open(tmp_folder / "examples_test_list.txt") as f:
example_tests_to_run = f.read()
expected_tests = {
"tests/models/bert/test_modeling_bert.py",
"tests/models/gpt2/test_modeling_gpt2.py",
"tests/models/t5/test_modeling_t5.py",
"tests/test_modeling_common.py",
}
assert set(tests_to_run.split(" ")) == expected_tests
assert set(example_tests_to_run.split(" ")) == example_tests
with patch_transformer_repo_path(tmp_folder):
infer_tests_to_run(tmp_folder / "test-output.txt")
with open(tmp_folder / "test-output.txt") as f:
tests_to_run = f.read()
with open(tmp_folder / "examples_test_list.txt") as f:
example_tests_to_run = f.read()
expected_tests = [f"tests/models/{name}/test_modeling_{name}.py" for name in models + ["t5"]]
expected_tests = set(expected_tests + ["tests/test_modeling_common.py"])
assert set(tests_to_run.split(" ")) == expected_tests
assert set(example_tests_to_run.split(" ")) == example_tests
@unittest.skip("Broken for now TODO @ArthurZucker")
def test_infer_tests_to_run_with_test_modifs(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
models = ["bert", "gpt2"] + [f"bert{i}" for i in range(10)]
repo = create_tmp_repo(tmp_folder, models=models)
commit_changes(
"tests/models/bert/test_modeling_bert.py",
"from transformers import BertConfig, BertModel\nfrom ...test_modeling_common import ModelTesterMixin\n\ncode1",
repo,
)
with patch_transformer_repo_path(tmp_folder):
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
with open(tmp_folder / "test-output.txt") as f:
tests_to_run = f.read()
assert tests_to_run == "tests/models/bert/test_modeling_bert.py"
@unittest.skip("Broken for now TODO @ArthurZucker")
def test_infer_tests_to_run_with_examples_modifs(self):
with tempfile.TemporaryDirectory() as tmp_folder:
tmp_folder = Path(tmp_folder)
models = ["bert", "gpt2"]
repo = create_tmp_repo(tmp_folder, models=models)
# Modification in one example trigger the corresponding test
commit_changes(
"examples/pytorch/text-classification/run_glue.py",
"from transformers import BertModeln\n\ncode1",
repo,
)
with patch_transformer_repo_path(tmp_folder):
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
with open(tmp_folder / "examples_test_list.txt") as f:
example_tests_to_run = f.read()
assert example_tests_to_run == "examples/pytorch/test_pytorch_examples.py"
# Modification in one test example file trigger that test
repo = create_tmp_repo(tmp_folder, models=models)
commit_changes(
"examples/pytorch/test_pytorch_examples.py",
"""test_args = "run_glue.py"\nmore_code""",
repo,
)
with patch_transformer_repo_path(tmp_folder):
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
with open(tmp_folder / "examples_test_list.txt") as f:
example_tests_to_run = f.read()
assert example_tests_to_run == "examples/pytorch/test_pytorch_examples.py"
def test_parse_commit_message(self):
assert parse_commit_message("Normal commit") == {"skip": False, "no_filter": False, "test_all": False}
assert parse_commit_message("[skip ci] commit") == {"skip": True, "no_filter": False, "test_all": False}
assert parse_commit_message("[ci skip] commit") == {"skip": True, "no_filter": False, "test_all": False}
assert parse_commit_message("[skip-ci] commit") == {"skip": True, "no_filter": False, "test_all": False}
assert parse_commit_message("[skip_ci] commit") == {"skip": True, "no_filter": False, "test_all": False}
assert parse_commit_message("[no filter] commit") == {"skip": False, "no_filter": True, "test_all": False}
assert parse_commit_message("[no-filter] commit") == {"skip": False, "no_filter": True, "test_all": False}
assert parse_commit_message("[no_filter] commit") == {"skip": False, "no_filter": True, "test_all": False}
assert parse_commit_message("[filter-no] commit") == {"skip": False, "no_filter": True, "test_all": False}
assert parse_commit_message("[test all] commit") == {"skip": False, "no_filter": False, "test_all": True}
assert parse_commit_message("[all test] commit") == {"skip": False, "no_filter": False, "test_all": True}
assert parse_commit_message("[test-all] commit") == {"skip": False, "no_filter": False, "test_all": True}
assert parse_commit_message("[all_test] commit") == {"skip": False, "no_filter": False, "test_all": True}