first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
75
tests/repo_utils/modular/test_conversion_order.py
Normal file
75
tests/repo_utils/modular/test_conversion_order.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.join(ROOT_DIR, "utils"))
|
||||
|
||||
import create_dependency_mapping # noqa: E402
|
||||
|
||||
|
||||
# This is equivalent to `all` in the current library state (as of 09/01/2025)
|
||||
MODEL_ROOT = os.path.join("src", "transformers", "models")
|
||||
FILES_TO_PARSE = [
|
||||
os.path.join(MODEL_ROOT, "starcoder2", "modular_starcoder2.py"),
|
||||
os.path.join(MODEL_ROOT, "gemma", "modular_gemma.py"),
|
||||
os.path.join(MODEL_ROOT, "olmo2", "modular_olmo2.py"),
|
||||
os.path.join(MODEL_ROOT, "diffllama", "modular_diffllama.py"),
|
||||
os.path.join(MODEL_ROOT, "granite", "modular_granite.py"),
|
||||
os.path.join(MODEL_ROOT, "gemma2", "modular_gemma2.py"),
|
||||
os.path.join(MODEL_ROOT, "mixtral", "modular_mixtral.py"),
|
||||
os.path.join(MODEL_ROOT, "olmo", "modular_olmo.py"),
|
||||
os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"),
|
||||
os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"),
|
||||
os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3.py"),
|
||||
os.path.join(MODEL_ROOT, "llava_next_video", "modular_llava_next_video.py"),
|
||||
os.path.join(MODEL_ROOT, "cohere2", "modular_cohere2.py"),
|
||||
os.path.join(MODEL_ROOT, "modernbert", "modular_modernbert.py"),
|
||||
os.path.join(MODEL_ROOT, "colpali", "modular_colpali.py"),
|
||||
os.path.join(MODEL_ROOT, "deformable_detr", "modular_deformable_detr.py"),
|
||||
os.path.join(MODEL_ROOT, "aria", "modular_aria.py"),
|
||||
os.path.join(MODEL_ROOT, "ijepa", "modular_ijepa.py"),
|
||||
os.path.join(MODEL_ROOT, "bamba", "modular_bamba.py"),
|
||||
os.path.join(MODEL_ROOT, "dinov2_with_registers", "modular_dinov2_with_registers.py"),
|
||||
os.path.join(MODEL_ROOT, "instructblipvideo", "modular_instructblipvideo.py"),
|
||||
os.path.join(MODEL_ROOT, "glm", "modular_glm.py"),
|
||||
os.path.join(MODEL_ROOT, "phi", "modular_phi.py"),
|
||||
os.path.join(MODEL_ROOT, "mistral", "modular_mistral.py"),
|
||||
os.path.join(MODEL_ROOT, "phi3", "modular_phi3.py"),
|
||||
os.path.join(MODEL_ROOT, "cohere", "modular_cohere.py"),
|
||||
os.path.join(MODEL_ROOT, "glm4", "modular_glm4.py"),
|
||||
os.path.join(MODEL_ROOT, "seed_oss", "modular_seed_oss.py"),
|
||||
]
|
||||
|
||||
|
||||
def appear_after(model1: str, model2: str, priority_list: list[list[str]]) -> bool:
|
||||
"""Return True if `model1` appear after `model2` in `priority_list`."""
|
||||
model1_index, model2_index = None, None
|
||||
for i, level in enumerate(priority_list):
|
||||
if model1 in level:
|
||||
model1_index = i
|
||||
if model2 in level:
|
||||
model2_index = i
|
||||
if model1_index is None or model2_index is None:
|
||||
raise ValueError(f"Model {model1} or {model2} not found in {priority_list}")
|
||||
return model1_index > model2_index
|
||||
|
||||
|
||||
class ConversionOrderTest(unittest.TestCase):
|
||||
def test_conversion_order(self):
|
||||
# Find the order
|
||||
priority_list, _ = create_dependency_mapping.find_priority_list(FILES_TO_PARSE)
|
||||
# Extract just the model names (list of lists)
|
||||
model_priority_list = [[file.split("/")[-2] for file in level] for level in priority_list]
|
||||
|
||||
# These are based on what the current library order should be (as of 09/01/2025)
|
||||
self.assertTrue(appear_after("mixtral", "mistral", model_priority_list))
|
||||
self.assertTrue(appear_after("gemma2", "gemma", model_priority_list))
|
||||
self.assertTrue(appear_after("starcoder2", "mistral", model_priority_list))
|
||||
self.assertTrue(appear_after("olmo2", "olmo", model_priority_list))
|
||||
self.assertTrue(appear_after("diffllama", "mistral", model_priority_list))
|
||||
self.assertTrue(appear_after("cohere2", "gemma2", model_priority_list))
|
||||
self.assertTrue(appear_after("cohere2", "cohere", model_priority_list))
|
||||
self.assertTrue(appear_after("phi3", "mistral", model_priority_list))
|
||||
self.assertTrue(appear_after("glm4", "glm", model_priority_list))
|
||||
129
tests/repo_utils/test_check_auto.py
Normal file
129
tests/repo_utils/test_check_auto.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
utils_path = os.path.join(git_repo_path, "utils")
|
||||
if utils_path not in sys.path:
|
||||
sys.path.append(utils_path)
|
||||
|
||||
import check_auto # noqa: E402
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cwd(path: Path):
|
||||
old = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(old)
|
||||
|
||||
|
||||
def _write_config(root: Path, module_name: str, classes: list[tuple[str, str]]) -> None:
|
||||
"""Create `src/transformers/models/<module>/configuration_<module>.py` with the given classes.
|
||||
|
||||
`classes` is a list of (class_name, model_type) pairs, all subclassing PreTrainedConfig.
|
||||
"""
|
||||
module_dir = root / "src" / "transformers" / "models" / module_name
|
||||
module_dir.mkdir(parents=True, exist_ok=True)
|
||||
body = "from transformers import PreTrainedConfig\n\n"
|
||||
for cls_name, model_type in classes:
|
||||
body += textwrap.dedent(
|
||||
f'''
|
||||
class {cls_name}(PreTrainedConfig):
|
||||
model_type = "{model_type}"
|
||||
'''
|
||||
)
|
||||
(module_dir / f"configuration_{module_name}.py").write_text(body, encoding="utf-8")
|
||||
|
||||
|
||||
class BuildConfigMappingNamesTest(unittest.TestCase):
|
||||
"""Tests for the natural-match tie-break in `check_auto.build_config_mapping_names`.
|
||||
|
||||
A natural match is one where a config's `model_type` equals its module directory name
|
||||
(e.g. `DetrConfig` with `model_type = "detr"` inside `models/detr/`). When two classes
|
||||
share a `model_type`, the natural one must always win regardless of filesystem ordering.
|
||||
"""
|
||||
|
||||
def test_single_natural_match(self):
|
||||
"""Baseline: one config in its eponymous module → no special mapping."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp)
|
||||
_write_config(root, "detr", [("DetrConfig", "detr")])
|
||||
with cwd(root):
|
||||
model_type_map, special_mappings = check_auto.build_config_mapping_names()
|
||||
|
||||
self.assertEqual(model_type_map, {"detr": "DetrConfig"})
|
||||
self.assertEqual(special_mappings, {})
|
||||
|
||||
def test_natural_wins_when_encountered_first(self):
|
||||
"""detr (natural) is alphabetically before maskformer (non-natural for model_type=detr).
|
||||
|
||||
This is the order modern Linux filesystems produce. The natural match must be kept
|
||||
and the alias must not appear in the canonical mapping.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp)
|
||||
_write_config(root, "detr", [("DetrConfig", "detr")])
|
||||
_write_config(
|
||||
root,
|
||||
"maskformer",
|
||||
[("MaskFormerConfig", "maskformer"), ("MaskFormerDetrConfig", "detr")],
|
||||
)
|
||||
with cwd(root):
|
||||
model_type_map, special_mappings = check_auto.build_config_mapping_names()
|
||||
|
||||
self.assertEqual(model_type_map["detr"], "DetrConfig")
|
||||
self.assertEqual(model_type_map["maskformer"], "MaskFormerConfig")
|
||||
self.assertNotIn("detr", special_mappings)
|
||||
|
||||
def test_natural_wins_when_encountered_second(self):
|
||||
"""The non-natural alias is alphabetically *before* the natural module.
|
||||
|
||||
Without the prefer-natural logic this is the case that breaks: the alias would be
|
||||
recorded first and then never overwritten. The fix must still pick the natural class
|
||||
and clear the now-stale special mapping.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp)
|
||||
# `aaa_alias` sorts before `foo`, so its non-natural class is processed first.
|
||||
_write_config(root, "aaa_alias", [("AaaConfig", "aaa_alias"), ("FooAliasConfig", "foo")])
|
||||
_write_config(root, "foo", [("FooConfig", "foo")])
|
||||
with cwd(root):
|
||||
model_type_map, special_mappings = check_auto.build_config_mapping_names()
|
||||
|
||||
self.assertEqual(model_type_map["foo"], "FooConfig")
|
||||
self.assertNotIn("foo", special_mappings, "stale alias entry must be cleared")
|
||||
# The alias module's own natural entry is still recorded.
|
||||
self.assertEqual(model_type_map["aaa_alias"], "AaaConfig")
|
||||
|
||||
def test_non_natural_only_records_special_mapping(self):
|
||||
"""If a model_type has no natural match, the alias is the canonical entry."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
root = Path(tmp)
|
||||
_write_config(root, "wrapper", [("WrapperConfig", "wrapper"), ("InnerConfig", "inner")])
|
||||
with cwd(root):
|
||||
model_type_map, special_mappings = check_auto.build_config_mapping_names()
|
||||
|
||||
self.assertEqual(model_type_map["inner"], "InnerConfig")
|
||||
self.assertEqual(special_mappings["inner"], "wrapper")
|
||||
450
tests/repo_utils/test_check_copies.py
Normal file
450
tests/repo_utils/test_check_copies.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
||||
|
||||
import check_copies # noqa: E402
|
||||
from check_copies import convert_to_localized_md, find_code_in_transformers, is_copy_consistent # noqa: E402
|
||||
|
||||
|
||||
# This is the reference code that will be used in the tests.
|
||||
# If BertLMPredictionHead is changed in modeling_bert.py, this code needs to be manually updated.
|
||||
REFERENCE_CODE = """ def __init__(self, config):
|
||||
super().__init__()
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
"""
|
||||
|
||||
MOCK_BERT_CODE = """from ...modeling_utils import PreTrainedModel
|
||||
|
||||
def bert_function(x):
|
||||
return x
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class BertModel(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.bert = BertEncoder(config)
|
||||
|
||||
@add_docstring(BERT_DOCSTRING)
|
||||
def forward(self, x):
|
||||
return self.bert(x)
|
||||
"""
|
||||
|
||||
MOCK_BERT_COPY_CODE = """from ...modeling_utils import PreTrainedModel
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.bert_function
|
||||
def bert_copy_function(x):
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention
|
||||
class BertCopyAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->BertCopy all-casing
|
||||
class BertCopyModel(BertCopyPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.bertcopy = BertCopyEncoder(config)
|
||||
|
||||
@add_docstring(BERTCOPY_DOCSTRING)
|
||||
def forward(self, x):
|
||||
return self.bertcopy(x)
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_BERT_CODE_MATCH = """
|
||||
class BertDummyModel:
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 9
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_ROBERTA_CODE_MATCH = """
|
||||
# Copied from transformers.models.dummy_bert_match.modeling_dummy_bert_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||
class RobertaBertDummyModel:
|
||||
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Ignore copy
|
||||
def only_in_roberta_to_be_ignored(self, c):
|
||||
return 3
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
# Ignore copy
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 6
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_BERT_CODE_NO_MATCH = """
|
||||
class BertDummyModel:
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def only_in_bert(self, c):
|
||||
return 7
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_not_ignored(self, c):
|
||||
return 8
|
||||
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 9
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_ROBERTA_CODE_NO_MATCH = """
|
||||
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||
class RobertaBertDummyModel:
|
||||
|
||||
attr_1 = 1
|
||||
attr_2 = 3
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Ignore copy
|
||||
def only_in_roberta_to_be_ignored(self, c):
|
||||
return 3
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def only_in_roberta_not_ignored(self, c):
|
||||
return 2
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_not_ignored(self, c):
|
||||
return 5
|
||||
|
||||
# Ignore copy
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 6
|
||||
"""
|
||||
|
||||
|
||||
EXPECTED_REPLACED_CODE = """
|
||||
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||
class RobertaBertDummyModel:
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def only_in_bert(self, c):
|
||||
return 7
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_not_ignored(self, c):
|
||||
return 8
|
||||
|
||||
# Ignore copy
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 6
|
||||
|
||||
# Ignore copy
|
||||
def only_in_roberta_to_be_ignored(self, c):
|
||||
return 3
|
||||
"""
|
||||
|
||||
|
||||
def replace_in_file(filename, old, new):
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(filename, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def create_tmp_repo(tmp_dir):
|
||||
"""
|
||||
Creates a mock repository in a temporary folder for testing.
|
||||
"""
|
||||
tmp_dir = Path(tmp_dir)
|
||||
if tmp_dir.exists():
|
||||
shutil.rmtree(tmp_dir)
|
||||
tmp_dir.mkdir(exist_ok=True)
|
||||
|
||||
model_dir = tmp_dir / "src" / "transformers" / "models"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
models = {
|
||||
"bert": MOCK_BERT_CODE,
|
||||
"bertcopy": MOCK_BERT_COPY_CODE,
|
||||
"dummy_bert_match": MOCK_DUMMY_BERT_CODE_MATCH,
|
||||
"dummy_roberta_match": MOCK_DUMMY_ROBERTA_CODE_MATCH,
|
||||
"dummy_bert_no_match": MOCK_DUMMY_BERT_CODE_NO_MATCH,
|
||||
"dummy_roberta_no_match": MOCK_DUMMY_ROBERTA_CODE_NO_MATCH,
|
||||
}
|
||||
for model, code in models.items():
|
||||
model_subdir = model_dir / model
|
||||
model_subdir.mkdir(exist_ok=True)
|
||||
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(code)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_transformer_repo_path(new_folder):
|
||||
"""
|
||||
Temporarily patches the variables defines in `check_copies` to use a different location for the repo.
|
||||
"""
|
||||
old_repo_path = check_copies.REPO_PATH
|
||||
old_doc_path = check_copies.PATH_TO_DOCS
|
||||
old_transformer_path = check_copies.TRANSFORMERS_PATH
|
||||
repo_path = Path(new_folder).resolve()
|
||||
check_copies.REPO_PATH = str(repo_path)
|
||||
check_copies.PATH_TO_DOCS = str(repo_path / "docs" / "source" / "en")
|
||||
check_copies.TRANSFORMERS_PATH = str(repo_path / "src" / "transformers")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
check_copies.REPO_PATH = old_repo_path
|
||||
check_copies.PATH_TO_DOCS = old_doc_path
|
||||
check_copies.TRANSFORMERS_PATH = old_transformer_path
|
||||
|
||||
|
||||
class CopyCheckTester(unittest.TestCase):
|
||||
def test_find_code_in_transformers(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
code = find_code_in_transformers("models.bert.modeling_bert.BertAttention")
|
||||
|
||||
reference_code = (
|
||||
"class BertAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n"
|
||||
)
|
||||
self.assertEqual(code, reference_code)
|
||||
|
||||
def test_is_copy_consistent(self):
|
||||
path_to_check = ["src", "transformers", "models", "bertcopy", "modeling_bertcopy.py"]
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
# Base check
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
self.assertEqual(diffs, [])
|
||||
|
||||
# Base check with an inconsistency
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
|
||||
replace_in_file(file_to_check, "self.bertcopy(x)", "self.bert(x)")
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
|
||||
|
||||
_ = is_copy_consistent(file_to_check, overwrite=True)
|
||||
|
||||
with open(file_to_check, encoding="utf-8") as f:
|
||||
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
|
||||
|
||||
def test_is_copy_consistent_with_ignored_match(self):
|
||||
path_to_check = ["src", "transformers", "models", "dummy_roberta_match", "modeling_dummy_roberta_match.py"]
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
# Base check
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
self.assertEqual(diffs, [])
|
||||
|
||||
def test_is_copy_consistent_with_ignored_no_match(self):
|
||||
path_to_check = [
|
||||
"src",
|
||||
"transformers",
|
||||
"models",
|
||||
"dummy_roberta_no_match",
|
||||
"modeling_dummy_roberta_no_match.py",
|
||||
]
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
# Base check with an inconsistency
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
# line 6: `attr_2 = 3` in `MOCK_DUMMY_ROBERTA_CODE_NO_MATCH`.
|
||||
# (which has a leading `\n`.)
|
||||
self.assertEqual(
|
||||
diffs, [["models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel", 6]]
|
||||
)
|
||||
|
||||
_ = is_copy_consistent(file_to_check, overwrite=True)
|
||||
|
||||
with open(file_to_check, encoding="utf-8") as f:
|
||||
self.assertEqual(f.read(), EXPECTED_REPLACED_CODE)
|
||||
|
||||
def test_convert_to_localized_md(self):
|
||||
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
|
||||
|
||||
md_list = (
|
||||
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
|
||||
" Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
|
||||
" Self-supervised Learning of Language Representations](https://huggingface.co/papers/1909.11942), by Zhenzhong"
|
||||
" Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.\n1."
|
||||
" **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (from HuggingFace),"
|
||||
" released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
|
||||
" lighter](https://huggingface.co/papers/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same"
|
||||
" method has been applied to compress GPT2 into"
|
||||
" [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
|
||||
" [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
|
||||
" Multilingual BERT into"
|
||||
" [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
|
||||
" version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)**"
|
||||
" (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders"
|
||||
" as discriminators rather than generators](https://huggingface.co/papers/2003.10555) by Kevin Clark, Minh-Thang"
|
||||
" Luong, Quoc V. Le, Christopher D. Manning."
|
||||
)
|
||||
localized_md_list = (
|
||||
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (来自 Google Research and the"
|
||||
" Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
|
||||
" Language Representations](https://huggingface.co/papers/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
|
||||
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
|
||||
)
|
||||
converted_md_list_sample = (
|
||||
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (来自 Google Research and the"
|
||||
" Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
|
||||
" Language Representations](https://huggingface.co/papers/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
|
||||
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n1."
|
||||
" **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (来自 HuggingFace) 伴随论文"
|
||||
" [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
|
||||
" lighter](https://huggingface.co/papers/1910.01108) 由 Victor Sanh, Lysandre Debut and Thomas Wolf 发布。 The same"
|
||||
" method has been applied to compress GPT2 into"
|
||||
" [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
|
||||
" [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
|
||||
" Multilingual BERT into"
|
||||
" [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
|
||||
" version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (来自"
|
||||
" Google Research/Stanford University) 伴随论文 [ELECTRA: Pre-training text encoders as discriminators rather"
|
||||
" than generators](https://huggingface.co/papers/2003.10555) 由 Kevin Clark, Minh-Thang Luong, Quoc V. Le,"
|
||||
" Christopher D. Manning 发布。\n"
|
||||
)
|
||||
|
||||
num_models_equal, converted_md_list = convert_to_localized_md(
|
||||
md_list, localized_md_list, localized_readme["format_model_list"]
|
||||
)
|
||||
|
||||
self.assertFalse(num_models_equal)
|
||||
self.assertEqual(converted_md_list, converted_md_list_sample)
|
||||
|
||||
num_models_equal, converted_md_list = convert_to_localized_md(
|
||||
md_list, converted_md_list, localized_readme["format_model_list"]
|
||||
)
|
||||
|
||||
# Check whether the number of models is equal to README.md after conversion.
|
||||
self.assertTrue(num_models_equal)
|
||||
|
||||
link_changed_md_list = (
|
||||
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
|
||||
" Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
|
||||
" Self-supervised Learning of Language Representations](https://huggingface.co/papers/1909.11942), by Zhenzhong"
|
||||
" Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut."
|
||||
)
|
||||
link_unchanged_md_list = (
|
||||
"1. **[ALBERT](https://huggingface.co/transformers/main/model_doc/albert.html)** (来自 Google Research and"
|
||||
" the Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
|
||||
" Language Representations](https://huggingface.co/papers/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
|
||||
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
|
||||
)
|
||||
converted_md_list_sample = (
|
||||
"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (来自 Google Research and the"
|
||||
" Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
|
||||
" Language Representations](https://huggingface.co/papers/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
|
||||
" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
|
||||
)
|
||||
|
||||
num_models_equal, converted_md_list = convert_to_localized_md(
|
||||
link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
|
||||
)
|
||||
|
||||
# Check if the model link is synchronized.
|
||||
self.assertEqual(converted_md_list, converted_md_list_sample)
|
||||
347
tests/repo_utils/test_check_docstrings.py
Normal file
347
tests/repo_utils/test_check_docstrings.py
Normal file
@@ -0,0 +1,347 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import unittest
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
||||
|
||||
from check_docstrings import ( # noqa: E402
|
||||
_build_ast_indexes,
|
||||
_find_typed_dict_classes,
|
||||
_get_auto_docstring_names,
|
||||
get_default_description,
|
||||
has_auto_docstring_decorator,
|
||||
replace_default_in_arg_description,
|
||||
)
|
||||
|
||||
|
||||
class CheckDostringsTested(unittest.TestCase):
|
||||
def test_replace_default_in_arg_description(self):
|
||||
# Standard docstring with default.
|
||||
desc_with_default = "`float`, *optional*, defaults to 2.0"
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_default, 2.0), "`float`, *optional*, defaults to 2.0"
|
||||
)
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_default, 1.0), "`float`, *optional*, defaults to 1.0"
|
||||
)
|
||||
self.assertEqual(replace_default_in_arg_description(desc_with_default, inspect._empty), "`float`")
|
||||
|
||||
# Standard docstring with default but optional is not using the stars.
|
||||
desc_with_default_typo = "`float`, `optional`, defaults to 2.0"
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_default_typo, 2.0), "`float`, *optional*, defaults to 2.0"
|
||||
)
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_default_typo, 1.0), "`float`, *optional*, defaults to 1.0"
|
||||
)
|
||||
|
||||
# If the default is None we do not erase the value in the docstring.
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_default, None), "`float`, *optional*, defaults to 2.0"
|
||||
)
|
||||
# If the default is None (and set as such in the docstring), we do not include it.
|
||||
desc_with_default = "`float`, *optional*, defaults to None"
|
||||
self.assertEqual(replace_default_in_arg_description(desc_with_default, None), "`float`, *optional*")
|
||||
desc_with_default = "`float`, *optional*, defaults to `None`"
|
||||
self.assertEqual(replace_default_in_arg_description(desc_with_default, None), "`float`, *optional*")
|
||||
|
||||
# Operations are not replaced, but put in backtiks.
|
||||
desc_with_default = "`float`, *optional*, defaults to 1/255"
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_default, 1 / 255), "`float`, *optional*, defaults to `1/255`"
|
||||
)
|
||||
desc_with_default = "`float`, *optional*, defaults to `1/255`"
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_default, 1 / 255), "`float`, *optional*, defaults to `1/255`"
|
||||
)
|
||||
|
||||
desc_with_optional = "`float`, *optional*"
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_optional, 2.0), "`float`, *optional*, defaults to 2.0"
|
||||
)
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_optional, 1.0), "`float`, *optional*, defaults to 1.0"
|
||||
)
|
||||
self.assertEqual(replace_default_in_arg_description(desc_with_optional, None), "`float`, *optional*")
|
||||
self.assertEqual(replace_default_in_arg_description(desc_with_optional, inspect._empty), "`float`")
|
||||
|
||||
desc_with_no_optional = "`float`"
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_no_optional, 2.0), "`float`, *optional*, defaults to 2.0"
|
||||
)
|
||||
self.assertEqual(
|
||||
replace_default_in_arg_description(desc_with_no_optional, 1.0), "`float`, *optional*, defaults to 1.0"
|
||||
)
|
||||
self.assertEqual(replace_default_in_arg_description(desc_with_no_optional, None), "`float`, *optional*")
|
||||
self.assertEqual(replace_default_in_arg_description(desc_with_no_optional, inspect._empty), "`float`")
|
||||
|
||||
def test_get_default_description(self):
|
||||
# Fake function to have arguments to test.
|
||||
def _fake_function(a, b: int, c=1, d: float = 2.0, e: str = "blob"):
|
||||
pass
|
||||
|
||||
params = inspect.signature(_fake_function).parameters
|
||||
assert get_default_description(params["a"]) == "`<fill_type>`"
|
||||
assert get_default_description(params["b"]) == "`int`"
|
||||
assert get_default_description(params["c"]) == "`<fill_type>`, *optional*, defaults to 1"
|
||||
assert get_default_description(params["d"]) == "`float`, *optional*, defaults to 2.0"
|
||||
assert get_default_description(params["e"]) == '`str`, *optional*, defaults to `"blob"`'
|
||||
|
||||
|
||||
class TestGetAutoDocstringNames(unittest.TestCase):
|
||||
"""Tests for _get_auto_docstring_names and has_auto_docstring_decorator."""
|
||||
|
||||
def setUp(self):
|
||||
self.cache = {}
|
||||
|
||||
def _write_temp(self, source):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(source)
|
||||
self.addCleanup(os.unlink, f.name)
|
||||
return f.name
|
||||
|
||||
def test_detects_simple_decorator(self):
|
||||
"""Test that a class decorated with @auto_docstring is detected."""
|
||||
path = self._write_temp(
|
||||
textwrap.dedent("""\
|
||||
from transformers import auto_docstring
|
||||
|
||||
@auto_docstring
|
||||
class Foo:
|
||||
pass
|
||||
""")
|
||||
)
|
||||
names = _get_auto_docstring_names(path, cache=self.cache)
|
||||
self.assertEqual(names, {"Foo"})
|
||||
|
||||
def test_detects_decorator_with_call(self):
|
||||
"""Test that a class decorated with @auto_docstring(args) (called form) is detected."""
|
||||
path = self._write_temp(
|
||||
textwrap.dedent("""\
|
||||
@auto_docstring(custom_args='x')
|
||||
class Bar:
|
||||
pass
|
||||
""")
|
||||
)
|
||||
names = _get_auto_docstring_names(path, cache=self.cache)
|
||||
self.assertEqual(names, {"Bar"})
|
||||
|
||||
def test_ignores_other_decorators(self):
|
||||
"""Test that classes with non-auto_docstring decorators are not detected."""
|
||||
path = self._write_temp(
|
||||
textwrap.dedent("""\
|
||||
@dataclass
|
||||
class Baz:
|
||||
pass
|
||||
""")
|
||||
)
|
||||
names = _get_auto_docstring_names(path, cache=self.cache)
|
||||
self.assertEqual(names, set())
|
||||
|
||||
def test_multiple_classes(self):
|
||||
"""Test that only decorated classes and functions are returned when multiple definitions exist."""
|
||||
path = self._write_temp(
|
||||
textwrap.dedent("""\
|
||||
@auto_docstring
|
||||
class A:
|
||||
pass
|
||||
|
||||
class B:
|
||||
pass
|
||||
|
||||
@auto_docstring()
|
||||
def func_c():
|
||||
pass
|
||||
""")
|
||||
)
|
||||
names = _get_auto_docstring_names(path, cache=self.cache)
|
||||
self.assertEqual(names, {"A", "func_c"})
|
||||
|
||||
def test_caching(self):
|
||||
"""Test that repeated calls for the same file return the cached (identical) result object."""
|
||||
path = self._write_temp(
|
||||
textwrap.dedent("""\
|
||||
@auto_docstring
|
||||
class X:
|
||||
pass
|
||||
""")
|
||||
)
|
||||
result1 = _get_auto_docstring_names(path, cache=self.cache)
|
||||
result2 = _get_auto_docstring_names(path, cache=self.cache)
|
||||
self.assertIs(result1, result2)
|
||||
|
||||
def test_syntax_error_returns_empty(self):
|
||||
"""Test that a file with a syntax error returns an empty set instead of raising."""
|
||||
path = self._write_temp("def broken(\n")
|
||||
names = _get_auto_docstring_names(path, cache=self.cache)
|
||||
self.assertEqual(names, set())
|
||||
|
||||
def test_has_auto_docstring_decorator_uses_cache(self):
|
||||
"""Test that has_auto_docstring_decorator looks up names from the pre-populated cache."""
|
||||
from unittest.mock import patch
|
||||
|
||||
path = self._write_temp(
|
||||
textwrap.dedent("""\
|
||||
@auto_docstring
|
||||
class Cached:
|
||||
pass
|
||||
""")
|
||||
)
|
||||
self.cache[path] = {"Cached"}
|
||||
|
||||
# Create classes whose __name__ matches/doesn't match the cache
|
||||
Cached = type("Cached", (), {})
|
||||
Other = type("Other", (), {})
|
||||
|
||||
with patch.object(inspect, "getfile", return_value=path):
|
||||
self.assertTrue(has_auto_docstring_decorator(Cached, cache=self.cache))
|
||||
self.assertFalse(has_auto_docstring_decorator(Other, cache=self.cache))
|
||||
|
||||
|
||||
class TestBuildAstIndexes(unittest.TestCase):
|
||||
"""Tests for _build_ast_indexes with pre-parsed tree."""
|
||||
|
||||
def test_finds_decorated_items(self):
|
||||
"""Test that _build_ast_indexes finds a decorated class and extracts its __init__ args."""
|
||||
source = textwrap.dedent("""\
|
||||
@auto_docstring
|
||||
class MyModel:
|
||||
def __init__(self, hidden_size=768):
|
||||
self.hidden_size = hidden_size
|
||||
""")
|
||||
items = _build_ast_indexes(source)
|
||||
self.assertEqual(len(items), 1)
|
||||
self.assertEqual(items[0].name, "MyModel")
|
||||
self.assertEqual(items[0].kind, "class")
|
||||
self.assertIn("hidden_size", items[0].args)
|
||||
|
||||
def test_shared_tree(self):
|
||||
"""Test that passing a pre-parsed AST tree produces the same results as letting the function parse internally."""
|
||||
source = textwrap.dedent("""\
|
||||
@auto_docstring
|
||||
class A:
|
||||
pass
|
||||
""")
|
||||
tree = ast.parse(source)
|
||||
items_with_tree = _build_ast_indexes(source, tree=tree)
|
||||
items_without = _build_ast_indexes(source)
|
||||
self.assertEqual(len(items_with_tree), len(items_without))
|
||||
self.assertEqual(items_with_tree[0].name, items_without[0].name)
|
||||
|
||||
def test_no_decorated_items(self):
|
||||
"""Test that a class without the auto_docstring decorator is not indexed."""
|
||||
source = textwrap.dedent("""\
|
||||
class Plain:
|
||||
pass
|
||||
""")
|
||||
items = _build_ast_indexes(source)
|
||||
self.assertEqual(items, [])
|
||||
|
||||
def test_function_decorated(self):
|
||||
"""Test that a decorated function is indexed with its arguments."""
|
||||
source = textwrap.dedent("""\
|
||||
@auto_docstring
|
||||
def my_func(x, y=10):
|
||||
pass
|
||||
""")
|
||||
items = _build_ast_indexes(source)
|
||||
self.assertEqual(len(items), 1)
|
||||
self.assertEqual(items[0].name, "my_func")
|
||||
self.assertEqual(items[0].kind, "function")
|
||||
self.assertIn("x", items[0].args)
|
||||
self.assertIn("y", items[0].args)
|
||||
|
||||
def test_custom_args_from_variable(self):
|
||||
"""Test that custom_args passed as a module-level variable are resolved to their string value."""
|
||||
source = textwrap.dedent("""\
|
||||
MY_ARGS = "custom param docs"
|
||||
|
||||
@auto_docstring(custom_args=MY_ARGS)
|
||||
class WithCustom:
|
||||
def __init__(self):
|
||||
pass
|
||||
""")
|
||||
items = _build_ast_indexes(source)
|
||||
self.assertEqual(len(items), 1)
|
||||
self.assertEqual(items[0].custom_args_text, "custom param docs")
|
||||
|
||||
|
||||
class TestFindTypedDictClasses(unittest.TestCase):
|
||||
"""Tests for _find_typed_dict_classes with pre-parsed tree."""
|
||||
|
||||
def test_finds_typed_dict(self):
|
||||
"""Test that a TypedDict subclass is found and its public fields are extracted."""
|
||||
source = textwrap.dedent("""\
|
||||
from typing import TypedDict
|
||||
|
||||
class MyKwargs(TypedDict):
|
||||
field_a: str
|
||||
field_b: int
|
||||
""")
|
||||
result = _find_typed_dict_classes(source)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["name"], "MyKwargs")
|
||||
self.assertIn("field_a", result[0]["all_fields"])
|
||||
self.assertIn("field_b", result[0]["all_fields"])
|
||||
|
||||
def test_shared_tree(self):
|
||||
"""Test that passing a pre-parsed AST tree produces the same results as internal parsing."""
|
||||
source = textwrap.dedent("""\
|
||||
class MyKwargs(TypedDict):
|
||||
x: int
|
||||
""")
|
||||
tree = ast.parse(source)
|
||||
r1 = _find_typed_dict_classes(source, tree=tree)
|
||||
r2 = _find_typed_dict_classes(source)
|
||||
self.assertEqual(len(r1), len(r2))
|
||||
self.assertEqual(r1[0]["name"], r2[0]["name"])
|
||||
|
||||
def test_skips_standard_kwargs(self):
|
||||
"""Test that well-known kwargs TypedDicts (e.g. TextKwargs) are excluded from results."""
|
||||
source = textwrap.dedent("""\
|
||||
class TextKwargs(TypedDict):
|
||||
field: str
|
||||
""")
|
||||
result = _find_typed_dict_classes(source)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_no_typed_dicts(self):
|
||||
"""Test that source with no TypedDict subclasses returns an empty list."""
|
||||
source = textwrap.dedent("""\
|
||||
class Regular:
|
||||
pass
|
||||
""")
|
||||
result = _find_typed_dict_classes(source)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
def test_skips_private_fields(self):
|
||||
"""Test that fields starting with an underscore are excluded from the extracted TypedDict fields."""
|
||||
source = textwrap.dedent("""\
|
||||
class MyKwargs(TypedDict):
|
||||
public: int
|
||||
_private: str
|
||||
""")
|
||||
result = _find_typed_dict_classes(source)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIn("public", result[0]["all_fields"])
|
||||
self.assertNotIn("_private", result[0]["all_fields"])
|
||||
74
tests/repo_utils/test_check_modular_conversion.py
Normal file
74
tests/repo_utils/test_check_modular_conversion.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
utils_path = os.path.join(git_repo_path, "utils")
|
||||
if utils_path not in sys.path:
|
||||
sys.path.append(utils_path)
|
||||
|
||||
import check_modular_conversion # noqa: E402
|
||||
|
||||
|
||||
class ConverterChangedInDiffTest(unittest.TestCase):
|
||||
"""Regression guard for PR #45492: changes to the converter alone must force a full check."""
|
||||
|
||||
def _patch_modified(self, files):
|
||||
return patch.object(check_modular_conversion, "_get_modified_files", return_value=files)
|
||||
|
||||
def test_returns_true_when_modular_model_converter_changed(self):
|
||||
with self._patch_modified(
|
||||
[
|
||||
"utils/modular_model_converter.py",
|
||||
"src/transformers/models/llava_onevision/modular_llava_onevision.py",
|
||||
]
|
||||
):
|
||||
self.assertTrue(check_modular_conversion.converter_changed_in_diff())
|
||||
|
||||
def test_returns_true_when_create_dependency_mapping_changed(self):
|
||||
with self._patch_modified(["utils/create_dependency_mapping.py"]):
|
||||
self.assertTrue(check_modular_conversion.converter_changed_in_diff())
|
||||
|
||||
def test_returns_false_for_model_only_diff(self):
|
||||
with self._patch_modified(
|
||||
[
|
||||
"src/transformers/models/llama/modular_llama.py",
|
||||
"src/transformers/models/llama/modeling_llama.py",
|
||||
]
|
||||
):
|
||||
self.assertFalse(check_modular_conversion.converter_changed_in_diff())
|
||||
|
||||
def test_returns_false_for_unrelated_utils_change(self):
|
||||
with self._patch_modified(["utils/check_modular_conversion.py", "utils/check_copies.py"]):
|
||||
self.assertFalse(check_modular_conversion.converter_changed_in_diff())
|
||||
|
||||
def test_converter_files_set_includes_expected_entries(self):
|
||||
# Keep the allow-list grounded: if either file is renamed/removed, this test fails loudly
|
||||
# so the detection logic is updated alongside the rename.
|
||||
self.assertIn("utils/modular_model_converter.py", check_modular_conversion.CONVERTER_FILES)
|
||||
self.assertIn("utils/create_dependency_mapping.py", check_modular_conversion.CONVERTER_FILES)
|
||||
for rel_path in check_modular_conversion.CONVERTER_FILES:
|
||||
self.assertTrue(
|
||||
os.path.exists(os.path.join(git_repo_path, rel_path)),
|
||||
f"{rel_path} listed in CONVERTER_FILES but does not exist on disk",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
151
tests/repo_utils/test_check_repo.py
Normal file
151
tests/repo_utils/test_check_repo.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
utils_path = os.path.join(git_repo_path, "utils")
|
||||
if utils_path not in sys.path:
|
||||
sys.path.append(utils_path)
|
||||
|
||||
import check_repo # noqa: E402
|
||||
|
||||
|
||||
class RecordingNamespace:
|
||||
"""Record directory listings and attribute access for cache tests."""
|
||||
|
||||
def __init__(self, mapping):
|
||||
self._mapping = mapping
|
||||
self.dir_calls = 0
|
||||
self.getattr_calls = []
|
||||
|
||||
def __dir__(self):
|
||||
self.dir_calls += 1
|
||||
return list(self._mapping.keys()) + ["__doc__"]
|
||||
|
||||
def __getattr__(self, name):
|
||||
self.getattr_calls.append(name)
|
||||
try:
|
||||
return self._mapping[name]
|
||||
except KeyError as error:
|
||||
raise AttributeError(name) from error
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_transformers_path(path: Path):
|
||||
"""Temporarily point `check_repo` at a temporary transformers source tree."""
|
||||
old_path = check_repo.PATH_TO_TRANSFORMERS
|
||||
check_repo.PATH_TO_TRANSFORMERS = str(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
check_repo.PATH_TO_TRANSFORMERS = old_path
|
||||
|
||||
|
||||
class CheckRepoTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Reset the `get_model_modules` cache before each test."""
|
||||
check_repo.get_model_modules.cache_clear()
|
||||
self.addCleanup(check_repo.get_model_modules.cache_clear)
|
||||
|
||||
def _write_modeling_file(self, root: Path, model_name: str, content: str) -> None:
|
||||
"""Create a temporary `modeling_*.py` file used by `check_models_have_kwargs`."""
|
||||
model_dir = root / "src" / "transformers" / "models" / model_name
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
(model_dir / f"modeling_{model_name}.py").write_text(content, encoding="utf-8")
|
||||
|
||||
def test_get_model_modules_is_cached(self):
|
||||
"""Repeated calls should reuse the cached module list instead of traversing models twice."""
|
||||
alpha_modeling = object()
|
||||
alpha_module = RecordingNamespace({"modeling_alpha": alpha_modeling, "configuration_alpha": object()})
|
||||
fake_models = RecordingNamespace({"alpha": alpha_module, "deprecated_alpha": object()})
|
||||
fake_transformers = SimpleNamespace(models=fake_models)
|
||||
|
||||
with patch.object(check_repo, "transformers", fake_transformers):
|
||||
first = check_repo.get_model_modules()
|
||||
second = check_repo.get_model_modules()
|
||||
|
||||
self.assertIs(first, second)
|
||||
self.assertEqual(first, [alpha_modeling])
|
||||
self.assertEqual(fake_models.dir_calls, 1)
|
||||
self.assertEqual(fake_models.getattr_calls, ["alpha"])
|
||||
self.assertEqual(alpha_module.dir_calls, 1)
|
||||
self.assertEqual(alpha_module.getattr_calls, ["modeling_alpha"])
|
||||
|
||||
def test_check_models_have_kwargs_ignores_nested_classes(self):
|
||||
"""Nested helper classes should not trigger missing-`**kwargs` failures."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
root = Path(tmpdir)
|
||||
self._write_modeling_file(
|
||||
root,
|
||||
"foo",
|
||||
"""
|
||||
class PreTrainedModel:
|
||||
pass
|
||||
|
||||
|
||||
class FooModel(PreTrainedModel):
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HelperContainer:
|
||||
class NestedModel(PreTrainedModel):
|
||||
def forward(self, hidden_states):
|
||||
return hidden_states
|
||||
""".strip()
|
||||
+ "\n",
|
||||
)
|
||||
|
||||
with patch_transformers_path(root / "src" / "transformers"):
|
||||
check_repo.check_models_have_kwargs()
|
||||
|
||||
def test_check_models_have_kwargs_still_checks_top_level_models(self):
|
||||
"""Top-level model classes should still fail when `forward()` omits `**kwargs`."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
root = Path(tmpdir)
|
||||
self._write_modeling_file(
|
||||
root,
|
||||
"foo",
|
||||
"""
|
||||
class PreTrainedModel:
|
||||
pass
|
||||
|
||||
|
||||
class FooModel(PreTrainedModel):
|
||||
def forward(self, hidden_states):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def make_nested_model():
|
||||
class NestedModel(PreTrainedModel):
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
return hidden_states
|
||||
|
||||
return NestedModel
|
||||
""".strip()
|
||||
+ "\n",
|
||||
)
|
||||
|
||||
with patch_transformers_path(root / "src" / "transformers"):
|
||||
with self.assertRaisesRegex(Exception, "FooModel"):
|
||||
check_repo.check_models_have_kwargs()
|
||||
198
tests/repo_utils/test_checkers.py
Normal file
198
tests/repo_utils/test_checkers.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
utils_path = os.path.join(git_repo_path, "utils")
|
||||
if utils_path not in sys.path:
|
||||
sys.path.append(utils_path)
|
||||
|
||||
import checkers # noqa: E402
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_checkers_paths(repo_root: Path):
|
||||
cache_path = repo_root / "utils" / ".checkers_cache.json"
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(patch.object(checkers, "REPO_ROOT", repo_root))
|
||||
stack.enter_context(patch.object(checkers, "CACHE_PATH", cache_path))
|
||||
stack.enter_context(patch.object(checkers, "CHECKERS", {"demo": ("Demo checker", "fake_checker.py", [], [])}))
|
||||
stack.enter_context(patch.object(checkers, "CHECKER_CACHE_GLOBS", {"demo": ["tracked/**/*.txt"]}))
|
||||
yield cache_path
|
||||
|
||||
|
||||
class CheckersCacheTest(unittest.TestCase):
|
||||
class _TTYStringIO(io.StringIO):
|
||||
def isatty(self) -> bool:
|
||||
return True
|
||||
|
||||
def _create_fake_repo(self, tmpdir: str) -> Path:
|
||||
"""Create a minimal repo layout for exercising checker cache inputs."""
|
||||
repo_root = Path(tmpdir)
|
||||
(repo_root / "tracked").mkdir()
|
||||
(repo_root / "tracked" / "input.txt").write_text("tracked\n", encoding="utf-8")
|
||||
(repo_root / "utils").mkdir()
|
||||
(repo_root / "utils" / "fake_checker.py").write_text("# fake checker\n", encoding="utf-8")
|
||||
return repo_root
|
||||
|
||||
def _run_main(self, *args: str, stdout=None) -> tuple[int | None, str]:
|
||||
"""Run `checkers.main()` with patched argv/stdout and return the exit code and captured output."""
|
||||
stdout = io.StringIO() if stdout is None else stdout
|
||||
with (
|
||||
patch.object(sys, "argv", ["checkers.py", *args]),
|
||||
patch.object(sys, "stdout", new=stdout),
|
||||
):
|
||||
exit_code = None
|
||||
try:
|
||||
checkers.main()
|
||||
except SystemExit as e:
|
||||
exit_code = e.code
|
||||
return exit_code, stdout.getvalue()
|
||||
|
||||
def test_checker_cache_detects_checker_script_changes(self):
|
||||
"""Cache entries should become stale when the checker implementation file changes."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_root = self._create_fake_repo(tmpdir)
|
||||
with patch_checkers_paths(repo_root) as cache_path:
|
||||
cache = checkers.CheckerCache(path=cache_path)
|
||||
self.assertFalse(cache.is_current("demo"))
|
||||
|
||||
cache.update("demo")
|
||||
self.assertTrue(cache.is_current("demo"))
|
||||
|
||||
(repo_root / "utils" / "fake_checker.py").write_text("# fake checker changed\n", encoding="utf-8")
|
||||
self.assertFalse(cache.is_current("demo"))
|
||||
|
||||
def test_main_skips_cached_runs(self):
|
||||
"""Main should reuse cached results for repeated runs."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_root = self._create_fake_repo(tmpdir)
|
||||
with (
|
||||
patch_checkers_paths(repo_root),
|
||||
patch.object(
|
||||
checkers,
|
||||
"run_checker",
|
||||
return_value=(0, "first run"),
|
||||
) as run_checker,
|
||||
):
|
||||
exit_code, _ = self._run_main("demo")
|
||||
self.assertIsNone(exit_code)
|
||||
self.assertEqual(run_checker.call_count, 1)
|
||||
|
||||
exit_code, output = self._run_main("demo")
|
||||
self.assertIsNone(exit_code)
|
||||
self.assertEqual(run_checker.call_count, 1)
|
||||
self.assertIn("(cached)", output)
|
||||
|
||||
def test_main_reruns_with_no_cache(self):
|
||||
"""Main should rerun when `--no-cache` is passed."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_root = self._create_fake_repo(tmpdir)
|
||||
with (
|
||||
patch_checkers_paths(repo_root),
|
||||
patch.object(
|
||||
checkers,
|
||||
"run_checker",
|
||||
side_effect=[(0, "first run"), (0, "forced rerun")],
|
||||
) as run_checker,
|
||||
):
|
||||
exit_code, _ = self._run_main("demo")
|
||||
self.assertIsNone(exit_code)
|
||||
self.assertEqual(run_checker.call_count, 1)
|
||||
|
||||
exit_code, _ = self._run_main("demo", "--no-cache")
|
||||
self.assertIsNone(exit_code)
|
||||
self.assertEqual(run_checker.call_count, 2)
|
||||
|
||||
def test_main_prints_full_output_on_failure_without_tty(self):
|
||||
"""Local non-TTY failures should print the full checker output instead of a cropped tail."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_root = self._create_fake_repo(tmpdir)
|
||||
output = "\n".join(f"line {i}" for i in range(12)) + "\n"
|
||||
with (
|
||||
patch.dict(os.environ, {"GITHUB_ACTIONS": "false", "CIRCLECI": "false"}),
|
||||
patch_checkers_paths(repo_root),
|
||||
patch.object(checkers, "run_checker", return_value=(1, output)),
|
||||
):
|
||||
exit_code, stdout = self._run_main("demo", "--keep-going")
|
||||
|
||||
self.assertEqual(exit_code, 1)
|
||||
self.assertIn("line 0", stdout)
|
||||
self.assertIn("line 11", stdout)
|
||||
|
||||
def test_main_prints_full_output_on_failure_with_tty(self):
|
||||
"""TTY failures should print the full checker output without reprinting the cropped window tail."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_root = self._create_fake_repo(tmpdir)
|
||||
output = "\n".join(f"line {i}" for i in range(12)) + "\n"
|
||||
|
||||
class FakeSlidingWindow:
|
||||
def __init__(self, label, max_lines=10):
|
||||
self.label = label
|
||||
self.max_lines = max_lines
|
||||
|
||||
def add_line(self, line):
|
||||
pass
|
||||
|
||||
def finish(self, success, elapsed=None, show_lines=True):
|
||||
print(f"window finished: {self.label} ({success}, {show_lines})")
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"GITHUB_ACTIONS": "false", "CIRCLECI": "false"}),
|
||||
patch_checkers_paths(repo_root),
|
||||
patch.object(checkers, "run_checker", return_value=(1, output)),
|
||||
patch.object(checkers, "SlidingWindow", FakeSlidingWindow),
|
||||
):
|
||||
exit_code, stdout = self._run_main("demo", "--keep-going", stdout=self._TTYStringIO())
|
||||
|
||||
self.assertEqual(exit_code, 1)
|
||||
self.assertIn("window finished: Demo checker (False, False)", stdout)
|
||||
self.assertIn("line 0", stdout)
|
||||
self.assertIn("line 11", stdout)
|
||||
|
||||
def test_main_prints_failure_suffix_in_ci(self):
|
||||
"""CI failures should still print any extra captured output that was not streamed live."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
repo_root = self._create_fake_repo(tmpdir)
|
||||
streamed_output = "line 0\nline 1\n"
|
||||
failure_suffix = "summary line\n"
|
||||
|
||||
def run_checker(name, fix=False, line_callback=None):
|
||||
self.assertEqual(name, "demo")
|
||||
self.assertFalse(fix)
|
||||
self.assertIsNotNone(line_callback)
|
||||
for line in streamed_output.splitlines(keepends=True):
|
||||
line_callback(line)
|
||||
return 1, streamed_output + failure_suffix
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"GITHUB_ACTIONS": "true", "CIRCLECI": "false"}),
|
||||
patch_checkers_paths(repo_root),
|
||||
patch.object(checkers, "run_checker", side_effect=run_checker),
|
||||
):
|
||||
exit_code, stdout = self._run_main("demo", "--keep-going")
|
||||
|
||||
self.assertEqual(exit_code, 1)
|
||||
self.assertIn("line 0", stdout)
|
||||
self.assertIn("line 1", stdout)
|
||||
self.assertIn("summary line", stdout)
|
||||
108
tests/repo_utils/test_get_test_info.py
Normal file
108
tests/repo_utils/test_get_test_info.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
||||
|
||||
import get_test_info # noqa: E402
|
||||
from get_test_info import ( # noqa: E402
|
||||
get_model_to_test_mapping,
|
||||
get_model_to_tester_mapping,
|
||||
get_test_to_tester_mapping,
|
||||
)
|
||||
|
||||
|
||||
BERT_TEST_FILE = os.path.join("tests", "models", "bert", "test_modeling_bert.py")
|
||||
BLIP_TEST_FILE = os.path.join("tests", "models", "blip", "test_modeling_blip.py")
|
||||
|
||||
|
||||
class GetTestInfoTester(unittest.TestCase):
|
||||
def test_get_test_to_tester_mapping(self):
|
||||
bert_test_tester_mapping = get_test_to_tester_mapping(BERT_TEST_FILE)
|
||||
blip_test_tester_mapping = get_test_to_tester_mapping(BLIP_TEST_FILE)
|
||||
|
||||
EXPECTED_BERT_MAPPING = {"BertModelTest": "BertModelTester"}
|
||||
|
||||
EXPECTED_BLIP_MAPPING = {
|
||||
"BlipModelTest": "BlipModelTester",
|
||||
"BlipTextImageModelTest": "BlipTextImageModelsModelTester",
|
||||
"BlipTextModelTest": "BlipTextModelTester",
|
||||
"BlipTextRetrievalModelTest": "BlipTextRetrievalModelTester",
|
||||
"BlipVQAModelTest": "BlipVQAModelTester",
|
||||
"BlipVisionModelTest": "BlipVisionModelTester",
|
||||
}
|
||||
|
||||
self.assertEqual(get_test_info.to_json(bert_test_tester_mapping), EXPECTED_BERT_MAPPING)
|
||||
self.assertEqual(get_test_info.to_json(blip_test_tester_mapping), EXPECTED_BLIP_MAPPING)
|
||||
|
||||
def test_get_model_to_test_mapping(self):
|
||||
bert_model_test_mapping = get_model_to_test_mapping(BERT_TEST_FILE)
|
||||
blip_model_test_mapping = get_model_to_test_mapping(BLIP_TEST_FILE)
|
||||
|
||||
EXPECTED_BERT_MAPPING = {
|
||||
"BertForMaskedLM": ["BertModelTest"],
|
||||
"BertForMultipleChoice": ["BertModelTest"],
|
||||
"BertForNextSentencePrediction": ["BertModelTest"],
|
||||
"BertForPreTraining": ["BertModelTest"],
|
||||
"BertForQuestionAnswering": ["BertModelTest"],
|
||||
"BertForSequenceClassification": ["BertModelTest"],
|
||||
"BertForTokenClassification": ["BertModelTest"],
|
||||
"BertLMHeadModel": ["BertModelTest"],
|
||||
"BertModel": ["BertModelTest"],
|
||||
}
|
||||
|
||||
EXPECTED_BLIP_MAPPING = {
|
||||
"BlipForConditionalGeneration": ["BlipTextImageModelTest"],
|
||||
"BlipForImageTextRetrieval": ["BlipTextRetrievalModelTest"],
|
||||
"BlipForQuestionAnswering": ["BlipVQAModelTest"],
|
||||
"BlipModel": ["BlipModelTest"],
|
||||
"BlipTextModel": ["BlipTextModelTest"],
|
||||
"BlipVisionModel": ["BlipVisionModelTest"],
|
||||
}
|
||||
|
||||
self.assertEqual(get_test_info.to_json(bert_model_test_mapping), EXPECTED_BERT_MAPPING)
|
||||
self.assertEqual(get_test_info.to_json(blip_model_test_mapping), EXPECTED_BLIP_MAPPING)
|
||||
|
||||
def test_get_model_to_tester_mapping(self):
|
||||
bert_model_tester_mapping = get_model_to_tester_mapping(BERT_TEST_FILE)
|
||||
blip_model_tester_mapping = get_model_to_tester_mapping(BLIP_TEST_FILE)
|
||||
|
||||
EXPECTED_BERT_MAPPING = {
|
||||
"BertForMaskedLM": ["BertModelTester"],
|
||||
"BertForMultipleChoice": ["BertModelTester"],
|
||||
"BertForNextSentencePrediction": ["BertModelTester"],
|
||||
"BertForPreTraining": ["BertModelTester"],
|
||||
"BertForQuestionAnswering": ["BertModelTester"],
|
||||
"BertForSequenceClassification": ["BertModelTester"],
|
||||
"BertForTokenClassification": ["BertModelTester"],
|
||||
"BertLMHeadModel": ["BertModelTester"],
|
||||
"BertModel": ["BertModelTester"],
|
||||
}
|
||||
|
||||
EXPECTED_BLIP_MAPPING = {
|
||||
"BlipForConditionalGeneration": ["BlipTextImageModelsModelTester"],
|
||||
"BlipForImageTextRetrieval": ["BlipTextRetrievalModelTester"],
|
||||
"BlipForQuestionAnswering": ["BlipVQAModelTester"],
|
||||
"BlipModel": ["BlipModelTester"],
|
||||
"BlipTextModel": ["BlipTextModelTester"],
|
||||
"BlipVisionModel": ["BlipVisionModelTester"],
|
||||
}
|
||||
|
||||
self.assertEqual(get_test_info.to_json(bert_model_tester_mapping), EXPECTED_BERT_MAPPING)
|
||||
self.assertEqual(get_test_info.to_json(blip_model_tester_mapping), EXPECTED_BLIP_MAPPING)
|
||||
806
tests/repo_utils/test_tests_fetcher.py
Normal file
806
tests/repo_utils/test_tests_fetcher.py
Normal file
@@ -0,0 +1,806 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from git import Repo
|
||||
|
||||
from transformers.testing_utils import CaptureStdout
|
||||
|
||||
|
||||
REPO_PATH = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.join(REPO_PATH, "utils"))
|
||||
|
||||
import tests_fetcher # noqa: E402
|
||||
from tests_fetcher import ( # noqa: E402
|
||||
checkout_commit,
|
||||
clean_code,
|
||||
create_reverse_dependency_map,
|
||||
create_reverse_dependency_tree,
|
||||
create_test_list_from_filter,
|
||||
diff_is_docstring_only,
|
||||
extract_imports,
|
||||
get_all_tests,
|
||||
get_diff,
|
||||
get_module_dependencies,
|
||||
get_repo_utils_tests,
|
||||
get_tree_starting_at,
|
||||
infer_tests_to_run,
|
||||
init_test_examples_dependencies,
|
||||
parse_commit_message,
|
||||
print_tree_deps_of,
|
||||
should_run_repo_utils_tests,
|
||||
)
|
||||
|
||||
|
||||
BERT_MODELING_FILE = "src/transformers/models/bert/modeling_bert.py"
|
||||
BERT_MODEL_FILE = """from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import is_torch_available
|
||||
from .configuration_bert import BertConfig
|
||||
|
||||
class BertModel:
|
||||
'''
|
||||
This is the docstring.
|
||||
'''
|
||||
This is the code
|
||||
"""
|
||||
|
||||
BERT_MODEL_FILE_NEW_DOCSTRING = """from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import is_torch_available
|
||||
from .configuration_bert import BertConfig
|
||||
|
||||
class BertModel:
|
||||
'''
|
||||
This is the docstring. It has been updated.
|
||||
'''
|
||||
This is the code
|
||||
"""
|
||||
|
||||
BERT_MODEL_FILE_NEW_CODE = """from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import is_torch_available
|
||||
from .configuration_bert import BertConfig
|
||||
|
||||
class BertModel:
|
||||
'''
|
||||
This is the docstring.
|
||||
'''
|
||||
This is the code. It has been updated
|
||||
"""
|
||||
|
||||
|
||||
def create_tmp_repo(tmp_dir, models=None):
|
||||
"""
|
||||
Creates a repository in a temporary directory mimicking the structure of Transformers. Uses the list of models
|
||||
provided (which defaults to just `["bert"]`).
|
||||
"""
|
||||
tmp_dir = Path(tmp_dir)
|
||||
if tmp_dir.exists():
|
||||
shutil.rmtree(tmp_dir)
|
||||
tmp_dir.mkdir(exist_ok=True)
|
||||
repo = Repo.init(tmp_dir)
|
||||
|
||||
if models is None:
|
||||
models = ["bert"]
|
||||
class_names = [model[0].upper() + model[1:] for model in models]
|
||||
|
||||
transformers_dir = tmp_dir / "src" / "transformers"
|
||||
transformers_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(transformers_dir / "__init__.py", "w") as f:
|
||||
init_lines = ["from .utils import cached_file, is_torch_available"]
|
||||
init_lines.extend(
|
||||
[f"from .models.{model} import {cls}Config, {cls}Model" for model, cls in zip(models, class_names)]
|
||||
)
|
||||
f.write("\n".join(init_lines) + "\n")
|
||||
with open(transformers_dir / "configuration_utils.py", "w") as f:
|
||||
f.write("from .utils import cached_file\n\ncode")
|
||||
with open(transformers_dir / "modeling_utils.py", "w") as f:
|
||||
f.write("from .utils import cached_file\n\ncode")
|
||||
|
||||
utils_dir = tmp_dir / "src" / "transformers" / "utils"
|
||||
utils_dir.mkdir(exist_ok=True)
|
||||
with open(utils_dir / "__init__.py", "w") as f:
|
||||
f.write("from .hub import cached_file\nfrom .imports import is_torch_available\n")
|
||||
with open(utils_dir / "hub.py", "w") as f:
|
||||
f.write("import huggingface_hub\n\ncode")
|
||||
with open(utils_dir / "imports.py", "w") as f:
|
||||
f.write("code")
|
||||
|
||||
model_dir = tmp_dir / "src" / "transformers" / "models"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(model_dir / "__init__.py", "w") as f:
|
||||
f.write("\n".join([f"import {model}" for model in models]))
|
||||
|
||||
for model, cls in zip(models, class_names):
|
||||
model_dir = tmp_dir / "src" / "transformers" / "models" / model
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(model_dir / "__init__.py", "w") as f:
|
||||
f.write(f"from .configuration_{model} import {cls}Config\nfrom .modeling_{model} import {cls}Model\n")
|
||||
with open(model_dir / f"configuration_{model}.py", "w") as f:
|
||||
f.write("from ...configuration_utils import PreTrainedConfig\ncode")
|
||||
with open(model_dir / f"modeling_{model}.py", "w") as f:
|
||||
modeling_code = BERT_MODEL_FILE.replace("bert", model).replace("Bert", cls)
|
||||
f.write(modeling_code)
|
||||
|
||||
test_dir = tmp_dir / "tests"
|
||||
test_dir.mkdir(exist_ok=True)
|
||||
with open(test_dir / "test_modeling_common.py", "w") as f:
|
||||
f.write("from transformers.modeling_utils import PreTrainedModel\ncode")
|
||||
|
||||
for model, cls in zip(models, class_names):
|
||||
test_model_dir = test_dir / "models" / model
|
||||
test_model_dir.mkdir(parents=True, exist_ok=True)
|
||||
(test_model_dir / "__init__.py").touch()
|
||||
with open(test_model_dir / f"test_modeling_{model}.py", "w") as f:
|
||||
f.write(
|
||||
f"from transformers import {cls}Config, {cls}Model\nfrom ...test_modeling_common import ModelTesterMixin\n\ncode"
|
||||
)
|
||||
|
||||
example_dir = tmp_dir / "examples"
|
||||
example_dir.mkdir(exist_ok=True)
|
||||
framework_dir = example_dir / "pytorch"
|
||||
framework_dir.mkdir(exist_ok=True)
|
||||
with open(framework_dir / "test_pytorch_examples.py", "w") as f:
|
||||
f.write("""test_args = "run_glue.py"\n""")
|
||||
glue_dir = framework_dir / "text-classification"
|
||||
glue_dir.mkdir(exist_ok=True)
|
||||
with open(glue_dir / "run_glue.py", "w") as f:
|
||||
f.write("from transformers import BertModel\n\ncode")
|
||||
|
||||
repo.index.add(["examples", "src", "tests"])
|
||||
repo.index.commit("Initial commit")
|
||||
if "main" not in repo.heads:
|
||||
repo.create_head("main")
|
||||
repo.head.reference = repo.refs.main
|
||||
if "master" in repo.heads:
|
||||
repo.delete_head("master")
|
||||
return repo
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_transformer_repo_path(new_folder):
|
||||
"""
|
||||
Temporarily patches the variables defines in `tests_fetcher` to use a different location for the repo.
|
||||
"""
|
||||
old_repo_path = tests_fetcher.PATH_TO_REPO
|
||||
tests_fetcher.PATH_TO_REPO = Path(new_folder).resolve()
|
||||
tests_fetcher.PATH_TO_EXAMPLES = tests_fetcher.PATH_TO_REPO / "examples"
|
||||
tests_fetcher.PATH_TO_TRANSFORMERS = tests_fetcher.PATH_TO_REPO / "src/transformers"
|
||||
tests_fetcher.PATH_TO_TESTS = tests_fetcher.PATH_TO_REPO / "tests"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
tests_fetcher.PATH_TO_REPO = old_repo_path
|
||||
tests_fetcher.PATH_TO_EXAMPLES = tests_fetcher.PATH_TO_REPO / "examples"
|
||||
tests_fetcher.PATH_TO_TRANSFORMERS = tests_fetcher.PATH_TO_REPO / "src/transformers"
|
||||
tests_fetcher.PATH_TO_TESTS = tests_fetcher.PATH_TO_REPO / "tests"
|
||||
|
||||
|
||||
def commit_changes(filenames, contents, repo, commit_message="Commit"):
|
||||
"""
|
||||
Commit new `contents` to `filenames` inside a given `repo`.
|
||||
"""
|
||||
if not isinstance(filenames, list):
|
||||
filenames = [filenames]
|
||||
if not isinstance(contents, list):
|
||||
contents = [contents]
|
||||
|
||||
folder = Path(repo.working_dir)
|
||||
for filename, content in zip(filenames, contents):
|
||||
with open(folder / filename, "w") as f:
|
||||
f.write(content)
|
||||
repo.index.add(filenames)
|
||||
commit = repo.index.commit(commit_message)
|
||||
return commit.hexsha
|
||||
|
||||
|
||||
class TestFetcherTester(unittest.TestCase):
|
||||
def test_checkout_commit(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
repo = create_tmp_repo(tmp_folder)
|
||||
initial_sha = repo.head.commit.hexsha
|
||||
new_sha = commit_changes(BERT_MODELING_FILE, BERT_MODEL_FILE_NEW_DOCSTRING, repo)
|
||||
|
||||
assert repo.head.commit.hexsha == new_sha
|
||||
with checkout_commit(repo, initial_sha):
|
||||
assert repo.head.commit.hexsha == initial_sha
|
||||
with open(tmp_folder / BERT_MODELING_FILE) as f:
|
||||
assert f.read() == BERT_MODEL_FILE
|
||||
|
||||
assert repo.head.commit.hexsha == new_sha
|
||||
with open(tmp_folder / BERT_MODELING_FILE) as f:
|
||||
assert f.read() == BERT_MODEL_FILE_NEW_DOCSTRING
|
||||
|
||||
def test_clean_code(self):
|
||||
# Clean code removes all strings in triple quotes
|
||||
assert clean_code('"""\nDocstring\n"""\ncode\n"""Long string"""\ncode\n') == "code\ncode"
|
||||
assert clean_code("'''\nDocstring\n'''\ncode\n'''Long string'''\ncode\n'''") == "code\ncode"
|
||||
|
||||
# Clean code removes all comments
|
||||
assert clean_code("code\n# Comment\ncode") == "code\ncode"
|
||||
assert clean_code("code # inline comment\ncode") == "code \ncode"
|
||||
|
||||
def test_get_all_tests(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert get_all_tests() == ["tests/models/bert", "tests/test_modeling_common.py"]
|
||||
|
||||
def test_get_all_tests_on_full_repo(self):
|
||||
all_tests = get_all_tests()
|
||||
assert "tests/models/albert" in all_tests
|
||||
assert "tests/models/bert" in all_tests
|
||||
assert "tests/repo_utils" in all_tests
|
||||
assert "tests/test_pipeline_mixin.py" in all_tests
|
||||
assert "tests/models" not in all_tests
|
||||
assert "tests/__pycache__" not in all_tests
|
||||
assert "tests/models/albert/test_modeling_albert.py" not in all_tests
|
||||
assert "tests/repo_utils/test_tests_fetcher.py" not in all_tests
|
||||
|
||||
def test_get_repo_utils_tests_on_full_repo(self):
|
||||
repo_utils_tests = get_repo_utils_tests()
|
||||
assert "tests/repo_utils/test_tests_fetcher.py" in repo_utils_tests
|
||||
|
||||
def test_should_run_repo_utils_tests(self):
|
||||
assert should_run_repo_utils_tests(["utils/check_modeling_structure.py"])
|
||||
assert not should_run_repo_utils_tests(["src/transformers/modeling_utils.py"])
|
||||
|
||||
def test_create_test_list_from_filter_routes_repo_utils_tests(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
create_test_list_from_filter(
|
||||
[
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"tests/repo_utils/test_tests_fetcher.py",
|
||||
],
|
||||
out_path=tmp_folder,
|
||||
)
|
||||
|
||||
with open(Path(tmp_folder) / "tests_repo_utils_test_list.txt", encoding="utf-8") as f:
|
||||
repo_utils_tests = f.read().splitlines()
|
||||
|
||||
assert repo_utils_tests == [
|
||||
"tests/repo_utils/test_tests_fetcher.py",
|
||||
]
|
||||
|
||||
def test_create_test_list_from_filter_does_not_create_hub_job(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
create_test_list_from_filter(["tests/models/bert/test_modeling_bert.py"], out_path=tmp_folder)
|
||||
|
||||
assert (Path(tmp_folder) / "tests_torch_test_list.txt").exists()
|
||||
assert not (Path(tmp_folder) / "tests_hub_test_list.txt").exists()
|
||||
|
||||
def test_infer_tests_to_run_adds_repo_utils_for_utils_changes(self):
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(patch.object(tests_fetcher, "commit_flags", {"test_all": False}, create=True))
|
||||
stack.enter_context(
|
||||
patch.object(
|
||||
tests_fetcher, "get_modified_python_files", return_value=["utils/check_modeling_structure.py"]
|
||||
)
|
||||
)
|
||||
stack.enter_context(patch.object(tests_fetcher, "create_reverse_dependency_map", return_value={}))
|
||||
stack.enter_context(
|
||||
patch.object(tests_fetcher, "get_impacted_files_from_tiny_model_summary", return_value=[])
|
||||
)
|
||||
mock_create_test_list = stack.enter_context(patch.object(tests_fetcher, "create_test_list_from_filter"))
|
||||
stack.enter_context(patch.object(tests_fetcher, "get_doctest_files", return_value=[]))
|
||||
infer_tests_to_run("unused.txt", diff_with_last_commit=True)
|
||||
|
||||
test_files_to_run = mock_create_test_list.call_args.args[0]
|
||||
assert "tests/repo_utils/test_tests_fetcher.py" in test_files_to_run
|
||||
|
||||
def test_diff_is_docstring_only(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
repo = create_tmp_repo(tmp_folder)
|
||||
|
||||
branching_point = repo.refs.main.commit
|
||||
bert_file = BERT_MODELING_FILE
|
||||
commit_changes(bert_file, BERT_MODEL_FILE_NEW_DOCSTRING, repo)
|
||||
assert diff_is_docstring_only(repo, branching_point, bert_file)
|
||||
|
||||
commit_changes(bert_file, BERT_MODEL_FILE_NEW_CODE, repo)
|
||||
assert not diff_is_docstring_only(repo, branching_point, bert_file)
|
||||
|
||||
def test_get_diff_ignores_docstring_only_changes(self):
|
||||
"""Files whose diff is only in docstrings/comments should be excluded from get_diff results."""
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
repo = create_tmp_repo(tmp_folder)
|
||||
branching_commit = repo.head.commit
|
||||
|
||||
# Docstring-only change: should NOT appear in diff
|
||||
commit_changes(BERT_MODELING_FILE, BERT_MODEL_FILE_NEW_DOCSTRING, repo)
|
||||
diff = get_diff(repo, repo.head.commit, [branching_commit])
|
||||
assert BERT_MODELING_FILE not in diff
|
||||
|
||||
# Real code change: should appear in diff
|
||||
commit_changes(BERT_MODELING_FILE, BERT_MODEL_FILE_NEW_CODE, repo)
|
||||
diff = get_diff(repo, repo.head.commit, [branching_commit])
|
||||
assert BERT_MODELING_FILE in diff
|
||||
|
||||
def test_extract_imports_relative(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
create_tmp_repo(tmp_folder)
|
||||
|
||||
expected_bert_imports = [
|
||||
("src/transformers/modeling_utils.py", ["PreTrainedModel"]),
|
||||
("src/transformers/utils/__init__.py", ["is_torch_available"]),
|
||||
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||
]
|
||||
expected_utils_imports = [
|
||||
("src/transformers/utils/hub.py", ["cached_file"]),
|
||||
("src/transformers/utils/imports.py", ["is_torch_available"]),
|
||||
]
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||
assert extract_imports("src/transformers/utils/__init__.py") == expected_utils_imports
|
||||
|
||||
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||
f.write(
|
||||
"from ...utils import cached_file, is_torch_available\nfrom .configuration_bert import BertConfig\n"
|
||||
)
|
||||
expected_bert_imports = [
|
||||
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||
]
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||
|
||||
# Test with multi-line imports
|
||||
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||
f.write(
|
||||
"from ...utils import (\n cached_file,\n is_torch_available\n)\nfrom .configuration_bert import BertConfig\n"
|
||||
)
|
||||
expected_bert_imports = [
|
||||
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||
]
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||
|
||||
def test_extract_imports_absolute(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
create_tmp_repo(tmp_folder)
|
||||
|
||||
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||
f.write(
|
||||
"from transformers.utils import cached_file, is_torch_available\nfrom transformers.models.bert.configuration_bert import BertConfig\n"
|
||||
)
|
||||
expected_bert_imports = [
|
||||
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||
]
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||
|
||||
# Test with multi-line imports
|
||||
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||
f.write(
|
||||
"from transformers.utils import (\n cached_file,\n is_torch_available\n)\nfrom transformers.models.bert.configuration_bert import BertConfig\n"
|
||||
)
|
||||
expected_bert_imports = [
|
||||
("src/transformers/models/bert/configuration_bert.py", ["BertConfig"]),
|
||||
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||
]
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||
|
||||
# Test with base imports
|
||||
with open(tmp_folder / BERT_MODELING_FILE, "w") as f:
|
||||
f.write(
|
||||
"from transformers.utils import (\n cached_file,\n is_torch_available\n)\nfrom transformers import BertConfig\n"
|
||||
)
|
||||
expected_bert_imports = [
|
||||
("src/transformers/__init__.py", ["BertConfig"]),
|
||||
("src/transformers/utils/__init__.py", ["cached_file", "is_torch_available"]),
|
||||
]
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert extract_imports(BERT_MODELING_FILE) == expected_bert_imports
|
||||
|
||||
def test_get_module_dependencies(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
create_tmp_repo(tmp_folder)
|
||||
|
||||
expected_bert_dependencies = [
|
||||
"src/transformers/modeling_utils.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/utils/imports.py",
|
||||
]
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert get_module_dependencies(BERT_MODELING_FILE) == expected_bert_dependencies
|
||||
|
||||
expected_test_bert_dependencies = [
|
||||
"tests/test_modeling_common.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/models/bert/modeling_bert.py",
|
||||
]
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert (
|
||||
get_module_dependencies("tests/models/bert/test_modeling_bert.py")
|
||||
== expected_test_bert_dependencies
|
||||
)
|
||||
|
||||
# Test with a submodule
|
||||
(tmp_folder / "src/transformers/utils/logging.py").touch()
|
||||
with open(tmp_folder / BERT_MODELING_FILE, "a") as f:
|
||||
f.write("from ...utils import logging\n")
|
||||
|
||||
expected_bert_dependencies = [
|
||||
"src/transformers/modeling_utils.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/utils/logging.py",
|
||||
"src/transformers/utils/imports.py",
|
||||
]
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert get_module_dependencies(BERT_MODELING_FILE) == expected_bert_dependencies
|
||||
|
||||
# Test with an object non-imported in the init
|
||||
create_tmp_repo(tmp_folder)
|
||||
with open(tmp_folder / BERT_MODELING_FILE, "a") as f:
|
||||
f.write("from ...utils import CONSTANT\n")
|
||||
|
||||
expected_bert_dependencies = [
|
||||
"src/transformers/modeling_utils.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/utils/__init__.py",
|
||||
"src/transformers/utils/imports.py",
|
||||
]
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert get_module_dependencies(BERT_MODELING_FILE) == expected_bert_dependencies
|
||||
|
||||
# Test with an example
|
||||
create_tmp_repo(tmp_folder)
|
||||
|
||||
expected_example_dependencies = ["src/transformers/models/bert/modeling_bert.py"]
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
assert (
|
||||
get_module_dependencies("examples/pytorch/text-classification/run_glue.py")
|
||||
== expected_example_dependencies
|
||||
)
|
||||
|
||||
def test_create_reverse_dependency_tree(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
tree = create_reverse_dependency_tree()
|
||||
|
||||
init_edges = [
|
||||
"src/transformers/utils/hub.py",
|
||||
"src/transformers/utils/imports.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/models/bert/modeling_bert.py",
|
||||
]
|
||||
assert {f for f, g in tree if g == "src/transformers/__init__.py"} == set(init_edges)
|
||||
|
||||
bert_edges = [
|
||||
"src/transformers/modeling_utils.py",
|
||||
"src/transformers/utils/imports.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
]
|
||||
assert {f for f, g in tree if g == "src/transformers/models/bert/modeling_bert.py"} == set(bert_edges)
|
||||
|
||||
test_bert_edges = [
|
||||
"tests/test_modeling_common.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/models/bert/modeling_bert.py",
|
||||
]
|
||||
assert {f for f, g in tree if g == "tests/models/bert/test_modeling_bert.py"} == set(test_bert_edges)
|
||||
|
||||
def test_get_tree_starting_at(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
edges = create_reverse_dependency_tree()
|
||||
|
||||
bert_tree = get_tree_starting_at("src/transformers/models/bert/modeling_bert.py", edges)
|
||||
config_utils_tree = get_tree_starting_at("src/transformers/configuration_utils.py", edges)
|
||||
|
||||
expected_bert_tree = [
|
||||
"src/transformers/models/bert/modeling_bert.py",
|
||||
[("src/transformers/models/bert/modeling_bert.py", "tests/models/bert/test_modeling_bert.py")],
|
||||
]
|
||||
assert bert_tree == expected_bert_tree
|
||||
|
||||
expected_config_tree = [
|
||||
"src/transformers/configuration_utils.py",
|
||||
[("src/transformers/configuration_utils.py", "src/transformers/models/bert/configuration_bert.py")],
|
||||
[
|
||||
("src/transformers/models/bert/configuration_bert.py", "tests/models/bert/test_modeling_bert.py"),
|
||||
(
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/models/bert/modeling_bert.py",
|
||||
),
|
||||
],
|
||||
]
|
||||
# Order of the edges is random
|
||||
assert [set(v) for v in config_utils_tree] == [set(v) for v in expected_config_tree]
|
||||
|
||||
def test_print_tree_deps_of(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
create_tmp_repo(tmp_folder)
|
||||
|
||||
# There are two possible outputs since the order of the last two lines is non-deterministic.
|
||||
expected_std_out = """src/transformers/models/bert/modeling_bert.py
|
||||
tests/models/bert/test_modeling_bert.py
|
||||
src/transformers/configuration_utils.py
|
||||
src/transformers/models/bert/configuration_bert.py
|
||||
src/transformers/models/bert/modeling_bert.py
|
||||
tests/models/bert/test_modeling_bert.py"""
|
||||
|
||||
expected_std_out_2 = """src/transformers/models/bert/modeling_bert.py
|
||||
tests/models/bert/test_modeling_bert.py
|
||||
src/transformers/configuration_utils.py
|
||||
src/transformers/models/bert/configuration_bert.py
|
||||
tests/models/bert/test_modeling_bert.py
|
||||
src/transformers/models/bert/modeling_bert.py"""
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder), CaptureStdout() as cs:
|
||||
print_tree_deps_of("src/transformers/models/bert/modeling_bert.py")
|
||||
print_tree_deps_of("src/transformers/configuration_utils.py")
|
||||
|
||||
assert cs.out.strip() in [expected_std_out, expected_std_out_2]
|
||||
|
||||
def test_init_test_examples_dependencies(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder).resolve()
|
||||
create_tmp_repo(tmp_folder)
|
||||
|
||||
expected_example_deps = {
|
||||
"examples/pytorch/test_pytorch_examples.py": [
|
||||
"examples/pytorch/text-classification/run_glue.py",
|
||||
"examples/pytorch/test_pytorch_examples.py",
|
||||
],
|
||||
}
|
||||
|
||||
expected_examples = {
|
||||
"examples/pytorch/test_pytorch_examples.py",
|
||||
"examples/pytorch/text-classification/run_glue.py",
|
||||
}
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
example_deps, all_examples = init_test_examples_dependencies()
|
||||
assert example_deps == expected_example_deps
|
||||
assert {str(f.relative_to(tmp_folder)) for f in all_examples} == expected_examples
|
||||
|
||||
def test_create_reverse_dependency_map(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
reverse_map = create_reverse_dependency_map()
|
||||
|
||||
# impact of BERT modeling file (note that we stop at the inits and don't go down further)
|
||||
expected_bert_deps = {
|
||||
"src/transformers/__init__.py",
|
||||
"src/transformers/models/bert/__init__.py",
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"examples/pytorch/test_pytorch_examples.py",
|
||||
"examples/pytorch/text-classification/run_glue.py",
|
||||
}
|
||||
assert set(reverse_map["src/transformers/models/bert/modeling_bert.py"]) == expected_bert_deps
|
||||
|
||||
# init gets the direct deps (and their recursive deps)
|
||||
expected_init_deps = {
|
||||
"src/transformers/utils/__init__.py",
|
||||
"src/transformers/utils/hub.py",
|
||||
"src/transformers/utils/imports.py",
|
||||
"src/transformers/models/bert/__init__.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/models/bert/modeling_bert.py",
|
||||
"src/transformers/configuration_utils.py",
|
||||
"src/transformers/modeling_utils.py",
|
||||
"tests/test_modeling_common.py",
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"examples/pytorch/test_pytorch_examples.py",
|
||||
"examples/pytorch/text-classification/run_glue.py",
|
||||
}
|
||||
assert set(reverse_map["src/transformers/__init__.py"]) == expected_init_deps
|
||||
|
||||
expected_init_deps = {
|
||||
"src/transformers/__init__.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/models/bert/modeling_bert.py",
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"examples/pytorch/test_pytorch_examples.py",
|
||||
"examples/pytorch/text-classification/run_glue.py",
|
||||
}
|
||||
assert set(reverse_map["src/transformers/models/bert/__init__.py"]) == expected_init_deps
|
||||
|
||||
# Test that with more models init of bert only gets deps to bert.
|
||||
create_tmp_repo(tmp_folder, models=["bert", "gpt2"])
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
reverse_map = create_reverse_dependency_map()
|
||||
|
||||
# init gets the direct deps (and their recursive deps)
|
||||
expected_init_deps = {
|
||||
"src/transformers/__init__.py",
|
||||
"src/transformers/models/bert/configuration_bert.py",
|
||||
"src/transformers/models/bert/modeling_bert.py",
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"examples/pytorch/test_pytorch_examples.py",
|
||||
"examples/pytorch/text-classification/run_glue.py",
|
||||
}
|
||||
assert set(reverse_map["src/transformers/models/bert/__init__.py"]) == expected_init_deps
|
||||
|
||||
@unittest.skip("Broken for now TODO @ArthurZucker")
|
||||
def test_infer_tests_to_run(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
models = ["bert", "gpt2"] + [f"bert{i}" for i in range(10)]
|
||||
repo = create_tmp_repo(tmp_folder, models=models)
|
||||
|
||||
commit_changes("src/transformers/models/bert/modeling_bert.py", BERT_MODEL_FILE_NEW_CODE, repo)
|
||||
|
||||
example_tests = {
|
||||
"examples/pytorch/test_pytorch_examples.py",
|
||||
}
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
|
||||
with open(tmp_folder / "test-output.txt") as f:
|
||||
tests_to_run = f.read()
|
||||
with open(tmp_folder / "examples_test_list.txt") as f:
|
||||
example_tests_to_run = f.read()
|
||||
|
||||
assert tests_to_run == "tests/models/bert/test_modeling_bert.py"
|
||||
assert set(example_tests_to_run.split(" ")) == example_tests
|
||||
|
||||
# Fake a new model addition
|
||||
repo = create_tmp_repo(tmp_folder, models=models)
|
||||
|
||||
branch = repo.create_head("new_model")
|
||||
branch.checkout()
|
||||
|
||||
with open(tmp_folder / "src/transformers/__init__.py", "a") as f:
|
||||
f.write("from .models.t5 import T5Config, T5Model\n")
|
||||
|
||||
model_dir = tmp_folder / "src/transformers/models/t5"
|
||||
model_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(model_dir / "__init__.py", "w") as f:
|
||||
f.write("from .configuration_t5 import T5Config\nfrom .modeling_t5 import T5Model\n")
|
||||
with open(model_dir / "configuration_t5.py", "w") as f:
|
||||
f.write("from ...configuration_utils import PreTrainedConfig\ncode")
|
||||
with open(model_dir / "modeling_t5.py", "w") as f:
|
||||
modeling_code = BERT_MODEL_FILE.replace("bert", "t5").replace("Bert", "T5")
|
||||
f.write(modeling_code)
|
||||
|
||||
test_dir = tmp_folder / "tests/models/t5"
|
||||
test_dir.mkdir(exist_ok=True)
|
||||
(test_dir / "__init__.py").touch()
|
||||
with open(test_dir / "test_modeling_t5.py", "w") as f:
|
||||
f.write(
|
||||
"from transformers import T5Config, T5Model\nfrom ...test_modeling_common import ModelTesterMixin\n\ncode"
|
||||
)
|
||||
|
||||
repo.index.add(["src", "tests"])
|
||||
repo.index.commit("Add T5 model")
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
infer_tests_to_run(tmp_folder / "test-output.txt")
|
||||
with open(tmp_folder / "test-output.txt") as f:
|
||||
tests_to_run = f.read()
|
||||
with open(tmp_folder / "examples_test_list.txt") as f:
|
||||
example_tests_to_run = f.read()
|
||||
|
||||
expected_tests = {
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"tests/models/gpt2/test_modeling_gpt2.py",
|
||||
"tests/models/t5/test_modeling_t5.py",
|
||||
"tests/test_modeling_common.py",
|
||||
}
|
||||
assert set(tests_to_run.split(" ")) == expected_tests
|
||||
assert set(example_tests_to_run.split(" ")) == example_tests
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
infer_tests_to_run(tmp_folder / "test-output.txt")
|
||||
with open(tmp_folder / "test-output.txt") as f:
|
||||
tests_to_run = f.read()
|
||||
with open(tmp_folder / "examples_test_list.txt") as f:
|
||||
example_tests_to_run = f.read()
|
||||
|
||||
expected_tests = [f"tests/models/{name}/test_modeling_{name}.py" for name in models + ["t5"]]
|
||||
expected_tests = set(expected_tests + ["tests/test_modeling_common.py"])
|
||||
assert set(tests_to_run.split(" ")) == expected_tests
|
||||
assert set(example_tests_to_run.split(" ")) == example_tests
|
||||
|
||||
@unittest.skip("Broken for now TODO @ArthurZucker")
|
||||
def test_infer_tests_to_run_with_test_modifs(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
models = ["bert", "gpt2"] + [f"bert{i}" for i in range(10)]
|
||||
repo = create_tmp_repo(tmp_folder, models=models)
|
||||
|
||||
commit_changes(
|
||||
"tests/models/bert/test_modeling_bert.py",
|
||||
"from transformers import BertConfig, BertModel\nfrom ...test_modeling_common import ModelTesterMixin\n\ncode1",
|
||||
repo,
|
||||
)
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
|
||||
with open(tmp_folder / "test-output.txt") as f:
|
||||
tests_to_run = f.read()
|
||||
|
||||
assert tests_to_run == "tests/models/bert/test_modeling_bert.py"
|
||||
|
||||
@unittest.skip("Broken for now TODO @ArthurZucker")
|
||||
def test_infer_tests_to_run_with_examples_modifs(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
tmp_folder = Path(tmp_folder)
|
||||
models = ["bert", "gpt2"]
|
||||
repo = create_tmp_repo(tmp_folder, models=models)
|
||||
|
||||
# Modification in one example trigger the corresponding test
|
||||
commit_changes(
|
||||
"examples/pytorch/text-classification/run_glue.py",
|
||||
"from transformers import BertModeln\n\ncode1",
|
||||
repo,
|
||||
)
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
|
||||
with open(tmp_folder / "examples_test_list.txt") as f:
|
||||
example_tests_to_run = f.read()
|
||||
|
||||
assert example_tests_to_run == "examples/pytorch/test_pytorch_examples.py"
|
||||
|
||||
# Modification in one test example file trigger that test
|
||||
repo = create_tmp_repo(tmp_folder, models=models)
|
||||
commit_changes(
|
||||
"examples/pytorch/test_pytorch_examples.py",
|
||||
"""test_args = "run_glue.py"\nmore_code""",
|
||||
repo,
|
||||
)
|
||||
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
infer_tests_to_run(tmp_folder / "test-output.txt", diff_with_last_commit=True)
|
||||
with open(tmp_folder / "examples_test_list.txt") as f:
|
||||
example_tests_to_run = f.read()
|
||||
|
||||
assert example_tests_to_run == "examples/pytorch/test_pytorch_examples.py"
|
||||
|
||||
def test_parse_commit_message(self):
|
||||
assert parse_commit_message("Normal commit") == {"skip": False, "no_filter": False, "test_all": False}
|
||||
|
||||
assert parse_commit_message("[skip ci] commit") == {"skip": True, "no_filter": False, "test_all": False}
|
||||
assert parse_commit_message("[ci skip] commit") == {"skip": True, "no_filter": False, "test_all": False}
|
||||
assert parse_commit_message("[skip-ci] commit") == {"skip": True, "no_filter": False, "test_all": False}
|
||||
assert parse_commit_message("[skip_ci] commit") == {"skip": True, "no_filter": False, "test_all": False}
|
||||
|
||||
assert parse_commit_message("[no filter] commit") == {"skip": False, "no_filter": True, "test_all": False}
|
||||
assert parse_commit_message("[no-filter] commit") == {"skip": False, "no_filter": True, "test_all": False}
|
||||
assert parse_commit_message("[no_filter] commit") == {"skip": False, "no_filter": True, "test_all": False}
|
||||
assert parse_commit_message("[filter-no] commit") == {"skip": False, "no_filter": True, "test_all": False}
|
||||
|
||||
assert parse_commit_message("[test all] commit") == {"skip": False, "no_filter": False, "test_all": True}
|
||||
assert parse_commit_message("[all test] commit") == {"skip": False, "no_filter": False, "test_all": True}
|
||||
assert parse_commit_message("[test-all] commit") == {"skip": False, "no_filter": False, "test_all": True}
|
||||
assert parse_commit_message("[all_test] commit") == {"skip": False, "no_filter": False, "test_all": True}
|
||||
Reference in New Issue
Block a user