first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled

This commit is contained in:
陈赣
2026-06-05 16:53:03 +08:00
commit 06f1fd69a6
6047 changed files with 1895387 additions and 0 deletions

427
utils/add_dates.py Normal file
View File

@@ -0,0 +1,427 @@
import argparse
import os
import re
import subprocess
from datetime import date, datetime
from urllib.error import HTTPError
from urllib.request import Request, urlopen
from huggingface_hub import paper_info
from transformers import logging
CHECKER_CONFIG = {
"name": "add_dates",
"label": "Model dates",
# Approximate: also reads docs/source/en/model_doc/*.md and uses git log + network
# calls to GitHub/HuggingFace for commit dates and paper metadata.
"cache_globs": ["src/transformers/models/**/__init__.py", "docs/source/en/model_doc/**/*.md"],
"check_args": ["--check-only"],
"fix_args": [],
}
logger = logging.get_logger(__name__)
ROOT = os.getcwd().split("utils")[0]
DOCS_PATH = os.path.join(ROOT, "docs/source/en/model_doc")
MODELS_PATH = os.path.join(ROOT, "src/transformers/models")
GITHUB_REPO_URL = "https://github.com/huggingface/transformers"
GITHUB_RAW_URL = "https://raw.githubusercontent.com/huggingface/transformers/main"
COPYRIGHT_DISCLAIMER = """<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->"""
ARXIV_PAPERS_NOT_IN_HF_PAPERS = {
"gemma3n.md": "2506.06644",
"xmod.md": "2205.06266",
}
def check_file_exists_on_github(file_path: str) -> bool:
"""Check if a file exists on the main branch of the GitHub repository.
Args:
file_path: Relative path from repository root
Returns:
True if file exists on GitHub main branch (or if check failed), False only if confirmed 404
Note:
On network errors or other issues, returns True (assumes file exists) with a warning.
This prevents the script from failing due to temporary network issues.
"""
# Convert absolute path to relative path from repository root if needed
if file_path.startswith(ROOT):
file_path = file_path[len(ROOT) :].lstrip("/")
# Construct the raw GitHub URL for the file
url = f"{GITHUB_RAW_URL}/{file_path}"
try:
# Make a HEAD request to check if file exists (more efficient than GET)
request = Request(url, method="HEAD")
request.add_header("User-Agent", "transformers-add-dates-script")
with urlopen(request, timeout=10) as response:
return response.status == 200
except HTTPError as e:
if e.code == 404:
# File doesn't exist on GitHub
return False
# HTTP error (non-404): assume file exists and continue with local git history
return True
except Exception:
# Network/timeout error: assume file exists and continue with local git history
return True
def get_modified_cards() -> list[str]:
"""Get the list of model names from modified files in docs/source/en/model_doc/"""
current_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip()
if current_branch == "main":
# On main branch, only uncommitted changes detected
result = subprocess.check_output(["git", "diff", "--name-only", "HEAD"], text=True)
else:
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
result = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8")
model_names = []
for line in result.strip().split("\n"):
if line:
# Check if the file is in the model_doc directory
if line.startswith("docs/source/en/model_doc/") and line.endswith(".md"):
file_path = os.path.join(ROOT, line)
if os.path.exists(file_path):
model_name = os.path.splitext(os.path.basename(line))[0]
if model_name not in ["auto", "timm_wrapper"]:
model_names.append(model_name)
return model_names
def get_paper_link(model_card: str | None, path: str | None) -> str:
"""Get the first paper link from the model card content."""
if model_card is not None and not model_card.endswith(".md"):
model_card = f"{model_card}.md"
file_path = path or os.path.join(DOCS_PATH, f"{model_card}")
model_card = os.path.basename(file_path)
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Find known paper links
paper_ids = re.findall(r"https://huggingface\.co/papers/\d+\.\d+", content)
paper_ids += re.findall(r"https://arxiv\.org/abs/\d+\.\d+", content)
paper_ids += re.findall(r"https://arxiv\.org/pdf/\d+\.\d+", content)
if len(paper_ids) == 0:
return "No_paper"
return paper_ids[0]
def get_first_commit_date(model_name: str | None) -> str:
"""Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers"""
if model_name.endswith(".md"):
model_name = f"{model_name[:-3]}"
model_name_src = model_name
if "-" in model_name:
model_name_src = model_name.replace("-", "_")
file_path = os.path.join(MODELS_PATH, model_name_src, "__init__.py")
# If the init file is not found (only true for legacy models), the doc's first commit date is used
if not os.path.exists(file_path):
file_path = os.path.join(DOCS_PATH, f"{model_name}.md")
# Check if file exists on GitHub main branch
file_exists_on_github = check_file_exists_on_github(file_path)
if not file_exists_on_github:
# File does not exist on GitHub main branch (new model), use today's date
final_date = date.today().isoformat()
else:
# File exists on GitHub main branch, get the first commit date from local git history
final_date = subprocess.check_output(
["git", "log", "--reverse", "--pretty=format:%ad", "--date=iso", file_path], text=True
)
return final_date.strip().split("\n")[0][:10]
def get_release_date(link: str) -> str:
if link.startswith("https://huggingface.co/papers/"):
link = link.replace("https://huggingface.co/papers/", "")
try:
info = paper_info(link)
return info.published_at.date().isoformat()
except Exception as e:
# Error fetching release date, function returns None (will use placeholder)
logger.debug(f"Could not fetch paper info for {link}: {e}")
elif link.startswith("https://arxiv.org/abs/") or link.startswith("https://arxiv.org/pdf/"):
return r"{release_date}"
def replace_paper_links(file_path: str) -> bool:
"""Replace arxiv links with huggingface links if valid, and replace hf.co with huggingface.co"""
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
original_content = content
# Replace hf.co with huggingface.co
content = content.replace("https://hf.co/", "https://huggingface.co/")
# Find all arxiv links
arxiv_links = re.findall(r"https://arxiv\.org/abs/(\d+\.\d+)", content)
arxiv_links += re.findall(r"https://arxiv\.org/pdf/(\d+\.\d+)", content)
for paper_id in arxiv_links:
try:
# Check if paper exists on huggingface
paper_info(paper_id)
# If no exception, replace the link
old_link = f"https://arxiv.org/abs/{paper_id}"
if old_link not in content:
old_link = f"https://arxiv.org/pdf/{paper_id}"
new_link = f"https://huggingface.co/papers/{paper_id}"
content = content.replace(old_link, new_link)
except Exception:
# Paper not available on huggingface, keep arxiv link
continue
# Write back only if content changed
if content != original_content:
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
return True
return False
def _normalize_model_card_name(model_card: str) -> str:
"""Ensure model card has .md extension"""
return model_card if model_card.endswith(".md") else f"{model_card}.md"
def _should_skip_model_card(model_card: str) -> bool:
"""Check if model card should be skipped"""
return model_card in ("auto.md", "timm_wrapper.md")
def _read_model_card_content(model_card: str) -> str:
"""Read and return the content of a model card"""
file_path = os.path.join(DOCS_PATH, model_card)
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
def _get_dates_pattern_match(content: str):
"""Search for the dates pattern in content and return match object"""
pattern = r"\n\*This model was (?:published in HF papers on (.*) and )?contributed to Hugging Face Transformers on (\d{4}-\d{2}-\d{2})\.\*"
return re.search(pattern, content)
def _dates_differ_significantly(date1: str, date2: str) -> bool:
"""Check if two dates differ by more than 1 day"""
try:
d1 = datetime.strptime(date1, "%Y-%m-%d")
d2 = datetime.strptime(date2, "%Y-%m-%d")
return abs((d1 - d2).days) > 1
except Exception:
return True # If dates can't be parsed, consider them different
def check_missing_dates(model_card_list: list[str]) -> list[str]:
"""Check which model cards are missing release dates and return their names"""
missing_dates = []
for model_card in model_card_list:
model_card = _normalize_model_card_name(model_card)
if _should_skip_model_card(model_card):
continue
content = _read_model_card_content(model_card)
if not _get_dates_pattern_match(content):
missing_dates.append(model_card)
return missing_dates
def check_incorrect_dates(model_card_list: list[str]) -> list[str]:
"""Check which model cards have incorrect model release/addition dates and return their names"""
incorrect_dates = []
for model_card in model_card_list:
model_card = _normalize_model_card_name(model_card)
if _should_skip_model_card(model_card):
continue
content = _read_model_card_content(model_card)
match = _get_dates_pattern_match(content)
file_path = os.path.join(DOCS_PATH, model_card)
paper_link = get_paper_link(model_card=model_card, path=file_path)
if paper_link in ("No_paper", "blog"):
release_date = r"{release_date}"
else:
release_date = get_release_date(paper_link)
if match:
# Preserve existing release date unless it's a placeholder
existing_release_date = match.group(1)
if existing_release_date not in (r"{release_date}", "None"):
release_date = existing_release_date
existing_hf_date = match.group(2)
actual_hf_date = get_first_commit_date(model_name=model_card)
if _dates_differ_significantly(existing_hf_date, actual_hf_date) or existing_release_date != release_date:
incorrect_dates.append(model_card)
return incorrect_dates
def insert_dates(model_card_list: list[str]):
"""Insert or update release and commit dates in model cards"""
for model_card in model_card_list:
model_card = _normalize_model_card_name(model_card)
if _should_skip_model_card(model_card):
continue
file_path = os.path.join(DOCS_PATH, model_card)
# First replace arxiv paper links with hf paper link if possible
replace_paper_links(file_path)
# Read content and ensure copyright disclaimer exists
content = _read_model_card_content(model_card)
markers = list(re.finditer(r"-->", content))
if len(markers) == 0:
# No copyright marker found, adding disclaimer to the top
content = COPYRIGHT_DISCLAIMER + "\n\n" + content
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
markers = list(re.finditer(r"-->", content))
# Get dates
hf_commit_date = get_first_commit_date(model_name=model_card)
paper_link = get_paper_link(model_card=model_card, path=file_path)
if paper_link in ("No_paper", "blog"):
release_date = r"{release_date}"
else:
release_date = get_release_date(paper_link)
match = _get_dates_pattern_match(content)
# Update or insert the dates line
if match:
# Preserve existing release date unless it's a placeholder
existing_release_date = match.group(1)
existing_hf_date = match.group(2)
if existing_release_date not in (r"{release_date}", "None"):
release_date = existing_release_date
if _dates_differ_significantly(existing_hf_date, hf_commit_date) or existing_release_date != release_date:
old_line = match.group(0)
if release_date != r"{release_date}":
new_line = f"\n*This model was published in HF papers on {release_date} and contributed to Hugging Face Transformers on {hf_commit_date}.*"
else:
new_line = f"\n*This model was contributed to Hugging Face Transformers on {hf_commit_date}.*"
content = content.replace(old_line, new_line)
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
else:
# Insert new dates line after copyright marker
insert_index = markers[0].end()
if release_date != r"{release_date}":
date_info = f"\n*This model was published in HF papers on {release_date} and contributed to Hugging Face Transformers on {hf_commit_date}.*"
else:
date_info = f"\n*This model was contributed to Hugging Face Transformers on {hf_commit_date}.*"
content = content[:insert_index] + date_info + content[insert_index:]
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
def get_all_model_cards():
"""Get all model cards from the docs path"""
all_files = os.listdir(DOCS_PATH)
model_cards = []
for file in all_files:
if file.endswith(".md"):
model_name = os.path.splitext(file)[0]
if model_name not in ["auto", "timm_wrapper"]:
model_cards.append(model_name)
return sorted(model_cards)
def main(all=False, models=None, check_only=False):
if check_only:
# Check all model cards for missing dates
all_model_cards = get_all_model_cards()
missing_dates = check_missing_dates(all_model_cards)
# Check modified model cards for incorrect dates
modified_cards = get_modified_cards()
incorrect_dates = check_incorrect_dates(modified_cards)
if missing_dates or incorrect_dates:
problematic_cards = missing_dates + incorrect_dates
model_names = [card.replace(".md", "") for card in problematic_cards]
raise ValueError(
f"Missing or incorrect dates in the following model cards: {' '.join(problematic_cards)}\n"
f"Run `python utils/add_dates.py --models {' '.join(model_names)}` to fix them."
)
return
# Determine which model cards to process
if all:
model_cards = get_all_model_cards()
elif models:
model_cards = models
else:
model_cards = get_modified_cards()
if not model_cards:
return
insert_dates(model_cards)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Add release and commit dates to model cards")
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--models", nargs="+", help="Specify model cards to process (without .md extension)")
group.add_argument("--all", action="store_true", help="Process all model cards in the docs directory")
group.add_argument("--check-only", action="store_true", help="Check if the dates are already present")
args = parser.parse_args()
try:
main(args.all, args.models, args.check_only)
except subprocess.CalledProcessError as e:
print(
f"An error occurred while executing git commands but it can be ignored (git issue) most probably local: {e}"
)

View File

@@ -0,0 +1,307 @@
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to add and/or update the attribute `pipeline_model_mapping` in model test files.
This script will be (mostly) used in the following 2 situations:
- run within a (scheduled) CI job to:
- check if model test files in the library have updated `pipeline_model_mapping`,
- and/or update test files and (possibly) open a GitHub pull request automatically
- being run by a `transformers` member to quickly check and update some particular test file(s)
This script is **NOT** intended to be run (manually) by community contributors.
"""
import argparse
import glob
import inspect
import os
import re
import unittest
from get_test_info import get_test_classes
from tests.test_pipeline_mixin import pipeline_test_mapping
PIPELINE_TEST_MAPPING = {}
for task in pipeline_test_mapping:
PIPELINE_TEST_MAPPING[task] = None
# DO **NOT** add item to this set (unless the reason is approved)
TEST_FILE_TO_IGNORE = {
"tests/models/esm/test_modeling_esmfold.py", # The pipeline test mapping is added to `test_modeling_esm.py`
}
def get_mapping_for_task(task):
"""Get mappings defined in `XXXPipelineTests` for the task `task`."""
# Use the cached results
if PIPELINE_TEST_MAPPING[task] is not None:
return PIPELINE_TEST_MAPPING[task]
pipeline_test_class = pipeline_test_mapping[task]["test"]
mapping = getattr(pipeline_test_class, "model_mapping", None)
if mapping is not None:
mapping = dict(mapping.items())
# cache the results
PIPELINE_TEST_MAPPING[task] = mapping
return mapping
def get_model_for_pipeline_test(test_class, task):
"""Get the model architecture(s) related to the test class `test_class` for a pipeline `task`."""
mapping = get_mapping_for_task(task)
if mapping is None:
return None
config_classes = list({model_class.config_class for model_class in test_class.all_model_classes})
if len(config_classes) != 1:
raise ValueError("There should be exactly one configuration class from `test_class.all_model_classes`.")
# This could be a list/tuple of model classes, but it's rare.
model_class = mapping.get(config_classes[0], None)
if isinstance(model_class, (tuple, list)):
model_class = sorted(model_class, key=lambda x: x.__name__)
return model_class
def get_pipeline_model_mapping(test_class):
"""Get `pipeline_model_mapping` for `test_class`."""
mapping = [(task, get_model_for_pipeline_test(test_class, task)) for task in pipeline_test_mapping]
mapping = sorted([(task, model) for task, model in mapping if model is not None], key=lambda x: x[0])
return dict(mapping)
def get_pipeline_model_mapping_string(test_class):
"""Get `pipeline_model_mapping` for `test_class` as a string (to be added to the test file).
This will be a 1-line string. After this is added to a test file, `make style` will format it beautifully.
"""
default_value = "{}"
mapping = get_pipeline_model_mapping(test_class)
if len(mapping) == 0:
return ""
texts = []
for task, model_classes in mapping.items():
if isinstance(model_classes, (tuple, list)):
# A list/tuple of model classes
value = "(" + ", ".join([x.__name__ for x in model_classes]) + ")"
else:
# A single model class
value = model_classes.__name__
texts.append(f'"{task}": {value}')
text = "{" + ", ".join(texts) + "}"
text = f"pipeline_model_mapping = {text} if is_torch_available() else {default_value}"
return text
def is_valid_test_class(test_class):
"""Restrict to `XXXModelTesterMixin` and should be a subclass of `unittest.TestCase`."""
if not issubclass(test_class, unittest.TestCase):
return False
return "ModelTesterMixin" in [x.__name__ for x in test_class.__bases__]
def find_test_class(test_file):
"""Find a test class in `test_file` to which we will add `pipeline_model_mapping`."""
test_classes = [x for x in get_test_classes(test_file) if is_valid_test_class(x)]
target_test_class = None
for test_class in test_classes:
# If a test class has defined `pipeline_model_mapping`, let's take it
if getattr(test_class, "pipeline_model_mapping", None) is not None:
target_test_class = test_class
break
# Take the test class with the shortest name (just a heuristic)
if target_test_class is None and len(test_classes) > 0:
target_test_class = min(test_classes, key=lambda x: (len(x.__name__), x.__name__))
return target_test_class
def find_block_ending(lines, start_idx, indent_level):
end_idx = start_idx
for idx, line in enumerate(lines[start_idx:]):
indent = len(line) - len(line.lstrip())
if idx == 0 or indent > indent_level or (indent == indent_level and line.strip() == ")"):
end_idx = start_idx + idx
elif idx > 0 and indent <= indent_level:
# Outside the definition block of `pipeline_model_mapping`
break
return end_idx
def add_pipeline_model_mapping(test_class, overwrite=False):
"""Add `pipeline_model_mapping` to `test_class`."""
if getattr(test_class, "pipeline_model_mapping", None) is not None:
if not overwrite:
return "", -1
line_to_add = get_pipeline_model_mapping_string(test_class)
if len(line_to_add) == 0:
return "", -1
line_to_add = line_to_add + "\n"
# The code defined the class `test_class`
class_lines, class_start_line_no = inspect.getsourcelines(test_class)
# `inspect` gives the code for an object, including decorator(s) if any.
# We (only) need the exact line of the class definition.
for idx, line in enumerate(class_lines):
if line.lstrip().startswith("class "):
class_lines = class_lines[idx:]
class_start_line_no += idx
break
class_end_line_no = class_start_line_no + len(class_lines) - 1
# The index in `class_lines` that starts the definition of `all_model_classes`, `all_generative_model_classes` or
# `pipeline_model_mapping`. This assumes they are defined in such order, and we take the start index of the last
# block that appears in a `test_class`.
start_idx = None
# The indent level of the line at `class_lines[start_idx]` (if defined)
indent_level = 0
# To record if `pipeline_model_mapping` is found in `test_class`.
def_line = None
for idx, line in enumerate(class_lines):
if line.strip().startswith("all_model_classes = "):
indent_level = len(line) - len(line.lstrip())
start_idx = idx
elif line.strip().startswith("all_generative_model_classes = "):
indent_level = len(line) - len(line.lstrip())
start_idx = idx
elif line.strip().startswith("pipeline_model_mapping = "):
indent_level = len(line) - len(line.lstrip())
start_idx = idx
def_line = line
break
if start_idx is None:
return "", -1
# Find the ending index (inclusive) of the above found block.
end_idx = find_block_ending(class_lines, start_idx, indent_level)
# Extract `is_xxx_available()` from existing blocks: some models require specific libraries like `timm` and use
# `is_timm_available()` instead of `is_torch_available()`.
# Keep leading and trailing whitespaces
r = re.compile(r"\s(is_\S+?_available\(\))\s")
for line in class_lines[start_idx : end_idx + 1]:
backend_condition = r.search(line)
if backend_condition is not None:
# replace the leading and trailing whitespaces to the space character " ".
target = " " + backend_condition[0][1:-1] + " "
line_to_add = r.sub(target, line_to_add)
break
if def_line is None:
# `pipeline_model_mapping` is not defined. The target index is set to the ending index (inclusive) of
# `all_model_classes` or `all_generative_model_classes`.
target_idx = end_idx
else:
# `pipeline_model_mapping` is defined. The target index is set to be one **BEFORE** its start index.
target_idx = start_idx - 1
# mark the lines of the currently existing `pipeline_model_mapping` to be removed.
for idx in range(start_idx, end_idx + 1):
# These lines are going to be removed before writing to the test file.
class_lines[idx] = None # noqa
# Make sure the test class is a subclass of `PipelineTesterMixin`.
parent_classes = [x.__name__ for x in test_class.__bases__]
if "PipelineTesterMixin" not in parent_classes:
# Put `PipelineTesterMixin` just before `unittest.TestCase`
_parent_classes = [x for x in parent_classes if x != "TestCase"] + ["PipelineTesterMixin"]
if "TestCase" in parent_classes:
# Here we **assume** the original string is always with `unittest.TestCase`.
_parent_classes.append("unittest.TestCase")
parent_classes = ", ".join(_parent_classes)
for idx, line in enumerate(class_lines):
# Find the ending of the declaration of `test_class`
if line.strip().endswith("):"):
# mark the lines of the declaration of `test_class` to be removed
for _idx in range(idx + 1):
class_lines[_idx] = None # noqa
break
# Add the new, one-line, class declaration for `test_class`
class_lines[0] = f"class {test_class.__name__}({parent_classes}):\n"
# Add indentation
line_to_add = " " * indent_level + line_to_add
# Insert `pipeline_model_mapping` to `class_lines`.
# (The line at `target_idx` should be kept by definition!)
class_lines = class_lines[: target_idx + 1] + [line_to_add] + class_lines[target_idx + 1 :]
# Remove the lines that are marked to be removed
class_lines = [x for x in class_lines if x is not None]
# Move from test class to module (in order to write to the test file)
module_lines = inspect.getsourcelines(inspect.getmodule(test_class))[0]
# Be careful with the 1-off between line numbers and array indices
module_lines = module_lines[: class_start_line_no - 1] + class_lines + module_lines[class_end_line_no:]
code = "".join(module_lines)
moddule_file = inspect.getsourcefile(test_class)
with open(moddule_file, "w", encoding="UTF-8", newline="\n") as fp:
fp.write(code)
return line_to_add
def add_pipeline_model_mapping_to_test_file(test_file, overwrite=False):
"""Add `pipeline_model_mapping` to `test_file`."""
test_class = find_test_class(test_file)
if test_class:
add_pipeline_model_mapping(test_class, overwrite=overwrite)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--test_file", type=str, help="A path to the test file, starting with the repository's `tests` directory."
)
parser.add_argument(
"--all",
action="store_true",
help="If to check and modify all test files.",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="If to overwrite a test class if it has already defined `pipeline_model_mapping`.",
)
args = parser.parse_args()
if not args.all and not args.test_file:
raise ValueError("Please specify either `test_file` or pass `--all` to check/modify all test files.")
elif args.all and args.test_file:
raise ValueError("Only one of `--test_file` and `--all` could be specified.")
test_files = []
if args.test_file:
test_files = [args.test_file]
else:
pattern = os.path.join("tests", "models", "**", "test_modeling_*.py")
for test_file in glob.glob(pattern):
test_files.append(test_file)
for test_file in test_files:
if test_file in TEST_FILE_TO_IGNORE:
print(f"[SKIPPED] {test_file} is skipped as it is in `TEST_FILE_TO_IGNORE` in the file {__file__}.")
continue
add_pipeline_model_mapping_to_test_file(test_file, overwrite=args.overwrite)

View File

@@ -0,0 +1,80 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Aggregate multiple failure report JSON files into a single file.
This script reads all JSON files from a directory and combines them
into a single JSON array.
"""
import argparse
import json
import sys
from pathlib import Path
def aggregate_failures(input_dir, output_file):
"""
Aggregate failure reports from multiple JSON files.
Args:
input_dir: Directory containing failure report JSON files
output_file: Path to output aggregated JSON file
Returns:
Number of failures aggregated
"""
failures = []
input_path = Path(input_dir)
if input_path.exists() and input_path.is_dir():
for failure_file in input_path.glob("*.json"):
try:
with open(failure_file) as f:
failure_data = json.load(f)
failures.append(failure_data)
except Exception as e:
print(f"Error reading {failure_file}: {e}", file=sys.stderr)
# Write aggregated failures
with open(output_file, "w") as f:
json.dump(failures, f, indent=2)
print(f"Aggregated {len(failures)} failure(s) from {input_dir} to {output_file}")
return len(failures)
def main():
parser = argparse.ArgumentParser(description="Aggregate failure report JSON files")
parser.add_argument(
"--input-dir",
required=True,
help="Directory containing failure report JSON files",
)
parser.add_argument(
"--output",
required=True,
help="Output file path for aggregated JSON",
)
args = parser.parse_args()
aggregate_failures(args.input_dir, args.output)
return 0
if __name__ == "__main__":
sys.exit(main())

309
utils/check_auto.py Normal file
View File

@@ -0,0 +1,309 @@
# Copyright 2026 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import difflib
import glob
import os
import subprocess
import tempfile
from collections import Counter, OrderedDict
from typing import Any
from sort_auto_mappings import sort_auto_mapping
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES as COMPLETE_CONFIG_MAPPING_NAMES
from transformers.models.auto.image_processing_auto import MISSING_IMAGE_PROCESSOR_MAPPING_NAMES
from transformers.models.auto.video_processing_auto import MISSING_VIDEO_PROCESSOR_MAPPING_NAMES
CHECKER_CONFIG = {
"name": "auto_mappings",
"label": "Generate auto mappings",
"cache_globs": [],
"check_args": [],
"fix_args": ["--fix_and_overwrite"],
}
AUTO_GENERATED_HADER = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from existing config files and their `model_type`s. Do NOT edit this file
# manually as any edits will be overwritten by auto-generation of the file. If any change should be done,
# please add the correct `cls.model_type` in your config class and run `python utils/check_auto.py --fix_and_overwrite`.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# Some keys are duplicated due to incorrect naming at model shipping and BC
IGNORE_DUPLICATE_CONFIG = ["GPT2Config", "EvollaConfig", "MLCDVisionConfig"]
def build_config_mapping_names() -> tuple[dict, dict]:
model_type_map = OrderedDict()
special_mappings = OrderedDict()
# Track which model_types were resolved by a "natural" match (model_type == module_name)
# so a later non-natural match (e.g. MaskFormerDetrConfig with model_type="detr" inside
# models/maskformer/) does not silently overwrite the canonical class.
natural_types: set[str] = set()
# `glob.glob` is filesystem-order dependent — sort to make the output deterministic.
all_files = sorted(glob.glob("src/transformers/models/**/configuration_*.py", recursive=True))
for config_path in all_files:
module_name = config_path.split("/")[-2]
with open(config_path, "r") as f:
content = f.read()
tree = ast.parse(content)
for node in tree.body:
if isinstance(node, ast.ClassDef) and any(
base.id == "PreTrainedConfig" for base in node.bases if isinstance(base, ast.Name)
):
config_cls_name = node.name
model_type = None
for stmt in node.body:
if isinstance(stmt, ast.Assign):
if model_types := [
stmt.value.value
for target in stmt.targets
if isinstance(target, ast.Name) and target.id == "model_type"
]:
model_type = model_types[0]
break
elif isinstance(stmt, ast.AnnAssign):
if stmt.target.id == "model_type":
model_type = stmt.value.value
break
if not model_type:
continue
is_natural = model_type == module_name
# If we already recorded a natural match for this model_type, don't let a
# non-natural one overwrite it — the natural class is the canonical owner.
if model_type in natural_types and not is_natural:
continue
model_type_map[model_type] = config_cls_name
if is_natural:
natural_types.add(model_type)
special_mappings.pop(model_type, None)
else:
special_mappings[model_type] = module_name
return model_type_map, special_mappings
def build_image_processor_mapping(
config_mapping: dict[str, str],
) -> OrderedDict[str, dict[str, str | None]]:
processor_mapping = OrderedDict()
for model_type in config_mapping:
module = model_type.replace("-", "_")
fast_processor_name = slow_processor_name = None
if os.path.exists(f"src/transformers/models/{module}/image_processing_pil_{module}.py"):
with open(f"src/transformers/models/{module}/image_processing_pil_{module}.py", "r") as f:
content = f.read()
tree = ast.parse(content)
for node in tree.body:
if isinstance(node, ast.ClassDef) and any(
base.id == "PilBackend" for base in node.bases if isinstance(base, ast.Name)
):
slow_processor_name = node.name
if os.path.exists(f"src/transformers/models/{module}/image_processing_{module}.py"):
with open(f"src/transformers/models/{module}/image_processing_{module}.py", "r") as f:
content = f.read()
tree = ast.parse(content)
for node in tree.body:
if isinstance(node, ast.ClassDef) and any(
base.id == "TorchvisionBackend" for base in node.bases if isinstance(base, ast.Name)
):
fast_processor_name = node.name
if slow_processor_name is not None or fast_processor_name is not None:
processor_mapping[model_type] = {
**({"pil": slow_processor_name} if slow_processor_name else {}),
**({"torchvision": fast_processor_name} if fast_processor_name else {}),
}
return processor_mapping
def build_video_processor_mapping(
config_mapping: dict[str, str],
) -> OrderedDict[str, dict[str, str | None]]:
processor_mapping = OrderedDict()
for model_type in config_mapping:
module = model_type.replace("-", "_")
video_processor_name = None
if os.path.exists(f"src/transformers/models/{module}/video_processing_{module}.py"):
with open(f"src/transformers/models/{module}/video_processing_{module}.py", "r") as f:
content = f.read()
tree = ast.parse(content)
for node in tree.body:
if isinstance(node, ast.ClassDef) and any(
base.id == "BaseVideoProcessor" for base in node.bases if isinstance(base, ast.Name)
):
video_processor_name = node.name
if video_processor_name is not None:
processor_mapping[model_type] = video_processor_name
return processor_mapping
def run_ruff_and_sort(file: str):
"""Run `ruff` linter and formatter on `file`, as in `make style` and sort the mappings order"""
sort_auto_mapping(file, overwrite=True)
subprocess.run(["ruff", "check", file, "--fix"], stdout=subprocess.DEVNULL)
subprocess.run(["ruff", "format", file], stdout=subprocess.DEVNULL)
def format_dict_value(v):
if isinstance(v, str):
return f'"{v}"'
elif isinstance(v, dict):
items = ", ".join(f'"{k}": {format_dict_value(val)}' for k, val in v.items())
return "{" + items + "}"
elif isinstance(v, list):
items = ", ".join(format_dict_value(x) for x in v)
return "[" + items + "]"
else:
return repr(v)
def format_ordered_dict(name: str, data: OrderedDict):
lines = []
lines.append(f"{name} = OrderedDict(")
lines.append(f"{' ' * 4}[")
for k, v in data.items():
lines.append(f'{" " * 8}("{k}", {format_dict_value(v)}),')
lines.append(f"{' ' * 4}]")
lines.append(")\n\n")
return "\n".join(lines)
def check_duplicates(mapping_for_special_models: dict[str, Any], auto_mapping: dict[str, Any]):
if intersections := (set(mapping_for_special_models.keys()) & set(auto_mapping.keys())):
raise ValueError(
"You have manually duplicated a model-type that is present in `auto_mappings.py`. "
f"Please, delete the entries for {intersections} if they are identical to auto-generated dict, "
"or use consistent naming across model files so that the names match."
)
def main(overwrite: bool):
filename = "src/transformers/models/auto/auto_mappings.py"
# 1. Read existing file content if available
old_content = ""
if os.path.exists(filename):
old_content = open(filename, "r").read()
# 2. Generate new config mapping dicts by parsing all model-config classes
config_mapping, special_mapping = build_config_mapping_names()
image_processor_mapping = build_image_processor_mapping(config_mapping=config_mapping)
video_processor_mapping = build_video_processor_mapping(config_mapping=config_mapping)
# Make sure users aren't duplicating the same keys manually
check_duplicates(MISSING_IMAGE_PROCESSOR_MAPPING_NAMES, image_processor_mapping)
check_duplicates(MISSING_VIDEO_PROCESSOR_MAPPING_NAMES, video_processor_mapping)
# The config mapping has to be one-to-one for correct `AutoConfig.from_pretrained()` because `LazyMapping`
# reverts keys/values and creates a dict from it. Duplicate values will be overwritten by whatever comes at last
duplicate_keys = [n for n, c in Counter(COMPLETE_CONFIG_MAPPING_NAMES.keys()).items() if c > 1]
if duplicate_keys:
raise ValueError(
f"Keys in `CONFIG_MAPPING_NAMES` contain duplicates = {duplicate_keys}. "
"The mapping has to be one-to-one to ensure correct `AutoConfig` functionality!"
)
duplicate_values = [
n
for n, c in Counter(COMPLETE_CONFIG_MAPPING_NAMES.values()).items()
if c > 1 and n not in IGNORE_DUPLICATE_CONFIG
]
if duplicate_values:
raise ValueError(
f"Values in `CONFIG_MAPPING_NAMES` contain duplicates = {duplicate_values}. "
"The mapping has to be one-to-one to ensure correct `AutoConfig` functionality!"
)
new_mappings = {
"CONFIG_MAPPING_NAMES": config_mapping,
"SPECIAL_MODEL_TYPE_TO_MODULE_NAME": special_mapping,
"IMAGE_PROCESSOR_MAPPING_NAMES": image_processor_mapping,
"VIDEO_PROCESSOR_MAPPING_NAMES": video_processor_mapping,
}
new_content = AUTO_GENERATED_HADER + "\nfrom collections import OrderedDict\n\n"
for k, v in new_mappings.items():
new_content += format_ordered_dict(name=k, data=v)
# 3. If the new auto-generate content is different, overwrite it
# Dirty hack to sort and apply ruff to the file content, for easier matching
with tempfile.TemporaryDirectory() as temp_folder:
temp_filename = os.path.join(temp_folder, "temp.py")
with open(temp_filename, "w") as temp_file:
temp_file.write(new_content)
run_ruff_and_sort(temp_filename)
new_content = open(temp_filename, "r").read()
if old_content != new_content:
if not overwrite:
diff = "".join(
difflib.unified_diff(
old_content.splitlines(keepends=True),
new_content.splitlines(keepends=True),
fromfile=f"{filename} (on disk)",
tofile=f"{filename} (regenerated)",
n=3,
)
)
raise Exception(
"Generated auto-mapping is not consistent with the contents of `models/auto/auto_mappings.py`.\n"
"Run `make fix-repo` or `python utils/check_auto.py --fix_and_overwrite` to fix them.\n\n"
f"Diff (on disk → regenerated):\n{diff}"
)
else:
with open(filename, "w") as f:
f.write(new_content)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
main(overwrite=args.fix_and_overwrite)

429
utils/check_bad_commit.py Normal file
View File

@@ -0,0 +1,429 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import json
import os
import re
import subprocess
from collections import defaultdict
import git
import requests
def create_script(target_test):
"""Create a python script to be run by `git bisect run` to determine if `target_test` passes or fails.
If a test is not found in a commit, the script with exit code `0` (i.e. `Success`).
Args:
target_test (`str`): The test to check.
Returns:
`str`: The script to be run by `git bisect run`.
"""
script = f"""
import os
import subprocess
_ = subprocess.run(
["python3", "-m", "pip", "install", "-e", "."],
capture_output = True,
text=True,
)
result = subprocess.run(
["python3", "-m", "pytest", "-v", "--flake-finder", "--flake-runs=4", "-rfEp", f"{target_test}"],
capture_output = True,
text=True,
)
print(result.stdout)
if f"FAILED {target_test}" in result.stdout:
print("test failed")
exit(1)
elif result.returncode != 0:
if "ERROR: file or directory not found: " in result.stderr:
print("test file or directory not found in this commit")
# git bisect treats exit code 125 as `test not found`. But this causes it not be able to make the conclusion
# if a test is added between the `good commit` (exclusive) and `bad commit` (inclusive) (in git bisect terminology).
# So we return 0 here in order to allow the process being able to identify the first commit that fails the test.
exit(0)
elif "ERROR: not found: " in result.stderr:
print("test not found in this commit")
exit(0)
else:
print(f"pytest gets unknown error: {{result.stderr}}")
exit(1)
print(f"pytest runs successfully.")
exit(0)
"""
with open("target_script.py", "w") as fp:
fp.write(script.strip())
def is_bad_commit(target_test, commit):
repo = git.Repo(".") # or specify path to your repo
# Save the current HEAD reference
original_head = repo.head.commit
# Checkout to the commit
repo.git.checkout(commit)
create_script(target_test=target_test)
result = subprocess.run(
["python3", "target_script.py"],
capture_output=True,
text=True,
)
# Restore to original commit
repo.git.checkout(original_head)
n_passed = 0
o = re.findall(r"====.* (\d+) passed", result.stdout)
if len(o) > 0:
n_passed = int(o[0])
n_failed = 0
o = re.findall(r"====.* (\d+) failed", result.stdout)
if len(o) > 0:
n_failed = int(o[0])
error_message = ""
if n_failed > 0:
match = re.search(r"^(FAILED .+ - .+)$", result.stdout, re.MULTILINE)
error_message = match.group(1).strip() if match else "Cannot retrieve error message."
return result.returncode != 0, n_failed, n_passed, error_message
def find_bad_commit(target_test, start_commit, end_commit):
"""Find (backward) the earliest commit between `start_commit` (inclusive) and `end_commit` (exclusive) at which `target_test` fails.
Args:
target_test (`str`): The test to check.
start_commit (`str`): The latest commit (inclusive).
end_commit (`str`): The earliest commit (exclusive).
Returns:
`dict`: A dict containing the info about the earliest commit at which `target_test` fails.
"""
result = {
"bad_commit": None,
"status": None,
"failure_at_workflow_commit": None,
"failure_at_base_commit": None,
"failure_at_bad_commit": None,
}
is_pr_ci = os.environ.get("GITHUB_EVENT_NAME") in ["issue_comment", "pull_request"]
# For PR comment CI, we "assume" all tests at `end_commit` passed, so any failing test during a PR CI run is
# "a new failing test", and we can perform more detailed checks with this script.
# For "a failing tes at start_commit", we check the test against `end_commit` (run multiple times):
# - if all passing at end_commit: an actual new failing test at start_commit
# - if all failing at end_commit: get the failure message and compare it against the one from start_commit:
# - same failure message: not a new failing test --> don't report it
# - different failure message: kind of a new failing test --> need to report it
# - if both failing and passing at end_commit: mark it as flaky
# check if `end_commit` fails the test
failed_before, n_failed, n_passed, failure_at_base_commit = is_bad_commit(target_test, end_commit)
# We only need one failure to conclude the test is flaky on the previous run with `end_commit`.
# However, when running on CI, we need at least one failure and one pass to conclude.
is_flaky_at_end_commit = ((not is_pr_ci) and n_failed > 0) or (is_pr_ci and n_failed > 0 and n_passed > 0)
# `n_passed == 0` itself is not enough, as the test may not exist in the codebase at `end_commit`.
is_failing_at_end_commit = failed_before and n_passed == 0
if is_flaky_at_end_commit:
result["status"] = (
f"flaky: test both passed and failed during the check of the current run on the previous commit: {end_commit}"
)
return result
elif (not is_pr_ci) and is_failing_at_end_commit:
result["status"] = (
f"flaky: test passed in the previous run (commit: {end_commit}) but failed (on the same commit) during the check of the current run."
)
return result
# if there is no new commit (e.g. 2 different CI runs on the same commit):
# - failed once on `start_commit` but passed on `end_commit`, which are the same commit --> flaky (or something change externally) --> don't report
if start_commit == end_commit:
result["status"] = (
f"flaky: test fails on the current CI run but passed in the previous run which is running on the same commit {end_commit}."
)
return result
# Now, we are (almost) sure `target_test` is not failing at `end_commit`. (For a PR CI, it may fail at `end_commit`)
# Check if `start_commit` fails the test.
# **IMPORTANT** we only need one pass to conclude the test is flaky on the current run with `start_commit`!
_, n_failed, n_passed, failure_at_workflow_commit = is_bad_commit(target_test, start_commit)
if n_passed > 0:
# failed on CI run, but not reproducible here --> don't report
result["status"] = (
f"flaky: test fails on the current CI run (commit: {start_commit}) but passes during the check."
)
return result
# The test fails on `start_commit`, and
# - if the CI is run on PR: this block checks if the test also failed on `start_commit`.
# - otherwise: the test passed on `end_commit` --> an actual new failing test, this block is skipped.
# TODO: A helper method to handle this and other possible error messages in a clean and centralized way.
failure_at_workflow_commit_processed = failure_at_workflow_commit
failure_at_base_commit_processed = failure_at_base_commit
if "torch.OutOfMemoryError: CUDA out of memory" in failure_at_workflow_commit_processed:
failure_at_workflow_commit_processed = "torch.OutOfMemoryError: CUDA out of memory"
if "torch.OutOfMemoryError: CUDA out of memory" in failure_at_base_commit_processed:
failure_at_base_commit_processed = "torch.OutOfMemoryError: CUDA out of memory"
different_failures = failure_at_workflow_commit_processed != failure_at_base_commit_processed
if is_pr_ci and failure_at_base_commit != "" and different_failures:
result["bad_commit"] = start_commit
result["status"] = (
f"test fails both on the current commit ({start_commit}) and the previous commit ({end_commit}), but with DIFFERENT error message!"
)
result["failure_at_workflow_commit"] = failure_at_workflow_commit
result["failure_at_base_commit"] = failure_at_base_commit
result["failure_at_bad_commit"] = failure_at_workflow_commit
return result
# Fail on both commits but with the same error message ==> don't include
elif is_pr_ci and not different_failures:
result["bad_commit"] = None
result["status"] = (
f"test fails both on the current commit ({start_commit}) and the previous commit ({end_commit}) with the SAME error message!"
)
result["failure_at_workflow_commit"] = failure_at_workflow_commit
result["failure_at_base_commit"] = failure_at_base_commit
result["failure_at_bad_commit"] = failure_at_workflow_commit
return result
# The test fails on `start_commit` but passed on `end_commit`.
create_script(target_test=target_test)
bash = f"""
git bisect reset
git bisect start --first-parent {start_commit} {end_commit}
git bisect run python3 target_script.py
"""
with open("run_git_bisect.sh", "w") as fp:
fp.write(bash.strip())
bash_result = subprocess.run(
["bash", "run_git_bisect.sh"],
check=False,
capture_output=True,
text=True,
)
print(bash_result.stdout)
# This happens if running the script gives exit code < 0 or other issues
if "error: bisect run failed" in bash_result.stderr:
error_msg = f"Error when running git bisect:\nbash error: {bash_result.stderr}\nbash output:\n{bash_result.stdout}\nset `bad_commit` to `None`."
print(error_msg)
result["status"] = "git bisect failed"
return result
pattern = r"(.+) is the first bad commit"
commits = re.findall(pattern, bash_result.stdout)
bad_commit = None
failure_at_bad_commit = ""
if len(commits) > 0:
bad_commit = commits[0]
_, _, _, failure_at_bad_commit = is_bad_commit(target_test, bad_commit)
print(f"Between `start_commit` {start_commit} and `end_commit` {end_commit}")
print(f"bad_commit: {bad_commit}\n")
result["bad_commit"] = bad_commit
result["status"] = "git bisect found the bad commit."
result["failure_at_workflow_commit"] = failure_at_workflow_commit
result["failure_at_base_commit"] = failure_at_base_commit
result["failure_at_bad_commit"] = failure_at_bad_commit
return result
def get_commit_info(commit, pr_number=None, github_token=None):
"""Get information for a commit via `api.github.com`."""
if commit is None:
return {"commit": None, "pr_number": None, "author": None, "merged_by": None, "parent": None}
author = None
merged_author = None
headers = (
{"Accept": "application/vnd.github+json", "Authorization": f"Bearer {github_token}"} if github_token else {}
)
# Use PR number from environment if not provided
if pr_number is None:
pr_number = os.environ.get("pr_number")
# First, get commit info to check if it's a merge commit
url = f"https://api.github.com/repos/huggingface/transformers/commits/{commit}"
commit_info = requests.get(url, headers=headers).json()
commit_to_query = commit
# Check if this is a merge commit created by GitHub
if commit_info.get("parents") and len(commit_info["parents"]) > 1:
commit_message = commit_info.get("commit", {}).get("message", "")
# Parse message like "Merge 1ac46bed... into 5a67f0a7..."
import re
match = re.match(r"^Merge ([a-f0-9]{40}) into ([a-f0-9]{40})", commit_message)
if match:
# Use the first SHA (the PR commit)
commit_to_query = match.group(1)
# If no PR number yet, try to discover it from the commit.
# The API can return an error dict (e.g. rate limit) instead of a list, so guard with isinstance.
if not pr_number:
url = f"https://api.github.com/repos/huggingface/transformers/commits/{commit_to_query}/pulls"
pr_info_for_commit = requests.get(url, headers=headers).json()
if isinstance(pr_info_for_commit, list) and len(pr_info_for_commit) > 0:
pr_number = pr_info_for_commit[0].get("number")
# If we have a PR number, get author and merged_by info.
# Use .get() throughout: on rate-limit/403 the API returns an error dict, not the expected PR object.
if pr_number:
url = f"https://api.github.com/repos/huggingface/transformers/pulls/{pr_number}"
pr_for_commit = requests.get(url, headers=headers).json()
author = pr_for_commit.get("user", {}).get("login")
merged_by = pr_for_commit.get("merged_by")
if merged_by is not None:
merged_author = merged_by.get("login")
parents = commit_info.get("parents", [])
parent = parents[0]["sha"] if parents else None
if author is None:
author = (commit_info.get("author") or {}).get("login")
return {"commit": commit, "pr_number": pr_number, "author": author, "merged_by": merged_author, "parent": parent}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--start_commit", type=str, required=True, help="The latest commit hash to check.")
parser.add_argument("--end_commit", type=str, required=True, help="The earliest commit hash to check.")
parser.add_argument("--test", type=str, help="The test to check.")
parser.add_argument("--file", type=str, help="The report file.")
parser.add_argument("--output_file", type=str, required=True, help="The path of the output file.")
parser.add_argument(
"--github_token",
type=str,
default=None,
help="GitHub token to avoid API rate limits. Falls back to GITHUB_TOKEN env var.",
)
args = parser.parse_args()
if args.github_token is None:
args.github_token = os.environ.get("GITHUB_TOKEN")
run_idx = os.environ.get("run_idx")
n_runners = os.environ.get("n_runners")
print(f"start_commit: {args.start_commit}")
print(f"end_commit: {args.end_commit}")
# Cache commit info to avoid redundant API calls and reduce rate limit pressure.
commit_info_cache = {}
if len({args.test is None, args.file is None}) != 2:
raise ValueError("Exactly one argument `test` or `file` must be specified.")
if args.test is not None:
commit, status = find_bad_commit(
target_test=args.test, start_commit=args.start_commit, end_commit=args.end_commit
)
with open(args.output_file, "w", encoding="UTF-8") as fp:
fp.write(f"{args.test}\n{commit}\n{status}")
elif os.path.isfile(args.file):
with open(args.file, "r", encoding="UTF-8") as fp:
reports = json.load(fp)
model_with_failures = []
for model in reports:
# We change the format of "new_failures.json" in PR #XXXXX, let's handle both formats for a few weeks.
if "failures" in reports[model]:
if "job_link" in reports[model]:
for device, device_failures in reports[model]["failures"].items():
if device in reports[model]["job_link"]:
for failure in device_failures:
failure["job_link"] = reports[model]["job_link"][device]
del reports[model]["job_link"]
reports[model] = reports[model]["failures"]
# TODO: make this script able to deal with both `single-gpu` and `multi-gpu` via a new argument.
reports[model].pop("multi-gpu", None)
failed_tests = reports[model].get("single-gpu", [])
model_with_failures.extend([(model, test) for test in failed_tests])
if run_idx is not None:
run_idx = int(run_idx)
n_runners = int(n_runners)
num_failed_tests_to_run = len(model_with_failures) // n_runners
start_idx = num_failed_tests_to_run * run_idx
end_idx = num_failed_tests_to_run * (run_idx + 1) if run_idx < n_runners - 1 else len(model_with_failures)
model_with_failures_to_check = model_with_failures[start_idx:end_idx]
model_with_failures = model_with_failures_to_check
failed_tests_with_bad_commits = defaultdict(list)
for model, failure in model_with_failures:
test = failure["line"]
bad_commit_info = find_bad_commit(
target_test=test, start_commit=args.start_commit, end_commit=args.end_commit
)
info = {"test": test}
info.update(bad_commit_info)
bad_commit = bad_commit_info["bad_commit"]
if bad_commit in commit_info_cache:
commit_info = commit_info_cache[bad_commit]
else:
commit_info = get_commit_info(bad_commit, github_token=args.github_token)
commit_info_cache[bad_commit] = commit_info
commit_info_copied = copy.deepcopy(commit_info)
commit_info_copied.pop("commit")
commit_info_copied.update({"workflow_commit": args.start_commit, "base_commit": args.end_commit})
info.update(commit_info_copied)
# put failure message toward the end
info = {k: v for k, v in info.items() if not k.startswith(("failure_at_", "job_link"))} | {
k: v for k, v in info.items() if k.startswith(("failure_at_", "job_link"))
}
failed_tests_with_bad_commits[model].append(info)
reports = {model: {"single-gpu": tests} for model, tests in failed_tests_with_bad_commits.items() if tests}
with open(args.output_file, "w", encoding="UTF-8") as fp:
json.dump(reports, fp, ensure_ascii=False, indent=4)

View File

@@ -0,0 +1,399 @@
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import re
from transformers.configuration_utils import PreTrainedConfig
from transformers.utils import direct_transformers_import
CHECKER_CONFIG = {
"name": "config_attributes",
"label": "Config attributes",
# Approximate: iterates CONFIG_MAPPING at runtime and also reads modeling_*.py files
# in each config's directory via os.listdir(). Deprecated models are skipped.
"cache_globs": ["src/transformers/models/**/configuration_*.py", "src/transformers/models/**/modeling_*.py"],
"check_args": [],
"fix_args": None,
}
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_config_docstrings.py
PATH_TO_TRANSFORMERS = "src/transformers"
# This is to make sure the transformers module imported is the one in the repo.
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
# Usually of small list of allowed attrs, but can be True to allow all
SPECIAL_CASES_TO_ALLOW = {
"Gemma4UnifiedAudioConfig": ["audio_embed_dim"], # Used as meta data for other attributes/properties
"Gemma4UnifiedVisionConfig": [
"patch_size",
"pooling_kernel_size",
], # Used as meta data for other attributes/properties
"MiniCPMV4_6Config": ["drop_vision_last_layer"],
"OpenAIPrivacyFilterConfig": ["classifier_dropout", "output_router_logits", "router_aux_loss_coef"],
"HYV3Config": ["output_router_logits"],
"NougatConfig": ["decoder", "encoder"],
"PI0Config": ["vlm_projection_dim"],
"EuroBertConfig": ["is_causal"], # not used directly, allows causal-bidirectional switch
"Ernie4_5_VL_MoeConfig": ["args"], # BC Alias
"Ernie4_5_VL_MoeTextConfig": ["args"], # BC Alias
"Ernie4_5_VL_MoeVisionConfig": ["args"], # BC Alias
"ExaoneMoeConfig": ["first_k_dense_replace"], # BC for other frameworks
"AfmoeConfig": ["global_attn_every_n_layers", "rope_scaling"],
"LagunaConfig": ["moe_apply_router_weight_on_input"],
"xLSTMConfig": ["add_out_norm", "chunkwise_kernel", "sequence_kernel", "step_kernel"],
"Lfm2Config": ["full_attn_idxs"],
"DiaConfig": ["delay_pattern"],
"BambaConfig": ["attn_layer_indices"],
"Dots1Config": ["max_window_layers"],
"JambaConfig": ["attn_layer_offset", "attn_layer_period", "expert_layer_offset", "expert_layer_period"],
"JetMoeConfig": ["output_router_logits"],
"Phi3Config": ["embd_pdrop"],
"EncodecConfig": ["overlap"],
"XcodecConfig": ["sample_rate", "audio_channels"],
"RecurrentGemmaConfig": ["block_types", "attention_window_size"],
"MambaConfig": ["expand"],
"FalconMambaConfig": ["expand"],
"FSMTConfig": ["langs", "common_kwargs", "early_stopping", "length_penalty", "max_length", "num_beams"],
"GPTNeoConfig": ["attention_types"],
"BlenderbotConfig": ["encoder_no_repeat_ngram_size"],
"EsmConfig": ["is_folding_model"],
"Mask2FormerConfig": ["ignore_value"],
"OneFormerConfig": ["ignore_value", "norm"],
"T5Config": ["feed_forward_proj"],
"MT5Config": ["feed_forward_proj", "tokenizer_class"],
"UMT5Config": ["feed_forward_proj", "tokenizer_class"],
"LongT5Config": ["feed_forward_proj"],
"Pop2PianoConfig": ["feed_forward_proj"],
"BioGptConfig": ["layer_norm_eps"],
"GLPNConfig": ["layer_norm_eps"],
"SegformerConfig": ["layer_norm_eps"],
"CvtConfig": ["layer_norm_eps"],
"PerceiverConfig": ["layer_norm_eps"],
"InformerConfig": ["num_static_real_features", "num_time_features"],
"TimeSeriesTransformerConfig": ["num_static_real_features", "num_time_features"],
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
"SamVisionConfig": ["mlp_ratio"],
"DeepseekOcr2SamVisionConfig": ["mlp_ratio"],
"Sam3VisionConfig": ["backbone_feature_sizes"],
"SamHQVisionConfig": ["mlp_ratio"],
"ClapAudioConfig": ["num_classes"],
"ClvpDecoderConfig": ["add_cross_attention"],
"SpeechT5HifiGanConfig": ["sampling_rate"],
"UdopConfig": ["feed_forward_proj"],
"ZambaConfig": ["attn_layer_offset", "attn_layer_period"],
"MllamaVisionConfig": ["supported_aspect_ratios"],
"LEDConfig": ["classifier_dropout"],
"GPTNeoXConfig": ["rotary_emb_base"],
"ShieldGemma2Config": ["mm_tokens_per_image", "vision_config"],
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
"ModernBertConfig": ["local_attention", "reference_compile"],
"ModernBertDecoderConfig": ["global_attn_every_n_layers", "local_attention", "local_rope_theta"],
"SmolLM3Config": ["no_rope_layer_interval"],
"Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"],
"HiggsAudioV2Config": ["audio_bos_token", "audio_stream_bos_id", "audio_stream_eos_id"],
"HiggsAudioV2TokenizerConfig": ["downsample_factor"],
"Cohere2MoeConfig": ["rope_scaling", "sliding_window_pattern"],
"CsmConfig": ["tie_codebooks_embeddings"],
"DeepseekV2Config": ["norm_topk_prob"],
"DeepseekV4Config": [
# All BC / config-compat surface that the modeling code never reads but
# checkpoints in the wild expose (so we keep accepting them in `__init__`):
# `attention_bias` — V4 has no bias on any linear; kept for parity with V3 configs.
# `n_shared_experts` — V4 always builds exactly one shared MLP; the count
# isn't read because there's no loop over shared experts.
# `norm_topk_prob` — V3 router knob; V4's `DeepseekV4TopKRouter` always normalises.
# `num_key_value_heads` — V4 is shared-KV MQA (always 1); not read at runtime.
# `num_nextn_predict_layers` — MTP layer count from upstream checkpoints; the
# MTP head isn't instantiated by transformers' V4 implementation.
# `router_jitter_noise` — inherited from Mixtral; V4 routers don't apply jitter.
"attention_bias",
"n_shared_experts",
"norm_topk_prob",
"num_key_value_heads",
"num_nextn_predict_layers",
"router_jitter_noise",
],
"EsmFoldConfig": ["esm_ablate_pairwise", "esm_ablate_sequence", "esm_input_dropout", "esm_type"],
"TrunkConfig": ["cpu_grad_checkpoint", "layer_drop"],
"SeamlessM4TConfig": True,
"SeamlessM4Tv2Config": True,
"ConditionalDetrConfig": True,
"DabDetrConfig": True,
"SwitchTransformersConfig": True,
"MaskFormerDetrConfig": True,
"DetrConfig": True,
"DFineConfig": True,
"Deimv2Config": True, # Mixed encoder variants (hybrid/lite) + DFine inheritance
"GroundingDinoConfig": True,
"MMGroundingDinoConfig": True,
"RTDetrConfig": True,
"RTDetrV2Config": True,
"YolosConfig": True,
"Llama4TextConfig": True,
"DPRConfig": True,
"FuyuConfig": True,
"LayoutXLMConfig": True,
"CLIPSegConfig": True,
"DeformableDetrConfig": True,
"DinatConfig": True,
"DonutSwinConfig": True,
"FastSpeech2ConformerConfig": True,
"LayoutLMv2Config": True,
"MaskFormerSwinConfig": True,
"MptConfig": True,
"MptAttentionConfig": True,
"RagConfig": True,
"SpeechT5Config": True,
"SwinConfig": True,
"Swin2SRConfig": True,
"Swinv2Config": True,
"TableTransformerConfig": True,
"TapasConfig": True,
"UniSpeechConfig": True,
"UniSpeechSatConfig": True,
"WavLMConfig": True,
"WhisperConfig": True,
"JukeboxPriorConfig": True,
"Pix2StructTextConfig": True,
"IdeficsConfig": True,
"IdeficsVisionConfig": True,
"IdeficsPerceiverConfig": True,
"GptOssConfig": True,
"LwDetrConfig": True,
"NemotronHConfig": True,
# RfDetr config attributes only used in loss code
"RfDetrConfig": [
"bbox_cost",
"bbox_loss_coefficient",
"class_cost",
"class_loss_coefficient",
"dice_loss_coefficient",
"eos_coefficient",
"focal_alpha",
"giou_cost",
"giou_loss_coefficient",
"mask_class_loss_coefficient",
"mask_dice_loss_coefficient",
"mask_loss_coefficient",
"mask_point_sample_ratio",
],
# Internally uses Got Ocr2 so no need to use in the modeling code as we remap in auto instead
"PPChart2TableConfig": True,
"PPChart2TableVisionConfig": True,
"GlmgaConfig": ["vision_config"],
"Sapiens2Config": [
"num_first_full_attention_layers", # builder attr consumed in __post_init__ to compute num_key_value_heads_per_layer
"num_key_value_attention_heads", # builder attr consumed in __post_init__ to compute num_key_value_heads_per_layer
"num_last_full_attention_layers", # builder attr consumed in __post_init__ to compute num_key_value_heads_per_layer
"flip_pairs", # used externally for post-processing keypoints, not in forward pass
],
}
# Common and important attributes, even if they do not always appear in the modeling files (can be a regex pattern)
ATTRIBUTES_TO_ALLOW = (
# Attr in base `PreTrainedConfig`
"transformers_version",
"architectures",
"chunk_size_feed_forward",
"dtype",
"id2label",
"label2id",
"problem_type",
"tokenizer_class",
"is_encoder_decoder",
"output_hidden_states",
"return_dict",
# Inits related
"initializer_range",
"init_std",
"initializer_factor",
"tie_word_embeddings",
# Special tokens
"bos_index",
"eos_index",
"pad_index",
"unk_index",
"mask_index",
r".+_token_id",
r".+_token_index",
# Processors
"image_seq_length",
"video_seq_length",
"image_size",
"text_config", # may appear as `get_text_config()`
"use_cache",
"out_features",
"out_indices",
"sampling_rate",
# backbone related arguments passed to load_backbone
"use_pretrained_backbone",
"backbone",
"backbone_config",
"use_timm_backbone",
"backbone_kwargs",
# rope attributes may not appear directly in the modeling but are used
"rope_theta",
"partial_rotary_factor",
"max_position_embeddings",
"pretraining_tp",
"use_sliding_window",
"max_window_layers",
# vision attributes that may be used indirectly via merge_with_config_defaults
"vision_feature_layer",
"vision_feature_select_strategy",
"vision_aspect_ratio",
)
def check_attribute_being_used(config_class, attributes, default_value, source_strings):
"""Check if any name in `attributes` is used in one of the strings in `source_strings`
Args:
config_class (`type`):
The configuration class for which the arguments in its `__init__` will be checked.
attributes (`List[str]`):
The name of an argument (or attribute) and its variant names if any.
default_value (`Any`):
A default value for the attribute in `attributes` assigned in the `__init__` of `config_class`.
source_strings (`List[str]`):
The python source code strings in the same modeling directory where `config_class` is defined. The file
containing the definition of `config_class` should be excluded.
"""
# If we can find the attribute used, then it's all good
for attribute in attributes:
for modeling_source in source_strings:
# check if we can find `config.xxx`, `getattr(config, "xxx", ...)` or `getattr(self.config, "xxx", ...)`
if (
f"config.{attribute}" in modeling_source
or f'getattr(config, "{attribute}"' in modeling_source
or f'getattr(self.config, "{attribute}"' in modeling_source
or (
"TextConfig" in config_class.__name__
and f"config.get_text_config().{attribute}" in modeling_source
)
):
return True
# Deal with multi-line cases
elif (
re.search(
rf'getattr[ \t\v\n\r\f]*\([ \t\v\n\r\f]*(self\.)?config,[ \t\v\n\r\f]*"{attribute}"',
modeling_source,
)
is not None
):
return True
# Special cases to be allowed even if not found as used
for attribute in attributes:
# Allow if the default value in the configuration class is different from the one in `PreTrainedConfig`
if (attribute == "is_encoder_decoder" and default_value is True) or attribute == "tie_word_embeddings":
return True
# General exceptions for all models
elif any(re.search(exception, attribute) for exception in ATTRIBUTES_TO_ALLOW):
return True
# Model-specific exceptions
elif config_class.__name__ in SPECIAL_CASES_TO_ALLOW:
model_exceptions = SPECIAL_CASES_TO_ALLOW[config_class.__name__]
# Can be true to allow all attributes, or a list of specific allowed attributes
if (isinstance(model_exceptions, bool) and model_exceptions) or attribute in model_exceptions:
return True
return False
def check_config_attributes_being_used(config_class):
"""Check the arguments in `__init__` of `config_class` are used in the modeling files in the same directory
Args:
config_class (`type`):
The configuration class for which the arguments in its `__init__` will be checked.
"""
# Get the parameters in `__init__` of the configuration class, and the default values if any
signature = dict(inspect.signature(config_class.__init__).parameters)
parameter_names = [x for x in list(signature.keys()) if x not in ["self", "kwargs"]]
parameter_defaults = [signature[param].default for param in parameter_names]
# If `attribute_map` exists, an attribute can have different names to be used in the modeling files, and as long
# as one variant is used, the test should pass
reversed_attribute_map = {}
if len(config_class.attribute_map) > 0:
reversed_attribute_map = {v: k for k, v in config_class.attribute_map.items()}
# Get the path to modeling source files
config_source_file = inspect.getsourcefile(config_class)
model_dir = os.path.dirname(config_source_file)
modeling_paths = [os.path.join(model_dir, fn) for fn in os.listdir(model_dir) if fn.startswith("modeling_")]
# Get the source code strings
modeling_sources = []
for path in modeling_paths:
if os.path.isfile(path):
with open(path, encoding="utf8") as fp:
modeling_sources.append(fp.read())
unused_attributes = []
for config_param, default_value in zip(parameter_names, parameter_defaults):
# `attributes` here is all the variant names for `config_param`
attributes = [config_param]
# some configuration classes have non-empty `attribute_map`, and both names could be used in the
# corresponding modeling files. As long as one of them appears, it is fine.
if config_param in reversed_attribute_map:
attributes.append(reversed_attribute_map[config_param])
if not check_attribute_being_used(config_class, attributes, default_value, modeling_sources):
unused_attributes.append(attributes[0])
return sorted(unused_attributes)
def check_config_attributes():
"""Check the arguments in `__init__` of all configuration classes are used in python files"""
configs_with_unused_attributes = {}
for _config_class in list(CONFIG_MAPPING.values()):
# Skip deprecated models
if "models.deprecated" in _config_class.__module__:
continue
# Some config classes are not in `CONFIG_MAPPING` (e.g. `CLIPVisionConfig`, `Blip2VisionConfig`, etc.)
config_classes_in_module = [
cls
for name, cls in inspect.getmembers(
inspect.getmodule(_config_class),
lambda x: inspect.isclass(x)
and issubclass(x, PreTrainedConfig)
and inspect.getmodule(x) == inspect.getmodule(_config_class),
)
]
for config_class in config_classes_in_module:
unused_attributes = check_config_attributes_being_used(config_class)
if len(unused_attributes) > 0:
configs_with_unused_attributes[config_class.__name__] = unused_attributes
if len(configs_with_unused_attributes) > 0:
error = "The following configuration classes contain unused attributes in the corresponding modeling files:\n"
for name, attributes in configs_with_unused_attributes.items():
error += f"{name}: {attributes}\n"
raise ValueError(error)
if __name__ == "__main__":
check_config_attributes()

View File

@@ -0,0 +1,95 @@
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from transformers.utils import direct_transformers_import
CHECKER_CONFIG = {
"name": "config_docstrings",
"label": "Config docstrings",
# Approximate: iterates CONFIG_MAPPING at runtime via inspect.getsource(), not cache globs.
# Only configs registered in CONFIG_MAPPING are checked; deprecated models are skipped.
"cache_globs": ["src/transformers/models/**/configuration_*.py"],
"check_args": [],
"fix_args": None,
}
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_config_docstrings.py
PATH_TO_TRANSFORMERS = "src/transformers"
# This is to make sure the transformers module imported is the one in the repo.
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
# For example, `[google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)`
_re_checkpoint = re.compile(r"""(?s)@auto_docstring\(.*?checkpoint\s*=\s*["']([^"']+)["']""")
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"DecisionTransformerConfig",
"EncoderDecoderConfig",
"MusicgenConfig",
"RagConfig",
"SpeechEncoderDecoderConfig",
"TimmBackboneConfig",
"TimmWrapperConfig",
"VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig",
"GraniteConfig",
"GraniteMoeConfig",
"GraniteMoeHybridConfig",
"Qwen3MoeConfig",
"GraniteSpeechConfig",
}
def get_checkpoint_from_config_class(config_class):
# source code of `config_class`
config_source = inspect.getsource(config_class)
checkpoints = _re_checkpoint.findall(config_source)
return checkpoints[0] if checkpoints else None
def check_config_docstrings_have_checkpoints():
configs_without_checkpoint = []
for config_class in list(CONFIG_MAPPING.values()):
# Skip deprecated models
if "models.deprecated" in config_class.__module__:
continue
checkpoint = get_checkpoint_from_config_class(config_class)
name = config_class.__name__
if checkpoint is None and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK:
configs_without_checkpoint.append(name)
if len(configs_without_checkpoint) > 0:
message = "\n".join(sorted(configs_without_checkpoint))
raise ValueError(
f"The following configurations don't contain any valid checkpoint:\n{message}\n\n"
"The requirement is to include a link pointing to one of the models of this architecture in the "
"docstring of the config classes listed above. The link should be passed to an `auto_docstring`"
"decorator as follows `@auto_docstring(checkpoint='myorg/mymodel')."
)
if __name__ == "__main__":
check_config_docstrings_have_checkpoints()

1055
utils/check_copies.py Normal file

File diff suppressed because it is too large Load Diff

179
utils/check_doc_toc.py Normal file
View File

@@ -0,0 +1,179 @@
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is responsible for ensuring that all model docs are part of the `_toctree.yml` and cleaning the model
section of the table of content by removing duplicates and sorting the entries in alphabetical order.
Usage (from the root of the repo):
Check that the table of content is properly sorted (used in `make check-repo`):
```bash
python utils/check_doc_toc.py
```
Auto-sort the table of content if it is not properly sorted (used in `make fix-repo`):
```bash
python utils/check_doc_toc.py --fix_and_overwrite
```
"""
import argparse
import os
from collections import defaultdict
import yaml
CHECKER_CONFIG = {
"name": "doc_toc",
"label": "Documentation table of contents",
# Also reads docs/source/en/_toctree.yml; the .md glob catches new/renamed doc files.
"cache_globs": ["docs/**/*.md", "docs/source/en/_toctree.yml"],
"check_args": [],
"fix_args": ["--fix_and_overwrite"],
}
ROOT = os.path.dirname(os.path.dirname(__file__))
TOCTREE_PATH = os.path.join(ROOT, "docs", "source", "en", "_toctree.yml")
DOC_PATH = os.path.join(ROOT, "docs", "source", "en", "model_doc")
def clean_model_doc_toc(model_doc: list[dict]) -> list[dict]:
"""
Cleans a section of the table of content of the model documentation (one specific modality) by removing duplicates
and sorting models alphabetically.
Args:
model_doc (`List[dict]`):
The list of dictionaries extracted from the `_toctree.yml` file for this specific modality.
Returns:
`List[dict]`: List of dictionaries like the input, but cleaned up and sorted.
"""
counts = defaultdict(int)
for doc in model_doc:
counts[doc["local"]] += 1
duplicates = [key for key, value in counts.items() if value > 1]
new_doc = []
for duplicate_key in duplicates:
titles = list({doc["title"] for doc in model_doc if doc["local"] == duplicate_key})
if len(titles) > 1:
raise ValueError(
f"{duplicate_key} is present several times in the documentation table of content at "
"`docs/source/en/_toctree.yml` with different *Title* values. Choose one of those and remove the "
"others."
)
# Only add this once
new_doc.append({"local": duplicate_key, "title": titles[0]})
# Add none duplicate-keys
new_doc.extend([doc for doc in model_doc if counts[doc["local"]] == 1])
# Sort
return sorted(new_doc, key=lambda s: s["title"].lower())
def ensure_all_models_in_toctree(model_doc: list[dict]):
"""Make sure that all models in `model_doc` folder are also part of the `_toctree.yml`. Raise if it's not
the case."""
all_documented_models = {model_doc_file.removesuffix(".md") for model_doc_file in os.listdir(DOC_PATH)} - {"auto"}
all_models_in_toctree = {
model_entry["local"].removeprefix("model_doc/") for section in model_doc for model_entry in section["sections"]
}
# everything alright
if all_documented_models == all_models_in_toctree:
return
documented_but_not_in_toctree = all_documented_models - all_models_in_toctree
in_toctree_but_not_documented = all_models_in_toctree - all_documented_models
error_msg = ""
if len(documented_but_not_in_toctree) > 0:
error_msg += (
f"{documented_but_not_in_toctree} appear(s) inside the folder `model_doc`, but not in the `_toctree.yml`. "
"Please add it/them in their corresponding section inside the `_toctree.yml`."
)
if len(in_toctree_but_not_documented) > 0:
if len(error_msg) > 0:
error_msg += "\n"
error_msg += (
f"{in_toctree_but_not_documented} appear(s) in the `_toctree.yml`, but not inside the folder `model_doc`. "
"Please add a corresponding `model.md` in `model_doc`."
)
raise ValueError(error_msg)
def check_model_doc(overwrite: bool = False):
"""
Check that the content of the table of content in `_toctree.yml` is up-to-date (i.e. it contains all models) and
clean (no duplicates and sorted for the model API doc) and potentially auto-cleans it.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether to just check if the TOC is clean or to auto-clean it (when `overwrite=True`).
"""
with open(TOCTREE_PATH, encoding="utf-8") as f:
content = yaml.safe_load(f.read())
# Get to the API doc
api_idx = 0
while content[api_idx]["title"] != "API":
api_idx += 1
api_doc = content[api_idx]["sections"]
# Then to the model doc
model_idx = 0
while api_doc[model_idx]["title"] != "Models":
model_idx += 1
model_doc = api_doc[model_idx]["sections"]
# Make sure the toctree contains all models
ensure_all_models_in_toctree(model_doc)
# Extract the modalities and clean them one by one.
modalities_docs = [(idx, section) for idx, section in enumerate(model_doc) if "sections" in section]
diff = False
for idx, modality_doc in modalities_docs:
old_modality_doc = modality_doc["sections"]
new_modality_doc = clean_model_doc_toc(old_modality_doc)
if old_modality_doc != new_modality_doc:
diff = True
if overwrite:
model_doc[idx]["sections"] = new_modality_doc
if diff:
if overwrite:
api_doc[model_idx]["sections"] = model_doc
content[api_idx]["sections"] = api_doc
with open(TOCTREE_PATH, "w", encoding="utf-8") as f:
f.write(yaml.dump(content, allow_unicode=True))
else:
raise ValueError(
"The model doc part of the table of content is not properly sorted, run `make fix-repo` to fix this."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_model_doc(args.fix_and_overwrite)

2196
utils/check_docstrings.py Normal file

File diff suppressed because it is too large Load Diff

100
utils/check_doctest_list.py Normal file
View File

@@ -0,0 +1,100 @@
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is responsible for cleaning the list of doctests by making sure the entries all exist and are in
alphabetical order.
Usage (from the root of the repo):
Check that the doctest list is properly sorted and all files exist (used in `make check-repo`):
```bash
python utils/check_doctest_list.py
```
Auto-sort the doctest list if it is not properly sorted (used in `make fix-repo`):
```bash
python utils/check_doctest_list.py --fix_and_overwrite
```
"""
import argparse
import os
CHECKER_CONFIG = {
"name": "doctest_list",
"label": "Doctest list",
# Over-approximation: the checker validates that paths in .txt list files exist and are
# sorted. The broad globs ensure cache invalidation when source files are added/removed.
"cache_globs": [
"utils/not_doctested.txt",
"utils/slow_documentation_tests.txt",
"src/transformers/**/*.py",
"docs/**/*.md",
],
"check_args": [],
"fix_args": ["--fix_and_overwrite"],
}
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_doctest_list.py
REPO_PATH = "."
DOCTEST_FILE_PATHS = ["not_doctested.txt", "slow_documentation_tests.txt"]
def clean_doctest_list(doctest_file: str, overwrite: bool = False):
"""
Cleans the doctest in a given file.
Args:
doctest_file (`str`):
The path to the doctest file to check or clean.
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to fix problems. If `False`, will error when the file is not clean.
"""
non_existent_paths = []
all_paths = []
with open(doctest_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip().split(" ")[0]
path = os.path.join(REPO_PATH, line)
if not (os.path.isfile(path) or os.path.isdir(path)):
non_existent_paths.append(line)
all_paths.append(line)
if len(non_existent_paths) > 0:
non_existent_paths = "\n".join([f"- {f}" for f in non_existent_paths])
raise ValueError(f"`{doctest_file}` contains non-existent paths:\n{non_existent_paths}")
sorted_paths = sorted(all_paths)
if all_paths != sorted_paths:
if not overwrite:
raise ValueError(
f"Files in `{doctest_file}` are not in alphabetical order, run `make fix-repo` to fix "
"this automatically."
)
with open(doctest_file, "w", encoding="utf-8") as f:
f.write("\n".join(sorted_paths) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
for doctest_file in DOCTEST_FILE_PATHS:
doctest_file = os.path.join(REPO_PATH, "utils", doctest_file)
clean_doctest_list(doctest_file, args.fix_and_overwrite)

265
utils/check_dummies.py Normal file
View File

@@ -0,0 +1,265 @@
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is responsible for making sure the dummies in utils/dummies_xxx.py are up to date with the main init.
Why dummies? This is to make sure that a user can always import all objects from `transformers`, even if they don't
have the necessary extra libs installed. Those objects will then raise helpful error message whenever the user tries
to access one of their methods.
Usage (from the root of the repo):
Check that the dummy files are up to date (used in `make check-repo`):
```bash
python utils/check_dummies.py
```
Update the dummy files if needed (used in `make fix-repo`):
```bash
python utils/check_dummies.py --fix_and_overwrite
```
"""
import argparse
import os
import re
CHECKER_CONFIG = {
"name": "dummies",
"label": "Dummy objects",
# Over-approximation: only reads __init__.py and utils/dummy_*_objects.py, but any
# new public object added anywhere could require a dummy update.
"cache_globs": ["src/transformers/**/*.py"],
"check_args": [],
"fix_args": ["--fix_and_overwrite"],
}
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py
PATH_TO_TRANSFORMERS = "src/transformers"
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
# Matches if not is_xxx_available()
_re_test_backend = re.compile(r"^\s+if\s+not\s+\(?is\_[a-z_]*\_available\(\)")
# Template for the dummy objects.
DUMMY_CONSTANT = """
{0} = None
"""
DUMMY_CLASS = """
class {0}(metaclass=DummyObject):
_backends = {1}
def __init__(self, *args, **kwargs):
requires_backends(self, {1})
"""
DUMMY_FUNCTION = """
def {0}(*args, **kwargs):
requires_backends({0}, {1})
"""
def find_backend(line: str) -> str | None:
"""
Find one (or multiple) backend in a code line of the init.
Args:
line (`str`): A code line in an init file.
Returns:
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
`xxx_and_yyy` for instance).
"""
if _re_test_backend.search(line) is None:
return None
backends = [b[0] for b in _re_backend.findall(line)]
backends.sort()
return "_and_".join(backends)
def read_init() -> dict[str, list[str]]:
"""
Read the init and extract backend-specific objects.
Returns:
Dict[str, List[str]]: A dictionary mapping backend name to the list of object names requiring that backend.
"""
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Get to the point we do the actual imports for type checking
line_index = 0
while not lines[line_index].startswith("if TYPE_CHECKING"):
line_index += 1
backend_specific_objects = {}
# Go through the end of the file
while line_index < len(lines):
# If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
if backend is not None:
while not lines[line_index].startswith(" else:"):
line_index += 1
line_index += 1
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
line = lines[line_index]
single_line_import_search = _re_single_line_import.search(line)
if single_line_import_search is not None:
# Single-line imports
objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" " * 12):
# Multiple-line imports (with 3 indent level)
objects.append(line[12:-2])
line_index += 1
backend_specific_objects[backend] = objects
else:
line_index += 1
return backend_specific_objects
def create_dummy_object(name: str, backend_name: str) -> str:
"""
Create the code for a dummy object.
Args:
name (`str`): The name of the object.
backend_name (`str`): The name of the backend required for that object.
Returns:
`str`: The code of the dummy object.
"""
if name.isupper():
return DUMMY_CONSTANT.format(name)
elif name.islower():
return DUMMY_FUNCTION.format(name, backend_name)
else:
return DUMMY_CLASS.format(name, backend_name)
def create_dummy_files(backend_specific_objects: dict[str, list[str]] | None = None) -> dict[str, str]:
"""
Create the content of the dummy files.
Args:
backend_specific_objects (`Dict[str, List[str]]`, *optional*):
The mapping backend name to list of backend-specific objects. If not passed, will be obtained by calling
`read_init()`.
Returns:
`Dict[str, str]`: A dictionary mapping backend name to code of the corresponding backend file.
"""
if backend_specific_objects is None:
backend_specific_objects = read_init()
dummy_files = {}
for backend, objects in backend_specific_objects.items():
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
dummy_file = "# This file is autogenerated by the command `make fix-repo`, do not edit.\n"
dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file
return dummy_files
def check_dummies(overwrite: bool = False):
"""
Check if the dummy files are up to date and maybe `overwrite` with the right content.
Args:
overwrite (`bool`, *optional*, default to `False`):
Whether or not to overwrite the content of the dummy files. Will raise an error if they are not up to date
when `overwrite=False`.
"""
dummy_files = create_dummy_files()
# For special correspondence backend name to shortcut as used in utils/dummy_xxx_objects.py
short_names = {"torch": "pt"}
# Locate actual dummy modules and read their content.
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
dummy_file_paths = {
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") for backend in dummy_files
}
actual_dummies = {}
for backend, file_path in dummy_file_paths.items():
if os.path.isfile(file_path):
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
actual_dummies[backend] = f.read()
else:
actual_dummies[backend] = ""
# Compare actual with what they should be.
for backend in dummy_files:
if dummy_files[backend] != actual_dummies[backend]:
if overwrite:
print(
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
"__init__ has new objects."
)
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
f.write(dummy_files[backend])
else:
# Temporary fix to help people identify which objects introduced are not correctly protected.
found = False
for _actual, _dummy in zip(
actual_dummies["torch"].split("class"), dummy_files["torch"].split("class")
):
if _actual != _dummy:
actual_broken = _actual
dummy_broken = _dummy
found = True
break
if not found:
print("A transient error was found with the dummies, please investigate.")
continue
raise ValueError(
"The main __init__ has objects that are not present in "
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n"
f" It is likely the following objects are responsible, see these excerpts: \n"
f"---------------------------------- Actual -------------------------------------\n"
f" \n {actual_broken} \n"
f"---------------------------------- Dummy -------------------------------------\n"
f" \n {dummy_broken} \n"
"Run `make fix-repo` to fix this."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_dummies(args.fix_and_overwrite)

View File

@@ -0,0 +1,253 @@
#!/usr/bin/env python3
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Check that `import transformers` does not pull in too many modules.
Traces the full import tree triggered by `import transformers` using a custom
``importlib.abc.MetaPathFinder`` and counts every module that gets loaded.
If the count exceeds ``MAX_IMPORT_COUNT`` the check fails, signalling a
potential regression in import speed.
Usage:
python utils/check_import_complexity.py # CI check mode
python utils/check_import_complexity.py --display # show the full import tree
"""
from __future__ import annotations
import argparse
import importlib
import importlib.abc
import sys
import threading
from dataclasses import dataclass, field
from types import ModuleType
from typing import Any
MAX_IMPORT_COUNT = 1000
# ---------------------------------------------------------------------------
# Import-tree data structures
# ---------------------------------------------------------------------------
@dataclass
class ImportNode:
name: str
children: list[ImportNode] = field(default_factory=list)
class LoaderProxy(importlib.abc.Loader):
"""Wrap a real loader to track the import stack during exec_module."""
def __init__(self, wrapped: Any, tracer: ImportTreeTracer, fullname: str):
self._wrapped = wrapped
self._tracer = tracer
self._fullname = fullname
def create_module(self, spec):
if hasattr(self._wrapped, "create_module"):
return self._wrapped.create_module(spec)
return None
def exec_module(self, module: ModuleType) -> None:
self._tracer.push(self._fullname)
try:
if hasattr(self._wrapped, "exec_module"):
self._wrapped.exec_module(module)
elif hasattr(self._wrapped, "load_module"):
self._wrapped.load_module(self._fullname)
else:
raise ImportError(f"Loader for {self._fullname!r} has neither exec_module nor load_module")
finally:
self._tracer.pop()
def __getattr__(self, name: str) -> Any:
return getattr(self._wrapped, name)
class ImportTreeFinder(importlib.abc.MetaPathFinder):
"""Intercept imports to build a parent/child tree of loaded modules."""
def __init__(self, tracer: ImportTreeTracer, original_meta_path: list[Any]):
self._tracer = tracer
self._original = list(original_meta_path)
def find_spec(self, fullname: str, path=None, target=None):
if self._tracer.is_seen(fullname):
return None
for finder in self._original:
try:
spec = finder.find_spec(fullname, path, target) if hasattr(finder, "find_spec") else None
except Exception:
continue
if spec is None:
continue
self._tracer.record(fullname)
if spec.loader is not None:
spec.loader = LoaderProxy(spec.loader, self._tracer, fullname)
return spec
return None
class ImportTreeTracer:
def __init__(self) -> None:
self._local = threading.local()
self._nodes: dict[str, ImportNode] = {}
self._roots: list[ImportNode] = []
self._seen: set[str] = set()
def _stack(self) -> list[str]:
stack = getattr(self._local, "stack", None)
if stack is None:
stack = []
self._local.stack = stack
return stack
def is_seen(self, fullname: str) -> bool:
return fullname in self._seen
def _get_or_create(self, fullname: str) -> ImportNode:
if fullname not in self._nodes:
self._nodes[fullname] = ImportNode(name=fullname)
return self._nodes[fullname]
def record(self, fullname: str) -> None:
if fullname in self._seen:
return
self._seen.add(fullname)
node = self._get_or_create(fullname)
stack = self._stack()
if stack:
parent = self._get_or_create(stack[-1])
if all(c.name != fullname for c in parent.children):
parent.children.append(node)
else:
if all(r.name != fullname for r in self._roots):
self._roots.append(node)
def push(self, fullname: str) -> None:
self._stack().append(fullname)
def pop(self) -> None:
stack = self._stack()
if stack:
stack.pop()
@property
def count(self) -> int:
return len(self._seen)
@property
def roots(self) -> list[ImportNode]:
return self._roots
# ---------------------------------------------------------------------------
# Tracing entry-point
# ---------------------------------------------------------------------------
def trace_import(target: str) -> ImportTreeTracer:
tracer = ImportTreeTracer()
original_meta_path = list(sys.meta_path)
finder = ImportTreeFinder(tracer, original_meta_path)
sys.meta_path.insert(0, finder)
try:
importlib.import_module(target)
finally:
try:
sys.meta_path.remove(finder)
except ValueError:
pass
return tracer
# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------
def format_tree(nodes: list[ImportNode]) -> str:
lines: list[str] = []
def _walk(node: ImportNode, prefix: str, is_last: bool) -> None:
connector = "└── " if is_last else "├── "
lines.append(f"{prefix}{connector}{node.name}")
child_prefix = prefix + (" " if is_last else "")
for i, child in enumerate(node.children):
_walk(child, child_prefix, i == len(node.children) - 1)
for i, root in enumerate(nodes):
_walk(root, "", i == len(nodes) - 1)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> int:
parser = argparse.ArgumentParser(description="Check import complexity for `import transformers`.")
parser.add_argument(
"--display",
action="store_true",
help="Display the full import tree (for debugging regressions).",
)
parser.add_argument(
"--max-count",
type=int,
default=MAX_IMPORT_COUNT,
help=f"Maximum allowed number of imported modules (default: {MAX_IMPORT_COUNT}).",
)
args = parser.parse_args()
try:
tracer = trace_import("transformers")
except Exception as exc:
print(f"ERROR: `import transformers` failed: {exc}", file=sys.stderr)
return 1
if args.display:
print(format_tree(tracer.roots))
print()
print(f"Total modules imported: {tracer.count}")
return 0
if tracer.count > args.max_count:
print(
f"Import complexity regression: `import transformers` triggered {tracer.count} module imports "
f"(maximum allowed: {args.max_count}).\n"
f"\n"
f"Run the following command to display the full import tree and identify the cause:\n"
f"\n"
f" python utils/check_import_complexity.py --display\n"
)
return 1
print(f"Import complexity OK: {tracer.count} modules (max {args.max_count})")
return 0
if __name__ == "__main__":
raise SystemExit(main())

359
utils/check_inits.py Normal file
View File

@@ -0,0 +1,359 @@
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that checks the custom inits of Transformers are well-defined: Transformers uses init files that delay the
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. The goal of this
script is to check the objects defined in both halves are the same.
This also checks the main init properly references all submodules, even if it doesn't import anything from them: every
submodule should be defined as a key of `_import_structure`, with an empty list as value potentially, or the submodule
won't be importable.
Use from the root of the repo with:
```bash
python utils/check_inits.py
```
for a check that will error in case of inconsistencies (used by `make check-repo`).
There is no auto-fix possible here sadly :-(
"""
import collections
import os
import re
from pathlib import Path
CHECKER_CONFIG = {
"name": "inits",
"label": "Init files",
"cache_globs": ["src/transformers/**/__init__.py"],
"check_args": [],
"fix_args": None,
}
# Path is set with the intent you should run this script from the root of the repo.
PATH_TO_TRANSFORMERS = "src/transformers"
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Catches a one-line _import_struct = {xxx}
_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}")
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
# Catches a line if not is_foo_available
_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)")
# Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]")
# Catches a line with an object between quotes and a comma: "MyModel",
_re_quote_object = re.compile(r'^\s+"([^"]+)",')
# Catches a line with objects between brackets only: ["foo", "bar"],
_re_between_brackets = re.compile(r"^\s+\[([^\]]+)\]")
# Catches a line with from foo import bar, bla, boo
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
# Catches a line with try:
_re_try = re.compile(r"^\s*try:")
# Catches a line with else:
_re_else = re.compile(r"^\s*else:")
def find_backend(line: str) -> str | None:
"""
Find one (or multiple) backend in a code line of the init.
Args:
line (`str`): A code line of the main init.
Returns:
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
`xxx_and_yyy` for instance).
"""
if _re_test_backend.search(line) is None:
return None
backends = [b[0] for b in _re_backend.findall(line)]
backends.sort()
return "_and_".join(backends)
def parse_init(init_file) -> tuple[dict[str, list[str]], dict[str, list[str]]] | None:
"""
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
defined.
Args:
init_file (`str`): Path to the init file to inspect.
Returns:
`Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]`: A tuple of two dictionaries mapping backends to list of
imported objects, one for the `_import_structure` part of the init and one for the `TYPE_CHECKING` part of the
init. Returns `None` if the init is not a custom init.
"""
with open(init_file, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Get the to `_import_structure` definition.
line_index = 0
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"):
line_index += 1
# If this is a traditional init, just return.
if line_index >= len(lines):
return None
# First grab the objects without a specific backend in _import_structure
objects = []
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
line = lines[line_index]
# If we have everything on a single line, let's deal with it.
if _re_one_line_import_struct.search(line):
content = _re_one_line_import_struct.search(line).groups()[0]
imports = re.findall(r"\[([^\]]+)\]", content)
for imp in imports:
objects.extend([obj[1:-1] for obj in imp.split(", ")])
line_index += 1
continue
single_line_import_search = _re_import_struct_key_value.search(line)
if single_line_import_search is not None:
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0]
objects.extend(imports)
elif line.startswith(" " * 8 + '"'):
objects.append(line[9:-3])
line_index += 1
# Those are stored with the key "none".
import_dict_objects = {"none": objects}
# Let's continue with backend-specific objects in _import_structure
while not lines[line_index].startswith("if TYPE_CHECKING"):
# If the line is an if not is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
# Check if the backend declaration is inside a try block:
if _re_try.search(lines[line_index - 1]) is None:
backend = None
if backend is not None:
line_index += 1
# Scroll until we hit the else block of try-except-else
while _re_else.search(lines[line_index]) is None:
line_index += 1
line_index += 1
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
line = lines[line_index]
if _re_import_struct_add_one.search(line) is not None:
objects.append(_re_import_struct_add_one.search(line).groups()[0])
elif _re_import_struct_add_many.search(line) is not None:
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
objects.extend(imports)
elif _re_between_brackets.search(line) is not None:
imports = _re_between_brackets.search(line).groups()[0].split(", ")
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
objects.extend(imports)
elif _re_quote_object.search(line) is not None:
objects.append(_re_quote_object.search(line).groups()[0])
elif line.startswith(" " * 8 + '"'):
objects.append(line[9:-3])
elif line.startswith(" " * 12 + '"'):
objects.append(line[13:-3])
line_index += 1
import_dict_objects[backend] = objects
else:
line_index += 1
# At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend
objects = []
while (
line_index < len(lines)
and find_backend(lines[line_index]) is None
and not lines[line_index].startswith("else")
):
line = lines[line_index]
single_line_import_search = _re_import.search(line)
if single_line_import_search is not None:
objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" " * 8):
objects.append(line[8:-2])
line_index += 1
type_hint_objects = {"none": objects}
# Let's continue with backend-specific objects
while line_index < len(lines):
# If the line is an if is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index])
# Check if the backend declaration is inside a try block:
if _re_try.search(lines[line_index - 1]) is None:
backend = None
if backend is not None:
line_index += 1
# Scroll until we hit the else block of try-except-else
while _re_else.search(lines[line_index]) is None:
line_index += 1
line_index += 1
objects = []
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
line = lines[line_index]
single_line_import_search = _re_import.search(line)
if single_line_import_search is not None:
objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" " * 12):
objects.append(line[12:-2])
line_index += 1
type_hint_objects[backend] = objects
else:
line_index += 1
return import_dict_objects, type_hint_objects
def analyze_results(import_dict_objects: dict[str, list[str]], type_hint_objects: dict[str, list[str]]) -> list[str]:
"""
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
Args:
import_dict_objects (`Dict[str, List[str]]`):
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
list of imported objects.
type_hint_objects (`Dict[str, List[str]]`):
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
list of imported objects.
Returns:
`List[str]`: The list of errors corresponding to mismatches.
"""
def find_duplicates(seq):
return [k for k, v in collections.Counter(seq).items() if v > 1]
# If one backend is missing from the other part of the init, error early.
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()):
return ["Both sides of the init do not have the same backends!"]
errors = []
# Find all errors.
for key in import_dict_objects:
# Duplicate imports in any half.
duplicate_imports = find_duplicates(import_dict_objects[key])
if duplicate_imports:
errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}")
duplicate_type_hints = find_duplicates(type_hint_objects[key])
if duplicate_type_hints:
errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}")
# Missing imports in either part of the init.
if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])):
name = "base imports" if key == "none" else f"{key} backend"
errors.append(f"Differences for {name}:")
for a in type_hint_objects[key]:
if a not in import_dict_objects[key]:
errors.append(f" {a} in TYPE_HINT but not in _import_structure.")
for a in import_dict_objects[key]:
if a not in type_hint_objects[key]:
errors.append(f" {a} in _import_structure but not in TYPE_HINT.")
return errors
def get_transformers_submodules() -> list[str]:
"""
Returns the list of Transformers submodules.
"""
submodules = []
for path, directories, files in os.walk(PATH_TO_TRANSFORMERS):
for folder in directories:
# Ignore private modules
if folder.startswith("_"):
directories.remove(folder)
continue
# Ignore leftovers from branches (empty folders apart from pycache)
if len(list((Path(path) / folder).glob("*.py"))) == 0:
continue
short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS))
submodule = short_path.replace(os.path.sep, ".")
submodules.append(submodule)
for fname in files:
if fname == "__init__.py":
continue
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
if len(submodule.split(".")) == 1:
submodules.append(submodule)
return submodules
IGNORE_SUBMODULES = [
"convert_pytorch_checkpoint_to_tf2",
"models.esm.openfold_utils",
"safetensors_conversion",
"modeling_gguf_pytorch_utils",
"kernels.falcon_mamba",
"kernels",
]
def check_submodules():
"""
Check all submodules of Transformers are properly registered in the main init. Error otherwise.
"""
# This is to make sure the transformers module imported is the one in the repo.
from transformers.utils import direct_transformers_import
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
import_structure_keys = set(transformers._import_structure.keys())
# This contains all the base keys of the _import_structure object defined in the init, but if the user is missing
# some optional dependencies, they may not have all of them. Thus we read the init to read all additions and
# (potentiall re-) add them.
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r") as f:
init_content = f.read()
import_structure_keys.update(set(re.findall(r"import_structure\[\"([^\"]*)\"\]", init_content)))
module_not_registered = [
module
for module in get_transformers_submodules()
if module not in IGNORE_SUBMODULES and module not in import_structure_keys
]
if len(module_not_registered) > 0:
list_of_modules = "\n".join(f"- {module}" for module in module_not_registered)
raise ValueError(
"The following submodules are not properly registered in the main init of Transformers:\n"
f"{list_of_modules}\n"
"Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value."
)
if __name__ == "__main__":
# This entire files needs an overhaul
pass

View File

@@ -0,0 +1,58 @@
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
from get_test_info import get_tester_classes
if __name__ == "__main__":
failures = []
pattern = os.path.join("tests", "models", "**", "test_modeling_*.py")
test_files = glob.glob(pattern)
for test_file in test_files:
tester_classes = get_tester_classes(test_file)
for tester_class in tester_classes:
# A few tester classes don't have `parent` parameter in `__init__`.
# TODO: deal this better
try:
tester = tester_class(parent=None)
except Exception:
continue
if hasattr(tester, "get_config"):
config = tester.get_config()
for k, v in config.to_dict().items():
if isinstance(v, int):
target = None
if k == "vocab_size":
target = 100
elif k == "max_position_embeddings":
target = 128
elif k in ["hidden_size", "d_model"]:
target = 40
elif k == ["num_layers", "num_hidden_layers", "num_encoder_layers", "num_decoder_layers"]:
target = 5
if target is not None and v > target:
failures.append(
f"{tester_class.__name__} will produce a `config` of type `{config.__class__.__name__}`"
f' with config["{k}"] = {v} which is too large for testing! Set its value to be smaller'
f" than {target}."
)
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))

View File

@@ -0,0 +1,124 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Keep `## Rules reference` section of docs/source/en/modeling_rules.md in sync
with the rules defined in utils/rules.toml via the installed mlinter package.
Usage (from the root of the repo):
Check everything is up to date (used in ``make check-repo``):
```bash
python utils/check_modeling_rules_doc.py
```
Auto-regenerate if out of date (used in ``make fix-repo``):
```bash
python utils/check_modeling_rules_doc.py --fix_and_overwrite
```
"""
import argparse
from pathlib import Path
CHECKER_CONFIG = {
"name": "modeling_rules_doc",
"label": "Modeling rules documentation",
# Depends on utils/rules.toml plus the installed `mlinter` package output,
# which cannot be fully expressed as repo cache globs for the checker cache.
"cache_globs": None,
"check_args": ["--rules-toml", "utils/rules.toml"],
"fix_args": ["--rules-toml", "utils/rules.toml", "--fix_and_overwrite"],
}
ROOT = Path(__file__).resolve().parent.parent
DOC_PATH = ROOT / "docs" / "source" / "en" / "modeling_rules.md"
RULES_TOML_PATH = ROOT / "utils" / "rules.toml"
BEGIN_MARKER = "<!-- BEGIN RULES REFERENCE -->"
END_MARKER = "<!-- END RULES REFERENCE -->"
def _require_mlinter():
try:
import mlinter
from mlinter import mlinter as mlinter_impl
except ModuleNotFoundError as error:
raise ModuleNotFoundError(
"This script requires the standalone `transformers-mlinter` package. "
'Install the repo quality dependencies with `pip install -e ".[quality]"` and retry.'
) from error
return mlinter, mlinter_impl
def _resolve_path(path: Path) -> Path:
return path if path.is_absolute() else ROOT / path
def generate_rules_reference(rule_specs_path: Path = RULES_TOML_PATH) -> str:
mlinter, mlinter_impl = _require_mlinter()
# Reuse mlinter's registry-switching helper so docs rendering reflects the repo-local rule file.
with mlinter_impl._using_rule_specs(_resolve_path(rule_specs_path)):
return mlinter.render_rules_reference()
def check_modeling_rules_doc(overwrite: bool = False, rule_specs_path: Path = RULES_TOML_PATH):
with DOC_PATH.open(encoding="utf-8") as f:
content = f.read()
begin_idx = content.find(BEGIN_MARKER)
end_idx = content.find(END_MARKER)
if begin_idx == -1 or end_idx == -1:
raise ValueError(
f"Could not find {BEGIN_MARKER} and {END_MARKER} markers in {DOC_PATH}. "
"These markers delimit the auto-generated rules reference section."
)
after_begin = begin_idx + len(BEGIN_MARKER)
expected = "\n\n" + generate_rules_reference(rule_specs_path) + "\n"
current = content[after_begin:end_idx]
if current == expected:
return
if overwrite:
new_content = content[:after_begin] + expected + content[end_idx:]
with DOC_PATH.open("w", encoding="utf-8") as f:
f.write(new_content)
print(f"Updated rules reference in {DOC_PATH}")
else:
raise ValueError(
"The rules reference section in docs/source/en/modeling_rules.md is out of sync "
"with utils/rules.toml. Run `make fix-repo` to regenerate it."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--rules-toml",
type=Path,
default=RULES_TOML_PATH,
help="Path to a rules TOML file. Defaults to utils/rules.toml.",
)
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
try:
check_modeling_rules_doc(args.fix_and_overwrite, args.rules_toml)
except ModuleNotFoundError as error:
raise SystemExit(str(error)) from error

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Thin local entrypoint for the external mlinter package."""
import sys
from pathlib import Path
CHECKER_CONFIG = {
"name": "modeling_structure",
"label": "Modeling file structure",
"cache_globs": [
"src/transformers/models/**/modeling_*.py",
"src/transformers/models/**/modular_*.py",
"src/transformers/models/**/configuration_*.py",
],
"check_args": ["--rules-toml", "utils/rules.toml"],
"fix_args": None,
}
RULES_TOML_PATH = Path(__file__).resolve().with_name("rules.toml")
def _require_mlinter():
try:
import mlinter
except ModuleNotFoundError as error:
raise ModuleNotFoundError(
"This script requires the standalone `transformers-mlinter` package. "
'Install the repo quality dependencies with `pip install -e ".[quality]"` and retry.'
) from error
return mlinter
def _add_default_rules_toml(argv: list[str]) -> list[str]:
if any(arg == "--rules-toml" or arg.startswith("--rules-toml=") for arg in argv[1:]):
return argv
return [argv[0], "--rules-toml", str(RULES_TOML_PATH), *argv[1:]]
if __name__ == "__main__":
try:
sys.argv = _add_default_rules_toml(sys.argv)
raise SystemExit(_require_mlinter().main())
except ModuleNotFoundError as error:
raise SystemExit(str(error)) from error

View File

@@ -0,0 +1,284 @@
import argparse
import difflib
import glob
import logging
import multiprocessing
import os
import shutil
import subprocess
from functools import partial
from create_dependency_mapping import find_priority_list
# Console for rich printing
from modular_model_converter import convert_modular_file, run_ruff
from rich.console import Console
from rich.syntax import Syntax
CHECKER_CONFIG = {
"name": "modular_conversion",
"label": "Modular file conversions",
# Globs the modular sources; also reads generated modeling_*.py at runtime for diffing.
"cache_globs": ["src/transformers/models/**/modular_*.py", "src/transformers/models/**/modeling_*.py"],
"check_args": [],
"fix_args": ["--fix_and_overwrite"],
}
logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
console = Console()
BACKUP_EXT = ".modular_backup"
def process_file(
modular_file_path,
generated_modeling_content,
file_type="modeling_",
show_diff=True,
):
file_name_prefix = file_type.split(".*")[0]
file_name_suffix = file_type.split(".*")[-1] if ".*" in file_type else ""
file_path = modular_file_path.replace("modular_", f"{file_name_prefix}_").replace(".py", f"{file_name_suffix}.py")
# Read the actual modeling file
with open(file_path, "r", encoding="utf-8") as modeling_file:
content = modeling_file.read()
diff = difflib.unified_diff(
generated_modeling_content[file_type].splitlines(),
content.splitlines(),
fromfile=f"{file_path}_generated",
tofile=f"{file_path}",
lineterm="",
)
diff_list = list(diff)
# Check for differences
if diff_list:
# first save the copy of the original file, to be able to restore it later
shutil.copy(file_path, file_path + BACKUP_EXT)
# we always save the generated content, to be able to update dependant files
with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file:
modeling_file.write(generated_modeling_content[file_type])
if not show_diff:
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
if show_diff:
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
diff_text = "\n".join(diff_list)
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
console.print(syntax)
return 1
else:
return 0
def convert_and_run_ruff(modular_file_path: str) -> dict[str, str]:
"""From a modular file, convert it and return all the contents of the file as string.
We need this function, because `ruff` needs the final filename to apply all rules correctly, so to get the
output as a string, we need to save a temporary file with similar name, run ruff, and re-read the temporary file"""
# Generate the expected modeling content
generated_modeling_content = convert_modular_file(modular_file_path)
# Temporary save the files with similar names to run `ruff` correctly, then re-read the result after linting/formatting
for file_type in generated_modeling_content:
file_name_prefix = file_type.split(".*")[0]
file_name_suffix = file_type.split(".*")[-1] if ".*" in file_type else ""
temp_file_name = modular_file_path.replace("modular_", f"{file_name_prefix}_").replace(
".py", f"_temp_pattern__{file_name_suffix}.py"
)
# Write the file only temporarily
with open(temp_file_name, "w") as f:
f.write(generated_modeling_content[file_type])
# Run ruff on the new file (with similar name pattern as the original one)
run_ruff(temp_file_name)
with open(temp_file_name, "r") as f:
generated_modeling_content[file_type] = f.read()
# delete file
os.remove(temp_file_name)
return generated_modeling_content
def compare_files(modular_file_path, show_diff=True):
# Generate the expected modeling content
generated_modeling_content = convert_and_run_ruff(modular_file_path)
diff = 0
for file_type in generated_modeling_content:
diff += process_file(modular_file_path, generated_modeling_content, file_type, show_diff)
return diff
# Changes to any of these files can alter the generated output for every modular model,
# so touching them must force a full re-check (see `converter_changed_in_diff`).
CONVERTER_FILES = {
"utils/modular_model_converter.py",
"utils/create_dependency_mapping.py",
}
def _get_modified_files():
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
return (
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
.decode("utf-8")
.split()
)
def get_models_in_diff():
"""
Finds all models that have been modified in the diff.
Returns:
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
"""
modified_files = _get_modified_files()
# Matches both modelling files and tests
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
model_names = set()
for file_path in relevant_modified_files:
model_name = file_path.split("/")[-2]
model_names.add(model_name)
return model_names
def converter_changed_in_diff():
"""Whether the diff touches a file that can change conversion output for every model."""
return any(f in CONVERTER_FILES for f in _get_modified_files())
def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
"""
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
Model is in the diff -> not guaranteed to have no differences
Dependency is in the diff -> not guaranteed to have no differences
Otherwise -> guaranteed to have no differences
Args:
modular_file_path: The path to the modular file.
dependencies: A dictionary containing the dependencies of each modular file.
models_in_diff: A set containing the names of the models that have been modified.
Returns:
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
"""
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
if model_name in models_in_diff:
return False
for dep in dependencies[modular_file_path]:
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
dependency_model_name = dep.split(".")[-2]
if dependency_model_name in models_in_diff:
return False
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
parser.add_argument(
"--files", default=["all"], type=str, nargs="+", help="List of modular_xxx.py files to compare."
)
parser.add_argument(
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
)
parser.add_argument("--check_all", action="store_true", help="Check all files, not just the ones in the diff.")
parser.add_argument(
"--num_workers",
default=-1,
type=int,
help="The number of workers to run. Default is -1, which means the number of CPU cores.",
)
args = parser.parse_args()
if args.files == ["all"]:
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
if args.num_workers == -1:
args.num_workers = multiprocessing.cpu_count()
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
# script will do nothing.
current_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip()
if current_branch == "main":
console.print(
"[bold red]You are developing on the main branch. We cannot identify the list of changed files and will have to check all files. This may take a while.[/bold red]"
)
models_in_diff = {file_path.split("/")[-2] for file_path in args.files}
elif converter_changed_in_diff():
# The converter (or its dependency-mapping helper) is in the diff: its output can shift
# for any model, so restrict-by-diff would miss regressions. Force a full check.
console.print("[bold yellow]Converter change detected in diff; checking all modular files.[/bold yellow]")
args.check_all = True
models_in_diff = {file_path.split("/")[-2] for file_path in args.files}
else:
models_in_diff = get_models_in_diff()
if not models_in_diff and not args.check_all:
exit(0)
non_matching_files = []
ordered_files, dependencies = find_priority_list(args.files)
flat_ordered_files = [item for sublist in ordered_files for item in sublist]
# ordered_files is a *sorted* list of lists of filepaths
# - files from the first list do NOT depend on other files
# - files in the second list depend on files from the first list
# - files in the third list depend on files from the second and (optionally) the first list
# - ... and so on
# files (models) within the same list are *independent* of each other;
# we start applying modular conversion to each list in parallel, starting from the first list
try:
for dependency_level_files in ordered_files:
# Filter files guaranteed no diff
files_to_check = []
for file_path in dependency_level_files:
if args.check_all or not guaranteed_no_diff(file_path, dependencies, models_in_diff):
files_to_check.append(file_path)
if not files_to_check:
continue
# Process files with diff
num_workers = min(args.num_workers, len(files_to_check))
with multiprocessing.Pool(num_workers) as p:
try:
is_changed_flags = p.map(
partial(compare_files, show_diff=not args.fix_and_overwrite),
files_to_check,
)
except Exception as e:
console.print(
f"[bold red]Failed to convert one or more files in batch: {files_to_check}[/bold red]"
)
console.print(f"[bold red]Error: {e}[/bold red]")
# Try to process files individually to identify which one failed
is_changed_flags = []
for file_path in files_to_check:
try:
result = compare_files(file_path, show_diff=not args.fix_and_overwrite)
is_changed_flags.append(result)
except Exception as individual_error:
console.print(f"[bold red]Failed to convert {file_path}: {individual_error}[/bold red]")
is_changed_flags.append(0) # Mark as no change to continue processing
# Collect changed files and their original paths
for is_changed, file_path in zip(is_changed_flags, files_to_check):
if is_changed:
non_matching_files.append(file_path)
# Update changed models, after each round of conversions
# (save model folder name)
models_in_diff.add(file_path.split("/")[-2])
finally:
# Restore overwritten files by modular (if needed)
backup_files = glob.glob("**/*" + BACKUP_EXT, recursive=True)
for backup_file_path in backup_files:
overwritten_path = backup_file_path.replace(BACKUP_EXT, "")
if not args.fix_and_overwrite and os.path.exists(overwritten_path):
shutil.copy(backup_file_path, overwritten_path)
os.remove(backup_file_path)
if non_matching_files and not args.fix_and_overwrite:
diff_models = set(file_path.split("/")[-2] for file_path in non_matching_files) # noqa
models_str = "\n - " + "\n - ".join(sorted(diff_models))
raise ValueError(f"Some diff and their modeling code did not match. Models in diff:{models_str}")

View File

@@ -0,0 +1,101 @@
import re
from transformers.pipelines import SUPPORTED_TASKS, Pipeline
CHECKER_CONFIG = {
"name": "pipeline_typing",
"label": "Pipeline type hints",
"cache_globs": ["src/transformers/pipelines/__init__.py"],
"check_args": [],
"fix_args": ["--fix_and_overwrite"],
}
HEADER = """
# fmt: off
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# The part of the file below was automatically generated from the code.
# Do NOT edit this part of the file manually as any edits will be overwritten by the generation
# of the file. If any change should be done, please apply the changes to the `pipeline` function
# below and run `python utils/check_pipeline_typing.py --fix_and_overwrite` to update the file.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
from typing import Literal, overload
"""
FOOTER = """
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# The part of the file above was automatically generated from the code.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# fmt: on
"""
TASK_PATTERN = "task: str | None = None"
def main(pipeline_file_path: str, fix_and_overwrite: bool = False):
with open(pipeline_file_path, "r") as file:
content = file.read()
# extract generated code in between <generated-code> and </generated-code>
current_generated_code = re.search(r"# <generated-code>(.*)# </generated-code>", content, re.DOTALL).group(1)
content_without_generated_code = content.replace(current_generated_code, "")
# extract pipeline signature in between `def pipeline` and `-> Pipeline`
pipeline_signature = re.search(r"def pipeline(.*) -> Pipeline:", content_without_generated_code, re.DOTALL).group(
1
)
pipeline_signature = pipeline_signature.replace("(\n ", "(") # start of the signature
pipeline_signature = pipeline_signature.replace(",\n ", ", ") # intermediate arguments
pipeline_signature = pipeline_signature.replace(",\n)", ")") # end of the signature
# collect and sort available pipelines
pipelines = [(f'"{task}"', task_info["impl"]) for task, task_info in SUPPORTED_TASKS.items()]
pipelines = sorted(pipelines, key=lambda x: x[0])
pipelines.insert(0, (None, Pipeline))
# generate new `pipeline` signatures
new_generated_code = ""
for task, pipeline_class in pipelines:
if TASK_PATTERN not in pipeline_signature:
raise ValueError(f"Can't find `{TASK_PATTERN}` in pipeline signature: {pipeline_signature}")
pipeline_type = pipeline_class if isinstance(pipeline_class, str) else pipeline_class.__name__
new_pipeline_signature = pipeline_signature.replace(TASK_PATTERN, f"task: Literal[{task}]")
new_generated_code += f"@overload\ndef pipeline{new_pipeline_signature} -> {pipeline_type}: ...\n"
new_generated_code = HEADER + new_generated_code + FOOTER
new_generated_code = new_generated_code.rstrip("\n") + "\n"
if new_generated_code != current_generated_code and fix_and_overwrite:
print(f"Updating {pipeline_file_path}...")
wrapped_current_generated_code = "# <generated-code>" + current_generated_code + "# </generated-code>"
wrapped_new_generated_code = "# <generated-code>" + new_generated_code + "# </generated-code>"
content = content.replace(wrapped_current_generated_code, wrapped_new_generated_code)
# write content to file
with open(pipeline_file_path, "w") as file:
file.write(content)
elif new_generated_code != current_generated_code and not fix_and_overwrite:
message = (
f"Found inconsistencies in {pipeline_file_path}. "
"Run `python utils/check_pipeline_typing.py --fix_and_overwrite` to fix them."
)
raise ValueError(message)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
parser.add_argument(
"--pipeline_file_path",
type=str,
default="src/transformers/pipelines/__init__.py",
help="Path to the pipeline file.",
)
args = parser.parse_args()
main(args.pipeline_file_path, args.fix_and_overwrite)

1507
utils/check_repo.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,57 @@
import argparse
import json
import subprocess
def get_runner_status(target_runners, token):
offline_runners = []
cmd = [
"curl",
"-H",
"Accept: application/vnd.github+json",
"-H",
f"Authorization: Bearer {token}",
"https://api.github.com/repos/huggingface/transformers/actions/runners",
]
output = subprocess.run(cmd, check=False, shell=True, stdout=subprocess.PIPE)
o = output.stdout.decode("utf-8")
status = json.loads(o)
runners = status["runners"]
for runner in runners:
if runner["name"] in target_runners:
if runner["status"] == "offline":
offline_runners.append(runner)
# save the result so we can report them on Slack
with open("offline_runners.txt", "w") as fp:
fp.write(json.dumps(offline_runners))
if len(offline_runners) > 0:
failed = "\n".join([x["name"] for x in offline_runners])
raise ValueError(f"The following runners are offline:\n{failed}")
if __name__ == "__main__":
def list_str(values):
return values.split(",")
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--target_runners",
default=None,
type=list_str,
required=True,
help="Comma-separated list of runners to check status.",
)
parser.add_argument(
"--token", default=None, type=str, required=True, help="A token that has actions:read permission."
)
args = parser.parse_args()
get_runner_status(args.target_runners, args.token)

71
utils/check_types.py Normal file
View File

@@ -0,0 +1,71 @@
"""Run ty type checking on specified directories.
Usage:
python utils/check_types.py src/transformers/utils src/transformers/generation
"""
import subprocess
import sys
CHECKER_CONFIG = {
"name": "types",
"label": "Type annotations",
# For contributors:
# - `check_args` below are the exact roots passed to `ty check`.
# - `cache_globs` here are only used by `utils/checkers.py` to decide when a
# previously clean `types` run can be reused from cache.
# ty follows imports *beyond* the checked roots, so the cache key must cover every source file
# that could change a result -- not just the explicitly-checked paths. We hash the whole package
# (plus the standalone .circleci target) so any source edit busts the cache and forces a
# re-check. Otherwise a cached pass could silently hide a newly-introduced error in a
# transitively-imported module that ty pulls in but that isn't one of the checked roots.
"cache_globs": [
"src/transformers/**/*.py",
".circleci/create_circleci_config.py",
],
"check_args": [
"src/transformers/_typing.py",
"src/transformers/cli",
"src/transformers/modeling_utils.py",
"src/transformers/utils",
"src/transformers/generation",
"src/transformers/pipelines/__init__.py",
"src/transformers/pipelines/feature_extraction.py",
"src/transformers/pipelines/image_feature_extraction.py",
"src/transformers/pipelines/video_classification.py",
"src/transformers/quantizers",
".circleci/create_circleci_config.py",
"src/transformers/dependency_versions_table.py",
"src/transformers/dependency_versions_check.py",
"src/transformers/conversion_mapping.py",
"src/transformers/time_series_utils.py",
"src/transformers/debug_utils.py",
"src/transformers/hyperparameter_search.py",
"src/transformers/pytorch_utils.py",
"src/transformers/file_utils.py",
"src/transformers/trainer_jit_checkpoint.py",
"src/transformers/trainer_optimizer.py",
],
"fix_args": None,
}
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <directory> [<directory> ...]")
sys.exit(1)
directories = sys.argv[1:]
print(f"Running ty check on: {', '.join(directories)}")
# `--error-on-warning` makes ty exit non-zero on warning-level diagnostics (e.g.
# possibly-missing-attribute), not just errors. Without it, warnings print but ty exits 0, so
# `make typing` and CI both pass and the issue is never caught before commit.
result = subprocess.run(
["ty", "check", "--respect-ignore-files", "--error-on-warning", "--exclude", "**/*_pb*", *directories],
)
sys.exit(result.returncode)
if __name__ == "__main__":
main()

641
utils/checkers.py Normal file
View File

@@ -0,0 +1,641 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unified runner for check/fix scripts.
Usage:
python utils/checkers.py copies,modular_conversion,doc_toc
python utils/checkers.py copies,modular_conversion,doc_toc --fix
python utils/checkers.py copies,doc_toc --keep-going
python utils/checkers.py all
python utils/checkers.py all --fix
Plugin system
-------------
Each checker module declares a ``CHECKER_CONFIG`` dict (extracted via ``ast.literal_eval``,
no import needed — this keeps discovery fast and avoids executing checker code at scan time).
See any ``check_*.py`` file for the schema.
Cache semantics of ``cache_globs``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``cache_globs`` lists the file patterns whose content is hashed to decide whether a checker
can be skipped. **Not all globs are exact reflections of the checker's runtime behaviour.**
* Some checkers introspect the live ``transformers`` module (``check_repo``,
``check_config_docstrings``, ``check_config_attributes``, ``update_metadata``), so their
globs are necessarily *approximations* of the true dependency set.
* Some checkers over-approximate (``check_dummies``, ``check_doctest_list``): any change
inside the broad glob forces a re-run even if the checker wouldn't look at that file.
This is safe—just less cache-efficient.
* Some checkers rely on external state (network, git history, installed packages) that
cannot be captured by cache globs at all (``add_dates``, ``imports``).
Each ``CHECKER_CONFIG`` that is an approximation has an inline comment explaining the
gap. For contributors: ``check_args`` control what a checker runs on, while
``cache_globs`` only control when the cache is invalidated. When in doubt, use
``--no-cache`` to force a full run.
"""
import argparse
import ast
import hashlib
import itertools
import json
import os
import shutil
import subprocess
import sys
import threading
import time
import warnings
from collections import deque
from pathlib import Path
UTILS_DIR = Path(__file__).parent
REPO_ROOT = UTILS_DIR.parent
CACHE_PATH = UTILS_DIR / ".checkers_cache.json"
# Required keys in each module's CHECKER_CONFIG dict.
_CHECKER_CONFIG_KEYS = {"name", "label", "cache_globs", "check_args", "fix_args"}
def _discover_checkers() -> tuple[dict, dict]:
"""Scan utils/*.py for CHECKER_CONFIG dicts using AST (no imports).
Each checker module may define a top-level ``CHECKER_CONFIG`` dict with
keys: name, label, cache_globs, check_args, fix_args.
Returns (checkers_dict, cache_globs_dict) matching the shapes of
the old CHECKERS and CHECKER_CACHE_GLOBS registries.
"""
checkers = {}
cache_globs = {}
for py_file in sorted(UTILS_DIR.glob("*.py")):
if py_file.name == Path(__file__).name:
continue
try:
tree = ast.parse(py_file.read_text(encoding="utf-8"), filename=str(py_file))
except SyntaxError:
continue
config = None
for node in ast.iter_child_nodes(tree):
if (
isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and node.targets[0].id == "CHECKER_CONFIG"
):
try:
config = ast.literal_eval(node.value)
except (ValueError, TypeError):
pass
break
if config is None:
continue
missing = _CHECKER_CONFIG_KEYS - set(config)
if missing:
warnings.warn(
f"CHECKER_CONFIG in {py_file.name} is missing keys: {', '.join(sorted(missing))}. Skipping.",
stacklevel=1,
)
continue
name = config["name"]
if name in checkers:
warnings.warn(
f"Duplicate checker name {name!r} in {py_file.name}, already defined by {checkers[name][1]}",
stacklevel=1,
)
checkers[name] = (
config["label"],
py_file.name,
config["check_args"],
config["fix_args"],
)
if config["cache_globs"] is not None:
cache_globs[name] = config["cache_globs"]
return checkers, cache_globs
# Inline checkers have no separate script file; they use custom runner functions below.
# fix_args=[] marks a checker as fix-capable (its custom runner handles --fix internally);
# fix_args=None marks a check-only entry that `make fix-repo` should silently skip.
_INLINE_CHECKERS = {
"deps_table": ("Dependency versions table", None, None, []),
"imports": ("Public imports", None, None, None),
"import_complexity": ("Import complexity", "check_import_complexity.py", [], None),
"ruff_check": ("Ruff linting", None, None, []),
"ruff_format": ("Ruff formatting", None, None, []),
}
_INLINE_CACHE_GLOBS = {
# Also generates/checks src/transformers/dependency_versions_table.py.
"deps_table": ["setup.py", "pyproject.toml", "src/transformers/dependency_versions_table.py"],
# Approximate: runs `from transformers import *` at runtime; depends on the full
# Python environment, not just these files. Broad globs used as a safe upper bound.
"imports": ["src/transformers/**/__init__.py", "src/transformers/**/*.py"],
# Approximate: ruff applies its own ignore rules from pyproject.toml at runtime.
"ruff_check": [
"examples/**/*.py",
"tests/**/*.py",
"src/**/*.py",
"utils/**/*.py",
"scripts/**/*.py",
".circleci/create_circleci_config.py",
"benchmark/**/*.py",
"benchmark_v2/**/*.py",
"setup.py",
"conftest.py",
],
"ruff_format": [
"examples/**/*.py",
"tests/**/*.py",
"src/**/*.py",
"utils/**/*.py",
"scripts/**/*.py",
".circleci/create_circleci_config.py",
"benchmark/**/*.py",
"benchmark_v2/**/*.py",
"setup.py",
"conftest.py",
],
}
# Build the registries: discovered modules + inline custom runners.
_discovered_checkers, _discovered_cache_globs = _discover_checkers()
CHECKERS = {**_discovered_checkers, **_INLINE_CHECKERS}
CHECKER_CACHE_GLOBS = {**_discovered_cache_globs, **_INLINE_CACHE_GLOBS}
def get_checker_cache_globs(checker_name: str) -> list[str] | None:
"""Return the cache inputs for a checker, including its implementation files."""
globs = CHECKER_CACHE_GLOBS.get(checker_name)
if globs is None:
return None
cache_globs = [*globs, str(Path("utils") / Path(__file__).name)]
script = CHECKERS[checker_name][1]
if script is not None:
cache_globs.append(str(Path("utils") / script))
return cache_globs
class CheckerCache:
"""Disk-backed cache that tracks file content hashes per checker.
For each checker that declares cache globs in CHECKER_CACHE_GLOBS, we compute
a single digest over all matching files. If the digest matches the stored
value from the last clean (rc == 0) run, the checker can be skipped.
"""
def __init__(self, path: Path | None = None):
self._path = CACHE_PATH if path is None else path
self._data = self._load()
def _load(self) -> dict:
try:
return json.loads(self._path.read_text(encoding="utf-8"))
except (FileNotFoundError, json.JSONDecodeError, OSError):
return {}
def save(self) -> None:
try:
self._path.write_text(json.dumps(self._data, sort_keys=True, indent=2) + "\n", encoding="utf-8")
except OSError:
pass
@staticmethod
def _digest_files(globs: list[str]) -> str:
"""Compute a single SHA-256 over sorted file paths + contents."""
h = hashlib.sha256()
paths = set()
for pattern in globs:
paths.update(REPO_ROOT.glob(pattern))
for p in sorted(paths):
if p.is_file():
h.update(str(p.relative_to(REPO_ROOT)).encode())
h.update(p.read_bytes())
return h.hexdigest()
def is_current(self, checker_name: str) -> bool:
"""Return True if the checker's files haven't changed since last clean run."""
globs = get_checker_cache_globs(checker_name)
if globs is None:
return False
return self._data.get(checker_name) == self._digest_files(globs)
def update(self, checker_name: str) -> None:
"""Record current digest for a checker (call after a clean run)."""
globs = get_checker_cache_globs(checker_name)
if globs is None:
return
self._data[checker_name] = self._digest_files(globs)
def invalidate(self, checker_name: str) -> None:
"""Remove a checker from the cache (call after a failed run)."""
self._data.pop(checker_name, None)
def _file_md5(path):
return hashlib.md5(path.read_bytes()).hexdigest()
# ANSI helpers
ORANGE = "\033[38;5;214m"
GREEN = "\033[32m"
RED = "\033[31m"
RESET = "\033[0m"
SPINNER_CHARS = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
def format_elapsed(seconds: float) -> str:
"""Format a duration for status output."""
if seconds >= 60:
minutes, seconds = divmod(seconds, 60)
return f"{int(minutes)}m{seconds:05.2f}s"
return f"{seconds:.2f}s"
class SlidingWindow:
"""Displays a spinning title + sliding window of the last N output lines in a TTY."""
def __init__(self, label, max_lines=10):
self.label = label
self.max_lines = max_lines
self.lines = deque(maxlen=max_lines)
self.displayed = 0 # number of output lines currently on screen
self.term_width = shutil.get_terminal_size().columns
self._spinner = itertools.cycle(SPINNER_CHARS)
self._stop = threading.Event()
self._lock = threading.Lock()
# Print initial title line (will be overwritten by spinner)
print(f"{ORANGE}{next(self._spinner)} {label}{RESET}")
self._title_on_screen = True
self._thread = threading.Thread(target=self._spin, daemon=True)
self._thread.start()
def _spin(self):
while not self._stop.is_set():
self._stop.wait(0.08)
if self._stop.is_set():
break
with self._lock:
self._redraw()
def _redraw(self):
"""Clear output lines + title, redraw everything."""
# Move up over output lines + title line
for _ in range(self.displayed + (1 if self._title_on_screen else 0)):
sys.stdout.write("\033[A\033[2K")
self.displayed = 0
# Redraw title with next spinner frame
print(f"{ORANGE}{next(self._spinner)} {self.label}{RESET}")
self._title_on_screen = True
# Redraw output lines
for line in self.lines:
print(line)
self.displayed = len(self.lines)
sys.stdout.flush()
def add_line(self, line):
with self._lock:
self.lines.append(line.rstrip()[: self.term_width])
self._redraw()
def finish(self, success, elapsed=None, show_lines=True):
"""Stop spinner and print final status title."""
self._stop.set()
self._thread.join()
with self._lock:
# Clear output lines + title
for _ in range(self.displayed + (1 if self._title_on_screen else 0)):
sys.stdout.write("\033[A\033[2K")
self._title_on_screen = False
self.displayed = 0
# Print final title with status
suffix = f" ({format_elapsed(elapsed)})" if elapsed is not None else ""
if success:
print(f"{GREEN}{self.label}{suffix}{RESET}")
else:
print(f"{RED}{self.label}{suffix}{RESET}")
# Reprint output lines when we want to preserve the tail summary.
if show_lines:
for line in self.lines:
print(line)
sys.stdout.flush()
def _print_output(output: str) -> None:
"""Print captured output without truncation."""
if not output:
return
print(output, end="" if output.endswith("\n") else "\n", flush=True)
def _run_cmd(cmd, line_callback=None):
"""Run a command, capturing output. Returns (returncode, output)."""
if line_callback is None:
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
return result.returncode, result.stdout.decode("utf-8", errors="replace")
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)
output_lines = []
for raw_line in proc.stdout:
line = raw_line.decode("utf-8", errors="replace")
output_lines.append(line)
line_callback(line)
proc.wait()
return proc.returncode, "".join(output_lines)
def run_deps_table_checker(fix=False, line_callback=None):
"""Check or fix the dependency versions table."""
deps_table = REPO_ROOT / "src" / "transformers" / "dependency_versions_table.py"
setup_py = REPO_ROOT / "setup.py"
cmd = [sys.executable, str(setup_py), "deps_table_update"]
if fix:
return _run_cmd(cmd, line_callback=line_callback)
before = _file_md5(deps_table)
rc, output = _run_cmd(cmd, line_callback=line_callback)
if rc != 0:
return rc, output
after = _file_md5(deps_table)
if before != after:
msg = (
"Error: the version dependency table is outdated.\n"
"Please run 'make fix-repo' and commit the changes. This requires Python 3.10.\n"
)
return 1, output + msg
return 0, output
def run_imports_checker(fix=False, line_callback=None):
"""Check that all public imports work."""
rc, output = _run_cmd([sys.executable, "-c", "from transformers import *"], line_callback=line_callback)
if rc != 0:
return rc, output + "Import failed, this means you introduced unprotected imports!\n"
return 0, output
RUFF_TARGETS = [
"examples",
"tests",
"src",
"utils",
"scripts",
".circleci/create_circleci_config.py",
"benchmark",
"benchmark_v2",
"setup.py",
"conftest.py",
]
def run_ruff_check(fix=False, line_callback=None):
"""Run ruff linting."""
cmd = ["ruff", "check", *RUFF_TARGETS]
if fix:
cmd += ["--fix", "--exclude", ""]
return _run_cmd(cmd, line_callback=line_callback)
def run_ruff_format(fix=False, line_callback=None):
"""Run ruff formatting."""
cmd = ["ruff", "format", *RUFF_TARGETS]
if not fix:
cmd += ["--check"]
else:
cmd += ["--exclude", ""]
return _run_cmd(cmd, line_callback=line_callback)
CUSTOM_RUNNERS = {
"deps_table": run_deps_table_checker,
"imports": run_imports_checker,
"ruff_check": run_ruff_check,
"ruff_format": run_ruff_format,
}
def get_checker_command(name, fix=False):
"""Return a shell-friendly command string for a checker."""
if name == "deps_table":
return "python setup.py deps_table_update"
if name == "imports":
return 'python -c "from transformers import *"'
if name == "ruff_check":
cmd = ["ruff", "check", *RUFF_TARGETS]
if fix:
cmd += ["--fix", "--exclude", ""]
return " ".join(cmd)
if name == "ruff_format":
cmd = ["ruff", "format", *RUFF_TARGETS]
if not fix:
cmd += ["--check"]
else:
cmd += ["--exclude", ""]
return " ".join(cmd)
_, script, check_args, fix_args = CHECKERS[name]
if fix and fix_args is None:
return None
args = fix_args if fix else check_args
return " ".join(["python", f"utils/{script}"] + args)
def run_checker(name, fix=False, line_callback=None):
if name in CUSTOM_RUNNERS:
return CUSTOM_RUNNERS[name](fix=fix, line_callback=line_callback)
_, script, check_args, fix_args = CHECKERS[name]
script_path = UTILS_DIR / script
if fix and fix_args is None:
return 0, "skipped (no fix mode)"
cmd = [sys.executable, str(script_path)]
cmd += fix_args if fix else check_args
return _run_cmd(cmd, line_callback=line_callback)
def main():
parser = argparse.ArgumentParser(description="Run check/fix scripts.")
parser.add_argument(
"checkers",
nargs="+",
help='Comma-separated checker names, or "all". Use --list to see available checkers.',
)
parser.add_argument("--fix", action="store_true", help="Run in fix mode instead of check mode.")
parser.add_argument(
"--keep-going", action="store_true", help="Run all checkers even if some fail (report failures at the end)."
)
parser.add_argument("--list", action="store_true", help="List available checkers and exit.")
parser.add_argument("--no-cache", action="store_true", help="Ignore the disk cache and re-run every checker.")
args = parser.parse_args()
if args.list:
for name, entry in sorted(CHECKERS.items()):
label, script, _, fix_args = entry
fixable = "fixable" if fix_args is not None else "check-only"
script_display = script or "custom"
print(f" {name:25s} {label:35s} ({script_display}, {fixable})")
return
# Join all positional args (shell line continuations may split them) and parse checker names
raw = " ".join(args.checkers)
if raw.strip() == "all":
names = list(CHECKERS.keys())
else:
names = [n.strip() for n in raw.split(",") if n.strip()]
unknown = [n for n in names if n not in CHECKERS]
if unknown:
print(f"Unknown checkers: {', '.join(unknown)}")
print(f"Available: {', '.join(sorted(CHECKERS.keys()))}")
sys.exit(1)
# In --fix mode, drop checkers that have no fix capability (fix_args is None) so
# they don't print bogus "(0.00s)" lines or inflate the final pass count. Print
# one transparency line listing what we're skipping.
if args.fix:
not_fixable = [n for n in names if CHECKERS[n][3] is None]
if not_fixable:
names = [n for n in names if CHECKERS[n][3] is not None]
print(
f"Skipping {len(not_fixable)} check-only checker(s) in fix mode: {', '.join(not_fixable)}\n",
flush=True,
)
is_ci = os.environ.get("GITHUB_ACTIONS") == "true" or os.environ.get("CIRCLECI") == "true"
is_tty = sys.stdout.isatty() and not is_ci
if not is_tty and hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(line_buffering=True)
use_cache = not args.no_cache and not args.fix
cache = CheckerCache() if use_cache else None
failures = []
skipped = 0
total_start = time.perf_counter()
for name in names:
label = CHECKERS[name][0]
# Skip if all relevant files are unchanged since last clean run
if cache is not None and cache.is_current(name):
skipped += 1
if is_tty:
print(f"{GREEN}{label} (cached){RESET}\n")
else:
print(f"{label} (cached)\n", flush=True)
continue
cmd_str = get_checker_command(name, fix=args.fix)
checker_start = time.perf_counter()
if is_tty:
window = SlidingWindow(label, max_lines=10)
if cmd_str:
window.add_line(f"$ {cmd_str}")
rc, output = run_checker(name, fix=args.fix, line_callback=window.add_line)
elapsed = time.perf_counter() - checker_start
window.finish(success=(rc == 0), elapsed=elapsed, show_lines=(rc == 0))
if rc != 0:
print()
_print_output(output)
print()
if rc == 0 and cache is not None:
cache.update(name)
elif rc != 0:
if cache is not None:
cache.invalidate(name)
failures.append(name)
if not args.keep_going:
if cache is not None:
cache.save()
sys.exit(1)
else:
print(f"{label}", flush=True)
if cmd_str:
print(f"$ {cmd_str}", flush=True)
if is_ci:
streamed_output = []
def print_line(line):
streamed_output.append(line)
print(line, end="", flush=True)
rc, output = run_checker(name, fix=args.fix, line_callback=print_line)
if rc != 0 and output:
streamed_text = "".join(streamed_output)
if output.startswith(streamed_text):
_print_output(output[len(streamed_text) :])
elif output != streamed_text:
_print_output(output)
else:
rc, output = run_checker(name, fix=args.fix)
if rc == 0:
tail = output.splitlines()[-10:]
if tail:
print("\n".join(tail), flush=True)
else:
_print_output(output)
elapsed = time.perf_counter() - checker_start
status = "OK" if rc == 0 else "FAILED"
print(f"{status} ({format_elapsed(elapsed)})", flush=True)
print(flush=True)
if rc == 0 and cache is not None:
cache.update(name)
elif rc != 0:
if cache is not None:
cache.invalidate(name)
failures.append(name)
if not args.keep_going:
if cache is not None:
cache.save()
sys.exit(1)
if cache is not None:
cache.save()
if failures:
print(f"\n{len(failures)} failed: {', '.join(failures)}", flush=True)
sys.exit(1)
total_elapsed = format_elapsed(time.perf_counter() - total_start)
passed = len(names) - skipped
if skipped:
print(f"\nAll {len(names)} checks passed in {total_elapsed} ({passed} ran, {skipped} cached).", flush=True)
else:
print(f"\nAll {len(names)} checks passed in {total_elapsed}.", flush=True)
if __name__ == "__main__":
main()

217
utils/collated_reports.py Normal file
View File

@@ -0,0 +1,217 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import subprocess
from dataclasses import dataclass
from pathlib import Path
DEFAULT_GPU_NAMES = ["mi300", "mi325", "mi355", "h100", "a10"]
def simplify_gpu_name(gpu_name: str, simplified_names: list[str]) -> str:
matches = []
for simplified_name in simplified_names:
if simplified_name in gpu_name:
matches.append(simplified_name)
if len(matches) == 1:
return matches[0]
return gpu_name
def parse_short_summary_line(line: str) -> tuple[str | None, int]:
if line.startswith("PASSED"):
return "passed", 1
if line.startswith("FAILED"):
return "failed", 1
if line.startswith("SKIPPED"):
line = line.split("[", maxsplit=1)[1]
line = line.split("]", maxsplit=1)[0]
return "skipped", int(line)
if line.startswith("ERROR"):
return "error", 1
return None, 0
def validate_path(p: str) -> Path:
# Validate path and apply glob pattern if provided
path = Path(p)
assert path.is_dir(), f"Path {path} is not a directory"
return path
def get_gpu_name(gpu_name: str | None) -> str:
# Get GPU name if available
if gpu_name is None:
try:
import torch
gpu_name = torch.cuda.get_device_name()
except Exception as e:
print(f"Failed to get GPU name with {e}")
gpu_name = "unknown"
else:
gpu_name = gpu_name.replace(" ", "_").lower()
gpu_name = simplify_gpu_name(gpu_name, DEFAULT_GPU_NAMES)
return gpu_name
def get_commit_hash(commit_hash: str | None) -> str:
# Get commit hash if available
if commit_hash is None:
try:
commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
except Exception as e:
print(f"Failed to get commit hash with {e}")
commit_hash = "unknown"
return commit_hash[:7]
@dataclass
class Args:
path: Path
machine_type: str
gpu_name: str
commit_hash: str
job: str | None
report_repo_id: str | None
def get_arguments(args: argparse.Namespace) -> Args:
path = validate_path(args.path)
machine_type = args.machine_type
gpu_name = get_gpu_name(args.gpu_name)
commit_hash = get_commit_hash(args.commit_hash)
job = args.job
report_repo_id = args.report_repo_id
return Args(path, machine_type, gpu_name, commit_hash, job, report_repo_id)
def upload_collated_report(job: str, report_repo_id: str, filename: str):
# Alternatively we can check for the existence of the collated_reports file and upload in notification_service.py
import os
from get_previous_daily_ci import get_last_daily_ci_run
from huggingface_hub import HfApi
api = HfApi()
# if it is not a scheduled run, upload the reports to a subfolder under `report_repo_folder`
report_repo_subfolder = ""
if os.getenv("GITHUB_EVENT_NAME") != "schedule":
report_repo_subfolder = f"{os.getenv('GITHUB_RUN_NUMBER')}-{os.getenv('GITHUB_RUN_ID')}"
report_repo_subfolder = f"runs/{report_repo_subfolder}"
workflow_run = get_last_daily_ci_run(
token=os.environ["ACCESS_REPO_INFO_TOKEN"], workflow_run_id=os.getenv("GITHUB_RUN_ID")
)
workflow_run_created_time = workflow_run["created_at"]
report_repo_folder = workflow_run_created_time.split("T")[0]
if report_repo_subfolder:
report_repo_folder = f"{report_repo_folder}/{report_repo_subfolder}"
api.upload_file(
path_or_fileobj=f"{filename}",
path_in_repo=f"{report_repo_folder}/ci_results_{job}/{filename}",
repo_id=report_repo_id,
repo_type="dataset",
token=os.getenv("TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN"),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Post process models test reports.")
parser.add_argument("--path", "-p", help="Path to the reports folder")
parser.add_argument(
"--machine-type", "-m", help="Process single or multi GPU results", choices=["single-gpu", "multi-gpu"]
)
parser.add_argument("--gpu-name", "-g", help="GPU name", default=None)
parser.add_argument("--commit-hash", "-c", help="Commit hash", default=None)
parser.add_argument("--job", "-j", help="Optional job name required for uploading reports", default=None)
parser.add_argument(
"--report-repo-id", "-r", help="Optional report repository ID required for uploading reports", default=None
)
args = get_arguments(parser.parse_args())
# Initialize accumulators for collated report
total_status_count = {
"passed": 0,
"failed": 0,
"skipped": 0,
"error": 0,
None: 0,
}
collated_report_buffer = []
path = args.path
machine_type = args.machine_type
gpu_name = args.gpu_name
commit_hash = args.commit_hash
job = args.job
report_repo_id = args.report_repo_id
# Loop through model directories and create collated reports
for model_dir in sorted(path.iterdir()):
if not model_dir.name.startswith(machine_type):
continue
# Create a new entry for the model
model_name = model_dir.name.split("models_")[-1].removesuffix("_test_reports")
report = {"model": model_name, "results": []}
results = []
# Read short summary
with open(model_dir / "summary_short.txt", "r") as f:
short_summary_lines = f.readlines()
# Parse short summary
for line in short_summary_lines[1:]:
status, count = parse_short_summary_line(line)
total_status_count[status] += count
if status:
result = {
"status": status,
"test": line.split(status.upper(), maxsplit=1)[1].strip(),
"count": count,
}
results.append(result)
# Add short summaries to report
report["results"] = results
collated_report_buffer.append(report)
filename = f"collated_reports_{machine_type}_{commit_hash}.json"
# Write collated report
with open(filename, "w") as f:
json.dump(
{
"gpu_name": gpu_name,
"machine_type": machine_type,
"commit_hash": commit_hash,
"total_status_count": total_status_count,
"results": collated_report_buffer,
},
f,
indent=2,
)
# Upload collated report
if job and report_repo_id:
upload_collated_report(job, report_repo_id, filename)

View File

@@ -0,0 +1,91 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def normalize_test_line(line):
line = line.strip()
# Normalize SKIPPED/XFAIL/etc with path:line and reason
match = re.match(r"^(SKIPPED|XFAIL|XPASS|EXPECTEDFAIL)\s+\[?\d*\]?\s*(\S+:\d+)", line)
if match:
status, location = match.groups()
return f"{status} {location}"
# Normalize ERROR/FAILED lines with optional message
if line.startswith("ERROR") or line.startswith("FAILED"):
return re.split(r"\s+-\s+", line)[0].strip()
return line
def parse_summary_file(file_path):
test_set = set()
with open(file_path, "r", encoding="utf-8") as f:
in_summary = False
for line in f:
if line.strip().startswith("==="):
in_summary = not in_summary
continue
if in_summary:
stripped = line.strip()
if stripped:
normalized = normalize_test_line(stripped)
test_set.add(normalized)
return test_set
def compare_job_sets(job_set1, job_set2):
all_job_names = sorted(set(job_set1) | set(job_set2))
report_lines = []
for job_name in all_job_names:
file1 = job_set1.get(job_name)
file2 = job_set2.get(job_name)
tests1 = parse_summary_file(file1) if file1 else set()
tests2 = parse_summary_file(file2) if file2 else set()
added = tests2 - tests1
removed = tests1 - tests2
if added or removed:
report_lines.append(f"=== Diff for job: {job_name} ===")
if removed:
report_lines.append("--- Absent in current run:")
for test in sorted(removed):
report_lines.append(f" - {test}")
if added:
report_lines.append("+++ Appeared in current run:")
for test in sorted(added):
report_lines.append(f" + {test}")
report_lines.append("") # blank line
return "\n".join(report_lines) if report_lines else "No differences found."
# Example usage:
# job_set_1 = {
# "albert": "prev/multi-gpu_run_models_gpu_models/albert_test_reports/summary_short.txt",
# "bloom": "prev/multi-gpu_run_models_gpu_models/bloom_test_reports/summary_short.txt",
# }
# job_set_2 = {
# "albert": "curr/multi-gpu_run_models_gpu_models/albert_test_reports/summary_short.txt",
# "bloom": "curr/multi-gpu_run_models_gpu_models/bloom_test_reports/summary_short.txt",
# }
# report = compare_job_sets(job_set_1, job_set_2)
# print(report)

View File

@@ -0,0 +1,116 @@
import ast
import re
from collections import defaultdict
# Function to perform topological sorting
def topological_sort(dependencies: dict) -> list[list[str]]:
"""Given the dependencies graph, construct a sorted list of list of modular files.
Examples:
The returned list of lists might be:
[
["../modular_mistral.py", "../modular_gemma.py"], # level 0
["../modular_llama4.py", "../modular_gemma2.py"], # level 1
["../modular_glm4.py"], # level 2
]
which means mistral and gemma do not depend on any other modular models, while llama4 and gemma2
depend on the models in the first list, and glm4 depends on the models in the second and (optionally) in the first list.
"""
# Nodes are the name of the models to convert (we only add those to the graph)
nodes = {node.rsplit("modular_", 1)[1].replace(".py", "") for node in dependencies}
# This will be a graph from models to convert, to models to convert that should be converted before (as they are a dependency)
graph = {}
name_mapping = {}
for node, deps in dependencies.items():
node_name = node.rsplit("modular_", 1)[1].replace(".py", "")
dep_names = {dep.split(".")[-2] for dep in deps}
dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name}
graph[node_name] = dependencies
name_mapping[node_name] = node
sorting_list = []
while len(graph) > 0:
# Find the nodes with 0 out-degree
leaf_nodes = {node for node in graph if len(graph[node]) == 0}
# Add them to the list as next level
sorting_list.append([name_mapping[node] for node in leaf_nodes])
# Remove the leaves from the graph (and from the deps of other nodes)
graph = {node: deps - leaf_nodes for node, deps in graph.items() if node not in leaf_nodes}
return sorting_list
# All the model file types that may be imported in modular files
ALL_FILE_TYPES = (
"modeling",
"configuration",
"tokenization",
"processing",
"image_processing",
"video_processing",
"feature_extraction",
)
def is_model_import(module: str | None) -> bool:
"""Check whether `module` is a model import or not."""
# Happens for fully relative import, i.e. `from ... import initialization as init`
if module is None:
return False
patterns = "|".join(ALL_FILE_TYPES)
regex = rf"(\w+)\.(?:{patterns})_(\w+)"
match_object = re.search(regex, module)
if match_object is not None:
model_name = match_object.group(1)
if model_name in match_object.group(2) and model_name != "auto":
return True
return False
def extract_model_imports_from_file(file_path):
"""From a python file `file_path`, extract the model-specific imports (the imports related to any model file in
Transformers)"""
with open(file_path, "r", encoding="utf-8") as file:
tree = ast.parse(file.read(), filename=file_path)
imports = set()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom):
if is_model_import(node.module):
imports.add(node.module)
return imports
def find_priority_list(modular_files: list[str]) -> tuple[list[list[str]], dict[str, set]]:
"""
Given a list of modular files, sorts them by topological order. Modular models that DON'T depend on other modular
models will be lower in the topological order.
Args:
modular_files (`list[str]`):
List of paths to the modular files.
Returns:
A tuple `ordered_files` and `dependencies`.
`ordered_file` is a list of lists consisting of the models at each level of the dependency graph. For example,
it might be:
[
["../modular_mistral.py", "../modular_gemma.py"], # level 0
["../modular_llama4.py", "../modular_gemma2.py"], # level 1
["../modular_glm4.py"], # level 2
]
which means mistral and gemma do not depend on any other modular models, while llama4 and gemma2 depend on the
models in the first list, and glm4 depends on the models in the second and (optionally) in the first list.
`dependencies` is a dictionary mapping each modular file to the models on which it relies (the models that are
imported in order to use inheritance).
"""
dependencies = defaultdict(set)
for file_path in modular_files:
dependencies[file_path].update(extract_model_imports_from_file(file_path))
ordered_files = topological_sort(dependencies)
return ordered_files, dependencies

1857
utils/create_dummy_models.py Normal file

File diff suppressed because it is too large Load Diff

338
utils/custom_init_isort.py Normal file
View File

@@ -0,0 +1,338 @@
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that sorts the imports in the custom inits of Transformers. Transformers uses init files that delay the
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. `isort` or `ruff`
properly sort the second half which looks like traditionl imports, the goal of this script is to sort the first half.
Use from the root of the repo with:
```bash
python utils/custom_init_isort.py
```
which will auto-sort the imports (used in `make style`).
For a check only (as used in `make check-repo`) run:
```bash
python utils/custom_init_isort.py --check_only
```
"""
import argparse
import os
import re
from collections.abc import Callable
from typing import Any
CHECKER_CONFIG = {
"name": "init_isort",
"label": "Import ordering",
"cache_globs": ["src/transformers/**/__init__.py"],
"check_args": ["--check_only"],
"fix_args": [],
}
# Path is defined with the intent you should run this script from the root of the repo.
PATH_TO_TRANSFORMERS = "src/transformers"
# Pattern that looks at the indentation in a line.
_re_indent = re.compile(r"^(\s*)\S")
# Pattern that matches `"key":" and puts `key` in group 0.
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
# Pattern that matches `"key",` and puts `key` in group 0.
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
def get_indent(line: str) -> str:
"""Returns the indent in given line (as string)."""
search = _re_indent.search(line)
return "" if search is None else search.groups()[0]
def split_code_in_indented_blocks(
code: str, indent_level: str = "", start_prompt: str | None = None, end_prompt: str | None = None
) -> list[str]:
"""
Split some code into its indented blocks, starting at a given level.
Args:
code (`str`): The code to split.
indent_level (`str`): The indent level (as string) to use for identifying the blocks to split.
start_prompt (`str`, *optional*): If provided, only starts splitting at the line where this text is.
end_prompt (`str`, *optional*): If provided, stops splitting at a line where this text is.
Warning:
The text before `start_prompt` or after `end_prompt` (if provided) is not ignored, just not split. The input `code`
can thus be retrieved by joining the result.
Returns:
`List[str]`: The list of blocks.
"""
# Let's split the code into lines and move to start_index.
index = 0
lines = code.split("\n")
if start_prompt is not None:
while not lines[index].startswith(start_prompt):
index += 1
blocks = ["\n".join(lines[:index])]
else:
blocks = []
# This variable contains the block treated at a given time.
current_block = [lines[index]]
index += 1
# We split into blocks until we get to the `end_prompt` (or the end of the file).
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
# We have a non-empty line with the proper indent -> start of a new block
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
# Store the current block in the result and rest. There are two cases: the line is part of the block (like
# a closing parenthesis) or not.
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
# Line is part of the current block
current_block.append(lines[index])
blocks.append("\n".join(current_block))
if index < len(lines) - 1:
current_block = [lines[index + 1]]
index += 1
else:
current_block = []
else:
# Line is not part of the current block
blocks.append("\n".join(current_block))
current_block = [lines[index]]
else:
# Just add the line to the current block
current_block.append(lines[index])
index += 1
# Adds current block if it's nonempty.
if len(current_block) > 0:
blocks.append("\n".join(current_block))
# Add final block after end_prompt if provided.
if end_prompt is not None and index < len(lines):
blocks.append("\n".join(lines[index:]))
return blocks
def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any], str]:
"""
Wraps a key function (as used in a sort) to lowercase and ignore underscores.
"""
def _inner(x):
return key(x).lower().replace("_", "")
return _inner
def sort_objects(objects: list[Any], key: Callable[[Any], str] | None = None) -> list[Any]:
"""
Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
last).
Args:
objects (`List[Any]`):
The list of objects to sort.
key (`Callable[[Any], str]`, *optional*):
A function taking an object as input and returning a string, used to sort them by alphabetical order.
If not provided, will default to noop (so a `key` must be provided if the `objects` are not of type string).
Returns:
`List[Any]`: The sorted list with the same elements as in the inputs
"""
# If no key is provided, we use a noop.
def noop(x):
return x
if key is None:
key = noop
# Constants are all uppercase, they go first.
constants = [obj for obj in objects if key(obj).isupper()]
# Classes are not all uppercase but start with a capital, they go second.
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
# Functions begin with a lowercase, they go last.
functions = [obj for obj in objects if not key(obj)[0].isupper()]
# Then we sort each group.
key1 = ignore_underscore_and_lowercase(key)
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
def sort_objects_in_import(import_statement: str) -> str:
"""
Sorts the imports in a single import statement.
Args:
import_statement (`str`): The import statement in which to sort the imports.
Returns:
`str`: The same as the input, but with objects properly sorted.
"""
# This inner function sort imports between [ ].
def _replace(match):
imports = match.groups()[0]
# If there is one import only, nothing to do.
if "," not in imports:
return f"[{imports}]"
keys = [part.strip().replace('"', "") for part in imports.split(",")]
# We will have a final empty element if the line finished with a comma.
if len(keys[-1]) == 0:
keys = keys[:-1]
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
lines = import_statement.split("\n")
if len(lines) > 3:
# Here we have to sort internal imports that are on several lines (one per name):
# key: [
# "object1",
# "object2",
# ...
# ]
# We may have to ignore one or two lines on each side.
idx = 2 if lines[1].strip() == "[" else 1
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
elif len(lines) == 3:
# Here we have to sort internal imports that are on one separate line:
# key: [
# "object1", "object2", ...
# ]
if _re_bracket_content.search(lines[1]) is not None:
lines[1] = _re_bracket_content.sub(_replace, lines[1])
else:
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
# We will have a final empty element if the line finished with a comma.
if len(keys[-1]) == 0:
keys = keys[:-1]
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
return "\n".join(lines)
else:
# Finally we have to deal with imports fitting on one line
import_statement = _re_bracket_content.sub(_replace, import_statement)
return import_statement
def sort_imports(file: str, check_only: bool = True):
"""
Sort the imports defined in the `_import_structure` of a given init.
Args:
file (`str`): The path to the init to check/fix.
check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
"""
with open(file, encoding="utf-8") as f:
code = f.read()
# If the file is not a custom init, there is nothing to do.
if "_import_structure = {" not in code:
return
# Blocks of indent level 0
main_blocks = split_code_in_indented_blocks(
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
)
# We ignore block 0 (everything until start_prompt) and the last block (everything after end_prompt).
for block_idx in range(1, len(main_blocks) - 1):
# Check if the block contains some `_import_structure`s thingy to sort.
block = main_blocks[block_idx]
block_lines = block.split("\n")
# Get to the start of the imports.
line_idx = 0
while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
# Skip dummy import blocks
if "import dummy" in block_lines[line_idx]:
line_idx = len(block_lines)
else:
line_idx += 1
if line_idx >= len(block_lines):
continue
# Ignore beginning and last line: they don't contain anything.
internal_block_code = "\n".join(block_lines[line_idx:-1])
indent = get_indent(block_lines[1])
# Slit the internal block into blocks of indent level 1.
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
# We have two categories of import key: list or _import_structure[key].append/extend
pattern = _re_direct_key if "_import_structure = {" in block_lines[0] else _re_indirect_key
# Grab the keys, but there is a trap: some lines are empty or just comments.
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
# We only sort the lines with a key.
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
count = 0
reorderded_blocks = []
for i in range(len(internal_blocks)):
if keys[i] is None:
reorderded_blocks.append(internal_blocks[i])
else:
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
reorderded_blocks.append(block)
count += 1
# And we put our main block back together with its first and last line.
main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
if code != "\n".join(main_blocks):
if check_only:
return True
else:
print(f"Overwriting {file}.")
with open(file, "w", encoding="utf-8") as f:
f.write("\n".join(main_blocks))
def sort_imports_in_all_inits(check_only=True):
"""
Sort the imports defined in the `_import_structure` of all inits in the repo.
Args:
check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
"""
failures = []
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
if "__init__.py" in files:
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
if result:
failures = [os.path.join(root, "__init__.py")]
if len(failures) > 0:
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
args = parser.parse_args()
sort_imports_in_all_inits(check_only=args.check_only)

377
utils/deprecate_models.py Normal file
View File

@@ -0,0 +1,377 @@
"""
Script which deprecates a list of given models
Example usage:
python utils/deprecate_models.py --models bert distilbert
"""
import argparse
import os
from collections import defaultdict
from pathlib import Path
import httpx
from custom_init_isort import sort_imports_in_all_inits
from git import Repo
from packaging import version
from transformers import CONFIG_MAPPING, logging
from transformers import __version__ as current_version
REPO_PATH = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
repo = Repo(REPO_PATH)
logger = logging.get_logger(__name__)
def get_last_stable_minor_release():
# Get the last stable release of transformers
url = "https://pypi.org/pypi/transformers/json"
release_data = httpx.get(url).json()
# Find the last stable release of transformers (version below current version)
major_version, minor_version, patch_version, _ = current_version.split(".")
last_major_minor = f"{major_version}.{int(minor_version) - 1}"
last_stable_minor_releases = [
release for release in release_data["releases"] if release.startswith(last_major_minor)
]
last_stable_release = max(last_stable_minor_releases, key=version.parse)
return last_stable_release
def build_tip_message(last_stable_release):
return (
"""
<Tip warning={true}>
This model is in maintenance mode only, we don't accept any new PRs changing its code.
"""
+ f"""If you run into any issues running this model, please reinstall the last version that supported this model: v{last_stable_release}.
You can do so by running the following command: `pip install -U transformers=={last_stable_release}`.
</Tip>"""
)
def insert_tip_to_model_doc(model_doc_path, tip_message):
tip_message_lines = tip_message.split("\n")
with open(model_doc_path, "r") as f:
model_doc = f.read()
# Add the tip message to the model doc page directly underneath the title
lines = model_doc.split("\n")
new_model_lines = []
for line in lines:
if line.startswith("# "):
new_model_lines.append(line)
new_model_lines.extend(tip_message_lines)
else:
new_model_lines.append(line)
with open(model_doc_path, "w") as f:
f.write("\n".join(new_model_lines))
def get_model_doc_path(model: str) -> tuple[str | None, str | None]:
# Possible variants of the model name in the model doc path
model_names = [model, model.replace("_", "-"), model.replace("_", "")]
model_doc_paths = [REPO_PATH / f"docs/source/en/model_doc/{model_name}.md" for model_name in model_names]
for model_doc_path, model_name in zip(model_doc_paths, model_names):
if os.path.exists(model_doc_path):
return model_doc_path, model_name
return None, None
def extract_model_info(model):
model_info = {}
model_doc_path, model_doc_name = get_model_doc_path(model)
model_path = REPO_PATH / f"src/transformers/models/{model}"
if model_doc_path is None:
print(f"Model doc path does not exist for {model}")
return None
model_info["model_doc_path"] = model_doc_path
model_info["model_doc_name"] = model_doc_name
if not os.path.exists(model_path):
print(f"Model path does not exist for {model}")
return None
model_info["model_path"] = model_path
return model_info
def update_relative_imports(filename, model):
with open(filename, "r") as f:
filelines = f.read()
new_file_lines = []
for line in filelines.split("\n"):
if line.startswith("from .."):
new_file_lines.append(line.replace("from ..", "from ..."))
else:
new_file_lines.append(line)
with open(filename, "w") as f:
f.write("\n".join(new_file_lines))
def remove_copied_from_statements(model):
model_path = REPO_PATH / f"src/transformers/models/{model}"
for file in os.listdir(model_path):
if file == "__pycache__":
continue
file_path = model_path / file
with open(file_path, "r") as f:
file_lines = f.read()
new_file_lines = []
for line in file_lines.split("\n"):
if "# Copied from" in line:
continue
new_file_lines.append(line)
with open(file_path, "w") as f:
f.write("\n".join(new_file_lines))
def move_model_files_to_deprecated(model):
model_path = REPO_PATH / f"src/transformers/models/{model}"
deprecated_model_path = REPO_PATH / f"src/transformers/models/deprecated/{model}"
if not os.path.exists(deprecated_model_path):
os.makedirs(deprecated_model_path)
for file in os.listdir(model_path):
if file == "__pycache__":
continue
repo.git.mv(f"{model_path}/{file}", f"{deprecated_model_path}/{file}")
# For deprecated files, we then need to update the relative imports
update_relative_imports(f"{deprecated_model_path}/{file}", model)
def delete_model_tests(model):
tests_path = REPO_PATH / f"tests/models/{model}"
if os.path.exists(tests_path):
repo.git.rm("-r", tests_path)
def get_line_indent(s):
return len(s) - len(s.lstrip())
def update_main_init_file(models):
"""
Replace all instances of model.model_name with model.deprecated.model_name in the __init__.py file
Args:
models (List[str]): The models to mark as deprecated
"""
filename = REPO_PATH / "src/transformers/__init__.py"
with open(filename, "r") as f:
init_file = f.read()
# 1. For each model, find all the instances of model.model_name and replace with model.deprecated.model_name
for model in models:
init_file = init_file.replace(f'models.{model}"', f'models.deprecated.{model}"')
init_file = init_file.replace(f"models.{model} import", f"models.deprecated.{model} import")
with open(filename, "w") as f:
f.write(init_file)
# 2. Resort the imports
sort_imports_in_all_inits(check_only=False)
def remove_model_references_from_file(filename, models, condition):
"""
Remove all references to the given models from the given file
Args:
filename (str): The file to remove the references from
models (List[str]): The models to remove
condition (Callable): A function that takes the line and model and returns True if the line should be removed
"""
filename = REPO_PATH / filename
with open(filename, "r") as f:
init_file = f.read()
new_file_lines = []
for i, line in enumerate(init_file.split("\n")):
if any(condition(line, model) for model in models):
continue
new_file_lines.append(line)
with open(filename, "w") as f:
f.write("\n".join(new_file_lines))
def remove_model_config_classes_from_config_check(model_config_classes):
"""
Remove the deprecated model config classes from the check_config_attributes.py file
Args:
model_config_classes (List[str]): The model config classes to remove e.g. ["BertConfig", "DistilBertConfig"]
"""
filename = REPO_PATH / "utils/check_config_attributes.py"
with open(filename, "r") as f:
check_config_attributes = f.read()
# Keep track as we have to delete comment above too
in_special_cases_to_allow = False
in_indent = False
new_file_lines = []
for line in check_config_attributes.split("\n"):
indent = get_line_indent(line)
if (line.strip() == "SPECIAL_CASES_TO_ALLOW = {") or (line.strip() == "SPECIAL_CASES_TO_ALLOW.update("):
in_special_cases_to_allow = True
elif in_special_cases_to_allow and indent == 0 and line.strip() in ("}", ")"):
in_special_cases_to_allow = False
if in_indent:
if line.strip().endswith(("]", "],")):
in_indent = False
continue
if in_special_cases_to_allow and any(
model_config_class in line for model_config_class in model_config_classes
):
# Remove comments above the model config class to remove
while new_file_lines[-1].strip().startswith("#"):
new_file_lines.pop()
if line.strip().endswith("["):
in_indent = True
continue
elif any(model_config_class in line for model_config_class in model_config_classes):
continue
new_file_lines.append(line)
with open(filename, "w") as f:
f.write("\n".join(new_file_lines))
def add_models_to_deprecated_models_in_config_auto(models):
"""
Add the models to the DEPRECATED_MODELS list in configuration_auto.py and sorts the list
to be in alphabetical order.
"""
filepath = REPO_PATH / "src/transformers/models/auto/configuration_auto.py"
with open(filepath, "r") as f:
config_auto = f.read()
new_file_lines = []
deprecated_models_list = []
in_deprecated_models = False
for line in config_auto.split("\n"):
if line.strip() == "DEPRECATED_MODELS = [":
in_deprecated_models = True
new_file_lines.append(line)
elif in_deprecated_models and line.strip() == "]":
in_deprecated_models = False
# Add the new models to deprecated models list
deprecated_models_list.extend([f' "{model}", ' for model in models])
# Sort so they're in alphabetical order in the file
deprecated_models_list = sorted(deprecated_models_list)
new_file_lines.extend(deprecated_models_list)
# Make sure we still have the closing bracket
new_file_lines.append(line)
elif in_deprecated_models:
deprecated_models_list.append(line)
else:
new_file_lines.append(line)
with open(filepath, "w") as f:
f.write("\n".join(new_file_lines))
def deprecate_models(models):
# Get model info
skipped_models = []
models_info = defaultdict(dict)
for model in models:
single_model_info = extract_model_info(model)
if single_model_info is None:
skipped_models.append(model)
else:
models_info[model] = single_model_info
model_config_classes = []
for model, model_info in models_info.items():
if model in CONFIG_MAPPING:
model_config_classes.append(CONFIG_MAPPING[model].__name__)
elif model_info["model_doc_name"] in CONFIG_MAPPING:
model_config_classes.append(CONFIG_MAPPING[model_info["model_doc_name"]].__name__)
else:
skipped_models.append(model)
print(f"Model config class not found for model: {model}")
# Filter out skipped models
models = [model for model in models if model not in skipped_models]
if skipped_models:
print(f"Skipped models: {skipped_models} as the model doc or model path could not be found.")
print(f"Models to deprecate: {models}")
# Remove model config classes from config check
print("Removing model config classes from config checks")
remove_model_config_classes_from_config_check(model_config_classes)
tip_message = build_tip_message(get_last_stable_minor_release())
for model, model_info in models_info.items():
print(f"Processing model: {model}")
# Add the tip message to the model doc page directly underneath the title
print("Adding tip message to model doc page")
insert_tip_to_model_doc(model_info["model_doc_path"], tip_message)
# Remove #Copied from statements from model's files
print("Removing #Copied from statements from model's files")
remove_copied_from_statements(model)
# Move the model file to deprecated: src/transformers/models/model -> src/transformers/models/deprecated/model
print("Moving model files to deprecated for model")
move_model_files_to_deprecated(model)
# Delete the model tests: tests/models/model
print("Deleting model tests")
delete_model_tests(model)
# # We do the following with all models passed at once to avoid having to re-write the file multiple times
print("Updating __init__.py file to point to the deprecated models")
update_main_init_file(models)
# Remove model references from other files
print("Removing model references from other files")
remove_model_references_from_file(
"src/transformers/models/__init__.py", models, lambda line, model: model == line.strip().strip(",")
)
remove_model_references_from_file(
"utils/slow_documentation_tests.txt", models, lambda line, model: "/" + model + "/" in line
)
remove_model_references_from_file("utils/not_doctested.txt", models, lambda line, model: "/" + model + "/" in line)
# Add models to DEPRECATED_MODELS in the configuration_auto.py
print("Adding models to DEPRECATED_MODELS in configuration_auto.py")
add_models_to_deprecated_models_in_config_auto(models)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--models", nargs="+", help="List of models to deprecate")
args = parser.parse_args()
deprecate_models(args.models)

160
utils/download_glue_data.py Normal file
View File

@@ -0,0 +1,160 @@
"""Script for downloading all GLUE data.
Original source: https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
Note: for legal reasons, we are unable to host MRPC.
You can either use the version hosted by the SentEval team, which is already tokenized,
or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
You should then rename and place specific files in a folder (see below for an example).
mkdir MRPC
cabextract MSRParaphraseCorpus.msi -d MRPC
cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
rm MRPC/_*
rm MSRParaphraseCorpus.msi
1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
"""
import argparse
import os
import sys
import urllib.request
import zipfile
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {
"CoLA": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4",
"SST": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8",
"MRPC": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc",
"QQP": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5",
"STS": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5",
"MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",
"SNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df",
"QNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601",
"RTE": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb",
"WNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf",
"diagnostic": "https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D",
}
MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"
def download_and_extract(task, data_dir):
print(f"Downloading and extracting {task}...")
data_file = f"{task}.zip"
urllib.request.urlretrieve(TASK2PATH[task], data_file)
with zipfile.ZipFile(data_file) as zip_ref:
zip_ref.extractall(data_dir)
os.remove(data_file)
print("\tCompleted!")
def format_mrpc(data_dir, path_to_data):
print("Processing MRPC...")
mrpc_dir = os.path.join(data_dir, "MRPC")
if not os.path.isdir(mrpc_dir):
os.mkdir(mrpc_dir)
if path_to_data:
mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
else:
print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
if not os.path.isfile(mrpc_train_file):
raise ValueError(f"Train data not found at {mrpc_train_file}")
if not os.path.isfile(mrpc_test_file):
raise ValueError(f"Test data not found at {mrpc_test_file}")
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
dev_ids = []
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
for row in ids_fh:
dev_ids.append(row.strip().split("\t"))
with (
open(mrpc_train_file, encoding="utf8") as data_fh,
open(os.path.join(mrpc_dir, "train.tsv"), "w", encoding="utf8") as train_fh,
open(os.path.join(mrpc_dir, "dev.tsv"), "w", encoding="utf8") as dev_fh,
):
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
for row in data_fh:
label, id1, id2, s1, s2 = row.strip().split("\t")
if [id1, id2] in dev_ids:
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
with (
open(mrpc_test_file, encoding="utf8") as data_fh,
open(os.path.join(mrpc_dir, "test.tsv"), "w", encoding="utf8") as test_fh,
):
header = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split("\t")
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
print("\tCompleted!")
def download_diagnostic(data_dir):
print("Downloading and extracting diagnostic...")
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
os.mkdir(os.path.join(data_dir, "diagnostic"))
data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
print("\tCompleted!")
return
def get_tasks(task_names):
task_names = task_names.split(",")
if "all" in task_names:
tasks = TASKS
else:
tasks = []
for task_name in task_names:
if task_name not in TASKS:
raise ValueError(f"Task {task_name} not found!")
tasks.append(task_name)
return tasks
def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", help="directory to save data to", type=str, default="glue_data")
parser.add_argument(
"--tasks", help="tasks to download data for as a comma separated string", type=str, default="all"
)
parser.add_argument(
"--path_to_mrpc",
help="path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt",
type=str,
default="",
)
args = parser.parse_args(arguments)
if not os.path.isdir(args.data_dir):
os.mkdir(args.data_dir)
tasks = get_tasks(args.tasks)
for task in tasks:
if task == "MRPC":
format_mrpc(args.data_dir, args.path_to_mrpc)
elif task == "diagnostic":
download_diagnostic(args.data_dir)
else:
download_and_extract(task, args.data_dir)
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

69
utils/extract_metadata.py Executable file
View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extract metadata from setup.py for CI testing.
Usage:
python utils/extract_metadata.py extras # List all extras (one per line)
python utils/extract_metadata.py python-versions # Output JSON array of Python versions
"""
import json
import sys
from pathlib import Path
from types import ModuleType
def get_setup_module() -> ModuleType:
"""Import and return the setup module."""
repo_root: Path = Path(__file__).parent.parent
sys.path.insert(0, str(repo_root))
import setup
return setup
def extract_extras() -> None:
"""Print all extras in definition order (one per line)."""
setup: ModuleType = get_setup_module()
for extra in setup.extras.keys():
print(extra)
def extract_python_versions() -> None:
"""Print supported Python versions as a JSON array."""
setup: ModuleType = get_setup_module()
min_ver: int
max_ver: int
min_ver, max_ver = setup.SUPPORTED_PYTHON_VERSIONS
versions: list[str] = [f"3.{v}" for v in range(min_ver, max_ver + 1)]
print(json.dumps(versions))
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python utils/extract_metadata.py {extras|python-versions}", file=sys.stderr)
sys.exit(1)
command: str = sys.argv[1]
if command == "extras":
extract_extras()
elif command == "python-versions":
extract_python_versions()
else:
print(f"Unknown command: {command}", file=sys.stderr)
print("Usage: python utils/extract_metadata.py {extras|python-versions}", file=sys.stderr)
sys.exit(1)

View File

@@ -0,0 +1,30 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Used by `.github/workflows/trigger_circleci.yml` to get the pull request number in CircleCI job runs."""
import os
if __name__ == "__main__":
pr_number = ""
pr = os.environ.get("CIRCLE_PULL_REQUEST", "")
if len(pr) > 0:
pr_number = pr.split("/")[-1]
if pr_number == "":
pr = os.environ.get("CIRCLE_BRANCH", "")
if pr.startswith("pull/"):
pr_number = "".join(pr.split("/")[1:2])
print(pr_number)

134
utils/extract_warnings.py Normal file
View File

@@ -0,0 +1,134 @@
import argparse
import json
import os
import time
import zipfile
from get_ci_error_statistics import download_artifact, get_artifacts_links
from transformers import logging
logger = logging.get_logger(__name__)
def extract_warnings_from_single_artifact(artifact_path, targets):
"""Extract warnings from a downloaded artifact (in .zip format)"""
selected_warnings = set()
buffer = []
def parse_line(fp):
for line in fp:
if isinstance(line, bytes):
line = line.decode("UTF-8")
if "warnings summary (final)" in line:
continue
# This means we are outside the body of a warning
elif not line.startswith(" "):
# process a single warning and move it to `selected_warnings`.
if len(buffer) > 0:
warning = "\n".join(buffer)
# Only keep the warnings specified in `targets`
if any(f": {x}: " in warning for x in targets):
selected_warnings.add(warning)
buffer.clear()
continue
else:
line = line.strip()
buffer.append(line)
if from_gh:
for filename in os.listdir(artifact_path):
file_path = os.path.join(artifact_path, filename)
if not os.path.isdir(file_path):
# read the file
if filename != "warnings.txt":
continue
with open(file_path) as fp:
parse_line(fp)
else:
try:
with zipfile.ZipFile(artifact_path) as z:
for filename in z.namelist():
if not os.path.isdir(filename):
# read the file
if filename != "warnings.txt":
continue
with z.open(filename) as fp:
parse_line(fp)
except Exception:
logger.warning(
f"{artifact_path} is either an invalid zip file or something else wrong. This file is skipped."
)
return selected_warnings
def extract_warnings(artifact_dir, targets):
"""Extract warnings from all artifact files"""
selected_warnings = set()
paths = [os.path.join(artifact_dir, p) for p in os.listdir(artifact_dir) if (p.endswith(".zip") or from_gh)]
for p in paths:
selected_warnings.update(extract_warnings_from_single_artifact(p, targets))
return selected_warnings
if __name__ == "__main__":
def list_str(values):
return values.split(",")
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--workflow_run_id", type=str, required=True, help="A GitHub Actions workflow run id.")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Where to store the downloaded artifacts and other result files.",
)
parser.add_argument("--token", default=None, type=str, help="A token that has actions:read permission.")
# optional parameters
parser.add_argument(
"--targets",
default="DeprecationWarning,UserWarning,FutureWarning",
type=list_str,
help="Comma-separated list of target warning(s) which we want to extract.",
)
parser.add_argument(
"--from_gh",
action="store_true",
help="If running from a GitHub action workflow and collecting warnings from its artifacts.",
)
args = parser.parse_args()
from_gh = args.from_gh
if from_gh:
# The artifacts have to be downloaded using `actions/download-artifact@v4`
pass
else:
os.makedirs(args.output_dir, exist_ok=True)
# get download links
artifacts = get_artifacts_links(args.workflow_run_id, token=args.token)
with open(os.path.join(args.output_dir, "artifacts.json"), "w", encoding="UTF-8") as fp:
json.dump(artifacts, fp, ensure_ascii=False, indent=4)
# download artifacts
for idx, (name, url) in enumerate(artifacts.items()):
print(name)
print(url)
print("=" * 80)
download_artifact(name, url, args.output_dir, args.token)
# Be gentle to GitHub
time.sleep(1)
# extract warnings from artifacts
selected_warnings = extract_warnings(args.output_dir, args.targets)
selected_warnings = sorted(selected_warnings)
with open(os.path.join(args.output_dir, "selected_warnings.json"), "w", encoding="UTF-8") as fp:
json.dump(selected_warnings, fp, ensure_ascii=False, indent=4)

View File

@@ -0,0 +1,382 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script downloads files from the HuggingFace Hub to be used for CI tests.
"""
import os
import re
# Ensure we always download from the public HuggingFace Hub, not the CI staging endpoint.
# huggingface_hub reads HUGGINGFACE_CO_STAGING at import time and hardcodes hub-ci.huggingface.co.
_staging_mode = os.environ.pop("HUGGINGFACE_CO_STAGING", None)
import httpx # noqa: E402
from huggingface_hub import hf_hub_download, snapshot_download # noqa: E402
from transformers.testing_utils import _run_pipeline_tests, _run_staging # noqa: E402
from transformers.utils.import_utils import is_mistral_common_available # noqa: E402
# ruff: enable[E402]
# Restore so transformers.testing_utils._run_staging can still read it.
if _staging_mode is not None:
os.environ["HUGGINGFACE_CO_STAGING"] = _staging_mode
URLS_FOR_TESTING_DATA = [
# TODO: copy those to our hf-internal-testing dataset and fix all tests using them
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_formula_rec_001.png",
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg",
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/doc_test.jpg",
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/chart_parsing_02.png",
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/layout_demo.jpg",
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/img_rot180_demo.jpg",
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_rec_001.png",
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png",
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg",
# Don't use the original COCO URLs anymore. Replace with images from https://huggingface.co/datasets/hf-internal-testing/fixtures-coco below
"http://images.cocodataset.org/val2017/000000000139.jpg",
"http://images.cocodataset.org/val2017/000000000285.jpg",
"http://images.cocodataset.org/val2017/000000000632.jpg",
"http://images.cocodataset.org/val2017/000000000724.jpg",
"http://images.cocodataset.org/val2017/000000000776.jpg",
"http://images.cocodataset.org/val2017/000000000785.jpg",
"http://images.cocodataset.org/val2017/000000000802.jpg",
"http://images.cocodataset.org/val2017/000000000872.jpg",
"http://images.cocodataset.org/val2017/000000001000.jpg",
"http://images.cocodataset.org/val2017/000000039769.jpg",
"http://images.cocodataset.org/val2017/000000077595.jpg",
"http://images.cocodataset.org/val2017/000000136466.jpg",
"https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
"https://llava-vl.github.io/static/images/view.jpg",
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png",
"https://huggingface.co/datasets/raushan-testing-hf/audio-test/resolve/main/f2641_0_throatclearing.wav",
"https://huggingface.co/datasets/raushan-testing-hf/audio-test/resolve/main/glass-breaking-151256.mp3",
"https://huggingface.co/datasets/raushan-testing-hf/images_test/resolve/main/picsum_237_200x300.jpg",
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/Big_Buck_Bunny_720_10s_10MB.mp4",
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
"https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png",
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/two_dogs.jpg",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/australia.jpg",
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/tiny_video.mp4",
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/tiny_video.mp4",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
# we should rely on this single dataset for our tests
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png",
"https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
"https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/coco_annotations.txt",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/coco_panoptic_annotations.txt",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/coco_panoptic/000000039769.png",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000139.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000285.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000632.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000724.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000776.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000785.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000802.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000872.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000001000.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000004016.png",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000039769.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000039769.png",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000077595.jpg",
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000136466.jpg",
]
def url_to_local_path(url, return_url_if_not_found=True):
filename = url.split("/")[-1]
if not os.path.exists(filename) and return_url_if_not_found:
return url
return filename
def parse_hf_url(url):
"""
Parse a HuggingFace Hub URL into components for hf_hub_download.
Returns dict with (repo_id, filename, repo_type, revision) or None if not a HF URL.
"""
pattern = r"https://huggingface\.co/(datasets/)?([^/]+/[^/]+)/resolve/([^/]+)/(.+)"
match = re.match(pattern, url)
if not match:
return None
is_dataset = match.group(1) is not None
revision = match.group(3)
return {
"repo_id": match.group(2),
"filename": match.group(4),
"repo_type": "dataset" if is_dataset else "model",
"revision": revision if revision != "main" else None,
}
def validate_downloaded_content(filepath):
with open(filepath, "rb") as f:
header = f.read(32)
for bad_sig in [b"<!doctype", b"<html", b'{"error', b'{"message']:
if header.lower().startswith(bad_sig):
raise ValueError(
f"Downloaded file appears to be an HTML error page, not a valid media file. "
f"This may indicate rate limiting. File starts with: {header[:200]!r}"
)
file_size = os.path.getsize(filepath)
if file_size < 100:
raise ValueError(f"Downloaded file is suspiciously small ({file_size} bytes).")
return True
def download_test_file(url):
"""
Download a URL to a local file, using hf_hub_download for HF URLs.
For HuggingFace URLs, uses hf_hub_download which handles authentication
automatically via the HF_TOKEN environment variable.
Returns the local filename.
"""
filename = url.split("/")[-1]
# Skip if file already exists
if os.path.exists(filename):
print(f"File already exists: {filename}")
return filename
# Check if this is a HuggingFace URL
hf_parts = parse_hf_url(url)
if hf_parts:
# Use hf_hub_download for HF URLs - handles auth automatically via HF_TOKEN env var
print(f"Downloading {filename} from HuggingFace Hub...")
try:
hf_hub_download(**hf_parts, local_dir=".")
print(f"Successfully downloaded: {filename}")
except Exception as e:
print(f"Error downloading {filename} from HuggingFace Hub: {e}")
raise
else:
# Use httpx for non-HF URLs (COCO, Britannica, etc.)
import time
max_retries = 3
for attempt in range(max_retries):
try:
print(f"Downloading {filename} from {url}")
with open(filename, "wb") as f:
with httpx.stream("GET", url, follow_redirects=True) as resp:
resp.raise_for_status()
f.writelines(resp.iter_bytes(chunk_size=8192))
validate_downloaded_content(filename)
print(f"Successfully downloaded: {filename}")
break
except Exception as e:
if attempt < max_retries - 1:
wait = 2 ** (attempt + 1)
print(f"Attempt {attempt + 1} failed for {filename}: {e}. Retrying in {wait}s...")
if os.path.exists(filename):
os.remove(filename)
time.sleep(wait)
else:
raise
return filename
if __name__ == "__main__":
if _run_pipeline_tests:
import datasets
_ = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
_ = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", split="test", revision="refs/pr/1")
_ = hf_hub_download(repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset")
hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt")
hf_hub_download(
repo_id="hf-internal-testing/fixtures_docvqa",
filename="nougat_pdf.png",
repo_type="dataset",
revision="ec57bf8c8b1653a209c13f6e9ee66b12df0fc2db",
)
hf_hub_download(
repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset"
)
hf_hub_download(
repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset"
)
hf_hub_download(
repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy", repo_type="dataset"
)
hf_hub_download(
repo_id="hf-internal-testing/spaghetti-video",
filename="eating_spaghetti_32_frames.npy",
repo_type="dataset",
)
hf_hub_download(
repo_id="hf-internal-testing/spaghetti-video",
filename="eating_spaghetti_8_frames.npy",
repo_type="dataset",
)
hf_hub_download(
repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
)
hf_hub_download(repo_id="huggyllama/llama-7b", filename="tokenizer.model")
hf_hub_download(
repo_id="nielsr/audio-spectogram-transformer-checkpoint", filename="sample_audio.flac", repo_type="dataset"
)
hf_hub_download(repo_id="nielsr/example-pdf", repo_type="dataset", filename="example_pdf.png")
hf_hub_download(
repo_id="nielsr/test-image",
filename="llava_1_6_input_ids.pt",
repo_type="dataset",
)
hf_hub_download(
repo_id="nielsr/test-image",
filename="llava_1_6_pixel_values.pt",
repo_type="dataset",
)
hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
hf_hub_download(
repo_id="raushan-testing-hf/images_test",
filename="emu3_image.npy",
repo_type="dataset",
)
hf_hub_download(repo_id="raushan-testing-hf/images_test", filename="llava_v1_5_radar.jpg", repo_type="dataset")
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset")
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset")
hf_hub_download(
repo_id="shumingh/perception_lm_test_images",
filename="14496_0.PNG",
repo_type="dataset",
)
hf_hub_download(
repo_id="shumingh/perception_lm_test_videos",
filename="GUWR5TyiY-M_000012_000022.mp4",
repo_type="dataset",
)
repo_id = "nielsr/image-segmentation-toy-data"
hf_hub_download(
repo_id="nielsr/image-segmentation-toy-data",
filename="instance_segmentation_image_1.png",
repo_type="dataset",
)
hf_hub_download(
repo_id="nielsr/image-segmentation-toy-data",
filename="instance_segmentation_image_2.png",
repo_type="dataset",
)
hf_hub_download(
repo_id="nielsr/image-segmentation-toy-data",
filename="instance_segmentation_annotation_1.png",
repo_type="dataset",
)
hf_hub_download(
repo_id="nielsr/image-segmentation-toy-data",
filename="instance_segmentation_annotation_2.png",
repo_type="dataset",
)
hf_hub_download(
repo_id="nielsr/image-segmentation-toy-data",
filename="semantic_segmentation_annotation_1.png",
repo_type="dataset",
)
hf_hub_download(
repo_id="nielsr/image-segmentation-toy-data",
filename="semantic_segmentation_annotation_2.png",
repo_type="dataset",
)
hf_hub_download(
repo_id="nielsr/image-segmentation-toy-data",
filename="semantic_segmentation_image_1.png",
repo_type="dataset",
)
hf_hub_download(
repo_id="nielsr/image-segmentation-toy-data",
filename="semantic_segmentation_image_2.png",
repo_type="dataset",
)
hf_hub_download("shi-labs/oneformer_demo", "ade20k_panoptic.json", repo_type="dataset")
hf_hub_download(
repo_id="nielsr/audio-spectogram-transformer-checkpoint", filename="sample_audio.flac", repo_type="dataset"
)
# Need to specify the username on the endpoint `hub-ci`, otherwise we get
# `fatal: could not read Username for 'https://hub-ci.huggingface.co': Success`
# But this repo. is never used in a test decorated by `is_staging_test`.
if not _run_staging:
if not os.path.isdir("tiny-random-custom-architecture"):
snapshot_download(
"hf-internal-testing/tiny-random-custom-architecture",
local_dir="tiny-random-custom-architecture",
)
# For `tests/test_tokenization_mistral_common.py:TestMistralCommonBackend`, which eventually calls
# `mistral_common.tokens.tokenizers.utils.download_tokenizer_from_hf_hub` which (probably) doesn't have the cache.
# For `revision=None`, see https://github.com/huggingface/transformers/pull/40623
if is_mistral_common_available():
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.utils import list_local_hf_repo_files
from transformers import AutoTokenizer
from transformers.tokenization_mistral_common import MistralCommonBackend
repo_id = "hf-internal-testing/namespace-mistralai-repo_name-Mistral-Small-3.1-24B-Instruct-2503"
# determine if we already have this downloaded
local_files_only = len(list_local_hf_repo_files(repo_id, revision=None)) > 0
# This will go the path `transformers/tokenization_mistral_common.py::MistralCommonBackend::from_pretrained --> mistral_common.tokens.tokenizers.utils.download_tokenizer_from_hf_hub`.
# No idea at all why we need the statement below again (`MistralCommonBackend.from_pretrained`).
AutoTokenizer.from_pretrained(
repo_id, tokenizer_type="mistral", local_files_only=local_files_only, revision=None
)
_ = MistralCommonBackend.from_pretrained(
repo_id,
local_files_only=local_files_only,
# This is a hack as `list_local_hf_repo_files` from `mistral_common` has a bug
# TODO: Discuss with `mistral-common` maintainers: after a fix being done there, remove this `revision` hack
revision=None,
)
MistralTokenizer.from_hf_hub(repo_id, local_files_only=local_files_only)
repo_id = "mistralai/Voxtral-Mini-3B-2507"
local_files_only = len(list_local_hf_repo_files(repo_id, revision=None)) > 0
AutoTokenizer.from_pretrained(repo_id, local_files_only=local_files_only, revision=None)
MistralTokenizer.from_hf_hub(repo_id, local_files_only=local_files_only)
# Download files from URLs to local directory
for url in URLS_FOR_TESTING_DATA:
download_test_file(url)

View File

@@ -0,0 +1,116 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Format extras smoke test results for Slack notification.
This script reads failure reports from a JSON file and outputs environment
variables for GitHub Actions to use in Slack notifications.
"""
import argparse
import json
import os
import sys
def format_slack_message(failures_file, workflow_url, output_file=None):
"""
Format extras smoke test results into Slack message components.
Args:
failures_file: Path to JSON file containing failure reports
workflow_url: URL to the GitHub Actions workflow run
output_file: Optional path to output file (defaults to GITHUB_ENV)
Returns:
Dictionary with title, message, and workflow_url
"""
# Read failures
with open(failures_file) as f:
failures = json.load(f)
if not failures:
# Success case
title = "Extras Smoke Test - All tests passed"
message = "All extras installed successfully across all Python versions."
else:
# Failure case - group by Python version
failures_by_python = {}
for failure in failures:
py_ver = failure.get("python_version", "unknown")
extra = failure.get("extra", "unknown")
if py_ver not in failures_by_python:
failures_by_python[py_ver] = []
failures_by_python[py_ver].append(extra)
title = f"Extras Smoke Test Failed - {len(failures)} failure(s)"
# Build failure details
details = []
for py_ver in sorted(failures_by_python.keys()):
extras = failures_by_python[py_ver]
extras_list = "\n".join([f"• `{extra}`" for extra in sorted(extras)])
details.append(f"*Python {py_ver}*\n{extras_list}")
message = "\n\n".join(details)
# Determine output destination
if output_file is None:
output_file = os.environ.get("GITHUB_ENV")
if not output_file:
print("Error: GITHUB_ENV not set and no output file specified", file=sys.stderr)
sys.exit(1)
# Write environment variables
with open(output_file, "a") as f:
f.write(f"SLACK_TITLE={title}\n")
f.write(f"SLACK_WORKFLOW_URL={workflow_url}\n")
# Use heredoc for multiline message
f.write("SLACK_MESSAGE<<EOF\n")
f.write(f"{message}\n")
f.write("EOF\n")
return {"title": title, "message": message, "workflow_url": workflow_url}
def main():
parser = argparse.ArgumentParser(description="Format extras smoke test results for Slack")
parser.add_argument(
"--failures",
required=True,
help="Path to JSON file containing failure reports",
)
parser.add_argument(
"--workflow-url",
required=True,
help="URL to the GitHub Actions workflow run",
)
parser.add_argument(
"--output",
help="Output file path (defaults to GITHUB_ENV)",
)
args = parser.parse_args()
result = format_slack_message(args.failures, args.workflow_url, args.output)
print(f"Formatted Slack message: {result['title']}")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,309 @@
import argparse
import json
import logging
import math
import os
import time
import traceback
import zipfile
from collections import Counter
import requests
logger = logging.getLogger(__name__)
def get_jobs(workflow_run_id, token=None):
"""Extract jobs in a GitHub Actions workflow run"""
headers = None
if token is not None:
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/jobs?per_page=100"
result = requests.get(url, headers=headers).json()
jobs = []
try:
jobs.extend(result["jobs"])
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
for i in range(pages_to_iterate_over):
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
jobs.extend(result["jobs"])
return jobs
except Exception:
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
return []
def get_job_links(workflow_run_id, token=None):
"""Extract job names and their job links in a GitHub Actions workflow run"""
headers = None
if token is not None:
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/jobs?per_page=100"
result = requests.get(url, headers=headers).json()
job_links = {}
try:
job_links.update({job["name"]: job["html_url"] for job in result["jobs"]})
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
for i in range(pages_to_iterate_over):
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
job_links.update({job["name"]: job["html_url"] for job in result["jobs"]})
return job_links
except Exception:
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
return {}
def get_artifacts_links(workflow_run_id, token=None):
"""Get all artifact links from a workflow run"""
headers = None
if token is not None:
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
url = (
f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/artifacts?per_page=100"
)
result = requests.get(url, headers=headers).json()
artifacts = {}
try:
artifacts.update({artifact["name"]: artifact["archive_download_url"] for artifact in result["artifacts"]})
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
for i in range(pages_to_iterate_over):
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
artifacts.update({artifact["name"]: artifact["archive_download_url"] for artifact in result["artifacts"]})
return artifacts
except Exception:
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
return {}
def download_artifact(artifact_name, artifact_url, output_dir, token):
"""Download a GitHub Action artifact from a URL.
The URL is of the form `https://api.github.com/repos/huggingface/transformers/actions/artifacts/{ARTIFACT_ID}/zip`,
but it can't be used to download directly. We need to get a redirect URL first.
See https://docs.github.com/en/rest/actions/artifacts#download-an-artifact
"""
headers = None
if token is not None:
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
result = requests.get(artifact_url, headers=headers, allow_redirects=False)
download_url = result.headers["Location"]
response = requests.get(download_url, allow_redirects=True)
file_path = os.path.join(output_dir, f"{artifact_name}.zip")
with open(file_path, "wb") as fp:
fp.write(response.content)
def get_errors_from_single_artifact(artifact_zip_path, job_links=None):
"""Extract errors from a downloaded artifact (in .zip format)"""
errors = []
failed_tests = []
job_name = None
with zipfile.ZipFile(artifact_zip_path) as z:
for filename in z.namelist():
if not os.path.isdir(filename):
# read the file
if filename in ["failures_line.txt", "summary_short.txt", "job_name.txt"]:
with z.open(filename) as f:
for line in f:
line = line.decode("UTF-8").strip()
if filename == "failures_line.txt":
try:
# `error_line` is the place where `error` occurs
error_line = line[: line.index(": ")]
error = line[line.index(": ") + len(": ") :]
errors.append([error_line, error])
except Exception:
# skip un-related lines that don't match the expected format
logger.debug(f"Skipping unrelated line: {line}")
elif filename == "summary_short.txt" and line.startswith("FAILED "):
# `test` is the test method that failed
test = line[len("FAILED ") :]
failed_tests.append(test)
elif filename == "job_name.txt":
job_name = line
if len(errors) != len(failed_tests):
raise ValueError(
f"`errors` and `failed_tests` should have the same number of elements. Got {len(errors)} for `errors` "
f"and {len(failed_tests)} for `failed_tests` instead. The test reports in {artifact_zip_path} have some"
" problem."
)
job_link = None
if job_name and job_links:
job_link = job_links.get(job_name, None)
# A list with elements of the form (line of error, error, failed test)
result = [x + [y] + [job_link] for x, y in zip(errors, failed_tests)]
return result
def get_all_errors(artifact_dir, job_links=None):
"""Extract errors from all artifact files"""
errors = []
paths = [os.path.join(artifact_dir, p) for p in os.listdir(artifact_dir) if p.endswith(".zip")]
for p in paths:
errors.extend(get_errors_from_single_artifact(p, job_links=job_links))
return errors
def reduce_by_error(logs, error_filter=None):
"""count each error"""
counter = Counter()
counter.update([x[1] for x in logs])
counts = counter.most_common()
r = {}
for error, count in counts:
if error_filter is None or error not in error_filter:
r[error] = {"count": count, "failed_tests": [(x[2], x[0]) for x in logs if x[1] == error]}
r = dict(sorted(r.items(), key=lambda item: item[1]["count"], reverse=True))
return r
def get_model(test):
"""Get the model name from a test method"""
test = test.split("::")[0]
if test.startswith("tests/models/"):
test = test.split("/")[2]
else:
test = None
return test
def reduce_by_model(logs, error_filter=None):
"""count each error per model"""
logs = [(x[0], x[1], get_model(x[2])) for x in logs]
logs = [x for x in logs if x[2] is not None]
tests = {x[2] for x in logs}
r = {}
for test in tests:
counter = Counter()
# count by errors in `test`
counter.update([x[1] for x in logs if x[2] == test])
counts = counter.most_common()
error_counts = {error: count for error, count in counts if (error_filter is None or error not in error_filter)}
n_errors = sum(error_counts.values())
if n_errors > 0:
r[test] = {"count": n_errors, "errors": error_counts}
r = dict(sorted(r.items(), key=lambda item: item[1]["count"], reverse=True))
return r
def make_github_table(reduced_by_error):
header = "| no. | error | status |"
sep = "|-:|:-|:-|"
lines = [header, sep]
for error in reduced_by_error:
count = reduced_by_error[error]["count"]
line = f"| {count} | {error[:100]} | |"
lines.append(line)
return "\n".join(lines)
def make_github_table_per_model(reduced_by_model):
header = "| model | no. of errors | major error | count |"
sep = "|-:|-:|-:|-:|"
lines = [header, sep]
for model in reduced_by_model:
count = reduced_by_model[model]["count"]
error, _count = list(reduced_by_model[model]["errors"].items())[0]
line = f"| {model} | {count} | {error[:60]} | {_count} |"
lines.append(line)
return "\n".join(lines)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--workflow_run_id", type=str, required=True, help="A GitHub Actions workflow run id.")
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Where to store the downloaded artifacts and other result files.",
)
parser.add_argument("--token", default=None, type=str, help="A token that has actions:read permission.")
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
_job_links = get_job_links(args.workflow_run_id, token=args.token)
job_links = {}
# To deal with `workflow_call` event, where a job name is the combination of the job names in the caller and callee.
# For example, `PyTorch 1.11 / Model tests (models/albert, single-gpu)`.
if _job_links:
for k, v in _job_links.items():
# This is how GitHub actions combine job names.
if " / " in k:
index = k.find(" / ")
k = k[index + len(" / ") :]
job_links[k] = v
with open(os.path.join(args.output_dir, "job_links.json"), "w", encoding="UTF-8") as fp:
json.dump(job_links, fp, ensure_ascii=False, indent=4)
artifacts = get_artifacts_links(args.workflow_run_id, token=args.token)
with open(os.path.join(args.output_dir, "artifacts.json"), "w", encoding="UTF-8") as fp:
json.dump(artifacts, fp, ensure_ascii=False, indent=4)
for idx, (name, url) in enumerate(artifacts.items()):
download_artifact(name, url, args.output_dir, args.token)
# Be gentle to GitHub
time.sleep(1)
errors = get_all_errors(args.output_dir, job_links=job_links)
# `e[1]` is the error
counter = Counter()
counter.update([e[1] for e in errors])
# print the top 30 most common test errors
most_common = counter.most_common(30)
for item in most_common:
print(item)
with open(os.path.join(args.output_dir, "errors.json"), "w", encoding="UTF-8") as fp:
json.dump(errors, fp, ensure_ascii=False, indent=4)
reduced_by_error = reduce_by_error(errors)
reduced_by_model = reduce_by_model(errors)
s1 = make_github_table(reduced_by_error)
s2 = make_github_table_per_model(reduced_by_model)
with open(os.path.join(args.output_dir, "reduced_by_error.txt"), "w", encoding="UTF-8") as fp:
fp.write(s1)
with open(os.path.join(args.output_dir, "reduced_by_model.txt"), "w", encoding="UTF-8") as fp:
fp.write(s2)

View File

@@ -0,0 +1,71 @@
import argparse
import math
import traceback
import dateutil.parser as date_parser
import requests
def extract_time_from_single_job(job):
"""Extract time info from a single job in a GitHub Actions workflow run"""
job_info = {}
start = job["started_at"]
end = job["completed_at"]
start_datetime = date_parser.parse(start)
end_datetime = date_parser.parse(end)
duration_in_min = round((end_datetime - start_datetime).total_seconds() / 60.0)
job_info["started_at"] = start
job_info["completed_at"] = end
job_info["duration"] = duration_in_min
return job_info
def get_job_time(workflow_run_id, token=None):
"""Extract time info for all jobs in a GitHub Actions workflow run"""
headers = None
if token is not None:
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/jobs?per_page=100"
result = requests.get(url, headers=headers).json()
job_time = {}
try:
job_time.update({job["name"]: extract_time_from_single_job(job) for job in result["jobs"]})
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
for i in range(pages_to_iterate_over):
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
job_time.update({job["name"]: extract_time_from_single_job(job) for job in result["jobs"]})
return job_time
except Exception:
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
return {}
if __name__ == "__main__":
r"""
Example:
python get_github_job_time.py --workflow_run_id 2945609517
"""
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--workflow_run_id", type=str, required=True, help="A GitHub Actions workflow run id.")
args = parser.parse_args()
job_time = get_job_time(args.workflow_run_id)
job_time = dict(sorted(job_time.items(), key=lambda item: item[1]["duration"], reverse=True))
for k, v in job_time.items():
print(f"{k}: {v['duration']}")

View File

@@ -0,0 +1,35 @@
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.:
# python ./utils/get_modified_files.py utils src tests examples
#
# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered
# since the output of this script is fed into Makefile commands it doesn't print a newline after the results
import re
import subprocess
import sys
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
modified_files = (
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split()).decode("utf-8").split()
)
joined_dirs = "|".join(sys.argv[1:])
regex = re.compile(rf"^({joined_dirs}).*?\.py$")
relevant_modified_files = [x for x in modified_files if regex.match(x)]
print(" ".join(relevant_modified_files), end="")

View File

@@ -0,0 +1,133 @@
import argparse
import json
import re
import string
MAX_NUM_JOBS_TO_SUGGEST = 16
def get_jobs_to_run():
# The file `pr_files.txt` contains the information about the files changed in a pull request, and it is prepared by
# the caller (using GitHub api).
# We can also use the following api to get the information if we don't have them before calling this script.
# url = f"https://api.github.com/repos/huggingface/transformers/pulls/PULL_NUMBER/files?ref={pr_sha}"
with open("pr_files.txt") as fp:
pr_files = json.load(fp)
pr_files = [{k: v for k, v in item.items() if k in ["filename", "status"]} for item in pr_files]
pr_files = [item["filename"] for item in pr_files if item["status"] in ["added", "modified"]]
# models or quantizers
re_1 = re.compile(r"src/transformers/(models/.*)/modeling_.*\.py")
re_2 = re.compile(r"src/transformers/(quantizers/quantizer_.*)\.py")
# tests for models or quantizers
re_3 = re.compile(r"tests/(models/.*)/test_.*\.py")
re_4 = re.compile(r"tests/(quantization/.*)/test_.*\.py")
# files in a model directory but not necessary a modeling file
re_5 = re.compile(r"src/transformers/(models/.*)/.*\.py")
regexes = [re_1, re_2, re_3, re_4, re_5]
jobs_to_run = []
for pr_file in pr_files:
for regex in regexes:
matched = regex.findall(pr_file)
if len(matched) > 0:
item = matched[0]
item = item.replace("quantizers/quantizer_", "quantization/")
# TODO: for files in `quantizers`, the processed item above may not exist. Try using a fuzzy matching
if item in repo_content:
jobs_to_run.append(item)
break
jobs_to_run = sorted(set(jobs_to_run))
return jobs_to_run
def parse_message(message: str) -> str:
"""
Parses a GitHub pull request's comment to find the models specified in it to run slow CI.
Args:
message (`str`): The body of a GitHub pull request's comment.
Returns:
`str`: The substring in `message` after `run-slow`, run_slow` or run slow`. If no such prefix is found, the
empty string is returned.
"""
if message is None:
return ""
message = message.strip().lower()
# run-slow: model_1, model_2, quantization_1, quantization_2
if not message.startswith(("run-slow", "run_slow", "run slow")):
return ""
message = message[len("run slow") :]
# remove leading `:`
while message.strip().startswith(":"):
message = message.strip()[1:]
return message
def get_jobs(message: str):
models = parse_message(message)
return models.replace(",", " ").split()
def check_name(model_name: str):
allowed = string.ascii_letters + string.digits + "_"
return not (model_name.startswith("_") or model_name.endswith("_")) and all(c in allowed for c in model_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--message", type=str, default="", help="The content of a comment.")
parser.add_argument("--quantization", action="store_true", help="If we collect quantization tests")
args = parser.parse_args()
# The files are prepared by the caller (using GitHub api).
# We can also use the following api to get the information if we don't have them before calling this script.
# url = f"https://api.github.com/repos/OWNER/REPO/contents/PATH?ref={pr_sha}"
# (we avoid to checkout the repository using `actions/checkout` to reduce the run time, but mostly to avoid the potential security issue as much as possible)
repo_content = []
for filename in ["tests_dir.txt", "tests_models_dir.txt", "tests_quantization_dir.txt"]:
with open(filename) as fp:
data = json.load(fp)
data = [item["path"][len("tests/") :] for item in data if item["type"] == "dir"]
repo_content.extend(data)
# These don't have the prefix `models/` or `quantization/`, so we need to add them.
if args.message:
specified_jobs = get_jobs(args.message)
specified_jobs = [job for job in specified_jobs if check_name(job)]
# Add prefix (`models/` or `quantization`)
jobs_to_run = []
for job in specified_jobs:
if not args.quantization:
if f"models/{job}" in repo_content:
jobs_to_run.append(f"models/{job}")
elif job in repo_content and job != "quantization":
jobs_to_run.append(job)
elif f"quantization/{job}" in repo_content:
jobs_to_run.append(f"quantization/{job}")
print(sorted(set(jobs_to_run)))
else:
# Compute (from the added/modified files) the directories under `tests/`, `tests/models/` and `tests/quantization`to run tests.
# These are already with the prefix `models/` or `quantization/`, so we don't need to add them.
jobs_to_run = get_jobs_to_run()
jobs_to_run = [x.replace("models/", "").replace("quantization/", "") for x in jobs_to_run]
jobs_to_run = [job for job in jobs_to_run if check_name(job)]
if len(jobs_to_run) > MAX_NUM_JOBS_TO_SUGGEST:
jobs_to_run = jobs_to_run[:MAX_NUM_JOBS_TO_SUGGEST]
suggestion = f"{', '.join(jobs_to_run)}"
print(suggestion)

View File

@@ -0,0 +1,159 @@
import os
import zipfile
import requests
from get_ci_error_statistics import download_artifact, get_artifacts_links
def get_daily_ci_runs(token, num_runs=7, workflow_id=None):
"""Get the workflow runs of the scheduled (daily) CI.
This only selects the runs triggered by the `schedule` event on the `main` branch.
"""
headers = None
if token is not None:
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
# The id of a workflow (not of a workflow run).
# From a given workflow run (where we have workflow run id), we can get the workflow id by going to
# https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}
# and check the `workflow_id` key.
if not workflow_id:
workflow_run_id = os.environ["GITHUB_RUN_ID"]
workflow_run = requests.get(
f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}", headers=headers
).json()
workflow_id = workflow_run["workflow_id"]
url = f"https://api.github.com/repos/huggingface/transformers/actions/workflows/{workflow_id}/runs"
# On `main` branch + event being `schedule` + not returning PRs + only `num_runs` results
url += f"?branch=main&exclude_pull_requests=true&per_page={num_runs}"
result = requests.get(f"{url}&event=schedule", headers=headers).json()
workflow_runs = result["workflow_runs"]
if len(workflow_runs) == 0:
result = requests.get(f"{url}&event=workflow_run", headers=headers).json()
workflow_runs = result["workflow_runs"]
return workflow_runs
def get_last_daily_ci_run(token, workflow_run_id=None, workflow_id=None, commit_sha=None):
"""Get the last completed workflow run id of the scheduled (daily) CI."""
headers = None
if token is not None:
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
workflow_run = None
if workflow_run_id is not None and workflow_run_id != "":
workflow_run = requests.get(
f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}", headers=headers
).json()
return workflow_run
workflow_runs = get_daily_ci_runs(token, workflow_id=workflow_id)
for run in workflow_runs:
if commit_sha in [None, ""] and run["status"] == "completed":
workflow_run = run
break
# if `commit_sha` is specified, return the latest completed run with `workflow_run["head_sha"]` matching the specified sha.
elif commit_sha not in [None, ""] and run["head_sha"] == commit_sha and run["status"] == "completed":
workflow_run = run
break
return workflow_run
def get_last_daily_ci_workflow_run_id(token, workflow_run_id=None, workflow_id=None, commit_sha=None):
"""Get the last completed workflow run id of the scheduled (daily) CI."""
if workflow_run_id is not None and workflow_run_id != "":
return workflow_run_id
workflow_run = get_last_daily_ci_run(token, workflow_id=workflow_id, commit_sha=commit_sha)
workflow_run_id = None
if workflow_run is not None:
workflow_run_id = workflow_run["id"]
return workflow_run_id
def get_last_daily_ci_run_commit(token, workflow_run_id=None, workflow_id=None, commit_sha=None):
"""Get the commit sha of the last completed scheduled daily CI workflow run."""
workflow_run = get_last_daily_ci_run(
token, workflow_run_id=workflow_run_id, workflow_id=workflow_id, commit_sha=commit_sha
)
workflow_run_head_sha = None
if workflow_run is not None:
workflow_run_head_sha = workflow_run["head_sha"]
return workflow_run_head_sha
def get_last_daily_ci_artifacts(
output_dir,
token,
workflow_run_id=None,
workflow_id=None,
commit_sha=None,
artifact_names=None,
):
"""Get the artifacts of last completed workflow run id of the scheduled (daily) CI."""
workflow_run_id = get_last_daily_ci_workflow_run_id(
token, workflow_run_id=workflow_run_id, workflow_id=workflow_id, commit_sha=commit_sha
)
if workflow_run_id is not None:
artifacts_links = get_artifacts_links(workflow_run_id=workflow_run_id, token=token)
if artifact_names is None:
artifact_names = artifacts_links.keys()
downloaded_artifact_names = []
for artifact_name in artifact_names:
if artifact_name in artifacts_links:
artifact_url = artifacts_links[artifact_name]
download_artifact(
artifact_name=artifact_name, artifact_url=artifact_url, output_dir=output_dir, token=token
)
downloaded_artifact_names.append(artifact_name)
return downloaded_artifact_names
def get_last_daily_ci_reports(
output_dir,
token,
workflow_run_id=None,
workflow_id=None,
commit_sha=None,
artifact_names=None,
):
"""Get the artifacts' content of the last completed workflow run id of the scheduled (daily) CI."""
downloaded_artifact_names = get_last_daily_ci_artifacts(
output_dir,
token,
workflow_run_id=workflow_run_id,
workflow_id=workflow_id,
commit_sha=commit_sha,
artifact_names=artifact_names,
)
results = {}
for artifact_name in downloaded_artifact_names:
artifact_zip_path = os.path.join(output_dir, f"{artifact_name}.zip")
if os.path.isfile(artifact_zip_path):
target_dir = os.path.join(output_dir, artifact_name)
with zipfile.ZipFile(artifact_zip_path) as z:
z.extractall(target_dir)
results[artifact_name] = {}
filename = os.listdir(target_dir)
for filename in filename:
file_path = os.path.join(target_dir, filename)
if not os.path.isdir(file_path):
# read the file
with open(file_path) as fp:
content = fp.read()
results[artifact_name][filename] = content
return results

217
utils/get_test_info.py Normal file
View File

@@ -0,0 +1,217 @@
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import sys
import unittest
# This is required to make the module import works (when the python process is running from the root of the repo)
sys.path.append(".")
r"""
The argument `test_file` in this file refers to a model test file. This should be a string of the from
`tests/models/*/test_modeling_*.py`.
"""
def get_module_path(test_file):
"""Return the module path of a model test file."""
components = test_file.split(os.path.sep)
if components[0:2] != ["tests", "models"]:
raise ValueError(
"`test_file` should start with `tests/models/` (with `/` being the OS specific path separator). Got "
f"{test_file} instead."
)
test_fn = components[-1]
if not test_fn.endswith("py"):
raise ValueError(f"`test_file` should be a python file. Got {test_fn} instead.")
if not test_fn.startswith("test_modeling_"):
raise ValueError(
f"`test_file` should point to a file name of the form `test_modeling_*.py`. Got {test_fn} instead."
)
components = components[:-1] + [test_fn.replace(".py", "")]
test_module_path = ".".join(components)
return test_module_path
def get_test_module(test_file):
"""Get the module of a model test file."""
test_module_path = get_module_path(test_file)
try:
test_module = importlib.import_module(test_module_path)
except AttributeError as exc:
# e.g. if you have a `tests` folder in `site-packages`, created by another package, when trying to import
# `tests.models...`
raise ValueError(
f"Could not import module {test_module_path}. Confirm that you don't have a package with the same root "
"name installed or in your environment's `site-packages`."
) from exc
return test_module
def get_tester_classes(test_file):
"""Get all classes in a model test file whose names ends with `ModelTester`."""
tester_classes = []
test_module = get_test_module(test_file)
for attr in dir(test_module):
if attr.endswith("ModelTester"):
tester_classes.append(getattr(test_module, attr))
# sort with class names
return sorted(tester_classes, key=lambda x: x.__name__)
def get_test_classes(test_file):
"""Get all [test] classes in a model test file with attribute `all_model_classes` that are non-empty.
These are usually the (model) test classes containing the (non-slow) tests to run and are subclasses of
`ModelTesterMixin`, as well as a subclass of `unittest.TestCase`. Exceptions include `RagTestMixin` (and its subclasses).
"""
test_classes = []
test_module = get_test_module(test_file)
for attr in dir(test_module):
attr_value = getattr(test_module, attr)
# Look for the test classes (subclass of `unittest.TestCase`) with `all_model_classes` attribute.
# This also excludes `ModelTesterMixin` and `CausalLMModelTest`.
if isinstance(attr_value, type) and issubclass(attr_value, unittest.TestCase):
model_classes = getattr(attr_value, "all_model_classes", [])
# `CausalLMModelTest` (subclass of `ModelTesterMixin`) has `all_model_classes` as a class attribute with
# the value being `None`. For a real test class of `CausalLMModelTest`, the value is only set during `setUp`.
if model_classes is None:
test_instance = attr_value()
test_instance.setUp()
model_classes = getattr(test_instance, "all_model_classes", [])
if len(model_classes) > 0:
test_classes.append(attr_value)
# sort with class names
return sorted(test_classes, key=lambda x: x.__name__)
def get_model_classes(test_file):
"""Get all model classes that appear in `all_model_classes` attributes in a model test file."""
test_classes = get_test_classes(test_file)
model_classes = set()
for test_class in test_classes:
all_model_classes = test_class.all_model_classes
if all_model_classes is None:
test_instance = test_class()
test_instance.setUp()
all_model_classes = test_instance.all_model_classes
model_classes.update(all_model_classes)
# sort with class names
return sorted(model_classes, key=lambda x: x.__name__)
def get_model_tester_from_test_class(test_class):
"""Get the model tester class of a model test class."""
test = test_class()
if hasattr(test, "setUp"):
test.setUp()
model_tester = None
if hasattr(test, "model_tester"):
# `ModelTesterMixin` has this attribute default to `None`. Let's skip this case.
if test.model_tester is not None:
model_tester = test.model_tester.__class__
return model_tester
def get_test_classes_for_model(test_file, model_class):
"""Get all [test] classes in `test_file` that have `model_class` in their `all_model_classes`."""
test_classes = get_test_classes(test_file)
target_test_classes = []
for test_class in test_classes:
all_model_classes = test_class.all_model_classes
if all_model_classes is None:
test_instance = test_class()
test_instance.setUp()
all_model_classes = test_instance.all_model_classes
if model_class in all_model_classes:
target_test_classes.append(test_class)
# sort with class names
return sorted(target_test_classes, key=lambda x: x.__name__)
def get_tester_classes_for_model(test_file, model_class):
"""Get all model tester classes in `test_file` that are associated to `model_class`."""
test_classes = get_test_classes_for_model(test_file, model_class)
tester_classes = []
for test_class in test_classes:
tester_class = get_model_tester_from_test_class(test_class)
if tester_class is not None:
tester_classes.append(tester_class)
# sort with class names
return sorted(tester_classes, key=lambda x: x.__name__)
def get_test_to_tester_mapping(test_file):
"""Get a mapping from [test] classes to model tester classes in `test_file`.
This uses `get_test_classes` which may return classes that are NOT subclasses of `unittest.TestCase`.
"""
test_classes = get_test_classes(test_file)
test_tester_mapping = {test_class: get_model_tester_from_test_class(test_class) for test_class in test_classes}
return test_tester_mapping
def get_model_to_test_mapping(test_file):
"""Get a mapping from model classes to test classes in `test_file`."""
model_classes = get_model_classes(test_file)
model_test_mapping = {
model_class: get_test_classes_for_model(test_file, model_class) for model_class in model_classes
}
return model_test_mapping
def get_model_to_tester_mapping(test_file):
"""Get a mapping from model classes to model tester classes in `test_file`."""
model_classes = get_model_classes(test_file)
model_to_tester_mapping = {
model_class: get_tester_classes_for_model(test_file, model_class) for model_class in model_classes
}
return model_to_tester_mapping
def to_json(o):
"""Make the information succinct and easy to read.
Avoid the full class representation like `<class 'transformers.models.bert.modeling_bert.BertForMaskedLM'>` when
displaying the results. Instead, we use class name (`BertForMaskedLM`) for the readability.
"""
if isinstance(o, str):
return o
elif isinstance(o, type):
return o.__name__
elif isinstance(o, (list, tuple)):
return [to_json(x) for x in o]
elif isinstance(o, dict):
return {to_json(k): to_json(v) for k, v in o.items()}
else:
return o

270
utils/get_test_reports.py Normal file
View File

@@ -0,0 +1,270 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This util provides a way to manually run the tests of the transformers repo as they would be run by the CI.
It was mainly used for models tests, so if you find features missing for another suite, do not hesitate to open a PR.
Functionnalities:
- Running specific test suite (models, tokenizers, etc.)
- Parallel execution across multiple processes (each has to be launched separately with different `--processes` argument)
- GPU/CPU test filtering and slow tests filter
- Temporary cache management for isolated test runs
- Resume functionality for interrupted test runs
- Important models subset testing
Example usages are below.
"""
import argparse
import contextlib
import os
import subprocess
import tempfile
from pathlib import Path
import torch
from .important_files import IMPORTANT_MODELS
def is_valid_test_dir(path: Path) -> bool:
"""Check if a given path represents a valid test dir: the path must point to a dir, not start with '__' or '.'"""
return path.is_dir() and not path.name.startswith("__") and not path.name.startswith(".")
def run_pytest(
suite: str, subdir: Path, root_test_dir: Path, machine_type: str, dry_run: bool, tmp_cache: str, cpu_tests: bool
) -> None:
"""
Execute pytest on a specific test directory with configured options:
- suite (str): name of the test suite being run (e.g., 'models', 'tokenizers')
- subdir (Path): the specific directory containing tests to run
- root_test_dir (Path): the root directory of all tests, used for relative paths
- machine_type (str): type of machine/environment (e.g., 'cpu', 'single-gpu', 'multi-gpu')
- dry_run (bool): if True, only print the command without executing it
- tmp_cache (str): prefix for temporary cache directory. If empty, no temp cache is used
- cpu_tests (bool): if True, include CPU-only tests; if False, exclude non-device tests
"""
relative_path = subdir.relative_to(root_test_dir)
report_name = f"{machine_type}_{suite}_{relative_path}_test_reports"
print(f"Suite: {suite} | Running on: {relative_path}")
cmd = ["python3", "-m", "pytest", "-rsfE", "-v", f"--make-reports={report_name}", str(subdir)]
if not cpu_tests:
cmd = cmd + ["-m", "not not_device_test"]
ctx_manager = tempfile.TemporaryDirectory(prefix=tmp_cache) if tmp_cache else contextlib.nullcontext()
with ctx_manager as tmp_dir:
env = os.environ.copy()
if tmp_cache:
env["HUGGINGFACE_HUB_CACHE"] = tmp_dir
print(f"Using temporary cache located at {tmp_dir = }")
print("Command:", " ".join(cmd))
if not dry_run:
subprocess.run(cmd, check=False, env=env)
def handle_suite(
suite: str,
test_root: Path,
machine_type: str,
dry_run: bool,
tmp_cache: str = "",
resume_at: str | None = None,
only_in: list[str] | None = None,
cpu_tests: bool = False,
process_id: int = 1,
total_processes: int = 1,
) -> None:
"""
Handle execution of a complete test suite with advanced filtering and process distribution.
Args:
- suite (str): Name of the test suite to run (corresponds to a directory under test_root).
- test_root (Path): Root directory containing all test suites.
- machine_type (str): Machine/environment type for report naming and identification.
- dry_run (bool): If True, only print commands without executing them.
- tmp_cache (str, optional): Prefix for temporary cache directories. If empty, no temp cache is used.
- resume_at (str, optional): Resume execution starting from this subdirectory name.
Useful for restarting interrupted test runs. Defaults to None (run from the beginning).
- only_in (list[str], optional): Only run tests in these specific subdirectories.
Can include special values like IMPORTANT_MODELS. Defaults to None (run all tests).
- cpu_tests (bool, optional): Whether to include CPU-only tests. Defaults to False.
- process_id (int, optional): Current process ID for parallel execution (1-indexed). Defaults to 1.
- total_processes (int, optional): Total number of parallel processes. Defaults to 1.
"""
# Check path to suite
full_path = test_root / suite
if not full_path.exists():
print(f"Test folder does not exist: {full_path}")
return
# Establish the list of subdir to go through
subdirs = sorted(full_path.iterdir())
subdirs = [s for s in subdirs if is_valid_test_dir(s)]
if resume_at is not None:
subdirs = [s for s in subdirs if s.name >= resume_at]
if only_in is not None:
subdirs = [s for s in subdirs if s.name in only_in]
if subdirs and total_processes > 1:
# This interleaves the subdirs / files. For instance for subdirs = [A, B, C, D, E] and 2 processes:
# - script launcehd with `--processes 0 2` will run A, C, E
# - script launcehd with `--processes 1 2` will run B, D
subdirs = subdirs[process_id::total_processes]
# If the subdir list is not empty, go through each
if subdirs:
for subdir in subdirs:
run_pytest(suite, subdir, test_root, machine_type, dry_run, tmp_cache, cpu_tests)
# Otherwise, launch pytest from the full path
else:
run_pytest(suite, full_path, test_root, machine_type, dry_run, tmp_cache, cpu_tests)
if __name__ == "__main__":
"""Command-line interface for running test suite with comprehensive reporting. Check handle_suite for more details.
Command-line Arguments:
folder: Path to the root test directory (required)
--suite: Test suite name to run (default: "models")
--cpu-tests: Include CPU-only tests in addition to device tests
--run-slow: Execute slow tests instead of skipping them
--resume-at: Resume execution from a specific subdirectory
--only-in: Run tests only in specified subdirectories (supports IMPORTANT_MODELS)
--processes: Process distribution as "process_id total_processes"
--dry-run: Print commands without executing them
--tmp-cache: Use temporary cache directories for isolated runs
--machine-type: Override automatic machine type detection
Machine Type Detection:
- 'cpu': No CUDA available
- 'single-gpu': CUDA available with 1 GPU
- 'multi-gpu': CUDA available with multiple GPUs
Process Distribution:
Use --processes to split work across multiple parallel processes:
--processes 0 4 # This is process 0 of 4 total processes
--processes 1 4 # This is process 1 of 4 total processes
...
Usage Examples:
# Basic model testing
python3 -m utils.get_test_reports tests/ --suite models
# Run slow tests for important models only
python3 -m utils.get_test_reports tests/ --suite models --run-slow --only-in IMPORTANT_MODELS
# Parallel execution across 4 processes, second process to launch (processes are 0-indexed)
python3 -m utils.get_test_reports tests/ --suite models --processes 1 4
# Resume interrupted run from 'bert' subdirectory with a tmp cache
python3 -m utils.get_test_reports tests/ --suite models --resume-at bert --tmp-cache /tmp/
# Run specific models with CPU tests
python3 -m utils.get_test_reports tests/ --suite models --only-in bert gpt2 --cpu-tests
# Run slow tests for only important models with a tmp cache
python3 -m utils.get_test_reports tests/ --suite models --run-slow --only-in IMPORTANT_MODELS --tmp-cache /tmp/
"""
parser = argparse.ArgumentParser()
parser.add_argument("folder", help="Path to test root folder (e.g., ./tests)")
# Choose which tests to run (broad picture)
parser.add_argument("--suite", type=str, default="models", help="Test suit to run")
parser.add_argument("--cpu-tests", action="store_true", help="Also runs non-device tests")
parser.add_argument("--run-slow", action="store_true", help="Run slow tests instead of skipping them")
parser.add_argument("--collect-outputs", action="store_true", help="Collect outputs of the tests")
# Fine-grain control over the tests to run
parser.add_argument("--resume-at", type=str, default=None, help="Resume at a specific subdir / file in the suite")
parser.add_argument(
"--only-in",
type=str,
nargs="+",
help="Only run tests in the given subdirs / file. Use IMPORTANT_MODELS to run only the important models tests.",
)
# How to run the test suite: is the work divided among processes, do a try run, use temp cache?
parser.add_argument(
"--processes",
type=int,
nargs="+",
help="Inform each CI process as to the work to do: format as `process_id total_processes`. "
"In order to run with multiple (eg. 3) processes, you need to run the script multiple times (eg. 3 times).",
)
parser.add_argument("--dry-run", action="store_true", help="Only print commands without running them")
parser.add_argument("--tmp-cache", type=str, help="Change HUGGINGFACE_HUB_CACHE to a tmp dir for each test")
# This is a purely decorative argument, but it can be useful to distinguish between runs
parser.add_argument(
"--machine-type", type=str, default="", help="Machine type, automatically inferred if not provided"
)
args = parser.parse_args()
# Handle run slow
if args.run_slow:
os.environ["RUN_SLOW"] = "yes"
print("[WARNING] Running slow tests.")
else:
print("[WARNING] Skipping slow tests.")
# Handle multiple CI processes
if args.processes is None:
process_id, total_processes = 1, 1
elif len(args.processes) == 2:
process_id, total_processes = args.processes
else:
raise ValueError(f"Invalid processes argument: {args.processes}")
# Assert test root exists
test_root = Path(args.folder).resolve()
if not test_root.exists():
print(f"Root test folder not found: {test_root}")
exit(1)
# Handle collection of outputs
if args.collect_outputs:
os.environ["PATCH_TESTING_METHODS_TO_COLLECT_OUTPUTS"] = "yes"
reports_dir = test_root.parent / "reports"
os.environ["_PATCHED_TESTING_METHODS_OUTPUT_DIR"] = str(reports_dir)
# Infer machine type if not provided
if args.machine_type == "":
if not torch.cuda.is_available():
machine_type = "cpu"
else:
machine_type = "multi-gpu" if torch.cuda.device_count() > 1 else "single-gpu"
else:
machine_type = args.machine_type
# Reduce the scope for models if necessary
only_in = args.only_in if args.only_in else None
if only_in == ["IMPORTANT_MODELS"]:
only_in = IMPORTANT_MODELS
# Launch suite
handle_suite(
suite=args.suite,
test_root=test_root,
machine_type=machine_type,
dry_run=args.dry_run,
tmp_cache=args.tmp_cache,
resume_at=args.resume_at,
only_in=only_in,
cpu_tests=args.cpu_tests,
process_id=process_id,
total_processes=total_processes,
)

30
utils/important_files.py Normal file
View File

@@ -0,0 +1,30 @@
# List here the models to always test.
IMPORTANT_MODELS = [
"auto",
"bert",
"gpt2",
"t5",
"modernbert",
"vit",
"clip",
"detr",
"table_transformer",
"got_ocr2",
"whisper",
"wav2vec2",
"qwen2_audio",
"speech_t5",
"csm",
"llama",
"gemma3",
"qwen2",
"mistral3",
"qwen2_5_vl",
"llava",
"smolvlm",
"internvl",
"gemma3n",
"gpt_oss",
"qwen2_5_omni",
"pi0",
]

View File

@@ -0,0 +1,4 @@
models/llama
models/mistral
models/mixtral
models/gemma

View File

@@ -0,0 +1,338 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script to find a candidate list of models to deprecate based on the number of downloads and the date of the last
commit.
"""
import argparse
import glob
import json
import os
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path
from git import Repo
from huggingface_hub import HfApi
from tqdm import tqdm
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES, DEPRECATED_MODELS
api = HfApi()
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
repo = Repo(PATH_TO_REPO)
# Used when the folder name on the hub does not match the folder name in `transformers/models`
# format = {folder name in `transformers/models`: expected tag on the hub}
MODEL_FOLDER_NAME_TO_TAG_MAPPING = {
"audio_spectrogram_transformer": "audio-spectrogram-transformer",
"bert_generation": "bert-generation",
"blenderbot_small": "blenderbot-small",
"blip_2": "blip-2",
"dab_detr": "dab-detr",
"data2vec": "data2vec-audio", # actually, the base model is never used as a tag, but the sub models are
"deberta_v2": "deberta-v2",
"donut": "donut-swin",
"encoder_decoder": "encoder-decoder",
"grounding_dino": "grounding-dino",
"kosmos2": "kosmos-2",
"kosmos2_5": "kosmos-2.5",
"megatron_bert": "megatron-bert",
"mgp_str": "mgp-str",
"mm_grounding_dino": "mm-grounding-dino",
"modernbert_decoder": "modernbert-decoder",
"nllb_moe": "nllb-moe",
"omdet_turbo": "omdet-turbo",
"openai": "openai-gpt",
"roberta_prelayernorm": "roberta-prelayernorm",
"sew_d": "sew-d",
"speech_encoder_decoder": "speech-encoder-decoder",
"table_transformer": "table-transformer",
"unispeech_sat": "unispeech-sat",
"vision_encoder_decoder": "vision-encoder-decoder",
"vision_text_dual_encoder": "vision-text-dual-encoder",
"wav2vec2_bert": "wav2vec2-bert",
"wav2vec2_conformer": "wav2vec2-conformer",
"x_clip": "xclip",
"xlm_roberta": "xlm-roberta",
"xlm_roberta_xl": "xlm-roberta-xl",
}
# Used on model architectures with multiple tags on the hub (e.g. on VLMs, we often support a text-only model).
# Applied after the model folder name mapping. format = {base model tag: [extra tags]}
EXTRA_TAGS_MAPPING = {
"aimv2": ["aimv2_vision_model"],
"aria": ["aria_text"],
"bart": ["barthez", "bartpho"],
"bert": ["bert-japanese", "bertweet", "herbert", "phobert"],
"beit": ["dit"],
"blip-2": ["blip_2_qformer"],
"chinese_clip": ["chinese_clip_vision_model"],
"clip": ["clip_text_model", "clip_vision_model"],
"data2vec-audio": ["data2vec-text", "data2vec-vision"],
"depth_anything": ["depth_anything_v2"],
"donut-swin": ["nougat"],
"edgetam": ["edgetam_vision_model"],
"fastspeech2_conformer": ["fastspeech2_conformer_with_hifigan"],
"gemma3": ["gemma3_text"],
"gemma3n": ["gemma3n_audio", "gemma3n_text", "gemma3n_vision"],
"gpt2": ["cpm", "dialogpt", "gpt-sw3", "megatron_gpt2"],
"glm4v_moe": ["glm4v_moe_text", "glm4v_moe_vision"],
"glm4_image": ["glm4_image_text", "glm4_image_vision"],
"glm4v": ["glm4v_text", "glm4v_vision"],
"idefics3": ["idefics3_vision"],
"internvl": ["internvl_vision"],
"layoutlmv2": ["layoutxlm"],
"llama": ["code_llama", "falcon3", "llama2", "llama3"],
"llama4": ["llama4_text"],
"llava_next": ["granitevision"],
"luke": ["mluke"],
"m2m_100": ["nllb"],
"maskformer": ["maskformer-swin"],
"mbart": ["mbart50"],
"parakeet": ["parakeet_ctc", "parakeet_encoder"],
"lasr": ["lasr_ctc", "lasr_encoder"],
"perception_lm": ["perception_encoder"],
"pix2struct": ["deplot", "matcha"],
"qwen2_5_vl": ["qwen2_5_vl_text"],
"qwen2_audio": ["qwen2_audio_encoder"],
"qwen2_vl": ["qwen2_vl_text"],
"qwen3_vl_moe": ["qwen3_vl_moe_text"],
"qwen3_vl": ["qwen3_vl_text"],
"qwen3_5": ["qwen3_5text"],
"qwen3_5_moe": ["qwen3_5_moe_text"],
"rt_detr": ["rt_detr_resnet"],
"sam2": ["sam2_hiera_det_model", "sam2_vision_model"],
"sam": ["sam_hq_vision_model", "sam_vision_model"],
"siglip2": ["siglip2_vision_model"],
"siglip": ["siglip_vision_model"],
"smolvlm": ["smolvlm_vision"],
"t5": ["byt5", "flan-t5", "flan-ul2", "madlad-400", "myt5", "t5v1.1", "ul2"],
"voxtral": ["voxtral_encoder"],
"wav2vec2": ["mms", "wav2vec2_phoneme", "xls_r", "xlsr_wav2vec2"],
"xlm-roberta": ["xlm-v"],
}
# Similar to `DEPRECATED_MODELS`, but containing the tags when the model tag does not match the model folder name :'(
DEPRECATED_MODELS_TAGS = {"gptsan-japanese", "open-llama", "transfo-xl", "xlm-prophetnet"}
class HubModelLister:
"""
Utility for getting models from the hub based on tags. Handles errors without crashing the script.
"""
def __init__(self, tags):
self.tags = tags
self.model_list = api.list_models(filter=tags)
def __iter__(self):
try:
yield from self.model_list
except Exception as e:
print(f"Error: {e}")
return
def _extract_commit_hash(commits):
for commit in commits:
if commit.startswith("commit "):
return commit.split(" ")[1]
return ""
def get_list_of_repo_model_paths(models_dir):
# Get list of all models in the library
models = glob.glob(os.path.join(models_dir, "*/modeling_*.py"))
# Get list of all deprecated models in the library
deprecated_models = glob.glob(os.path.join(models_dir, "deprecated", "*"))
# For each deprecated model, remove the deprecated models from the list of all models as well as the symlink path
for deprecated_model in deprecated_models:
deprecated_model_name = "/" + deprecated_model.split("/")[-1] + "/"
models = [model for model in models if deprecated_model_name not in model]
# Remove deprecated models
models = [model for model in models if "/deprecated" not in model]
# Remove auto
models = [model for model in models if "/auto/" not in model]
return models
def get_list_of_models_to_deprecate(
thresh_num_downloads=5_000,
thresh_date=None,
use_cache=False,
save_model_info=False,
max_num_models=-1,
):
if thresh_date is None:
thresh_date = datetime.now(timezone.utc).replace(year=datetime.now(timezone.utc).year - 1)
else:
thresh_date = datetime.strptime(thresh_date, "%Y-%m-%d").replace(tzinfo=timezone.utc)
models_dir = PATH_TO_REPO / "src/transformers/models"
model_paths = get_list_of_repo_model_paths(models_dir=models_dir)
if use_cache and os.path.exists("models_info.json"):
with open("models_info.json", "r") as f:
models_info = json.load(f)
# Convert datetimes back to datetime objects
for model, info in models_info.items():
info["first_commit_datetime"] = datetime.fromisoformat(info["first_commit_datetime"])
else:
print("Building a dictionary of basic model info...")
models_info = defaultdict(dict)
for i, model_path in enumerate(tqdm(sorted(model_paths))):
if max_num_models != -1 and i > max_num_models:
break
model = model_path.split("/")[-2]
if model in models_info:
continue
commits = repo.git.log("--diff-filter=A", "--", model_path).split("\n")
commit_hash = _extract_commit_hash(commits)
commit_obj = repo.commit(commit_hash)
committed_datetime = commit_obj.committed_datetime
models_info[model]["commit_hash"] = commit_hash
models_info[model]["first_commit_datetime"] = committed_datetime
models_info[model]["model_path"] = model_path
models_info[model]["downloads"] = 0
models_info[model]["tags"] = [model]
# The keys in the dictionary above are the model folder names. In some cases, the model tag on the hub does not
# match the model folder name. We replace the key and append the expected tag.
for folder_name, expected_tag in MODEL_FOLDER_NAME_TO_TAG_MAPPING.items():
if folder_name in models_info:
models_info[expected_tag] = models_info[folder_name]
models_info[expected_tag]["tags"] = [expected_tag]
del models_info[folder_name]
# Some models have multiple tags on the hub. We add the expected tag to the list of tags.
for model_name, extra_tags in EXTRA_TAGS_MAPPING.items():
if model_name in models_info:
models_info[model_name]["tags"].extend(extra_tags)
# Sanity check for the case with all models: the model tags must match the keys in the CONFIG_MAPPING_NAMES
# (= actual model tags on the hub)
if max_num_models == -1:
all_model_tags = set()
for model_name in models_info:
all_model_tags.update(models_info[model_name]["tags"])
non_deprecated_model_tags = (
set(CONFIG_MAPPING_NAMES.keys()) - set(DEPRECATED_MODELS_TAGS) - set(DEPRECATED_MODELS)
)
if all_model_tags != non_deprecated_model_tags:
raise ValueError(
"The tags of the `models_info` dictionary must match the keys in the `CONFIG_MAPPING_NAMES`!"
"\nMissing tags in `model_info`: "
+ str(sorted(non_deprecated_model_tags - all_model_tags))
+ "\nExtra tags in `model_info`: "
+ str(sorted(all_model_tags - non_deprecated_model_tags))
+ "\n\nYou need to update one or more of the following: `CONFIG_MAPPING_NAMES`, "
"`EXTRA_TAGS_MAPPING` or `DEPRECATED_MODELS_TAGS`."
)
# Filter out models which were added less than a year ago
models_info = {
model: info for model, info in models_info.items() if info["first_commit_datetime"] < thresh_date
}
# We make successive calls to the hub, filtering based on the model tags
print("Making calls to the hub to find models below the threshold number of downloads...")
num_models = len(models_info)
for i, (model, model_info) in enumerate(models_info.items()):
print(f"{i + 1}/{num_models}: getting hub downloads for model='{model}' (tags={model_info['tags']})")
for model_tag in model_info["tags"]:
if model_info["downloads"] > thresh_num_downloads:
break
model_list = HubModelLister(tags=model_tag)
for hub_model in model_list:
if hub_model.private:
continue
model_info["downloads"] += hub_model.downloads
# No need to make further hub calls, it's above the set threshold
if model_info["downloads"] > thresh_num_downloads:
break
if save_model_info and not (use_cache and os.path.exists("models_info.json")):
# Make datetimes serializable
for model, info in models_info.items():
info["first_commit_datetime"] = info["first_commit_datetime"].isoformat()
with open("models_info.json", "w") as f:
json.dump(models_info, f, indent=4)
print("\nFinding models to deprecate:")
n_models_to_deprecate = 0
models_to_deprecate = {}
for model, info in models_info.items():
n_downloads = info["downloads"]
if n_downloads < thresh_num_downloads:
n_models_to_deprecate += 1
models_to_deprecate[model] = info
print(f"\nModel: {model}")
print(f"Downloads: {n_downloads}")
print(f"Date: {info['first_commit_datetime']}")
# sort models to deprecate by downloads (lowest downloads first)
models_to_deprecate = sorted(models_to_deprecate.items(), key=lambda x: x[1]["downloads"])
print("\nModels to deprecate: ", "\n" + "\n".join([model[0] for model in models_to_deprecate]))
print(f"\nNumber of models to deprecate: {n_models_to_deprecate}")
print("Before deprecating make sure to verify the models, including if they're used as a module in other models.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--save_model_info", action="store_true", help="Save the retrieved model info to a json file.")
parser.add_argument(
"--use_cache", action="store_true", help="Use the cached model info instead of calling the hub."
)
parser.add_argument(
"--thresh_num_downloads",
type=int,
default=5_000,
help=(
"Threshold number of downloads below which a model should be deprecated. Default is 5,000. If you are "
"considering a sweep and using a cache, set this to the highest number of the sweep."
),
)
parser.add_argument(
"--thresh_date",
type=str,
default=None,
help=(
"Date to consider the first commit from. Format: YYYY-MM-DD. If unset, defaults to one year ago from "
"today."
),
)
parser.add_argument(
"--max_num_models",
type=int,
default=-1,
help="Maximum number of models architectures to consider. -1 means all models. Useful for testing.",
)
args = parser.parse_args()
models_to_deprecate = get_list_of_models_to_deprecate(
thresh_num_downloads=args.thresh_num_downloads,
thresh_date=args.thresh_date,
use_cache=args.use_cache,
save_model_info=args.save_model_info,
max_num_models=args.max_num_models,
)

View File

@@ -0,0 +1,184 @@
import os
import libcst as cst
# Files from external libraries that should not be tracked
# E.g. for habana, we don't want to track the dependencies from `modeling_all_models.py` as it is not part of the transformers library
EXCLUDED_EXTERNAL_FILES = {
"habana": [{"name": "modeling_all_models", "type": "modeling"}],
}
def convert_relative_import_to_absolute(
import_node: cst.ImportFrom,
file_path: str,
package_name: str | None = "transformers",
) -> cst.ImportFrom:
"""
Convert a relative libcst.ImportFrom node into an absolute one,
using the file path and package name.
Args:
import_node: A relative import node (e.g. `from ..utils import helper`)
file_path: Path to the file containing the import (can be absolute or relative)
package_name: The top-level package name (e.g. 'myproject')
Returns:
A new ImportFrom node with the absolute import path
"""
if not (import_node.relative and len(import_node.relative) > 0):
return import_node # Already absolute
file_path = os.path.abspath(file_path)
rel_level = len(import_node.relative)
# Strip file extension and split into parts
file_path_no_ext = file_path.removesuffix(".py")
file_parts = file_path_no_ext.split(os.path.sep)
# Ensure the file path includes the package name
if package_name not in file_parts:
raise ValueError(f"Package name '{package_name}' not found in file path '{file_path}'")
# Slice file_parts starting from the package name
pkg_index = file_parts.index(package_name)
module_parts = file_parts[pkg_index + 1 :] # e.g. ['module', 'submodule', 'foo']
if len(module_parts) < rel_level:
raise ValueError(f"Relative import level ({rel_level}) goes beyond package root.")
base_parts = module_parts[:-rel_level]
# Flatten the module being imported (if any)
def flatten_module(module: cst.BaseExpression | None) -> list[str]:
if not module:
return []
if isinstance(module, cst.Name):
return [module.value]
elif isinstance(module, cst.Attribute):
parts = []
while isinstance(module, cst.Attribute):
parts.insert(0, module.attr.value)
module = module.value
if isinstance(module, cst.Name):
parts.insert(0, module.value)
return parts
return []
import_parts = flatten_module(import_node.module)
# Combine to get the full absolute import path
full_parts = [package_name] + base_parts + import_parts
# Handle special case where the import comes from a namespace package (e.g. optimum with `optimum.habana`, `optimum.intel` instead of `src.optimum`)
if package_name != "transformers" and file_parts[pkg_index - 1] != "src":
full_parts = [file_parts[pkg_index - 1]] + full_parts
# Build the dotted module path
dotted_module: cst.BaseExpression | None = None
for part in full_parts:
name = cst.Name(part)
dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)
# Return a new ImportFrom node with absolute import
return import_node.with_changes(module=dotted_module, relative=[])
def convert_to_relative_import(import_node: cst.ImportFrom, file_path: str, package_name: str) -> cst.ImportFrom:
"""
Convert an absolute import to a relative one if it belongs to `package_name`.
Parameters:
- node: The ImportFrom node to possibly transform.
- file_path: Absolute path to the file containing the import (e.g., '/path/to/mypackage/foo/bar.py').
- package_name: The top-level package name (e.g., 'mypackage').
Returns:
- A possibly modified ImportFrom node.
"""
if import_node.relative:
return import_node # Already relative import
# Extract module name string from ImportFrom
def get_module_name(module):
if isinstance(module, cst.Name):
return module.value, [module.value]
elif isinstance(module, cst.Attribute):
parts = []
while isinstance(module, cst.Attribute):
parts.append(module.attr.value)
module = module.value
if isinstance(module, cst.Name):
parts.append(module.value)
parts.reverse()
return ".".join(parts), parts
return "", None
module_name, submodule_list = get_module_name(import_node.module)
# Check if it's from the target package
if (
not (module_name.startswith(package_name + ".") or module_name.startswith("optimum." + package_name + "."))
and module_name != package_name
):
return import_node # Not from target package
# Locate the package root inside the file path
norm_file_path = os.path.normpath(file_path)
parts = norm_file_path.split(os.sep)
try:
pkg_index = parts.index(package_name)
except ValueError:
# Package name not found in path — assume we can't resolve relative depth
return import_node
# Depth is how many directories after the package name before the current file
depth = len(parts) - pkg_index - 1 # exclude the .py file itself
for i, submodule in enumerate(parts[pkg_index + 1 :]):
if submodule == submodule_list[2 + i]:
depth -= 1
else:
break
# Create the correct number of dots
relative = [cst.Dot()] * depth if depth > 0 else [cst.Dot()]
# Strip package prefix from import module path
if module_name.startswith("optimum." + package_name + "."):
stripped_name = module_name[len("optimum." + package_name) :].lstrip(".")
else:
stripped_name = module_name[len(package_name) :].lstrip(".")
# Build new module node
if stripped_name == "":
new_module = None
else:
name_parts = stripped_name.split(".")[i:]
new_module = cst.Name(name_parts[0])
for part in name_parts[1:]:
new_module = cst.Attribute(value=new_module, attr=cst.Name(part))
return import_node.with_changes(module=new_module, relative=relative)
class AbsoluteImportTransformer(cst.CSTTransformer):
def __init__(self, relative_path: str, source_library: str):
super().__init__()
self.relative_path = relative_path
self.source_library = source_library
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
return convert_relative_import_to_absolute(
import_node=updated_node, file_path=self.relative_path, package_name=self.source_library
)
class RelativeImportTransformer(cst.CSTTransformer):
def __init__(self, relative_path: str, source_library: str):
super().__init__()
self.relative_path = relative_path
self.source_library = source_library
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
return convert_to_relative_import(updated_node, self.relative_path, self.source_library)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,913 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 🔴🔴🔴 THIS IS AN INTERNAL TOOL. It WILL interact with the hub and use significant local compute resources. Use at your own risk.
"""
Modular model detector: utilities for detecting code similarities between model implementations.
This module provides tools to analyze and detect similarities between different model implementations
in the transformers library. It uses both embedding-based and token-based (Jaccard) similarity metrics
to identify similar code patterns across different model definitions.
Its function is to identify which models can be _modular_-ized, meaning, which already existing classes are
present in the codebase and look very similar to the one we have.
Two scores are computed, one is a code embedding, and the other is a simple Jaccard bag-of-tokens index for overlap
of token sets. A score of 1.00 means the code is identical.
Usage:
```bash
cd transformers
# Use directly the util, it will download the index embedding from the hub. It will require some RAM/VRAM.
>>> python utils/modular_model_detector.py --modeling-file my_new_beit3_modeling_file.py
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 33.62it/s]
encoding 21 query definitions with Qwen/Qwen3-Embedding-4B (device=cuda, batch=16, max_length=4096)
stuff.py::Beit3ImageTextMatchingOutput:
embedding:
blip_2::Blip2ImageTextMatchingModelOutput (0.9994)
chinese_clip::ChineseCLIPOutput (0.9818)
owlvit::OwlViTOutput (0.9818)
aimv2::Aimv2Output (0.9818)
blip::BlipOutput (0.9818)
jaccard:
owlv2::Owlv2Output (0.9667)
metaclip_2::MetaClip2Output (0.9667)
altclip::AltCLIPOutput (0.9667)
owlvit::OwlViTOutput (0.9667)
blip::BlipOutput (0.9667)
intersection:
blip::BlipOutput
owlvit::OwlViTOutput
stuff.py::Beit3MLP:
embedding:
efficientloftr::EfficientLoFTRMLP (0.9718)
seggpt::SegGptMlp (0.9650)
mgp_str::MgpstrMlp (0.9646)
vitpose_backbone::VitPoseBackboneMLP (0.9640)
granitemoeshared::GraniteMoeSharedMLP (0.9633)
jaccard:
chinese_clip::ChineseCLIPTextSelfOutput (0.5294)
convbert::ConvBertSelfOutput (0.5294)
bert::BertSelfOutput (0.5294)
roformer::RoFormerSelfOutput (0.5294)
layoutlmv3::LayoutLMv3SelfOutput (0.5294)
intersection:
stuff.py::Beit3FeedForwardNetwork:
embedding:
prophetnet::ProphetNetFeedForward (0.9766)
dab_detr::DabDetrDecoderLayerFFN (0.9730)
kosmos2::Kosmos2TextFFN (0.9697)
kosmos2_5::Kosmos2_5TextFFN (0.9697)
parakeet::ParakeetEncoderFeedForward (0.9678)
jaccard:
groupvit::GroupViTMLP (0.4898)
convbert::ConvBertOutput (0.4600)
chinese_clip::ChineseCLIPTextOutput (0.4565)
bert::BertOutput (0.4565)
roformer::RoFormerOutput (0.4565)
intersection:
```
# If you wish to build the index first, you can run
python utils/modular_model_detector.py --build
# You can also change the embedding model for a larger/smaller one.
"""
import argparse
import ast
import json
import logging
import os
import re
from datetime import datetime
from functools import cache
from pathlib import Path
import numpy as np
import torch
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub import logging as huggingface_hub_logging
from safetensors.numpy import load_file as safetensors_load
from safetensors.numpy import save_file as safetensors_save
from tqdm import tqdm
import transformers
from transformers import AutoModel, AutoTokenizer
from transformers.utils import enable_tf32
from transformers.utils import logging as transformers_logging
# ANSI color codes for CLI output styling
ANSI_RESET = "\033[0m"
ANSI_BOLD = "\033[1m"
ANSI_HEADER = "\033[1;36m"
ANSI_SECTION = "\033[1;35m"
ANSI_ROW = "\033[0;37m"
ANSI_HIGHLIGHT_TOP = "\033[1;32m"
ANSI_HIGHLIGHT_OLD = "\033[1;33m"
ANSI_HIGHLIGHT_CANDIDATE = "\033[1;34m"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
MODELS_ROOT = Path("src/transformers/models")
EMBEDDINGS_PATH = "embeddings.safetensors"
INDEX_MAP_PATH = "code_index_map.json"
TOKENS_PATH = "code_index_tokens.json"
HUB_DATASET_DEFAULT = "hf-internal-testing/transformers_code_embeddings"
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B"
BATCH_SIZE = 16
MAX_LENGTH = 4096
def _normalize(string: str | None) -> str:
"""
Normalize a string by removing all non-alphanumeric characters and converting to lowercase.
Args:
string (`str` or `None`): The string to normalize.
Returns:
`str`: The normalized string, or empty string if input is None.
"""
return re.sub(r"[^a-z0-9]+", "", string.lower()) if string else ""
def _strip_source_for_tokens(code: str) -> str:
"""
Strip docstrings, comments, and import statements from source code.
Args:
code (`str`): The source code to strip.
Returns:
`str`: The stripped source code.
"""
code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code)
code = re.sub(r"#.*", "", code)
return "\n".join(line for line in code.splitlines() if not re.match(r"\s*(from|import)\s+", line))
def _tokenize(code: str) -> set[str]:
"""
Extract all Python identifiers from source code.
Args:
code (`str`): The source code to tokenize.
Returns:
`set[str]`: A set of all identifiers found in the code.
"""
return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code))
def _leading_symbol_prefix(name: str) -> str:
"""
Extract the leading prefix from a symbol name (e.g., 'Llama' from 'LlamaAttention').
Args:
name (`str`): The symbol name to extract prefix from.
Returns:
`str`: The leading prefix, or empty string if no match.
"""
match = re.match(r"^([A-Z][a-z0-9]+)", name) or re.match(r"^([A-Za-z0-9]+)", name)
return match.group(1) if match else ""
def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str:
"""
Sanitize code for embedding by replacing model-specific identifiers with generic placeholder.
Args:
code (`str`): The source code to sanitize.
model_hint (`str` or `None`): Hint about the model name (e.g., 'llama').
symbol_hint (`str` or `None`): Hint about the symbol name (e.g., 'LlamaAttention').
Returns:
`str`: The sanitized code with model-specific identifiers replaced by 'Model'.
"""
base = _strip_source_for_tokens(code)
variants = set()
if model_hint:
variants.add(model_hint)
variants.add(model_hint.replace("_", ""))
variants.add(re.sub(r"\d+", "", model_hint))
if symbol_hint:
prefix = _leading_symbol_prefix(symbol_hint)
if prefix:
variants.add(prefix)
variants.add(prefix.replace("_", ""))
variants.add(re.sub(r"\d+", "", prefix))
variants |= {variant.lower() for variant in list(variants)}
sanitized = base
for variant in sorted({x for x in variants if len(x) >= 3}, key=len, reverse=True):
sanitized = re.sub(re.escape(variant), "Model", sanitized, flags=re.IGNORECASE)
return sanitized
class CodeSimilarityAnalyzer:
"""
Analyzer for detecting code similarities between model implementations.
This class uses embedding-based and token-based similarity metrics to identify similar
code patterns across different model definitions in the transformers library.
Args:
hub_dataset (`str`): The Hub dataset repository ID containing the code embeddings index.
"""
def __init__(self, hub_dataset: str):
for name in ("huggingface_hub", "httpx", "urllib3", "transformers"):
logging.getLogger(name).setLevel(logging.ERROR)
huggingface_hub_logging.set_verbosity_error()
transformers_logging.set_verbosity_error()
enable_tf32(True)
torch.set_grad_enabled(False)
self.models_root = MODELS_ROOT
self.hub_dataset = hub_dataset
self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval()
self.device = self.model.device
self.index_dir: Path | None = None
# ---------- HUB IO ----------
def _resolve_index_path(self, filename: str) -> Path:
if self.index_dir is None:
return Path(filename)
return self.index_dir / filename
def ensure_local_index(self) -> None:
"""Ensure index files are available locally, preferring Hub cache snapshots."""
if self.index_dir is not None and all(
(self.index_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)
):
return
workspace_dir = Path.cwd()
if all((workspace_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)):
self.index_dir = workspace_dir
return
logging.info(f"downloading index from hub cache: {self.hub_dataset}")
snapshot_path = snapshot_download(repo_id=self.hub_dataset, repo_type="dataset")
snapshot_dir = Path(snapshot_path)
missing = [
fname for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) if not (snapshot_dir / fname).exists()
]
if missing:
raise FileNotFoundError("Missing expected files in Hub snapshot: " + ", ".join(missing))
self.index_dir = snapshot_dir
def push_index_to_hub(self) -> None:
"""Upload index files to the Hub dataset repository."""
api = HfApi()
api.create_repo(repo_id=self.hub_dataset, repo_type="dataset", exist_ok=True)
for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH):
logging.info(f"pushing {fname} -> {self.hub_dataset}")
api.upload_file(
path_or_fileobj=fname,
path_in_repo=os.path.basename(fname),
repo_id=self.hub_dataset,
repo_type="dataset",
)
# ---------- parsing & encoding ----------
def _extract_definitions(
self, file_path: Path, relative_to: Path | None = None, model_hint: str | None = None
) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]:
"""
Extract class and function definitions from a Python file.
Args:
file_path (`Path`): Path to the Python file to parse.
relative_to (`Path` or `None`): Base path for computing relative identifiers.
model_hint (`str` or `None`): Model name hint for sanitization.
Returns:
`tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]`: A tuple containing:
- definitions_raw: Mapping of identifiers to raw source code
- definitions_sanitized: Mapping of identifiers to sanitized source code
- definitions_tokens: Mapping of identifiers to sorted token lists
- definitions_kind: Mapping of identifiers to either "class" or "function"
"""
definitions_raw = {}
definitions_sanitized = {}
definitions_tokens = {}
definitions_kind = {}
source = file_path.read_text(encoding="utf-8")
lines = source.splitlines()
tree = ast.parse(source)
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
segment = ast.get_source_segment(source, node)
if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
start = max(0, node.lineno - 1)
end = node.end_lineno
segment = "\n".join(lines[start:end])
if segment:
identifier = (
f"{file_path.relative_to(relative_to)}:{node.name}"
if relative_to
else f"{file_path.name}:{node.name}"
)
definitions_raw[identifier] = segment
sanitized = _sanitize_for_embedding(segment, model_hint, node.name)
definitions_sanitized[identifier] = sanitized
definitions_tokens[identifier] = sorted(_tokenize(sanitized))
if isinstance(node, ast.ClassDef):
definitions_kind[identifier] = "class"
else:
definitions_kind[identifier] = "function"
return definitions_raw, definitions_sanitized, definitions_tokens, definitions_kind
def _infer_model_from_relative_path(self, relative_path: Path) -> str | None:
try:
relative = relative_path.resolve().relative_to(self.models_root.resolve())
return relative.parts[0]
except Exception:
return None
def _infer_query_model_name(self, modeling_file: Path) -> str | None:
model = self._infer_model_from_relative_path(modeling_file)
if model:
return model
stem = modeling_file.stem
if stem.startswith("modeling_") and len(stem) > len("modeling_"):
return stem[len("modeling_") :]
return None
def _encode_batch(self, texts: list[str]) -> np.ndarray:
"""
Encode a batch of texts into normalized embeddings.
Args:
texts (`list[str]`): List of text strings to encode.
Returns:
`np.ndarray`: Normalized embeddings as a float32 numpy array.
"""
encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
encoded = {key: value.to(self.device) for key, value in encoded.items()}
with (
torch.autocast(device_type=self.device.type, dtype=self.dtype)
if self.device.type == "cuda"
else torch.no_grad()
):
output = self.model(**encoded)
if hasattr(output, "last_hidden_state"):
embeddings = output.last_hidden_state
mask = encoded["attention_mask"].unsqueeze(-1)
embeddings = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9)
elif hasattr(output, "pooler_output"):
embeddings = output.pooler_output
else:
embeddings = output[0].mean(dim=1)
embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1)
return embeddings.cpu().numpy().astype("float32")
def encode(self, texts: list[str]) -> np.ndarray:
"""
Encode a list of texts into embeddings, processing in batches.
Args:
texts (`list[str]`): List of text strings to encode.
Returns:
`np.ndarray`: Stacked embeddings for all texts.
"""
output = []
for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="encode", leave=False):
output.append(self._encode_batch(texts[i : i + BATCH_SIZE]))
if self.device.type == "cuda":
torch.cuda.empty_cache()
return np.vstack(output) if output else np.zeros((0, 0), dtype="float32")
# ---------- build & search ----------
def build_index(self) -> None:
"""Build the code similarity index from all modeling files and save to disk."""
logging.info("collecting files")
files = list(self.models_root.rglob("modeling_*.py"))
logging.info(f"parsing {len(files)} files")
identifiers = []
sanitized_sources = []
tokens_map = {}
for file_path in tqdm(files, desc="parse", leave=False):
model_hint = self._infer_model_from_relative_path(file_path)
(
_,
definitions_sanitized,
definitions_tokens,
_,
) = self._extract_definitions(file_path, self.models_root, model_hint)
for identifier in definitions_sanitized.keys():
identifiers.append(identifier)
sanitized_sources.append(definitions_sanitized[identifier])
tokens_map[identifier] = definitions_tokens[identifier]
logging.info(
f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})"
)
embeddings = self.encode(sanitized_sources)
safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH)
with open(INDEX_MAP_PATH, "w", encoding="utf-8") as file:
json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file)
with open(TOKENS_PATH, "w", encoding="utf-8") as file:
json.dump(tokens_map, file)
self.index_dir = Path.cwd()
def _topk_embedding(
self,
query_embedding_row: np.ndarray,
base_embeddings: np.ndarray,
identifier_map: dict[int, str],
self_model_normalized: str,
self_name: str,
k: int,
) -> list[tuple[str, float]]:
similarities = query_embedding_row @ base_embeddings.T
indices = np.argpartition(-similarities, k + 32)[: k + 32]
indices = indices[np.argsort(-similarities[indices])]
output = []
for match_id in indices:
identifier = identifier_map[int(match_id)]
parent_relative_path, match_name = identifier.split(":", 1)
parent_model = Path(parent_relative_path).parts[0]
if match_name == self_name:
continue
if self_model_normalized and _normalize(parent_model) == self_model_normalized:
continue
output.append((identifier, float(similarities[match_id])))
if len(output) >= k:
break
return output
def _topk_jaccard(
self,
query_tokens: set[str],
identifiers: list[str],
tokens_map: dict[str, list[str]],
self_model_normalized: str,
self_name: str,
k: int,
) -> list[tuple[str, float]]:
"""
Find top-k most similar definitions using Jaccard similarity on token sets.
Args:
query_tokens (`set[str]`): Set of tokens from the query definition.
identifiers (`list[str]`): List of all definition identifiers in the index.
tokens_map (`dict[str, list[str]]`): Mapping of identifiers to their token lists.
self_model_normalized (`str`): Normalized name of the query model to exclude.
self_name (`str`): Name of the query definition to exclude.
k (`int`): Number of top results to return.
Returns:
`list[tuple[str, float]]`: List of (identifier, score) tuples.
"""
scores = []
for identifier in identifiers:
parent_relative_path, match_name = identifier.split(":", 1)
parent_model = Path(parent_relative_path).parts[0]
if match_name == self_name:
continue
if self_model_normalized and _normalize(parent_model) == self_model_normalized:
continue
tokens = set(tokens_map.get(identifier, []))
if not tokens or not query_tokens:
continue
score = len(query_tokens & tokens) / len(query_tokens | tokens)
if score > 0:
scores.append((identifier, score))
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:k]
def analyze_file(
self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True, use_jaccard=False
) -> dict[str, dict[str, list]]:
"""
Analyze a modeling file and find similar code definitions in the index.
Args:
modeling_file (`Path`): Path to the modeling file to analyze.
top_k_per_item (`int`, *optional*, defaults to 5): Number of top matches to return per definition.
allow_hub_fallback (`bool`, *optional*, defaults to `True`): Whether to download index from Hub if not found locally.
Returns:
`dict[str, dict[str, list]]`: Dictionary mapping definition names to their similarity results.
Each result contains 'embedding', 'jaccard', and 'intersection' keys.
"""
if allow_hub_fallback:
self.ensure_local_index()
base = safetensors_load(str(self._resolve_index_path(EMBEDDINGS_PATH)))
base_embeddings = base["embeddings"]
with open(self._resolve_index_path(INDEX_MAP_PATH), "r", encoding="utf-8") as file:
identifier_map = {int(key): value for key, value in json.load(file).items()}
identifiers = [identifier_map[i] for i in range(len(identifier_map))]
with open(self._resolve_index_path(TOKENS_PATH), "r", encoding="utf-8") as file:
tokens_map = json.load(file)
self_model = self._infer_query_model_name(modeling_file)
definitions_raw, definitions_sanitized, _, definitions_kind = self._extract_definitions(
modeling_file, None, self_model
)
query_identifiers = list(definitions_raw.keys())
query_sources_sanitized = [definitions_sanitized[key] for key in query_identifiers]
query_tokens_list = [set(_tokenize(source)) for source in query_sources_sanitized]
self_model_normalized = _normalize(self_model)
logging.info(
f"encoding {len(query_sources_sanitized)} query definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})"
)
query_embeddings = self.encode(query_sources_sanitized)
output = {}
for i, query_identifier in enumerate(query_identifiers):
query_name = query_identifier.split(":")[-1]
embedding_top = self._topk_embedding(
query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item
)
embedding_set = {identifier for identifier, _ in embedding_top}
kind = definitions_kind.get(query_identifier, "function")
entry = {"kind": kind, "embedding": embedding_top}
if use_jaccard:
jaccard_top = self._topk_jaccard(
query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item
)
jaccard_set = {identifier for identifier, _ in jaccard_top}
intersection = set(embedding_set & jaccard_set)
entry.update({"jaccard": jaccard_top, "intersection": intersection})
output[query_name] = entry
return output
_RELEASE_RE = re.compile(
r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE
)
def build_date_data() -> dict[str, str]:
"""
Scan Markdown files in `root_dir` and build {model_id: date_released}.
- model_id is the filename without extension (e.g., "llama" for "llama.md")
- date_released is the first YYYY-MM-DD matched after "...was released on ..."
- Ignores non-*.md files and directories.
Returns:
dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD).
Files without a match are simply omitted.
"""
root_dir = transformers.__file__.split("src/transformers")[0]
root = Path(root_dir).joinpath("docs/source/en/model_doc")
result: dict[str, str] = {}
for md_path in root.glob("*.md"):
try:
text = md_path.read_text(encoding="utf-8", errors="ignore")
except Exception:
# Skip unreadable files quietly
logging.info(f"Failed to read md for {md_path}")
m = _RELEASE_RE.search(text)
if m:
model_id = md_path.stem # e.g., "llama" from "llama.md"
result[model_id] = m.group(1)
return result
def _format_table(headers: list[str], rows: list[tuple[str, ...] | None], row_styles: list[str] | None = None) -> str:
if not rows:
return f"{ANSI_ROW}(no matches){ANSI_RESET}"
widths = [len(header) for header in headers]
for row in rows:
if row is None:
continue
for idx, cell in enumerate(row):
widths[idx] = max(widths[idx], len(cell))
header_line = " | ".join(header.ljust(widths[idx]) for idx, header in enumerate(headers))
divider = "-+-".join("-" * widths[idx] for idx in range(len(headers)))
total_width = sum(widths) + 3 * (len(headers) - 1)
styled_rows = []
style_idx = 0
for row in rows:
if row is None:
styled_rows.append(f"{ANSI_SECTION}{'-' * total_width}{ANSI_RESET}")
continue
line = " | ".join(cell.ljust(widths[col_idx]) for col_idx, cell in enumerate(row))
style = ANSI_ROW
if row_styles and style_idx < len(row_styles) and row_styles[style_idx]:
style = row_styles[style_idx]
styled_rows.append(f"{style}{line}{ANSI_RESET}")
style_idx += 1
return "\n".join([f"{ANSI_SECTION}{header_line}{ANSI_RESET}", divider] + styled_rows)
def _parse_release_date(value: str) -> datetime | None:
"""Return a datetime parsed from YYYY-MM-DD strings, otherwise None."""
try:
return datetime.strptime(value, "%Y-%m-%d")
except (TypeError, ValueError):
return None
@cache
def _load_definition_line_map(relative_path: str) -> dict[str, int]:
"""Return {definition_name: line_number} for top-level definitions in the given file."""
file_path = MODELS_ROOT / relative_path
try:
source = file_path.read_text(encoding="utf-8")
except (FileNotFoundError, OSError):
return {} # gracefully keep going
try:
tree = ast.parse(source)
except SyntaxError:
return {}
line_map: dict[str, int] = {}
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
line_map[node.name] = getattr(node, "lineno", None) or 1
elif isinstance(node, ast.Assign):
continue
return line_map
def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, str]:
"""Return full path and formatted line number string for the given definition."""
full_path = MODELS_ROOT / relative_path
line = _load_definition_line_map(relative_path).get(definition)
line_str = str(line) if line is not None else "?"
return str(full_path), line_str
def _colorize_heading(text: str) -> str:
return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}"
def main():
"""CLI entry point for the modular model detector."""
logging.basicConfig(level=logging.INFO, format="%(message)s")
parser = argparse.ArgumentParser(prog="hf-code-sim")
parser.add_argument("--build", action="store_true")
parser.add_argument("--modeling-file", type=str, help='You can just specify "vits" if you are lazy like me.')
parser.add_argument(
"--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset."
)
parser.add_argument(
"--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index."
)
parser.add_argument("--use_jaccard", type=bool, default=False, help="Whether or not to use jaccard index")
args = parser.parse_args()
analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset)
if args.build:
analyzer.build_index()
if args.push_new_index:
analyzer.push_index_to_hub()
return
if not args.modeling_file:
raise SystemExit("Provide --modeling-file or use --build")
dates = build_date_data()
modeling_file = args.modeling_file
if os.sep not in modeling_file:
modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py")
results = analyzer.analyze_file(
Path(modeling_file), top_k_per_item=5, allow_hub_fallback=True, use_jaccard=args.use_jaccard
)
modeling_filename = Path(modeling_file).name
release_key = modeling_filename.split("modeling_")[-1][:-3]
release_date = dates.get(release_key, "unknown release date")
aggregate_scores: dict[str, float] = {}
for data in results.values():
for identifier, score in data.get("embedding", []):
try:
relative_path, _ = identifier.split(":", 1)
except ValueError:
continue
aggregate_scores[relative_path] = aggregate_scores.get(relative_path, 0.0) + score
best_candidate_path: str | None = None
if aggregate_scores:
best_candidate_path = max(aggregate_scores.items(), key=lambda item: item[1])[0]
best_model = Path(best_candidate_path).parts[0] if Path(best_candidate_path).parts else "?"
best_release = dates.get(best_model, "unknown release date")
logging.info(
f"{ANSI_HIGHLIGHT_CANDIDATE}Closest overall candidate: {MODELS_ROOT / best_candidate_path}"
f" (release: {best_release}, total score: {aggregate_scores[best_candidate_path]:.4f}){ANSI_RESET}"
)
grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []}
for query_name, data in results.items():
kind = data.get("kind", "function")
grouped.setdefault(kind, []).append((query_name, data))
section_titles = [("class", "Classes"), ("function", "Functions")]
legend_shown = False
for kind, title in section_titles:
entries = grouped.get(kind, [])
if not entries:
continue
metrics_present: set[str] = set()
for _, data in entries:
if data.get("embedding"):
metrics_present.add("embedding")
if args.use_jaccard:
if data.get("jaccard"):
metrics_present.add("jaccard")
if data.get("intersection"):
metrics_present.add("intersection")
include_metric_column = bool(metrics_present - {"embedding"})
headers = ["Symbol", "Path", "Score", "Release"]
if include_metric_column:
headers = ["Symbol", "Metric", "Path", "Score", "Release"]
table_rows: list[tuple[str, ...] | None] = []
row_styles: list[str] = []
has_metric_rows = False
logging.info(_colorize_heading(title))
for query_name, data in entries:
if table_rows:
table_rows.append(None)
symbol_label = query_name
if release_date:
symbol_label = f"{symbol_label}"
symbol_row = (symbol_label,) + ("",) * (len(headers) - 1)
table_rows.append(symbol_row)
row_styles.append(ANSI_BOLD)
embedding_details: list[tuple[str, str, str, float, str]] = []
embedding_style_indices: list[int] = []
for identifier, score in data.get("embedding", []):
try:
relative_path, match_name = identifier.split(":", 1)
except ValueError:
continue
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
match_release = dates.get(model_id, "unknown release date")
full_path, line = _resolve_definition_location(relative_path, match_name)
display_path = f"{full_path}:{line} ({match_name})"
if include_metric_column:
row = ("", "embedding", display_path, f"{score:.4f}", match_release)
else:
row = ("", display_path, f"{score:.4f}", match_release)
table_rows.append(row)
row_styles.append(ANSI_ROW)
embedding_style_indices.append(len(row_styles) - 1)
embedding_details.append((relative_path, model_id, match_name, score, match_release))
has_metric_rows = True
if embedding_details:
highest_score = None
highest_idx = None
for idx, (_, _, _, score, _) in enumerate(embedding_details):
if highest_score is None or score > highest_score:
highest_score = score
highest_idx = idx
if highest_idx is not None:
row_styles[embedding_style_indices[highest_idx]] = ANSI_HIGHLIGHT_TOP
if highest_score is not None:
oldest_idx = None
oldest_date = None
for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details):
if highest_score - score > 0.1:
continue
parsed = _parse_release_date(release_value)
if parsed is None:
continue
if oldest_date is None or parsed < oldest_date:
oldest_date = parsed
oldest_idx = idx
if (
oldest_idx is not None
and row_styles[embedding_style_indices[oldest_idx]] != ANSI_HIGHLIGHT_TOP
):
row_styles[embedding_style_indices[oldest_idx]] = ANSI_HIGHLIGHT_OLD
if best_candidate_path is not None:
for idx, (relative_path, _, _, _, _) in enumerate(embedding_details):
style_position = embedding_style_indices[idx]
if row_styles[style_position] != ANSI_ROW:
continue
if relative_path == best_candidate_path:
row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE
if args.use_jaccard:
for identifier, score in data.get("jaccard", []):
try:
relative_path, match_name = identifier.split(":", 1)
except ValueError:
continue
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
match_release = dates.get(model_id, "unknown release date")
full_path, line = _resolve_definition_location(relative_path, match_name)
display_path = f"{full_path}:{line} ({match_name})"
if include_metric_column:
row = ("", "jaccard", display_path, f"{score:.4f}", match_release)
else:
row = ("", display_path, f"{score:.4f}", match_release)
table_rows.append(row)
row_styles.append(ANSI_ROW)
has_metric_rows = True
if best_candidate_path == relative_path:
row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE
for identifier in sorted(data.get("intersection", [])):
try:
relative_path, match_name = identifier.split(":", 1)
except ValueError:
continue
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
match_release = dates.get(model_id, "unknown release date")
full_path, line = _resolve_definition_location(relative_path, match_name)
display_path = f"{full_path}:{line} ({match_name})"
if include_metric_column:
row = ("", "intersection", display_path, "--", match_release)
else:
row = ("", display_path, "--", match_release)
table_rows.append(row)
row_styles.append(ANSI_ROW)
has_metric_rows = True
if best_candidate_path == relative_path:
row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE
if table_rows:
if not legend_shown and has_metric_rows:
logging.info(
"Legend: "
f"{ANSI_HIGHLIGHT_TOP}highest match{ANSI_RESET}, "
f"{ANSI_HIGHLIGHT_OLD}oldest within 0.1{ANSI_RESET}, "
f"{ANSI_HIGHLIGHT_CANDIDATE}closest overall candidate{ANSI_RESET}"
)
legend_shown = True
logging.info(_format_table(headers, table_rows, row_styles))
logging.info("")
if __name__ == "__main__":
main()

750
utils/not_doctested.txt Normal file
View File

@@ -0,0 +1,750 @@
docs/source/en/_config.py
docs/source/en/accelerate.md
docs/source/en/add_new_model.md
docs/source/en/add_new_pipeline.md
docs/source/en/community.md
docs/source/en/contributing.md
docs/source/en/custom_models.md
docs/source/en/debugging.md
docs/source/en/fast_tokenizers.md
docs/source/en/glossary.md
docs/source/en/hpo_train.md
docs/source/en/index.md
docs/source/en/installation.md
docs/source/en/internal/audio_utils.md
docs/source/en/internal/file_utils.md
docs/source/en/internal/image_processing_utils.md
docs/source/en/internal/modeling_utils.md
docs/source/en/internal/pipelines_utils.md
docs/source/en/internal/time_series_utils.md
docs/source/en/internal/tokenization_utils.md
docs/source/en/internal/trainer_utils.md
docs/source/en/llm_tutorial.md
docs/source/en/main_classes/callback.md
docs/source/en/main_classes/configuration.md
docs/source/en/main_classes/data_collator.md
docs/source/en/main_classes/deepspeed.md
docs/source/en/main_classes/feature_extractor.md
docs/source/en/main_classes/image_processor.md
docs/source/en/main_classes/logging.md
docs/source/en/main_classes/model.md
docs/source/en/main_classes/optimizer_schedules.md
docs/source/en/main_classes/output.md
docs/source/en/main_classes/pipelines.md
docs/source/en/main_classes/processors.md
docs/source/en/main_classes/quantization.md
docs/source/en/main_classes/tokenizer.md
docs/source/en/main_classes/trainer.md
docs/source/en/model_doc/albert.md
docs/source/en/model_doc/align.md
docs/source/en/model_doc/altclip.md
docs/source/en/model_doc/audio-spectrogram-transformer.md
docs/source/en/model_doc/auto.md
docs/source/en/model_doc/autoformer.md
docs/source/en/model_doc/bark.md
docs/source/en/model_doc/bart.md
docs/source/en/model_doc/barthez.md
docs/source/en/model_doc/bartpho.md
docs/source/en/model_doc/beit.md
docs/source/en/model_doc/bert-generation.md
docs/source/en/model_doc/bert-japanese.md
docs/source/en/model_doc/bert.md
docs/source/en/model_doc/bertweet.md
docs/source/en/model_doc/big_bird.md
docs/source/en/model_doc/bigbird_pegasus.md
docs/source/en/model_doc/biogpt.md
docs/source/en/model_doc/bit.md
docs/source/en/model_doc/blenderbot-small.md
docs/source/en/model_doc/blenderbot.md
docs/source/en/model_doc/blip-2.md
docs/source/en/model_doc/blip.md
docs/source/en/model_doc/bloom.md
docs/source/en/model_doc/bridgetower.md
docs/source/en/model_doc/camembert.md
docs/source/en/model_doc/canine.md
docs/source/en/model_doc/chinese_clip.md
docs/source/en/model_doc/clap.md
docs/source/en/model_doc/clip.md
docs/source/en/model_doc/clipseg.md
docs/source/en/model_doc/codegen.md
docs/source/en/model_doc/conditional_detr.md
docs/source/en/model_doc/convbert.md
docs/source/en/model_doc/convnext.md
docs/source/en/model_doc/convnextv2.md
docs/source/en/model_doc/cpm.md
docs/source/en/model_doc/cpmant.md
docs/source/en/model_doc/ctrl.md
docs/source/en/model_doc/cvt.md
docs/source/en/model_doc/data2vec.md
docs/source/en/model_doc/deberta-v2.md
docs/source/en/model_doc/deberta.md
docs/source/en/model_doc/decision_transformer.md
docs/source/en/model_doc/deformable_detr.md
docs/source/en/model_doc/deit.md
docs/source/en/model_doc/deplot.md
docs/source/en/model_doc/detr.md
docs/source/en/model_doc/dialogpt.md
docs/source/en/model_doc/dinat.md
docs/source/en/model_doc/dinov2.md
docs/source/en/model_doc/distilbert.md
docs/source/en/model_doc/dit.md
docs/source/en/model_doc/dpr.md
docs/source/en/model_doc/efficientnet.md
docs/source/en/model_doc/electra.md
docs/source/en/model_doc/encodec.md
docs/source/en/model_doc/ernie.md
docs/source/en/model_doc/esm.md
docs/source/en/model_doc/flan-t5.md
docs/source/en/model_doc/flan-ul2.md
docs/source/en/model_doc/flaubert.md
docs/source/en/model_doc/flava.md
docs/source/en/model_doc/fnet.md
docs/source/en/model_doc/focalnet.md
docs/source/en/model_doc/fsmt.md
docs/source/en/model_doc/funnel.md
docs/source/en/model_doc/git.md
docs/source/en/model_doc/glpn.md
docs/source/en/model_doc/gpt-sw3.md
docs/source/en/model_doc/gpt2.md
docs/source/en/model_doc/gpt_bigcode.md
docs/source/en/model_doc/gpt_neo.md
docs/source/en/model_doc/gpt_neox.md
docs/source/en/model_doc/gpt_neox_japanese.md
docs/source/en/model_doc/gptj.md
docs/source/en/model_doc/groupvit.md
docs/source/en/model_doc/herbert.md
docs/source/en/model_doc/hubert.md
docs/source/en/model_doc/ibert.md
docs/source/en/model_doc/idefics.md
docs/source/en/model_doc/imagegpt.md
docs/source/en/model_doc/informer.md
docs/source/en/model_doc/instructblip.md
docs/source/en/model_doc/layoutlm.md
docs/source/en/model_doc/layoutlmv2.md
docs/source/en/model_doc/layoutlmv3.md
docs/source/en/model_doc/layoutxlm.md
docs/source/en/model_doc/led.md
docs/source/en/model_doc/levit.md
docs/source/en/model_doc/lilt.md
docs/source/en/model_doc/llama.md
docs/source/en/model_doc/llama2.md
docs/source/en/model_doc/llava.md
docs/source/en/model_doc/llava_next.md
docs/source/en/model_doc/longformer.md
docs/source/en/model_doc/longt5.md
docs/source/en/model_doc/luke.md
docs/source/en/model_doc/lxmert.md
docs/source/en/model_doc/m2m_100.md
docs/source/en/model_doc/madlad-400.md
docs/source/en/model_doc/marian.md
docs/source/en/model_doc/mask2former.md
docs/source/en/model_doc/maskformer.md
docs/source/en/model_doc/matcha.md
docs/source/en/model_doc/mbart.md
docs/source/en/model_doc/megatron-bert.md
docs/source/en/model_doc/megatron_gpt2.md
docs/source/en/model_doc/mgp-str.md
docs/source/en/model_doc/mistral.md
docs/source/en/model_doc/mixtral.md
docs/source/en/model_doc/mluke.md
docs/source/en/model_doc/mms.md
docs/source/en/model_doc/mobilebert.md
docs/source/en/model_doc/mobilenet_v1.md
docs/source/en/model_doc/mobilenet_v2.md
docs/source/en/model_doc/mobilevit.md
docs/source/en/model_doc/mobilevitv2.md
docs/source/en/model_doc/mpnet.md
docs/source/en/model_doc/mpt.md
docs/source/en/model_doc/mra.md
docs/source/en/model_doc/mt5.md
docs/source/en/model_doc/musicgen.md
docs/source/en/model_doc/musicgen_melody.md
docs/source/en/model_doc/mvp.md
docs/source/en/model_doc/nllb-moe.md
docs/source/en/model_doc/nllb.md
docs/source/en/model_doc/nystromformer.md
docs/source/en/model_doc/oneformer.md
docs/source/en/model_doc/openai-gpt.md
docs/source/en/model_doc/opt.md
docs/source/en/model_doc/owlvit.md
docs/source/en/model_doc/pegasus.md
docs/source/en/model_doc/pegasus_x.md
docs/source/en/model_doc/perceiver.md
docs/source/en/model_doc/phobert.md
docs/source/en/model_doc/pix2struct.md
docs/source/en/model_doc/plbart.md
docs/source/en/model_doc/poolformer.md
docs/source/en/model_doc/pop2piano.md
docs/source/en/model_doc/prophetnet.md
docs/source/en/model_doc/pvt.md
docs/source/en/model_doc/qwen2.md
docs/source/en/model_doc/qwen2_moe.md
docs/source/en/model_doc/rag.md
docs/source/en/model_doc/reformer.md
docs/source/en/model_doc/regnet.md
docs/source/en/model_doc/rembert.md
docs/source/en/model_doc/resnet.md
docs/source/en/model_doc/roberta-prelayernorm.md
docs/source/en/model_doc/roberta.md
docs/source/en/model_doc/roc_bert.md
docs/source/en/model_doc/roformer.md
docs/source/en/model_doc/rwkv.md
docs/source/en/model_doc/sam.md
docs/source/en/model_doc/sam_hq.md
docs/source/en/model_doc/segformer.md
docs/source/en/model_doc/sew-d.md
docs/source/en/model_doc/sew.md
docs/source/en/model_doc/speech-encoder-decoder.md
docs/source/en/model_doc/speecht5.md
docs/source/en/model_doc/splinter.md
docs/source/en/model_doc/squeezebert.md
docs/source/en/model_doc/swiftformer.md
docs/source/en/model_doc/swin.md
docs/source/en/model_doc/swin2sr.md
docs/source/en/model_doc/swinv2.md
docs/source/en/model_doc/table-transformer.md
docs/source/en/model_doc/tapas.md
docs/source/en/model_doc/time_series_transformer.md
docs/source/en/model_doc/timesformer.md
docs/source/en/model_doc/trocr.md
docs/source/en/model_doc/ul2.md
docs/source/en/model_doc/umt5.md
docs/source/en/model_doc/unispeech-sat.md
docs/source/en/model_doc/unispeech.md
docs/source/en/model_doc/upernet.md
docs/source/en/model_doc/videomae.md
docs/source/en/model_doc/vilt.md
docs/source/en/model_doc/vipllava.md
docs/source/en/model_doc/vision-encoder-decoder.md
docs/source/en/model_doc/vision-text-dual-encoder.md
docs/source/en/model_doc/visual_bert.md
docs/source/en/model_doc/vit.md
docs/source/en/model_doc/vit_mae.md
docs/source/en/model_doc/vit_msn.md
docs/source/en/model_doc/vivit.md
docs/source/en/model_doc/wav2vec2-conformer.md
docs/source/en/model_doc/wav2vec2.md
docs/source/en/model_doc/wav2vec2_phoneme.md
docs/source/en/model_doc/wavlm.md
docs/source/en/model_doc/whisper.md
docs/source/en/model_doc/xclip.md
docs/source/en/model_doc/xglm.md
docs/source/en/model_doc/xlm-roberta-xl.md
docs/source/en/model_doc/xlm-roberta.md
docs/source/en/model_doc/xlm-v.md
docs/source/en/model_doc/xlm.md
docs/source/en/model_doc/xlnet.md
docs/source/en/model_doc/xls_r.md
docs/source/en/model_doc/xlsr_wav2vec2.md
docs/source/en/model_doc/xmod.md
docs/source/en/model_doc/yolos.md
docs/source/en/model_doc/yoso.md
docs/source/en/model_memory_anatomy.md
docs/source/en/model_sharing.md
docs/source/en/notebooks.md
docs/source/en/peft.md
docs/source/en/perf_hardware.md
docs/source/en/perf_torch_compile.md
docs/source/en/perf_train_cpu.md
docs/source/en/perf_train_gpu_many.md
docs/source/en/perf_train_special.md
docs/source/en/perplexity.md
docs/source/en/philosophy.md
docs/source/en/pipeline_webserver.md
docs/source/en/pr_checks.md
docs/source/en/run_scripts.md
docs/source/en/serialization.md
docs/source/en/tasks/asr.md
docs/source/en/tasks/audio_classification.md
docs/source/en/tasks/document_question_answering.md
docs/source/en/tasks/idefics.md
docs/source/en/tasks/image_captioning.md
docs/source/en/tasks/image_classification.md
docs/source/en/tasks/language_modeling.md
docs/source/en/tasks/masked_language_modeling.md
docs/source/en/tasks/monocular_depth_estimation.md
docs/source/en/tasks/multiple_choice.md
docs/source/en/tasks/object_detection.md
docs/source/en/tasks/question_answering.md
docs/source/en/tasks/semantic_segmentation.md
docs/source/en/tasks/sequence_classification.md
docs/source/en/tasks/summarization.md
docs/source/en/tasks/text-to-speech.md
docs/source/en/tasks/token_classification.md
docs/source/en/tasks/translation.md
docs/source/en/tasks/video_classification.md
docs/source/en/tasks/visual_question_answering.md
docs/source/en/tasks/zero_shot_image_classification.md
docs/source/en/tasks/zero_shot_object_detection.md
docs/source/en/tokenizer_summary.md
docs/source/en/training.md
docs/source/en/troubleshooting.md
src/transformers/activations.py
src/transformers/audio_utils.py
src/transformers/cli/add_new_model_like.py
src/transformers/cli/chat.py
src/transformers/cli/download.py
src/transformers/cli/serve.py
src/transformers/cli/system.py
src/transformers/cli/transformers.py
src/transformers/configuration_utils.py
src/transformers/convert_slow_tokenizer.py
src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py
src/transformers/data/data_collator.py
src/transformers/data/datasets/glue.py
src/transformers/data/datasets/squad.py
src/transformers/data/metrics/squad_metrics.py
src/transformers/data/processors/glue.py
src/transformers/data/processors/squad.py
src/transformers/data/processors/utils.py
src/transformers/data/processors/xnli.py
src/transformers/debug_utils.py
src/transformers/dependency_versions_check.py
src/transformers/dependency_versions_table.py
src/transformers/dynamic_module_utils.py
src/transformers/feature_extraction_sequence_utils.py
src/transformers/feature_extraction_utils.py
src/transformers/file_utils.py
src/transformers/hf_argparser.py
src/transformers/hyperparameter_search.py
src/transformers/image_processing_utils.py
src/transformers/image_transforms.py
src/transformers/image_utils.py
src/transformers/integrations/bitsandbytes.py
src/transformers/integrations/deepspeed.py
src/transformers/integrations/integration_utils.py
src/transformers/integrations/peft.py
src/transformers/modelcard.py
src/transformers/modeling_outputs.py
src/transformers/modeling_utils.py
src/transformers/models/align/configuration_align.py
src/transformers/models/align/modeling_align.py
src/transformers/models/altclip/configuration_altclip.py
src/transformers/models/altclip/modeling_altclip.py
src/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py
src/transformers/models/audio_spectrogram_transformer/convert_audio_spectrogram_transformer_original_to_pytorch.py
src/transformers/models/auto/auto_factory.py
src/transformers/models/auto/configuration_auto.py
src/transformers/models/auto/modeling_auto.py
src/transformers/models/autoformer/configuration_autoformer.py
src/transformers/models/autoformer/modeling_autoformer.py
src/transformers/models/bark/convert_suno_to_hf.py
src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/beit/convert_beit_unilm_to_pytorch.py
src/transformers/models/bert_generation/modeling_bert_generation.py
src/transformers/models/biogpt/configuration_biogpt.py
src/transformers/models/biogpt/convert_biogpt_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/biogpt/modeling_biogpt.py
src/transformers/models/bit/configuration_bit.py
src/transformers/models/bit/convert_bit_to_pytorch.py
src/transformers/models/bit/modeling_bit.py
src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/blip/configuration_blip.py
src/transformers/models/blip/convert_blip_original_pytorch_to_hf.py
src/transformers/models/blip/modeling_blip_text.py
src/transformers/models/blip_2/configuration_blip_2.py
src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py
src/transformers/models/blip_2/modeling_blip_2.py
src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py
src/transformers/models/bloom/modeling_bloom.py
src/transformers/models/bridgetower/configuration_bridgetower.py
src/transformers/models/bridgetower/modeling_bridgetower.py
src/transformers/models/bros/convert_bros_to_pytorch.py
src/transformers/models/camembert/modeling_camembert.py
src/transformers/models/chinese_clip/configuration_chinese_clip.py
src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py
src/transformers/models/chinese_clip/modeling_chinese_clip.py
src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py
src/transformers/models/clip/convert_clip_original_pytorch_to_hf.py
src/transformers/models/clip/modeling_clip.py
src/transformers/models/clipseg/configuration_clipseg.py
src/transformers/models/clipseg/convert_clipseg_original_pytorch_to_hf.py
src/transformers/models/codegen/modeling_codegen.py
src/transformers/models/conditional_detr/convert_conditional_detr_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/convbert/modeling_convbert.py
src/transformers/models/convnext/convert_convnext_to_pytorch.py
src/transformers/models/convnextv2/configuration_convnextv2.py
src/transformers/models/convnextv2/convert_convnextv2_to_pytorch.py
src/transformers/models/convnextv2/modeling_convnextv2.py
src/transformers/models/cpmant/configuration_cpmant.py
src/transformers/models/cpmant/modeling_cpmant.py
src/transformers/models/cpmant/tokenization_cpmant.py
src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/data2vec/modeling_data2vec_text.py
src/transformers/models/decision_transformer/modeling_decision_transformer.py
src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py
src/transformers/models/deit/convert_deit_timm_to_pytorch.py
src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/detr/convert_detr_to_pytorch.py
src/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/dinov2/configuration_dinov2.py
src/transformers/models/dinov2/convert_dinov2_to_hf.py
src/transformers/models/dinov2/modeling_dinov2.py
src/transformers/models/distilbert/modeling_distilbert.py
src/transformers/models/dit/convert_dit_unilm_to_pytorch.py
src/transformers/models/donut/configuration_donut_swin.py
src/transformers/models/donut/convert_donut_to_pytorch.py
src/transformers/models/donut/modeling_donut_swin.py
src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py
src/transformers/models/dpr/modeling_dpr.py
src/transformers/models/dpt/configuration_dpt.py
src/transformers/models/dpt/convert_dpt_hybrid_to_pytorch.py
src/transformers/models/dpt/convert_dpt_to_pytorch.py
src/transformers/models/efficientnet/configuration_efficientnet.py
src/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py
src/transformers/models/efficientnet/modeling_efficientnet.py
src/transformers/models/encodec/configuration_encodec.py
src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py
src/transformers/models/encoder_decoder/modeling_encoder_decoder.py
src/transformers/models/ernie/modeling_ernie.py
src/transformers/models/esm/configuration_esm.py
src/transformers/models/esm/convert_esm.py
src/transformers/models/esm/modeling_esm.py
src/transformers/models/esm/modeling_esmfold.py
src/transformers/models/esm/openfold_utils/chunk_utils.py
src/transformers/models/esm/openfold_utils/data_transforms.py
src/transformers/models/esm/openfold_utils/feats.py
src/transformers/models/esm/openfold_utils/loss.py
src/transformers/models/esm/openfold_utils/protein.py
src/transformers/models/esm/openfold_utils/residue_constants.py
src/transformers/models/esm/openfold_utils/rigid_utils.py
src/transformers/models/esm/openfold_utils/tensor_utils.py
src/transformers/models/falcon/configuration_falcon.py
src/transformers/models/falcon/modeling_falcon.py
src/transformers/models/flaubert/configuration_flaubert.py
src/transformers/models/flaubert/modeling_flaubert.py
src/transformers/models/flava/convert_dalle_to_flava_codebook.py
src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py
src/transformers/models/flava/modeling_flava.py
src/transformers/models/fnet/modeling_fnet.py
src/transformers/models/focalnet/configuration_focalnet.py
src/transformers/models/focalnet/convert_focalnet_to_hf_format.py
src/transformers/models/focalnet/modeling_focalnet.py
src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/fsmt/modeling_fsmt.py
src/transformers/models/funnel/configuration_funnel.py
src/transformers/models/funnel/modeling_funnel.py
src/transformers/models/fuyu/convert_fuyu_model_weights_to_hf.py
src/transformers/models/gemma/configuration_gemma.py
src/transformers/models/gemma/convert_gemma_weights_to_hf.py
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/git/configuration_git.py
src/transformers/models/git/convert_git_to_pytorch.py
src/transformers/models/glpn/configuration_glpn.py
src/transformers/models/glpn/convert_glpn_to_pytorch.py
src/transformers/models/gpt2/CONVERSION.md
src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
src/transformers/models/gpt_neo/modeling_gpt_neo.py
src/transformers/models/gpt_neox/modeling_gpt_neox.py
src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
src/transformers/models/gpt_sw3/convert_megatron_to_pytorch.py
src/transformers/models/gptj/configuration_gptj.py
src/transformers/models/groupvit/configuration_groupvit.py
src/transformers/models/groupvit/convert_groupvit_nvlab_to_hf.py
src/transformers/models/hubert/configuration_hubert.py
src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py
src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/hubert/convert_hubert_original_s3prl_checkpoint_to_pytorch.py
src/transformers/models/ibert/configuration_ibert.py
src/transformers/models/ibert/modeling_ibert.py
src/transformers/models/ibert/quant_modules.py
src/transformers/models/idefics/configuration_idefics.py
src/transformers/models/idefics/image_processing_idefics.py
src/transformers/models/idefics/modeling_idefics.py
src/transformers/models/idefics/perceiver.py
src/transformers/models/idefics/processing_idefics.py
src/transformers/models/idefics/vision.py
src/transformers/models/informer/configuration_informer.py
src/transformers/models/informer/modeling_informer.py
src/transformers/models/instructblip/configuration_instructblip.py
src/transformers/models/instructblip/convert_instructblip_original_to_pytorch.py
src/transformers/models/instructblip/modeling_instructblip.py
src/transformers/models/instructblip/processing_instructblip.py
src/transformers/models/jamba/configuration_jamba.py
src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/kosmos2/convert_kosmos2_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/led/configuration_led.py
src/transformers/models/led/modeling_led.py
src/transformers/models/levit/convert_levit_timm_to_pytorch.py
src/transformers/models/levit/modeling_levit.py
src/transformers/models/lilt/configuration_lilt.py
src/transformers/models/llama/configuration_llama.py
src/transformers/models/llama/convert_llama_weights_to_hf.py
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llava/configuration_llava.py
src/transformers/models/llava/modeling_llava.py
src/transformers/models/llava_next/configuration_llava_next.py
src/transformers/models/llava_next/modeling_llava_next.py
src/transformers/models/longformer/configuration_longformer.py
src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py
src/transformers/models/longt5/configuration_longt5.py
src/transformers/models/luke/configuration_luke.py
src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/luke/modeling_luke.py
src/transformers/models/lxmert/configuration_lxmert.py
src/transformers/models/lxmert/modeling_lxmert.py
src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py
src/transformers/models/m2m_100/modeling_m2m_100.py
src/transformers/models/marian/configuration_marian.py
src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py
src/transformers/models/marian/convert_marian_to_pytorch.py
src/transformers/models/markuplm/configuration_markuplm.py
src/transformers/models/markuplm/feature_extraction_markuplm.py
src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/maskformer/configuration_maskformer_swin.py
src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/maskformer/convert_maskformer_resnet_to_pytorch.py
src/transformers/models/maskformer/convert_maskformer_swin_to_pytorch.py
src/transformers/models/maskformer/modeling_maskformer_swin.py
src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py
src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py
src/transformers/models/megatron_bert/modeling_megatron_bert.py
src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py
src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py
src/transformers/models/mgp_str/configuration_mgp_str.py
src/transformers/models/mgp_str/modeling_mgp_str.py
src/transformers/models/mistral/configuration_mistral.py
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mixtral/configuration_mixtral.py
src/transformers/models/mixtral/modeling_mixtral.py
src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py
src/transformers/models/mobilenet_v2/configuration_mobilenet_v2.py
src/transformers/models/mobilevit/configuration_mobilevit.py
src/transformers/models/mobilevit/convert_mlcvnets_to_pytorch.py
src/transformers/models/mobilevitv2/convert_mlcvnets_to_pytorch.py
src/transformers/models/mpnet/configuration_mpnet.py
src/transformers/models/mpnet/modeling_mpnet.py
src/transformers/models/mpt/configuration_mpt.py
src/transformers/models/mpt/modeling_mpt.py
src/transformers/models/mra/configuration_mra.py
src/transformers/models/mra/convert_mra_pytorch_to_pytorch.py
src/transformers/models/mra/modeling_mra.py
src/transformers/models/mt5/configuration_mt5.py
src/transformers/models/mt5/modeling_mt5.py
src/transformers/models/musicgen/convert_musicgen_transformers.py
src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py
src/transformers/models/mvp/modeling_mvp.py
src/transformers/models/nllb_moe/configuration_nllb_moe.py
src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py
src/transformers/models/nllb_moe/modeling_nllb_moe.py
src/transformers/models/nougat/convert_nougat_to_hf.py
src/transformers/models/nystromformer/configuration_nystromformer.py
src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/nystromformer/modeling_nystromformer.py
src/transformers/models/oneformer/convert_to_hf_oneformer.py
src/transformers/models/openai/modeling_openai.py
src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/owlvit/configuration_owlvit.py
src/transformers/models/pegasus_x/modeling_pegasus_x.py
src/transformers/models/perceiver/configuration_perceiver.py
src/transformers/models/perceiver/convert_perceiver_haiku_to_pytorch.py
src/transformers/models/persimmon/convert_persimmon_weights_to_hf.py
src/transformers/models/persimmon/modeling_persimmon.py
src/transformers/models/pix2struct/configuration_pix2struct.py
src/transformers/models/pix2struct/convert_pix2struct_original_pytorch_to_hf.py
src/transformers/models/pix2struct/image_processing_pix2struct.py
src/transformers/models/pix2struct/processing_pix2struct.py
src/transformers/models/plbart/convert_plbart_original_checkpoint_to_torch.py
src/transformers/models/poolformer/convert_poolformer_original_to_pytorch.py
src/transformers/models/pop2piano/convert_pop2piano_weights_to_hf.py
src/transformers/models/pop2piano/feature_extraction_pop2piano.py
src/transformers/models/pop2piano/processing_pop2piano.py
src/transformers/models/pop2piano/tokenization_pop2piano.py
src/transformers/models/prophetnet/configuration_prophetnet.py
src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/prophetnet/modeling_prophetnet.py
src/transformers/models/pvt/configuration_pvt.py
src/transformers/models/pvt/convert_pvt_to_pytorch.py
src/transformers/models/pvt/image_processing_pvt.py
src/transformers/models/pvt/modeling_pvt.py
src/transformers/models/qwen2/configuration_qwen2.py
src/transformers/models/qwen2/modeling_qwen2.py
src/transformers/models/qwen2/tokenization_qwen2.py
src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
src/transformers/models/rag/configuration_rag.py
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/retrieval_rag.py
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
src/transformers/models/regnet/configuration_regnet.py
src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
src/transformers/models/regnet/convert_regnet_to_pytorch.py
src/transformers/models/rembert/configuration_rembert.py
src/transformers/models/rembert/modeling_rembert.py
src/transformers/models/resnet/convert_resnet_to_pytorch.py
src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/roberta_prelayernorm/convert_roberta_prelayernorm_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/roc_bert/configuration_roc_bert.py
src/transformers/models/roformer/modeling_roformer.py
src/transformers/models/rwkv/configuration_rwkv.py
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
src/transformers/models/rwkv/modeling_rwkv.py
src/transformers/models/sam/configuration_sam.py
src/transformers/models/sam/convert_sam_to_hf.py
src/transformers/models/sam/image_processing_sam.py
src/transformers/models/sam/modeling_sam.py
src/transformers/models/sam/processing_sam.py
src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py
src/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py
src/transformers/models/segformer/configuration_segformer.py
src/transformers/models/segformer/convert_segformer_original_to_pytorch.py
src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py
src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py
src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py
src/transformers/models/speecht5/configuration_speecht5.py
src/transformers/models/speecht5/convert_hifigan.py
src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/speecht5/number_normalizer.py
src/transformers/models/splinter/configuration_splinter.py
src/transformers/models/splinter/modeling_splinter.py
src/transformers/models/squeezebert/modeling_squeezebert.py
src/transformers/models/stablelm/modeling_stablelm.py
src/transformers/models/starcoder2/modeling_starcoder2.py
src/transformers/models/swiftformer/configuration_swiftformer.py
src/transformers/models/swiftformer/convert_swiftformer_original_to_hf.py
src/transformers/models/swiftformer/modeling_swiftformer.py
src/transformers/models/swin/convert_swin_simmim_to_pytorch.py
src/transformers/models/swin/convert_swin_timm_to_pytorch.py
src/transformers/models/swin2sr/configuration_swin2sr.py
src/transformers/models/swin2sr/convert_swin2sr_original_to_pytorch.py
src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py
src/transformers/models/swinv2/modeling_swinv2.py
src/transformers/models/switch_transformers/configuration_switch_transformers.py
src/transformers/models/switch_transformers/convert_big_switch.py
src/transformers/models/switch_transformers/modeling_switch_transformers.py
src/transformers/models/t5/configuration_t5.py
src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py
src/transformers/models/t5/modeling_t5.py
src/transformers/models/table_transformer/configuration_table_transformer.py
src/transformers/models/table_transformer/convert_table_transformer_to_hf.py
src/transformers/models/table_transformer/convert_table_transformer_to_hf_no_timm.py
src/transformers/models/tapas/configuration_tapas.py
src/transformers/models/tapas/modeling_tapas.py
src/transformers/models/timesformer/convert_timesformer_to_pytorch.py
src/transformers/models/timm_backbone/configuration_timm_backbone.py
src/transformers/models/timm_backbone/modeling_timm_backbone.py
src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py
src/transformers/models/umt5/configuration_umt5.py
src/transformers/models/umt5/convert_umt5_checkpoint_to_pytorch.py
src/transformers/models/umt5/modeling_umt5.py
src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/unispeech_sat/configuration_unispeech_sat.py
src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py
src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/upernet/configuration_upernet.py
src/transformers/models/upernet/convert_convnext_upernet_to_pytorch.py
src/transformers/models/upernet/convert_swin_upernet_to_pytorch.py
src/transformers/models/videomae/configuration_videomae.py
src/transformers/models/videomae/convert_videomae_to_pytorch.py
src/transformers/models/vilt/configuration_vilt.py
src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
src/transformers/models/vipllava/configuration_vipllava.py
src/transformers/models/vipllava/modeling_vipllava.py
src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py
src/transformers/models/visual_bert/convert_visual_bert_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/visual_bert/modeling_visual_bert.py
src/transformers/models/vit/convert_dino_to_pytorch.py
src/transformers/models/vit/convert_vit_timm_to_pytorch.py
src/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py
src/transformers/models/vit_msn/configuration_vit_msn.py
src/transformers/models/vit_msn/convert_msn_to_pytorch.py
src/transformers/models/vivit/configuration_vivit.py
src/transformers/models/vivit/image_processing_vivit.py
src/transformers/models/vivit/modeling_vivit.py
src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py
src/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py
src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py
src/transformers/models/whisper/convert_openai_to_hf.py
src/transformers/models/whisper/english_normalizer.py
src/transformers/models/x_clip/configuration_x_clip.py
src/transformers/models/x_clip/convert_x_clip_original_pytorch_to_hf.py
src/transformers/models/xglm/configuration_xglm.py
src/transformers/models/xglm/convert_xglm_original_ckpt_to_trfms.py
src/transformers/models/xglm/modeling_xglm.py
src/transformers/models/xlm/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/xlm/modeling_xlm.py
src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
src/transformers/models/xlm_roberta_xl/convert_xlm_roberta_xl_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
src/transformers/models/xlnet/modeling_xlnet.py
src/transformers/models/xmod/convert_xmod_original_pytorch_checkpoint_to_pytorch.py
src/transformers/models/yolos/convert_yolos_to_pytorch.py
src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py
src/transformers/models/yoso/modeling_yoso.py
src/transformers/models/zamba/configuration_zamba.py
src/transformers/models/zamba/modeling_zamba.py
src/transformers/optimization.py
src/transformers/pipelines/audio_classification.py
src/transformers/pipelines/audio_utils.py
src/transformers/pipelines/automatic_speech_recognition.py
src/transformers/pipelines/base.py
src/transformers/pipelines/depth_estimation.py
src/transformers/pipelines/document_question_answering.py
src/transformers/pipelines/feature_extraction.py
src/transformers/pipelines/fill_mask.py
src/transformers/pipelines/image_classification.py
src/transformers/pipelines/image_segmentation.py
src/transformers/pipelines/mask_generation.py
src/transformers/pipelines/object_detection.py
src/transformers/pipelines/pt_utils.py
src/transformers/pipelines/table_question_answering.py
src/transformers/pipelines/text_classification.py
src/transformers/pipelines/token_classification.py
src/transformers/pipelines/video_classification.py
src/transformers/pipelines/zero_shot_audio_classification.py
src/transformers/pipelines/zero_shot_classification.py
src/transformers/pipelines/zero_shot_image_classification.py
src/transformers/pipelines/zero_shot_object_detection.py
src/transformers/processing_utils.py
src/transformers/pytorch_utils.py
src/transformers/quantizers/auto.py
src/transformers/quantizers/base.py
src/transformers/quantizers/quantizer_awq.py
src/transformers/quantizers/quantizer_bnb_4bit.py
src/transformers/quantizers/quantizer_bnb_8bit.py
src/transformers/quantizers/quantizer_gptq.py
src/transformers/quantizers/quantizers_utils.py
src/transformers/testing_utils.py
src/transformers/time_series_utils.py
src/transformers/tokenization_python.py
src/transformers/tokenization_utils_base.py
src/transformers/trainer.py
src/transformers/trainer_callback.py
src/transformers/trainer_pt_utils.py
src/transformers/trainer_seq2seq.py
src/transformers/trainer_utils.py
src/transformers/training_args.py
src/transformers/training_args_seq2seq.py
src/transformers/utils/backbone_utils.py
src/transformers/utils/constants.py
src/transformers/utils/doc.py
src/transformers/utils/dummy_detectron2_objects.py
src/transformers/utils/dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects.py
src/transformers/utils/dummy_music_objects.py
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_sentencepiece_and_tokenizers_objects.py
src/transformers/utils/dummy_speech_objects.py
src/transformers/utils/dummy_tokenizers_objects.py
src/transformers/utils/dummy_vision_objects.py
src/transformers/utils/generic.py
src/transformers/utils/hp_naming.py
src/transformers/utils/hub.py
src/transformers/utils/import_utils.py
src/transformers/utils/logging.py
src/transformers/utils/notebook.py
src/transformers/utils/peft_utils.py
src/transformers/utils/quantization_config.py
src/transformers/utils/sentencepiece_model_pb2.py
src/transformers/utils/sentencepiece_model_pb2_new.py
src/transformers/utils/versions.py

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,384 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import re
import time
from get_ci_error_statistics import get_jobs
from slack_sdk import WebClient
client = WebClient(token=os.environ["CI_SLACK_BOT_TOKEN"])
def handle_test_results(test_results):
expressions = test_results.split(" ")
failed = 0
success = 0
# When the output is short enough, the output is surrounded by = signs: "== OUTPUT =="
# When it is too long, those signs are not present.
time_spent = expressions[-2] if "=" in expressions[-1] else expressions[-1]
for i, expression in enumerate(expressions):
if "failed" in expression:
failed += int(expressions[i - 1])
if "passed" in expression:
success += int(expressions[i - 1])
return failed, success, time_spent
def extract_first_line_failure(failures_short_lines):
failures = {}
file = None
in_error = False
for line in failures_short_lines.split("\n"):
if re.search(r"_ \[doctest\]", line):
in_error = True
file = line.split(" ")[2]
elif in_error and not line.split(" ")[0].isdigit():
failures[file] = line
in_error = False
return failures
class Message:
def __init__(self, title: str, doc_test_results: dict):
self.title = title
self.n_success = sum(job_result["n_success"] for job_result in doc_test_results.values())
self.n_failures = sum(job_result["n_failures"] for job_result in doc_test_results.values())
self.n_tests = self.n_success + self.n_failures
# Failures and success of the modeling tests
self.doc_test_results = doc_test_results
@property
def time(self) -> str:
all_results = [*self.doc_test_results.values()]
time_spent = [r["time_spent"].split(", ")[0] for r in all_results if len(r["time_spent"])]
total_secs = 0
for timings in time_spent:
time_parts = timings.split(":")
# Time can be formatted as xx:xx:xx, as .xx, or as x.xx if the time spent was less than a minute.
if len(time_parts) == 1:
time_parts = [0, 0, time_parts[0]]
hours, minutes, seconds = int(time_parts[0]), int(time_parts[1]), float(time_parts[2])
total_secs += hours * 3600 + minutes * 60 + seconds
hours, minutes, seconds = total_secs // 3600, (total_secs % 3600) // 60, total_secs % 60
return f"{int(hours)}h{int(minutes)}m{int(seconds)}s"
@property
def header(self) -> dict:
return {"type": "header", "text": {"type": "plain_text", "text": self.title}}
@property
def no_failures(self) -> dict:
return {
"type": "section",
"text": {
"type": "plain_text",
"text": f"🌞 There were no failures: all {self.n_tests} tests passed. The suite ran in {self.time}.",
"emoji": True,
},
"accessory": {
"type": "button",
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
"url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
},
}
@property
def failures(self) -> dict:
return {
"type": "section",
"text": {
"type": "plain_text",
"text": (
f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in"
f" {self.time}."
),
"emoji": True,
},
"accessory": {
"type": "button",
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
"url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
},
}
@property
def category_failures(self) -> list[dict]:
failure_blocks = []
MAX_ERROR_TEXT = 3000 - len("The following examples had failures:\n\n\n\n") - len("[Truncated]\n")
line_length = 40
category_failures = {k: v["failed"] for k, v in doc_test_results.items() if isinstance(v, dict)}
def single_category_failures(category, failures):
text = ""
if len(failures) == 0:
return ""
text += f"*{category} failures*:".ljust(line_length // 2).rjust(line_length // 2) + "\n"
for idx, failure in enumerate(failures):
new_text = text + f"`{failure}`\n"
if len(new_text) > MAX_ERROR_TEXT:
text = text + "[Truncated]\n"
break
text = new_text
return text
for category, failures in category_failures.items():
report = single_category_failures(category, failures)
if len(report) == 0:
continue
block = {
"type": "section",
"text": {
"type": "mrkdwn",
"text": f"The following examples had failures:\n\n\n{report}\n",
},
}
failure_blocks.append(block)
return failure_blocks
@property
def payload(self) -> str:
blocks = [self.header]
if self.n_failures > 0:
blocks.append(self.failures)
if self.n_failures > 0:
blocks.extend(self.category_failures)
if self.n_failures == 0:
blocks.append(self.no_failures)
return json.dumps(blocks)
@staticmethod
def error_out():
payload = [
{
"type": "section",
"text": {
"type": "plain_text",
"text": "There was an issue running the tests.",
},
"accessory": {
"type": "button",
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
"url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
},
}
]
print("Sending the following payload")
print(json.dumps({"blocks": json.loads(payload)}))
client.chat_postMessage(
channel=SLACK_REPORT_CHANNEL_ID,
text="There was an issue running the tests.",
blocks=payload,
)
def post(self):
print("Sending the following payload")
print(json.dumps({"blocks": json.loads(self.payload)}))
text = f"{self.n_failures} failures out of {self.n_tests} tests," if self.n_failures else "All tests passed."
self.thread_ts = client.chat_postMessage(
channel=SLACK_REPORT_CHANNEL_ID,
blocks=self.payload,
text=text,
)
def get_reply_blocks(self, job_name, job_link, failures, text):
# `text` must be less than 3001 characters in Slack SDK
# keep some room for adding "[Truncated]" when necessary
MAX_ERROR_TEXT = 3000 - len("[Truncated]")
failure_text = ""
for key, value in failures.items():
new_text = failure_text + f"*{key}*\n_{value}_\n\n"
if len(new_text) > MAX_ERROR_TEXT:
# `failure_text` here has length <= 3000
failure_text = failure_text + "[Truncated]"
break
# `failure_text` here has length <= MAX_ERROR_TEXT
failure_text = new_text
title = job_name
content = {"type": "section", "text": {"type": "mrkdwn", "text": text}}
if job_link is not None:
content["accessory"] = {
"type": "button",
"text": {"type": "plain_text", "text": "GitHub Action job", "emoji": True},
"url": job_link,
}
return [
{"type": "header", "text": {"type": "plain_text", "text": title, "emoji": True}},
content,
{"type": "section", "text": {"type": "mrkdwn", "text": failure_text}},
]
def post_reply(self):
if self.thread_ts is None:
raise ValueError("Can only post reply if a post has been made.")
sorted_dict = sorted(self.doc_test_results.items(), key=lambda t: t[0])
for job_name, job_result in sorted_dict:
if len(job_result["failures"]) > 0:
text = f"*Num failures* :{len(job_result['failed'])} \n"
failures = job_result["failures"]
blocks = self.get_reply_blocks(job_name, job_result["job_link"], failures, text=text)
print("Sending the following reply")
print(json.dumps({"blocks": blocks}))
client.chat_postMessage(
channel=SLACK_REPORT_CHANNEL_ID,
text=f"Results for {job_name}",
blocks=blocks,
thread_ts=self.thread_ts["ts"],
)
time.sleep(1)
def retrieve_artifact(name: str):
_artifact = {}
if os.path.exists(name):
files = os.listdir(name)
for file in files:
try:
with open(os.path.join(name, file), encoding="utf-8") as f:
_artifact[file.split(".")[0]] = f.read()
except UnicodeDecodeError as e:
raise ValueError(f"Could not open {os.path.join(name, file)}.") from e
return _artifact
def retrieve_available_artifacts():
class Artifact:
def __init__(self, name: str):
self.name = name
self.paths = []
def __str__(self):
return self.name
def add_path(self, path: str):
self.paths.append({"name": self.name, "path": path})
_available_artifacts: dict[str, Artifact] = {}
directories = filter(os.path.isdir, os.listdir())
for directory in directories:
artifact_name = directory
if artifact_name not in _available_artifacts:
_available_artifacts[artifact_name] = Artifact(artifact_name)
_available_artifacts[artifact_name].add_path(directory)
return _available_artifacts
if __name__ == "__main__":
SLACK_REPORT_CHANNEL_ID = os.environ["SLACK_REPORT_CHANNEL"]
github_actions_jobs = get_jobs(
workflow_run_id=os.environ["GITHUB_RUN_ID"], token=os.environ["ACCESS_REPO_INFO_TOKEN"]
)
artifact_name_to_job_map = {}
for job in github_actions_jobs:
for step in job["steps"]:
if step["name"].startswith("Test suite reports artifacts: "):
artifact_name = step["name"][len("Test suite reports artifacts: ") :]
artifact_name_to_job_map[artifact_name] = job
break
available_artifacts = retrieve_available_artifacts()
doc_test_results = {}
# `artifact_key` is the artifact path
for artifact_obj in available_artifacts.values():
artifact_path = artifact_obj.paths[0]
if not artifact_path["path"].startswith("doc_tests_gpu_test_reports_"):
continue
# change "_" back to "/" (to show the job name as path)
job_name = artifact_path["path"].replace("doc_tests_gpu_test_reports_", "").replace("_", "/")
# This dict (for each job) will contain all the information relative to each doc test job, in particular:
# - failed: list of failed tests
# - failures: dict in the format 'test': 'error_message'
job_result = {}
doc_test_results[job_name] = job_result
job = artifact_name_to_job_map[artifact_path["path"]]
job_result["job_link"] = job["html_url"]
job_result["category"] = "Python Examples" if job_name.startswith("src/") else "MD Examples"
artifact = retrieve_artifact(artifact_path["path"])
if "stats" in artifact:
failed, success, time_spent = handle_test_results(artifact["stats"])
job_result["n_failures"] = failed
job_result["n_success"] = success
job_result["time_spent"] = time_spent[1:-1] + ", "
job_result["failed"] = []
job_result["failures"] = {}
all_failures = extract_first_line_failure(artifact["failures_short"])
for line in artifact["summary_short"].split("\n"):
if re.search("FAILED", line):
line = line.replace("FAILED ", "")
line = line.split()[0].replace("\n", "")
if "::" in line:
file_path, test = line.split("::")
else:
file_path, test = line, line
job_result["failed"].append(test)
failure = all_failures.get(test, "N/A")
job_result["failures"][test] = failure
# Save and to be uploaded as artifact
os.makedirs("doc_test_results", exist_ok=True)
with open("doc_test_results/doc_test_results.json", "w", encoding="UTF-8") as fp:
json.dump(doc_test_results, fp, ensure_ascii=False, indent=4)
message = Message("[INFO] Results of the doc tests.", doc_test_results)
message.post()
message.post_reply()

155
utils/patch_helper.py Normal file
View File

@@ -0,0 +1,155 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This should help you prepare a patch, automatically extracting the commits to cherry-pick
in chronological order to avoid merge conflicts. An equivalent way to do this is to use
`git log --pretty=oneline HEAD...v4.41.0` and grep.
Potential TODO: automatically cherry-picks them.
Pass in a list of PR:
`python utils/patch_helper.py --prs 31108 31054 31008 31010 31004`
will produce the following:
```bash
Skipping invalid version tag: list
Skipping invalid version tag: localattn1
Git cherry-pick commands to run:
git cherry-pick 03935d300d60110bb86edb49d2315089cfb19789 #2024-05-24 11:00:59+02:00
git cherry-pick bdb9106f247fca48a71eb384be25dbbd29b065a8 #2024-05-24 19:02:55+02:00
git cherry-pick 84c4b72ee99e8e65a8a5754a5f9d6265b45cf67e #2024-05-27 10:34:14+02:00
git cherry-pick 936ab7bae5e040ec58994cb722dd587b9ab26581 #2024-05-28 11:56:05+02:00
git cherry-pick 0bef4a273825d2cfc52ddfe62ba486ee61cc116f #2024-05-29 13:33:26+01:00
```
"""
import json
import subprocess
import transformers
LABEL = "for patch" # Replace with your label
REPO = "huggingface/transformers" # Optional if already in correct repo
def get_release_branch_name():
"""Derive branch name from transformers version."""
major, minor, *_ = transformers.__version__.split(".")
major = int(major)
minor = int(minor)
if minor == 0:
# Handle major version rollback, e.g., from 5.0 to 4.latest (if ever needed)
major -= 1
# You'll need logic to determine the last minor of the previous major version
raise ValueError("Minor version is 0; need logic to find previous major version's last minor")
return f"v{major}.{minor}-release"
def checkout_branch(branch):
"""Checkout the target branch."""
try:
subprocess.run(["git", "checkout", branch], check=True)
print(f"[SUCCESS] Checked out branch: {branch}")
except subprocess.CalledProcessError:
print(f"[FAIL] Failed to checkout branch: {branch}. Does it exist?")
exit(1)
def get_prs_by_label(label):
"""Call gh CLI to get PRs with a specific label."""
cmd = [
"gh",
"pr",
"list",
"--label",
label,
"--state",
"all",
"--json",
"number,title,mergeCommit,url",
"--limit",
"100",
]
result = subprocess.run(cmd, check=False, capture_output=True, text=True)
result.check_returncode()
prs = json.loads(result.stdout)
for pr in prs:
is_merged = pr.get("mergeCommit", {})
if is_merged:
pr["oid"] = is_merged.get("oid")
return prs
def get_commit_timestamp(commit_sha):
"""Get UNIX timestamp of a commit using git."""
result = subprocess.run(
["git", "show", "-s", "--format=%ct", commit_sha], check=False, capture_output=True, text=True
)
result.check_returncode()
return int(result.stdout.strip())
def cherry_pick_commit(sha):
"""Cherry-pick a given commit SHA."""
try:
subprocess.run(["git", "cherry-pick", sha], check=True)
print(f"[SUCCESS] Cherry-picked commit {sha}")
except subprocess.CalledProcessError:
print(f"[WARNING] Failed to cherry-pick {sha}. Manual intervention required.")
def commit_in_history(commit_sha, base_branch="HEAD"):
"""Return True if commit is already part of base_branch history."""
result = subprocess.run(
["git", "merge-base", "--is-ancestor", commit_sha, base_branch],
check=False,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return result.returncode == 0
def main(verbose=False):
branch = get_release_branch_name()
checkout_branch(branch)
prs = get_prs_by_label(LABEL)
# Attach commit timestamps
for pr in prs:
sha = pr.get("oid")
if sha:
pr["timestamp"] = get_commit_timestamp(sha)
else:
print("\n" + "=" * 80)
print(f"[WARNING] PR #{pr['number']} ({sha}) is NOT in main!")
print("[WARNING] A core maintainer must review this before cherry-picking.")
print("=" * 80 + "\n")
# Sort by commit timestamp (ascending)
prs = [pr for pr in prs if pr.get("timestamp") is not None]
prs.sort(key=lambda pr: pr["timestamp"])
for pr in prs:
sha = pr.get("oid")
if sha:
if commit_in_history(sha):
if verbose:
print(f"[INFO] PR #{pr['number']} ({pr['title']}) already in history. Skipping.")
else:
print(f"[INFO] PR #{pr['number']} ({pr['title']}) not in history. Cherry-picking...")
cherry_pick_commit(sha)
if __name__ == "__main__":
main()

175
utils/pr_slow_ci_models.py Normal file
View File

@@ -0,0 +1,175 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to get the models for which to run slow CI.
A new model added in a pull request will be included, as well as models specified in a GitHub pull request's comment
with a prefix `run-slow`, `run_slow` or `run slow`. For example, the commit message `run_slow: bert, gpt2` will give
`bert` and `gpt2`.
Usage:
```bash
python utils/pr_slow_ci_models.py
```
"""
import argparse
import json
import os.path
import re
import string
from pathlib import Path
from git import Repo
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
def get_new_python_files_between_commits(base_commit: str, commits: list[str]) -> list[str]:
"""
Get the list of added python files between a base commit and one or several commits.
Args:
repo (`git.Repo`):
A git repository (for instance the Transformers repo).
base_commit (`str`):
The commit reference of where to compare for the diff. This is the current commit, not the branching point!
commits (`List[str]`):
The list of commits with which to compare the repo at `base_commit` (so the branching point).
Returns:
`List[str]`: The list of python files added between a base commit and one or several commits.
"""
code_diff = []
for commit in commits:
for diff_obj in commit.diff(base_commit):
# We always add new python files
if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"):
code_diff.append(diff_obj.b_path)
return code_diff
def get_new_python_files(diff_with_last_commit=False) -> list[str]:
"""
Return a list of python files that have been added between the current head and the main branch.
Returns:
`List[str]`: The list of python files added.
"""
repo = Repo(PATH_TO_REPO)
try:
# For the cases where the main branch exists locally
main = repo.refs.main
except AttributeError:
# On GitHub Actions runners, it doesn't have local main branch
main = repo.remotes.origin.refs.main
if not diff_with_last_commit:
print(f"main is at {main.commit}")
print(f"Current head is at {repo.head.commit}")
commits = repo.merge_base(main, repo.head)
for commit in commits:
print(f"Branching commit: {commit}")
else:
print(f"main is at {main.commit}")
commits = main.commit.parents
for commit in commits:
print(f"Parent commit: {commit}")
return get_new_python_files_between_commits(repo.head.commit, commits)
def get_new_model(diff_with_last_commit=False):
new_files = get_new_python_files(diff_with_last_commit)
reg = re.compile(r"src/transformers/models/(.*)/modeling_.*\.py")
new_model = ""
for x in new_files:
find_new_model = reg.findall(x)
if len(find_new_model) > 0:
new_model = find_new_model[0]
# It's unlikely we have 2 new modeling files in a pull request.
break
return new_model
def parse_message(message: str) -> str:
"""
Parses a GitHub pull request's comment to find the models specified in it to run slow CI.
Args:
message (`str`): The body of a GitHub pull request's comment.
Returns:
`str`: The substring in `message` after `run-slow`, run_slow` or run slow`. If no such prefix is found, the
empty string is returned.
"""
if message is None:
return ""
message = message.strip().lower()
# run-slow: model_1, model_2
if not message.startswith(("run-slow", "run_slow", "run slow")):
return ""
message = message[len("run slow") :]
# remove leading `:`
while message.strip().startswith(":"):
message = message.strip()[1:]
return message
def get_models(message: str):
models = parse_message(message)
return models.replace(",", " ").split()
def check_model_names(model_name: str):
allowed = string.ascii_letters + string.digits + "_"
return not (model_name.startswith("_") or model_name.endswith("_")) and all(c in allowed for c in model_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--message", type=str, default="", help="The content of a comment.")
parser.add_argument("--quantization", action="store_true", help="If we collect quantization tests")
args = parser.parse_args()
new_model = get_new_model()
specified_models = get_models(args.message)
models = ([] if new_model == "" else [new_model]) + specified_models
# a guard for strange model names
models = [model for model in models if check_model_names(model)]
# Add prefix
final_list = []
for model in models:
if not args.quantization:
if os.path.isdir(f"tests/models/{model}"):
final_list.append(f"models/{model}")
elif os.path.isdir(f"tests/{model}") and model != "quantization":
final_list.append(model)
elif os.path.isdir(f"tests/quantization/{model}"):
final_list.append(f"quantization/{model}")
# Use `json.dumps` to get the double quotes instead of single quote, e.g. `["model/vit"]`.
# (to avoid some shell expansion issues when this script is called from a Github Actions workflow)
print(json.dumps(sorted(set(final_list))))

73
utils/print_env.py Normal file
View File

@@ -0,0 +1,73 @@
#!/usr/bin/env python3
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this script dumps information about the environment
import sys
import transformers
from transformers import is_torch_hpu_available, is_torch_xpu_available
print("Python version:", sys.version)
print("transformers version:", transformers.__version__)
try:
import torch
print("Torch version:", torch.__version__)
accelerator = "NA"
if torch.cuda.is_available():
accelerator = "CUDA"
elif is_torch_xpu_available():
accelerator = "XPU"
elif is_torch_hpu_available():
accelerator = "HPU"
print("Torch accelerator:", accelerator)
if accelerator == "CUDA":
print("Cuda version:", torch.version.cuda)
print("CuDNN version:", torch.backends.cudnn.version())
print("Number of GPUs available:", torch.cuda.device_count())
print("NCCL version:", torch.cuda.nccl.version())
elif accelerator == "XPU":
print("SYCL version:", torch.version.xpu)
print("Number of XPUs available:", torch.xpu.device_count())
elif accelerator == "HPU":
print("HPU version:", torch.__version__.split("+")[-1])
print("Number of HPUs available:", torch.hpu.device_count())
except ImportError:
print("Torch version:", None)
try:
import deepspeed
print("DeepSpeed version:", deepspeed.__version__)
except ImportError:
print("DeepSpeed version:", None)
try:
import torchcodec
versions = torchcodec._core.get_ffmpeg_library_versions()
print("FFmpeg version:", versions["ffmpeg_version"])
except ImportError:
print("FFmpeg version:", None)
except (AttributeError, KeyError, RuntimeError):
print("Failed to get FFmpeg version")

View File

@@ -0,0 +1,151 @@
"""An internal script to process `new_failures_with_bad_commit.json` produced by `utils/check_bad_commit.py`.
This is used by `.github/workflows/check_failed_model_tests.yml` to produce a slack report of the following form
```
<{url}|New failed tests>
{
"GH_ydshieh": {
"vit": 1
}
}
```
"""
import json
import os
from collections import Counter
from copy import deepcopy
from get_previous_daily_ci import get_last_daily_ci_run
from huggingface_hub import HfApi
if __name__ == "__main__":
api = HfApi()
job_name = os.environ.get("JOB_NAME")
# Upload to Hub and get the url
# if it is not a scheduled run, upload the reports to a subfolder under `report_repo_folder`
report_repo_subfolder = ""
if os.getenv("GITHUB_EVENT_NAME") != "schedule":
report_repo_subfolder = f"{os.getenv('GITHUB_RUN_NUMBER')}-{os.getenv('GITHUB_RUN_ID')}"
report_repo_subfolder = f"runs/{report_repo_subfolder}"
workflow_run = get_last_daily_ci_run(
token=os.environ["ACCESS_REPO_INFO_TOKEN"], workflow_run_id=os.getenv("GITHUB_RUN_ID")
)
workflow_run_created_time = workflow_run["created_at"]
report_repo_folder = workflow_run_created_time.split("T")[0]
if report_repo_subfolder:
report_repo_folder = f"{report_repo_folder}/{report_repo_subfolder}"
report_repo_id = os.getenv("REPORT_REPO_ID")
with open("new_failures_with_bad_commit.json") as fp:
data = json.load(fp)
with open(f"ci_results_{job_name}/job_links.json") as fp:
job_links = json.load(fp)
# Update `new_failures_with_bad_commit.json` with job links information before uploading to Hub repository
# - need to change `single-gpu` to `single` and same for `multi-gpu` to match the keys in `job_link`.
for model, model_result in data.items():
for device, failed_tests in model_result.items():
for failed_test in failed_tests:
key = model
if list(job_links.keys()) == [job_name]:
key = job_name
failed_test["job_link"] = job_links[key][device.replace("-gpu", "")]
with open("new_failures_with_bad_commit.json", "w") as fp:
json.dump(data, fp, indent=4, ensure_ascii=False)
commit_info = api.upload_file(
path_or_fileobj="new_failures_with_bad_commit.json",
path_in_repo=f"{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit.json",
repo_id=report_repo_id,
repo_type="dataset",
token=os.environ.get("TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN", None),
)
url = f"https://huggingface.co/datasets/{report_repo_id}/raw/{commit_info.oid}/{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit.json"
with open("new_failures_with_bad_commit_url.txt", "w") as fp:
fp.write(url)
# TODO: extend
team_members = [
"ArthurZucker",
"Cyrilvallez",
"LysandreJik",
"MekkCyber",
"Rocketknight1",
"SunMarc",
"ebezzam",
"eustlb",
"gante",
"itazap",
"ivarflakstad",
"molbap",
"remi-or",
"stevhliu",
"vasqu",
"ydshieh",
"zucchini-nlp",
"tarekziade",
]
# Counting the number of failures grouped by authors
new_data = {}
for model, model_result in data.items():
for device, failed_tests in model_result.items():
for failed_test in failed_tests:
author = failed_test["author"]
# If author is not a team member, and the PR is already merged: change to the one who merged the PR
if author not in team_members and failed_test["merged_by"] is not None:
author = failed_test["merged_by"]
if author not in new_data:
new_data[author] = Counter()
new_data[author].update([model])
for author in new_data:
new_data[author] = dict(new_data[author])
# Group by author
new_data_full = {author: deepcopy(data) for author in new_data}
for author, _data in new_data_full.items():
for model, model_result in _data.items():
for device, failed_tests in model_result.items():
failed_tests = [
x
for x in failed_tests
if x["author"] == author or (x["merged_by"] is not None and x["merged_by"] == author)
]
model_result[device] = failed_tests
_data[model] = {k: v for k, v in model_result.items() if len(v) > 0}
new_data_full[author] = {k: v for k, v in _data.items() if len(v) > 0}
with open("new_failures_with_bad_commit_grouped_by_authors.json", "w") as fp:
json.dump(new_data_full, fp, ensure_ascii=False, indent=4)
commit_info = api.upload_file(
path_or_fileobj="new_failures_with_bad_commit_grouped_by_authors.json",
path_in_repo=f"{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit_grouped_by_authors.json",
repo_id=report_repo_id,
repo_type="dataset",
token=os.environ.get("TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN", None),
)
url = f"https://huggingface.co/datasets/{report_repo_id}/raw/{commit_info.oid}/{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit_grouped_by_authors.json"
# Add `GH_` prefix as keyword mention
output = {}
for author, item in new_data.items():
author = f"GH_{author}"
output[author] = item
report = f"<{url}|New failed tests>\\n\\n"
report += json.dumps(output, indent=4).replace('"', '\\"').replace("\n", "\\n")
print(report)

View File

@@ -0,0 +1,146 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
from collections import Counter
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--workflow_id", type=str, required=True)
args = parser.parse_args()
r = requests.get(
f"https://circleci.com/api/v2/workflow/{args.workflow_id}/job",
headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")},
)
jobs = r.json()["items"]
os.makedirs("outputs", exist_ok=True)
workflow_summary = {}
failure_entries = []
for job in jobs:
if job["name"].startswith(("tests_", "examples_", "pipelines_")):
url = f"https://circleci.com/api/v2/project/{job['project_slug']}/{job['job_number']}/artifacts"
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
job_artifacts = r.json()["items"]
os.makedirs(f"outputs/{job['name']}", exist_ok=True)
job_test_summaries = {}
job_failure_lines = {}
for artifact in job_artifacts:
url = artifact["url"]
if artifact["path"].endswith("/summary_short.txt"):
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
job_test_summaries[artifact["node_index"]] = r.text
elif artifact["path"].endswith("/failures_line.txt"):
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
job_failure_lines[artifact["node_index"]] = r.text
summary = {}
for node_index, node_test_summary in job_test_summaries.items():
for line in node_test_summary.splitlines():
if line.startswith("PASSED "):
summary[line[7:]] = "passed"
elif line.startswith("FAILED "):
summary[line[7:].split()[0]] = "failed"
summary = dict(sorted(summary.items(), key=lambda x: (x[1], x[0])))
workflow_summary[job["name"]] = summary
with open(f"outputs/{job['name']}/test_summary.json", "w") as fp:
json.dump(summary, fp, indent=4)
# Collect failure details
for node_index, summary_text in job_test_summaries.items():
failure_lines_list = [
l.strip()
for l in job_failure_lines.get(node_index, "").splitlines()
if l.strip() and not l.strip().startswith(("=", "_", "short test summary")) and ": " in l
]
failure_idx = 0
for line in summary_text.splitlines():
if line.startswith("FAILED ") and " - Failed: (subprocess)" not in line:
test_name, _, short_error = line[7:].strip().partition(" - ")
test_name = test_name.strip()
parts = test_name.split("::", 1)[0].split("/")
model_name = parts[2] if len(parts) >= 3 and test_name.startswith("tests/models/") else None
full_error = (
failure_lines_list[failure_idx] if failure_idx < len(failure_lines_list) else short_error
)
failure_entries.append(
{
"job_name": job["name"],
"test_name": test_name,
"short_error": short_error,
"error": full_error,
"model_name": model_name,
}
)
failure_idx += 1
# Build workflow summary
new_workflow_summary = {}
for job_name, job_summary in workflow_summary.items():
for test, status in job_summary.items():
new_workflow_summary.setdefault(test, {})[job_name] = status
new_workflow_summary = {
test: dict(sorted(result.items())) for test, result in sorted(new_workflow_summary.items())
}
with open("outputs/test_summary.json", "w") as fp:
json.dump(new_workflow_summary, fp, indent=4)
# Aggregate failures by test and model
by_test, by_model = {}, {}
for entry in failure_entries:
# Normalize test name
normalized = entry["test_name"].split("[", 1)[0]
parts = normalized.split("::")
normalized = "::".join(parts[:-1] + [re.sub(r"_\d{2,}.*$", "", parts[-1])])
by_test.setdefault(normalized, {"count": 0, "errors": Counter(), "jobs": set(), "variants": set()})
by_test[normalized]["count"] += 1
by_test[normalized]["errors"][entry["error"]] += 1
by_test[normalized]["jobs"].add(entry["job_name"])
by_test[normalized]["variants"].add(entry["test_name"])
if entry["model_name"]:
by_model.setdefault(entry["model_name"], {"count": 0, "errors": Counter(), "tests": set()})
by_model[entry["model_name"]]["count"] += 1
by_model[entry["model_name"]]["errors"][entry["error"]] += 1
by_model[entry["model_name"]]["tests"].add(entry["test_name"])
# Convert Counter and sets to dicts/lists for JSON serialization
for info in by_test.values():
info["errors"] = dict(info["errors"].most_common())
info["jobs"] = sorted(info["jobs"])
info["variants"] = sorted(info["variants"])
for info in by_model.values():
info["errors"] = dict(info["errors"].most_common())
info["tests"] = sorted(info["tests"])
with open("outputs/failure_summary.json", "w") as fp:
json.dump({"failures": failure_entries, "by_test": by_test, "by_model": by_model}, fp, indent=4)

View File

@@ -0,0 +1,74 @@
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This helper computes the "ideal" number of nodes to use in circle CI.
For each job, we compute this parameter and pass it to the `generated_config.yaml`.
"""
import json
import math
import os
MAX_PARALLEL_NODES = 8 # TODO create a mapping!
AVERAGE_TESTS_PER_NODES = 5
def count_lines(filepath):
"""Count the number of lines in a file."""
try:
with open(filepath, "r") as f:
return len(f.read().split("\n"))
except FileNotFoundError:
return 0
def compute_parallel_nodes(line_count, max_tests_per_node=10):
"""Compute the number of parallel nodes required."""
num_nodes = math.ceil(line_count / AVERAGE_TESTS_PER_NODES)
if line_count < 4:
return 1
return min(MAX_PARALLEL_NODES, num_nodes)
def process_artifacts(input_file, output_file):
# Read the JSON data from the input file
with open(input_file, "r") as f:
data = json.load(f)
# Process items and build the new JSON structure
transformed_data = {}
for item in data.get("items", []):
if "test_list" in item["path"]:
key = os.path.splitext(os.path.basename(item["path"]))[0]
transformed_data[key] = item["url"]
parallel_key = key.split("_test")[0] + "_parallelism"
file_path = os.path.join("test_preparation", f"{key}.txt")
line_count = count_lines(file_path)
transformed_data[parallel_key] = compute_parallel_nodes(line_count)
# Remove the "generated_config" key if it exists
if "generated_config" in transformed_data:
del transformed_data["generated_config"]
# Write the transformed data to the output file
with open(output_file, "w") as f:
json.dump(transformed_data, f, indent=2)
if __name__ == "__main__":
input_file = "test_preparation/artifacts.json"
output_file = "test_preparation/transformed_artifacts.json"
process_artifacts(input_file, output_file)

227
utils/release.py Normal file
View File

@@ -0,0 +1,227 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that prepares the repository for releases (or patches) by updating all versions in the relevant places. It
also performs some post-release cleanup, by updating the links in the main README to respective model doc pages (from
main to stable).
To prepare for a release, use from the root of the repo on the release branch with:
```bash
python release.py
```
or use `make pre-release`.
To prepare for a patch release, use from the root of the repo on the release branch with:
```bash
python release.py --patch
```
or use `make pre-patch`.
To do the post-release cleanup, use from the root of the repo on the main branch with:
```bash
python release.py --post_release
```
or use `make post-release`.
"""
import argparse
import os
import re
from pathlib import Path
import packaging.version
# All paths are defined with the intent that this script should be run from the root of the repo.
PATH_TO_EXAMPLES = "examples/"
PATH_TO_MODELS = "src/transformers/models"
PATH_TO_UTILS = "utils"
# This maps a type of file to the pattern to look for when searching where the version is defined, as well as the
# template to follow when replacing it with the new version.
REPLACE_PATTERNS = {
"examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'),
"init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
"setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
"uv_script_release": (
re.compile(r'^# "transformers(\[.+\])?.*$', re.MULTILINE),
r'# "transformers\g<1>==VERSION",',
),
"uv_script_dev": (
re.compile(r'^# "transformers(\[.+\])?.*$', re.MULTILINE),
r'# "transformers\g<1> @ git+https://github.com/huggingface/transformers.git",',
),
}
# This maps a type of file to its path in Transformers
REPLACE_FILES = {
"init": "src/transformers/__init__.py",
"setup": "setup.py",
}
README_FILE = "README.md"
UV_SCRIPT_MARKER = "# /// script"
def update_version_in_file(fname: str, version: str, file_type: str):
"""
Update the version of Transformers in one file.
Args:
fname (`str`): The path to the file where we want to update the version.
version (`str`): The new version to set in the file.
file_type (`str`): The type of the file (should be a key in `REPLACE_PATTERNS`).
"""
with open(fname, "r", encoding="utf-8", newline="\n") as f:
code = f.read()
re_pattern, replace = REPLACE_PATTERNS[file_type]
replace = replace.replace("VERSION", version)
code = re_pattern.sub(replace, code)
with open(fname, "w", encoding="utf-8", newline="\n") as f:
f.write(code)
def update_version_in_examples(version: str, patch: bool = False):
"""
Update the version in all examples files.
Args:
version (`str`): The new version to set in the examples.
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
"""
for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES):
# Removing some of the folders with non-actively maintained examples from the walk
if "legacy" in directories:
directories.remove("legacy")
for fname in fnames:
if fname.endswith(".py"):
if UV_SCRIPT_MARKER in Path(folder, fname).read_text():
# Update the dependencies in UV scripts
uv_script_file_type = "uv_script_dev" if ".dev" in version else "uv_script_release"
update_version_in_file(os.path.join(folder, fname), version, file_type=uv_script_file_type)
if not patch:
# We don't update the version in the examples for patch releases.
update_version_in_file(os.path.join(folder, fname), version, file_type="examples")
def global_version_update(version: str, patch: bool = False):
"""
Update the version in all needed files.
Args:
version (`str`): The new version to set everywhere.
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
"""
for pattern, fname in REPLACE_FILES.items():
update_version_in_file(fname, version, pattern)
# REMOVED AFTER v5! Uncomment to start updating the version of the examples again
# update_version_in_examples(version, patch=patch)
def remove_conversion_scripts():
"""
Delete the scripts that convert models from older, unsupported formats. We don't want to include these
in release wheels because they often have to open insecure file types (pickle, Torch .bin models). This results in
vulnerability scanners flagging us and can cause compliance issues for users with strict security policies.
"""
model_dir = Path(PATH_TO_MODELS)
for conversion_script in list(model_dir.glob("**/convert*.py")):
conversion_script.unlink()
def remove_internal_utils():
"""
Delete internal utils that should not be included in releases for security reasons.
"""
(Path(PATH_TO_UTILS) / "modular_model_detector.py").unlink()
def get_version() -> packaging.version.Version:
"""
Reads the current version in the main __init__.
"""
with open(REPLACE_FILES["init"], "r") as f:
code = f.read()
default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
return packaging.version.parse(default_version)
def pre_release_work(patch: bool = False):
"""
Do all the necessary pre-release steps:
- figure out the next minor release version and ask confirmation
- update the version everywhere
- clean-up the model list in the main README
Args:
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
"""
# First let's get the default version: base version if we are in dev, bump minor otherwise.
default_version = get_version()
if patch and default_version.is_devrelease:
raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
if default_version.is_devrelease:
default_version = default_version.base_version
elif patch:
default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
else:
default_version = f"{default_version.major}.{default_version.minor + 1}.0"
# Now let's ask nicely if we have found the right version.
version = input(f"Which version are you releasing? [{default_version}]")
if len(version) == 0:
version = default_version
print(f"Updating version to {version}.")
global_version_update(version, patch=patch)
print("Deleting conversion and internal utils scripts.")
remove_conversion_scripts()
remove_internal_utils()
def post_release_work():
"""
Do all the necessary post-release steps:
- figure out the next dev version and ask confirmation
- update the version everywhere
- clean-up the model list in the main README
"""
# First let's get the current version
current_version = get_version()
dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
current_version = current_version.base_version
# Check with the user we got that right.
version = input(f"Which version are we developing now? [{dev_version}]")
if len(version) == 0:
version = dev_version
print(f"Updating version to {version}.")
global_version_update(version)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
args = parser.parse_args()
if not args.post_release:
pre_release_work(patch=args.patch)
elif args.patch:
print("Nothing to do after a patch :-)")
else:
post_release_work()

251
utils/rules.toml Normal file
View File

@@ -0,0 +1,251 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file can carry repo-local rule overrides for faster iteration between
# `transformers-mlinter` releases.
# Keep it synced with the upstream package's rules.toml when possible so local
# behavior does not drift from the published checker longer than necessary.
version = 1
[rules.TRF001]
description = "Class-level config_class on <Model>PreTrainedModel should match <Model>Config naming."
default_enabled = true
allowlist_models = ["qwen3_omni_moe"]
[rules.TRF001.explanation]
what_it_does = "Checks naming consistency between <Model>PreTrainedModel and config_class."
why_bad = "Mismatched config_class can break loading, auto classes, and developer expectations."
diff = '''
class AcmePreTrainedModel(PreTrainedModel):
- config_class = WileConfig
+ config_class = AcmeConfig
'''
[rules.TRF002]
description = "base_model_prefix should be a non-empty canonical string when defined on PreTrainedModel classes."
default_enabled = true
allowlist_models = ["lighton_ocr"]
[rules.TRF002.explanation]
what_it_does = "Checks that base_model_prefix, when set, is a non-empty, whitespace-free string literal."
why_bad = "Invalid prefixes can break weight loading key mapping and base model access patterns."
diff = '''
class AcmePreTrainedModel(PreTrainedModel):
- base_model_prefix = ""
+ base_model_prefix = "model"
'''
[rules.TRF003]
description = "forward() should use capture_output/can_return_tuple decorators instead of manual return_dict branching."
default_enabled = false
allowlist_models = []
[rules.TRF003.explanation]
what_it_does = "Detects forward methods that use the old 'if not return_dict: return (x,)' pattern."
why_bad = "The old return_dict branching pattern is error-prone and verbose. Use the capture_output or can_return_tuple decorators instead."
diff = '''
-def forward(self, x, return_dict=None):
- if not return_dict:
- return (x,)
- return AcmeModelOutput(last_hidden_state=x)
+@can_return_tuple
+def forward(self, x):
+ return AcmeModelOutput(last_hidden_state=x)
'''
[rules.TRF004]
description = "Models must never override tie_weights. Use _tied_weights_keys instead."
default_enabled = true
allowlist_models = ["data2vec", "hubert", "sew", "sew_d", "unispeech", "unispeech_sat", "wav2vec2", "wav2vec2_conformer", "wavlm"]
[rules.TRF004.explanation]
what_it_does = "Checks that no model class defines a tie_weights method."
why_bad = "Overriding tie_weights leads to bad consequences for loading, device_map computation, and saving. Use _tied_weights_keys class attribute to declare tied weights instead."
diff = '''
-def tie_weights(self):
- self.lm_head.weight = self.emb.weight
+class AcmeForCausalLM(AcmePreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
'''
[rules.TRF005]
description = "_no_split_modules, when defined, should be a list/tuple of non-empty strings."
default_enabled = true
allowlist_models = ["d_fine", "deformable_detr", "glm46v", "lw_detr", "pp_doclayout_v3", "rt_detr", "rt_detr_v2", "voxtral", "voxtral_realtime"]
[rules.TRF005.explanation]
what_it_does = "Checks the shape of _no_split_modules when present."
why_bad = "Malformed values can break device-map partitioning and sharding behavior."
diff = '''
-_no_split_modules = [SomeLayerClass, ""]
+_no_split_modules = ["AcmeDecoderLayer", "AcmeAttention"]
'''
[rules.TRF006]
description = "forward with cache arguments should reference cache control/state variables consistently."
default_enabled = true
allowlist_models = ["chinese_clip", "evolla", "idefics2", "llama4"]
[rules.TRF006.explanation]
what_it_does = "Checks forward signatures that expose cache arguments for usage of those arguments in method body."
why_bad = "Unused cache arguments can indicate incomplete caching support and inconsistent API behavior."
diff = '''
def forward(self, x, past_key_values=None, use_cache=False):
+ if use_cache:
+ ...
return x
'''
[rules.TRF007]
description = "self.post_init() in __init__ should remain at the end of initialization for PreTrainedModel classes."
default_enabled = true
allowlist_models = ["distilbert", "lxmert", "mt5", "pix2struct", "pop2piano", "switch_transformers", "t5"]
[rules.TRF007.explanation]
what_it_does = "Checks for self attribute assignments after self.post_init() in __init__."
why_bad = "Mutating model structure after post_init can bypass intended initialization/finalization logic."
diff = '''
def __init__(self, config):
...
- self.post_init()
- self.proj = nn.Linear(...)
+ self.proj = nn.Linear(...)
+ self.post_init()
'''
[rules.TRF008]
description = "Doc decorators on PreTrainedModel classes should avoid empty add_start_docstrings usage."
default_enabled = true
[rules.TRF008.explanation]
what_it_does = "Checks add_start_docstrings usage on model classes for non-empty docstring arguments."
why_bad = "Empty decorator usage produces unclear docs and weakens generated API documentation quality."
diff = '''
-@add_start_docstrings("")
+@add_start_docstrings("The Acme model.")
class AcmeModel(AcmePreTrainedModel):
...
'''
[rules.TRF009]
description = "modeling_<name>.py should avoid importing implementation code from another model package."
default_enabled = true
allowlist_models = ["dpr", "maskformer", "sam3_video", "vision_text_dual_encoder"]
[rules.TRF009.explanation]
what_it_does = "Checks modeling files for cross-model imports such as transformers.models.other_model.* or from ..other_model.* imports."
why_bad = "Cross-model implementation imports violate the single-file policy and make model behavior harder to inspect and maintain."
diff = '''
-from transformers.models.llama.modeling_llama import LlamaAttention
+# Keep implementation local to this file.
+# If reusing code, copy it with a # Copied from comment.
'''
[rules.TRF010]
description = "Direct config definitions must use @strict(accept_kwargs=True)."
default_enabled = true
allowlist_models = ["nemotron_h", "vibevoice_asr"]
[rules.TRF010.explanation]
what_it_does = "Checks direct PreTrainedConfig/PretrainedConfig subclasses in configuration_*.py and modular_*.py for an explicit @strict(accept_kwargs=True) decorator."
why_bad = "Without strict, new config classes miss the repo's runtime type-validation contract and drift from the dataclass-based config standard."
diff = '''
+@strict(accept_kwargs=True)
class AcmeConfig(PreTrainedConfig):
...
'''
[rules.TRF011]
description = "forward() must not access non-nn.Module attributes on submodules (breaks pipeline parallelism with Identity replacement)."
default_enabled = true
allowlist_models = []
[rules.TRF011.explanation]
what_it_does = "In forward() methods of PreTrainedModel subclasses, checks for attribute accesses on submodules that would not exist on torch.nn.Identity. This includes attribute accesses on loop variables iterating over self.layers, and self.<submodule>.<attr> chains where <attr> is not a standard nn.Module attribute."
why_bad = "Pipeline parallelism may replace any submodule with torch.nn.Identity. Accessing custom attributes (e.g. decoder_layer.attention_type) on a replaced module raises AttributeError at runtime. Per-layer metadata should be read from self.config instead."
diff = '''
def forward(self, ...):
- for decoder_layer in self.layers:
+ for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
- attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ attention_mask=causal_mask_mapping[self.config.layer_types[i]],
)
'''
[rules.TRF012]
description = "_init_weights must use init primitives, not in-place operations on module weights."
default_enabled = true
allowlist_models = []
[rules.TRF012.explanation]
what_it_does = "Checks that _init_weights(self, module) does not use in-place operations (e.g. .normal_(), .zero_()) directly on module weights."
why_bad = "We rely on internal flags set on parameters to track whether they need re-initialization. In-place ops bypass this mechanism. Use the `init` primitives instead."
diff = '''
+from transformers import initialization as init
+
def _init_weights(self, module):
- module.weight.normal_(mean=0.0, std=0.02)
+ init.normal_(module.weight, mean=0.0, std=0.02)
'''
[rules.TRF013]
description = "PreTrainedModel __init__ must call self.post_init()."
default_enabled = true
allowlist_models = []
[rules.TRF013.explanation]
what_it_does = "Checks that every PreTrainedModel subclass with an __init__ method calls self.post_init(). In modular files, calling super().__init__() is also accepted since it propagates post_init from the parent."
why_bad = "post_init performs essential finalization (weight initialization, gradient checkpointing setup, etc.). Omitting it causes subtle runtime bugs."
diff = '''
class AcmeModel(AcmePreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList(...)
+ self.post_init()
'''
[rules.TRF014]
description = "`trust_remote_code` should never be used in native model integrations."
default_enabled = true
allowlist_models = []
[rules.TRF014.explanation]
what_it_does = "Checks whether `trust_remote_code` is passed or used in code (e.g. as kwarg) within native model integration files."
why_bad = "`trust_remote_code` allows arbitrary loading, including binaries, which should only be a power feature for users, not a standard use-case. Native integrations must not depend on it, as remote code cannot be reviewed or maintained within transformers."
diff = '''
class AcmeModel(AcmePreTrainedModel):
def __init__(self, config):
super().__init__(config)
- self.model = AutoModel.from_pretrained(..., trust_remote_code=True)
+ self.model = AutoModel.from_pretrained(...)
'''
[rules.TRF015]
description = "Models with non-empty _tied_weights_keys must have tie_word_embeddings in their Config."
default_enabled = true
allowlist_models = []
[rules.TRF015.explanation]
what_it_does = "When a PreTrainedModel subclass defines _tied_weights_keys as a non-empty collection, checks that the corresponding configuration file declares a tie_word_embeddings field."
why_bad = "Without tie_word_embeddings in the config, users cannot control weight tying behavior. The model ties weights unconditionally, breaking serialization round-trips and preventing fine-tuning with untied heads."
diff = '''
# configuration_foo.py
@strict(accept_kwargs=True)
class FooConfig(PreTrainedConfig):
hidden_size: int = 768
+ tie_word_embeddings: bool = True
'''

198
utils/scan_skipped_tests.py Normal file
View File

@@ -0,0 +1,198 @@
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import re
from pathlib import Path
REPO_ROOT = Path().cwd()
COMMON_TEST_FILES: list[tuple[Path, str]] = [
(REPO_ROOT / "tests/test_modeling_common.py", "common"),
(REPO_ROOT / "tests/generation/test_utils.py", "GenerationMixin"),
]
MODELS_DIR = REPO_ROOT / "tests/models"
def get_common_tests(file_paths_with_origin: list[tuple[Path, str]]) -> dict[str, str]:
"""Extract all common test function names (e.g., 'test_forward')."""
tests_with_origin: dict[str, str] = {}
for file_path, origin_tag in file_paths_with_origin:
if not file_path.is_file():
continue
content = file_path.read_text(encoding="utf-8")
for test_name in re.findall(r"^\s*def\s+(test_[A-Za-z0-9_]+)", content, re.MULTILINE):
tests_with_origin[test_name] = origin_tag
return tests_with_origin
def get_models_and_test_files(models_dir: Path) -> tuple[list[str], list[Path]]:
if not models_dir.is_dir():
raise FileNotFoundError(f"Models directory not found at {models_dir}")
test_files: list[Path] = sorted(models_dir.rglob("test_modeling_*.py"))
model_names: list[str] = sorted({file_path.parent.name for file_path in test_files})
return model_names, test_files
def _extract_reason_from_decorators(decorators_block: str) -> str:
"""Extracts the reason string from a decorator block, if any."""
reason_match = re.search(r'reason\s*=\s*["\'](.*?)["\']', decorators_block)
if reason_match:
return reason_match.group(1)
reason_match = re.search(r'\((?:.*?,\s*)?["\'](.*?)["\']\)', decorators_block)
if reason_match:
return reason_match.group(1)
return decorators_block.strip().split("\n")[-1].strip()
def extract_test_info(file_content: str) -> dict[str, tuple[str, str]]:
"""
Parse a test file once and return a mapping of test functions to their
status and skip reason, e.g. {'test_forward': ('SKIPPED', 'too slow')}.
"""
result: dict[str, tuple[str, str]] = {}
pattern = re.compile(r"((?:^\s*@.*?\n)*?)^\s*def\s+(test_[A-Za-z0-9_]+)\b", re.MULTILINE)
for decorators_block, test_name in pattern.findall(file_content):
if "skip" in decorators_block:
result[test_name] = ("SKIPPED", _extract_reason_from_decorators(decorators_block))
else:
result[test_name] = ("RAN", "")
return result
def build_model_overrides(model_test_files: list[Path]) -> dict[str, dict[str, tuple[str, str]]]:
"""Return *model_name → {test_name → (status, reason)}* mapping."""
model_overrides: dict[str, dict[str, tuple[str, str]]] = {}
for file_path in model_test_files:
model_name = file_path.parent.name
file_content = file_path.read_text(encoding="utf-8")
model_overrides.setdefault(model_name, {}).update(extract_test_info(file_content))
return model_overrides
def save_json(obj: dict, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
def summarize_single_test(
test_name: str,
model_names: list[str],
model_overrides: dict[str, dict[str, tuple[str, str]]],
) -> dict[str, object]:
"""Print a concise terminal summary for *test_name* and return the raw data."""
models_ran, models_skipped, reasons_for_skipping = [], [], []
for model_name in model_names:
status, reason = model_overrides.get(model_name, {}).get(test_name, ("RAN", ""))
if status == "SKIPPED":
models_skipped.append(model_name)
reasons_for_skipping.append(f"{model_name}: {reason}")
else:
models_ran.append(model_name)
total_models = len(model_names)
skipped_ratio = len(models_skipped) / total_models if total_models else 0.0
print(f"\n== {test_name} ==")
print(f"Ran : {len(models_ran)}/{total_models}")
print(f"Skipped : {len(models_skipped)}/{total_models} ({skipped_ratio:.1%})")
for reason_entry in reasons_for_skipping[:10]:
print(f" - {reason_entry}")
if len(reasons_for_skipping) > 10:
print(" - ...")
return {
"models_ran": sorted(models_ran),
"models_skipped": sorted(models_skipped),
"skipped_proportion": round(skipped_ratio, 4),
"reasons_skipped": sorted(reasons_for_skipping),
}
def summarize_all_tests(
tests_with_origin: dict[str, str],
model_names: list[str],
model_overrides: dict[str, dict[str, tuple[str, str]]],
) -> dict[str, object]:
"""Return aggregated data for every discovered common test."""
results: dict[str, object] = {}
total_models = len(model_names)
test_names = list(tests_with_origin)
print(f"[INFO] Aggregating {len(test_names)} tests...")
for index, test_fn in enumerate(test_names, 1):
print(f" ({index}/{len(test_names)}) {test_fn}", end="\r")
models_ran, models_skipped, reasons_for_skipping = [], [], []
for model_name in model_names:
status, reason = model_overrides.get(model_name, {}).get(test_fn, ("RAN", ""))
if status == "SKIPPED":
models_skipped.append(model_name)
reasons_for_skipping.append(f"{model_name}: {reason}")
else:
models_ran.append(model_name)
skipped_ratio = len(models_skipped) / total_models if total_models else 0.0
results[test_fn] = {
"origin": tests_with_origin[test_fn],
"models_ran": sorted(models_ran),
"models_skipped": sorted(models_skipped),
"skipped_proportion": round(skipped_ratio, 4),
"reasons_skipped": sorted(reasons_for_skipping),
}
print("\n[INFO] Scan complete.")
return results
def main() -> None:
parser = argparse.ArgumentParser(
description="Scan model tests for overridden or skipped common or generate tests.",
)
parser.add_argument(
"--output_dir",
default=".",
help="Directory for JSON output (default: %(default)s)",
)
parser.add_argument(
"--test_method_name",
help="Scan only this test method (singletest mode)",
)
args = parser.parse_args()
output_dir = Path(args.output_dir).expanduser()
test_method_name = args.test_method_name
tests_with_origin = get_common_tests(COMMON_TEST_FILES)
if test_method_name:
tests_with_origin = {test_method_name: tests_with_origin.get(test_method_name, "unknown")}
model_names, model_test_files = get_models_and_test_files(MODELS_DIR)
print(f"[INFO] Parsing {len(model_test_files)} model test files once each...")
model_overrides = build_model_overrides(model_test_files)
if test_method_name:
data = summarize_single_test(test_method_name, model_names, model_overrides)
json_path = output_dir / f"scan_{test_method_name}.json"
else:
data = summarize_all_tests(tests_with_origin, model_names, model_overrides)
json_path = output_dir / "all_tests_scan_result.json"
save_json(data, json_path)
print(f"\n[INFO] JSON saved to {json_path.resolve()}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,26 @@
"""A simple script to set flexibly CUDA_VISIBLE_DEVICES in GitHub Actions CI workflow files."""
import argparse
import os
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--test_folder",
type=str,
default=None,
help="The test folder name of the model being tested. For example, `models/cohere`.",
)
args = parser.parse_args()
# `test_eager_matches_sdpa_generate` for `cohere` needs a lot of GPU memory!
# This depends on the runners. At this moment we are targeting our AWS CI runners.
if args.test_folder == "models/cohere":
cuda_visible_devices = "0,1,2,3"
elif "CUDA_VISIBLE_DEVICES" in os.environ:
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
else:
cuda_visible_devices = "0"
print(cuda_visible_devices)

View File

@@ -0,0 +1,15 @@
docs/source/en/generation_strategies.md
docs/source/en/model_doc/code_llama.md
docs/source/en/model_doc/ctrl.md
docs/source/en/model_doc/kosmos-2.md
docs/source/en/model_doc/seamless_m4t.md
docs/source/en/model_doc/seamless_m4t_v2.md
docs/source/en/tasks/prompting.md
docs/source/ja/model_doc/code_llama.md
src/transformers/models/blip_2/modeling_blip_2.py
src/transformers/models/ctrl/modeling_ctrl.py
src/transformers/models/fuyu/modeling_fuyu.py
src/transformers/models/idefics2/modeling_idefics2.py
src/transformers/models/kosmos2/modeling_kosmos2.py
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
src/transformers/models/musicgen_melody/processing_musicgen_melody.py

130
utils/sort_auto_mappings.py Normal file
View File

@@ -0,0 +1,130 @@
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that sorts the names in the auto mappings defines in the auto modules in alphabetical order.
Use from the root of the repo with:
```bash
python utils/sort_auto_mappings.py
```
to auto-fix all the auto mappings (used in `make style`).
To only check if the mappings are properly sorted (as used in `make check-repo`), do:
```bash
python utils/sort_auto_mappings.py --check_only
```
"""
import argparse
import os
import re
CHECKER_CONFIG = {
"name": "sort_auto_mappings",
"label": "Sort auto mappings",
"cache_globs": ["src/transformers/models/auto/*.py"],
"check_args": ["--check_only"],
"fix_args": [],
}
# Path are set with the intent you should run this script from the root of the repo.
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
# re pattern that matches XXX_MAPPING_NAMES or SPECIAL_MODEL_TYPE_TO_MODULE_NAMES
_re_intro_mapping = re.compile(r"[A-Z_]+(_MAPPING|_MODEL_TYPE_TO_MODULE)(\s+|_[A-Z_]+\s+)=\s+OrderedDict(?!\(\*)")
# re pattern that matches identifiers in mappings
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
def sort_auto_mapping(fname: str, overwrite: bool = False) -> bool | None:
"""
Sort all auto mappings in a file.
Args:
fname (`str`): The name of the file where we want to sort auto-mappings.
overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
Returns:
`Optional[bool]`: Returns `None` if `overwrite=True`. Otherwise returns `True` if the file has an auto-mapping
improperly sorted, `False` if the file is okay.
"""
with open(fname, "r", encoding="utf-8") as f:
content = f.read()
lines = content.split("\n")
new_lines = []
line_idx = 0
while line_idx < len(lines):
if _re_intro_mapping.search(lines[line_idx]) is not None:
# Start of a new mapping!
indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
while not lines[line_idx].startswith(" " * indent + "("):
new_lines.append(lines[line_idx])
line_idx += 1
blocks = []
while lines[line_idx].strip() != "]":
# Blocks either fit in one line or not
if lines[line_idx].strip() == "(":
start_idx = line_idx
while not lines[line_idx].startswith(" " * indent + ")"):
line_idx += 1
blocks.append("\n".join(lines[start_idx : line_idx + 1]))
else:
blocks.append(lines[line_idx])
line_idx += 1
# Sort blocks by their identifiers
blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
new_lines += blocks
else:
new_lines.append(lines[line_idx])
line_idx += 1
if overwrite:
with open(fname, "w", encoding="utf-8") as f:
f.write("\n".join(new_lines))
else:
return "\n".join(new_lines) != content
def sort_all_auto_mappings(overwrite: bool = False):
"""
Sort all auto mappings in the library.
Args:
overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
"""
fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
if not overwrite and any(diffs):
failures = [f for f, d in zip(fnames, diffs) if d]
raise ValueError(
f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
" this."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
args = parser.parse_args()
sort_all_auto_mappings(not args.check_only)

View File

@@ -0,0 +1,98 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to get the files against which we will run doc testing.
This uses `tests_fetcher.get_all_doctest_files` then groups the test files by their directory paths.
The files in `docs/source/en/model_doc` or `docs/source/en/tasks` are **NOT** grouped together with other files in the
same directory: the objective is to run doctest against them in independent GitHub Actions jobs.
Assume we are under `transformers` root directory:
To get a map (dictionary) between directory (or file) paths and the corresponding files
```bash
python utils/split_doctest_jobs.py
```
or to get a list of lists of directory (or file) paths
```bash
python utils/split_doctest_jobs.py --only_return_keys --num_splits 4
```
(this is used to allow GitHub Actions to generate more than 256 jobs using matrix)
"""
import argparse
from collections import defaultdict
from pathlib import Path
from tests_fetcher import get_all_doctest_files
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--only_return_keys",
action="store_true",
help="if to only return the keys (which is a list of list of files' directory or file paths).",
)
parser.add_argument(
"--num_splits",
type=int,
default=1,
help="the number of splits into which the (flat) list of directory/file paths will be split. This has effect only if `only_return_keys` is `True`.",
)
args = parser.parse_args()
all_doctest_files = get_all_doctest_files()
raw_test_collection_map = defaultdict(list)
for file in all_doctest_files:
file_dir = "/".join(Path(file).parents[0].parts)
# not to run files in `src/` for now as it is completely broken at this moment. See issues/39159 and
# https://github.com/huggingface/transformers/actions/runs/15988670157
# TODO (ydshieh): fix the error, ideally before 2025/09
if file_dir.startswith("src/"):
continue
raw_test_collection_map[file_dir].append(file)
refined_test_collection_map = {}
for file_dir in raw_test_collection_map:
if file_dir in ["docs/source/en/model_doc", "docs/source/en/tasks"]:
for file in raw_test_collection_map[file_dir]:
refined_test_collection_map[file] = file
else:
refined_test_collection_map[file_dir] = " ".join(sorted(raw_test_collection_map[file_dir]))
sorted_file_dirs = sorted(refined_test_collection_map.keys())
test_collection_map = {}
for file_dir in sorted_file_dirs:
test_collection_map[file_dir] = refined_test_collection_map[file_dir]
num_jobs = len(test_collection_map)
num_jobs_per_splits = num_jobs // args.num_splits
file_directory_splits = []
end = 0
for idx in range(args.num_splits):
start = end
end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0)
file_directory_splits.append(sorted_file_dirs[start:end])
if args.only_return_keys:
print(file_directory_splits)
else:
print(dict(test_collection_map))

View File

@@ -0,0 +1,88 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is used to get the list of folders under `tests/models` and split the list into `NUM_SLICES` splits.
The main use case is a GitHub Actions workflow file calling this script to get the (nested) list of folders allowing it
to split the list of jobs to run into multiple slices each containing a smaller number of jobs. This way, we can bypass
the maximum of 256 jobs in a matrix.
See the `setup` and `run_models_gpu` jobs defined in the workflow file `.github/workflows/self-scheduled.yml` for more
details.
Usage:
This script is required to be run under `tests` folder of `transformers` root directory.
Assume we are under `transformers` root directory:
```bash
cd tests
python ../utils/split_model_tests.py --num_splits 64
```
"""
import argparse
import ast
import os
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--subdirs",
type=str,
default="",
help="the list of pre-computed model names (directory names under `tests/models`) or directory names under `tests` (except `models`).",
)
parser.add_argument(
"--num_splits",
type=int,
default=1,
help="the number of splits into which the (flat) list of folders will be split.",
)
args = parser.parse_args()
tests = os.getcwd()
model_tests = os.listdir(os.path.join(tests, "models"))
d1 = sorted(filter(os.path.isdir, os.listdir(tests)))
d2 = sorted(filter(os.path.isdir, [f"models/{x}" for x in model_tests]))
d1.remove("models")
d = d2 + d1
if args.subdirs != "":
model_tests = ast.literal_eval(args.subdirs)
# We handle both cases with and without prefix because `push-important-models.yml` returns the list without
# the prefix (i.e. `models`) but `utils/pr_slow_ci_models.py` (called by `self-comment-ci.yml`) returns the
# list with the prefix (`models`) and some directory names under `tests`.
d = []
for x in model_tests:
if os.path.isdir(x):
d.append(x)
if os.path.isdir(f"models/{x}"):
d.append(f"models/{x}")
d = sorted(d)
num_jobs = len(d)
num_jobs_per_splits = num_jobs // args.num_splits
model_splits = []
end = 0
for idx in range(args.num_splits):
start = end
end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0)
# Only add the slice if it is not an empty list
if len(d[start:end]) > 0:
model_splits.append(d[start:end])
print(model_splits)

View File

View File

@@ -0,0 +1,9 @@
from transformers import PreTrainedConfig
class CustomConfig(PreTrainedConfig):
model_type = "custom"
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)

View File

@@ -0,0 +1,5 @@
from transformers import Wav2Vec2FeatureExtractor
class CustomFeatureExtractor(Wav2Vec2FeatureExtractor):
pass

View File

@@ -0,0 +1,5 @@
from transformers import CLIPImageProcessor
class CustomImageProcessor(CLIPImageProcessor):
pass

View File

@@ -0,0 +1,20 @@
import torch
from transformers import PreTrainedModel
from .custom_configuration import CustomConfig
class CustomModel(PreTrainedModel):
config_class = CustomConfig
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
self.post_init()
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass

View File

@@ -0,0 +1,33 @@
import numpy as np
from transformers import Pipeline
def softmax(outputs):
maxes = np.max(outputs, axis=-1, keepdims=True)
shifted_exp = np.exp(outputs - maxes)
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
class PairClassificationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "second_text" in kwargs:
preprocess_kwargs["second_text"] = kwargs["second_text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, second_text=None):
return self.tokenizer(text, text_pair=second_text, return_tensors="pt")
def _forward(self, model_inputs):
return self.model(**model_inputs)
def postprocess(self, model_outputs):
logits = model_outputs.logits[0].numpy()
probabilities = softmax(logits)
best_class = np.argmax(probabilities)
label = self.model.config.id2label[best_class]
score = probabilities[best_class].item()
logits = logits.tolist()
return {"label": label, "score": score, "logits": logits}

View File

@@ -0,0 +1,6 @@
from transformers import ProcessorMixin
class CustomProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)

View File

@@ -0,0 +1,5 @@
from transformers import BertTokenizer
class CustomTokenizer(BertTokenizer):
pass

View File

@@ -0,0 +1,10 @@
from transformers import BertTokenizerFast
from .custom_tokenization import CustomTokenizer
class CustomTokenizerFast(BertTokenizerFast):
slow_tokenizer_class = CustomTokenizer
_auto_map = {
"AutoTokenizer": ("custom_tokenization.CustomTokenizer", "custom_tokenization_fast.CustomTokenizerFast")
}

View File

@@ -0,0 +1,5 @@
from transformers import LlavaOnevisionVideoProcessor
class CustomVideoProcessor(LlavaOnevisionVideoProcessor):
pass

1189
utils/tests_fetcher.py Normal file

File diff suppressed because it is too large Load Diff

360
utils/update_metadata.py Executable file
View File

@@ -0,0 +1,360 @@
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that updates the metadata of the Transformers library in the repository `huggingface/transformers-metadata`.
Usage for an update (as used by the GitHub action `update_metadata`):
```bash
python utils/update_metadata.py --token <token> --commit_sha <commit_sha>
```
Usage to check all pipelines are properly defined in the constant `PIPELINE_TAGS_AND_AUTO_MODELS` of this script, so
that new pipelines are properly added as metadata (as used in `make check-repo`):
```bash
python utils/update_metadata.py --check-only
```
"""
import argparse
import collections
import os
import re
import tempfile
import pandas as pd
from datasets import Dataset
from huggingface_hub import hf_hub_download, upload_folder
from transformers.utils import direct_transformers_import
CHECKER_CONFIG = {
"name": "update_metadata",
"label": "Model metadata",
# Approximate: imports the transformers module and inspects pipeline/auto mappings
# at runtime. Does not iterate over files matching these globs directly.
"cache_globs": ["src/transformers/models/**/*.py", "docs/**/*.md"],
"check_args": ["--check-only"],
# No safe local "fix" mode: running without `--check-only` pushes to the
# `huggingface/transformers-metadata` Hub dataset (requires an auth token).
# `fix_args=None` makes `make fix-repo` skip this checker, like other check-only ones.
"fix_args": None,
}
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/update_metadata.py
TRANSFORMERS_PATH = "src/transformers"
# This is to make sure the transformers module imported is the one in the repo.
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
# Regexes that match model names
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)")
# Fill this with tuples (pipeline_tag, model_mapping, auto_model)
PIPELINE_TAGS_AND_AUTO_MODELS = [
("pretraining", "MODEL_FOR_PRETRAINING_MAPPING_NAMES", "AutoModelForPreTraining"),
("feature-extraction", "MODEL_MAPPING_NAMES", "AutoModel"),
("image-feature-extraction", "MODEL_FOR_IMAGE_MAPPING_NAMES", "AutoModel"),
("audio-classification", "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForAudioClassification"),
("text-generation", "MODEL_FOR_CAUSAL_LM_MAPPING_NAMES", "AutoModelForCausalLM"),
("automatic-speech-recognition", "MODEL_FOR_CTC_MAPPING_NAMES", "AutoModelForCTC"),
("image-classification", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForImageClassification"),
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
("any-to-any", "MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES", "AutoModelForMultimodalLM"),
("image-text-to-text", "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "AutoModelForImageTextToText"),
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
(
"zero-shot-object-detection",
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES",
"AutoModelForZeroShotObjectDetection",
),
("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"),
("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"),
("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"),
("automatic-speech-recognition", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", "AutoModelForSpeechSeq2Seq"),
(
"table-question-answering",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES",
"AutoModelForTableQuestionAnswering",
),
("token-classification", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES", "AutoModelForTokenClassification"),
("multiple-choice", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES", "AutoModelForMultipleChoice"),
(
"next-sentence-prediction",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES",
"AutoModelForNextSentencePrediction",
),
(
"audio-frame-classification",
"MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES",
"AutoModelForAudioFrameClassification",
),
("audio-xvector", "MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "AutoModelForAudioXVector"),
(
"document-question-answering",
"MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES",
"AutoModelForDocumentQuestionAnswering",
),
(
"visual-question-answering",
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES",
"AutoModelForVisualQuestionAnswering",
),
(
"zero-shot-image-classification",
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
"AutoModelForZeroShotImageClassification",
),
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
("text-to-audio", "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES", "AutoModelForTextToSpectrogram"),
("text-to-audio", "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES", "AutoModelForTextToWaveform"),
("keypoint-matching", "MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES", "AutoModelForKeypointMatching"),
]
def camel_case_split(identifier: str) -> list[str]:
"""
Split a camel-cased name into words.
Args:
identifier (`str`): The camel-cased name to parse.
Returns:
`List[str]`: The list of words in the identifier (as separated by capital letters).
Example:
```py
>>> camel_case_split("CamelCasedClass")
["Camel", "Cased", "Class"]
```
"""
# Regex thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
return [m.group(0) for m in matches]
def get_frameworks_table() -> pd.DataFrame:
"""
Generates a dataframe containing the supported auto classes for each model type, using the content of the auto
modules.
"""
# Dictionary model names to config.
config_mapping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
model_prefix_to_model_type = {
config.replace("Config", ""): model_type for model_type, config in config_mapping_names.items()
}
pt_models = collections.defaultdict(bool)
# Let's lookup through all transformers object (once) and find if models are supported by a given backend.
for attr_name in dir(transformers_module):
lookup_dict = None
if _re_pt_models.match(attr_name) is not None:
lookup_dict = pt_models
attr_name = _re_pt_models.match(attr_name).groups()[0]
if lookup_dict is not None:
while len(attr_name) > 0:
if attr_name in model_prefix_to_model_type:
lookup_dict[model_prefix_to_model_type[attr_name]] = True
break
# Try again after removing the last word in the name
attr_name = "".join(camel_case_split(attr_name)[:-1])
all_models = set(pt_models.keys())
all_models = list(all_models)
all_models.sort()
data = {"model_type": all_models}
data["pytorch"] = [pt_models[t] for t in all_models]
# Now let's find the right processing class for each model. In order we check if there is a Processor, then a
# Tokenizer, then a FeatureExtractor, then an ImageProcessor
processors = {}
for t in all_models:
if t in transformers_module.models.auto.processing_auto.PROCESSOR_MAPPING_NAMES:
processors[t] = "AutoProcessor"
elif t in transformers_module.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES:
processors[t] = "AutoTokenizer"
elif t in transformers_module.models.auto.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES:
processors[t] = "AutoImageProcessor"
elif t in transformers_module.models.auto.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES:
processors[t] = "AutoFeatureExtractor"
else:
# Default to AutoTokenizer if a model has nothing, for backward compatibility.
processors[t] = "AutoTokenizer"
data["processor"] = [processors[t] for t in all_models]
return pd.DataFrame(data)
def update_pipeline_and_auto_class_table(table: dict[str, tuple[str, str]]) -> dict[str, tuple[str, str]]:
"""
Update the table mapping models to pipelines and auto classes without removing old keys if they don't exist anymore.
Args:
table (`Dict[str, Tuple[str, str]]`):
The existing table mapping model names to a tuple containing the pipeline tag and the auto-class name with
which they should be used.
Returns:
`Dict[str, Tuple[str, str]]`: The updated table in the same format.
"""
module = transformers_module.models.auto.modeling_auto
for pipeline_tag, model_mapping, cls in PIPELINE_TAGS_AND_AUTO_MODELS:
if not hasattr(module, model_mapping):
continue
# First extract all model_names
model_names = []
for name in getattr(module, model_mapping).values():
if isinstance(name, str):
model_names.append(name)
else:
model_names.extend(list(name))
# Add pipeline tag and auto model class for those models
table.update(dict.fromkeys(model_names, (pipeline_tag, cls)))
return table
def update_metadata(token: str, commit_sha: str):
"""
Update the metadata for the Transformers repo in `huggingface/transformers-metadata`.
Args:
token (`str`): A valid token giving write access to `huggingface/transformers-metadata`.
commit_sha (`str`): The commit SHA on Transformers corresponding to this update.
"""
frameworks_table = get_frameworks_table()
frameworks_dataset = Dataset.from_pandas(frameworks_table)
resolved_tags_file = hf_hub_download(
"huggingface/transformers-metadata", "pipeline_tags.json", repo_type="dataset", token=token
)
tags_dataset = Dataset.from_json(resolved_tags_file)
table = {
tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"])
for i in range(len(tags_dataset))
}
table = update_pipeline_and_auto_class_table(table)
# Sort the model classes to avoid some nondeterministic updates to create false update commits.
model_classes = sorted(table.keys())
tags_table = pd.DataFrame(
{
"model_class": model_classes,
"pipeline_tag": [table[m][0] for m in model_classes],
"auto_class": [table[m][1] for m in model_classes],
}
)
tags_dataset = Dataset.from_pandas(tags_table)
hub_frameworks_json = hf_hub_download(
repo_id="huggingface/transformers-metadata",
filename="frameworks.json",
repo_type="dataset",
token=token,
)
with open(hub_frameworks_json) as f:
hub_frameworks_json = f.read()
hub_pipeline_tags_json = hf_hub_download(
repo_id="huggingface/transformers-metadata",
filename="pipeline_tags.json",
repo_type="dataset",
token=token,
)
with open(hub_pipeline_tags_json) as f:
hub_pipeline_tags_json = f.read()
with tempfile.TemporaryDirectory() as tmp_dir:
frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json"))
tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
with open(os.path.join(tmp_dir, "frameworks.json")) as f:
frameworks_json = f.read()
with open(os.path.join(tmp_dir, "pipeline_tags.json")) as f:
pipeline_tags_json = f.read()
frameworks_equal = hub_frameworks_json == frameworks_json
hub_pipeline_tags_equal = hub_pipeline_tags_json == pipeline_tags_json
if frameworks_equal and hub_pipeline_tags_equal:
print("No updates on the Hub, not pushing the metadata files.")
return
if commit_sha is not None:
commit_message = (
f"Update with commit {commit_sha}\n\nSee: "
f"https://github.com/huggingface/transformers/commit/{commit_sha}"
)
else:
commit_message = "Update"
upload_folder(
repo_id="huggingface/transformers-metadata",
folder_path=tmp_dir,
repo_type="dataset",
token=token,
commit_message=commit_message,
)
def check_pipeline_tags():
"""
Check all pipeline tags are properly defined in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant of this script.
"""
in_table = {tag: cls for tag, _, cls in PIPELINE_TAGS_AND_AUTO_MODELS}
pipeline_tasks = transformers_module.pipelines.SUPPORTED_TASKS
missing = []
for key in pipeline_tasks:
if key not in in_table:
model = pipeline_tasks[key]["pt"]
if isinstance(model, (list, tuple)):
model = model[0]
model = model.__name__
if model not in in_table.values():
missing.append(key)
if len(missing) > 0:
msg = ", ".join(missing)
raise ValueError(
"The following pipeline tags are not present in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant inside "
f"`utils/update_metadata.py`: {msg}. Please add them!"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--token", type=str, help="The token to use to push to the transformers-metadata dataset.")
parser.add_argument("--commit_sha", type=str, help="The sha of the commit going with this update.")
parser.add_argument("--check-only", action="store_true", help="Activate to just check all pipelines are present.")
args = parser.parse_args()
if args.check_only:
check_pipeline_tags()
else:
update_metadata(args.token, args.commit_sha)

173
utils/update_tiny_models.py Normal file
View File

@@ -0,0 +1,173 @@
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script running `create_dummy_models.py` with a pre-defined set of arguments.
This file is intended to be used in a CI workflow file without the need of specifying arguments. It creates and uploads
tiny models for all model classes (if their tiny versions are not on the Hub yet), as well as produces an updated
version of `tests/utils/tiny_model_summary.json`. That updated file should be merged into the `main` branch of
`transformers` so the pipeline testing will use the latest created/updated tiny models.
"""
import argparse
import json
import multiprocessing
import os
import time
from create_dummy_models import COMPOSITE_MODELS, create_tiny_models
from huggingface_hub import HfApi
import transformers
from transformers import AutoFeatureExtractor, AutoImageProcessor, AutoTokenizer, logging
from transformers.image_processing_utils import BaseImageProcessor
logger = logging.get_logger(__name__)
def get_all_model_names():
model_names = set()
module_name = "modeling_auto"
module = getattr(transformers.models.auto, module_name, None)
if module is not None:
# all mappings in a single auto modeling file
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES") and x.startswith("MODEL_")]
for name in mapping_names:
mapping = getattr(module, name)
if mapping is not None:
for v in mapping.values():
if isinstance(v, (list, tuple)):
model_names.update(v)
elif isinstance(v, str):
model_names.add(v)
return sorted(model_names)
def get_tiny_model_names_from_repo():
with open("tests/utils/tiny_model_summary.json") as fp:
tiny_model_info = json.load(fp)
tiny_models_names = set()
for model_base_name in tiny_model_info:
tiny_models_names.update(tiny_model_info[model_base_name]["model_classes"])
return sorted(tiny_models_names)
def get_tiny_model_summary_from_hub(output_path):
api = HfApi()
special_models = COMPOSITE_MODELS.values()
# All tiny model base names on Hub
model_names = get_all_model_names()
models = api.list_models(author="hf-internal-testing")
_models = set()
for x in models:
model = x.id
org, model = model.split("/")
if not model.startswith("tiny-random-"):
continue
model = model.replace("tiny-random-", "")
if not model[0].isupper():
continue
if model not in model_names and model not in special_models:
continue
_models.add(model)
models = sorted(_models)
# All tiny model names on Hub
summary = {}
for model in models:
repo_id = f"hf-internal-testing/tiny-random-{model}"
model = model.split("-")[0]
try:
repo_info = api.repo_info(repo_id)
content = {
"tokenizer_classes": set(),
"processor_classes": set(),
"model_classes": set(),
"sha": repo_info.sha,
}
except Exception:
continue
try:
time.sleep(1)
tokenizer_fast = AutoTokenizer.from_pretrained(repo_id)
content["tokenizer_classes"].add(tokenizer_fast.__class__.__name__)
except Exception as e:
logger.debug(f"Could not load fast tokenizer for {repo_id}: {e}")
try:
time.sleep(1)
tokenizer_slow = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
content["tokenizer_classes"].add(tokenizer_slow.__class__.__name__)
except Exception as e:
logger.debug(f"Could not load slow tokenizer for {repo_id}: {e}")
try:
time.sleep(1)
img_p = AutoImageProcessor.from_pretrained(repo_id)
content["processor_classes"].add(img_p.__class__.__name__)
except Exception as e:
logger.debug(f"Could not load image processor for {repo_id}: {e}")
try:
time.sleep(1)
feat_p = AutoFeatureExtractor.from_pretrained(repo_id)
if not isinstance(feat_p, BaseImageProcessor):
content["processor_classes"].add(feat_p.__class__.__name__)
except Exception as e:
logger.debug(f"Could not load feature extractor for {repo_id}: {e}")
try:
time.sleep(1)
model_class = getattr(transformers, model)
m = model_class.from_pretrained(repo_id)
content["model_classes"].add(m.__class__.__name__)
except Exception as e:
logger.debug(f"Could not load model for {repo_id}: {e}")
content["tokenizer_classes"] = sorted(content["tokenizer_classes"])
content["processor_classes"] = sorted(content["processor_classes"])
content["model_classes"] = sorted(content["model_classes"])
summary[model] = content
with open(os.path.join(output_path, "hub_tiny_model_summary.json"), "w") as fp:
json.dump(summary, fp, ensure_ascii=False, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_workers", default=1, type=int, help="The number of workers to run.")
args = parser.parse_args()
# This has to be `spawn` to avoid hanging forever!
multiprocessing.set_start_method("spawn")
output_path = "tiny_models"
all = True
model_types = None
models_to_skip = get_tiny_model_names_from_repo()
no_check = True
upload = True
organization = "hf-internal-testing"
create_tiny_models(
output_path,
all,
model_types,
models_to_skip,
no_check,
upload,
organization,
token=os.environ.get("TOKEN", None),
num_workers=args.num_workers,
)