first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
427
utils/add_dates.py
Normal file
427
utils/add_dates.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from datetime import date, datetime
|
||||
from urllib.error import HTTPError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from huggingface_hub import paper_info
|
||||
|
||||
from transformers import logging
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "add_dates",
|
||||
"label": "Model dates",
|
||||
# Approximate: also reads docs/source/en/model_doc/*.md and uses git log + network
|
||||
# calls to GitHub/HuggingFace for commit dates and paper metadata.
|
||||
"cache_globs": ["src/transformers/models/**/__init__.py", "docs/source/en/model_doc/**/*.md"],
|
||||
"check_args": ["--check-only"],
|
||||
"fix_args": [],
|
||||
}
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ROOT = os.getcwd().split("utils")[0]
|
||||
DOCS_PATH = os.path.join(ROOT, "docs/source/en/model_doc")
|
||||
MODELS_PATH = os.path.join(ROOT, "src/transformers/models")
|
||||
GITHUB_REPO_URL = "https://github.com/huggingface/transformers"
|
||||
GITHUB_RAW_URL = "https://raw.githubusercontent.com/huggingface/transformers/main"
|
||||
|
||||
COPYRIGHT_DISCLAIMER = """<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->"""
|
||||
|
||||
ARXIV_PAPERS_NOT_IN_HF_PAPERS = {
|
||||
"gemma3n.md": "2506.06644",
|
||||
"xmod.md": "2205.06266",
|
||||
}
|
||||
|
||||
|
||||
def check_file_exists_on_github(file_path: str) -> bool:
|
||||
"""Check if a file exists on the main branch of the GitHub repository.
|
||||
|
||||
Args:
|
||||
file_path: Relative path from repository root
|
||||
|
||||
Returns:
|
||||
True if file exists on GitHub main branch (or if check failed), False only if confirmed 404
|
||||
|
||||
Note:
|
||||
On network errors or other issues, returns True (assumes file exists) with a warning.
|
||||
This prevents the script from failing due to temporary network issues.
|
||||
"""
|
||||
# Convert absolute path to relative path from repository root if needed
|
||||
if file_path.startswith(ROOT):
|
||||
file_path = file_path[len(ROOT) :].lstrip("/")
|
||||
|
||||
# Construct the raw GitHub URL for the file
|
||||
url = f"{GITHUB_RAW_URL}/{file_path}"
|
||||
|
||||
try:
|
||||
# Make a HEAD request to check if file exists (more efficient than GET)
|
||||
request = Request(url, method="HEAD")
|
||||
request.add_header("User-Agent", "transformers-add-dates-script")
|
||||
|
||||
with urlopen(request, timeout=10) as response:
|
||||
return response.status == 200
|
||||
except HTTPError as e:
|
||||
if e.code == 404:
|
||||
# File doesn't exist on GitHub
|
||||
return False
|
||||
# HTTP error (non-404): assume file exists and continue with local git history
|
||||
return True
|
||||
except Exception:
|
||||
# Network/timeout error: assume file exists and continue with local git history
|
||||
return True
|
||||
|
||||
|
||||
def get_modified_cards() -> list[str]:
|
||||
"""Get the list of model names from modified files in docs/source/en/model_doc/"""
|
||||
|
||||
current_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip()
|
||||
if current_branch == "main":
|
||||
# On main branch, only uncommitted changes detected
|
||||
result = subprocess.check_output(["git", "diff", "--name-only", "HEAD"], text=True)
|
||||
else:
|
||||
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
|
||||
result = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8")
|
||||
|
||||
model_names = []
|
||||
for line in result.strip().split("\n"):
|
||||
if line:
|
||||
# Check if the file is in the model_doc directory
|
||||
if line.startswith("docs/source/en/model_doc/") and line.endswith(".md"):
|
||||
file_path = os.path.join(ROOT, line)
|
||||
if os.path.exists(file_path):
|
||||
model_name = os.path.splitext(os.path.basename(line))[0]
|
||||
if model_name not in ["auto", "timm_wrapper"]:
|
||||
model_names.append(model_name)
|
||||
|
||||
return model_names
|
||||
|
||||
|
||||
def get_paper_link(model_card: str | None, path: str | None) -> str:
|
||||
"""Get the first paper link from the model card content."""
|
||||
|
||||
if model_card is not None and not model_card.endswith(".md"):
|
||||
model_card = f"{model_card}.md"
|
||||
file_path = path or os.path.join(DOCS_PATH, f"{model_card}")
|
||||
model_card = os.path.basename(file_path)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find known paper links
|
||||
paper_ids = re.findall(r"https://huggingface\.co/papers/\d+\.\d+", content)
|
||||
paper_ids += re.findall(r"https://arxiv\.org/abs/\d+\.\d+", content)
|
||||
paper_ids += re.findall(r"https://arxiv\.org/pdf/\d+\.\d+", content)
|
||||
|
||||
if len(paper_ids) == 0:
|
||||
return "No_paper"
|
||||
|
||||
return paper_ids[0]
|
||||
|
||||
|
||||
def get_first_commit_date(model_name: str | None) -> str:
|
||||
"""Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers"""
|
||||
|
||||
if model_name.endswith(".md"):
|
||||
model_name = f"{model_name[:-3]}"
|
||||
|
||||
model_name_src = model_name
|
||||
if "-" in model_name:
|
||||
model_name_src = model_name.replace("-", "_")
|
||||
file_path = os.path.join(MODELS_PATH, model_name_src, "__init__.py")
|
||||
|
||||
# If the init file is not found (only true for legacy models), the doc's first commit date is used
|
||||
if not os.path.exists(file_path):
|
||||
file_path = os.path.join(DOCS_PATH, f"{model_name}.md")
|
||||
|
||||
# Check if file exists on GitHub main branch
|
||||
file_exists_on_github = check_file_exists_on_github(file_path)
|
||||
|
||||
if not file_exists_on_github:
|
||||
# File does not exist on GitHub main branch (new model), use today's date
|
||||
final_date = date.today().isoformat()
|
||||
else:
|
||||
# File exists on GitHub main branch, get the first commit date from local git history
|
||||
final_date = subprocess.check_output(
|
||||
["git", "log", "--reverse", "--pretty=format:%ad", "--date=iso", file_path], text=True
|
||||
)
|
||||
return final_date.strip().split("\n")[0][:10]
|
||||
|
||||
|
||||
def get_release_date(link: str) -> str:
|
||||
if link.startswith("https://huggingface.co/papers/"):
|
||||
link = link.replace("https://huggingface.co/papers/", "")
|
||||
|
||||
try:
|
||||
info = paper_info(link)
|
||||
return info.published_at.date().isoformat()
|
||||
except Exception as e:
|
||||
# Error fetching release date, function returns None (will use placeholder)
|
||||
logger.debug(f"Could not fetch paper info for {link}: {e}")
|
||||
|
||||
elif link.startswith("https://arxiv.org/abs/") or link.startswith("https://arxiv.org/pdf/"):
|
||||
return r"{release_date}"
|
||||
|
||||
|
||||
def replace_paper_links(file_path: str) -> bool:
|
||||
"""Replace arxiv links with huggingface links if valid, and replace hf.co with huggingface.co"""
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
original_content = content
|
||||
|
||||
# Replace hf.co with huggingface.co
|
||||
content = content.replace("https://hf.co/", "https://huggingface.co/")
|
||||
|
||||
# Find all arxiv links
|
||||
arxiv_links = re.findall(r"https://arxiv\.org/abs/(\d+\.\d+)", content)
|
||||
arxiv_links += re.findall(r"https://arxiv\.org/pdf/(\d+\.\d+)", content)
|
||||
|
||||
for paper_id in arxiv_links:
|
||||
try:
|
||||
# Check if paper exists on huggingface
|
||||
paper_info(paper_id)
|
||||
# If no exception, replace the link
|
||||
old_link = f"https://arxiv.org/abs/{paper_id}"
|
||||
if old_link not in content:
|
||||
old_link = f"https://arxiv.org/pdf/{paper_id}"
|
||||
new_link = f"https://huggingface.co/papers/{paper_id}"
|
||||
content = content.replace(old_link, new_link)
|
||||
|
||||
except Exception:
|
||||
# Paper not available on huggingface, keep arxiv link
|
||||
continue
|
||||
|
||||
# Write back only if content changed
|
||||
if content != original_content:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_model_card_name(model_card: str) -> str:
|
||||
"""Ensure model card has .md extension"""
|
||||
return model_card if model_card.endswith(".md") else f"{model_card}.md"
|
||||
|
||||
|
||||
def _should_skip_model_card(model_card: str) -> bool:
|
||||
"""Check if model card should be skipped"""
|
||||
return model_card in ("auto.md", "timm_wrapper.md")
|
||||
|
||||
|
||||
def _read_model_card_content(model_card: str) -> str:
|
||||
"""Read and return the content of a model card"""
|
||||
file_path = os.path.join(DOCS_PATH, model_card)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _get_dates_pattern_match(content: str):
|
||||
"""Search for the dates pattern in content and return match object"""
|
||||
pattern = r"\n\*This model was (?:published in HF papers on (.*) and )?contributed to Hugging Face Transformers on (\d{4}-\d{2}-\d{2})\.\*"
|
||||
return re.search(pattern, content)
|
||||
|
||||
|
||||
def _dates_differ_significantly(date1: str, date2: str) -> bool:
|
||||
"""Check if two dates differ by more than 1 day"""
|
||||
try:
|
||||
d1 = datetime.strptime(date1, "%Y-%m-%d")
|
||||
d2 = datetime.strptime(date2, "%Y-%m-%d")
|
||||
return abs((d1 - d2).days) > 1
|
||||
except Exception:
|
||||
return True # If dates can't be parsed, consider them different
|
||||
|
||||
|
||||
def check_missing_dates(model_card_list: list[str]) -> list[str]:
|
||||
"""Check which model cards are missing release dates and return their names"""
|
||||
missing_dates = []
|
||||
|
||||
for model_card in model_card_list:
|
||||
model_card = _normalize_model_card_name(model_card)
|
||||
if _should_skip_model_card(model_card):
|
||||
continue
|
||||
|
||||
content = _read_model_card_content(model_card)
|
||||
if not _get_dates_pattern_match(content):
|
||||
missing_dates.append(model_card)
|
||||
|
||||
return missing_dates
|
||||
|
||||
|
||||
def check_incorrect_dates(model_card_list: list[str]) -> list[str]:
|
||||
"""Check which model cards have incorrect model release/addition dates and return their names"""
|
||||
incorrect_dates = []
|
||||
|
||||
for model_card in model_card_list:
|
||||
model_card = _normalize_model_card_name(model_card)
|
||||
if _should_skip_model_card(model_card):
|
||||
continue
|
||||
|
||||
content = _read_model_card_content(model_card)
|
||||
match = _get_dates_pattern_match(content)
|
||||
|
||||
file_path = os.path.join(DOCS_PATH, model_card)
|
||||
paper_link = get_paper_link(model_card=model_card, path=file_path)
|
||||
|
||||
if paper_link in ("No_paper", "blog"):
|
||||
release_date = r"{release_date}"
|
||||
else:
|
||||
release_date = get_release_date(paper_link)
|
||||
|
||||
if match:
|
||||
# Preserve existing release date unless it's a placeholder
|
||||
existing_release_date = match.group(1)
|
||||
if existing_release_date not in (r"{release_date}", "None"):
|
||||
release_date = existing_release_date
|
||||
|
||||
existing_hf_date = match.group(2)
|
||||
actual_hf_date = get_first_commit_date(model_name=model_card)
|
||||
|
||||
if _dates_differ_significantly(existing_hf_date, actual_hf_date) or existing_release_date != release_date:
|
||||
incorrect_dates.append(model_card)
|
||||
|
||||
return incorrect_dates
|
||||
|
||||
|
||||
def insert_dates(model_card_list: list[str]):
|
||||
"""Insert or update release and commit dates in model cards"""
|
||||
for model_card in model_card_list:
|
||||
model_card = _normalize_model_card_name(model_card)
|
||||
if _should_skip_model_card(model_card):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(DOCS_PATH, model_card)
|
||||
|
||||
# First replace arxiv paper links with hf paper link if possible
|
||||
replace_paper_links(file_path)
|
||||
|
||||
# Read content and ensure copyright disclaimer exists
|
||||
content = _read_model_card_content(model_card)
|
||||
markers = list(re.finditer(r"-->", content))
|
||||
|
||||
if len(markers) == 0:
|
||||
# No copyright marker found, adding disclaimer to the top
|
||||
content = COPYRIGHT_DISCLAIMER + "\n\n" + content
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
markers = list(re.finditer(r"-->", content))
|
||||
|
||||
# Get dates
|
||||
hf_commit_date = get_first_commit_date(model_name=model_card)
|
||||
paper_link = get_paper_link(model_card=model_card, path=file_path)
|
||||
|
||||
if paper_link in ("No_paper", "blog"):
|
||||
release_date = r"{release_date}"
|
||||
else:
|
||||
release_date = get_release_date(paper_link)
|
||||
|
||||
match = _get_dates_pattern_match(content)
|
||||
|
||||
# Update or insert the dates line
|
||||
if match:
|
||||
# Preserve existing release date unless it's a placeholder
|
||||
existing_release_date = match.group(1)
|
||||
existing_hf_date = match.group(2)
|
||||
|
||||
if existing_release_date not in (r"{release_date}", "None"):
|
||||
release_date = existing_release_date
|
||||
|
||||
if _dates_differ_significantly(existing_hf_date, hf_commit_date) or existing_release_date != release_date:
|
||||
old_line = match.group(0)
|
||||
if release_date != r"{release_date}":
|
||||
new_line = f"\n*This model was published in HF papers on {release_date} and contributed to Hugging Face Transformers on {hf_commit_date}.*"
|
||||
else:
|
||||
new_line = f"\n*This model was contributed to Hugging Face Transformers on {hf_commit_date}.*"
|
||||
content = content.replace(old_line, new_line)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
else:
|
||||
# Insert new dates line after copyright marker
|
||||
insert_index = markers[0].end()
|
||||
if release_date != r"{release_date}":
|
||||
date_info = f"\n*This model was published in HF papers on {release_date} and contributed to Hugging Face Transformers on {hf_commit_date}.*"
|
||||
else:
|
||||
date_info = f"\n*This model was contributed to Hugging Face Transformers on {hf_commit_date}.*"
|
||||
content = content[:insert_index] + date_info + content[insert_index:]
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def get_all_model_cards():
|
||||
"""Get all model cards from the docs path"""
|
||||
|
||||
all_files = os.listdir(DOCS_PATH)
|
||||
model_cards = []
|
||||
for file in all_files:
|
||||
if file.endswith(".md"):
|
||||
model_name = os.path.splitext(file)[0]
|
||||
if model_name not in ["auto", "timm_wrapper"]:
|
||||
model_cards.append(model_name)
|
||||
return sorted(model_cards)
|
||||
|
||||
|
||||
def main(all=False, models=None, check_only=False):
|
||||
if check_only:
|
||||
# Check all model cards for missing dates
|
||||
all_model_cards = get_all_model_cards()
|
||||
missing_dates = check_missing_dates(all_model_cards)
|
||||
|
||||
# Check modified model cards for incorrect dates
|
||||
modified_cards = get_modified_cards()
|
||||
incorrect_dates = check_incorrect_dates(modified_cards)
|
||||
|
||||
if missing_dates or incorrect_dates:
|
||||
problematic_cards = missing_dates + incorrect_dates
|
||||
model_names = [card.replace(".md", "") for card in problematic_cards]
|
||||
raise ValueError(
|
||||
f"Missing or incorrect dates in the following model cards: {' '.join(problematic_cards)}\n"
|
||||
f"Run `python utils/add_dates.py --models {' '.join(model_names)}` to fix them."
|
||||
)
|
||||
return
|
||||
|
||||
# Determine which model cards to process
|
||||
if all:
|
||||
model_cards = get_all_model_cards()
|
||||
elif models:
|
||||
model_cards = models
|
||||
else:
|
||||
model_cards = get_modified_cards()
|
||||
if not model_cards:
|
||||
return
|
||||
|
||||
insert_dates(model_cards)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Add release and commit dates to model cards")
|
||||
group = parser.add_mutually_exclusive_group(required=False)
|
||||
group.add_argument("--models", nargs="+", help="Specify model cards to process (without .md extension)")
|
||||
group.add_argument("--all", action="store_true", help="Process all model cards in the docs directory")
|
||||
group.add_argument("--check-only", action="store_true", help="Check if the dates are already present")
|
||||
|
||||
args = parser.parse_args()
|
||||
try:
|
||||
main(args.all, args.models, args.check_only)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(
|
||||
f"An error occurred while executing git commands but it can be ignored (git issue) most probably local: {e}"
|
||||
)
|
||||
307
utils/add_pipeline_model_mapping_to_test.py
Normal file
307
utils/add_pipeline_model_mapping_to_test.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A script to add and/or update the attribute `pipeline_model_mapping` in model test files.
|
||||
|
||||
This script will be (mostly) used in the following 2 situations:
|
||||
|
||||
- run within a (scheduled) CI job to:
|
||||
- check if model test files in the library have updated `pipeline_model_mapping`,
|
||||
- and/or update test files and (possibly) open a GitHub pull request automatically
|
||||
- being run by a `transformers` member to quickly check and update some particular test file(s)
|
||||
|
||||
This script is **NOT** intended to be run (manually) by community contributors.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from get_test_info import get_test_classes
|
||||
|
||||
from tests.test_pipeline_mixin import pipeline_test_mapping
|
||||
|
||||
|
||||
PIPELINE_TEST_MAPPING = {}
|
||||
for task in pipeline_test_mapping:
|
||||
PIPELINE_TEST_MAPPING[task] = None
|
||||
|
||||
|
||||
# DO **NOT** add item to this set (unless the reason is approved)
|
||||
TEST_FILE_TO_IGNORE = {
|
||||
"tests/models/esm/test_modeling_esmfold.py", # The pipeline test mapping is added to `test_modeling_esm.py`
|
||||
}
|
||||
|
||||
|
||||
def get_mapping_for_task(task):
|
||||
"""Get mappings defined in `XXXPipelineTests` for the task `task`."""
|
||||
# Use the cached results
|
||||
if PIPELINE_TEST_MAPPING[task] is not None:
|
||||
return PIPELINE_TEST_MAPPING[task]
|
||||
|
||||
pipeline_test_class = pipeline_test_mapping[task]["test"]
|
||||
mapping = getattr(pipeline_test_class, "model_mapping", None)
|
||||
|
||||
if mapping is not None:
|
||||
mapping = dict(mapping.items())
|
||||
|
||||
# cache the results
|
||||
PIPELINE_TEST_MAPPING[task] = mapping
|
||||
return mapping
|
||||
|
||||
|
||||
def get_model_for_pipeline_test(test_class, task):
|
||||
"""Get the model architecture(s) related to the test class `test_class` for a pipeline `task`."""
|
||||
mapping = get_mapping_for_task(task)
|
||||
if mapping is None:
|
||||
return None
|
||||
|
||||
config_classes = list({model_class.config_class for model_class in test_class.all_model_classes})
|
||||
if len(config_classes) != 1:
|
||||
raise ValueError("There should be exactly one configuration class from `test_class.all_model_classes`.")
|
||||
|
||||
# This could be a list/tuple of model classes, but it's rare.
|
||||
model_class = mapping.get(config_classes[0], None)
|
||||
if isinstance(model_class, (tuple, list)):
|
||||
model_class = sorted(model_class, key=lambda x: x.__name__)
|
||||
|
||||
return model_class
|
||||
|
||||
|
||||
def get_pipeline_model_mapping(test_class):
|
||||
"""Get `pipeline_model_mapping` for `test_class`."""
|
||||
mapping = [(task, get_model_for_pipeline_test(test_class, task)) for task in pipeline_test_mapping]
|
||||
mapping = sorted([(task, model) for task, model in mapping if model is not None], key=lambda x: x[0])
|
||||
|
||||
return dict(mapping)
|
||||
|
||||
|
||||
def get_pipeline_model_mapping_string(test_class):
|
||||
"""Get `pipeline_model_mapping` for `test_class` as a string (to be added to the test file).
|
||||
|
||||
This will be a 1-line string. After this is added to a test file, `make style` will format it beautifully.
|
||||
"""
|
||||
default_value = "{}"
|
||||
mapping = get_pipeline_model_mapping(test_class)
|
||||
if len(mapping) == 0:
|
||||
return ""
|
||||
|
||||
texts = []
|
||||
for task, model_classes in mapping.items():
|
||||
if isinstance(model_classes, (tuple, list)):
|
||||
# A list/tuple of model classes
|
||||
value = "(" + ", ".join([x.__name__ for x in model_classes]) + ")"
|
||||
else:
|
||||
# A single model class
|
||||
value = model_classes.__name__
|
||||
texts.append(f'"{task}": {value}')
|
||||
text = "{" + ", ".join(texts) + "}"
|
||||
text = f"pipeline_model_mapping = {text} if is_torch_available() else {default_value}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def is_valid_test_class(test_class):
|
||||
"""Restrict to `XXXModelTesterMixin` and should be a subclass of `unittest.TestCase`."""
|
||||
if not issubclass(test_class, unittest.TestCase):
|
||||
return False
|
||||
return "ModelTesterMixin" in [x.__name__ for x in test_class.__bases__]
|
||||
|
||||
|
||||
def find_test_class(test_file):
|
||||
"""Find a test class in `test_file` to which we will add `pipeline_model_mapping`."""
|
||||
test_classes = [x for x in get_test_classes(test_file) if is_valid_test_class(x)]
|
||||
|
||||
target_test_class = None
|
||||
for test_class in test_classes:
|
||||
# If a test class has defined `pipeline_model_mapping`, let's take it
|
||||
if getattr(test_class, "pipeline_model_mapping", None) is not None:
|
||||
target_test_class = test_class
|
||||
break
|
||||
# Take the test class with the shortest name (just a heuristic)
|
||||
if target_test_class is None and len(test_classes) > 0:
|
||||
target_test_class = min(test_classes, key=lambda x: (len(x.__name__), x.__name__))
|
||||
|
||||
return target_test_class
|
||||
|
||||
|
||||
def find_block_ending(lines, start_idx, indent_level):
|
||||
end_idx = start_idx
|
||||
for idx, line in enumerate(lines[start_idx:]):
|
||||
indent = len(line) - len(line.lstrip())
|
||||
if idx == 0 or indent > indent_level or (indent == indent_level and line.strip() == ")"):
|
||||
end_idx = start_idx + idx
|
||||
elif idx > 0 and indent <= indent_level:
|
||||
# Outside the definition block of `pipeline_model_mapping`
|
||||
break
|
||||
|
||||
return end_idx
|
||||
|
||||
|
||||
def add_pipeline_model_mapping(test_class, overwrite=False):
|
||||
"""Add `pipeline_model_mapping` to `test_class`."""
|
||||
if getattr(test_class, "pipeline_model_mapping", None) is not None:
|
||||
if not overwrite:
|
||||
return "", -1
|
||||
|
||||
line_to_add = get_pipeline_model_mapping_string(test_class)
|
||||
if len(line_to_add) == 0:
|
||||
return "", -1
|
||||
line_to_add = line_to_add + "\n"
|
||||
|
||||
# The code defined the class `test_class`
|
||||
class_lines, class_start_line_no = inspect.getsourcelines(test_class)
|
||||
# `inspect` gives the code for an object, including decorator(s) if any.
|
||||
# We (only) need the exact line of the class definition.
|
||||
for idx, line in enumerate(class_lines):
|
||||
if line.lstrip().startswith("class "):
|
||||
class_lines = class_lines[idx:]
|
||||
class_start_line_no += idx
|
||||
break
|
||||
class_end_line_no = class_start_line_no + len(class_lines) - 1
|
||||
|
||||
# The index in `class_lines` that starts the definition of `all_model_classes`, `all_generative_model_classes` or
|
||||
# `pipeline_model_mapping`. This assumes they are defined in such order, and we take the start index of the last
|
||||
# block that appears in a `test_class`.
|
||||
start_idx = None
|
||||
# The indent level of the line at `class_lines[start_idx]` (if defined)
|
||||
indent_level = 0
|
||||
# To record if `pipeline_model_mapping` is found in `test_class`.
|
||||
def_line = None
|
||||
for idx, line in enumerate(class_lines):
|
||||
if line.strip().startswith("all_model_classes = "):
|
||||
indent_level = len(line) - len(line.lstrip())
|
||||
start_idx = idx
|
||||
elif line.strip().startswith("all_generative_model_classes = "):
|
||||
indent_level = len(line) - len(line.lstrip())
|
||||
start_idx = idx
|
||||
elif line.strip().startswith("pipeline_model_mapping = "):
|
||||
indent_level = len(line) - len(line.lstrip())
|
||||
start_idx = idx
|
||||
def_line = line
|
||||
break
|
||||
|
||||
if start_idx is None:
|
||||
return "", -1
|
||||
# Find the ending index (inclusive) of the above found block.
|
||||
end_idx = find_block_ending(class_lines, start_idx, indent_level)
|
||||
|
||||
# Extract `is_xxx_available()` from existing blocks: some models require specific libraries like `timm` and use
|
||||
# `is_timm_available()` instead of `is_torch_available()`.
|
||||
# Keep leading and trailing whitespaces
|
||||
r = re.compile(r"\s(is_\S+?_available\(\))\s")
|
||||
for line in class_lines[start_idx : end_idx + 1]:
|
||||
backend_condition = r.search(line)
|
||||
if backend_condition is not None:
|
||||
# replace the leading and trailing whitespaces to the space character " ".
|
||||
target = " " + backend_condition[0][1:-1] + " "
|
||||
line_to_add = r.sub(target, line_to_add)
|
||||
break
|
||||
|
||||
if def_line is None:
|
||||
# `pipeline_model_mapping` is not defined. The target index is set to the ending index (inclusive) of
|
||||
# `all_model_classes` or `all_generative_model_classes`.
|
||||
target_idx = end_idx
|
||||
else:
|
||||
# `pipeline_model_mapping` is defined. The target index is set to be one **BEFORE** its start index.
|
||||
target_idx = start_idx - 1
|
||||
# mark the lines of the currently existing `pipeline_model_mapping` to be removed.
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
# These lines are going to be removed before writing to the test file.
|
||||
class_lines[idx] = None # noqa
|
||||
|
||||
# Make sure the test class is a subclass of `PipelineTesterMixin`.
|
||||
parent_classes = [x.__name__ for x in test_class.__bases__]
|
||||
if "PipelineTesterMixin" not in parent_classes:
|
||||
# Put `PipelineTesterMixin` just before `unittest.TestCase`
|
||||
_parent_classes = [x for x in parent_classes if x != "TestCase"] + ["PipelineTesterMixin"]
|
||||
if "TestCase" in parent_classes:
|
||||
# Here we **assume** the original string is always with `unittest.TestCase`.
|
||||
_parent_classes.append("unittest.TestCase")
|
||||
parent_classes = ", ".join(_parent_classes)
|
||||
for idx, line in enumerate(class_lines):
|
||||
# Find the ending of the declaration of `test_class`
|
||||
if line.strip().endswith("):"):
|
||||
# mark the lines of the declaration of `test_class` to be removed
|
||||
for _idx in range(idx + 1):
|
||||
class_lines[_idx] = None # noqa
|
||||
break
|
||||
# Add the new, one-line, class declaration for `test_class`
|
||||
class_lines[0] = f"class {test_class.__name__}({parent_classes}):\n"
|
||||
|
||||
# Add indentation
|
||||
line_to_add = " " * indent_level + line_to_add
|
||||
# Insert `pipeline_model_mapping` to `class_lines`.
|
||||
# (The line at `target_idx` should be kept by definition!)
|
||||
class_lines = class_lines[: target_idx + 1] + [line_to_add] + class_lines[target_idx + 1 :]
|
||||
# Remove the lines that are marked to be removed
|
||||
class_lines = [x for x in class_lines if x is not None]
|
||||
|
||||
# Move from test class to module (in order to write to the test file)
|
||||
module_lines = inspect.getsourcelines(inspect.getmodule(test_class))[0]
|
||||
# Be careful with the 1-off between line numbers and array indices
|
||||
module_lines = module_lines[: class_start_line_no - 1] + class_lines + module_lines[class_end_line_no:]
|
||||
code = "".join(module_lines)
|
||||
|
||||
moddule_file = inspect.getsourcefile(test_class)
|
||||
with open(moddule_file, "w", encoding="UTF-8", newline="\n") as fp:
|
||||
fp.write(code)
|
||||
|
||||
return line_to_add
|
||||
|
||||
|
||||
def add_pipeline_model_mapping_to_test_file(test_file, overwrite=False):
|
||||
"""Add `pipeline_model_mapping` to `test_file`."""
|
||||
test_class = find_test_class(test_file)
|
||||
if test_class:
|
||||
add_pipeline_model_mapping(test_class, overwrite=overwrite)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--test_file", type=str, help="A path to the test file, starting with the repository's `tests` directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all",
|
||||
action="store_true",
|
||||
help="If to check and modify all test files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="If to overwrite a test class if it has already defined `pipeline_model_mapping`.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.all and not args.test_file:
|
||||
raise ValueError("Please specify either `test_file` or pass `--all` to check/modify all test files.")
|
||||
elif args.all and args.test_file:
|
||||
raise ValueError("Only one of `--test_file` and `--all` could be specified.")
|
||||
|
||||
test_files = []
|
||||
if args.test_file:
|
||||
test_files = [args.test_file]
|
||||
else:
|
||||
pattern = os.path.join("tests", "models", "**", "test_modeling_*.py")
|
||||
for test_file in glob.glob(pattern):
|
||||
test_files.append(test_file)
|
||||
|
||||
for test_file in test_files:
|
||||
if test_file in TEST_FILE_TO_IGNORE:
|
||||
print(f"[SKIPPED] {test_file} is skipped as it is in `TEST_FILE_TO_IGNORE` in the file {__file__}.")
|
||||
continue
|
||||
add_pipeline_model_mapping_to_test_file(test_file, overwrite=args.overwrite)
|
||||
80
utils/aggregate_failure_reports.py
Normal file
80
utils/aggregate_failure_reports.py
Normal file
@@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Aggregate multiple failure report JSON files into a single file.
|
||||
|
||||
This script reads all JSON files from a directory and combines them
|
||||
into a single JSON array.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def aggregate_failures(input_dir, output_file):
|
||||
"""
|
||||
Aggregate failure reports from multiple JSON files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing failure report JSON files
|
||||
output_file: Path to output aggregated JSON file
|
||||
|
||||
Returns:
|
||||
Number of failures aggregated
|
||||
"""
|
||||
failures = []
|
||||
input_path = Path(input_dir)
|
||||
|
||||
if input_path.exists() and input_path.is_dir():
|
||||
for failure_file in input_path.glob("*.json"):
|
||||
try:
|
||||
with open(failure_file) as f:
|
||||
failure_data = json.load(f)
|
||||
failures.append(failure_data)
|
||||
except Exception as e:
|
||||
print(f"Error reading {failure_file}: {e}", file=sys.stderr)
|
||||
|
||||
# Write aggregated failures
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(failures, f, indent=2)
|
||||
|
||||
print(f"Aggregated {len(failures)} failure(s) from {input_dir} to {output_file}")
|
||||
return len(failures)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Aggregate failure report JSON files")
|
||||
parser.add_argument(
|
||||
"--input-dir",
|
||||
required=True,
|
||||
help="Directory containing failure report JSON files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=True,
|
||||
help="Output file path for aggregated JSON",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
aggregate_failures(args.input_dir, args.output)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
309
utils/check_auto.py
Normal file
309
utils/check_auto.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import difflib
|
||||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from collections import Counter, OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from sort_auto_mappings import sort_auto_mapping
|
||||
|
||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES as COMPLETE_CONFIG_MAPPING_NAMES
|
||||
from transformers.models.auto.image_processing_auto import MISSING_IMAGE_PROCESSOR_MAPPING_NAMES
|
||||
from transformers.models.auto.video_processing_auto import MISSING_VIDEO_PROCESSOR_MAPPING_NAMES
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "auto_mappings",
|
||||
"label": "Generate auto mappings",
|
||||
"cache_globs": [],
|
||||
"check_args": [],
|
||||
"fix_args": ["--fix_and_overwrite"],
|
||||
}
|
||||
|
||||
AUTO_GENERATED_HADER = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from existing config files and their `model_type`s. Do NOT edit this file
|
||||
# manually as any edits will be overwritten by auto-generation of the file. If any change should be done,
|
||||
# please add the correct `cls.model_type` in your config class and run `python utils/check_auto.py --fix_and_overwrite`.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Copyright 2026 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
|
||||
# Some keys are duplicated due to incorrect naming at model shipping and BC
|
||||
IGNORE_DUPLICATE_CONFIG = ["GPT2Config", "EvollaConfig", "MLCDVisionConfig"]
|
||||
|
||||
|
||||
def build_config_mapping_names() -> tuple[dict, dict]:
|
||||
model_type_map = OrderedDict()
|
||||
special_mappings = OrderedDict()
|
||||
# Track which model_types were resolved by a "natural" match (model_type == module_name)
|
||||
# so a later non-natural match (e.g. MaskFormerDetrConfig with model_type="detr" inside
|
||||
# models/maskformer/) does not silently overwrite the canonical class.
|
||||
natural_types: set[str] = set()
|
||||
|
||||
# `glob.glob` is filesystem-order dependent — sort to make the output deterministic.
|
||||
all_files = sorted(glob.glob("src/transformers/models/**/configuration_*.py", recursive=True))
|
||||
for config_path in all_files:
|
||||
module_name = config_path.split("/")[-2]
|
||||
with open(config_path, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef) and any(
|
||||
base.id == "PreTrainedConfig" for base in node.bases if isinstance(base, ast.Name)
|
||||
):
|
||||
config_cls_name = node.name
|
||||
model_type = None
|
||||
for stmt in node.body:
|
||||
if isinstance(stmt, ast.Assign):
|
||||
if model_types := [
|
||||
stmt.value.value
|
||||
for target in stmt.targets
|
||||
if isinstance(target, ast.Name) and target.id == "model_type"
|
||||
]:
|
||||
model_type = model_types[0]
|
||||
break
|
||||
elif isinstance(stmt, ast.AnnAssign):
|
||||
if stmt.target.id == "model_type":
|
||||
model_type = stmt.value.value
|
||||
break
|
||||
|
||||
if not model_type:
|
||||
continue
|
||||
|
||||
is_natural = model_type == module_name
|
||||
# If we already recorded a natural match for this model_type, don't let a
|
||||
# non-natural one overwrite it — the natural class is the canonical owner.
|
||||
if model_type in natural_types and not is_natural:
|
||||
continue
|
||||
|
||||
model_type_map[model_type] = config_cls_name
|
||||
if is_natural:
|
||||
natural_types.add(model_type)
|
||||
special_mappings.pop(model_type, None)
|
||||
else:
|
||||
special_mappings[model_type] = module_name
|
||||
|
||||
return model_type_map, special_mappings
|
||||
|
||||
|
||||
def build_image_processor_mapping(
|
||||
config_mapping: dict[str, str],
|
||||
) -> OrderedDict[str, dict[str, str | None]]:
|
||||
processor_mapping = OrderedDict()
|
||||
for model_type in config_mapping:
|
||||
module = model_type.replace("-", "_")
|
||||
fast_processor_name = slow_processor_name = None
|
||||
if os.path.exists(f"src/transformers/models/{module}/image_processing_pil_{module}.py"):
|
||||
with open(f"src/transformers/models/{module}/image_processing_pil_{module}.py", "r") as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef) and any(
|
||||
base.id == "PilBackend" for base in node.bases if isinstance(base, ast.Name)
|
||||
):
|
||||
slow_processor_name = node.name
|
||||
|
||||
if os.path.exists(f"src/transformers/models/{module}/image_processing_{module}.py"):
|
||||
with open(f"src/transformers/models/{module}/image_processing_{module}.py", "r") as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef) and any(
|
||||
base.id == "TorchvisionBackend" for base in node.bases if isinstance(base, ast.Name)
|
||||
):
|
||||
fast_processor_name = node.name
|
||||
|
||||
if slow_processor_name is not None or fast_processor_name is not None:
|
||||
processor_mapping[model_type] = {
|
||||
**({"pil": slow_processor_name} if slow_processor_name else {}),
|
||||
**({"torchvision": fast_processor_name} if fast_processor_name else {}),
|
||||
}
|
||||
|
||||
return processor_mapping
|
||||
|
||||
|
||||
def build_video_processor_mapping(
|
||||
config_mapping: dict[str, str],
|
||||
) -> OrderedDict[str, dict[str, str | None]]:
|
||||
processor_mapping = OrderedDict()
|
||||
for model_type in config_mapping:
|
||||
module = model_type.replace("-", "_")
|
||||
video_processor_name = None
|
||||
|
||||
if os.path.exists(f"src/transformers/models/{module}/video_processing_{module}.py"):
|
||||
with open(f"src/transformers/models/{module}/video_processing_{module}.py", "r") as f:
|
||||
content = f.read()
|
||||
|
||||
tree = ast.parse(content)
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef) and any(
|
||||
base.id == "BaseVideoProcessor" for base in node.bases if isinstance(base, ast.Name)
|
||||
):
|
||||
video_processor_name = node.name
|
||||
|
||||
if video_processor_name is not None:
|
||||
processor_mapping[model_type] = video_processor_name
|
||||
|
||||
return processor_mapping
|
||||
|
||||
|
||||
def run_ruff_and_sort(file: str):
|
||||
"""Run `ruff` linter and formatter on `file`, as in `make style` and sort the mappings order"""
|
||||
sort_auto_mapping(file, overwrite=True)
|
||||
subprocess.run(["ruff", "check", file, "--fix"], stdout=subprocess.DEVNULL)
|
||||
subprocess.run(["ruff", "format", file], stdout=subprocess.DEVNULL)
|
||||
|
||||
|
||||
def format_dict_value(v):
|
||||
if isinstance(v, str):
|
||||
return f'"{v}"'
|
||||
elif isinstance(v, dict):
|
||||
items = ", ".join(f'"{k}": {format_dict_value(val)}' for k, val in v.items())
|
||||
return "{" + items + "}"
|
||||
elif isinstance(v, list):
|
||||
items = ", ".join(format_dict_value(x) for x in v)
|
||||
return "[" + items + "]"
|
||||
else:
|
||||
return repr(v)
|
||||
|
||||
|
||||
def format_ordered_dict(name: str, data: OrderedDict):
|
||||
lines = []
|
||||
|
||||
lines.append(f"{name} = OrderedDict(")
|
||||
lines.append(f"{' ' * 4}[")
|
||||
|
||||
for k, v in data.items():
|
||||
lines.append(f'{" " * 8}("{k}", {format_dict_value(v)}),')
|
||||
|
||||
lines.append(f"{' ' * 4}]")
|
||||
lines.append(")\n\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def check_duplicates(mapping_for_special_models: dict[str, Any], auto_mapping: dict[str, Any]):
|
||||
if intersections := (set(mapping_for_special_models.keys()) & set(auto_mapping.keys())):
|
||||
raise ValueError(
|
||||
"You have manually duplicated a model-type that is present in `auto_mappings.py`. "
|
||||
f"Please, delete the entries for {intersections} if they are identical to auto-generated dict, "
|
||||
"or use consistent naming across model files so that the names match."
|
||||
)
|
||||
|
||||
|
||||
def main(overwrite: bool):
|
||||
filename = "src/transformers/models/auto/auto_mappings.py"
|
||||
|
||||
# 1. Read existing file content if available
|
||||
old_content = ""
|
||||
if os.path.exists(filename):
|
||||
old_content = open(filename, "r").read()
|
||||
|
||||
# 2. Generate new config mapping dicts by parsing all model-config classes
|
||||
config_mapping, special_mapping = build_config_mapping_names()
|
||||
image_processor_mapping = build_image_processor_mapping(config_mapping=config_mapping)
|
||||
video_processor_mapping = build_video_processor_mapping(config_mapping=config_mapping)
|
||||
|
||||
# Make sure users aren't duplicating the same keys manually
|
||||
check_duplicates(MISSING_IMAGE_PROCESSOR_MAPPING_NAMES, image_processor_mapping)
|
||||
check_duplicates(MISSING_VIDEO_PROCESSOR_MAPPING_NAMES, video_processor_mapping)
|
||||
|
||||
# The config mapping has to be one-to-one for correct `AutoConfig.from_pretrained()` because `LazyMapping`
|
||||
# reverts keys/values and creates a dict from it. Duplicate values will be overwritten by whatever comes at last
|
||||
duplicate_keys = [n for n, c in Counter(COMPLETE_CONFIG_MAPPING_NAMES.keys()).items() if c > 1]
|
||||
if duplicate_keys:
|
||||
raise ValueError(
|
||||
f"Keys in `CONFIG_MAPPING_NAMES` contain duplicates = {duplicate_keys}. "
|
||||
"The mapping has to be one-to-one to ensure correct `AutoConfig` functionality!"
|
||||
)
|
||||
|
||||
duplicate_values = [
|
||||
n
|
||||
for n, c in Counter(COMPLETE_CONFIG_MAPPING_NAMES.values()).items()
|
||||
if c > 1 and n not in IGNORE_DUPLICATE_CONFIG
|
||||
]
|
||||
if duplicate_values:
|
||||
raise ValueError(
|
||||
f"Values in `CONFIG_MAPPING_NAMES` contain duplicates = {duplicate_values}. "
|
||||
"The mapping has to be one-to-one to ensure correct `AutoConfig` functionality!"
|
||||
)
|
||||
|
||||
new_mappings = {
|
||||
"CONFIG_MAPPING_NAMES": config_mapping,
|
||||
"SPECIAL_MODEL_TYPE_TO_MODULE_NAME": special_mapping,
|
||||
"IMAGE_PROCESSOR_MAPPING_NAMES": image_processor_mapping,
|
||||
"VIDEO_PROCESSOR_MAPPING_NAMES": video_processor_mapping,
|
||||
}
|
||||
new_content = AUTO_GENERATED_HADER + "\nfrom collections import OrderedDict\n\n"
|
||||
for k, v in new_mappings.items():
|
||||
new_content += format_ordered_dict(name=k, data=v)
|
||||
|
||||
# 3. If the new auto-generate content is different, overwrite it
|
||||
# Dirty hack to sort and apply ruff to the file content, for easier matching
|
||||
with tempfile.TemporaryDirectory() as temp_folder:
|
||||
temp_filename = os.path.join(temp_folder, "temp.py")
|
||||
with open(temp_filename, "w") as temp_file:
|
||||
temp_file.write(new_content)
|
||||
|
||||
run_ruff_and_sort(temp_filename)
|
||||
new_content = open(temp_filename, "r").read()
|
||||
|
||||
if old_content != new_content:
|
||||
if not overwrite:
|
||||
diff = "".join(
|
||||
difflib.unified_diff(
|
||||
old_content.splitlines(keepends=True),
|
||||
new_content.splitlines(keepends=True),
|
||||
fromfile=f"{filename} (on disk)",
|
||||
tofile=f"{filename} (regenerated)",
|
||||
n=3,
|
||||
)
|
||||
)
|
||||
raise Exception(
|
||||
"Generated auto-mapping is not consistent with the contents of `models/auto/auto_mappings.py`.\n"
|
||||
"Run `make fix-repo` or `python utils/check_auto.py --fix_and_overwrite` to fix them.\n\n"
|
||||
f"Diff (on disk → regenerated):\n{diff}"
|
||||
)
|
||||
else:
|
||||
with open(filename, "w") as f:
|
||||
f.write(new_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
main(overwrite=args.fix_and_overwrite)
|
||||
429
utils/check_bad_commit.py
Normal file
429
utils/check_bad_commit.py
Normal file
@@ -0,0 +1,429 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
|
||||
import git
|
||||
import requests
|
||||
|
||||
|
||||
def create_script(target_test):
|
||||
"""Create a python script to be run by `git bisect run` to determine if `target_test` passes or fails.
|
||||
If a test is not found in a commit, the script with exit code `0` (i.e. `Success`).
|
||||
|
||||
Args:
|
||||
target_test (`str`): The test to check.
|
||||
|
||||
Returns:
|
||||
`str`: The script to be run by `git bisect run`.
|
||||
"""
|
||||
|
||||
script = f"""
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
_ = subprocess.run(
|
||||
["python3", "-m", "pip", "install", "-e", "."],
|
||||
capture_output = True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
result = subprocess.run(
|
||||
["python3", "-m", "pytest", "-v", "--flake-finder", "--flake-runs=4", "-rfEp", f"{target_test}"],
|
||||
capture_output = True,
|
||||
text=True,
|
||||
)
|
||||
print(result.stdout)
|
||||
|
||||
if f"FAILED {target_test}" in result.stdout:
|
||||
print("test failed")
|
||||
exit(1)
|
||||
elif result.returncode != 0:
|
||||
if "ERROR: file or directory not found: " in result.stderr:
|
||||
print("test file or directory not found in this commit")
|
||||
# git bisect treats exit code 125 as `test not found`. But this causes it not be able to make the conclusion
|
||||
# if a test is added between the `good commit` (exclusive) and `bad commit` (inclusive) (in git bisect terminology).
|
||||
# So we return 0 here in order to allow the process being able to identify the first commit that fails the test.
|
||||
exit(0)
|
||||
elif "ERROR: not found: " in result.stderr:
|
||||
print("test not found in this commit")
|
||||
exit(0)
|
||||
else:
|
||||
print(f"pytest gets unknown error: {{result.stderr}}")
|
||||
exit(1)
|
||||
|
||||
print(f"pytest runs successfully.")
|
||||
exit(0)
|
||||
"""
|
||||
|
||||
with open("target_script.py", "w") as fp:
|
||||
fp.write(script.strip())
|
||||
|
||||
|
||||
def is_bad_commit(target_test, commit):
|
||||
repo = git.Repo(".") # or specify path to your repo
|
||||
|
||||
# Save the current HEAD reference
|
||||
original_head = repo.head.commit
|
||||
|
||||
# Checkout to the commit
|
||||
repo.git.checkout(commit)
|
||||
|
||||
create_script(target_test=target_test)
|
||||
|
||||
result = subprocess.run(
|
||||
["python3", "target_script.py"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# Restore to original commit
|
||||
repo.git.checkout(original_head)
|
||||
|
||||
n_passed = 0
|
||||
o = re.findall(r"====.* (\d+) passed", result.stdout)
|
||||
if len(o) > 0:
|
||||
n_passed = int(o[0])
|
||||
|
||||
n_failed = 0
|
||||
o = re.findall(r"====.* (\d+) failed", result.stdout)
|
||||
if len(o) > 0:
|
||||
n_failed = int(o[0])
|
||||
|
||||
error_message = ""
|
||||
if n_failed > 0:
|
||||
match = re.search(r"^(FAILED .+ - .+)$", result.stdout, re.MULTILINE)
|
||||
error_message = match.group(1).strip() if match else "Cannot retrieve error message."
|
||||
|
||||
return result.returncode != 0, n_failed, n_passed, error_message
|
||||
|
||||
|
||||
def find_bad_commit(target_test, start_commit, end_commit):
|
||||
"""Find (backward) the earliest commit between `start_commit` (inclusive) and `end_commit` (exclusive) at which `target_test` fails.
|
||||
|
||||
Args:
|
||||
target_test (`str`): The test to check.
|
||||
start_commit (`str`): The latest commit (inclusive).
|
||||
end_commit (`str`): The earliest commit (exclusive).
|
||||
|
||||
Returns:
|
||||
`dict`: A dict containing the info about the earliest commit at which `target_test` fails.
|
||||
"""
|
||||
result = {
|
||||
"bad_commit": None,
|
||||
"status": None,
|
||||
"failure_at_workflow_commit": None,
|
||||
"failure_at_base_commit": None,
|
||||
"failure_at_bad_commit": None,
|
||||
}
|
||||
|
||||
is_pr_ci = os.environ.get("GITHUB_EVENT_NAME") in ["issue_comment", "pull_request"]
|
||||
|
||||
# For PR comment CI, we "assume" all tests at `end_commit` passed, so any failing test during a PR CI run is
|
||||
# "a new failing test", and we can perform more detailed checks with this script.
|
||||
# For "a failing tes at start_commit", we check the test against `end_commit` (run multiple times):
|
||||
# - if all passing at end_commit: an actual new failing test at start_commit
|
||||
# - if all failing at end_commit: get the failure message and compare it against the one from start_commit:
|
||||
# - same failure message: not a new failing test --> don't report it
|
||||
# - different failure message: kind of a new failing test --> need to report it
|
||||
# - if both failing and passing at end_commit: mark it as flaky
|
||||
|
||||
# check if `end_commit` fails the test
|
||||
failed_before, n_failed, n_passed, failure_at_base_commit = is_bad_commit(target_test, end_commit)
|
||||
# We only need one failure to conclude the test is flaky on the previous run with `end_commit`.
|
||||
# However, when running on CI, we need at least one failure and one pass to conclude.
|
||||
is_flaky_at_end_commit = ((not is_pr_ci) and n_failed > 0) or (is_pr_ci and n_failed > 0 and n_passed > 0)
|
||||
# `n_passed == 0` itself is not enough, as the test may not exist in the codebase at `end_commit`.
|
||||
is_failing_at_end_commit = failed_before and n_passed == 0
|
||||
|
||||
if is_flaky_at_end_commit:
|
||||
result["status"] = (
|
||||
f"flaky: test both passed and failed during the check of the current run on the previous commit: {end_commit}"
|
||||
)
|
||||
return result
|
||||
|
||||
elif (not is_pr_ci) and is_failing_at_end_commit:
|
||||
result["status"] = (
|
||||
f"flaky: test passed in the previous run (commit: {end_commit}) but failed (on the same commit) during the check of the current run."
|
||||
)
|
||||
return result
|
||||
|
||||
# if there is no new commit (e.g. 2 different CI runs on the same commit):
|
||||
# - failed once on `start_commit` but passed on `end_commit`, which are the same commit --> flaky (or something change externally) --> don't report
|
||||
if start_commit == end_commit:
|
||||
result["status"] = (
|
||||
f"flaky: test fails on the current CI run but passed in the previous run which is running on the same commit {end_commit}."
|
||||
)
|
||||
return result
|
||||
|
||||
# Now, we are (almost) sure `target_test` is not failing at `end_commit`. (For a PR CI, it may fail at `end_commit`)
|
||||
# Check if `start_commit` fails the test.
|
||||
# **IMPORTANT** we only need one pass to conclude the test is flaky on the current run with `start_commit`!
|
||||
_, n_failed, n_passed, failure_at_workflow_commit = is_bad_commit(target_test, start_commit)
|
||||
if n_passed > 0:
|
||||
# failed on CI run, but not reproducible here --> don't report
|
||||
result["status"] = (
|
||||
f"flaky: test fails on the current CI run (commit: {start_commit}) but passes during the check."
|
||||
)
|
||||
return result
|
||||
|
||||
# The test fails on `start_commit`, and
|
||||
# - if the CI is run on PR: this block checks if the test also failed on `start_commit`.
|
||||
# - otherwise: the test passed on `end_commit` --> an actual new failing test, this block is skipped.
|
||||
|
||||
# TODO: A helper method to handle this and other possible error messages in a clean and centralized way.
|
||||
failure_at_workflow_commit_processed = failure_at_workflow_commit
|
||||
failure_at_base_commit_processed = failure_at_base_commit
|
||||
if "torch.OutOfMemoryError: CUDA out of memory" in failure_at_workflow_commit_processed:
|
||||
failure_at_workflow_commit_processed = "torch.OutOfMemoryError: CUDA out of memory"
|
||||
if "torch.OutOfMemoryError: CUDA out of memory" in failure_at_base_commit_processed:
|
||||
failure_at_base_commit_processed = "torch.OutOfMemoryError: CUDA out of memory"
|
||||
|
||||
different_failures = failure_at_workflow_commit_processed != failure_at_base_commit_processed
|
||||
|
||||
if is_pr_ci and failure_at_base_commit != "" and different_failures:
|
||||
result["bad_commit"] = start_commit
|
||||
result["status"] = (
|
||||
f"test fails both on the current commit ({start_commit}) and the previous commit ({end_commit}), but with DIFFERENT error message!"
|
||||
)
|
||||
result["failure_at_workflow_commit"] = failure_at_workflow_commit
|
||||
result["failure_at_base_commit"] = failure_at_base_commit
|
||||
result["failure_at_bad_commit"] = failure_at_workflow_commit
|
||||
return result
|
||||
# Fail on both commits but with the same error message ==> don't include
|
||||
elif is_pr_ci and not different_failures:
|
||||
result["bad_commit"] = None
|
||||
result["status"] = (
|
||||
f"test fails both on the current commit ({start_commit}) and the previous commit ({end_commit}) with the SAME error message!"
|
||||
)
|
||||
result["failure_at_workflow_commit"] = failure_at_workflow_commit
|
||||
result["failure_at_base_commit"] = failure_at_base_commit
|
||||
result["failure_at_bad_commit"] = failure_at_workflow_commit
|
||||
return result
|
||||
|
||||
# The test fails on `start_commit` but passed on `end_commit`.
|
||||
create_script(target_test=target_test)
|
||||
|
||||
bash = f"""
|
||||
git bisect reset
|
||||
git bisect start --first-parent {start_commit} {end_commit}
|
||||
git bisect run python3 target_script.py
|
||||
"""
|
||||
|
||||
with open("run_git_bisect.sh", "w") as fp:
|
||||
fp.write(bash.strip())
|
||||
|
||||
bash_result = subprocess.run(
|
||||
["bash", "run_git_bisect.sh"],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
print(bash_result.stdout)
|
||||
|
||||
# This happens if running the script gives exit code < 0 or other issues
|
||||
if "error: bisect run failed" in bash_result.stderr:
|
||||
error_msg = f"Error when running git bisect:\nbash error: {bash_result.stderr}\nbash output:\n{bash_result.stdout}\nset `bad_commit` to `None`."
|
||||
print(error_msg)
|
||||
result["status"] = "git bisect failed"
|
||||
return result
|
||||
|
||||
pattern = r"(.+) is the first bad commit"
|
||||
commits = re.findall(pattern, bash_result.stdout)
|
||||
|
||||
bad_commit = None
|
||||
failure_at_bad_commit = ""
|
||||
if len(commits) > 0:
|
||||
bad_commit = commits[0]
|
||||
_, _, _, failure_at_bad_commit = is_bad_commit(target_test, bad_commit)
|
||||
|
||||
print(f"Between `start_commit` {start_commit} and `end_commit` {end_commit}")
|
||||
print(f"bad_commit: {bad_commit}\n")
|
||||
|
||||
result["bad_commit"] = bad_commit
|
||||
result["status"] = "git bisect found the bad commit."
|
||||
result["failure_at_workflow_commit"] = failure_at_workflow_commit
|
||||
result["failure_at_base_commit"] = failure_at_base_commit
|
||||
result["failure_at_bad_commit"] = failure_at_bad_commit
|
||||
return result
|
||||
|
||||
|
||||
def get_commit_info(commit, pr_number=None, github_token=None):
|
||||
"""Get information for a commit via `api.github.com`."""
|
||||
if commit is None:
|
||||
return {"commit": None, "pr_number": None, "author": None, "merged_by": None, "parent": None}
|
||||
|
||||
author = None
|
||||
merged_author = None
|
||||
|
||||
headers = (
|
||||
{"Accept": "application/vnd.github+json", "Authorization": f"Bearer {github_token}"} if github_token else {}
|
||||
)
|
||||
|
||||
# Use PR number from environment if not provided
|
||||
if pr_number is None:
|
||||
pr_number = os.environ.get("pr_number")
|
||||
|
||||
# First, get commit info to check if it's a merge commit
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/commits/{commit}"
|
||||
commit_info = requests.get(url, headers=headers).json()
|
||||
|
||||
commit_to_query = commit
|
||||
|
||||
# Check if this is a merge commit created by GitHub
|
||||
if commit_info.get("parents") and len(commit_info["parents"]) > 1:
|
||||
commit_message = commit_info.get("commit", {}).get("message", "")
|
||||
# Parse message like "Merge 1ac46bed... into 5a67f0a7..."
|
||||
import re
|
||||
|
||||
match = re.match(r"^Merge ([a-f0-9]{40}) into ([a-f0-9]{40})", commit_message)
|
||||
if match:
|
||||
# Use the first SHA (the PR commit)
|
||||
commit_to_query = match.group(1)
|
||||
|
||||
# If no PR number yet, try to discover it from the commit.
|
||||
# The API can return an error dict (e.g. rate limit) instead of a list, so guard with isinstance.
|
||||
if not pr_number:
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/commits/{commit_to_query}/pulls"
|
||||
pr_info_for_commit = requests.get(url, headers=headers).json()
|
||||
if isinstance(pr_info_for_commit, list) and len(pr_info_for_commit) > 0:
|
||||
pr_number = pr_info_for_commit[0].get("number")
|
||||
|
||||
# If we have a PR number, get author and merged_by info.
|
||||
# Use .get() throughout: on rate-limit/403 the API returns an error dict, not the expected PR object.
|
||||
if pr_number:
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/pulls/{pr_number}"
|
||||
pr_for_commit = requests.get(url, headers=headers).json()
|
||||
author = pr_for_commit.get("user", {}).get("login")
|
||||
merged_by = pr_for_commit.get("merged_by")
|
||||
if merged_by is not None:
|
||||
merged_author = merged_by.get("login")
|
||||
|
||||
parents = commit_info.get("parents", [])
|
||||
parent = parents[0]["sha"] if parents else None
|
||||
if author is None:
|
||||
author = (commit_info.get("author") or {}).get("login")
|
||||
|
||||
return {"commit": commit, "pr_number": pr_number, "author": author, "merged_by": merged_author, "parent": parent}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--start_commit", type=str, required=True, help="The latest commit hash to check.")
|
||||
parser.add_argument("--end_commit", type=str, required=True, help="The earliest commit hash to check.")
|
||||
parser.add_argument("--test", type=str, help="The test to check.")
|
||||
parser.add_argument("--file", type=str, help="The report file.")
|
||||
parser.add_argument("--output_file", type=str, required=True, help="The path of the output file.")
|
||||
parser.add_argument(
|
||||
"--github_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="GitHub token to avoid API rate limits. Falls back to GITHUB_TOKEN env var.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.github_token is None:
|
||||
args.github_token = os.environ.get("GITHUB_TOKEN")
|
||||
|
||||
run_idx = os.environ.get("run_idx")
|
||||
n_runners = os.environ.get("n_runners")
|
||||
|
||||
print(f"start_commit: {args.start_commit}")
|
||||
print(f"end_commit: {args.end_commit}")
|
||||
|
||||
# Cache commit info to avoid redundant API calls and reduce rate limit pressure.
|
||||
commit_info_cache = {}
|
||||
|
||||
if len({args.test is None, args.file is None}) != 2:
|
||||
raise ValueError("Exactly one argument `test` or `file` must be specified.")
|
||||
|
||||
if args.test is not None:
|
||||
commit, status = find_bad_commit(
|
||||
target_test=args.test, start_commit=args.start_commit, end_commit=args.end_commit
|
||||
)
|
||||
with open(args.output_file, "w", encoding="UTF-8") as fp:
|
||||
fp.write(f"{args.test}\n{commit}\n{status}")
|
||||
elif os.path.isfile(args.file):
|
||||
with open(args.file, "r", encoding="UTF-8") as fp:
|
||||
reports = json.load(fp)
|
||||
|
||||
model_with_failures = []
|
||||
|
||||
for model in reports:
|
||||
# We change the format of "new_failures.json" in PR #XXXXX, let's handle both formats for a few weeks.
|
||||
if "failures" in reports[model]:
|
||||
if "job_link" in reports[model]:
|
||||
for device, device_failures in reports[model]["failures"].items():
|
||||
if device in reports[model]["job_link"]:
|
||||
for failure in device_failures:
|
||||
failure["job_link"] = reports[model]["job_link"][device]
|
||||
del reports[model]["job_link"]
|
||||
reports[model] = reports[model]["failures"]
|
||||
|
||||
# TODO: make this script able to deal with both `single-gpu` and `multi-gpu` via a new argument.
|
||||
reports[model].pop("multi-gpu", None)
|
||||
failed_tests = reports[model].get("single-gpu", [])
|
||||
|
||||
model_with_failures.extend([(model, test) for test in failed_tests])
|
||||
|
||||
if run_idx is not None:
|
||||
run_idx = int(run_idx)
|
||||
n_runners = int(n_runners)
|
||||
|
||||
num_failed_tests_to_run = len(model_with_failures) // n_runners
|
||||
|
||||
start_idx = num_failed_tests_to_run * run_idx
|
||||
end_idx = num_failed_tests_to_run * (run_idx + 1) if run_idx < n_runners - 1 else len(model_with_failures)
|
||||
|
||||
model_with_failures_to_check = model_with_failures[start_idx:end_idx]
|
||||
model_with_failures = model_with_failures_to_check
|
||||
|
||||
failed_tests_with_bad_commits = defaultdict(list)
|
||||
for model, failure in model_with_failures:
|
||||
test = failure["line"]
|
||||
bad_commit_info = find_bad_commit(
|
||||
target_test=test, start_commit=args.start_commit, end_commit=args.end_commit
|
||||
)
|
||||
info = {"test": test}
|
||||
info.update(bad_commit_info)
|
||||
|
||||
bad_commit = bad_commit_info["bad_commit"]
|
||||
|
||||
if bad_commit in commit_info_cache:
|
||||
commit_info = commit_info_cache[bad_commit]
|
||||
else:
|
||||
commit_info = get_commit_info(bad_commit, github_token=args.github_token)
|
||||
commit_info_cache[bad_commit] = commit_info
|
||||
|
||||
commit_info_copied = copy.deepcopy(commit_info)
|
||||
commit_info_copied.pop("commit")
|
||||
commit_info_copied.update({"workflow_commit": args.start_commit, "base_commit": args.end_commit})
|
||||
info.update(commit_info_copied)
|
||||
# put failure message toward the end
|
||||
info = {k: v for k, v in info.items() if not k.startswith(("failure_at_", "job_link"))} | {
|
||||
k: v for k, v in info.items() if k.startswith(("failure_at_", "job_link"))
|
||||
}
|
||||
|
||||
failed_tests_with_bad_commits[model].append(info)
|
||||
|
||||
reports = {model: {"single-gpu": tests} for model, tests in failed_tests_with_bad_commits.items() if tests}
|
||||
|
||||
with open(args.output_file, "w", encoding="UTF-8") as fp:
|
||||
json.dump(reports, fp, ensure_ascii=False, indent=4)
|
||||
399
utils/check_config_attributes.py
Normal file
399
utils/check_config_attributes.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
|
||||
from transformers.configuration_utils import PreTrainedConfig
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "config_attributes",
|
||||
"label": "Config attributes",
|
||||
# Approximate: iterates CONFIG_MAPPING at runtime and also reads modeling_*.py files
|
||||
# in each config's directory via os.listdir(). Deprecated models are skipped.
|
||||
"cache_globs": ["src/transformers/models/**/configuration_*.py", "src/transformers/models/**/modeling_*.py"],
|
||||
"check_args": [],
|
||||
"fix_args": None,
|
||||
}
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_config_docstrings.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
||||
|
||||
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
|
||||
# Usually of small list of allowed attrs, but can be True to allow all
|
||||
SPECIAL_CASES_TO_ALLOW = {
|
||||
"Gemma4UnifiedAudioConfig": ["audio_embed_dim"], # Used as meta data for other attributes/properties
|
||||
"Gemma4UnifiedVisionConfig": [
|
||||
"patch_size",
|
||||
"pooling_kernel_size",
|
||||
], # Used as meta data for other attributes/properties
|
||||
"MiniCPMV4_6Config": ["drop_vision_last_layer"],
|
||||
"OpenAIPrivacyFilterConfig": ["classifier_dropout", "output_router_logits", "router_aux_loss_coef"],
|
||||
"HYV3Config": ["output_router_logits"],
|
||||
"NougatConfig": ["decoder", "encoder"],
|
||||
"PI0Config": ["vlm_projection_dim"],
|
||||
"EuroBertConfig": ["is_causal"], # not used directly, allows causal-bidirectional switch
|
||||
"Ernie4_5_VL_MoeConfig": ["args"], # BC Alias
|
||||
"Ernie4_5_VL_MoeTextConfig": ["args"], # BC Alias
|
||||
"Ernie4_5_VL_MoeVisionConfig": ["args"], # BC Alias
|
||||
"ExaoneMoeConfig": ["first_k_dense_replace"], # BC for other frameworks
|
||||
"AfmoeConfig": ["global_attn_every_n_layers", "rope_scaling"],
|
||||
"LagunaConfig": ["moe_apply_router_weight_on_input"],
|
||||
"xLSTMConfig": ["add_out_norm", "chunkwise_kernel", "sequence_kernel", "step_kernel"],
|
||||
"Lfm2Config": ["full_attn_idxs"],
|
||||
"DiaConfig": ["delay_pattern"],
|
||||
"BambaConfig": ["attn_layer_indices"],
|
||||
"Dots1Config": ["max_window_layers"],
|
||||
"JambaConfig": ["attn_layer_offset", "attn_layer_period", "expert_layer_offset", "expert_layer_period"],
|
||||
"JetMoeConfig": ["output_router_logits"],
|
||||
"Phi3Config": ["embd_pdrop"],
|
||||
"EncodecConfig": ["overlap"],
|
||||
"XcodecConfig": ["sample_rate", "audio_channels"],
|
||||
"RecurrentGemmaConfig": ["block_types", "attention_window_size"],
|
||||
"MambaConfig": ["expand"],
|
||||
"FalconMambaConfig": ["expand"],
|
||||
"FSMTConfig": ["langs", "common_kwargs", "early_stopping", "length_penalty", "max_length", "num_beams"],
|
||||
"GPTNeoConfig": ["attention_types"],
|
||||
"BlenderbotConfig": ["encoder_no_repeat_ngram_size"],
|
||||
"EsmConfig": ["is_folding_model"],
|
||||
"Mask2FormerConfig": ["ignore_value"],
|
||||
"OneFormerConfig": ["ignore_value", "norm"],
|
||||
"T5Config": ["feed_forward_proj"],
|
||||
"MT5Config": ["feed_forward_proj", "tokenizer_class"],
|
||||
"UMT5Config": ["feed_forward_proj", "tokenizer_class"],
|
||||
"LongT5Config": ["feed_forward_proj"],
|
||||
"Pop2PianoConfig": ["feed_forward_proj"],
|
||||
"BioGptConfig": ["layer_norm_eps"],
|
||||
"GLPNConfig": ["layer_norm_eps"],
|
||||
"SegformerConfig": ["layer_norm_eps"],
|
||||
"CvtConfig": ["layer_norm_eps"],
|
||||
"PerceiverConfig": ["layer_norm_eps"],
|
||||
"InformerConfig": ["num_static_real_features", "num_time_features"],
|
||||
"TimeSeriesTransformerConfig": ["num_static_real_features", "num_time_features"],
|
||||
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
|
||||
"SamVisionConfig": ["mlp_ratio"],
|
||||
"DeepseekOcr2SamVisionConfig": ["mlp_ratio"],
|
||||
"Sam3VisionConfig": ["backbone_feature_sizes"],
|
||||
"SamHQVisionConfig": ["mlp_ratio"],
|
||||
"ClapAudioConfig": ["num_classes"],
|
||||
"ClvpDecoderConfig": ["add_cross_attention"],
|
||||
"SpeechT5HifiGanConfig": ["sampling_rate"],
|
||||
"UdopConfig": ["feed_forward_proj"],
|
||||
"ZambaConfig": ["attn_layer_offset", "attn_layer_period"],
|
||||
"MllamaVisionConfig": ["supported_aspect_ratios"],
|
||||
"LEDConfig": ["classifier_dropout"],
|
||||
"GPTNeoXConfig": ["rotary_emb_base"],
|
||||
"ShieldGemma2Config": ["mm_tokens_per_image", "vision_config"],
|
||||
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
||||
"ModernBertConfig": ["local_attention", "reference_compile"],
|
||||
"ModernBertDecoderConfig": ["global_attn_every_n_layers", "local_attention", "local_rope_theta"],
|
||||
"SmolLM3Config": ["no_rope_layer_interval"],
|
||||
"Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"],
|
||||
"HiggsAudioV2Config": ["audio_bos_token", "audio_stream_bos_id", "audio_stream_eos_id"],
|
||||
"HiggsAudioV2TokenizerConfig": ["downsample_factor"],
|
||||
"Cohere2MoeConfig": ["rope_scaling", "sliding_window_pattern"],
|
||||
"CsmConfig": ["tie_codebooks_embeddings"],
|
||||
"DeepseekV2Config": ["norm_topk_prob"],
|
||||
"DeepseekV4Config": [
|
||||
# All BC / config-compat surface that the modeling code never reads but
|
||||
# checkpoints in the wild expose (so we keep accepting them in `__init__`):
|
||||
# `attention_bias` — V4 has no bias on any linear; kept for parity with V3 configs.
|
||||
# `n_shared_experts` — V4 always builds exactly one shared MLP; the count
|
||||
# isn't read because there's no loop over shared experts.
|
||||
# `norm_topk_prob` — V3 router knob; V4's `DeepseekV4TopKRouter` always normalises.
|
||||
# `num_key_value_heads` — V4 is shared-KV MQA (always 1); not read at runtime.
|
||||
# `num_nextn_predict_layers` — MTP layer count from upstream checkpoints; the
|
||||
# MTP head isn't instantiated by transformers' V4 implementation.
|
||||
# `router_jitter_noise` — inherited from Mixtral; V4 routers don't apply jitter.
|
||||
"attention_bias",
|
||||
"n_shared_experts",
|
||||
"norm_topk_prob",
|
||||
"num_key_value_heads",
|
||||
"num_nextn_predict_layers",
|
||||
"router_jitter_noise",
|
||||
],
|
||||
"EsmFoldConfig": ["esm_ablate_pairwise", "esm_ablate_sequence", "esm_input_dropout", "esm_type"],
|
||||
"TrunkConfig": ["cpu_grad_checkpoint", "layer_drop"],
|
||||
"SeamlessM4TConfig": True,
|
||||
"SeamlessM4Tv2Config": True,
|
||||
"ConditionalDetrConfig": True,
|
||||
"DabDetrConfig": True,
|
||||
"SwitchTransformersConfig": True,
|
||||
"MaskFormerDetrConfig": True,
|
||||
"DetrConfig": True,
|
||||
"DFineConfig": True,
|
||||
"Deimv2Config": True, # Mixed encoder variants (hybrid/lite) + DFine inheritance
|
||||
"GroundingDinoConfig": True,
|
||||
"MMGroundingDinoConfig": True,
|
||||
"RTDetrConfig": True,
|
||||
"RTDetrV2Config": True,
|
||||
"YolosConfig": True,
|
||||
"Llama4TextConfig": True,
|
||||
"DPRConfig": True,
|
||||
"FuyuConfig": True,
|
||||
"LayoutXLMConfig": True,
|
||||
"CLIPSegConfig": True,
|
||||
"DeformableDetrConfig": True,
|
||||
"DinatConfig": True,
|
||||
"DonutSwinConfig": True,
|
||||
"FastSpeech2ConformerConfig": True,
|
||||
"LayoutLMv2Config": True,
|
||||
"MaskFormerSwinConfig": True,
|
||||
"MptConfig": True,
|
||||
"MptAttentionConfig": True,
|
||||
"RagConfig": True,
|
||||
"SpeechT5Config": True,
|
||||
"SwinConfig": True,
|
||||
"Swin2SRConfig": True,
|
||||
"Swinv2Config": True,
|
||||
"TableTransformerConfig": True,
|
||||
"TapasConfig": True,
|
||||
"UniSpeechConfig": True,
|
||||
"UniSpeechSatConfig": True,
|
||||
"WavLMConfig": True,
|
||||
"WhisperConfig": True,
|
||||
"JukeboxPriorConfig": True,
|
||||
"Pix2StructTextConfig": True,
|
||||
"IdeficsConfig": True,
|
||||
"IdeficsVisionConfig": True,
|
||||
"IdeficsPerceiverConfig": True,
|
||||
"GptOssConfig": True,
|
||||
"LwDetrConfig": True,
|
||||
"NemotronHConfig": True,
|
||||
# RfDetr config attributes only used in loss code
|
||||
"RfDetrConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"class_loss_coefficient",
|
||||
"dice_loss_coefficient",
|
||||
"eos_coefficient",
|
||||
"focal_alpha",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
"mask_class_loss_coefficient",
|
||||
"mask_dice_loss_coefficient",
|
||||
"mask_loss_coefficient",
|
||||
"mask_point_sample_ratio",
|
||||
],
|
||||
# Internally uses Got Ocr2 so no need to use in the modeling code as we remap in auto instead
|
||||
"PPChart2TableConfig": True,
|
||||
"PPChart2TableVisionConfig": True,
|
||||
"GlmgaConfig": ["vision_config"],
|
||||
"Sapiens2Config": [
|
||||
"num_first_full_attention_layers", # builder attr consumed in __post_init__ to compute num_key_value_heads_per_layer
|
||||
"num_key_value_attention_heads", # builder attr consumed in __post_init__ to compute num_key_value_heads_per_layer
|
||||
"num_last_full_attention_layers", # builder attr consumed in __post_init__ to compute num_key_value_heads_per_layer
|
||||
"flip_pairs", # used externally for post-processing keypoints, not in forward pass
|
||||
],
|
||||
}
|
||||
|
||||
# Common and important attributes, even if they do not always appear in the modeling files (can be a regex pattern)
|
||||
ATTRIBUTES_TO_ALLOW = (
|
||||
# Attr in base `PreTrainedConfig`
|
||||
"transformers_version",
|
||||
"architectures",
|
||||
"chunk_size_feed_forward",
|
||||
"dtype",
|
||||
"id2label",
|
||||
"label2id",
|
||||
"problem_type",
|
||||
"tokenizer_class",
|
||||
"is_encoder_decoder",
|
||||
"output_hidden_states",
|
||||
"return_dict",
|
||||
# Inits related
|
||||
"initializer_range",
|
||||
"init_std",
|
||||
"initializer_factor",
|
||||
"tie_word_embeddings",
|
||||
# Special tokens
|
||||
"bos_index",
|
||||
"eos_index",
|
||||
"pad_index",
|
||||
"unk_index",
|
||||
"mask_index",
|
||||
r".+_token_id",
|
||||
r".+_token_index",
|
||||
# Processors
|
||||
"image_seq_length",
|
||||
"video_seq_length",
|
||||
"image_size",
|
||||
"text_config", # may appear as `get_text_config()`
|
||||
"use_cache",
|
||||
"out_features",
|
||||
"out_indices",
|
||||
"sampling_rate",
|
||||
# backbone related arguments passed to load_backbone
|
||||
"use_pretrained_backbone",
|
||||
"backbone",
|
||||
"backbone_config",
|
||||
"use_timm_backbone",
|
||||
"backbone_kwargs",
|
||||
# rope attributes may not appear directly in the modeling but are used
|
||||
"rope_theta",
|
||||
"partial_rotary_factor",
|
||||
"max_position_embeddings",
|
||||
"pretraining_tp",
|
||||
"use_sliding_window",
|
||||
"max_window_layers",
|
||||
# vision attributes that may be used indirectly via merge_with_config_defaults
|
||||
"vision_feature_layer",
|
||||
"vision_feature_select_strategy",
|
||||
"vision_aspect_ratio",
|
||||
)
|
||||
|
||||
|
||||
def check_attribute_being_used(config_class, attributes, default_value, source_strings):
|
||||
"""Check if any name in `attributes` is used in one of the strings in `source_strings`
|
||||
|
||||
Args:
|
||||
config_class (`type`):
|
||||
The configuration class for which the arguments in its `__init__` will be checked.
|
||||
attributes (`List[str]`):
|
||||
The name of an argument (or attribute) and its variant names if any.
|
||||
default_value (`Any`):
|
||||
A default value for the attribute in `attributes` assigned in the `__init__` of `config_class`.
|
||||
source_strings (`List[str]`):
|
||||
The python source code strings in the same modeling directory where `config_class` is defined. The file
|
||||
containing the definition of `config_class` should be excluded.
|
||||
"""
|
||||
# If we can find the attribute used, then it's all good
|
||||
for attribute in attributes:
|
||||
for modeling_source in source_strings:
|
||||
# check if we can find `config.xxx`, `getattr(config, "xxx", ...)` or `getattr(self.config, "xxx", ...)`
|
||||
if (
|
||||
f"config.{attribute}" in modeling_source
|
||||
or f'getattr(config, "{attribute}"' in modeling_source
|
||||
or f'getattr(self.config, "{attribute}"' in modeling_source
|
||||
or (
|
||||
"TextConfig" in config_class.__name__
|
||||
and f"config.get_text_config().{attribute}" in modeling_source
|
||||
)
|
||||
):
|
||||
return True
|
||||
# Deal with multi-line cases
|
||||
elif (
|
||||
re.search(
|
||||
rf'getattr[ \t\v\n\r\f]*\([ \t\v\n\r\f]*(self\.)?config,[ \t\v\n\r\f]*"{attribute}"',
|
||||
modeling_source,
|
||||
)
|
||||
is not None
|
||||
):
|
||||
return True
|
||||
|
||||
# Special cases to be allowed even if not found as used
|
||||
for attribute in attributes:
|
||||
# Allow if the default value in the configuration class is different from the one in `PreTrainedConfig`
|
||||
if (attribute == "is_encoder_decoder" and default_value is True) or attribute == "tie_word_embeddings":
|
||||
return True
|
||||
# General exceptions for all models
|
||||
elif any(re.search(exception, attribute) for exception in ATTRIBUTES_TO_ALLOW):
|
||||
return True
|
||||
# Model-specific exceptions
|
||||
elif config_class.__name__ in SPECIAL_CASES_TO_ALLOW:
|
||||
model_exceptions = SPECIAL_CASES_TO_ALLOW[config_class.__name__]
|
||||
# Can be true to allow all attributes, or a list of specific allowed attributes
|
||||
if (isinstance(model_exceptions, bool) and model_exceptions) or attribute in model_exceptions:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_config_attributes_being_used(config_class):
|
||||
"""Check the arguments in `__init__` of `config_class` are used in the modeling files in the same directory
|
||||
|
||||
Args:
|
||||
config_class (`type`):
|
||||
The configuration class for which the arguments in its `__init__` will be checked.
|
||||
"""
|
||||
# Get the parameters in `__init__` of the configuration class, and the default values if any
|
||||
signature = dict(inspect.signature(config_class.__init__).parameters)
|
||||
parameter_names = [x for x in list(signature.keys()) if x not in ["self", "kwargs"]]
|
||||
parameter_defaults = [signature[param].default for param in parameter_names]
|
||||
|
||||
# If `attribute_map` exists, an attribute can have different names to be used in the modeling files, and as long
|
||||
# as one variant is used, the test should pass
|
||||
reversed_attribute_map = {}
|
||||
if len(config_class.attribute_map) > 0:
|
||||
reversed_attribute_map = {v: k for k, v in config_class.attribute_map.items()}
|
||||
|
||||
# Get the path to modeling source files
|
||||
config_source_file = inspect.getsourcefile(config_class)
|
||||
model_dir = os.path.dirname(config_source_file)
|
||||
modeling_paths = [os.path.join(model_dir, fn) for fn in os.listdir(model_dir) if fn.startswith("modeling_")]
|
||||
|
||||
# Get the source code strings
|
||||
modeling_sources = []
|
||||
for path in modeling_paths:
|
||||
if os.path.isfile(path):
|
||||
with open(path, encoding="utf8") as fp:
|
||||
modeling_sources.append(fp.read())
|
||||
|
||||
unused_attributes = []
|
||||
for config_param, default_value in zip(parameter_names, parameter_defaults):
|
||||
# `attributes` here is all the variant names for `config_param`
|
||||
attributes = [config_param]
|
||||
# some configuration classes have non-empty `attribute_map`, and both names could be used in the
|
||||
# corresponding modeling files. As long as one of them appears, it is fine.
|
||||
if config_param in reversed_attribute_map:
|
||||
attributes.append(reversed_attribute_map[config_param])
|
||||
|
||||
if not check_attribute_being_used(config_class, attributes, default_value, modeling_sources):
|
||||
unused_attributes.append(attributes[0])
|
||||
|
||||
return sorted(unused_attributes)
|
||||
|
||||
|
||||
def check_config_attributes():
|
||||
"""Check the arguments in `__init__` of all configuration classes are used in python files"""
|
||||
configs_with_unused_attributes = {}
|
||||
for _config_class in list(CONFIG_MAPPING.values()):
|
||||
# Skip deprecated models
|
||||
if "models.deprecated" in _config_class.__module__:
|
||||
continue
|
||||
# Some config classes are not in `CONFIG_MAPPING` (e.g. `CLIPVisionConfig`, `Blip2VisionConfig`, etc.)
|
||||
config_classes_in_module = [
|
||||
cls
|
||||
for name, cls in inspect.getmembers(
|
||||
inspect.getmodule(_config_class),
|
||||
lambda x: inspect.isclass(x)
|
||||
and issubclass(x, PreTrainedConfig)
|
||||
and inspect.getmodule(x) == inspect.getmodule(_config_class),
|
||||
)
|
||||
]
|
||||
for config_class in config_classes_in_module:
|
||||
unused_attributes = check_config_attributes_being_used(config_class)
|
||||
if len(unused_attributes) > 0:
|
||||
configs_with_unused_attributes[config_class.__name__] = unused_attributes
|
||||
|
||||
if len(configs_with_unused_attributes) > 0:
|
||||
error = "The following configuration classes contain unused attributes in the corresponding modeling files:\n"
|
||||
for name, attributes in configs_with_unused_attributes.items():
|
||||
error += f"{name}: {attributes}\n"
|
||||
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_config_attributes()
|
||||
95
utils/check_config_docstrings.py
Normal file
95
utils/check_config_docstrings.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "config_docstrings",
|
||||
"label": "Config docstrings",
|
||||
# Approximate: iterates CONFIG_MAPPING at runtime via inspect.getsource(), not cache globs.
|
||||
# Only configs registered in CONFIG_MAPPING are checked; deprecated models are skipped.
|
||||
"cache_globs": ["src/transformers/models/**/configuration_*.py"],
|
||||
"check_args": [],
|
||||
"fix_args": None,
|
||||
}
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_config_docstrings.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
||||
|
||||
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
|
||||
# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
|
||||
# For example, `[google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)`
|
||||
_re_checkpoint = re.compile(r"""(?s)@auto_docstring\(.*?checkpoint\s*=\s*["']([^"']+)["']""")
|
||||
|
||||
|
||||
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"DecisionTransformerConfig",
|
||||
"EncoderDecoderConfig",
|
||||
"MusicgenConfig",
|
||||
"RagConfig",
|
||||
"SpeechEncoderDecoderConfig",
|
||||
"TimmBackboneConfig",
|
||||
"TimmWrapperConfig",
|
||||
"VisionEncoderDecoderConfig",
|
||||
"VisionTextDualEncoderConfig",
|
||||
"GraniteConfig",
|
||||
"GraniteMoeConfig",
|
||||
"GraniteMoeHybridConfig",
|
||||
"Qwen3MoeConfig",
|
||||
"GraniteSpeechConfig",
|
||||
}
|
||||
|
||||
|
||||
def get_checkpoint_from_config_class(config_class):
|
||||
# source code of `config_class`
|
||||
config_source = inspect.getsource(config_class)
|
||||
checkpoints = _re_checkpoint.findall(config_source)
|
||||
return checkpoints[0] if checkpoints else None
|
||||
|
||||
|
||||
def check_config_docstrings_have_checkpoints():
|
||||
configs_without_checkpoint = []
|
||||
|
||||
for config_class in list(CONFIG_MAPPING.values()):
|
||||
# Skip deprecated models
|
||||
if "models.deprecated" in config_class.__module__:
|
||||
continue
|
||||
checkpoint = get_checkpoint_from_config_class(config_class)
|
||||
|
||||
name = config_class.__name__
|
||||
if checkpoint is None and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK:
|
||||
configs_without_checkpoint.append(name)
|
||||
|
||||
if len(configs_without_checkpoint) > 0:
|
||||
message = "\n".join(sorted(configs_without_checkpoint))
|
||||
raise ValueError(
|
||||
f"The following configurations don't contain any valid checkpoint:\n{message}\n\n"
|
||||
"The requirement is to include a link pointing to one of the models of this architecture in the "
|
||||
"docstring of the config classes listed above. The link should be passed to an `auto_docstring`"
|
||||
"decorator as follows `@auto_docstring(checkpoint='myorg/mymodel')."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_config_docstrings_have_checkpoints()
|
||||
1055
utils/check_copies.py
Normal file
1055
utils/check_copies.py
Normal file
File diff suppressed because it is too large
Load Diff
179
utils/check_doc_toc.py
Normal file
179
utils/check_doc_toc.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script is responsible for ensuring that all model docs are part of the `_toctree.yml` and cleaning the model
|
||||
section of the table of content by removing duplicates and sorting the entries in alphabetical order.
|
||||
|
||||
Usage (from the root of the repo):
|
||||
|
||||
Check that the table of content is properly sorted (used in `make check-repo`):
|
||||
|
||||
```bash
|
||||
python utils/check_doc_toc.py
|
||||
```
|
||||
|
||||
Auto-sort the table of content if it is not properly sorted (used in `make fix-repo`):
|
||||
|
||||
```bash
|
||||
python utils/check_doc_toc.py --fix_and_overwrite
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from collections import defaultdict
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "doc_toc",
|
||||
"label": "Documentation table of contents",
|
||||
# Also reads docs/source/en/_toctree.yml; the .md glob catches new/renamed doc files.
|
||||
"cache_globs": ["docs/**/*.md", "docs/source/en/_toctree.yml"],
|
||||
"check_args": [],
|
||||
"fix_args": ["--fix_and_overwrite"],
|
||||
}
|
||||
|
||||
ROOT = os.path.dirname(os.path.dirname(__file__))
|
||||
TOCTREE_PATH = os.path.join(ROOT, "docs", "source", "en", "_toctree.yml")
|
||||
DOC_PATH = os.path.join(ROOT, "docs", "source", "en", "model_doc")
|
||||
|
||||
|
||||
def clean_model_doc_toc(model_doc: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Cleans a section of the table of content of the model documentation (one specific modality) by removing duplicates
|
||||
and sorting models alphabetically.
|
||||
|
||||
Args:
|
||||
model_doc (`List[dict]`):
|
||||
The list of dictionaries extracted from the `_toctree.yml` file for this specific modality.
|
||||
|
||||
Returns:
|
||||
`List[dict]`: List of dictionaries like the input, but cleaned up and sorted.
|
||||
"""
|
||||
counts = defaultdict(int)
|
||||
for doc in model_doc:
|
||||
counts[doc["local"]] += 1
|
||||
duplicates = [key for key, value in counts.items() if value > 1]
|
||||
|
||||
new_doc = []
|
||||
for duplicate_key in duplicates:
|
||||
titles = list({doc["title"] for doc in model_doc if doc["local"] == duplicate_key})
|
||||
if len(titles) > 1:
|
||||
raise ValueError(
|
||||
f"{duplicate_key} is present several times in the documentation table of content at "
|
||||
"`docs/source/en/_toctree.yml` with different *Title* values. Choose one of those and remove the "
|
||||
"others."
|
||||
)
|
||||
# Only add this once
|
||||
new_doc.append({"local": duplicate_key, "title": titles[0]})
|
||||
|
||||
# Add none duplicate-keys
|
||||
new_doc.extend([doc for doc in model_doc if counts[doc["local"]] == 1])
|
||||
|
||||
# Sort
|
||||
return sorted(new_doc, key=lambda s: s["title"].lower())
|
||||
|
||||
|
||||
def ensure_all_models_in_toctree(model_doc: list[dict]):
|
||||
"""Make sure that all models in `model_doc` folder are also part of the `_toctree.yml`. Raise if it's not
|
||||
the case."""
|
||||
all_documented_models = {model_doc_file.removesuffix(".md") for model_doc_file in os.listdir(DOC_PATH)} - {"auto"}
|
||||
all_models_in_toctree = {
|
||||
model_entry["local"].removeprefix("model_doc/") for section in model_doc for model_entry in section["sections"]
|
||||
}
|
||||
|
||||
# everything alright
|
||||
if all_documented_models == all_models_in_toctree:
|
||||
return
|
||||
|
||||
documented_but_not_in_toctree = all_documented_models - all_models_in_toctree
|
||||
in_toctree_but_not_documented = all_models_in_toctree - all_documented_models
|
||||
|
||||
error_msg = ""
|
||||
if len(documented_but_not_in_toctree) > 0:
|
||||
error_msg += (
|
||||
f"{documented_but_not_in_toctree} appear(s) inside the folder `model_doc`, but not in the `_toctree.yml`. "
|
||||
"Please add it/them in their corresponding section inside the `_toctree.yml`."
|
||||
)
|
||||
if len(in_toctree_but_not_documented) > 0:
|
||||
if len(error_msg) > 0:
|
||||
error_msg += "\n"
|
||||
error_msg += (
|
||||
f"{in_toctree_but_not_documented} appear(s) in the `_toctree.yml`, but not inside the folder `model_doc`. "
|
||||
"Please add a corresponding `model.md` in `model_doc`."
|
||||
)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def check_model_doc(overwrite: bool = False):
|
||||
"""
|
||||
Check that the content of the table of content in `_toctree.yml` is up-to-date (i.e. it contains all models) and
|
||||
clean (no duplicates and sorted for the model API doc) and potentially auto-cleans it.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether to just check if the TOC is clean or to auto-clean it (when `overwrite=True`).
|
||||
"""
|
||||
with open(TOCTREE_PATH, encoding="utf-8") as f:
|
||||
content = yaml.safe_load(f.read())
|
||||
|
||||
# Get to the API doc
|
||||
api_idx = 0
|
||||
while content[api_idx]["title"] != "API":
|
||||
api_idx += 1
|
||||
api_doc = content[api_idx]["sections"]
|
||||
|
||||
# Then to the model doc
|
||||
model_idx = 0
|
||||
while api_doc[model_idx]["title"] != "Models":
|
||||
model_idx += 1
|
||||
|
||||
model_doc = api_doc[model_idx]["sections"]
|
||||
|
||||
# Make sure the toctree contains all models
|
||||
ensure_all_models_in_toctree(model_doc)
|
||||
|
||||
# Extract the modalities and clean them one by one.
|
||||
modalities_docs = [(idx, section) for idx, section in enumerate(model_doc) if "sections" in section]
|
||||
diff = False
|
||||
for idx, modality_doc in modalities_docs:
|
||||
old_modality_doc = modality_doc["sections"]
|
||||
new_modality_doc = clean_model_doc_toc(old_modality_doc)
|
||||
|
||||
if old_modality_doc != new_modality_doc:
|
||||
diff = True
|
||||
if overwrite:
|
||||
model_doc[idx]["sections"] = new_modality_doc
|
||||
|
||||
if diff:
|
||||
if overwrite:
|
||||
api_doc[model_idx]["sections"] = model_doc
|
||||
content[api_idx]["sections"] = api_doc
|
||||
with open(TOCTREE_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(yaml.dump(content, allow_unicode=True))
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model doc part of the table of content is not properly sorted, run `make fix-repo` to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_model_doc(args.fix_and_overwrite)
|
||||
2196
utils/check_docstrings.py
Normal file
2196
utils/check_docstrings.py
Normal file
File diff suppressed because it is too large
Load Diff
100
utils/check_doctest_list.py
Normal file
100
utils/check_doctest_list.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script is responsible for cleaning the list of doctests by making sure the entries all exist and are in
|
||||
alphabetical order.
|
||||
|
||||
Usage (from the root of the repo):
|
||||
|
||||
Check that the doctest list is properly sorted and all files exist (used in `make check-repo`):
|
||||
|
||||
```bash
|
||||
python utils/check_doctest_list.py
|
||||
```
|
||||
|
||||
Auto-sort the doctest list if it is not properly sorted (used in `make fix-repo`):
|
||||
|
||||
```bash
|
||||
python utils/check_doctest_list.py --fix_and_overwrite
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "doctest_list",
|
||||
"label": "Doctest list",
|
||||
# Over-approximation: the checker validates that paths in .txt list files exist and are
|
||||
# sorted. The broad globs ensure cache invalidation when source files are added/removed.
|
||||
"cache_globs": [
|
||||
"utils/not_doctested.txt",
|
||||
"utils/slow_documentation_tests.txt",
|
||||
"src/transformers/**/*.py",
|
||||
"docs/**/*.md",
|
||||
],
|
||||
"check_args": [],
|
||||
"fix_args": ["--fix_and_overwrite"],
|
||||
}
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_doctest_list.py
|
||||
REPO_PATH = "."
|
||||
DOCTEST_FILE_PATHS = ["not_doctested.txt", "slow_documentation_tests.txt"]
|
||||
|
||||
|
||||
def clean_doctest_list(doctest_file: str, overwrite: bool = False):
|
||||
"""
|
||||
Cleans the doctest in a given file.
|
||||
|
||||
Args:
|
||||
doctest_file (`str`):
|
||||
The path to the doctest file to check or clean.
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to fix problems. If `False`, will error when the file is not clean.
|
||||
"""
|
||||
non_existent_paths = []
|
||||
all_paths = []
|
||||
with open(doctest_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip().split(" ")[0]
|
||||
path = os.path.join(REPO_PATH, line)
|
||||
if not (os.path.isfile(path) or os.path.isdir(path)):
|
||||
non_existent_paths.append(line)
|
||||
all_paths.append(line)
|
||||
|
||||
if len(non_existent_paths) > 0:
|
||||
non_existent_paths = "\n".join([f"- {f}" for f in non_existent_paths])
|
||||
raise ValueError(f"`{doctest_file}` contains non-existent paths:\n{non_existent_paths}")
|
||||
|
||||
sorted_paths = sorted(all_paths)
|
||||
if all_paths != sorted_paths:
|
||||
if not overwrite:
|
||||
raise ValueError(
|
||||
f"Files in `{doctest_file}` are not in alphabetical order, run `make fix-repo` to fix "
|
||||
"this automatically."
|
||||
)
|
||||
with open(doctest_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(sorted_paths) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
for doctest_file in DOCTEST_FILE_PATHS:
|
||||
doctest_file = os.path.join(REPO_PATH, "utils", doctest_file)
|
||||
clean_doctest_list(doctest_file, args.fix_and_overwrite)
|
||||
265
utils/check_dummies.py
Normal file
265
utils/check_dummies.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script is responsible for making sure the dummies in utils/dummies_xxx.py are up to date with the main init.
|
||||
|
||||
Why dummies? This is to make sure that a user can always import all objects from `transformers`, even if they don't
|
||||
have the necessary extra libs installed. Those objects will then raise helpful error message whenever the user tries
|
||||
to access one of their methods.
|
||||
|
||||
Usage (from the root of the repo):
|
||||
|
||||
Check that the dummy files are up to date (used in `make check-repo`):
|
||||
|
||||
```bash
|
||||
python utils/check_dummies.py
|
||||
```
|
||||
|
||||
Update the dummy files if needed (used in `make fix-repo`):
|
||||
|
||||
```bash
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "dummies",
|
||||
"label": "Dummy objects",
|
||||
# Over-approximation: only reads __init__.py and utils/dummy_*_objects.py, but any
|
||||
# new public object added anywhere could require a dummy update.
|
||||
"cache_globs": ["src/transformers/**/*.py"],
|
||||
"check_args": [],
|
||||
"fix_args": ["--fix_and_overwrite"],
|
||||
}
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_dummies.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
||||
# Matches from xxx import bla
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
# Matches if not is_xxx_available()
|
||||
_re_test_backend = re.compile(r"^\s+if\s+not\s+\(?is\_[a-z_]*\_available\(\)")
|
||||
|
||||
|
||||
# Template for the dummy objects.
|
||||
DUMMY_CONSTANT = """
|
||||
{0} = None
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_CLASS = """
|
||||
class {0}(metaclass=DummyObject):
|
||||
_backends = {1}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_backends({0}, {1})
|
||||
"""
|
||||
|
||||
|
||||
def find_backend(line: str) -> str | None:
|
||||
"""
|
||||
Find one (or multiple) backend in a code line of the init.
|
||||
|
||||
Args:
|
||||
line (`str`): A code line in an init file.
|
||||
|
||||
Returns:
|
||||
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
|
||||
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
|
||||
`xxx_and_yyy` for instance).
|
||||
"""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def read_init() -> dict[str, list[str]]:
|
||||
"""
|
||||
Read the init and extract backend-specific objects.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: A dictionary mapping backend name to the list of object names requiring that backend.
|
||||
"""
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get to the point we do the actual imports for type checking
|
||||
line_index = 0
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
line_index += 1
|
||||
|
||||
backend_specific_objects = {}
|
||||
# Go through the end of the file
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
while not lines[line_index].startswith(" else:"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_single_line_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
# Single-line imports
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
# Multiple-line imports (with 3 indent level)
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
backend_specific_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
return backend_specific_objects
|
||||
|
||||
|
||||
def create_dummy_object(name: str, backend_name: str) -> str:
|
||||
"""
|
||||
Create the code for a dummy object.
|
||||
|
||||
Args:
|
||||
name (`str`): The name of the object.
|
||||
backend_name (`str`): The name of the backend required for that object.
|
||||
|
||||
Returns:
|
||||
`str`: The code of the dummy object.
|
||||
"""
|
||||
if name.isupper():
|
||||
return DUMMY_CONSTANT.format(name)
|
||||
elif name.islower():
|
||||
return DUMMY_FUNCTION.format(name, backend_name)
|
||||
else:
|
||||
return DUMMY_CLASS.format(name, backend_name)
|
||||
|
||||
|
||||
def create_dummy_files(backend_specific_objects: dict[str, list[str]] | None = None) -> dict[str, str]:
|
||||
"""
|
||||
Create the content of the dummy files.
|
||||
|
||||
Args:
|
||||
backend_specific_objects (`Dict[str, List[str]]`, *optional*):
|
||||
The mapping backend name to list of backend-specific objects. If not passed, will be obtained by calling
|
||||
`read_init()`.
|
||||
|
||||
Returns:
|
||||
`Dict[str, str]`: A dictionary mapping backend name to code of the corresponding backend file.
|
||||
"""
|
||||
if backend_specific_objects is None:
|
||||
backend_specific_objects = read_init()
|
||||
|
||||
dummy_files = {}
|
||||
|
||||
for backend, objects in backend_specific_objects.items():
|
||||
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
|
||||
dummy_file = "# This file is autogenerated by the command `make fix-repo`, do not edit.\n"
|
||||
dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
|
||||
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
|
||||
dummy_files[backend] = dummy_file
|
||||
|
||||
return dummy_files
|
||||
|
||||
|
||||
def check_dummies(overwrite: bool = False):
|
||||
"""
|
||||
Check if the dummy files are up to date and maybe `overwrite` with the right content.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, default to `False`):
|
||||
Whether or not to overwrite the content of the dummy files. Will raise an error if they are not up to date
|
||||
when `overwrite=False`.
|
||||
"""
|
||||
dummy_files = create_dummy_files()
|
||||
# For special correspondence backend name to shortcut as used in utils/dummy_xxx_objects.py
|
||||
short_names = {"torch": "pt"}
|
||||
|
||||
# Locate actual dummy modules and read their content.
|
||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||
dummy_file_paths = {
|
||||
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") for backend in dummy_files
|
||||
}
|
||||
|
||||
actual_dummies = {}
|
||||
for backend, file_path in dummy_file_paths.items():
|
||||
if os.path.isfile(file_path):
|
||||
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_dummies[backend] = f.read()
|
||||
else:
|
||||
actual_dummies[backend] = ""
|
||||
|
||||
# Compare actual with what they should be.
|
||||
for backend in dummy_files:
|
||||
if dummy_files[backend] != actual_dummies[backend]:
|
||||
if overwrite:
|
||||
print(
|
||||
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||
"__init__ has new objects."
|
||||
)
|
||||
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(dummy_files[backend])
|
||||
else:
|
||||
# Temporary fix to help people identify which objects introduced are not correctly protected.
|
||||
found = False
|
||||
for _actual, _dummy in zip(
|
||||
actual_dummies["torch"].split("class"), dummy_files["torch"].split("class")
|
||||
):
|
||||
if _actual != _dummy:
|
||||
actual_broken = _actual
|
||||
dummy_broken = _dummy
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
print("A transient error was found with the dummies, please investigate.")
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in "
|
||||
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n"
|
||||
f" It is likely the following objects are responsible, see these excerpts: \n"
|
||||
f"---------------------------------- Actual -------------------------------------\n"
|
||||
f" \n {actual_broken} \n"
|
||||
f"---------------------------------- Dummy -------------------------------------\n"
|
||||
f" \n {dummy_broken} \n"
|
||||
"Run `make fix-repo` to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_dummies(args.fix_and_overwrite)
|
||||
253
utils/check_import_complexity.py
Normal file
253
utils/check_import_complexity.py
Normal file
@@ -0,0 +1,253 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Check that `import transformers` does not pull in too many modules.
|
||||
|
||||
Traces the full import tree triggered by `import transformers` using a custom
|
||||
``importlib.abc.MetaPathFinder`` and counts every module that gets loaded.
|
||||
If the count exceeds ``MAX_IMPORT_COUNT`` the check fails, signalling a
|
||||
potential regression in import speed.
|
||||
|
||||
Usage:
|
||||
python utils/check_import_complexity.py # CI check mode
|
||||
python utils/check_import_complexity.py --display # show the full import tree
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import importlib.abc
|
||||
import sys
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
|
||||
MAX_IMPORT_COUNT = 1000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import-tree data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportNode:
|
||||
name: str
|
||||
children: list[ImportNode] = field(default_factory=list)
|
||||
|
||||
|
||||
class LoaderProxy(importlib.abc.Loader):
|
||||
"""Wrap a real loader to track the import stack during exec_module."""
|
||||
|
||||
def __init__(self, wrapped: Any, tracer: ImportTreeTracer, fullname: str):
|
||||
self._wrapped = wrapped
|
||||
self._tracer = tracer
|
||||
self._fullname = fullname
|
||||
|
||||
def create_module(self, spec):
|
||||
if hasattr(self._wrapped, "create_module"):
|
||||
return self._wrapped.create_module(spec)
|
||||
return None
|
||||
|
||||
def exec_module(self, module: ModuleType) -> None:
|
||||
self._tracer.push(self._fullname)
|
||||
try:
|
||||
if hasattr(self._wrapped, "exec_module"):
|
||||
self._wrapped.exec_module(module)
|
||||
elif hasattr(self._wrapped, "load_module"):
|
||||
self._wrapped.load_module(self._fullname)
|
||||
else:
|
||||
raise ImportError(f"Loader for {self._fullname!r} has neither exec_module nor load_module")
|
||||
finally:
|
||||
self._tracer.pop()
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._wrapped, name)
|
||||
|
||||
|
||||
class ImportTreeFinder(importlib.abc.MetaPathFinder):
|
||||
"""Intercept imports to build a parent/child tree of loaded modules."""
|
||||
|
||||
def __init__(self, tracer: ImportTreeTracer, original_meta_path: list[Any]):
|
||||
self._tracer = tracer
|
||||
self._original = list(original_meta_path)
|
||||
|
||||
def find_spec(self, fullname: str, path=None, target=None):
|
||||
if self._tracer.is_seen(fullname):
|
||||
return None
|
||||
|
||||
for finder in self._original:
|
||||
try:
|
||||
spec = finder.find_spec(fullname, path, target) if hasattr(finder, "find_spec") else None
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if spec is None:
|
||||
continue
|
||||
|
||||
self._tracer.record(fullname)
|
||||
|
||||
if spec.loader is not None:
|
||||
spec.loader = LoaderProxy(spec.loader, self._tracer, fullname)
|
||||
|
||||
return spec
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ImportTreeTracer:
|
||||
def __init__(self) -> None:
|
||||
self._local = threading.local()
|
||||
self._nodes: dict[str, ImportNode] = {}
|
||||
self._roots: list[ImportNode] = []
|
||||
self._seen: set[str] = set()
|
||||
|
||||
def _stack(self) -> list[str]:
|
||||
stack = getattr(self._local, "stack", None)
|
||||
if stack is None:
|
||||
stack = []
|
||||
self._local.stack = stack
|
||||
return stack
|
||||
|
||||
def is_seen(self, fullname: str) -> bool:
|
||||
return fullname in self._seen
|
||||
|
||||
def _get_or_create(self, fullname: str) -> ImportNode:
|
||||
if fullname not in self._nodes:
|
||||
self._nodes[fullname] = ImportNode(name=fullname)
|
||||
return self._nodes[fullname]
|
||||
|
||||
def record(self, fullname: str) -> None:
|
||||
if fullname in self._seen:
|
||||
return
|
||||
self._seen.add(fullname)
|
||||
node = self._get_or_create(fullname)
|
||||
stack = self._stack()
|
||||
if stack:
|
||||
parent = self._get_or_create(stack[-1])
|
||||
if all(c.name != fullname for c in parent.children):
|
||||
parent.children.append(node)
|
||||
else:
|
||||
if all(r.name != fullname for r in self._roots):
|
||||
self._roots.append(node)
|
||||
|
||||
def push(self, fullname: str) -> None:
|
||||
self._stack().append(fullname)
|
||||
|
||||
def pop(self) -> None:
|
||||
stack = self._stack()
|
||||
if stack:
|
||||
stack.pop()
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self._seen)
|
||||
|
||||
@property
|
||||
def roots(self) -> list[ImportNode]:
|
||||
return self._roots
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tracing entry-point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def trace_import(target: str) -> ImportTreeTracer:
|
||||
tracer = ImportTreeTracer()
|
||||
original_meta_path = list(sys.meta_path)
|
||||
finder = ImportTreeFinder(tracer, original_meta_path)
|
||||
sys.meta_path.insert(0, finder)
|
||||
try:
|
||||
importlib.import_module(target)
|
||||
finally:
|
||||
try:
|
||||
sys.meta_path.remove(finder)
|
||||
except ValueError:
|
||||
pass
|
||||
return tracer
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Display helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def format_tree(nodes: list[ImportNode]) -> str:
|
||||
lines: list[str] = []
|
||||
|
||||
def _walk(node: ImportNode, prefix: str, is_last: bool) -> None:
|
||||
connector = "└── " if is_last else "├── "
|
||||
lines.append(f"{prefix}{connector}{node.name}")
|
||||
child_prefix = prefix + (" " if is_last else "│ ")
|
||||
for i, child in enumerate(node.children):
|
||||
_walk(child, child_prefix, i == len(node.children) - 1)
|
||||
|
||||
for i, root in enumerate(nodes):
|
||||
_walk(root, "", i == len(nodes) - 1)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Check import complexity for `import transformers`.")
|
||||
parser.add_argument(
|
||||
"--display",
|
||||
action="store_true",
|
||||
help="Display the full import tree (for debugging regressions).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-count",
|
||||
type=int,
|
||||
default=MAX_IMPORT_COUNT,
|
||||
help=f"Maximum allowed number of imported modules (default: {MAX_IMPORT_COUNT}).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
tracer = trace_import("transformers")
|
||||
except Exception as exc:
|
||||
print(f"ERROR: `import transformers` failed: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if args.display:
|
||||
print(format_tree(tracer.roots))
|
||||
print()
|
||||
print(f"Total modules imported: {tracer.count}")
|
||||
return 0
|
||||
|
||||
if tracer.count > args.max_count:
|
||||
print(
|
||||
f"Import complexity regression: `import transformers` triggered {tracer.count} module imports "
|
||||
f"(maximum allowed: {args.max_count}).\n"
|
||||
f"\n"
|
||||
f"Run the following command to display the full import tree and identify the cause:\n"
|
||||
f"\n"
|
||||
f" python utils/check_import_complexity.py --display\n"
|
||||
)
|
||||
return 1
|
||||
|
||||
print(f"Import complexity OK: {tracer.count} modules (max {args.max_count})")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
359
utils/check_inits.py
Normal file
359
utils/check_inits.py
Normal file
@@ -0,0 +1,359 @@
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that checks the custom inits of Transformers are well-defined: Transformers uses init files that delay the
|
||||
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
|
||||
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
|
||||
delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the
|
||||
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. The goal of this
|
||||
script is to check the objects defined in both halves are the same.
|
||||
|
||||
This also checks the main init properly references all submodules, even if it doesn't import anything from them: every
|
||||
submodule should be defined as a key of `_import_structure`, with an empty list as value potentially, or the submodule
|
||||
won't be importable.
|
||||
|
||||
Use from the root of the repo with:
|
||||
|
||||
```bash
|
||||
python utils/check_inits.py
|
||||
```
|
||||
|
||||
for a check that will error in case of inconsistencies (used by `make check-repo`).
|
||||
|
||||
There is no auto-fix possible here sadly :-(
|
||||
"""
|
||||
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "inits",
|
||||
"label": "Init files",
|
||||
"cache_globs": ["src/transformers/**/__init__.py"],
|
||||
"check_args": [],
|
||||
"fix_args": None,
|
||||
}
|
||||
|
||||
# Path is set with the intent you should run this script from the root of the repo.
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
||||
# Catches a one-line _import_struct = {xxx}
|
||||
_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}")
|
||||
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
|
||||
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
|
||||
# Catches a line if not is_foo_available
|
||||
_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)")
|
||||
# Catches a line _import_struct["bla"].append("foo")
|
||||
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
|
||||
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
|
||||
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]")
|
||||
# Catches a line with an object between quotes and a comma: "MyModel",
|
||||
_re_quote_object = re.compile(r'^\s+"([^"]+)",')
|
||||
# Catches a line with objects between brackets only: ["foo", "bar"],
|
||||
_re_between_brackets = re.compile(r"^\s+\[([^\]]+)\]")
|
||||
# Catches a line with from foo import bar, bla, boo
|
||||
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
# Catches a line with try:
|
||||
_re_try = re.compile(r"^\s*try:")
|
||||
# Catches a line with else:
|
||||
_re_else = re.compile(r"^\s*else:")
|
||||
|
||||
|
||||
def find_backend(line: str) -> str | None:
|
||||
"""
|
||||
Find one (or multiple) backend in a code line of the init.
|
||||
|
||||
Args:
|
||||
line (`str`): A code line of the main init.
|
||||
|
||||
Returns:
|
||||
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
|
||||
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
|
||||
`xxx_and_yyy` for instance).
|
||||
"""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def parse_init(init_file) -> tuple[dict[str, list[str]], dict[str, list[str]]] | None:
|
||||
"""
|
||||
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
|
||||
defined.
|
||||
|
||||
Args:
|
||||
init_file (`str`): Path to the init file to inspect.
|
||||
|
||||
Returns:
|
||||
`Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]`: A tuple of two dictionaries mapping backends to list of
|
||||
imported objects, one for the `_import_structure` part of the init and one for the `TYPE_CHECKING` part of the
|
||||
init. Returns `None` if the init is not a custom init.
|
||||
"""
|
||||
with open(init_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get the to `_import_structure` definition.
|
||||
line_index = 0
|
||||
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"):
|
||||
line_index += 1
|
||||
|
||||
# If this is a traditional init, just return.
|
||||
if line_index >= len(lines):
|
||||
return None
|
||||
|
||||
# First grab the objects without a specific backend in _import_structure
|
||||
objects = []
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
|
||||
line = lines[line_index]
|
||||
# If we have everything on a single line, let's deal with it.
|
||||
if _re_one_line_import_struct.search(line):
|
||||
content = _re_one_line_import_struct.search(line).groups()[0]
|
||||
imports = re.findall(r"\[([^\]]+)\]", content)
|
||||
for imp in imports:
|
||||
objects.extend([obj[1:-1] for obj in imp.split(", ")])
|
||||
line_index += 1
|
||||
continue
|
||||
single_line_import_search = _re_import_struct_key_value.search(line)
|
||||
if single_line_import_search is not None:
|
||||
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif line.startswith(" " * 8 + '"'):
|
||||
objects.append(line[9:-3])
|
||||
line_index += 1
|
||||
|
||||
# Those are stored with the key "none".
|
||||
import_dict_objects = {"none": objects}
|
||||
|
||||
# Let's continue with backend-specific objects in _import_structure
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
# If the line is an if not is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
# Check if the backend declaration is inside a try block:
|
||||
if _re_try.search(lines[line_index - 1]) is None:
|
||||
backend = None
|
||||
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Scroll until we hit the else block of try-except-else
|
||||
while _re_else.search(lines[line_index]) is None:
|
||||
line_index += 1
|
||||
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
|
||||
line = lines[line_index]
|
||||
if _re_import_struct_add_one.search(line) is not None:
|
||||
objects.append(_re_import_struct_add_one.search(line).groups()[0])
|
||||
elif _re_import_struct_add_many.search(line) is not None:
|
||||
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
|
||||
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif _re_between_brackets.search(line) is not None:
|
||||
imports = _re_between_brackets.search(line).groups()[0].split(", ")
|
||||
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif _re_quote_object.search(line) is not None:
|
||||
objects.append(_re_quote_object.search(line).groups()[0])
|
||||
elif line.startswith(" " * 8 + '"'):
|
||||
objects.append(line[9:-3])
|
||||
elif line.startswith(" " * 12 + '"'):
|
||||
objects.append(line[13:-3])
|
||||
line_index += 1
|
||||
|
||||
import_dict_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
# At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend
|
||||
objects = []
|
||||
while (
|
||||
line_index < len(lines)
|
||||
and find_backend(lines[line_index]) is None
|
||||
and not lines[line_index].startswith("else")
|
||||
):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 8):
|
||||
objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
|
||||
type_hint_objects = {"none": objects}
|
||||
|
||||
# Let's continue with backend-specific objects
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
# Check if the backend declaration is inside a try block:
|
||||
if _re_try.search(lines[line_index - 1]) is None:
|
||||
backend = None
|
||||
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Scroll until we hit the else block of try-except-else
|
||||
while _re_else.search(lines[line_index]) is None:
|
||||
line_index += 1
|
||||
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
type_hint_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
return import_dict_objects, type_hint_objects
|
||||
|
||||
|
||||
def analyze_results(import_dict_objects: dict[str, list[str]], type_hint_objects: dict[str, list[str]]) -> list[str]:
|
||||
"""
|
||||
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
|
||||
|
||||
Args:
|
||||
import_dict_objects (`Dict[str, List[str]]`):
|
||||
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
|
||||
list of imported objects.
|
||||
type_hint_objects (`Dict[str, List[str]]`):
|
||||
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
|
||||
list of imported objects.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of errors corresponding to mismatches.
|
||||
"""
|
||||
|
||||
def find_duplicates(seq):
|
||||
return [k for k, v in collections.Counter(seq).items() if v > 1]
|
||||
|
||||
# If one backend is missing from the other part of the init, error early.
|
||||
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()):
|
||||
return ["Both sides of the init do not have the same backends!"]
|
||||
|
||||
errors = []
|
||||
# Find all errors.
|
||||
for key in import_dict_objects:
|
||||
# Duplicate imports in any half.
|
||||
duplicate_imports = find_duplicates(import_dict_objects[key])
|
||||
if duplicate_imports:
|
||||
errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}")
|
||||
duplicate_type_hints = find_duplicates(type_hint_objects[key])
|
||||
if duplicate_type_hints:
|
||||
errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}")
|
||||
|
||||
# Missing imports in either part of the init.
|
||||
if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])):
|
||||
name = "base imports" if key == "none" else f"{key} backend"
|
||||
errors.append(f"Differences for {name}:")
|
||||
for a in type_hint_objects[key]:
|
||||
if a not in import_dict_objects[key]:
|
||||
errors.append(f" {a} in TYPE_HINT but not in _import_structure.")
|
||||
for a in import_dict_objects[key]:
|
||||
if a not in type_hint_objects[key]:
|
||||
errors.append(f" {a} in _import_structure but not in TYPE_HINT.")
|
||||
return errors
|
||||
|
||||
|
||||
def get_transformers_submodules() -> list[str]:
|
||||
"""
|
||||
Returns the list of Transformers submodules.
|
||||
"""
|
||||
submodules = []
|
||||
for path, directories, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
for folder in directories:
|
||||
# Ignore private modules
|
||||
if folder.startswith("_"):
|
||||
directories.remove(folder)
|
||||
continue
|
||||
# Ignore leftovers from branches (empty folders apart from pycache)
|
||||
if len(list((Path(path) / folder).glob("*.py"))) == 0:
|
||||
continue
|
||||
short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS))
|
||||
submodule = short_path.replace(os.path.sep, ".")
|
||||
submodules.append(submodule)
|
||||
for fname in files:
|
||||
if fname == "__init__.py":
|
||||
continue
|
||||
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
|
||||
submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
|
||||
if len(submodule.split(".")) == 1:
|
||||
submodules.append(submodule)
|
||||
return submodules
|
||||
|
||||
|
||||
IGNORE_SUBMODULES = [
|
||||
"convert_pytorch_checkpoint_to_tf2",
|
||||
"models.esm.openfold_utils",
|
||||
"safetensors_conversion",
|
||||
"modeling_gguf_pytorch_utils",
|
||||
"kernels.falcon_mamba",
|
||||
"kernels",
|
||||
]
|
||||
|
||||
|
||||
def check_submodules():
|
||||
"""
|
||||
Check all submodules of Transformers are properly registered in the main init. Error otherwise.
|
||||
"""
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
||||
|
||||
import_structure_keys = set(transformers._import_structure.keys())
|
||||
# This contains all the base keys of the _import_structure object defined in the init, but if the user is missing
|
||||
# some optional dependencies, they may not have all of them. Thus we read the init to read all additions and
|
||||
# (potentiall re-) add them.
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r") as f:
|
||||
init_content = f.read()
|
||||
import_structure_keys.update(set(re.findall(r"import_structure\[\"([^\"]*)\"\]", init_content)))
|
||||
|
||||
module_not_registered = [
|
||||
module
|
||||
for module in get_transformers_submodules()
|
||||
if module not in IGNORE_SUBMODULES and module not in import_structure_keys
|
||||
]
|
||||
|
||||
if len(module_not_registered) > 0:
|
||||
list_of_modules = "\n".join(f"- {module}" for module in module_not_registered)
|
||||
raise ValueError(
|
||||
"The following submodules are not properly registered in the main init of Transformers:\n"
|
||||
f"{list_of_modules}\n"
|
||||
"Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This entire files needs an overhaul
|
||||
pass
|
||||
58
utils/check_model_tester.py
Normal file
58
utils/check_model_tester.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
from get_test_info import get_tester_classes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
failures = []
|
||||
|
||||
pattern = os.path.join("tests", "models", "**", "test_modeling_*.py")
|
||||
test_files = glob.glob(pattern)
|
||||
|
||||
for test_file in test_files:
|
||||
tester_classes = get_tester_classes(test_file)
|
||||
for tester_class in tester_classes:
|
||||
# A few tester classes don't have `parent` parameter in `__init__`.
|
||||
# TODO: deal this better
|
||||
try:
|
||||
tester = tester_class(parent=None)
|
||||
except Exception:
|
||||
continue
|
||||
if hasattr(tester, "get_config"):
|
||||
config = tester.get_config()
|
||||
for k, v in config.to_dict().items():
|
||||
if isinstance(v, int):
|
||||
target = None
|
||||
if k == "vocab_size":
|
||||
target = 100
|
||||
elif k == "max_position_embeddings":
|
||||
target = 128
|
||||
elif k in ["hidden_size", "d_model"]:
|
||||
target = 40
|
||||
elif k == ["num_layers", "num_hidden_layers", "num_encoder_layers", "num_decoder_layers"]:
|
||||
target = 5
|
||||
if target is not None and v > target:
|
||||
failures.append(
|
||||
f"{tester_class.__name__} will produce a `config` of type `{config.__class__.__name__}`"
|
||||
f' with config["{k}"] = {v} which is too large for testing! Set its value to be smaller'
|
||||
f" than {target}."
|
||||
)
|
||||
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
124
utils/check_modeling_rules_doc.py
Normal file
124
utils/check_modeling_rules_doc.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Keep `## Rules reference` section of docs/source/en/modeling_rules.md in sync
|
||||
with the rules defined in utils/rules.toml via the installed mlinter package.
|
||||
|
||||
Usage (from the root of the repo):
|
||||
|
||||
Check everything is up to date (used in ``make check-repo``):
|
||||
|
||||
```bash
|
||||
python utils/check_modeling_rules_doc.py
|
||||
```
|
||||
|
||||
Auto-regenerate if out of date (used in ``make fix-repo``):
|
||||
|
||||
```bash
|
||||
python utils/check_modeling_rules_doc.py --fix_and_overwrite
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "modeling_rules_doc",
|
||||
"label": "Modeling rules documentation",
|
||||
# Depends on utils/rules.toml plus the installed `mlinter` package output,
|
||||
# which cannot be fully expressed as repo cache globs for the checker cache.
|
||||
"cache_globs": None,
|
||||
"check_args": ["--rules-toml", "utils/rules.toml"],
|
||||
"fix_args": ["--rules-toml", "utils/rules.toml", "--fix_and_overwrite"],
|
||||
}
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
DOC_PATH = ROOT / "docs" / "source" / "en" / "modeling_rules.md"
|
||||
RULES_TOML_PATH = ROOT / "utils" / "rules.toml"
|
||||
|
||||
BEGIN_MARKER = "<!-- BEGIN RULES REFERENCE -->"
|
||||
END_MARKER = "<!-- END RULES REFERENCE -->"
|
||||
|
||||
|
||||
def _require_mlinter():
|
||||
try:
|
||||
import mlinter
|
||||
from mlinter import mlinter as mlinter_impl
|
||||
except ModuleNotFoundError as error:
|
||||
raise ModuleNotFoundError(
|
||||
"This script requires the standalone `transformers-mlinter` package. "
|
||||
'Install the repo quality dependencies with `pip install -e ".[quality]"` and retry.'
|
||||
) from error
|
||||
|
||||
return mlinter, mlinter_impl
|
||||
|
||||
|
||||
def _resolve_path(path: Path) -> Path:
|
||||
return path if path.is_absolute() else ROOT / path
|
||||
|
||||
|
||||
def generate_rules_reference(rule_specs_path: Path = RULES_TOML_PATH) -> str:
|
||||
mlinter, mlinter_impl = _require_mlinter()
|
||||
# Reuse mlinter's registry-switching helper so docs rendering reflects the repo-local rule file.
|
||||
with mlinter_impl._using_rule_specs(_resolve_path(rule_specs_path)):
|
||||
return mlinter.render_rules_reference()
|
||||
|
||||
|
||||
def check_modeling_rules_doc(overwrite: bool = False, rule_specs_path: Path = RULES_TOML_PATH):
|
||||
with DOC_PATH.open(encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
begin_idx = content.find(BEGIN_MARKER)
|
||||
end_idx = content.find(END_MARKER)
|
||||
if begin_idx == -1 or end_idx == -1:
|
||||
raise ValueError(
|
||||
f"Could not find {BEGIN_MARKER} and {END_MARKER} markers in {DOC_PATH}. "
|
||||
"These markers delimit the auto-generated rules reference section."
|
||||
)
|
||||
|
||||
after_begin = begin_idx + len(BEGIN_MARKER)
|
||||
expected = "\n\n" + generate_rules_reference(rule_specs_path) + "\n"
|
||||
current = content[after_begin:end_idx]
|
||||
|
||||
if current == expected:
|
||||
return
|
||||
|
||||
if overwrite:
|
||||
new_content = content[:after_begin] + expected + content[end_idx:]
|
||||
with DOC_PATH.open("w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
print(f"Updated rules reference in {DOC_PATH}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"The rules reference section in docs/source/en/modeling_rules.md is out of sync "
|
||||
"with utils/rules.toml. Run `make fix-repo` to regenerate it."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--rules-toml",
|
||||
type=Path,
|
||||
default=RULES_TOML_PATH,
|
||||
help="Path to a rules TOML file. Defaults to utils/rules.toml.",
|
||||
)
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
check_modeling_rules_doc(args.fix_and_overwrite, args.rules_toml)
|
||||
except ModuleNotFoundError as error:
|
||||
raise SystemExit(str(error)) from error
|
||||
60
utils/check_modeling_structure.py
Normal file
60
utils/check_modeling_structure.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Thin local entrypoint for the external mlinter package."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "modeling_structure",
|
||||
"label": "Modeling file structure",
|
||||
"cache_globs": [
|
||||
"src/transformers/models/**/modeling_*.py",
|
||||
"src/transformers/models/**/modular_*.py",
|
||||
"src/transformers/models/**/configuration_*.py",
|
||||
],
|
||||
"check_args": ["--rules-toml", "utils/rules.toml"],
|
||||
"fix_args": None,
|
||||
}
|
||||
|
||||
RULES_TOML_PATH = Path(__file__).resolve().with_name("rules.toml")
|
||||
|
||||
|
||||
def _require_mlinter():
|
||||
try:
|
||||
import mlinter
|
||||
except ModuleNotFoundError as error:
|
||||
raise ModuleNotFoundError(
|
||||
"This script requires the standalone `transformers-mlinter` package. "
|
||||
'Install the repo quality dependencies with `pip install -e ".[quality]"` and retry.'
|
||||
) from error
|
||||
|
||||
return mlinter
|
||||
|
||||
|
||||
def _add_default_rules_toml(argv: list[str]) -> list[str]:
|
||||
if any(arg == "--rules-toml" or arg.startswith("--rules-toml=") for arg in argv[1:]):
|
||||
return argv
|
||||
|
||||
return [argv[0], "--rules-toml", str(RULES_TOML_PATH), *argv[1:]]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
sys.argv = _add_default_rules_toml(sys.argv)
|
||||
raise SystemExit(_require_mlinter().main())
|
||||
except ModuleNotFoundError as error:
|
||||
raise SystemExit(str(error)) from error
|
||||
284
utils/check_modular_conversion.py
Normal file
284
utils/check_modular_conversion.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import argparse
|
||||
import difflib
|
||||
import glob
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from functools import partial
|
||||
|
||||
from create_dependency_mapping import find_priority_list
|
||||
|
||||
# Console for rich printing
|
||||
from modular_model_converter import convert_modular_file, run_ruff
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "modular_conversion",
|
||||
"label": "Modular file conversions",
|
||||
# Globs the modular sources; also reads generated modeling_*.py at runtime for diffing.
|
||||
"cache_globs": ["src/transformers/models/**/modular_*.py", "src/transformers/models/**/modeling_*.py"],
|
||||
"check_args": [],
|
||||
"fix_args": ["--fix_and_overwrite"],
|
||||
}
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
console = Console()
|
||||
|
||||
BACKUP_EXT = ".modular_backup"
|
||||
|
||||
|
||||
def process_file(
|
||||
modular_file_path,
|
||||
generated_modeling_content,
|
||||
file_type="modeling_",
|
||||
show_diff=True,
|
||||
):
|
||||
file_name_prefix = file_type.split(".*")[0]
|
||||
file_name_suffix = file_type.split(".*")[-1] if ".*" in file_type else ""
|
||||
file_path = modular_file_path.replace("modular_", f"{file_name_prefix}_").replace(".py", f"{file_name_suffix}.py")
|
||||
# Read the actual modeling file
|
||||
with open(file_path, "r", encoding="utf-8") as modeling_file:
|
||||
content = modeling_file.read()
|
||||
diff = difflib.unified_diff(
|
||||
generated_modeling_content[file_type].splitlines(),
|
||||
content.splitlines(),
|
||||
fromfile=f"{file_path}_generated",
|
||||
tofile=f"{file_path}",
|
||||
lineterm="",
|
||||
)
|
||||
diff_list = list(diff)
|
||||
# Check for differences
|
||||
if diff_list:
|
||||
# first save the copy of the original file, to be able to restore it later
|
||||
shutil.copy(file_path, file_path + BACKUP_EXT)
|
||||
# we always save the generated content, to be able to update dependant files
|
||||
with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file:
|
||||
modeling_file.write(generated_modeling_content[file_type])
|
||||
if not show_diff:
|
||||
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
|
||||
if show_diff:
|
||||
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
|
||||
diff_text = "\n".join(diff_list)
|
||||
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
|
||||
console.print(syntax)
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def convert_and_run_ruff(modular_file_path: str) -> dict[str, str]:
|
||||
"""From a modular file, convert it and return all the contents of the file as string.
|
||||
We need this function, because `ruff` needs the final filename to apply all rules correctly, so to get the
|
||||
output as a string, we need to save a temporary file with similar name, run ruff, and re-read the temporary file"""
|
||||
# Generate the expected modeling content
|
||||
generated_modeling_content = convert_modular_file(modular_file_path)
|
||||
# Temporary save the files with similar names to run `ruff` correctly, then re-read the result after linting/formatting
|
||||
for file_type in generated_modeling_content:
|
||||
file_name_prefix = file_type.split(".*")[0]
|
||||
file_name_suffix = file_type.split(".*")[-1] if ".*" in file_type else ""
|
||||
temp_file_name = modular_file_path.replace("modular_", f"{file_name_prefix}_").replace(
|
||||
".py", f"_temp_pattern__{file_name_suffix}.py"
|
||||
)
|
||||
# Write the file only temporarily
|
||||
with open(temp_file_name, "w") as f:
|
||||
f.write(generated_modeling_content[file_type])
|
||||
# Run ruff on the new file (with similar name pattern as the original one)
|
||||
run_ruff(temp_file_name)
|
||||
with open(temp_file_name, "r") as f:
|
||||
generated_modeling_content[file_type] = f.read()
|
||||
# delete file
|
||||
os.remove(temp_file_name)
|
||||
|
||||
return generated_modeling_content
|
||||
|
||||
|
||||
def compare_files(modular_file_path, show_diff=True):
|
||||
# Generate the expected modeling content
|
||||
generated_modeling_content = convert_and_run_ruff(modular_file_path)
|
||||
diff = 0
|
||||
for file_type in generated_modeling_content:
|
||||
diff += process_file(modular_file_path, generated_modeling_content, file_type, show_diff)
|
||||
return diff
|
||||
|
||||
|
||||
# Changes to any of these files can alter the generated output for every modular model,
|
||||
# so touching them must force a full re-check (see `converter_changed_in_diff`).
|
||||
CONVERTER_FILES = {
|
||||
"utils/modular_model_converter.py",
|
||||
"utils/create_dependency_mapping.py",
|
||||
}
|
||||
|
||||
|
||||
def _get_modified_files():
|
||||
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
|
||||
return (
|
||||
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
|
||||
.decode("utf-8")
|
||||
.split()
|
||||
)
|
||||
|
||||
|
||||
def get_models_in_diff():
|
||||
"""
|
||||
Finds all models that have been modified in the diff.
|
||||
|
||||
Returns:
|
||||
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
|
||||
"""
|
||||
modified_files = _get_modified_files()
|
||||
|
||||
# Matches both modelling files and tests
|
||||
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
|
||||
model_names = set()
|
||||
for file_path in relevant_modified_files:
|
||||
model_name = file_path.split("/")[-2]
|
||||
model_names.add(model_name)
|
||||
return model_names
|
||||
|
||||
|
||||
def converter_changed_in_diff():
|
||||
"""Whether the diff touches a file that can change conversion output for every model."""
|
||||
return any(f in CONVERTER_FILES for f in _get_modified_files())
|
||||
|
||||
|
||||
def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
|
||||
"""
|
||||
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
|
||||
|
||||
Model is in the diff -> not guaranteed to have no differences
|
||||
Dependency is in the diff -> not guaranteed to have no differences
|
||||
Otherwise -> guaranteed to have no differences
|
||||
|
||||
Args:
|
||||
modular_file_path: The path to the modular file.
|
||||
dependencies: A dictionary containing the dependencies of each modular file.
|
||||
models_in_diff: A set containing the names of the models that have been modified.
|
||||
|
||||
Returns:
|
||||
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
|
||||
"""
|
||||
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
||||
if model_name in models_in_diff:
|
||||
return False
|
||||
for dep in dependencies[modular_file_path]:
|
||||
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
|
||||
dependency_model_name = dep.split(".")[-2]
|
||||
if dependency_model_name in models_in_diff:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
|
||||
parser.add_argument(
|
||||
"--files", default=["all"], type=str, nargs="+", help="List of modular_xxx.py files to compare."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
|
||||
)
|
||||
parser.add_argument("--check_all", action="store_true", help="Check all files, not just the ones in the diff.")
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="The number of workers to run. Default is -1, which means the number of CPU cores.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files == ["all"]:
|
||||
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||
|
||||
if args.num_workers == -1:
|
||||
args.num_workers = multiprocessing.cpu_count()
|
||||
|
||||
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
|
||||
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
|
||||
# script will do nothing.
|
||||
current_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip()
|
||||
if current_branch == "main":
|
||||
console.print(
|
||||
"[bold red]You are developing on the main branch. We cannot identify the list of changed files and will have to check all files. This may take a while.[/bold red]"
|
||||
)
|
||||
models_in_diff = {file_path.split("/")[-2] for file_path in args.files}
|
||||
elif converter_changed_in_diff():
|
||||
# The converter (or its dependency-mapping helper) is in the diff: its output can shift
|
||||
# for any model, so restrict-by-diff would miss regressions. Force a full check.
|
||||
console.print("[bold yellow]Converter change detected in diff; checking all modular files.[/bold yellow]")
|
||||
args.check_all = True
|
||||
models_in_diff = {file_path.split("/")[-2] for file_path in args.files}
|
||||
else:
|
||||
models_in_diff = get_models_in_diff()
|
||||
if not models_in_diff and not args.check_all:
|
||||
exit(0)
|
||||
|
||||
non_matching_files = []
|
||||
ordered_files, dependencies = find_priority_list(args.files)
|
||||
flat_ordered_files = [item for sublist in ordered_files for item in sublist]
|
||||
|
||||
# ordered_files is a *sorted* list of lists of filepaths
|
||||
# - files from the first list do NOT depend on other files
|
||||
# - files in the second list depend on files from the first list
|
||||
# - files in the third list depend on files from the second and (optionally) the first list
|
||||
# - ... and so on
|
||||
# files (models) within the same list are *independent* of each other;
|
||||
# we start applying modular conversion to each list in parallel, starting from the first list
|
||||
try:
|
||||
for dependency_level_files in ordered_files:
|
||||
# Filter files guaranteed no diff
|
||||
files_to_check = []
|
||||
for file_path in dependency_level_files:
|
||||
if args.check_all or not guaranteed_no_diff(file_path, dependencies, models_in_diff):
|
||||
files_to_check.append(file_path)
|
||||
|
||||
if not files_to_check:
|
||||
continue
|
||||
|
||||
# Process files with diff
|
||||
num_workers = min(args.num_workers, len(files_to_check))
|
||||
with multiprocessing.Pool(num_workers) as p:
|
||||
try:
|
||||
is_changed_flags = p.map(
|
||||
partial(compare_files, show_diff=not args.fix_and_overwrite),
|
||||
files_to_check,
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[bold red]Failed to convert one or more files in batch: {files_to_check}[/bold red]"
|
||||
)
|
||||
console.print(f"[bold red]Error: {e}[/bold red]")
|
||||
# Try to process files individually to identify which one failed
|
||||
is_changed_flags = []
|
||||
for file_path in files_to_check:
|
||||
try:
|
||||
result = compare_files(file_path, show_diff=not args.fix_and_overwrite)
|
||||
is_changed_flags.append(result)
|
||||
except Exception as individual_error:
|
||||
console.print(f"[bold red]Failed to convert {file_path}: {individual_error}[/bold red]")
|
||||
is_changed_flags.append(0) # Mark as no change to continue processing
|
||||
|
||||
# Collect changed files and their original paths
|
||||
for is_changed, file_path in zip(is_changed_flags, files_to_check):
|
||||
if is_changed:
|
||||
non_matching_files.append(file_path)
|
||||
|
||||
# Update changed models, after each round of conversions
|
||||
# (save model folder name)
|
||||
models_in_diff.add(file_path.split("/")[-2])
|
||||
|
||||
finally:
|
||||
# Restore overwritten files by modular (if needed)
|
||||
backup_files = glob.glob("**/*" + BACKUP_EXT, recursive=True)
|
||||
for backup_file_path in backup_files:
|
||||
overwritten_path = backup_file_path.replace(BACKUP_EXT, "")
|
||||
if not args.fix_and_overwrite and os.path.exists(overwritten_path):
|
||||
shutil.copy(backup_file_path, overwritten_path)
|
||||
os.remove(backup_file_path)
|
||||
|
||||
if non_matching_files and not args.fix_and_overwrite:
|
||||
diff_models = set(file_path.split("/")[-2] for file_path in non_matching_files) # noqa
|
||||
models_str = "\n - " + "\n - ".join(sorted(diff_models))
|
||||
raise ValueError(f"Some diff and their modeling code did not match. Models in diff:{models_str}")
|
||||
101
utils/check_pipeline_typing.py
Normal file
101
utils/check_pipeline_typing.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import re
|
||||
|
||||
from transformers.pipelines import SUPPORTED_TASKS, Pipeline
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "pipeline_typing",
|
||||
"label": "Pipeline type hints",
|
||||
"cache_globs": ["src/transformers/pipelines/__init__.py"],
|
||||
"check_args": [],
|
||||
"fix_args": ["--fix_and_overwrite"],
|
||||
}
|
||||
|
||||
HEADER = """
|
||||
# fmt: off
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# The part of the file below was automatically generated from the code.
|
||||
# Do NOT edit this part of the file manually as any edits will be overwritten by the generation
|
||||
# of the file. If any change should be done, please apply the changes to the `pipeline` function
|
||||
# below and run `python utils/check_pipeline_typing.py --fix_and_overwrite` to update the file.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from typing import Literal, overload
|
||||
|
||||
|
||||
"""
|
||||
|
||||
FOOTER = """
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# The part of the file above was automatically generated from the code.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# fmt: on
|
||||
"""
|
||||
|
||||
TASK_PATTERN = "task: str | None = None"
|
||||
|
||||
|
||||
def main(pipeline_file_path: str, fix_and_overwrite: bool = False):
|
||||
with open(pipeline_file_path, "r") as file:
|
||||
content = file.read()
|
||||
|
||||
# extract generated code in between <generated-code> and </generated-code>
|
||||
current_generated_code = re.search(r"# <generated-code>(.*)# </generated-code>", content, re.DOTALL).group(1)
|
||||
content_without_generated_code = content.replace(current_generated_code, "")
|
||||
|
||||
# extract pipeline signature in between `def pipeline` and `-> Pipeline`
|
||||
pipeline_signature = re.search(r"def pipeline(.*) -> Pipeline:", content_without_generated_code, re.DOTALL).group(
|
||||
1
|
||||
)
|
||||
pipeline_signature = pipeline_signature.replace("(\n ", "(") # start of the signature
|
||||
pipeline_signature = pipeline_signature.replace(",\n ", ", ") # intermediate arguments
|
||||
pipeline_signature = pipeline_signature.replace(",\n)", ")") # end of the signature
|
||||
|
||||
# collect and sort available pipelines
|
||||
pipelines = [(f'"{task}"', task_info["impl"]) for task, task_info in SUPPORTED_TASKS.items()]
|
||||
pipelines = sorted(pipelines, key=lambda x: x[0])
|
||||
pipelines.insert(0, (None, Pipeline))
|
||||
|
||||
# generate new `pipeline` signatures
|
||||
new_generated_code = ""
|
||||
for task, pipeline_class in pipelines:
|
||||
if TASK_PATTERN not in pipeline_signature:
|
||||
raise ValueError(f"Can't find `{TASK_PATTERN}` in pipeline signature: {pipeline_signature}")
|
||||
pipeline_type = pipeline_class if isinstance(pipeline_class, str) else pipeline_class.__name__
|
||||
new_pipeline_signature = pipeline_signature.replace(TASK_PATTERN, f"task: Literal[{task}]")
|
||||
new_generated_code += f"@overload\ndef pipeline{new_pipeline_signature} -> {pipeline_type}: ...\n"
|
||||
|
||||
new_generated_code = HEADER + new_generated_code + FOOTER
|
||||
new_generated_code = new_generated_code.rstrip("\n") + "\n"
|
||||
|
||||
if new_generated_code != current_generated_code and fix_and_overwrite:
|
||||
print(f"Updating {pipeline_file_path}...")
|
||||
wrapped_current_generated_code = "# <generated-code>" + current_generated_code + "# </generated-code>"
|
||||
wrapped_new_generated_code = "# <generated-code>" + new_generated_code + "# </generated-code>"
|
||||
content = content.replace(wrapped_current_generated_code, wrapped_new_generated_code)
|
||||
|
||||
# write content to file
|
||||
with open(pipeline_file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
elif new_generated_code != current_generated_code and not fix_and_overwrite:
|
||||
message = (
|
||||
f"Found inconsistencies in {pipeline_file_path}. "
|
||||
"Run `python utils/check_pipeline_typing.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
parser.add_argument(
|
||||
"--pipeline_file_path",
|
||||
type=str,
|
||||
default="src/transformers/pipelines/__init__.py",
|
||||
help="Path to the pipeline file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.pipeline_file_path, args.fix_and_overwrite)
|
||||
1507
utils/check_repo.py
Normal file
1507
utils/check_repo.py
Normal file
File diff suppressed because it is too large
Load Diff
57
utils/check_self_hosted_runner.py
Normal file
57
utils/check_self_hosted_runner.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
|
||||
def get_runner_status(target_runners, token):
|
||||
offline_runners = []
|
||||
|
||||
cmd = [
|
||||
"curl",
|
||||
"-H",
|
||||
"Accept: application/vnd.github+json",
|
||||
"-H",
|
||||
f"Authorization: Bearer {token}",
|
||||
"https://api.github.com/repos/huggingface/transformers/actions/runners",
|
||||
]
|
||||
|
||||
output = subprocess.run(cmd, check=False, shell=True, stdout=subprocess.PIPE)
|
||||
o = output.stdout.decode("utf-8")
|
||||
status = json.loads(o)
|
||||
|
||||
runners = status["runners"]
|
||||
for runner in runners:
|
||||
if runner["name"] in target_runners:
|
||||
if runner["status"] == "offline":
|
||||
offline_runners.append(runner)
|
||||
|
||||
# save the result so we can report them on Slack
|
||||
with open("offline_runners.txt", "w") as fp:
|
||||
fp.write(json.dumps(offline_runners))
|
||||
|
||||
if len(offline_runners) > 0:
|
||||
failed = "\n".join([x["name"] for x in offline_runners])
|
||||
raise ValueError(f"The following runners are offline:\n{failed}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def list_str(values):
|
||||
return values.split(",")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--target_runners",
|
||||
default=None,
|
||||
type=list_str,
|
||||
required=True,
|
||||
help="Comma-separated list of runners to check status.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token", default=None, type=str, required=True, help="A token that has actions:read permission."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
get_runner_status(args.target_runners, args.token)
|
||||
71
utils/check_types.py
Normal file
71
utils/check_types.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Run ty type checking on specified directories.
|
||||
|
||||
Usage:
|
||||
python utils/check_types.py src/transformers/utils src/transformers/generation
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "types",
|
||||
"label": "Type annotations",
|
||||
# For contributors:
|
||||
# - `check_args` below are the exact roots passed to `ty check`.
|
||||
# - `cache_globs` here are only used by `utils/checkers.py` to decide when a
|
||||
# previously clean `types` run can be reused from cache.
|
||||
# ty follows imports *beyond* the checked roots, so the cache key must cover every source file
|
||||
# that could change a result -- not just the explicitly-checked paths. We hash the whole package
|
||||
# (plus the standalone .circleci target) so any source edit busts the cache and forces a
|
||||
# re-check. Otherwise a cached pass could silently hide a newly-introduced error in a
|
||||
# transitively-imported module that ty pulls in but that isn't one of the checked roots.
|
||||
"cache_globs": [
|
||||
"src/transformers/**/*.py",
|
||||
".circleci/create_circleci_config.py",
|
||||
],
|
||||
"check_args": [
|
||||
"src/transformers/_typing.py",
|
||||
"src/transformers/cli",
|
||||
"src/transformers/modeling_utils.py",
|
||||
"src/transformers/utils",
|
||||
"src/transformers/generation",
|
||||
"src/transformers/pipelines/__init__.py",
|
||||
"src/transformers/pipelines/feature_extraction.py",
|
||||
"src/transformers/pipelines/image_feature_extraction.py",
|
||||
"src/transformers/pipelines/video_classification.py",
|
||||
"src/transformers/quantizers",
|
||||
".circleci/create_circleci_config.py",
|
||||
"src/transformers/dependency_versions_table.py",
|
||||
"src/transformers/dependency_versions_check.py",
|
||||
"src/transformers/conversion_mapping.py",
|
||||
"src/transformers/time_series_utils.py",
|
||||
"src/transformers/debug_utils.py",
|
||||
"src/transformers/hyperparameter_search.py",
|
||||
"src/transformers/pytorch_utils.py",
|
||||
"src/transformers/file_utils.py",
|
||||
"src/transformers/trainer_jit_checkpoint.py",
|
||||
"src/transformers/trainer_optimizer.py",
|
||||
],
|
||||
"fix_args": None,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print(f"Usage: {sys.argv[0]} <directory> [<directory> ...]")
|
||||
sys.exit(1)
|
||||
|
||||
directories = sys.argv[1:]
|
||||
print(f"Running ty check on: {', '.join(directories)}")
|
||||
# `--error-on-warning` makes ty exit non-zero on warning-level diagnostics (e.g.
|
||||
# possibly-missing-attribute), not just errors. Without it, warnings print but ty exits 0, so
|
||||
# `make typing` and CI both pass and the issue is never caught before commit.
|
||||
result = subprocess.run(
|
||||
["ty", "check", "--respect-ignore-files", "--error-on-warning", "--exclude", "**/*_pb*", *directories],
|
||||
)
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
641
utils/checkers.py
Normal file
641
utils/checkers.py
Normal file
@@ -0,0 +1,641 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unified runner for check/fix scripts.
|
||||
|
||||
Usage:
|
||||
python utils/checkers.py copies,modular_conversion,doc_toc
|
||||
python utils/checkers.py copies,modular_conversion,doc_toc --fix
|
||||
python utils/checkers.py copies,doc_toc --keep-going
|
||||
python utils/checkers.py all
|
||||
python utils/checkers.py all --fix
|
||||
|
||||
Plugin system
|
||||
-------------
|
||||
Each checker module declares a ``CHECKER_CONFIG`` dict (extracted via ``ast.literal_eval``,
|
||||
no import needed — this keeps discovery fast and avoids executing checker code at scan time).
|
||||
See any ``check_*.py`` file for the schema.
|
||||
|
||||
Cache semantics of ``cache_globs``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
``cache_globs`` lists the file patterns whose content is hashed to decide whether a checker
|
||||
can be skipped. **Not all globs are exact reflections of the checker's runtime behaviour.**
|
||||
|
||||
* Some checkers introspect the live ``transformers`` module (``check_repo``,
|
||||
``check_config_docstrings``, ``check_config_attributes``, ``update_metadata``), so their
|
||||
globs are necessarily *approximations* of the true dependency set.
|
||||
* Some checkers over-approximate (``check_dummies``, ``check_doctest_list``): any change
|
||||
inside the broad glob forces a re-run even if the checker wouldn't look at that file.
|
||||
This is safe—just less cache-efficient.
|
||||
* Some checkers rely on external state (network, git history, installed packages) that
|
||||
cannot be captured by cache globs at all (``add_dates``, ``imports``).
|
||||
|
||||
Each ``CHECKER_CONFIG`` that is an approximation has an inline comment explaining the
|
||||
gap. For contributors: ``check_args`` control what a checker runs on, while
|
||||
``cache_globs`` only control when the cache is invalidated. When in doubt, use
|
||||
``--no-cache`` to force a full run.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import hashlib
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
UTILS_DIR = Path(__file__).parent
|
||||
REPO_ROOT = UTILS_DIR.parent
|
||||
CACHE_PATH = UTILS_DIR / ".checkers_cache.json"
|
||||
|
||||
# Required keys in each module's CHECKER_CONFIG dict.
|
||||
_CHECKER_CONFIG_KEYS = {"name", "label", "cache_globs", "check_args", "fix_args"}
|
||||
|
||||
|
||||
def _discover_checkers() -> tuple[dict, dict]:
|
||||
"""Scan utils/*.py for CHECKER_CONFIG dicts using AST (no imports).
|
||||
|
||||
Each checker module may define a top-level ``CHECKER_CONFIG`` dict with
|
||||
keys: name, label, cache_globs, check_args, fix_args.
|
||||
|
||||
Returns (checkers_dict, cache_globs_dict) matching the shapes of
|
||||
the old CHECKERS and CHECKER_CACHE_GLOBS registries.
|
||||
"""
|
||||
checkers = {}
|
||||
cache_globs = {}
|
||||
|
||||
for py_file in sorted(UTILS_DIR.glob("*.py")):
|
||||
if py_file.name == Path(__file__).name:
|
||||
continue
|
||||
|
||||
try:
|
||||
tree = ast.parse(py_file.read_text(encoding="utf-8"), filename=str(py_file))
|
||||
except SyntaxError:
|
||||
continue
|
||||
|
||||
config = None
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if (
|
||||
isinstance(node, ast.Assign)
|
||||
and len(node.targets) == 1
|
||||
and isinstance(node.targets[0], ast.Name)
|
||||
and node.targets[0].id == "CHECKER_CONFIG"
|
||||
):
|
||||
try:
|
||||
config = ast.literal_eval(node.value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
break
|
||||
|
||||
if config is None:
|
||||
continue
|
||||
|
||||
missing = _CHECKER_CONFIG_KEYS - set(config)
|
||||
if missing:
|
||||
warnings.warn(
|
||||
f"CHECKER_CONFIG in {py_file.name} is missing keys: {', '.join(sorted(missing))}. Skipping.",
|
||||
stacklevel=1,
|
||||
)
|
||||
continue
|
||||
|
||||
name = config["name"]
|
||||
if name in checkers:
|
||||
warnings.warn(
|
||||
f"Duplicate checker name {name!r} in {py_file.name}, already defined by {checkers[name][1]}",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
checkers[name] = (
|
||||
config["label"],
|
||||
py_file.name,
|
||||
config["check_args"],
|
||||
config["fix_args"],
|
||||
)
|
||||
if config["cache_globs"] is not None:
|
||||
cache_globs[name] = config["cache_globs"]
|
||||
|
||||
return checkers, cache_globs
|
||||
|
||||
|
||||
# Inline checkers have no separate script file; they use custom runner functions below.
|
||||
# fix_args=[] marks a checker as fix-capable (its custom runner handles --fix internally);
|
||||
# fix_args=None marks a check-only entry that `make fix-repo` should silently skip.
|
||||
_INLINE_CHECKERS = {
|
||||
"deps_table": ("Dependency versions table", None, None, []),
|
||||
"imports": ("Public imports", None, None, None),
|
||||
"import_complexity": ("Import complexity", "check_import_complexity.py", [], None),
|
||||
"ruff_check": ("Ruff linting", None, None, []),
|
||||
"ruff_format": ("Ruff formatting", None, None, []),
|
||||
}
|
||||
|
||||
_INLINE_CACHE_GLOBS = {
|
||||
# Also generates/checks src/transformers/dependency_versions_table.py.
|
||||
"deps_table": ["setup.py", "pyproject.toml", "src/transformers/dependency_versions_table.py"],
|
||||
# Approximate: runs `from transformers import *` at runtime; depends on the full
|
||||
# Python environment, not just these files. Broad globs used as a safe upper bound.
|
||||
"imports": ["src/transformers/**/__init__.py", "src/transformers/**/*.py"],
|
||||
# Approximate: ruff applies its own ignore rules from pyproject.toml at runtime.
|
||||
"ruff_check": [
|
||||
"examples/**/*.py",
|
||||
"tests/**/*.py",
|
||||
"src/**/*.py",
|
||||
"utils/**/*.py",
|
||||
"scripts/**/*.py",
|
||||
".circleci/create_circleci_config.py",
|
||||
"benchmark/**/*.py",
|
||||
"benchmark_v2/**/*.py",
|
||||
"setup.py",
|
||||
"conftest.py",
|
||||
],
|
||||
"ruff_format": [
|
||||
"examples/**/*.py",
|
||||
"tests/**/*.py",
|
||||
"src/**/*.py",
|
||||
"utils/**/*.py",
|
||||
"scripts/**/*.py",
|
||||
".circleci/create_circleci_config.py",
|
||||
"benchmark/**/*.py",
|
||||
"benchmark_v2/**/*.py",
|
||||
"setup.py",
|
||||
"conftest.py",
|
||||
],
|
||||
}
|
||||
|
||||
# Build the registries: discovered modules + inline custom runners.
|
||||
_discovered_checkers, _discovered_cache_globs = _discover_checkers()
|
||||
|
||||
CHECKERS = {**_discovered_checkers, **_INLINE_CHECKERS}
|
||||
CHECKER_CACHE_GLOBS = {**_discovered_cache_globs, **_INLINE_CACHE_GLOBS}
|
||||
|
||||
|
||||
def get_checker_cache_globs(checker_name: str) -> list[str] | None:
|
||||
"""Return the cache inputs for a checker, including its implementation files."""
|
||||
globs = CHECKER_CACHE_GLOBS.get(checker_name)
|
||||
if globs is None:
|
||||
return None
|
||||
|
||||
cache_globs = [*globs, str(Path("utils") / Path(__file__).name)]
|
||||
script = CHECKERS[checker_name][1]
|
||||
if script is not None:
|
||||
cache_globs.append(str(Path("utils") / script))
|
||||
return cache_globs
|
||||
|
||||
|
||||
class CheckerCache:
|
||||
"""Disk-backed cache that tracks file content hashes per checker.
|
||||
|
||||
For each checker that declares cache globs in CHECKER_CACHE_GLOBS, we compute
|
||||
a single digest over all matching files. If the digest matches the stored
|
||||
value from the last clean (rc == 0) run, the checker can be skipped.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path | None = None):
|
||||
self._path = CACHE_PATH if path is None else path
|
||||
self._data = self._load()
|
||||
|
||||
def _load(self) -> dict:
|
||||
try:
|
||||
return json.loads(self._path.read_text(encoding="utf-8"))
|
||||
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
||||
return {}
|
||||
|
||||
def save(self) -> None:
|
||||
try:
|
||||
self._path.write_text(json.dumps(self._data, sort_keys=True, indent=2) + "\n", encoding="utf-8")
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _digest_files(globs: list[str]) -> str:
|
||||
"""Compute a single SHA-256 over sorted file paths + contents."""
|
||||
h = hashlib.sha256()
|
||||
paths = set()
|
||||
for pattern in globs:
|
||||
paths.update(REPO_ROOT.glob(pattern))
|
||||
for p in sorted(paths):
|
||||
if p.is_file():
|
||||
h.update(str(p.relative_to(REPO_ROOT)).encode())
|
||||
h.update(p.read_bytes())
|
||||
return h.hexdigest()
|
||||
|
||||
def is_current(self, checker_name: str) -> bool:
|
||||
"""Return True if the checker's files haven't changed since last clean run."""
|
||||
globs = get_checker_cache_globs(checker_name)
|
||||
if globs is None:
|
||||
return False
|
||||
return self._data.get(checker_name) == self._digest_files(globs)
|
||||
|
||||
def update(self, checker_name: str) -> None:
|
||||
"""Record current digest for a checker (call after a clean run)."""
|
||||
globs = get_checker_cache_globs(checker_name)
|
||||
if globs is None:
|
||||
return
|
||||
self._data[checker_name] = self._digest_files(globs)
|
||||
|
||||
def invalidate(self, checker_name: str) -> None:
|
||||
"""Remove a checker from the cache (call after a failed run)."""
|
||||
self._data.pop(checker_name, None)
|
||||
|
||||
|
||||
def _file_md5(path):
|
||||
return hashlib.md5(path.read_bytes()).hexdigest()
|
||||
|
||||
|
||||
# ANSI helpers
|
||||
ORANGE = "\033[38;5;214m"
|
||||
GREEN = "\033[32m"
|
||||
RED = "\033[31m"
|
||||
RESET = "\033[0m"
|
||||
SPINNER_CHARS = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
|
||||
|
||||
|
||||
def format_elapsed(seconds: float) -> str:
|
||||
"""Format a duration for status output."""
|
||||
if seconds >= 60:
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
return f"{int(minutes)}m{seconds:05.2f}s"
|
||||
return f"{seconds:.2f}s"
|
||||
|
||||
|
||||
class SlidingWindow:
|
||||
"""Displays a spinning title + sliding window of the last N output lines in a TTY."""
|
||||
|
||||
def __init__(self, label, max_lines=10):
|
||||
self.label = label
|
||||
self.max_lines = max_lines
|
||||
self.lines = deque(maxlen=max_lines)
|
||||
self.displayed = 0 # number of output lines currently on screen
|
||||
self.term_width = shutil.get_terminal_size().columns
|
||||
self._spinner = itertools.cycle(SPINNER_CHARS)
|
||||
self._stop = threading.Event()
|
||||
self._lock = threading.Lock()
|
||||
# Print initial title line (will be overwritten by spinner)
|
||||
print(f"{ORANGE}{next(self._spinner)} {label}{RESET}")
|
||||
self._title_on_screen = True
|
||||
self._thread = threading.Thread(target=self._spin, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def _spin(self):
|
||||
while not self._stop.is_set():
|
||||
self._stop.wait(0.08)
|
||||
if self._stop.is_set():
|
||||
break
|
||||
with self._lock:
|
||||
self._redraw()
|
||||
|
||||
def _redraw(self):
|
||||
"""Clear output lines + title, redraw everything."""
|
||||
# Move up over output lines + title line
|
||||
for _ in range(self.displayed + (1 if self._title_on_screen else 0)):
|
||||
sys.stdout.write("\033[A\033[2K")
|
||||
self.displayed = 0
|
||||
# Redraw title with next spinner frame
|
||||
print(f"{ORANGE}{next(self._spinner)} {self.label}{RESET}")
|
||||
self._title_on_screen = True
|
||||
# Redraw output lines
|
||||
for line in self.lines:
|
||||
print(line)
|
||||
self.displayed = len(self.lines)
|
||||
sys.stdout.flush()
|
||||
|
||||
def add_line(self, line):
|
||||
with self._lock:
|
||||
self.lines.append(line.rstrip()[: self.term_width])
|
||||
self._redraw()
|
||||
|
||||
def finish(self, success, elapsed=None, show_lines=True):
|
||||
"""Stop spinner and print final status title."""
|
||||
self._stop.set()
|
||||
self._thread.join()
|
||||
with self._lock:
|
||||
# Clear output lines + title
|
||||
for _ in range(self.displayed + (1 if self._title_on_screen else 0)):
|
||||
sys.stdout.write("\033[A\033[2K")
|
||||
self._title_on_screen = False
|
||||
self.displayed = 0
|
||||
# Print final title with status
|
||||
suffix = f" ({format_elapsed(elapsed)})" if elapsed is not None else ""
|
||||
if success:
|
||||
print(f"{GREEN}✓ {self.label}{suffix}{RESET}")
|
||||
else:
|
||||
print(f"{RED}✗ {self.label}{suffix}{RESET}")
|
||||
# Reprint output lines when we want to preserve the tail summary.
|
||||
if show_lines:
|
||||
for line in self.lines:
|
||||
print(line)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _print_output(output: str) -> None:
|
||||
"""Print captured output without truncation."""
|
||||
if not output:
|
||||
return
|
||||
|
||||
print(output, end="" if output.endswith("\n") else "\n", flush=True)
|
||||
|
||||
|
||||
def _run_cmd(cmd, line_callback=None):
|
||||
"""Run a command, capturing output. Returns (returncode, output)."""
|
||||
if line_callback is None:
|
||||
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
return result.returncode, result.stdout.decode("utf-8", errors="replace")
|
||||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env)
|
||||
output_lines = []
|
||||
for raw_line in proc.stdout:
|
||||
line = raw_line.decode("utf-8", errors="replace")
|
||||
output_lines.append(line)
|
||||
line_callback(line)
|
||||
proc.wait()
|
||||
return proc.returncode, "".join(output_lines)
|
||||
|
||||
|
||||
def run_deps_table_checker(fix=False, line_callback=None):
|
||||
"""Check or fix the dependency versions table."""
|
||||
deps_table = REPO_ROOT / "src" / "transformers" / "dependency_versions_table.py"
|
||||
setup_py = REPO_ROOT / "setup.py"
|
||||
cmd = [sys.executable, str(setup_py), "deps_table_update"]
|
||||
|
||||
if fix:
|
||||
return _run_cmd(cmd, line_callback=line_callback)
|
||||
|
||||
before = _file_md5(deps_table)
|
||||
rc, output = _run_cmd(cmd, line_callback=line_callback)
|
||||
if rc != 0:
|
||||
return rc, output
|
||||
after = _file_md5(deps_table)
|
||||
if before != after:
|
||||
msg = (
|
||||
"Error: the version dependency table is outdated.\n"
|
||||
"Please run 'make fix-repo' and commit the changes. This requires Python 3.10.\n"
|
||||
)
|
||||
return 1, output + msg
|
||||
return 0, output
|
||||
|
||||
|
||||
def run_imports_checker(fix=False, line_callback=None):
|
||||
"""Check that all public imports work."""
|
||||
rc, output = _run_cmd([sys.executable, "-c", "from transformers import *"], line_callback=line_callback)
|
||||
if rc != 0:
|
||||
return rc, output + "Import failed, this means you introduced unprotected imports!\n"
|
||||
return 0, output
|
||||
|
||||
|
||||
RUFF_TARGETS = [
|
||||
"examples",
|
||||
"tests",
|
||||
"src",
|
||||
"utils",
|
||||
"scripts",
|
||||
".circleci/create_circleci_config.py",
|
||||
"benchmark",
|
||||
"benchmark_v2",
|
||||
"setup.py",
|
||||
"conftest.py",
|
||||
]
|
||||
|
||||
|
||||
def run_ruff_check(fix=False, line_callback=None):
|
||||
"""Run ruff linting."""
|
||||
cmd = ["ruff", "check", *RUFF_TARGETS]
|
||||
if fix:
|
||||
cmd += ["--fix", "--exclude", ""]
|
||||
return _run_cmd(cmd, line_callback=line_callback)
|
||||
|
||||
|
||||
def run_ruff_format(fix=False, line_callback=None):
|
||||
"""Run ruff formatting."""
|
||||
cmd = ["ruff", "format", *RUFF_TARGETS]
|
||||
if not fix:
|
||||
cmd += ["--check"]
|
||||
else:
|
||||
cmd += ["--exclude", ""]
|
||||
return _run_cmd(cmd, line_callback=line_callback)
|
||||
|
||||
|
||||
CUSTOM_RUNNERS = {
|
||||
"deps_table": run_deps_table_checker,
|
||||
"imports": run_imports_checker,
|
||||
"ruff_check": run_ruff_check,
|
||||
"ruff_format": run_ruff_format,
|
||||
}
|
||||
|
||||
|
||||
def get_checker_command(name, fix=False):
|
||||
"""Return a shell-friendly command string for a checker."""
|
||||
if name == "deps_table":
|
||||
return "python setup.py deps_table_update"
|
||||
if name == "imports":
|
||||
return 'python -c "from transformers import *"'
|
||||
if name == "ruff_check":
|
||||
cmd = ["ruff", "check", *RUFF_TARGETS]
|
||||
if fix:
|
||||
cmd += ["--fix", "--exclude", ""]
|
||||
return " ".join(cmd)
|
||||
if name == "ruff_format":
|
||||
cmd = ["ruff", "format", *RUFF_TARGETS]
|
||||
if not fix:
|
||||
cmd += ["--check"]
|
||||
else:
|
||||
cmd += ["--exclude", ""]
|
||||
return " ".join(cmd)
|
||||
|
||||
_, script, check_args, fix_args = CHECKERS[name]
|
||||
if fix and fix_args is None:
|
||||
return None
|
||||
args = fix_args if fix else check_args
|
||||
return " ".join(["python", f"utils/{script}"] + args)
|
||||
|
||||
|
||||
def run_checker(name, fix=False, line_callback=None):
|
||||
if name in CUSTOM_RUNNERS:
|
||||
return CUSTOM_RUNNERS[name](fix=fix, line_callback=line_callback)
|
||||
|
||||
_, script, check_args, fix_args = CHECKERS[name]
|
||||
script_path = UTILS_DIR / script
|
||||
|
||||
if fix and fix_args is None:
|
||||
return 0, "skipped (no fix mode)"
|
||||
|
||||
cmd = [sys.executable, str(script_path)]
|
||||
cmd += fix_args if fix else check_args
|
||||
|
||||
return _run_cmd(cmd, line_callback=line_callback)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run check/fix scripts.")
|
||||
parser.add_argument(
|
||||
"checkers",
|
||||
nargs="+",
|
||||
help='Comma-separated checker names, or "all". Use --list to see available checkers.',
|
||||
)
|
||||
parser.add_argument("--fix", action="store_true", help="Run in fix mode instead of check mode.")
|
||||
parser.add_argument(
|
||||
"--keep-going", action="store_true", help="Run all checkers even if some fail (report failures at the end)."
|
||||
)
|
||||
parser.add_argument("--list", action="store_true", help="List available checkers and exit.")
|
||||
parser.add_argument("--no-cache", action="store_true", help="Ignore the disk cache and re-run every checker.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list:
|
||||
for name, entry in sorted(CHECKERS.items()):
|
||||
label, script, _, fix_args = entry
|
||||
fixable = "fixable" if fix_args is not None else "check-only"
|
||||
script_display = script or "custom"
|
||||
print(f" {name:25s} {label:35s} ({script_display}, {fixable})")
|
||||
return
|
||||
|
||||
# Join all positional args (shell line continuations may split them) and parse checker names
|
||||
raw = " ".join(args.checkers)
|
||||
if raw.strip() == "all":
|
||||
names = list(CHECKERS.keys())
|
||||
else:
|
||||
names = [n.strip() for n in raw.split(",") if n.strip()]
|
||||
|
||||
unknown = [n for n in names if n not in CHECKERS]
|
||||
if unknown:
|
||||
print(f"Unknown checkers: {', '.join(unknown)}")
|
||||
print(f"Available: {', '.join(sorted(CHECKERS.keys()))}")
|
||||
sys.exit(1)
|
||||
|
||||
# In --fix mode, drop checkers that have no fix capability (fix_args is None) so
|
||||
# they don't print bogus "(0.00s)" lines or inflate the final pass count. Print
|
||||
# one transparency line listing what we're skipping.
|
||||
if args.fix:
|
||||
not_fixable = [n for n in names if CHECKERS[n][3] is None]
|
||||
if not_fixable:
|
||||
names = [n for n in names if CHECKERS[n][3] is not None]
|
||||
print(
|
||||
f"Skipping {len(not_fixable)} check-only checker(s) in fix mode: {', '.join(not_fixable)}\n",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
is_ci = os.environ.get("GITHUB_ACTIONS") == "true" or os.environ.get("CIRCLECI") == "true"
|
||||
is_tty = sys.stdout.isatty() and not is_ci
|
||||
|
||||
if not is_tty and hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(line_buffering=True)
|
||||
|
||||
use_cache = not args.no_cache and not args.fix
|
||||
cache = CheckerCache() if use_cache else None
|
||||
|
||||
failures = []
|
||||
skipped = 0
|
||||
total_start = time.perf_counter()
|
||||
for name in names:
|
||||
label = CHECKERS[name][0]
|
||||
|
||||
# Skip if all relevant files are unchanged since last clean run
|
||||
if cache is not None and cache.is_current(name):
|
||||
skipped += 1
|
||||
if is_tty:
|
||||
print(f"{GREEN}✓ {label} (cached){RESET}\n")
|
||||
else:
|
||||
print(f"{label} (cached)\n", flush=True)
|
||||
continue
|
||||
|
||||
cmd_str = get_checker_command(name, fix=args.fix)
|
||||
checker_start = time.perf_counter()
|
||||
|
||||
if is_tty:
|
||||
window = SlidingWindow(label, max_lines=10)
|
||||
if cmd_str:
|
||||
window.add_line(f"$ {cmd_str}")
|
||||
rc, output = run_checker(name, fix=args.fix, line_callback=window.add_line)
|
||||
elapsed = time.perf_counter() - checker_start
|
||||
window.finish(success=(rc == 0), elapsed=elapsed, show_lines=(rc == 0))
|
||||
if rc != 0:
|
||||
print()
|
||||
_print_output(output)
|
||||
print()
|
||||
if rc == 0 and cache is not None:
|
||||
cache.update(name)
|
||||
elif rc != 0:
|
||||
if cache is not None:
|
||||
cache.invalidate(name)
|
||||
failures.append(name)
|
||||
if not args.keep_going:
|
||||
if cache is not None:
|
||||
cache.save()
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"{label}", flush=True)
|
||||
if cmd_str:
|
||||
print(f"$ {cmd_str}", flush=True)
|
||||
if is_ci:
|
||||
streamed_output = []
|
||||
|
||||
def print_line(line):
|
||||
streamed_output.append(line)
|
||||
print(line, end="", flush=True)
|
||||
|
||||
rc, output = run_checker(name, fix=args.fix, line_callback=print_line)
|
||||
if rc != 0 and output:
|
||||
streamed_text = "".join(streamed_output)
|
||||
if output.startswith(streamed_text):
|
||||
_print_output(output[len(streamed_text) :])
|
||||
elif output != streamed_text:
|
||||
_print_output(output)
|
||||
else:
|
||||
rc, output = run_checker(name, fix=args.fix)
|
||||
if rc == 0:
|
||||
tail = output.splitlines()[-10:]
|
||||
if tail:
|
||||
print("\n".join(tail), flush=True)
|
||||
else:
|
||||
_print_output(output)
|
||||
elapsed = time.perf_counter() - checker_start
|
||||
status = "OK" if rc == 0 else "FAILED"
|
||||
print(f"{status} ({format_elapsed(elapsed)})", flush=True)
|
||||
print(flush=True)
|
||||
if rc == 0 and cache is not None:
|
||||
cache.update(name)
|
||||
elif rc != 0:
|
||||
if cache is not None:
|
||||
cache.invalidate(name)
|
||||
failures.append(name)
|
||||
if not args.keep_going:
|
||||
if cache is not None:
|
||||
cache.save()
|
||||
sys.exit(1)
|
||||
|
||||
if cache is not None:
|
||||
cache.save()
|
||||
|
||||
if failures:
|
||||
print(f"\n{len(failures)} failed: {', '.join(failures)}", flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
total_elapsed = format_elapsed(time.perf_counter() - total_start)
|
||||
passed = len(names) - skipped
|
||||
if skipped:
|
||||
print(f"\nAll {len(names)} checks passed in {total_elapsed} ({passed} ran, {skipped} cached).", flush=True)
|
||||
else:
|
||||
print(f"\nAll {len(names)} checks passed in {total_elapsed}.", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
217
utils/collated_reports.py
Normal file
217
utils/collated_reports.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
DEFAULT_GPU_NAMES = ["mi300", "mi325", "mi355", "h100", "a10"]
|
||||
|
||||
|
||||
def simplify_gpu_name(gpu_name: str, simplified_names: list[str]) -> str:
|
||||
matches = []
|
||||
for simplified_name in simplified_names:
|
||||
if simplified_name in gpu_name:
|
||||
matches.append(simplified_name)
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
return gpu_name
|
||||
|
||||
|
||||
def parse_short_summary_line(line: str) -> tuple[str | None, int]:
|
||||
if line.startswith("PASSED"):
|
||||
return "passed", 1
|
||||
if line.startswith("FAILED"):
|
||||
return "failed", 1
|
||||
if line.startswith("SKIPPED"):
|
||||
line = line.split("[", maxsplit=1)[1]
|
||||
line = line.split("]", maxsplit=1)[0]
|
||||
return "skipped", int(line)
|
||||
if line.startswith("ERROR"):
|
||||
return "error", 1
|
||||
return None, 0
|
||||
|
||||
|
||||
def validate_path(p: str) -> Path:
|
||||
# Validate path and apply glob pattern if provided
|
||||
path = Path(p)
|
||||
assert path.is_dir(), f"Path {path} is not a directory"
|
||||
return path
|
||||
|
||||
|
||||
def get_gpu_name(gpu_name: str | None) -> str:
|
||||
# Get GPU name if available
|
||||
if gpu_name is None:
|
||||
try:
|
||||
import torch
|
||||
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
except Exception as e:
|
||||
print(f"Failed to get GPU name with {e}")
|
||||
gpu_name = "unknown"
|
||||
else:
|
||||
gpu_name = gpu_name.replace(" ", "_").lower()
|
||||
gpu_name = simplify_gpu_name(gpu_name, DEFAULT_GPU_NAMES)
|
||||
|
||||
return gpu_name
|
||||
|
||||
|
||||
def get_commit_hash(commit_hash: str | None) -> str:
|
||||
# Get commit hash if available
|
||||
if commit_hash is None:
|
||||
try:
|
||||
commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||||
except Exception as e:
|
||||
print(f"Failed to get commit hash with {e}")
|
||||
commit_hash = "unknown"
|
||||
|
||||
return commit_hash[:7]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
path: Path
|
||||
machine_type: str
|
||||
gpu_name: str
|
||||
commit_hash: str
|
||||
job: str | None
|
||||
report_repo_id: str | None
|
||||
|
||||
|
||||
def get_arguments(args: argparse.Namespace) -> Args:
|
||||
path = validate_path(args.path)
|
||||
machine_type = args.machine_type
|
||||
gpu_name = get_gpu_name(args.gpu_name)
|
||||
commit_hash = get_commit_hash(args.commit_hash)
|
||||
job = args.job
|
||||
report_repo_id = args.report_repo_id
|
||||
return Args(path, machine_type, gpu_name, commit_hash, job, report_repo_id)
|
||||
|
||||
|
||||
def upload_collated_report(job: str, report_repo_id: str, filename: str):
|
||||
# Alternatively we can check for the existence of the collated_reports file and upload in notification_service.py
|
||||
import os
|
||||
|
||||
from get_previous_daily_ci import get_last_daily_ci_run
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# if it is not a scheduled run, upload the reports to a subfolder under `report_repo_folder`
|
||||
report_repo_subfolder = ""
|
||||
if os.getenv("GITHUB_EVENT_NAME") != "schedule":
|
||||
report_repo_subfolder = f"{os.getenv('GITHUB_RUN_NUMBER')}-{os.getenv('GITHUB_RUN_ID')}"
|
||||
report_repo_subfolder = f"runs/{report_repo_subfolder}"
|
||||
|
||||
workflow_run = get_last_daily_ci_run(
|
||||
token=os.environ["ACCESS_REPO_INFO_TOKEN"], workflow_run_id=os.getenv("GITHUB_RUN_ID")
|
||||
)
|
||||
workflow_run_created_time = workflow_run["created_at"]
|
||||
report_repo_folder = workflow_run_created_time.split("T")[0]
|
||||
|
||||
if report_repo_subfolder:
|
||||
report_repo_folder = f"{report_repo_folder}/{report_repo_subfolder}"
|
||||
|
||||
api.upload_file(
|
||||
path_or_fileobj=f"{filename}",
|
||||
path_in_repo=f"{report_repo_folder}/ci_results_{job}/{filename}",
|
||||
repo_id=report_repo_id,
|
||||
repo_type="dataset",
|
||||
token=os.getenv("TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Post process models test reports.")
|
||||
parser.add_argument("--path", "-p", help="Path to the reports folder")
|
||||
parser.add_argument(
|
||||
"--machine-type", "-m", help="Process single or multi GPU results", choices=["single-gpu", "multi-gpu"]
|
||||
)
|
||||
parser.add_argument("--gpu-name", "-g", help="GPU name", default=None)
|
||||
parser.add_argument("--commit-hash", "-c", help="Commit hash", default=None)
|
||||
parser.add_argument("--job", "-j", help="Optional job name required for uploading reports", default=None)
|
||||
parser.add_argument(
|
||||
"--report-repo-id", "-r", help="Optional report repository ID required for uploading reports", default=None
|
||||
)
|
||||
args = get_arguments(parser.parse_args())
|
||||
|
||||
# Initialize accumulators for collated report
|
||||
total_status_count = {
|
||||
"passed": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"error": 0,
|
||||
None: 0,
|
||||
}
|
||||
collated_report_buffer = []
|
||||
|
||||
path = args.path
|
||||
machine_type = args.machine_type
|
||||
gpu_name = args.gpu_name
|
||||
commit_hash = args.commit_hash
|
||||
job = args.job
|
||||
report_repo_id = args.report_repo_id
|
||||
|
||||
# Loop through model directories and create collated reports
|
||||
for model_dir in sorted(path.iterdir()):
|
||||
if not model_dir.name.startswith(machine_type):
|
||||
continue
|
||||
|
||||
# Create a new entry for the model
|
||||
model_name = model_dir.name.split("models_")[-1].removesuffix("_test_reports")
|
||||
report = {"model": model_name, "results": []}
|
||||
results = []
|
||||
|
||||
# Read short summary
|
||||
with open(model_dir / "summary_short.txt", "r") as f:
|
||||
short_summary_lines = f.readlines()
|
||||
|
||||
# Parse short summary
|
||||
for line in short_summary_lines[1:]:
|
||||
status, count = parse_short_summary_line(line)
|
||||
total_status_count[status] += count
|
||||
if status:
|
||||
result = {
|
||||
"status": status,
|
||||
"test": line.split(status.upper(), maxsplit=1)[1].strip(),
|
||||
"count": count,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Add short summaries to report
|
||||
report["results"] = results
|
||||
|
||||
collated_report_buffer.append(report)
|
||||
|
||||
filename = f"collated_reports_{machine_type}_{commit_hash}.json"
|
||||
# Write collated report
|
||||
with open(filename, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"gpu_name": gpu_name,
|
||||
"machine_type": machine_type,
|
||||
"commit_hash": commit_hash,
|
||||
"total_status_count": total_status_count,
|
||||
"results": collated_report_buffer,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
# Upload collated report
|
||||
if job and report_repo_id:
|
||||
upload_collated_report(job, report_repo_id, filename)
|
||||
91
utils/compare_test_runs.py
Normal file
91
utils/compare_test_runs.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def normalize_test_line(line):
|
||||
line = line.strip()
|
||||
|
||||
# Normalize SKIPPED/XFAIL/etc with path:line and reason
|
||||
match = re.match(r"^(SKIPPED|XFAIL|XPASS|EXPECTEDFAIL)\s+\[?\d*\]?\s*(\S+:\d+)", line)
|
||||
if match:
|
||||
status, location = match.groups()
|
||||
return f"{status} {location}"
|
||||
|
||||
# Normalize ERROR/FAILED lines with optional message
|
||||
if line.startswith("ERROR") or line.startswith("FAILED"):
|
||||
return re.split(r"\s+-\s+", line)[0].strip()
|
||||
|
||||
return line
|
||||
|
||||
|
||||
def parse_summary_file(file_path):
|
||||
test_set = set()
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
in_summary = False
|
||||
for line in f:
|
||||
if line.strip().startswith("==="):
|
||||
in_summary = not in_summary
|
||||
continue
|
||||
if in_summary:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
normalized = normalize_test_line(stripped)
|
||||
test_set.add(normalized)
|
||||
return test_set
|
||||
|
||||
|
||||
def compare_job_sets(job_set1, job_set2):
|
||||
all_job_names = sorted(set(job_set1) | set(job_set2))
|
||||
report_lines = []
|
||||
|
||||
for job_name in all_job_names:
|
||||
file1 = job_set1.get(job_name)
|
||||
file2 = job_set2.get(job_name)
|
||||
|
||||
tests1 = parse_summary_file(file1) if file1 else set()
|
||||
tests2 = parse_summary_file(file2) if file2 else set()
|
||||
|
||||
added = tests2 - tests1
|
||||
removed = tests1 - tests2
|
||||
|
||||
if added or removed:
|
||||
report_lines.append(f"=== Diff for job: {job_name} ===")
|
||||
if removed:
|
||||
report_lines.append("--- Absent in current run:")
|
||||
for test in sorted(removed):
|
||||
report_lines.append(f" - {test}")
|
||||
if added:
|
||||
report_lines.append("+++ Appeared in current run:")
|
||||
for test in sorted(added):
|
||||
report_lines.append(f" + {test}")
|
||||
report_lines.append("") # blank line
|
||||
|
||||
return "\n".join(report_lines) if report_lines else "No differences found."
|
||||
|
||||
|
||||
# Example usage:
|
||||
# job_set_1 = {
|
||||
# "albert": "prev/multi-gpu_run_models_gpu_models/albert_test_reports/summary_short.txt",
|
||||
# "bloom": "prev/multi-gpu_run_models_gpu_models/bloom_test_reports/summary_short.txt",
|
||||
# }
|
||||
|
||||
# job_set_2 = {
|
||||
# "albert": "curr/multi-gpu_run_models_gpu_models/albert_test_reports/summary_short.txt",
|
||||
# "bloom": "curr/multi-gpu_run_models_gpu_models/bloom_test_reports/summary_short.txt",
|
||||
# }
|
||||
|
||||
# report = compare_job_sets(job_set_1, job_set_2)
|
||||
# print(report)
|
||||
116
utils/create_dependency_mapping.py
Normal file
116
utils/create_dependency_mapping.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import ast
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
# Function to perform topological sorting
|
||||
def topological_sort(dependencies: dict) -> list[list[str]]:
|
||||
"""Given the dependencies graph, construct a sorted list of list of modular files.
|
||||
|
||||
Examples:
|
||||
|
||||
The returned list of lists might be:
|
||||
[
|
||||
["../modular_mistral.py", "../modular_gemma.py"], # level 0
|
||||
["../modular_llama4.py", "../modular_gemma2.py"], # level 1
|
||||
["../modular_glm4.py"], # level 2
|
||||
]
|
||||
which means mistral and gemma do not depend on any other modular models, while llama4 and gemma2
|
||||
depend on the models in the first list, and glm4 depends on the models in the second and (optionally) in the first list.
|
||||
"""
|
||||
|
||||
# Nodes are the name of the models to convert (we only add those to the graph)
|
||||
nodes = {node.rsplit("modular_", 1)[1].replace(".py", "") for node in dependencies}
|
||||
# This will be a graph from models to convert, to models to convert that should be converted before (as they are a dependency)
|
||||
graph = {}
|
||||
name_mapping = {}
|
||||
for node, deps in dependencies.items():
|
||||
node_name = node.rsplit("modular_", 1)[1].replace(".py", "")
|
||||
dep_names = {dep.split(".")[-2] for dep in deps}
|
||||
dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name}
|
||||
graph[node_name] = dependencies
|
||||
name_mapping[node_name] = node
|
||||
|
||||
sorting_list = []
|
||||
while len(graph) > 0:
|
||||
# Find the nodes with 0 out-degree
|
||||
leaf_nodes = {node for node in graph if len(graph[node]) == 0}
|
||||
# Add them to the list as next level
|
||||
sorting_list.append([name_mapping[node] for node in leaf_nodes])
|
||||
# Remove the leaves from the graph (and from the deps of other nodes)
|
||||
graph = {node: deps - leaf_nodes for node, deps in graph.items() if node not in leaf_nodes}
|
||||
|
||||
return sorting_list
|
||||
|
||||
|
||||
# All the model file types that may be imported in modular files
|
||||
ALL_FILE_TYPES = (
|
||||
"modeling",
|
||||
"configuration",
|
||||
"tokenization",
|
||||
"processing",
|
||||
"image_processing",
|
||||
"video_processing",
|
||||
"feature_extraction",
|
||||
)
|
||||
|
||||
|
||||
def is_model_import(module: str | None) -> bool:
|
||||
"""Check whether `module` is a model import or not."""
|
||||
# Happens for fully relative import, i.e. `from ... import initialization as init`
|
||||
if module is None:
|
||||
return False
|
||||
patterns = "|".join(ALL_FILE_TYPES)
|
||||
regex = rf"(\w+)\.(?:{patterns})_(\w+)"
|
||||
match_object = re.search(regex, module)
|
||||
if match_object is not None:
|
||||
model_name = match_object.group(1)
|
||||
if model_name in match_object.group(2) and model_name != "auto":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_model_imports_from_file(file_path):
|
||||
"""From a python file `file_path`, extract the model-specific imports (the imports related to any model file in
|
||||
Transformers)"""
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
tree = ast.parse(file.read(), filename=file_path)
|
||||
imports = set()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
if is_model_import(node.module):
|
||||
imports.add(node.module)
|
||||
return imports
|
||||
|
||||
|
||||
def find_priority_list(modular_files: list[str]) -> tuple[list[list[str]], dict[str, set]]:
|
||||
"""
|
||||
Given a list of modular files, sorts them by topological order. Modular models that DON'T depend on other modular
|
||||
models will be lower in the topological order.
|
||||
|
||||
Args:
|
||||
modular_files (`list[str]`):
|
||||
List of paths to the modular files.
|
||||
|
||||
Returns:
|
||||
A tuple `ordered_files` and `dependencies`.
|
||||
|
||||
`ordered_file` is a list of lists consisting of the models at each level of the dependency graph. For example,
|
||||
it might be:
|
||||
[
|
||||
["../modular_mistral.py", "../modular_gemma.py"], # level 0
|
||||
["../modular_llama4.py", "../modular_gemma2.py"], # level 1
|
||||
["../modular_glm4.py"], # level 2
|
||||
]
|
||||
which means mistral and gemma do not depend on any other modular models, while llama4 and gemma2 depend on the
|
||||
models in the first list, and glm4 depends on the models in the second and (optionally) in the first list.
|
||||
|
||||
`dependencies` is a dictionary mapping each modular file to the models on which it relies (the models that are
|
||||
imported in order to use inheritance).
|
||||
"""
|
||||
dependencies = defaultdict(set)
|
||||
for file_path in modular_files:
|
||||
dependencies[file_path].update(extract_model_imports_from_file(file_path))
|
||||
ordered_files = topological_sort(dependencies)
|
||||
return ordered_files, dependencies
|
||||
1857
utils/create_dummy_models.py
Normal file
1857
utils/create_dummy_models.py
Normal file
File diff suppressed because it is too large
Load Diff
338
utils/custom_init_isort.py
Normal file
338
utils/custom_init_isort.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that sorts the imports in the custom inits of Transformers. Transformers uses init files that delay the
|
||||
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
|
||||
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
|
||||
delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the
|
||||
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. `isort` or `ruff`
|
||||
properly sort the second half which looks like traditionl imports, the goal of this script is to sort the first half.
|
||||
|
||||
Use from the root of the repo with:
|
||||
|
||||
```bash
|
||||
python utils/custom_init_isort.py
|
||||
```
|
||||
|
||||
which will auto-sort the imports (used in `make style`).
|
||||
|
||||
For a check only (as used in `make check-repo`) run:
|
||||
|
||||
```bash
|
||||
python utils/custom_init_isort.py --check_only
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "init_isort",
|
||||
"label": "Import ordering",
|
||||
"cache_globs": ["src/transformers/**/__init__.py"],
|
||||
"check_args": ["--check_only"],
|
||||
"fix_args": [],
|
||||
}
|
||||
|
||||
# Path is defined with the intent you should run this script from the root of the repo.
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
# Pattern that looks at the indentation in a line.
|
||||
_re_indent = re.compile(r"^(\s*)\S")
|
||||
# Pattern that matches `"key":" and puts `key` in group 0.
|
||||
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
|
||||
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
|
||||
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
|
||||
# Pattern that matches `"key",` and puts `key` in group 0.
|
||||
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
|
||||
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
|
||||
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
|
||||
|
||||
|
||||
def get_indent(line: str) -> str:
|
||||
"""Returns the indent in given line (as string)."""
|
||||
search = _re_indent.search(line)
|
||||
return "" if search is None else search.groups()[0]
|
||||
|
||||
|
||||
def split_code_in_indented_blocks(
|
||||
code: str, indent_level: str = "", start_prompt: str | None = None, end_prompt: str | None = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split some code into its indented blocks, starting at a given level.
|
||||
|
||||
Args:
|
||||
code (`str`): The code to split.
|
||||
indent_level (`str`): The indent level (as string) to use for identifying the blocks to split.
|
||||
start_prompt (`str`, *optional*): If provided, only starts splitting at the line where this text is.
|
||||
end_prompt (`str`, *optional*): If provided, stops splitting at a line where this text is.
|
||||
|
||||
Warning:
|
||||
The text before `start_prompt` or after `end_prompt` (if provided) is not ignored, just not split. The input `code`
|
||||
can thus be retrieved by joining the result.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of blocks.
|
||||
"""
|
||||
# Let's split the code into lines and move to start_index.
|
||||
index = 0
|
||||
lines = code.split("\n")
|
||||
if start_prompt is not None:
|
||||
while not lines[index].startswith(start_prompt):
|
||||
index += 1
|
||||
blocks = ["\n".join(lines[:index])]
|
||||
else:
|
||||
blocks = []
|
||||
|
||||
# This variable contains the block treated at a given time.
|
||||
current_block = [lines[index]]
|
||||
index += 1
|
||||
# We split into blocks until we get to the `end_prompt` (or the end of the file).
|
||||
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
|
||||
# We have a non-empty line with the proper indent -> start of a new block
|
||||
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
|
||||
# Store the current block in the result and rest. There are two cases: the line is part of the block (like
|
||||
# a closing parenthesis) or not.
|
||||
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
|
||||
# Line is part of the current block
|
||||
current_block.append(lines[index])
|
||||
blocks.append("\n".join(current_block))
|
||||
if index < len(lines) - 1:
|
||||
current_block = [lines[index + 1]]
|
||||
index += 1
|
||||
else:
|
||||
current_block = []
|
||||
else:
|
||||
# Line is not part of the current block
|
||||
blocks.append("\n".join(current_block))
|
||||
current_block = [lines[index]]
|
||||
else:
|
||||
# Just add the line to the current block
|
||||
current_block.append(lines[index])
|
||||
index += 1
|
||||
|
||||
# Adds current block if it's nonempty.
|
||||
if len(current_block) > 0:
|
||||
blocks.append("\n".join(current_block))
|
||||
|
||||
# Add final block after end_prompt if provided.
|
||||
if end_prompt is not None and index < len(lines):
|
||||
blocks.append("\n".join(lines[index:]))
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any], str]:
|
||||
"""
|
||||
Wraps a key function (as used in a sort) to lowercase and ignore underscores.
|
||||
"""
|
||||
|
||||
def _inner(x):
|
||||
return key(x).lower().replace("_", "")
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
def sort_objects(objects: list[Any], key: Callable[[Any], str] | None = None) -> list[Any]:
|
||||
"""
|
||||
Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
|
||||
last).
|
||||
|
||||
Args:
|
||||
objects (`List[Any]`):
|
||||
The list of objects to sort.
|
||||
key (`Callable[[Any], str]`, *optional*):
|
||||
A function taking an object as input and returning a string, used to sort them by alphabetical order.
|
||||
If not provided, will default to noop (so a `key` must be provided if the `objects` are not of type string).
|
||||
|
||||
Returns:
|
||||
`List[Any]`: The sorted list with the same elements as in the inputs
|
||||
"""
|
||||
|
||||
# If no key is provided, we use a noop.
|
||||
def noop(x):
|
||||
return x
|
||||
|
||||
if key is None:
|
||||
key = noop
|
||||
# Constants are all uppercase, they go first.
|
||||
constants = [obj for obj in objects if key(obj).isupper()]
|
||||
# Classes are not all uppercase but start with a capital, they go second.
|
||||
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
|
||||
# Functions begin with a lowercase, they go last.
|
||||
functions = [obj for obj in objects if not key(obj)[0].isupper()]
|
||||
|
||||
# Then we sort each group.
|
||||
key1 = ignore_underscore_and_lowercase(key)
|
||||
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
|
||||
|
||||
|
||||
def sort_objects_in_import(import_statement: str) -> str:
|
||||
"""
|
||||
Sorts the imports in a single import statement.
|
||||
|
||||
Args:
|
||||
import_statement (`str`): The import statement in which to sort the imports.
|
||||
|
||||
Returns:
|
||||
`str`: The same as the input, but with objects properly sorted.
|
||||
"""
|
||||
|
||||
# This inner function sort imports between [ ].
|
||||
def _replace(match):
|
||||
imports = match.groups()[0]
|
||||
# If there is one import only, nothing to do.
|
||||
if "," not in imports:
|
||||
return f"[{imports}]"
|
||||
keys = [part.strip().replace('"', "") for part in imports.split(",")]
|
||||
# We will have a final empty element if the line finished with a comma.
|
||||
if len(keys[-1]) == 0:
|
||||
keys = keys[:-1]
|
||||
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
|
||||
|
||||
lines = import_statement.split("\n")
|
||||
if len(lines) > 3:
|
||||
# Here we have to sort internal imports that are on several lines (one per name):
|
||||
# key: [
|
||||
# "object1",
|
||||
# "object2",
|
||||
# ...
|
||||
# ]
|
||||
|
||||
# We may have to ignore one or two lines on each side.
|
||||
idx = 2 if lines[1].strip() == "[" else 1
|
||||
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
|
||||
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
|
||||
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
|
||||
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
|
||||
elif len(lines) == 3:
|
||||
# Here we have to sort internal imports that are on one separate line:
|
||||
# key: [
|
||||
# "object1", "object2", ...
|
||||
# ]
|
||||
if _re_bracket_content.search(lines[1]) is not None:
|
||||
lines[1] = _re_bracket_content.sub(_replace, lines[1])
|
||||
else:
|
||||
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
|
||||
# We will have a final empty element if the line finished with a comma.
|
||||
if len(keys[-1]) == 0:
|
||||
keys = keys[:-1]
|
||||
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
|
||||
return "\n".join(lines)
|
||||
else:
|
||||
# Finally we have to deal with imports fitting on one line
|
||||
import_statement = _re_bracket_content.sub(_replace, import_statement)
|
||||
return import_statement
|
||||
|
||||
|
||||
def sort_imports(file: str, check_only: bool = True):
|
||||
"""
|
||||
Sort the imports defined in the `_import_structure` of a given init.
|
||||
|
||||
Args:
|
||||
file (`str`): The path to the init to check/fix.
|
||||
check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
|
||||
"""
|
||||
with open(file, encoding="utf-8") as f:
|
||||
code = f.read()
|
||||
|
||||
# If the file is not a custom init, there is nothing to do.
|
||||
if "_import_structure = {" not in code:
|
||||
return
|
||||
|
||||
# Blocks of indent level 0
|
||||
main_blocks = split_code_in_indented_blocks(
|
||||
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
|
||||
)
|
||||
|
||||
# We ignore block 0 (everything until start_prompt) and the last block (everything after end_prompt).
|
||||
for block_idx in range(1, len(main_blocks) - 1):
|
||||
# Check if the block contains some `_import_structure`s thingy to sort.
|
||||
block = main_blocks[block_idx]
|
||||
block_lines = block.split("\n")
|
||||
|
||||
# Get to the start of the imports.
|
||||
line_idx = 0
|
||||
while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
|
||||
# Skip dummy import blocks
|
||||
if "import dummy" in block_lines[line_idx]:
|
||||
line_idx = len(block_lines)
|
||||
else:
|
||||
line_idx += 1
|
||||
if line_idx >= len(block_lines):
|
||||
continue
|
||||
|
||||
# Ignore beginning and last line: they don't contain anything.
|
||||
internal_block_code = "\n".join(block_lines[line_idx:-1])
|
||||
indent = get_indent(block_lines[1])
|
||||
# Slit the internal block into blocks of indent level 1.
|
||||
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
|
||||
# We have two categories of import key: list or _import_structure[key].append/extend
|
||||
pattern = _re_direct_key if "_import_structure = {" in block_lines[0] else _re_indirect_key
|
||||
# Grab the keys, but there is a trap: some lines are empty or just comments.
|
||||
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
|
||||
# We only sort the lines with a key.
|
||||
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
|
||||
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
|
||||
|
||||
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
|
||||
count = 0
|
||||
reorderded_blocks = []
|
||||
for i in range(len(internal_blocks)):
|
||||
if keys[i] is None:
|
||||
reorderded_blocks.append(internal_blocks[i])
|
||||
else:
|
||||
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
|
||||
reorderded_blocks.append(block)
|
||||
count += 1
|
||||
|
||||
# And we put our main block back together with its first and last line.
|
||||
main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
|
||||
|
||||
if code != "\n".join(main_blocks):
|
||||
if check_only:
|
||||
return True
|
||||
else:
|
||||
print(f"Overwriting {file}.")
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(main_blocks))
|
||||
|
||||
|
||||
def sort_imports_in_all_inits(check_only=True):
|
||||
"""
|
||||
Sort the imports defined in the `_import_structure` of all inits in the repo.
|
||||
|
||||
Args:
|
||||
check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
|
||||
"""
|
||||
failures = []
|
||||
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
if "__init__.py" in files:
|
||||
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
|
||||
if result:
|
||||
failures = [os.path.join(root, "__init__.py")]
|
||||
if len(failures) > 0:
|
||||
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
|
||||
args = parser.parse_args()
|
||||
|
||||
sort_imports_in_all_inits(check_only=args.check_only)
|
||||
377
utils/deprecate_models.py
Normal file
377
utils/deprecate_models.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Script which deprecates a list of given models
|
||||
|
||||
Example usage:
|
||||
python utils/deprecate_models.py --models bert distilbert
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from custom_init_isort import sort_imports_in_all_inits
|
||||
from git import Repo
|
||||
from packaging import version
|
||||
|
||||
from transformers import CONFIG_MAPPING, logging
|
||||
from transformers import __version__ as current_version
|
||||
|
||||
|
||||
REPO_PATH = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
|
||||
repo = Repo(REPO_PATH)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_last_stable_minor_release():
|
||||
# Get the last stable release of transformers
|
||||
url = "https://pypi.org/pypi/transformers/json"
|
||||
release_data = httpx.get(url).json()
|
||||
|
||||
# Find the last stable release of transformers (version below current version)
|
||||
major_version, minor_version, patch_version, _ = current_version.split(".")
|
||||
last_major_minor = f"{major_version}.{int(minor_version) - 1}"
|
||||
last_stable_minor_releases = [
|
||||
release for release in release_data["releases"] if release.startswith(last_major_minor)
|
||||
]
|
||||
last_stable_release = max(last_stable_minor_releases, key=version.parse)
|
||||
|
||||
return last_stable_release
|
||||
|
||||
|
||||
def build_tip_message(last_stable_release):
|
||||
return (
|
||||
"""
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
"""
|
||||
+ f"""If you run into any issues running this model, please reinstall the last version that supported this model: v{last_stable_release}.
|
||||
You can do so by running the following command: `pip install -U transformers=={last_stable_release}`.
|
||||
|
||||
</Tip>"""
|
||||
)
|
||||
|
||||
|
||||
def insert_tip_to_model_doc(model_doc_path, tip_message):
|
||||
tip_message_lines = tip_message.split("\n")
|
||||
|
||||
with open(model_doc_path, "r") as f:
|
||||
model_doc = f.read()
|
||||
|
||||
# Add the tip message to the model doc page directly underneath the title
|
||||
lines = model_doc.split("\n")
|
||||
|
||||
new_model_lines = []
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
new_model_lines.append(line)
|
||||
new_model_lines.extend(tip_message_lines)
|
||||
else:
|
||||
new_model_lines.append(line)
|
||||
|
||||
with open(model_doc_path, "w") as f:
|
||||
f.write("\n".join(new_model_lines))
|
||||
|
||||
|
||||
def get_model_doc_path(model: str) -> tuple[str | None, str | None]:
|
||||
# Possible variants of the model name in the model doc path
|
||||
model_names = [model, model.replace("_", "-"), model.replace("_", "")]
|
||||
|
||||
model_doc_paths = [REPO_PATH / f"docs/source/en/model_doc/{model_name}.md" for model_name in model_names]
|
||||
|
||||
for model_doc_path, model_name in zip(model_doc_paths, model_names):
|
||||
if os.path.exists(model_doc_path):
|
||||
return model_doc_path, model_name
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def extract_model_info(model):
|
||||
model_info = {}
|
||||
model_doc_path, model_doc_name = get_model_doc_path(model)
|
||||
model_path = REPO_PATH / f"src/transformers/models/{model}"
|
||||
|
||||
if model_doc_path is None:
|
||||
print(f"Model doc path does not exist for {model}")
|
||||
return None
|
||||
model_info["model_doc_path"] = model_doc_path
|
||||
model_info["model_doc_name"] = model_doc_name
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Model path does not exist for {model}")
|
||||
return None
|
||||
model_info["model_path"] = model_path
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
def update_relative_imports(filename, model):
|
||||
with open(filename, "r") as f:
|
||||
filelines = f.read()
|
||||
|
||||
new_file_lines = []
|
||||
for line in filelines.split("\n"):
|
||||
if line.startswith("from .."):
|
||||
new_file_lines.append(line.replace("from ..", "from ..."))
|
||||
else:
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def remove_copied_from_statements(model):
|
||||
model_path = REPO_PATH / f"src/transformers/models/{model}"
|
||||
for file in os.listdir(model_path):
|
||||
if file == "__pycache__":
|
||||
continue
|
||||
file_path = model_path / file
|
||||
with open(file_path, "r") as f:
|
||||
file_lines = f.read()
|
||||
|
||||
new_file_lines = []
|
||||
for line in file_lines.split("\n"):
|
||||
if "# Copied from" in line:
|
||||
continue
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def move_model_files_to_deprecated(model):
|
||||
model_path = REPO_PATH / f"src/transformers/models/{model}"
|
||||
deprecated_model_path = REPO_PATH / f"src/transformers/models/deprecated/{model}"
|
||||
|
||||
if not os.path.exists(deprecated_model_path):
|
||||
os.makedirs(deprecated_model_path)
|
||||
|
||||
for file in os.listdir(model_path):
|
||||
if file == "__pycache__":
|
||||
continue
|
||||
repo.git.mv(f"{model_path}/{file}", f"{deprecated_model_path}/{file}")
|
||||
|
||||
# For deprecated files, we then need to update the relative imports
|
||||
update_relative_imports(f"{deprecated_model_path}/{file}", model)
|
||||
|
||||
|
||||
def delete_model_tests(model):
|
||||
tests_path = REPO_PATH / f"tests/models/{model}"
|
||||
|
||||
if os.path.exists(tests_path):
|
||||
repo.git.rm("-r", tests_path)
|
||||
|
||||
|
||||
def get_line_indent(s):
|
||||
return len(s) - len(s.lstrip())
|
||||
|
||||
|
||||
def update_main_init_file(models):
|
||||
"""
|
||||
Replace all instances of model.model_name with model.deprecated.model_name in the __init__.py file
|
||||
|
||||
Args:
|
||||
models (List[str]): The models to mark as deprecated
|
||||
"""
|
||||
filename = REPO_PATH / "src/transformers/__init__.py"
|
||||
with open(filename, "r") as f:
|
||||
init_file = f.read()
|
||||
|
||||
# 1. For each model, find all the instances of model.model_name and replace with model.deprecated.model_name
|
||||
for model in models:
|
||||
init_file = init_file.replace(f'models.{model}"', f'models.deprecated.{model}"')
|
||||
init_file = init_file.replace(f"models.{model} import", f"models.deprecated.{model} import")
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(init_file)
|
||||
|
||||
# 2. Resort the imports
|
||||
sort_imports_in_all_inits(check_only=False)
|
||||
|
||||
|
||||
def remove_model_references_from_file(filename, models, condition):
|
||||
"""
|
||||
Remove all references to the given models from the given file
|
||||
|
||||
Args:
|
||||
filename (str): The file to remove the references from
|
||||
models (List[str]): The models to remove
|
||||
condition (Callable): A function that takes the line and model and returns True if the line should be removed
|
||||
"""
|
||||
filename = REPO_PATH / filename
|
||||
with open(filename, "r") as f:
|
||||
init_file = f.read()
|
||||
|
||||
new_file_lines = []
|
||||
for i, line in enumerate(init_file.split("\n")):
|
||||
if any(condition(line, model) for model in models):
|
||||
continue
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def remove_model_config_classes_from_config_check(model_config_classes):
|
||||
"""
|
||||
Remove the deprecated model config classes from the check_config_attributes.py file
|
||||
|
||||
Args:
|
||||
model_config_classes (List[str]): The model config classes to remove e.g. ["BertConfig", "DistilBertConfig"]
|
||||
"""
|
||||
filename = REPO_PATH / "utils/check_config_attributes.py"
|
||||
with open(filename, "r") as f:
|
||||
check_config_attributes = f.read()
|
||||
|
||||
# Keep track as we have to delete comment above too
|
||||
in_special_cases_to_allow = False
|
||||
in_indent = False
|
||||
new_file_lines = []
|
||||
|
||||
for line in check_config_attributes.split("\n"):
|
||||
indent = get_line_indent(line)
|
||||
if (line.strip() == "SPECIAL_CASES_TO_ALLOW = {") or (line.strip() == "SPECIAL_CASES_TO_ALLOW.update("):
|
||||
in_special_cases_to_allow = True
|
||||
|
||||
elif in_special_cases_to_allow and indent == 0 and line.strip() in ("}", ")"):
|
||||
in_special_cases_to_allow = False
|
||||
|
||||
if in_indent:
|
||||
if line.strip().endswith(("]", "],")):
|
||||
in_indent = False
|
||||
continue
|
||||
|
||||
if in_special_cases_to_allow and any(
|
||||
model_config_class in line for model_config_class in model_config_classes
|
||||
):
|
||||
# Remove comments above the model config class to remove
|
||||
while new_file_lines[-1].strip().startswith("#"):
|
||||
new_file_lines.pop()
|
||||
|
||||
if line.strip().endswith("["):
|
||||
in_indent = True
|
||||
|
||||
continue
|
||||
|
||||
elif any(model_config_class in line for model_config_class in model_config_classes):
|
||||
continue
|
||||
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def add_models_to_deprecated_models_in_config_auto(models):
|
||||
"""
|
||||
Add the models to the DEPRECATED_MODELS list in configuration_auto.py and sorts the list
|
||||
to be in alphabetical order.
|
||||
"""
|
||||
filepath = REPO_PATH / "src/transformers/models/auto/configuration_auto.py"
|
||||
with open(filepath, "r") as f:
|
||||
config_auto = f.read()
|
||||
|
||||
new_file_lines = []
|
||||
deprecated_models_list = []
|
||||
in_deprecated_models = False
|
||||
for line in config_auto.split("\n"):
|
||||
if line.strip() == "DEPRECATED_MODELS = [":
|
||||
in_deprecated_models = True
|
||||
new_file_lines.append(line)
|
||||
elif in_deprecated_models and line.strip() == "]":
|
||||
in_deprecated_models = False
|
||||
# Add the new models to deprecated models list
|
||||
deprecated_models_list.extend([f' "{model}", ' for model in models])
|
||||
# Sort so they're in alphabetical order in the file
|
||||
deprecated_models_list = sorted(deprecated_models_list)
|
||||
new_file_lines.extend(deprecated_models_list)
|
||||
# Make sure we still have the closing bracket
|
||||
new_file_lines.append(line)
|
||||
elif in_deprecated_models:
|
||||
deprecated_models_list.append(line)
|
||||
else:
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(filepath, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def deprecate_models(models):
|
||||
# Get model info
|
||||
skipped_models = []
|
||||
models_info = defaultdict(dict)
|
||||
for model in models:
|
||||
single_model_info = extract_model_info(model)
|
||||
if single_model_info is None:
|
||||
skipped_models.append(model)
|
||||
else:
|
||||
models_info[model] = single_model_info
|
||||
|
||||
model_config_classes = []
|
||||
for model, model_info in models_info.items():
|
||||
if model in CONFIG_MAPPING:
|
||||
model_config_classes.append(CONFIG_MAPPING[model].__name__)
|
||||
elif model_info["model_doc_name"] in CONFIG_MAPPING:
|
||||
model_config_classes.append(CONFIG_MAPPING[model_info["model_doc_name"]].__name__)
|
||||
else:
|
||||
skipped_models.append(model)
|
||||
print(f"Model config class not found for model: {model}")
|
||||
|
||||
# Filter out skipped models
|
||||
models = [model for model in models if model not in skipped_models]
|
||||
|
||||
if skipped_models:
|
||||
print(f"Skipped models: {skipped_models} as the model doc or model path could not be found.")
|
||||
print(f"Models to deprecate: {models}")
|
||||
|
||||
# Remove model config classes from config check
|
||||
print("Removing model config classes from config checks")
|
||||
remove_model_config_classes_from_config_check(model_config_classes)
|
||||
|
||||
tip_message = build_tip_message(get_last_stable_minor_release())
|
||||
|
||||
for model, model_info in models_info.items():
|
||||
print(f"Processing model: {model}")
|
||||
# Add the tip message to the model doc page directly underneath the title
|
||||
print("Adding tip message to model doc page")
|
||||
insert_tip_to_model_doc(model_info["model_doc_path"], tip_message)
|
||||
|
||||
# Remove #Copied from statements from model's files
|
||||
print("Removing #Copied from statements from model's files")
|
||||
remove_copied_from_statements(model)
|
||||
|
||||
# Move the model file to deprecated: src/transformers/models/model -> src/transformers/models/deprecated/model
|
||||
print("Moving model files to deprecated for model")
|
||||
move_model_files_to_deprecated(model)
|
||||
|
||||
# Delete the model tests: tests/models/model
|
||||
print("Deleting model tests")
|
||||
delete_model_tests(model)
|
||||
|
||||
# # We do the following with all models passed at once to avoid having to re-write the file multiple times
|
||||
print("Updating __init__.py file to point to the deprecated models")
|
||||
update_main_init_file(models)
|
||||
|
||||
# Remove model references from other files
|
||||
print("Removing model references from other files")
|
||||
remove_model_references_from_file(
|
||||
"src/transformers/models/__init__.py", models, lambda line, model: model == line.strip().strip(",")
|
||||
)
|
||||
remove_model_references_from_file(
|
||||
"utils/slow_documentation_tests.txt", models, lambda line, model: "/" + model + "/" in line
|
||||
)
|
||||
remove_model_references_from_file("utils/not_doctested.txt", models, lambda line, model: "/" + model + "/" in line)
|
||||
|
||||
# Add models to DEPRECATED_MODELS in the configuration_auto.py
|
||||
print("Adding models to DEPRECATED_MODELS in configuration_auto.py")
|
||||
add_models_to_deprecated_models_in_config_auto(models)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--models", nargs="+", help="List of models to deprecate")
|
||||
args = parser.parse_args()
|
||||
deprecate_models(args.models)
|
||||
160
utils/download_glue_data.py
Normal file
160
utils/download_glue_data.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Script for downloading all GLUE data.
|
||||
Original source: https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
|
||||
|
||||
Note: for legal reasons, we are unable to host MRPC.
|
||||
You can either use the version hosted by the SentEval team, which is already tokenized,
|
||||
or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
|
||||
For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
|
||||
You should then rename and place specific files in a folder (see below for an example).
|
||||
|
||||
mkdir MRPC
|
||||
cabextract MSRParaphraseCorpus.msi -d MRPC
|
||||
cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
|
||||
cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
|
||||
rm MRPC/_*
|
||||
rm MSRParaphraseCorpus.msi
|
||||
|
||||
1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
|
||||
2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
|
||||
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
|
||||
TASK2PATH = {
|
||||
"CoLA": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4",
|
||||
"SST": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8",
|
||||
"MRPC": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc",
|
||||
"QQP": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5",
|
||||
"STS": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5",
|
||||
"MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",
|
||||
"SNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df",
|
||||
"QNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601",
|
||||
"RTE": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb",
|
||||
"WNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf",
|
||||
"diagnostic": "https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D",
|
||||
}
|
||||
|
||||
MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
|
||||
MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"
|
||||
|
||||
|
||||
def download_and_extract(task, data_dir):
|
||||
print(f"Downloading and extracting {task}...")
|
||||
data_file = f"{task}.zip"
|
||||
urllib.request.urlretrieve(TASK2PATH[task], data_file)
|
||||
with zipfile.ZipFile(data_file) as zip_ref:
|
||||
zip_ref.extractall(data_dir)
|
||||
os.remove(data_file)
|
||||
print("\tCompleted!")
|
||||
|
||||
|
||||
def format_mrpc(data_dir, path_to_data):
|
||||
print("Processing MRPC...")
|
||||
mrpc_dir = os.path.join(data_dir, "MRPC")
|
||||
if not os.path.isdir(mrpc_dir):
|
||||
os.mkdir(mrpc_dir)
|
||||
if path_to_data:
|
||||
mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
|
||||
mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
|
||||
else:
|
||||
print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
|
||||
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
|
||||
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
|
||||
urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
|
||||
urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
|
||||
if not os.path.isfile(mrpc_train_file):
|
||||
raise ValueError(f"Train data not found at {mrpc_train_file}")
|
||||
if not os.path.isfile(mrpc_test_file):
|
||||
raise ValueError(f"Test data not found at {mrpc_test_file}")
|
||||
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
|
||||
|
||||
dev_ids = []
|
||||
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
|
||||
for row in ids_fh:
|
||||
dev_ids.append(row.strip().split("\t"))
|
||||
|
||||
with (
|
||||
open(mrpc_train_file, encoding="utf8") as data_fh,
|
||||
open(os.path.join(mrpc_dir, "train.tsv"), "w", encoding="utf8") as train_fh,
|
||||
open(os.path.join(mrpc_dir, "dev.tsv"), "w", encoding="utf8") as dev_fh,
|
||||
):
|
||||
header = data_fh.readline()
|
||||
train_fh.write(header)
|
||||
dev_fh.write(header)
|
||||
for row in data_fh:
|
||||
label, id1, id2, s1, s2 = row.strip().split("\t")
|
||||
if [id1, id2] in dev_ids:
|
||||
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
|
||||
else:
|
||||
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
|
||||
|
||||
with (
|
||||
open(mrpc_test_file, encoding="utf8") as data_fh,
|
||||
open(os.path.join(mrpc_dir, "test.tsv"), "w", encoding="utf8") as test_fh,
|
||||
):
|
||||
header = data_fh.readline()
|
||||
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
|
||||
for idx, row in enumerate(data_fh):
|
||||
label, id1, id2, s1, s2 = row.strip().split("\t")
|
||||
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
|
||||
print("\tCompleted!")
|
||||
|
||||
|
||||
def download_diagnostic(data_dir):
|
||||
print("Downloading and extracting diagnostic...")
|
||||
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
|
||||
os.mkdir(os.path.join(data_dir, "diagnostic"))
|
||||
data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
|
||||
urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
|
||||
print("\tCompleted!")
|
||||
return
|
||||
|
||||
|
||||
def get_tasks(task_names):
|
||||
task_names = task_names.split(",")
|
||||
if "all" in task_names:
|
||||
tasks = TASKS
|
||||
else:
|
||||
tasks = []
|
||||
for task_name in task_names:
|
||||
if task_name not in TASKS:
|
||||
raise ValueError(f"Task {task_name} not found!")
|
||||
tasks.append(task_name)
|
||||
return tasks
|
||||
|
||||
|
||||
def main(arguments):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_dir", help="directory to save data to", type=str, default="glue_data")
|
||||
parser.add_argument(
|
||||
"--tasks", help="tasks to download data for as a comma separated string", type=str, default="all"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path_to_mrpc",
|
||||
help="path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt",
|
||||
type=str,
|
||||
default="",
|
||||
)
|
||||
args = parser.parse_args(arguments)
|
||||
|
||||
if not os.path.isdir(args.data_dir):
|
||||
os.mkdir(args.data_dir)
|
||||
tasks = get_tasks(args.tasks)
|
||||
|
||||
for task in tasks:
|
||||
if task == "MRPC":
|
||||
format_mrpc(args.data_dir, args.path_to_mrpc)
|
||||
elif task == "diagnostic":
|
||||
download_diagnostic(args.data_dir)
|
||||
else:
|
||||
download_and_extract(task, args.data_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
69
utils/extract_metadata.py
Executable file
69
utils/extract_metadata.py
Executable file
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Extract metadata from setup.py for CI testing.
|
||||
|
||||
Usage:
|
||||
python utils/extract_metadata.py extras # List all extras (one per line)
|
||||
python utils/extract_metadata.py python-versions # Output JSON array of Python versions
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def get_setup_module() -> ModuleType:
|
||||
"""Import and return the setup module."""
|
||||
repo_root: Path = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(repo_root))
|
||||
import setup
|
||||
|
||||
return setup
|
||||
|
||||
|
||||
def extract_extras() -> None:
|
||||
"""Print all extras in definition order (one per line)."""
|
||||
setup: ModuleType = get_setup_module()
|
||||
for extra in setup.extras.keys():
|
||||
print(extra)
|
||||
|
||||
|
||||
def extract_python_versions() -> None:
|
||||
"""Print supported Python versions as a JSON array."""
|
||||
setup: ModuleType = get_setup_module()
|
||||
min_ver: int
|
||||
max_ver: int
|
||||
min_ver, max_ver = setup.SUPPORTED_PYTHON_VERSIONS
|
||||
versions: list[str] = [f"3.{v}" for v in range(min_ver, max_ver + 1)]
|
||||
print(json.dumps(versions))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python utils/extract_metadata.py {extras|python-versions}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
command: str = sys.argv[1]
|
||||
|
||||
if command == "extras":
|
||||
extract_extras()
|
||||
elif command == "python-versions":
|
||||
extract_python_versions()
|
||||
else:
|
||||
print(f"Unknown command: {command}", file=sys.stderr)
|
||||
print("Usage: python utils/extract_metadata.py {extras|python-versions}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
30
utils/extract_pr_number_from_circleci.py
Normal file
30
utils/extract_pr_number_from_circleci.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Used by `.github/workflows/trigger_circleci.yml` to get the pull request number in CircleCI job runs."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pr_number = ""
|
||||
|
||||
pr = os.environ.get("CIRCLE_PULL_REQUEST", "")
|
||||
if len(pr) > 0:
|
||||
pr_number = pr.split("/")[-1]
|
||||
if pr_number == "":
|
||||
pr = os.environ.get("CIRCLE_BRANCH", "")
|
||||
if pr.startswith("pull/"):
|
||||
pr_number = "".join(pr.split("/")[1:2])
|
||||
|
||||
print(pr_number)
|
||||
134
utils/extract_warnings.py
Normal file
134
utils/extract_warnings.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import zipfile
|
||||
|
||||
from get_ci_error_statistics import download_artifact, get_artifacts_links
|
||||
|
||||
from transformers import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def extract_warnings_from_single_artifact(artifact_path, targets):
|
||||
"""Extract warnings from a downloaded artifact (in .zip format)"""
|
||||
selected_warnings = set()
|
||||
buffer = []
|
||||
|
||||
def parse_line(fp):
|
||||
for line in fp:
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("UTF-8")
|
||||
if "warnings summary (final)" in line:
|
||||
continue
|
||||
# This means we are outside the body of a warning
|
||||
elif not line.startswith(" "):
|
||||
# process a single warning and move it to `selected_warnings`.
|
||||
if len(buffer) > 0:
|
||||
warning = "\n".join(buffer)
|
||||
# Only keep the warnings specified in `targets`
|
||||
if any(f": {x}: " in warning for x in targets):
|
||||
selected_warnings.add(warning)
|
||||
buffer.clear()
|
||||
continue
|
||||
else:
|
||||
line = line.strip()
|
||||
buffer.append(line)
|
||||
|
||||
if from_gh:
|
||||
for filename in os.listdir(artifact_path):
|
||||
file_path = os.path.join(artifact_path, filename)
|
||||
if not os.path.isdir(file_path):
|
||||
# read the file
|
||||
if filename != "warnings.txt":
|
||||
continue
|
||||
with open(file_path) as fp:
|
||||
parse_line(fp)
|
||||
else:
|
||||
try:
|
||||
with zipfile.ZipFile(artifact_path) as z:
|
||||
for filename in z.namelist():
|
||||
if not os.path.isdir(filename):
|
||||
# read the file
|
||||
if filename != "warnings.txt":
|
||||
continue
|
||||
with z.open(filename) as fp:
|
||||
parse_line(fp)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"{artifact_path} is either an invalid zip file or something else wrong. This file is skipped."
|
||||
)
|
||||
|
||||
return selected_warnings
|
||||
|
||||
|
||||
def extract_warnings(artifact_dir, targets):
|
||||
"""Extract warnings from all artifact files"""
|
||||
|
||||
selected_warnings = set()
|
||||
|
||||
paths = [os.path.join(artifact_dir, p) for p in os.listdir(artifact_dir) if (p.endswith(".zip") or from_gh)]
|
||||
for p in paths:
|
||||
selected_warnings.update(extract_warnings_from_single_artifact(p, targets))
|
||||
|
||||
return selected_warnings
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def list_str(values):
|
||||
return values.split(",")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--workflow_run_id", type=str, required=True, help="A GitHub Actions workflow run id.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Where to store the downloaded artifacts and other result files.",
|
||||
)
|
||||
parser.add_argument("--token", default=None, type=str, help="A token that has actions:read permission.")
|
||||
# optional parameters
|
||||
parser.add_argument(
|
||||
"--targets",
|
||||
default="DeprecationWarning,UserWarning,FutureWarning",
|
||||
type=list_str,
|
||||
help="Comma-separated list of target warning(s) which we want to extract.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from_gh",
|
||||
action="store_true",
|
||||
help="If running from a GitHub action workflow and collecting warnings from its artifacts.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
from_gh = args.from_gh
|
||||
if from_gh:
|
||||
# The artifacts have to be downloaded using `actions/download-artifact@v4`
|
||||
pass
|
||||
else:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# get download links
|
||||
artifacts = get_artifacts_links(args.workflow_run_id, token=args.token)
|
||||
with open(os.path.join(args.output_dir, "artifacts.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(artifacts, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
# download artifacts
|
||||
for idx, (name, url) in enumerate(artifacts.items()):
|
||||
print(name)
|
||||
print(url)
|
||||
print("=" * 80)
|
||||
download_artifact(name, url, args.output_dir, args.token)
|
||||
# Be gentle to GitHub
|
||||
time.sleep(1)
|
||||
|
||||
# extract warnings from artifacts
|
||||
selected_warnings = extract_warnings(args.output_dir, args.targets)
|
||||
selected_warnings = sorted(selected_warnings)
|
||||
with open(os.path.join(args.output_dir, "selected_warnings.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(selected_warnings, fp, ensure_ascii=False, indent=4)
|
||||
382
utils/fetch_hub_objects_for_ci.py
Normal file
382
utils/fetch_hub_objects_for_ci.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script downloads files from the HuggingFace Hub to be used for CI tests.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
# Ensure we always download from the public HuggingFace Hub, not the CI staging endpoint.
|
||||
# huggingface_hub reads HUGGINGFACE_CO_STAGING at import time and hardcodes hub-ci.huggingface.co.
|
||||
_staging_mode = os.environ.pop("HUGGINGFACE_CO_STAGING", None)
|
||||
|
||||
import httpx # noqa: E402
|
||||
from huggingface_hub import hf_hub_download, snapshot_download # noqa: E402
|
||||
|
||||
from transformers.testing_utils import _run_pipeline_tests, _run_staging # noqa: E402
|
||||
from transformers.utils.import_utils import is_mistral_common_available # noqa: E402
|
||||
|
||||
|
||||
# ruff: enable[E402]
|
||||
|
||||
# Restore so transformers.testing_utils._run_staging can still read it.
|
||||
if _staging_mode is not None:
|
||||
os.environ["HUGGINGFACE_CO_STAGING"] = _staging_mode
|
||||
|
||||
|
||||
URLS_FOR_TESTING_DATA = [
|
||||
# TODO: copy those to our hf-internal-testing dataset and fix all tests using them
|
||||
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_formula_rec_001.png",
|
||||
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/ocr_demo2.jpg",
|
||||
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/doc_test.jpg",
|
||||
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/chart_parsing_02.png",
|
||||
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/layout_demo.jpg",
|
||||
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/img_rot180_demo.jpg",
|
||||
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_rec_001.png",
|
||||
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png",
|
||||
"https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg",
|
||||
# Don't use the original COCO URLs anymore. Replace with images from https://huggingface.co/datasets/hf-internal-testing/fixtures-coco below
|
||||
"http://images.cocodataset.org/val2017/000000000139.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000285.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000632.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000724.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000776.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000785.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000802.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000872.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000001000.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000077595.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000136466.jpg",
|
||||
"https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
|
||||
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||
"https://llava-vl.github.io/static/images/view.jpg",
|
||||
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/audio-test/resolve/main/f2641_0_throatclearing.wav",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/audio-test/resolve/main/glass-breaking-151256.mp3",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/images_test/resolve/main/picsum_237_200x300.jpg",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
"https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png",
|
||||
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/two_dogs.jpg",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/australia.jpg",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/tiny_video.mp4",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/tiny_video.mp4",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
|
||||
# we should rely on this single dataset for our tests
|
||||
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/coco_annotations.txt",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/coco_panoptic_annotations.txt",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/coco_panoptic/000000039769.png",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000139.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000285.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000632.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000724.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000776.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000785.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000802.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000000872.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000001000.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000004016.png",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000039769.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000039769.png",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000077595.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-coco/resolve/main/val2017/000000136466.jpg",
|
||||
]
|
||||
|
||||
|
||||
def url_to_local_path(url, return_url_if_not_found=True):
|
||||
filename = url.split("/")[-1]
|
||||
|
||||
if not os.path.exists(filename) and return_url_if_not_found:
|
||||
return url
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def parse_hf_url(url):
|
||||
"""
|
||||
Parse a HuggingFace Hub URL into components for hf_hub_download.
|
||||
|
||||
Returns dict with (repo_id, filename, repo_type, revision) or None if not a HF URL.
|
||||
"""
|
||||
pattern = r"https://huggingface\.co/(datasets/)?([^/]+/[^/]+)/resolve/([^/]+)/(.+)"
|
||||
match = re.match(pattern, url)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
is_dataset = match.group(1) is not None
|
||||
revision = match.group(3)
|
||||
return {
|
||||
"repo_id": match.group(2),
|
||||
"filename": match.group(4),
|
||||
"repo_type": "dataset" if is_dataset else "model",
|
||||
"revision": revision if revision != "main" else None,
|
||||
}
|
||||
|
||||
|
||||
def validate_downloaded_content(filepath):
|
||||
with open(filepath, "rb") as f:
|
||||
header = f.read(32)
|
||||
|
||||
for bad_sig in [b"<!doctype", b"<html", b'{"error', b'{"message']:
|
||||
if header.lower().startswith(bad_sig):
|
||||
raise ValueError(
|
||||
f"Downloaded file appears to be an HTML error page, not a valid media file. "
|
||||
f"This may indicate rate limiting. File starts with: {header[:200]!r}"
|
||||
)
|
||||
|
||||
file_size = os.path.getsize(filepath)
|
||||
if file_size < 100:
|
||||
raise ValueError(f"Downloaded file is suspiciously small ({file_size} bytes).")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def download_test_file(url):
|
||||
"""
|
||||
Download a URL to a local file, using hf_hub_download for HF URLs.
|
||||
|
||||
For HuggingFace URLs, uses hf_hub_download which handles authentication
|
||||
automatically via the HF_TOKEN environment variable.
|
||||
|
||||
Returns the local filename.
|
||||
"""
|
||||
filename = url.split("/")[-1]
|
||||
|
||||
# Skip if file already exists
|
||||
if os.path.exists(filename):
|
||||
print(f"File already exists: {filename}")
|
||||
return filename
|
||||
|
||||
# Check if this is a HuggingFace URL
|
||||
hf_parts = parse_hf_url(url)
|
||||
|
||||
if hf_parts:
|
||||
# Use hf_hub_download for HF URLs - handles auth automatically via HF_TOKEN env var
|
||||
print(f"Downloading {filename} from HuggingFace Hub...")
|
||||
try:
|
||||
hf_hub_download(**hf_parts, local_dir=".")
|
||||
print(f"Successfully downloaded: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Error downloading {filename} from HuggingFace Hub: {e}")
|
||||
raise
|
||||
else:
|
||||
# Use httpx for non-HF URLs (COCO, Britannica, etc.)
|
||||
import time
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
print(f"Downloading {filename} from {url}")
|
||||
with open(filename, "wb") as f:
|
||||
with httpx.stream("GET", url, follow_redirects=True) as resp:
|
||||
resp.raise_for_status()
|
||||
f.writelines(resp.iter_bytes(chunk_size=8192))
|
||||
|
||||
validate_downloaded_content(filename)
|
||||
print(f"Successfully downloaded: {filename}")
|
||||
break
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait = 2 ** (attempt + 1)
|
||||
print(f"Attempt {attempt + 1} failed for {filename}: {e}. Retrying in {wait}s...")
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
time.sleep(wait)
|
||||
else:
|
||||
raise
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if _run_pipeline_tests:
|
||||
import datasets
|
||||
|
||||
_ = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
_ = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", split="test", revision="refs/pr/1")
|
||||
_ = hf_hub_download(repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset")
|
||||
|
||||
hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
|
||||
hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt")
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/fixtures_docvqa",
|
||||
filename="nougat_pdf.png",
|
||||
repo_type="dataset",
|
||||
revision="ec57bf8c8b1653a209c13f6e9ee66b12df0fc2db",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/spaghetti-video",
|
||||
filename="eating_spaghetti_32_frames.npy",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/spaghetti-video",
|
||||
filename="eating_spaghetti_8_frames.npy",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(repo_id="huggyllama/llama-7b", filename="tokenizer.model")
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/audio-spectogram-transformer-checkpoint", filename="sample_audio.flac", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(repo_id="nielsr/example-pdf", repo_type="dataset", filename="example_pdf.png")
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/test-image",
|
||||
filename="llava_1_6_input_ids.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/test-image",
|
||||
filename="llava_1_6_pixel_values.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
|
||||
hf_hub_download(
|
||||
repo_id="raushan-testing-hf/images_test",
|
||||
filename="emu3_image.npy",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(repo_id="raushan-testing-hf/images_test", filename="llava_v1_5_radar.jpg", repo_type="dataset")
|
||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
|
||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset")
|
||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset")
|
||||
hf_hub_download(
|
||||
repo_id="shumingh/perception_lm_test_images",
|
||||
filename="14496_0.PNG",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="shumingh/perception_lm_test_videos",
|
||||
filename="GUWR5TyiY-M_000012_000022.mp4",
|
||||
repo_type="dataset",
|
||||
)
|
||||
repo_id = "nielsr/image-segmentation-toy-data"
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="instance_segmentation_image_1.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="instance_segmentation_image_2.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="instance_segmentation_annotation_1.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="instance_segmentation_annotation_2.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="semantic_segmentation_annotation_1.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="semantic_segmentation_annotation_2.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="semantic_segmentation_image_1.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="semantic_segmentation_image_2.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download("shi-labs/oneformer_demo", "ade20k_panoptic.json", repo_type="dataset")
|
||||
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/audio-spectogram-transformer-checkpoint", filename="sample_audio.flac", repo_type="dataset"
|
||||
)
|
||||
|
||||
# Need to specify the username on the endpoint `hub-ci`, otherwise we get
|
||||
# `fatal: could not read Username for 'https://hub-ci.huggingface.co': Success`
|
||||
# But this repo. is never used in a test decorated by `is_staging_test`.
|
||||
if not _run_staging:
|
||||
if not os.path.isdir("tiny-random-custom-architecture"):
|
||||
snapshot_download(
|
||||
"hf-internal-testing/tiny-random-custom-architecture",
|
||||
local_dir="tiny-random-custom-architecture",
|
||||
)
|
||||
|
||||
# For `tests/test_tokenization_mistral_common.py:TestMistralCommonBackend`, which eventually calls
|
||||
# `mistral_common.tokens.tokenizers.utils.download_tokenizer_from_hf_hub` which (probably) doesn't have the cache.
|
||||
# For `revision=None`, see https://github.com/huggingface/transformers/pull/40623
|
||||
if is_mistral_common_available():
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.utils import list_local_hf_repo_files
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.tokenization_mistral_common import MistralCommonBackend
|
||||
|
||||
repo_id = "hf-internal-testing/namespace-mistralai-repo_name-Mistral-Small-3.1-24B-Instruct-2503"
|
||||
|
||||
# determine if we already have this downloaded
|
||||
local_files_only = len(list_local_hf_repo_files(repo_id, revision=None)) > 0
|
||||
|
||||
# This will go the path `transformers/tokenization_mistral_common.py::MistralCommonBackend::from_pretrained --> mistral_common.tokens.tokenizers.utils.download_tokenizer_from_hf_hub`.
|
||||
# No idea at all why we need the statement below again (`MistralCommonBackend.from_pretrained`).
|
||||
AutoTokenizer.from_pretrained(
|
||||
repo_id, tokenizer_type="mistral", local_files_only=local_files_only, revision=None
|
||||
)
|
||||
|
||||
_ = MistralCommonBackend.from_pretrained(
|
||||
repo_id,
|
||||
local_files_only=local_files_only,
|
||||
# This is a hack as `list_local_hf_repo_files` from `mistral_common` has a bug
|
||||
# TODO: Discuss with `mistral-common` maintainers: after a fix being done there, remove this `revision` hack
|
||||
revision=None,
|
||||
)
|
||||
|
||||
MistralTokenizer.from_hf_hub(repo_id, local_files_only=local_files_only)
|
||||
|
||||
repo_id = "mistralai/Voxtral-Mini-3B-2507"
|
||||
local_files_only = len(list_local_hf_repo_files(repo_id, revision=None)) > 0
|
||||
|
||||
AutoTokenizer.from_pretrained(repo_id, local_files_only=local_files_only, revision=None)
|
||||
MistralTokenizer.from_hf_hub(repo_id, local_files_only=local_files_only)
|
||||
|
||||
# Download files from URLs to local directory
|
||||
for url in URLS_FOR_TESTING_DATA:
|
||||
download_test_file(url)
|
||||
116
utils/format_extras_slack_message.py
Normal file
116
utils/format_extras_slack_message.py
Normal file
@@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Format extras smoke test results for Slack notification.
|
||||
|
||||
This script reads failure reports from a JSON file and outputs environment
|
||||
variables for GitHub Actions to use in Slack notifications.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def format_slack_message(failures_file, workflow_url, output_file=None):
|
||||
"""
|
||||
Format extras smoke test results into Slack message components.
|
||||
|
||||
Args:
|
||||
failures_file: Path to JSON file containing failure reports
|
||||
workflow_url: URL to the GitHub Actions workflow run
|
||||
output_file: Optional path to output file (defaults to GITHUB_ENV)
|
||||
|
||||
Returns:
|
||||
Dictionary with title, message, and workflow_url
|
||||
"""
|
||||
# Read failures
|
||||
with open(failures_file) as f:
|
||||
failures = json.load(f)
|
||||
|
||||
if not failures:
|
||||
# Success case
|
||||
title = "Extras Smoke Test - All tests passed"
|
||||
message = "All extras installed successfully across all Python versions."
|
||||
else:
|
||||
# Failure case - group by Python version
|
||||
failures_by_python = {}
|
||||
for failure in failures:
|
||||
py_ver = failure.get("python_version", "unknown")
|
||||
extra = failure.get("extra", "unknown")
|
||||
|
||||
if py_ver not in failures_by_python:
|
||||
failures_by_python[py_ver] = []
|
||||
failures_by_python[py_ver].append(extra)
|
||||
|
||||
title = f"Extras Smoke Test Failed - {len(failures)} failure(s)"
|
||||
|
||||
# Build failure details
|
||||
details = []
|
||||
for py_ver in sorted(failures_by_python.keys()):
|
||||
extras = failures_by_python[py_ver]
|
||||
extras_list = "\n".join([f"• `{extra}`" for extra in sorted(extras)])
|
||||
details.append(f"*Python {py_ver}*\n{extras_list}")
|
||||
|
||||
message = "\n\n".join(details)
|
||||
|
||||
# Determine output destination
|
||||
if output_file is None:
|
||||
output_file = os.environ.get("GITHUB_ENV")
|
||||
if not output_file:
|
||||
print("Error: GITHUB_ENV not set and no output file specified", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Write environment variables
|
||||
with open(output_file, "a") as f:
|
||||
f.write(f"SLACK_TITLE={title}\n")
|
||||
f.write(f"SLACK_WORKFLOW_URL={workflow_url}\n")
|
||||
# Use heredoc for multiline message
|
||||
f.write("SLACK_MESSAGE<<EOF\n")
|
||||
f.write(f"{message}\n")
|
||||
f.write("EOF\n")
|
||||
|
||||
return {"title": title, "message": message, "workflow_url": workflow_url}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Format extras smoke test results for Slack")
|
||||
parser.add_argument(
|
||||
"--failures",
|
||||
required=True,
|
||||
help="Path to JSON file containing failure reports",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workflow-url",
|
||||
required=True,
|
||||
help="URL to the GitHub Actions workflow run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
help="Output file path (defaults to GITHUB_ENV)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = format_slack_message(args.failures, args.workflow_url, args.output)
|
||||
print(f"Formatted Slack message: {result['title']}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
309
utils/get_ci_error_statistics.py
Normal file
309
utils/get_ci_error_statistics.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import zipfile
|
||||
from collections import Counter
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_jobs(workflow_run_id, token=None):
|
||||
"""Extract jobs in a GitHub Actions workflow run"""
|
||||
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/jobs?per_page=100"
|
||||
result = requests.get(url, headers=headers).json()
|
||||
jobs = []
|
||||
|
||||
try:
|
||||
jobs.extend(result["jobs"])
|
||||
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
|
||||
|
||||
for i in range(pages_to_iterate_over):
|
||||
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
|
||||
jobs.extend(result["jobs"])
|
||||
|
||||
return jobs
|
||||
except Exception:
|
||||
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_job_links(workflow_run_id, token=None):
|
||||
"""Extract job names and their job links in a GitHub Actions workflow run"""
|
||||
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/jobs?per_page=100"
|
||||
result = requests.get(url, headers=headers).json()
|
||||
job_links = {}
|
||||
|
||||
try:
|
||||
job_links.update({job["name"]: job["html_url"] for job in result["jobs"]})
|
||||
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
|
||||
|
||||
for i in range(pages_to_iterate_over):
|
||||
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
|
||||
job_links.update({job["name"]: job["html_url"] for job in result["jobs"]})
|
||||
|
||||
return job_links
|
||||
except Exception:
|
||||
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def get_artifacts_links(workflow_run_id, token=None):
|
||||
"""Get all artifact links from a workflow run"""
|
||||
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
url = (
|
||||
f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/artifacts?per_page=100"
|
||||
)
|
||||
result = requests.get(url, headers=headers).json()
|
||||
artifacts = {}
|
||||
|
||||
try:
|
||||
artifacts.update({artifact["name"]: artifact["archive_download_url"] for artifact in result["artifacts"]})
|
||||
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
|
||||
|
||||
for i in range(pages_to_iterate_over):
|
||||
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
|
||||
artifacts.update({artifact["name"]: artifact["archive_download_url"] for artifact in result["artifacts"]})
|
||||
|
||||
return artifacts
|
||||
except Exception:
|
||||
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def download_artifact(artifact_name, artifact_url, output_dir, token):
|
||||
"""Download a GitHub Action artifact from a URL.
|
||||
|
||||
The URL is of the form `https://api.github.com/repos/huggingface/transformers/actions/artifacts/{ARTIFACT_ID}/zip`,
|
||||
but it can't be used to download directly. We need to get a redirect URL first.
|
||||
See https://docs.github.com/en/rest/actions/artifacts#download-an-artifact
|
||||
"""
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
result = requests.get(artifact_url, headers=headers, allow_redirects=False)
|
||||
download_url = result.headers["Location"]
|
||||
response = requests.get(download_url, allow_redirects=True)
|
||||
file_path = os.path.join(output_dir, f"{artifact_name}.zip")
|
||||
with open(file_path, "wb") as fp:
|
||||
fp.write(response.content)
|
||||
|
||||
|
||||
def get_errors_from_single_artifact(artifact_zip_path, job_links=None):
|
||||
"""Extract errors from a downloaded artifact (in .zip format)"""
|
||||
errors = []
|
||||
failed_tests = []
|
||||
job_name = None
|
||||
|
||||
with zipfile.ZipFile(artifact_zip_path) as z:
|
||||
for filename in z.namelist():
|
||||
if not os.path.isdir(filename):
|
||||
# read the file
|
||||
if filename in ["failures_line.txt", "summary_short.txt", "job_name.txt"]:
|
||||
with z.open(filename) as f:
|
||||
for line in f:
|
||||
line = line.decode("UTF-8").strip()
|
||||
if filename == "failures_line.txt":
|
||||
try:
|
||||
# `error_line` is the place where `error` occurs
|
||||
error_line = line[: line.index(": ")]
|
||||
error = line[line.index(": ") + len(": ") :]
|
||||
errors.append([error_line, error])
|
||||
except Exception:
|
||||
# skip un-related lines that don't match the expected format
|
||||
logger.debug(f"Skipping unrelated line: {line}")
|
||||
elif filename == "summary_short.txt" and line.startswith("FAILED "):
|
||||
# `test` is the test method that failed
|
||||
test = line[len("FAILED ") :]
|
||||
failed_tests.append(test)
|
||||
elif filename == "job_name.txt":
|
||||
job_name = line
|
||||
|
||||
if len(errors) != len(failed_tests):
|
||||
raise ValueError(
|
||||
f"`errors` and `failed_tests` should have the same number of elements. Got {len(errors)} for `errors` "
|
||||
f"and {len(failed_tests)} for `failed_tests` instead. The test reports in {artifact_zip_path} have some"
|
||||
" problem."
|
||||
)
|
||||
|
||||
job_link = None
|
||||
if job_name and job_links:
|
||||
job_link = job_links.get(job_name, None)
|
||||
|
||||
# A list with elements of the form (line of error, error, failed test)
|
||||
result = [x + [y] + [job_link] for x, y in zip(errors, failed_tests)]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all_errors(artifact_dir, job_links=None):
|
||||
"""Extract errors from all artifact files"""
|
||||
|
||||
errors = []
|
||||
|
||||
paths = [os.path.join(artifact_dir, p) for p in os.listdir(artifact_dir) if p.endswith(".zip")]
|
||||
for p in paths:
|
||||
errors.extend(get_errors_from_single_artifact(p, job_links=job_links))
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def reduce_by_error(logs, error_filter=None):
|
||||
"""count each error"""
|
||||
|
||||
counter = Counter()
|
||||
counter.update([x[1] for x in logs])
|
||||
counts = counter.most_common()
|
||||
r = {}
|
||||
for error, count in counts:
|
||||
if error_filter is None or error not in error_filter:
|
||||
r[error] = {"count": count, "failed_tests": [(x[2], x[0]) for x in logs if x[1] == error]}
|
||||
|
||||
r = dict(sorted(r.items(), key=lambda item: item[1]["count"], reverse=True))
|
||||
return r
|
||||
|
||||
|
||||
def get_model(test):
|
||||
"""Get the model name from a test method"""
|
||||
test = test.split("::")[0]
|
||||
if test.startswith("tests/models/"):
|
||||
test = test.split("/")[2]
|
||||
else:
|
||||
test = None
|
||||
|
||||
return test
|
||||
|
||||
|
||||
def reduce_by_model(logs, error_filter=None):
|
||||
"""count each error per model"""
|
||||
|
||||
logs = [(x[0], x[1], get_model(x[2])) for x in logs]
|
||||
logs = [x for x in logs if x[2] is not None]
|
||||
tests = {x[2] for x in logs}
|
||||
|
||||
r = {}
|
||||
for test in tests:
|
||||
counter = Counter()
|
||||
# count by errors in `test`
|
||||
counter.update([x[1] for x in logs if x[2] == test])
|
||||
counts = counter.most_common()
|
||||
error_counts = {error: count for error, count in counts if (error_filter is None or error not in error_filter)}
|
||||
n_errors = sum(error_counts.values())
|
||||
if n_errors > 0:
|
||||
r[test] = {"count": n_errors, "errors": error_counts}
|
||||
|
||||
r = dict(sorted(r.items(), key=lambda item: item[1]["count"], reverse=True))
|
||||
return r
|
||||
|
||||
|
||||
def make_github_table(reduced_by_error):
|
||||
header = "| no. | error | status |"
|
||||
sep = "|-:|:-|:-|"
|
||||
lines = [header, sep]
|
||||
for error in reduced_by_error:
|
||||
count = reduced_by_error[error]["count"]
|
||||
line = f"| {count} | {error[:100]} | |"
|
||||
lines.append(line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def make_github_table_per_model(reduced_by_model):
|
||||
header = "| model | no. of errors | major error | count |"
|
||||
sep = "|-:|-:|-:|-:|"
|
||||
lines = [header, sep]
|
||||
for model in reduced_by_model:
|
||||
count = reduced_by_model[model]["count"]
|
||||
error, _count = list(reduced_by_model[model]["errors"].items())[0]
|
||||
line = f"| {model} | {count} | {error[:60]} | {_count} |"
|
||||
lines.append(line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--workflow_run_id", type=str, required=True, help="A GitHub Actions workflow run id.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Where to store the downloaded artifacts and other result files.",
|
||||
)
|
||||
parser.add_argument("--token", default=None, type=str, help="A token that has actions:read permission.")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
_job_links = get_job_links(args.workflow_run_id, token=args.token)
|
||||
job_links = {}
|
||||
# To deal with `workflow_call` event, where a job name is the combination of the job names in the caller and callee.
|
||||
# For example, `PyTorch 1.11 / Model tests (models/albert, single-gpu)`.
|
||||
if _job_links:
|
||||
for k, v in _job_links.items():
|
||||
# This is how GitHub actions combine job names.
|
||||
if " / " in k:
|
||||
index = k.find(" / ")
|
||||
k = k[index + len(" / ") :]
|
||||
job_links[k] = v
|
||||
with open(os.path.join(args.output_dir, "job_links.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(job_links, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
artifacts = get_artifacts_links(args.workflow_run_id, token=args.token)
|
||||
with open(os.path.join(args.output_dir, "artifacts.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(artifacts, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
for idx, (name, url) in enumerate(artifacts.items()):
|
||||
download_artifact(name, url, args.output_dir, args.token)
|
||||
# Be gentle to GitHub
|
||||
time.sleep(1)
|
||||
|
||||
errors = get_all_errors(args.output_dir, job_links=job_links)
|
||||
|
||||
# `e[1]` is the error
|
||||
counter = Counter()
|
||||
counter.update([e[1] for e in errors])
|
||||
|
||||
# print the top 30 most common test errors
|
||||
most_common = counter.most_common(30)
|
||||
for item in most_common:
|
||||
print(item)
|
||||
|
||||
with open(os.path.join(args.output_dir, "errors.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(errors, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
reduced_by_error = reduce_by_error(errors)
|
||||
reduced_by_model = reduce_by_model(errors)
|
||||
|
||||
s1 = make_github_table(reduced_by_error)
|
||||
s2 = make_github_table_per_model(reduced_by_model)
|
||||
|
||||
with open(os.path.join(args.output_dir, "reduced_by_error.txt"), "w", encoding="UTF-8") as fp:
|
||||
fp.write(s1)
|
||||
with open(os.path.join(args.output_dir, "reduced_by_model.txt"), "w", encoding="UTF-8") as fp:
|
||||
fp.write(s2)
|
||||
71
utils/get_github_job_time.py
Normal file
71
utils/get_github_job_time.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import argparse
|
||||
import math
|
||||
import traceback
|
||||
|
||||
import dateutil.parser as date_parser
|
||||
import requests
|
||||
|
||||
|
||||
def extract_time_from_single_job(job):
|
||||
"""Extract time info from a single job in a GitHub Actions workflow run"""
|
||||
|
||||
job_info = {}
|
||||
|
||||
start = job["started_at"]
|
||||
end = job["completed_at"]
|
||||
|
||||
start_datetime = date_parser.parse(start)
|
||||
end_datetime = date_parser.parse(end)
|
||||
|
||||
duration_in_min = round((end_datetime - start_datetime).total_seconds() / 60.0)
|
||||
|
||||
job_info["started_at"] = start
|
||||
job_info["completed_at"] = end
|
||||
job_info["duration"] = duration_in_min
|
||||
|
||||
return job_info
|
||||
|
||||
|
||||
def get_job_time(workflow_run_id, token=None):
|
||||
"""Extract time info for all jobs in a GitHub Actions workflow run"""
|
||||
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/jobs?per_page=100"
|
||||
result = requests.get(url, headers=headers).json()
|
||||
job_time = {}
|
||||
|
||||
try:
|
||||
job_time.update({job["name"]: extract_time_from_single_job(job) for job in result["jobs"]})
|
||||
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
|
||||
|
||||
for i in range(pages_to_iterate_over):
|
||||
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
|
||||
job_time.update({job["name"]: extract_time_from_single_job(job) for job in result["jobs"]})
|
||||
|
||||
return job_time
|
||||
except Exception:
|
||||
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
r"""
|
||||
Example:
|
||||
|
||||
python get_github_job_time.py --workflow_run_id 2945609517
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--workflow_run_id", type=str, required=True, help="A GitHub Actions workflow run id.")
|
||||
args = parser.parse_args()
|
||||
|
||||
job_time = get_job_time(args.workflow_run_id)
|
||||
job_time = dict(sorted(job_time.items(), key=lambda item: item[1]["duration"], reverse=True))
|
||||
|
||||
for k, v in job_time.items():
|
||||
print(f"{k}: {v['duration']}")
|
||||
35
utils/get_modified_files.py
Normal file
35
utils/get_modified_files.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.:
|
||||
# python ./utils/get_modified_files.py utils src tests examples
|
||||
#
|
||||
# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered
|
||||
# since the output of this script is fed into Makefile commands it doesn't print a newline after the results
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
|
||||
modified_files = (
|
||||
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split()).decode("utf-8").split()
|
||||
)
|
||||
|
||||
joined_dirs = "|".join(sys.argv[1:])
|
||||
regex = re.compile(rf"^({joined_dirs}).*?\.py$")
|
||||
|
||||
relevant_modified_files = [x for x in modified_files if regex.match(x)]
|
||||
print(" ".join(relevant_modified_files), end="")
|
||||
133
utils/get_pr_run_slow_jobs.py
Normal file
133
utils/get_pr_run_slow_jobs.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import string
|
||||
|
||||
|
||||
MAX_NUM_JOBS_TO_SUGGEST = 16
|
||||
|
||||
|
||||
def get_jobs_to_run():
|
||||
# The file `pr_files.txt` contains the information about the files changed in a pull request, and it is prepared by
|
||||
# the caller (using GitHub api).
|
||||
# We can also use the following api to get the information if we don't have them before calling this script.
|
||||
# url = f"https://api.github.com/repos/huggingface/transformers/pulls/PULL_NUMBER/files?ref={pr_sha}"
|
||||
with open("pr_files.txt") as fp:
|
||||
pr_files = json.load(fp)
|
||||
pr_files = [{k: v for k, v in item.items() if k in ["filename", "status"]} for item in pr_files]
|
||||
pr_files = [item["filename"] for item in pr_files if item["status"] in ["added", "modified"]]
|
||||
|
||||
# models or quantizers
|
||||
re_1 = re.compile(r"src/transformers/(models/.*)/modeling_.*\.py")
|
||||
re_2 = re.compile(r"src/transformers/(quantizers/quantizer_.*)\.py")
|
||||
|
||||
# tests for models or quantizers
|
||||
re_3 = re.compile(r"tests/(models/.*)/test_.*\.py")
|
||||
re_4 = re.compile(r"tests/(quantization/.*)/test_.*\.py")
|
||||
|
||||
# files in a model directory but not necessary a modeling file
|
||||
re_5 = re.compile(r"src/transformers/(models/.*)/.*\.py")
|
||||
|
||||
regexes = [re_1, re_2, re_3, re_4, re_5]
|
||||
|
||||
jobs_to_run = []
|
||||
for pr_file in pr_files:
|
||||
for regex in regexes:
|
||||
matched = regex.findall(pr_file)
|
||||
if len(matched) > 0:
|
||||
item = matched[0]
|
||||
item = item.replace("quantizers/quantizer_", "quantization/")
|
||||
# TODO: for files in `quantizers`, the processed item above may not exist. Try using a fuzzy matching
|
||||
if item in repo_content:
|
||||
jobs_to_run.append(item)
|
||||
break
|
||||
jobs_to_run = sorted(set(jobs_to_run))
|
||||
|
||||
return jobs_to_run
|
||||
|
||||
|
||||
def parse_message(message: str) -> str:
|
||||
"""
|
||||
Parses a GitHub pull request's comment to find the models specified in it to run slow CI.
|
||||
|
||||
Args:
|
||||
message (`str`): The body of a GitHub pull request's comment.
|
||||
|
||||
Returns:
|
||||
`str`: The substring in `message` after `run-slow`, run_slow` or run slow`. If no such prefix is found, the
|
||||
empty string is returned.
|
||||
"""
|
||||
if message is None:
|
||||
return ""
|
||||
|
||||
message = message.strip().lower()
|
||||
|
||||
# run-slow: model_1, model_2, quantization_1, quantization_2
|
||||
if not message.startswith(("run-slow", "run_slow", "run slow")):
|
||||
return ""
|
||||
message = message[len("run slow") :]
|
||||
# remove leading `:`
|
||||
while message.strip().startswith(":"):
|
||||
message = message.strip()[1:]
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_jobs(message: str):
|
||||
models = parse_message(message)
|
||||
return models.replace(",", " ").split()
|
||||
|
||||
|
||||
def check_name(model_name: str):
|
||||
allowed = string.ascii_letters + string.digits + "_"
|
||||
return not (model_name.startswith("_") or model_name.endswith("_")) and all(c in allowed for c in model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--message", type=str, default="", help="The content of a comment.")
|
||||
parser.add_argument("--quantization", action="store_true", help="If we collect quantization tests")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The files are prepared by the caller (using GitHub api).
|
||||
# We can also use the following api to get the information if we don't have them before calling this script.
|
||||
# url = f"https://api.github.com/repos/OWNER/REPO/contents/PATH?ref={pr_sha}"
|
||||
# (we avoid to checkout the repository using `actions/checkout` to reduce the run time, but mostly to avoid the potential security issue as much as possible)
|
||||
repo_content = []
|
||||
for filename in ["tests_dir.txt", "tests_models_dir.txt", "tests_quantization_dir.txt"]:
|
||||
with open(filename) as fp:
|
||||
data = json.load(fp)
|
||||
data = [item["path"][len("tests/") :] for item in data if item["type"] == "dir"]
|
||||
repo_content.extend(data)
|
||||
|
||||
# These don't have the prefix `models/` or `quantization/`, so we need to add them.
|
||||
if args.message:
|
||||
specified_jobs = get_jobs(args.message)
|
||||
specified_jobs = [job for job in specified_jobs if check_name(job)]
|
||||
|
||||
# Add prefix (`models/` or `quantization`)
|
||||
jobs_to_run = []
|
||||
for job in specified_jobs:
|
||||
if not args.quantization:
|
||||
if f"models/{job}" in repo_content:
|
||||
jobs_to_run.append(f"models/{job}")
|
||||
elif job in repo_content and job != "quantization":
|
||||
jobs_to_run.append(job)
|
||||
elif f"quantization/{job}" in repo_content:
|
||||
jobs_to_run.append(f"quantization/{job}")
|
||||
|
||||
print(sorted(set(jobs_to_run)))
|
||||
|
||||
else:
|
||||
# Compute (from the added/modified files) the directories under `tests/`, `tests/models/` and `tests/quantization`to run tests.
|
||||
# These are already with the prefix `models/` or `quantization/`, so we don't need to add them.
|
||||
jobs_to_run = get_jobs_to_run()
|
||||
jobs_to_run = [x.replace("models/", "").replace("quantization/", "") for x in jobs_to_run]
|
||||
jobs_to_run = [job for job in jobs_to_run if check_name(job)]
|
||||
|
||||
if len(jobs_to_run) > MAX_NUM_JOBS_TO_SUGGEST:
|
||||
jobs_to_run = jobs_to_run[:MAX_NUM_JOBS_TO_SUGGEST]
|
||||
|
||||
suggestion = f"{', '.join(jobs_to_run)}"
|
||||
|
||||
print(suggestion)
|
||||
159
utils/get_previous_daily_ci.py
Normal file
159
utils/get_previous_daily_ci.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
from get_ci_error_statistics import download_artifact, get_artifacts_links
|
||||
|
||||
|
||||
def get_daily_ci_runs(token, num_runs=7, workflow_id=None):
|
||||
"""Get the workflow runs of the scheduled (daily) CI.
|
||||
|
||||
This only selects the runs triggered by the `schedule` event on the `main` branch.
|
||||
"""
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
# The id of a workflow (not of a workflow run).
|
||||
# From a given workflow run (where we have workflow run id), we can get the workflow id by going to
|
||||
# https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}
|
||||
# and check the `workflow_id` key.
|
||||
|
||||
if not workflow_id:
|
||||
workflow_run_id = os.environ["GITHUB_RUN_ID"]
|
||||
workflow_run = requests.get(
|
||||
f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}", headers=headers
|
||||
).json()
|
||||
workflow_id = workflow_run["workflow_id"]
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/actions/workflows/{workflow_id}/runs"
|
||||
# On `main` branch + event being `schedule` + not returning PRs + only `num_runs` results
|
||||
url += f"?branch=main&exclude_pull_requests=true&per_page={num_runs}"
|
||||
|
||||
result = requests.get(f"{url}&event=schedule", headers=headers).json()
|
||||
workflow_runs = result["workflow_runs"]
|
||||
if len(workflow_runs) == 0:
|
||||
result = requests.get(f"{url}&event=workflow_run", headers=headers).json()
|
||||
workflow_runs = result["workflow_runs"]
|
||||
|
||||
return workflow_runs
|
||||
|
||||
|
||||
def get_last_daily_ci_run(token, workflow_run_id=None, workflow_id=None, commit_sha=None):
|
||||
"""Get the last completed workflow run id of the scheduled (daily) CI."""
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
workflow_run = None
|
||||
if workflow_run_id is not None and workflow_run_id != "":
|
||||
workflow_run = requests.get(
|
||||
f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}", headers=headers
|
||||
).json()
|
||||
return workflow_run
|
||||
|
||||
workflow_runs = get_daily_ci_runs(token, workflow_id=workflow_id)
|
||||
for run in workflow_runs:
|
||||
if commit_sha in [None, ""] and run["status"] == "completed":
|
||||
workflow_run = run
|
||||
break
|
||||
# if `commit_sha` is specified, return the latest completed run with `workflow_run["head_sha"]` matching the specified sha.
|
||||
elif commit_sha not in [None, ""] and run["head_sha"] == commit_sha and run["status"] == "completed":
|
||||
workflow_run = run
|
||||
break
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
def get_last_daily_ci_workflow_run_id(token, workflow_run_id=None, workflow_id=None, commit_sha=None):
|
||||
"""Get the last completed workflow run id of the scheduled (daily) CI."""
|
||||
if workflow_run_id is not None and workflow_run_id != "":
|
||||
return workflow_run_id
|
||||
|
||||
workflow_run = get_last_daily_ci_run(token, workflow_id=workflow_id, commit_sha=commit_sha)
|
||||
workflow_run_id = None
|
||||
if workflow_run is not None:
|
||||
workflow_run_id = workflow_run["id"]
|
||||
|
||||
return workflow_run_id
|
||||
|
||||
|
||||
def get_last_daily_ci_run_commit(token, workflow_run_id=None, workflow_id=None, commit_sha=None):
|
||||
"""Get the commit sha of the last completed scheduled daily CI workflow run."""
|
||||
workflow_run = get_last_daily_ci_run(
|
||||
token, workflow_run_id=workflow_run_id, workflow_id=workflow_id, commit_sha=commit_sha
|
||||
)
|
||||
workflow_run_head_sha = None
|
||||
if workflow_run is not None:
|
||||
workflow_run_head_sha = workflow_run["head_sha"]
|
||||
|
||||
return workflow_run_head_sha
|
||||
|
||||
|
||||
def get_last_daily_ci_artifacts(
|
||||
output_dir,
|
||||
token,
|
||||
workflow_run_id=None,
|
||||
workflow_id=None,
|
||||
commit_sha=None,
|
||||
artifact_names=None,
|
||||
):
|
||||
"""Get the artifacts of last completed workflow run id of the scheduled (daily) CI."""
|
||||
workflow_run_id = get_last_daily_ci_workflow_run_id(
|
||||
token, workflow_run_id=workflow_run_id, workflow_id=workflow_id, commit_sha=commit_sha
|
||||
)
|
||||
if workflow_run_id is not None:
|
||||
artifacts_links = get_artifacts_links(workflow_run_id=workflow_run_id, token=token)
|
||||
|
||||
if artifact_names is None:
|
||||
artifact_names = artifacts_links.keys()
|
||||
|
||||
downloaded_artifact_names = []
|
||||
for artifact_name in artifact_names:
|
||||
if artifact_name in artifacts_links:
|
||||
artifact_url = artifacts_links[artifact_name]
|
||||
download_artifact(
|
||||
artifact_name=artifact_name, artifact_url=artifact_url, output_dir=output_dir, token=token
|
||||
)
|
||||
downloaded_artifact_names.append(artifact_name)
|
||||
|
||||
return downloaded_artifact_names
|
||||
|
||||
|
||||
def get_last_daily_ci_reports(
|
||||
output_dir,
|
||||
token,
|
||||
workflow_run_id=None,
|
||||
workflow_id=None,
|
||||
commit_sha=None,
|
||||
artifact_names=None,
|
||||
):
|
||||
"""Get the artifacts' content of the last completed workflow run id of the scheduled (daily) CI."""
|
||||
downloaded_artifact_names = get_last_daily_ci_artifacts(
|
||||
output_dir,
|
||||
token,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
commit_sha=commit_sha,
|
||||
artifact_names=artifact_names,
|
||||
)
|
||||
|
||||
results = {}
|
||||
for artifact_name in downloaded_artifact_names:
|
||||
artifact_zip_path = os.path.join(output_dir, f"{artifact_name}.zip")
|
||||
if os.path.isfile(artifact_zip_path):
|
||||
target_dir = os.path.join(output_dir, artifact_name)
|
||||
with zipfile.ZipFile(artifact_zip_path) as z:
|
||||
z.extractall(target_dir)
|
||||
|
||||
results[artifact_name] = {}
|
||||
filename = os.listdir(target_dir)
|
||||
for filename in filename:
|
||||
file_path = os.path.join(target_dir, filename)
|
||||
if not os.path.isdir(file_path):
|
||||
# read the file
|
||||
with open(file_path) as fp:
|
||||
content = fp.read()
|
||||
results[artifact_name][filename] = content
|
||||
|
||||
return results
|
||||
217
utils/get_test_info.py
Normal file
217
utils/get_test_info.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
# This is required to make the module import works (when the python process is running from the root of the repo)
|
||||
sys.path.append(".")
|
||||
|
||||
|
||||
r"""
|
||||
The argument `test_file` in this file refers to a model test file. This should be a string of the from
|
||||
`tests/models/*/test_modeling_*.py`.
|
||||
"""
|
||||
|
||||
|
||||
def get_module_path(test_file):
|
||||
"""Return the module path of a model test file."""
|
||||
components = test_file.split(os.path.sep)
|
||||
if components[0:2] != ["tests", "models"]:
|
||||
raise ValueError(
|
||||
"`test_file` should start with `tests/models/` (with `/` being the OS specific path separator). Got "
|
||||
f"{test_file} instead."
|
||||
)
|
||||
test_fn = components[-1]
|
||||
if not test_fn.endswith("py"):
|
||||
raise ValueError(f"`test_file` should be a python file. Got {test_fn} instead.")
|
||||
if not test_fn.startswith("test_modeling_"):
|
||||
raise ValueError(
|
||||
f"`test_file` should point to a file name of the form `test_modeling_*.py`. Got {test_fn} instead."
|
||||
)
|
||||
|
||||
components = components[:-1] + [test_fn.replace(".py", "")]
|
||||
test_module_path = ".".join(components)
|
||||
|
||||
return test_module_path
|
||||
|
||||
|
||||
def get_test_module(test_file):
|
||||
"""Get the module of a model test file."""
|
||||
test_module_path = get_module_path(test_file)
|
||||
try:
|
||||
test_module = importlib.import_module(test_module_path)
|
||||
except AttributeError as exc:
|
||||
# e.g. if you have a `tests` folder in `site-packages`, created by another package, when trying to import
|
||||
# `tests.models...`
|
||||
raise ValueError(
|
||||
f"Could not import module {test_module_path}. Confirm that you don't have a package with the same root "
|
||||
"name installed or in your environment's `site-packages`."
|
||||
) from exc
|
||||
|
||||
return test_module
|
||||
|
||||
|
||||
def get_tester_classes(test_file):
|
||||
"""Get all classes in a model test file whose names ends with `ModelTester`."""
|
||||
tester_classes = []
|
||||
test_module = get_test_module(test_file)
|
||||
for attr in dir(test_module):
|
||||
if attr.endswith("ModelTester"):
|
||||
tester_classes.append(getattr(test_module, attr))
|
||||
|
||||
# sort with class names
|
||||
return sorted(tester_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_test_classes(test_file):
|
||||
"""Get all [test] classes in a model test file with attribute `all_model_classes` that are non-empty.
|
||||
|
||||
These are usually the (model) test classes containing the (non-slow) tests to run and are subclasses of
|
||||
`ModelTesterMixin`, as well as a subclass of `unittest.TestCase`. Exceptions include `RagTestMixin` (and its subclasses).
|
||||
"""
|
||||
test_classes = []
|
||||
test_module = get_test_module(test_file)
|
||||
for attr in dir(test_module):
|
||||
attr_value = getattr(test_module, attr)
|
||||
|
||||
# Look for the test classes (subclass of `unittest.TestCase`) with `all_model_classes` attribute.
|
||||
# This also excludes `ModelTesterMixin` and `CausalLMModelTest`.
|
||||
if isinstance(attr_value, type) and issubclass(attr_value, unittest.TestCase):
|
||||
model_classes = getattr(attr_value, "all_model_classes", [])
|
||||
# `CausalLMModelTest` (subclass of `ModelTesterMixin`) has `all_model_classes` as a class attribute with
|
||||
# the value being `None`. For a real test class of `CausalLMModelTest`, the value is only set during `setUp`.
|
||||
if model_classes is None:
|
||||
test_instance = attr_value()
|
||||
test_instance.setUp()
|
||||
model_classes = getattr(test_instance, "all_model_classes", [])
|
||||
if len(model_classes) > 0:
|
||||
test_classes.append(attr_value)
|
||||
|
||||
# sort with class names
|
||||
return sorted(test_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_model_classes(test_file):
|
||||
"""Get all model classes that appear in `all_model_classes` attributes in a model test file."""
|
||||
test_classes = get_test_classes(test_file)
|
||||
model_classes = set()
|
||||
for test_class in test_classes:
|
||||
all_model_classes = test_class.all_model_classes
|
||||
if all_model_classes is None:
|
||||
test_instance = test_class()
|
||||
test_instance.setUp()
|
||||
all_model_classes = test_instance.all_model_classes
|
||||
model_classes.update(all_model_classes)
|
||||
|
||||
# sort with class names
|
||||
return sorted(model_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_model_tester_from_test_class(test_class):
|
||||
"""Get the model tester class of a model test class."""
|
||||
test = test_class()
|
||||
if hasattr(test, "setUp"):
|
||||
test.setUp()
|
||||
|
||||
model_tester = None
|
||||
if hasattr(test, "model_tester"):
|
||||
# `ModelTesterMixin` has this attribute default to `None`. Let's skip this case.
|
||||
if test.model_tester is not None:
|
||||
model_tester = test.model_tester.__class__
|
||||
|
||||
return model_tester
|
||||
|
||||
|
||||
def get_test_classes_for_model(test_file, model_class):
|
||||
"""Get all [test] classes in `test_file` that have `model_class` in their `all_model_classes`."""
|
||||
test_classes = get_test_classes(test_file)
|
||||
|
||||
target_test_classes = []
|
||||
|
||||
for test_class in test_classes:
|
||||
all_model_classes = test_class.all_model_classes
|
||||
if all_model_classes is None:
|
||||
test_instance = test_class()
|
||||
test_instance.setUp()
|
||||
all_model_classes = test_instance.all_model_classes
|
||||
|
||||
if model_class in all_model_classes:
|
||||
target_test_classes.append(test_class)
|
||||
|
||||
# sort with class names
|
||||
return sorted(target_test_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_tester_classes_for_model(test_file, model_class):
|
||||
"""Get all model tester classes in `test_file` that are associated to `model_class`."""
|
||||
test_classes = get_test_classes_for_model(test_file, model_class)
|
||||
|
||||
tester_classes = []
|
||||
for test_class in test_classes:
|
||||
tester_class = get_model_tester_from_test_class(test_class)
|
||||
if tester_class is not None:
|
||||
tester_classes.append(tester_class)
|
||||
|
||||
# sort with class names
|
||||
return sorted(tester_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_test_to_tester_mapping(test_file):
|
||||
"""Get a mapping from [test] classes to model tester classes in `test_file`.
|
||||
|
||||
This uses `get_test_classes` which may return classes that are NOT subclasses of `unittest.TestCase`.
|
||||
"""
|
||||
test_classes = get_test_classes(test_file)
|
||||
test_tester_mapping = {test_class: get_model_tester_from_test_class(test_class) for test_class in test_classes}
|
||||
return test_tester_mapping
|
||||
|
||||
|
||||
def get_model_to_test_mapping(test_file):
|
||||
"""Get a mapping from model classes to test classes in `test_file`."""
|
||||
model_classes = get_model_classes(test_file)
|
||||
model_test_mapping = {
|
||||
model_class: get_test_classes_for_model(test_file, model_class) for model_class in model_classes
|
||||
}
|
||||
return model_test_mapping
|
||||
|
||||
|
||||
def get_model_to_tester_mapping(test_file):
|
||||
"""Get a mapping from model classes to model tester classes in `test_file`."""
|
||||
model_classes = get_model_classes(test_file)
|
||||
model_to_tester_mapping = {
|
||||
model_class: get_tester_classes_for_model(test_file, model_class) for model_class in model_classes
|
||||
}
|
||||
return model_to_tester_mapping
|
||||
|
||||
|
||||
def to_json(o):
|
||||
"""Make the information succinct and easy to read.
|
||||
|
||||
Avoid the full class representation like `<class 'transformers.models.bert.modeling_bert.BertForMaskedLM'>` when
|
||||
displaying the results. Instead, we use class name (`BertForMaskedLM`) for the readability.
|
||||
"""
|
||||
if isinstance(o, str):
|
||||
return o
|
||||
elif isinstance(o, type):
|
||||
return o.__name__
|
||||
elif isinstance(o, (list, tuple)):
|
||||
return [to_json(x) for x in o]
|
||||
elif isinstance(o, dict):
|
||||
return {to_json(k): to_json(v) for k, v in o.items()}
|
||||
else:
|
||||
return o
|
||||
270
utils/get_test_reports.py
Normal file
270
utils/get_test_reports.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This util provides a way to manually run the tests of the transformers repo as they would be run by the CI.
|
||||
It was mainly used for models tests, so if you find features missing for another suite, do not hesitate to open a PR.
|
||||
|
||||
Functionnalities:
|
||||
- Running specific test suite (models, tokenizers, etc.)
|
||||
- Parallel execution across multiple processes (each has to be launched separately with different `--processes` argument)
|
||||
- GPU/CPU test filtering and slow tests filter
|
||||
- Temporary cache management for isolated test runs
|
||||
- Resume functionality for interrupted test runs
|
||||
- Important models subset testing
|
||||
|
||||
Example usages are below.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from .important_files import IMPORTANT_MODELS
|
||||
|
||||
|
||||
def is_valid_test_dir(path: Path) -> bool:
|
||||
"""Check if a given path represents a valid test dir: the path must point to a dir, not start with '__' or '.'"""
|
||||
return path.is_dir() and not path.name.startswith("__") and not path.name.startswith(".")
|
||||
|
||||
|
||||
def run_pytest(
|
||||
suite: str, subdir: Path, root_test_dir: Path, machine_type: str, dry_run: bool, tmp_cache: str, cpu_tests: bool
|
||||
) -> None:
|
||||
"""
|
||||
Execute pytest on a specific test directory with configured options:
|
||||
- suite (str): name of the test suite being run (e.g., 'models', 'tokenizers')
|
||||
- subdir (Path): the specific directory containing tests to run
|
||||
- root_test_dir (Path): the root directory of all tests, used for relative paths
|
||||
- machine_type (str): type of machine/environment (e.g., 'cpu', 'single-gpu', 'multi-gpu')
|
||||
- dry_run (bool): if True, only print the command without executing it
|
||||
- tmp_cache (str): prefix for temporary cache directory. If empty, no temp cache is used
|
||||
- cpu_tests (bool): if True, include CPU-only tests; if False, exclude non-device tests
|
||||
"""
|
||||
relative_path = subdir.relative_to(root_test_dir)
|
||||
report_name = f"{machine_type}_{suite}_{relative_path}_test_reports"
|
||||
print(f"Suite: {suite} | Running on: {relative_path}")
|
||||
|
||||
cmd = ["python3", "-m", "pytest", "-rsfE", "-v", f"--make-reports={report_name}", str(subdir)]
|
||||
if not cpu_tests:
|
||||
cmd = cmd + ["-m", "not not_device_test"]
|
||||
|
||||
ctx_manager = tempfile.TemporaryDirectory(prefix=tmp_cache) if tmp_cache else contextlib.nullcontext()
|
||||
with ctx_manager as tmp_dir:
|
||||
env = os.environ.copy()
|
||||
if tmp_cache:
|
||||
env["HUGGINGFACE_HUB_CACHE"] = tmp_dir
|
||||
|
||||
print(f"Using temporary cache located at {tmp_dir = }")
|
||||
|
||||
print("Command:", " ".join(cmd))
|
||||
if not dry_run:
|
||||
subprocess.run(cmd, check=False, env=env)
|
||||
|
||||
|
||||
def handle_suite(
|
||||
suite: str,
|
||||
test_root: Path,
|
||||
machine_type: str,
|
||||
dry_run: bool,
|
||||
tmp_cache: str = "",
|
||||
resume_at: str | None = None,
|
||||
only_in: list[str] | None = None,
|
||||
cpu_tests: bool = False,
|
||||
process_id: int = 1,
|
||||
total_processes: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Handle execution of a complete test suite with advanced filtering and process distribution.
|
||||
Args:
|
||||
- suite (str): Name of the test suite to run (corresponds to a directory under test_root).
|
||||
- test_root (Path): Root directory containing all test suites.
|
||||
- machine_type (str): Machine/environment type for report naming and identification.
|
||||
- dry_run (bool): If True, only print commands without executing them.
|
||||
- tmp_cache (str, optional): Prefix for temporary cache directories. If empty, no temp cache is used.
|
||||
- resume_at (str, optional): Resume execution starting from this subdirectory name.
|
||||
Useful for restarting interrupted test runs. Defaults to None (run from the beginning).
|
||||
- only_in (list[str], optional): Only run tests in these specific subdirectories.
|
||||
Can include special values like IMPORTANT_MODELS. Defaults to None (run all tests).
|
||||
- cpu_tests (bool, optional): Whether to include CPU-only tests. Defaults to False.
|
||||
- process_id (int, optional): Current process ID for parallel execution (1-indexed). Defaults to 1.
|
||||
- total_processes (int, optional): Total number of parallel processes. Defaults to 1.
|
||||
"""
|
||||
# Check path to suite
|
||||
full_path = test_root / suite
|
||||
if not full_path.exists():
|
||||
print(f"Test folder does not exist: {full_path}")
|
||||
return
|
||||
|
||||
# Establish the list of subdir to go through
|
||||
subdirs = sorted(full_path.iterdir())
|
||||
subdirs = [s for s in subdirs if is_valid_test_dir(s)]
|
||||
if resume_at is not None:
|
||||
subdirs = [s for s in subdirs if s.name >= resume_at]
|
||||
if only_in is not None:
|
||||
subdirs = [s for s in subdirs if s.name in only_in]
|
||||
if subdirs and total_processes > 1:
|
||||
# This interleaves the subdirs / files. For instance for subdirs = [A, B, C, D, E] and 2 processes:
|
||||
# - script launcehd with `--processes 0 2` will run A, C, E
|
||||
# - script launcehd with `--processes 1 2` will run B, D
|
||||
subdirs = subdirs[process_id::total_processes]
|
||||
|
||||
# If the subdir list is not empty, go through each
|
||||
if subdirs:
|
||||
for subdir in subdirs:
|
||||
run_pytest(suite, subdir, test_root, machine_type, dry_run, tmp_cache, cpu_tests)
|
||||
# Otherwise, launch pytest from the full path
|
||||
else:
|
||||
run_pytest(suite, full_path, test_root, machine_type, dry_run, tmp_cache, cpu_tests)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Command-line interface for running test suite with comprehensive reporting. Check handle_suite for more details.
|
||||
|
||||
Command-line Arguments:
|
||||
folder: Path to the root test directory (required)
|
||||
--suite: Test suite name to run (default: "models")
|
||||
--cpu-tests: Include CPU-only tests in addition to device tests
|
||||
--run-slow: Execute slow tests instead of skipping them
|
||||
--resume-at: Resume execution from a specific subdirectory
|
||||
--only-in: Run tests only in specified subdirectories (supports IMPORTANT_MODELS)
|
||||
--processes: Process distribution as "process_id total_processes"
|
||||
--dry-run: Print commands without executing them
|
||||
--tmp-cache: Use temporary cache directories for isolated runs
|
||||
--machine-type: Override automatic machine type detection
|
||||
|
||||
Machine Type Detection:
|
||||
- 'cpu': No CUDA available
|
||||
- 'single-gpu': CUDA available with 1 GPU
|
||||
- 'multi-gpu': CUDA available with multiple GPUs
|
||||
|
||||
Process Distribution:
|
||||
Use --processes to split work across multiple parallel processes:
|
||||
--processes 0 4 # This is process 0 of 4 total processes
|
||||
--processes 1 4 # This is process 1 of 4 total processes
|
||||
...
|
||||
|
||||
Usage Examples:
|
||||
# Basic model testing
|
||||
python3 -m utils.get_test_reports tests/ --suite models
|
||||
|
||||
# Run slow tests for important models only
|
||||
python3 -m utils.get_test_reports tests/ --suite models --run-slow --only-in IMPORTANT_MODELS
|
||||
|
||||
# Parallel execution across 4 processes, second process to launch (processes are 0-indexed)
|
||||
python3 -m utils.get_test_reports tests/ --suite models --processes 1 4
|
||||
|
||||
# Resume interrupted run from 'bert' subdirectory with a tmp cache
|
||||
python3 -m utils.get_test_reports tests/ --suite models --resume-at bert --tmp-cache /tmp/
|
||||
|
||||
# Run specific models with CPU tests
|
||||
python3 -m utils.get_test_reports tests/ --suite models --only-in bert gpt2 --cpu-tests
|
||||
|
||||
# Run slow tests for only important models with a tmp cache
|
||||
python3 -m utils.get_test_reports tests/ --suite models --run-slow --only-in IMPORTANT_MODELS --tmp-cache /tmp/
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("folder", help="Path to test root folder (e.g., ./tests)")
|
||||
|
||||
# Choose which tests to run (broad picture)
|
||||
parser.add_argument("--suite", type=str, default="models", help="Test suit to run")
|
||||
parser.add_argument("--cpu-tests", action="store_true", help="Also runs non-device tests")
|
||||
parser.add_argument("--run-slow", action="store_true", help="Run slow tests instead of skipping them")
|
||||
parser.add_argument("--collect-outputs", action="store_true", help="Collect outputs of the tests")
|
||||
|
||||
# Fine-grain control over the tests to run
|
||||
parser.add_argument("--resume-at", type=str, default=None, help="Resume at a specific subdir / file in the suite")
|
||||
parser.add_argument(
|
||||
"--only-in",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Only run tests in the given subdirs / file. Use IMPORTANT_MODELS to run only the important models tests.",
|
||||
)
|
||||
|
||||
# How to run the test suite: is the work divided among processes, do a try run, use temp cache?
|
||||
parser.add_argument(
|
||||
"--processes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Inform each CI process as to the work to do: format as `process_id total_processes`. "
|
||||
"In order to run with multiple (eg. 3) processes, you need to run the script multiple times (eg. 3 times).",
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true", help="Only print commands without running them")
|
||||
parser.add_argument("--tmp-cache", type=str, help="Change HUGGINGFACE_HUB_CACHE to a tmp dir for each test")
|
||||
|
||||
# This is a purely decorative argument, but it can be useful to distinguish between runs
|
||||
parser.add_argument(
|
||||
"--machine-type", type=str, default="", help="Machine type, automatically inferred if not provided"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle run slow
|
||||
if args.run_slow:
|
||||
os.environ["RUN_SLOW"] = "yes"
|
||||
print("[WARNING] Running slow tests.")
|
||||
else:
|
||||
print("[WARNING] Skipping slow tests.")
|
||||
|
||||
# Handle multiple CI processes
|
||||
if args.processes is None:
|
||||
process_id, total_processes = 1, 1
|
||||
elif len(args.processes) == 2:
|
||||
process_id, total_processes = args.processes
|
||||
else:
|
||||
raise ValueError(f"Invalid processes argument: {args.processes}")
|
||||
|
||||
# Assert test root exists
|
||||
test_root = Path(args.folder).resolve()
|
||||
if not test_root.exists():
|
||||
print(f"Root test folder not found: {test_root}")
|
||||
exit(1)
|
||||
|
||||
# Handle collection of outputs
|
||||
if args.collect_outputs:
|
||||
os.environ["PATCH_TESTING_METHODS_TO_COLLECT_OUTPUTS"] = "yes"
|
||||
reports_dir = test_root.parent / "reports"
|
||||
os.environ["_PATCHED_TESTING_METHODS_OUTPUT_DIR"] = str(reports_dir)
|
||||
|
||||
# Infer machine type if not provided
|
||||
if args.machine_type == "":
|
||||
if not torch.cuda.is_available():
|
||||
machine_type = "cpu"
|
||||
else:
|
||||
machine_type = "multi-gpu" if torch.cuda.device_count() > 1 else "single-gpu"
|
||||
else:
|
||||
machine_type = args.machine_type
|
||||
|
||||
# Reduce the scope for models if necessary
|
||||
only_in = args.only_in if args.only_in else None
|
||||
if only_in == ["IMPORTANT_MODELS"]:
|
||||
only_in = IMPORTANT_MODELS
|
||||
|
||||
# Launch suite
|
||||
handle_suite(
|
||||
suite=args.suite,
|
||||
test_root=test_root,
|
||||
machine_type=machine_type,
|
||||
dry_run=args.dry_run,
|
||||
tmp_cache=args.tmp_cache,
|
||||
resume_at=args.resume_at,
|
||||
only_in=only_in,
|
||||
cpu_tests=args.cpu_tests,
|
||||
process_id=process_id,
|
||||
total_processes=total_processes,
|
||||
)
|
||||
30
utils/important_files.py
Normal file
30
utils/important_files.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# List here the models to always test.
|
||||
IMPORTANT_MODELS = [
|
||||
"auto",
|
||||
"bert",
|
||||
"gpt2",
|
||||
"t5",
|
||||
"modernbert",
|
||||
"vit",
|
||||
"clip",
|
||||
"detr",
|
||||
"table_transformer",
|
||||
"got_ocr2",
|
||||
"whisper",
|
||||
"wav2vec2",
|
||||
"qwen2_audio",
|
||||
"speech_t5",
|
||||
"csm",
|
||||
"llama",
|
||||
"gemma3",
|
||||
"qwen2",
|
||||
"mistral3",
|
||||
"qwen2_5_vl",
|
||||
"llava",
|
||||
"smolvlm",
|
||||
"internvl",
|
||||
"gemma3n",
|
||||
"gpt_oss",
|
||||
"qwen2_5_omni",
|
||||
"pi0",
|
||||
]
|
||||
4
utils/important_models.txt
Normal file
4
utils/important_models.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
models/llama
|
||||
models/mistral
|
||||
models/mixtral
|
||||
models/gemma
|
||||
338
utils/models_to_deprecate.py
Normal file
338
utils/models_to_deprecate.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Script to find a candidate list of models to deprecate based on the number of downloads and the date of the last
|
||||
commit.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from git import Repo
|
||||
from huggingface_hub import HfApi
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES, DEPRECATED_MODELS
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
||||
repo = Repo(PATH_TO_REPO)
|
||||
|
||||
|
||||
# Used when the folder name on the hub does not match the folder name in `transformers/models`
|
||||
# format = {folder name in `transformers/models`: expected tag on the hub}
|
||||
MODEL_FOLDER_NAME_TO_TAG_MAPPING = {
|
||||
"audio_spectrogram_transformer": "audio-spectrogram-transformer",
|
||||
"bert_generation": "bert-generation",
|
||||
"blenderbot_small": "blenderbot-small",
|
||||
"blip_2": "blip-2",
|
||||
"dab_detr": "dab-detr",
|
||||
"data2vec": "data2vec-audio", # actually, the base model is never used as a tag, but the sub models are
|
||||
"deberta_v2": "deberta-v2",
|
||||
"donut": "donut-swin",
|
||||
"encoder_decoder": "encoder-decoder",
|
||||
"grounding_dino": "grounding-dino",
|
||||
"kosmos2": "kosmos-2",
|
||||
"kosmos2_5": "kosmos-2.5",
|
||||
"megatron_bert": "megatron-bert",
|
||||
"mgp_str": "mgp-str",
|
||||
"mm_grounding_dino": "mm-grounding-dino",
|
||||
"modernbert_decoder": "modernbert-decoder",
|
||||
"nllb_moe": "nllb-moe",
|
||||
"omdet_turbo": "omdet-turbo",
|
||||
"openai": "openai-gpt",
|
||||
"roberta_prelayernorm": "roberta-prelayernorm",
|
||||
"sew_d": "sew-d",
|
||||
"speech_encoder_decoder": "speech-encoder-decoder",
|
||||
"table_transformer": "table-transformer",
|
||||
"unispeech_sat": "unispeech-sat",
|
||||
"vision_encoder_decoder": "vision-encoder-decoder",
|
||||
"vision_text_dual_encoder": "vision-text-dual-encoder",
|
||||
"wav2vec2_bert": "wav2vec2-bert",
|
||||
"wav2vec2_conformer": "wav2vec2-conformer",
|
||||
"x_clip": "xclip",
|
||||
"xlm_roberta": "xlm-roberta",
|
||||
"xlm_roberta_xl": "xlm-roberta-xl",
|
||||
}
|
||||
|
||||
# Used on model architectures with multiple tags on the hub (e.g. on VLMs, we often support a text-only model).
|
||||
# Applied after the model folder name mapping. format = {base model tag: [extra tags]}
|
||||
EXTRA_TAGS_MAPPING = {
|
||||
"aimv2": ["aimv2_vision_model"],
|
||||
"aria": ["aria_text"],
|
||||
"bart": ["barthez", "bartpho"],
|
||||
"bert": ["bert-japanese", "bertweet", "herbert", "phobert"],
|
||||
"beit": ["dit"],
|
||||
"blip-2": ["blip_2_qformer"],
|
||||
"chinese_clip": ["chinese_clip_vision_model"],
|
||||
"clip": ["clip_text_model", "clip_vision_model"],
|
||||
"data2vec-audio": ["data2vec-text", "data2vec-vision"],
|
||||
"depth_anything": ["depth_anything_v2"],
|
||||
"donut-swin": ["nougat"],
|
||||
"edgetam": ["edgetam_vision_model"],
|
||||
"fastspeech2_conformer": ["fastspeech2_conformer_with_hifigan"],
|
||||
"gemma3": ["gemma3_text"],
|
||||
"gemma3n": ["gemma3n_audio", "gemma3n_text", "gemma3n_vision"],
|
||||
"gpt2": ["cpm", "dialogpt", "gpt-sw3", "megatron_gpt2"],
|
||||
"glm4v_moe": ["glm4v_moe_text", "glm4v_moe_vision"],
|
||||
"glm4_image": ["glm4_image_text", "glm4_image_vision"],
|
||||
"glm4v": ["glm4v_text", "glm4v_vision"],
|
||||
"idefics3": ["idefics3_vision"],
|
||||
"internvl": ["internvl_vision"],
|
||||
"layoutlmv2": ["layoutxlm"],
|
||||
"llama": ["code_llama", "falcon3", "llama2", "llama3"],
|
||||
"llama4": ["llama4_text"],
|
||||
"llava_next": ["granitevision"],
|
||||
"luke": ["mluke"],
|
||||
"m2m_100": ["nllb"],
|
||||
"maskformer": ["maskformer-swin"],
|
||||
"mbart": ["mbart50"],
|
||||
"parakeet": ["parakeet_ctc", "parakeet_encoder"],
|
||||
"lasr": ["lasr_ctc", "lasr_encoder"],
|
||||
"perception_lm": ["perception_encoder"],
|
||||
"pix2struct": ["deplot", "matcha"],
|
||||
"qwen2_5_vl": ["qwen2_5_vl_text"],
|
||||
"qwen2_audio": ["qwen2_audio_encoder"],
|
||||
"qwen2_vl": ["qwen2_vl_text"],
|
||||
"qwen3_vl_moe": ["qwen3_vl_moe_text"],
|
||||
"qwen3_vl": ["qwen3_vl_text"],
|
||||
"qwen3_5": ["qwen3_5text"],
|
||||
"qwen3_5_moe": ["qwen3_5_moe_text"],
|
||||
"rt_detr": ["rt_detr_resnet"],
|
||||
"sam2": ["sam2_hiera_det_model", "sam2_vision_model"],
|
||||
"sam": ["sam_hq_vision_model", "sam_vision_model"],
|
||||
"siglip2": ["siglip2_vision_model"],
|
||||
"siglip": ["siglip_vision_model"],
|
||||
"smolvlm": ["smolvlm_vision"],
|
||||
"t5": ["byt5", "flan-t5", "flan-ul2", "madlad-400", "myt5", "t5v1.1", "ul2"],
|
||||
"voxtral": ["voxtral_encoder"],
|
||||
"wav2vec2": ["mms", "wav2vec2_phoneme", "xls_r", "xlsr_wav2vec2"],
|
||||
"xlm-roberta": ["xlm-v"],
|
||||
}
|
||||
|
||||
# Similar to `DEPRECATED_MODELS`, but containing the tags when the model tag does not match the model folder name :'(
|
||||
DEPRECATED_MODELS_TAGS = {"gptsan-japanese", "open-llama", "transfo-xl", "xlm-prophetnet"}
|
||||
|
||||
|
||||
class HubModelLister:
|
||||
"""
|
||||
Utility for getting models from the hub based on tags. Handles errors without crashing the script.
|
||||
"""
|
||||
|
||||
def __init__(self, tags):
|
||||
self.tags = tags
|
||||
self.model_list = api.list_models(filter=tags)
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
yield from self.model_list
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return
|
||||
|
||||
|
||||
def _extract_commit_hash(commits):
|
||||
for commit in commits:
|
||||
if commit.startswith("commit "):
|
||||
return commit.split(" ")[1]
|
||||
return ""
|
||||
|
||||
|
||||
def get_list_of_repo_model_paths(models_dir):
|
||||
# Get list of all models in the library
|
||||
models = glob.glob(os.path.join(models_dir, "*/modeling_*.py"))
|
||||
|
||||
# Get list of all deprecated models in the library
|
||||
deprecated_models = glob.glob(os.path.join(models_dir, "deprecated", "*"))
|
||||
# For each deprecated model, remove the deprecated models from the list of all models as well as the symlink path
|
||||
for deprecated_model in deprecated_models:
|
||||
deprecated_model_name = "/" + deprecated_model.split("/")[-1] + "/"
|
||||
models = [model for model in models if deprecated_model_name not in model]
|
||||
# Remove deprecated models
|
||||
models = [model for model in models if "/deprecated" not in model]
|
||||
# Remove auto
|
||||
models = [model for model in models if "/auto/" not in model]
|
||||
return models
|
||||
|
||||
|
||||
def get_list_of_models_to_deprecate(
|
||||
thresh_num_downloads=5_000,
|
||||
thresh_date=None,
|
||||
use_cache=False,
|
||||
save_model_info=False,
|
||||
max_num_models=-1,
|
||||
):
|
||||
if thresh_date is None:
|
||||
thresh_date = datetime.now(timezone.utc).replace(year=datetime.now(timezone.utc).year - 1)
|
||||
else:
|
||||
thresh_date = datetime.strptime(thresh_date, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||||
|
||||
models_dir = PATH_TO_REPO / "src/transformers/models"
|
||||
model_paths = get_list_of_repo_model_paths(models_dir=models_dir)
|
||||
|
||||
if use_cache and os.path.exists("models_info.json"):
|
||||
with open("models_info.json", "r") as f:
|
||||
models_info = json.load(f)
|
||||
# Convert datetimes back to datetime objects
|
||||
for model, info in models_info.items():
|
||||
info["first_commit_datetime"] = datetime.fromisoformat(info["first_commit_datetime"])
|
||||
|
||||
else:
|
||||
print("Building a dictionary of basic model info...")
|
||||
models_info = defaultdict(dict)
|
||||
for i, model_path in enumerate(tqdm(sorted(model_paths))):
|
||||
if max_num_models != -1 and i > max_num_models:
|
||||
break
|
||||
model = model_path.split("/")[-2]
|
||||
if model in models_info:
|
||||
continue
|
||||
commits = repo.git.log("--diff-filter=A", "--", model_path).split("\n")
|
||||
commit_hash = _extract_commit_hash(commits)
|
||||
commit_obj = repo.commit(commit_hash)
|
||||
committed_datetime = commit_obj.committed_datetime
|
||||
models_info[model]["commit_hash"] = commit_hash
|
||||
models_info[model]["first_commit_datetime"] = committed_datetime
|
||||
models_info[model]["model_path"] = model_path
|
||||
models_info[model]["downloads"] = 0
|
||||
models_info[model]["tags"] = [model]
|
||||
|
||||
# The keys in the dictionary above are the model folder names. In some cases, the model tag on the hub does not
|
||||
# match the model folder name. We replace the key and append the expected tag.
|
||||
for folder_name, expected_tag in MODEL_FOLDER_NAME_TO_TAG_MAPPING.items():
|
||||
if folder_name in models_info:
|
||||
models_info[expected_tag] = models_info[folder_name]
|
||||
models_info[expected_tag]["tags"] = [expected_tag]
|
||||
del models_info[folder_name]
|
||||
|
||||
# Some models have multiple tags on the hub. We add the expected tag to the list of tags.
|
||||
for model_name, extra_tags in EXTRA_TAGS_MAPPING.items():
|
||||
if model_name in models_info:
|
||||
models_info[model_name]["tags"].extend(extra_tags)
|
||||
|
||||
# Sanity check for the case with all models: the model tags must match the keys in the CONFIG_MAPPING_NAMES
|
||||
# (= actual model tags on the hub)
|
||||
if max_num_models == -1:
|
||||
all_model_tags = set()
|
||||
for model_name in models_info:
|
||||
all_model_tags.update(models_info[model_name]["tags"])
|
||||
|
||||
non_deprecated_model_tags = (
|
||||
set(CONFIG_MAPPING_NAMES.keys()) - set(DEPRECATED_MODELS_TAGS) - set(DEPRECATED_MODELS)
|
||||
)
|
||||
if all_model_tags != non_deprecated_model_tags:
|
||||
raise ValueError(
|
||||
"The tags of the `models_info` dictionary must match the keys in the `CONFIG_MAPPING_NAMES`!"
|
||||
"\nMissing tags in `model_info`: "
|
||||
+ str(sorted(non_deprecated_model_tags - all_model_tags))
|
||||
+ "\nExtra tags in `model_info`: "
|
||||
+ str(sorted(all_model_tags - non_deprecated_model_tags))
|
||||
+ "\n\nYou need to update one or more of the following: `CONFIG_MAPPING_NAMES`, "
|
||||
"`EXTRA_TAGS_MAPPING` or `DEPRECATED_MODELS_TAGS`."
|
||||
)
|
||||
|
||||
# Filter out models which were added less than a year ago
|
||||
models_info = {
|
||||
model: info for model, info in models_info.items() if info["first_commit_datetime"] < thresh_date
|
||||
}
|
||||
|
||||
# We make successive calls to the hub, filtering based on the model tags
|
||||
print("Making calls to the hub to find models below the threshold number of downloads...")
|
||||
num_models = len(models_info)
|
||||
for i, (model, model_info) in enumerate(models_info.items()):
|
||||
print(f"{i + 1}/{num_models}: getting hub downloads for model='{model}' (tags={model_info['tags']})")
|
||||
for model_tag in model_info["tags"]:
|
||||
if model_info["downloads"] > thresh_num_downloads:
|
||||
break
|
||||
model_list = HubModelLister(tags=model_tag)
|
||||
for hub_model in model_list:
|
||||
if hub_model.private:
|
||||
continue
|
||||
model_info["downloads"] += hub_model.downloads
|
||||
# No need to make further hub calls, it's above the set threshold
|
||||
if model_info["downloads"] > thresh_num_downloads:
|
||||
break
|
||||
|
||||
if save_model_info and not (use_cache and os.path.exists("models_info.json")):
|
||||
# Make datetimes serializable
|
||||
for model, info in models_info.items():
|
||||
info["first_commit_datetime"] = info["first_commit_datetime"].isoformat()
|
||||
with open("models_info.json", "w") as f:
|
||||
json.dump(models_info, f, indent=4)
|
||||
|
||||
print("\nFinding models to deprecate:")
|
||||
n_models_to_deprecate = 0
|
||||
models_to_deprecate = {}
|
||||
for model, info in models_info.items():
|
||||
n_downloads = info["downloads"]
|
||||
if n_downloads < thresh_num_downloads:
|
||||
n_models_to_deprecate += 1
|
||||
models_to_deprecate[model] = info
|
||||
print(f"\nModel: {model}")
|
||||
print(f"Downloads: {n_downloads}")
|
||||
print(f"Date: {info['first_commit_datetime']}")
|
||||
|
||||
# sort models to deprecate by downloads (lowest downloads first)
|
||||
models_to_deprecate = sorted(models_to_deprecate.items(), key=lambda x: x[1]["downloads"])
|
||||
|
||||
print("\nModels to deprecate: ", "\n" + "\n".join([model[0] for model in models_to_deprecate]))
|
||||
print(f"\nNumber of models to deprecate: {n_models_to_deprecate}")
|
||||
print("Before deprecating make sure to verify the models, including if they're used as a module in other models.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--save_model_info", action="store_true", help="Save the retrieved model info to a json file.")
|
||||
parser.add_argument(
|
||||
"--use_cache", action="store_true", help="Use the cached model info instead of calling the hub."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--thresh_num_downloads",
|
||||
type=int,
|
||||
default=5_000,
|
||||
help=(
|
||||
"Threshold number of downloads below which a model should be deprecated. Default is 5,000. If you are "
|
||||
"considering a sweep and using a cache, set this to the highest number of the sweep."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--thresh_date",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Date to consider the first commit from. Format: YYYY-MM-DD. If unset, defaults to one year ago from "
|
||||
"today."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_num_models",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum number of models architectures to consider. -1 means all models. Useful for testing.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
models_to_deprecate = get_list_of_models_to_deprecate(
|
||||
thresh_num_downloads=args.thresh_num_downloads,
|
||||
thresh_date=args.thresh_date,
|
||||
use_cache=args.use_cache,
|
||||
save_model_info=args.save_model_info,
|
||||
max_num_models=args.max_num_models,
|
||||
)
|
||||
184
utils/modular_integrations.py
Normal file
184
utils/modular_integrations.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import os
|
||||
|
||||
import libcst as cst
|
||||
|
||||
|
||||
# Files from external libraries that should not be tracked
|
||||
# E.g. for habana, we don't want to track the dependencies from `modeling_all_models.py` as it is not part of the transformers library
|
||||
EXCLUDED_EXTERNAL_FILES = {
|
||||
"habana": [{"name": "modeling_all_models", "type": "modeling"}],
|
||||
}
|
||||
|
||||
|
||||
def convert_relative_import_to_absolute(
|
||||
import_node: cst.ImportFrom,
|
||||
file_path: str,
|
||||
package_name: str | None = "transformers",
|
||||
) -> cst.ImportFrom:
|
||||
"""
|
||||
Convert a relative libcst.ImportFrom node into an absolute one,
|
||||
using the file path and package name.
|
||||
|
||||
Args:
|
||||
import_node: A relative import node (e.g. `from ..utils import helper`)
|
||||
file_path: Path to the file containing the import (can be absolute or relative)
|
||||
package_name: The top-level package name (e.g. 'myproject')
|
||||
|
||||
Returns:
|
||||
A new ImportFrom node with the absolute import path
|
||||
"""
|
||||
if not (import_node.relative and len(import_node.relative) > 0):
|
||||
return import_node # Already absolute
|
||||
|
||||
file_path = os.path.abspath(file_path)
|
||||
rel_level = len(import_node.relative)
|
||||
|
||||
# Strip file extension and split into parts
|
||||
file_path_no_ext = file_path.removesuffix(".py")
|
||||
file_parts = file_path_no_ext.split(os.path.sep)
|
||||
|
||||
# Ensure the file path includes the package name
|
||||
if package_name not in file_parts:
|
||||
raise ValueError(f"Package name '{package_name}' not found in file path '{file_path}'")
|
||||
|
||||
# Slice file_parts starting from the package name
|
||||
pkg_index = file_parts.index(package_name)
|
||||
module_parts = file_parts[pkg_index + 1 :] # e.g. ['module', 'submodule', 'foo']
|
||||
if len(module_parts) < rel_level:
|
||||
raise ValueError(f"Relative import level ({rel_level}) goes beyond package root.")
|
||||
|
||||
base_parts = module_parts[:-rel_level]
|
||||
|
||||
# Flatten the module being imported (if any)
|
||||
def flatten_module(module: cst.BaseExpression | None) -> list[str]:
|
||||
if not module:
|
||||
return []
|
||||
if isinstance(module, cst.Name):
|
||||
return [module.value]
|
||||
elif isinstance(module, cst.Attribute):
|
||||
parts = []
|
||||
while isinstance(module, cst.Attribute):
|
||||
parts.insert(0, module.attr.value)
|
||||
module = module.value
|
||||
if isinstance(module, cst.Name):
|
||||
parts.insert(0, module.value)
|
||||
return parts
|
||||
return []
|
||||
|
||||
import_parts = flatten_module(import_node.module)
|
||||
|
||||
# Combine to get the full absolute import path
|
||||
full_parts = [package_name] + base_parts + import_parts
|
||||
|
||||
# Handle special case where the import comes from a namespace package (e.g. optimum with `optimum.habana`, `optimum.intel` instead of `src.optimum`)
|
||||
if package_name != "transformers" and file_parts[pkg_index - 1] != "src":
|
||||
full_parts = [file_parts[pkg_index - 1]] + full_parts
|
||||
|
||||
# Build the dotted module path
|
||||
dotted_module: cst.BaseExpression | None = None
|
||||
for part in full_parts:
|
||||
name = cst.Name(part)
|
||||
dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)
|
||||
|
||||
# Return a new ImportFrom node with absolute import
|
||||
return import_node.with_changes(module=dotted_module, relative=[])
|
||||
|
||||
|
||||
def convert_to_relative_import(import_node: cst.ImportFrom, file_path: str, package_name: str) -> cst.ImportFrom:
|
||||
"""
|
||||
Convert an absolute import to a relative one if it belongs to `package_name`.
|
||||
|
||||
Parameters:
|
||||
- node: The ImportFrom node to possibly transform.
|
||||
- file_path: Absolute path to the file containing the import (e.g., '/path/to/mypackage/foo/bar.py').
|
||||
- package_name: The top-level package name (e.g., 'mypackage').
|
||||
|
||||
Returns:
|
||||
- A possibly modified ImportFrom node.
|
||||
"""
|
||||
if import_node.relative:
|
||||
return import_node # Already relative import
|
||||
|
||||
# Extract module name string from ImportFrom
|
||||
def get_module_name(module):
|
||||
if isinstance(module, cst.Name):
|
||||
return module.value, [module.value]
|
||||
elif isinstance(module, cst.Attribute):
|
||||
parts = []
|
||||
while isinstance(module, cst.Attribute):
|
||||
parts.append(module.attr.value)
|
||||
module = module.value
|
||||
if isinstance(module, cst.Name):
|
||||
parts.append(module.value)
|
||||
parts.reverse()
|
||||
return ".".join(parts), parts
|
||||
return "", None
|
||||
|
||||
module_name, submodule_list = get_module_name(import_node.module)
|
||||
|
||||
# Check if it's from the target package
|
||||
if (
|
||||
not (module_name.startswith(package_name + ".") or module_name.startswith("optimum." + package_name + "."))
|
||||
and module_name != package_name
|
||||
):
|
||||
return import_node # Not from target package
|
||||
|
||||
# Locate the package root inside the file path
|
||||
norm_file_path = os.path.normpath(file_path)
|
||||
parts = norm_file_path.split(os.sep)
|
||||
|
||||
try:
|
||||
pkg_index = parts.index(package_name)
|
||||
except ValueError:
|
||||
# Package name not found in path — assume we can't resolve relative depth
|
||||
return import_node
|
||||
|
||||
# Depth is how many directories after the package name before the current file
|
||||
depth = len(parts) - pkg_index - 1 # exclude the .py file itself
|
||||
for i, submodule in enumerate(parts[pkg_index + 1 :]):
|
||||
if submodule == submodule_list[2 + i]:
|
||||
depth -= 1
|
||||
else:
|
||||
break
|
||||
|
||||
# Create the correct number of dots
|
||||
relative = [cst.Dot()] * depth if depth > 0 else [cst.Dot()]
|
||||
|
||||
# Strip package prefix from import module path
|
||||
if module_name.startswith("optimum." + package_name + "."):
|
||||
stripped_name = module_name[len("optimum." + package_name) :].lstrip(".")
|
||||
else:
|
||||
stripped_name = module_name[len(package_name) :].lstrip(".")
|
||||
|
||||
# Build new module node
|
||||
if stripped_name == "":
|
||||
new_module = None
|
||||
else:
|
||||
name_parts = stripped_name.split(".")[i:]
|
||||
new_module = cst.Name(name_parts[0])
|
||||
for part in name_parts[1:]:
|
||||
new_module = cst.Attribute(value=new_module, attr=cst.Name(part))
|
||||
|
||||
return import_node.with_changes(module=new_module, relative=relative)
|
||||
|
||||
|
||||
class AbsoluteImportTransformer(cst.CSTTransformer):
|
||||
def __init__(self, relative_path: str, source_library: str):
|
||||
super().__init__()
|
||||
self.relative_path = relative_path
|
||||
self.source_library = source_library
|
||||
|
||||
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
|
||||
return convert_relative_import_to_absolute(
|
||||
import_node=updated_node, file_path=self.relative_path, package_name=self.source_library
|
||||
)
|
||||
|
||||
|
||||
class RelativeImportTransformer(cst.CSTTransformer):
|
||||
def __init__(self, relative_path: str, source_library: str):
|
||||
super().__init__()
|
||||
self.relative_path = relative_path
|
||||
self.source_library = source_library
|
||||
|
||||
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
|
||||
return convert_to_relative_import(updated_node, self.relative_path, self.source_library)
|
||||
2121
utils/modular_model_converter.py
Normal file
2121
utils/modular_model_converter.py
Normal file
File diff suppressed because it is too large
Load Diff
913
utils/modular_model_detector.py
Normal file
913
utils/modular_model_detector.py
Normal file
@@ -0,0 +1,913 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# 🔴🔴🔴 THIS IS AN INTERNAL TOOL. It WILL interact with the hub and use significant local compute resources. Use at your own risk.
|
||||
|
||||
|
||||
"""
|
||||
Modular model detector: utilities for detecting code similarities between model implementations.
|
||||
|
||||
This module provides tools to analyze and detect similarities between different model implementations
|
||||
in the transformers library. It uses both embedding-based and token-based (Jaccard) similarity metrics
|
||||
to identify similar code patterns across different model definitions.
|
||||
|
||||
Its function is to identify which models can be _modular_-ized, meaning, which already existing classes are
|
||||
present in the codebase and look very similar to the one we have.
|
||||
|
||||
Two scores are computed, one is a code embedding, and the other is a simple Jaccard bag-of-tokens index for overlap
|
||||
of token sets. A score of 1.00 means the code is identical.
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
cd transformers
|
||||
|
||||
# Use directly the util, it will download the index embedding from the hub. It will require some RAM/VRAM.
|
||||
|
||||
>>> python utils/modular_model_detector.py --modeling-file my_new_beit3_modeling_file.py
|
||||
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 33.62it/s]
|
||||
encoding 21 query definitions with Qwen/Qwen3-Embedding-4B (device=cuda, batch=16, max_length=4096)
|
||||
stuff.py::Beit3ImageTextMatchingOutput:
|
||||
embedding:
|
||||
blip_2::Blip2ImageTextMatchingModelOutput (0.9994)
|
||||
chinese_clip::ChineseCLIPOutput (0.9818)
|
||||
owlvit::OwlViTOutput (0.9818)
|
||||
aimv2::Aimv2Output (0.9818)
|
||||
blip::BlipOutput (0.9818)
|
||||
jaccard:
|
||||
owlv2::Owlv2Output (0.9667)
|
||||
metaclip_2::MetaClip2Output (0.9667)
|
||||
altclip::AltCLIPOutput (0.9667)
|
||||
owlvit::OwlViTOutput (0.9667)
|
||||
blip::BlipOutput (0.9667)
|
||||
intersection:
|
||||
blip::BlipOutput
|
||||
owlvit::OwlViTOutput
|
||||
|
||||
stuff.py::Beit3MLP:
|
||||
embedding:
|
||||
efficientloftr::EfficientLoFTRMLP (0.9718)
|
||||
seggpt::SegGptMlp (0.9650)
|
||||
mgp_str::MgpstrMlp (0.9646)
|
||||
vitpose_backbone::VitPoseBackboneMLP (0.9640)
|
||||
granitemoeshared::GraniteMoeSharedMLP (0.9633)
|
||||
jaccard:
|
||||
chinese_clip::ChineseCLIPTextSelfOutput (0.5294)
|
||||
convbert::ConvBertSelfOutput (0.5294)
|
||||
bert::BertSelfOutput (0.5294)
|
||||
roformer::RoFormerSelfOutput (0.5294)
|
||||
layoutlmv3::LayoutLMv3SelfOutput (0.5294)
|
||||
intersection:
|
||||
|
||||
stuff.py::Beit3FeedForwardNetwork:
|
||||
embedding:
|
||||
prophetnet::ProphetNetFeedForward (0.9766)
|
||||
dab_detr::DabDetrDecoderLayerFFN (0.9730)
|
||||
kosmos2::Kosmos2TextFFN (0.9697)
|
||||
kosmos2_5::Kosmos2_5TextFFN (0.9697)
|
||||
parakeet::ParakeetEncoderFeedForward (0.9678)
|
||||
jaccard:
|
||||
groupvit::GroupViTMLP (0.4898)
|
||||
convbert::ConvBertOutput (0.4600)
|
||||
chinese_clip::ChineseCLIPTextOutput (0.4565)
|
||||
bert::BertOutput (0.4565)
|
||||
roformer::RoFormerOutput (0.4565)
|
||||
intersection:
|
||||
|
||||
|
||||
|
||||
```
|
||||
|
||||
|
||||
# If you wish to build the index first, you can run
|
||||
|
||||
python utils/modular_model_detector.py --build
|
||||
|
||||
# You can also change the embedding model for a larger/smaller one.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub import logging as huggingface_hub_logging
|
||||
from safetensors.numpy import load_file as safetensors_load
|
||||
from safetensors.numpy import save_file as safetensors_save
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers.utils import enable_tf32
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
|
||||
# ANSI color codes for CLI output styling
|
||||
ANSI_RESET = "\033[0m"
|
||||
ANSI_BOLD = "\033[1m"
|
||||
ANSI_HEADER = "\033[1;36m"
|
||||
ANSI_SECTION = "\033[1;35m"
|
||||
ANSI_ROW = "\033[0;37m"
|
||||
ANSI_HIGHLIGHT_TOP = "\033[1;32m"
|
||||
ANSI_HIGHLIGHT_OLD = "\033[1;33m"
|
||||
ANSI_HIGHLIGHT_CANDIDATE = "\033[1;34m"
|
||||
|
||||
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
MODELS_ROOT = Path("src/transformers/models")
|
||||
EMBEDDINGS_PATH = "embeddings.safetensors"
|
||||
INDEX_MAP_PATH = "code_index_map.json"
|
||||
TOKENS_PATH = "code_index_tokens.json"
|
||||
HUB_DATASET_DEFAULT = "hf-internal-testing/transformers_code_embeddings"
|
||||
|
||||
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B"
|
||||
BATCH_SIZE = 16
|
||||
MAX_LENGTH = 4096
|
||||
|
||||
|
||||
def _normalize(string: str | None) -> str:
|
||||
"""
|
||||
Normalize a string by removing all non-alphanumeric characters and converting to lowercase.
|
||||
|
||||
Args:
|
||||
string (`str` or `None`): The string to normalize.
|
||||
|
||||
Returns:
|
||||
`str`: The normalized string, or empty string if input is None.
|
||||
"""
|
||||
return re.sub(r"[^a-z0-9]+", "", string.lower()) if string else ""
|
||||
|
||||
|
||||
def _strip_source_for_tokens(code: str) -> str:
|
||||
"""
|
||||
Strip docstrings, comments, and import statements from source code.
|
||||
|
||||
Args:
|
||||
code (`str`): The source code to strip.
|
||||
|
||||
Returns:
|
||||
`str`: The stripped source code.
|
||||
"""
|
||||
code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code)
|
||||
code = re.sub(r"#.*", "", code)
|
||||
return "\n".join(line for line in code.splitlines() if not re.match(r"\s*(from|import)\s+", line))
|
||||
|
||||
|
||||
def _tokenize(code: str) -> set[str]:
|
||||
"""
|
||||
Extract all Python identifiers from source code.
|
||||
|
||||
Args:
|
||||
code (`str`): The source code to tokenize.
|
||||
|
||||
Returns:
|
||||
`set[str]`: A set of all identifiers found in the code.
|
||||
"""
|
||||
return set(re.findall(r"\b[a-zA-Z_][a-zA-Z0-9_]*\b", code))
|
||||
|
||||
|
||||
def _leading_symbol_prefix(name: str) -> str:
|
||||
"""
|
||||
Extract the leading prefix from a symbol name (e.g., 'Llama' from 'LlamaAttention').
|
||||
|
||||
Args:
|
||||
name (`str`): The symbol name to extract prefix from.
|
||||
|
||||
Returns:
|
||||
`str`: The leading prefix, or empty string if no match.
|
||||
"""
|
||||
match = re.match(r"^([A-Z][a-z0-9]+)", name) or re.match(r"^([A-Za-z0-9]+)", name)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
|
||||
def _sanitize_for_embedding(code: str, model_hint: str | None, symbol_hint: str | None) -> str:
|
||||
"""
|
||||
Sanitize code for embedding by replacing model-specific identifiers with generic placeholder.
|
||||
|
||||
Args:
|
||||
code (`str`): The source code to sanitize.
|
||||
model_hint (`str` or `None`): Hint about the model name (e.g., 'llama').
|
||||
symbol_hint (`str` or `None`): Hint about the symbol name (e.g., 'LlamaAttention').
|
||||
|
||||
Returns:
|
||||
`str`: The sanitized code with model-specific identifiers replaced by 'Model'.
|
||||
"""
|
||||
base = _strip_source_for_tokens(code)
|
||||
variants = set()
|
||||
if model_hint:
|
||||
variants.add(model_hint)
|
||||
variants.add(model_hint.replace("_", ""))
|
||||
variants.add(re.sub(r"\d+", "", model_hint))
|
||||
if symbol_hint:
|
||||
prefix = _leading_symbol_prefix(symbol_hint)
|
||||
if prefix:
|
||||
variants.add(prefix)
|
||||
variants.add(prefix.replace("_", ""))
|
||||
variants.add(re.sub(r"\d+", "", prefix))
|
||||
variants |= {variant.lower() for variant in list(variants)}
|
||||
sanitized = base
|
||||
for variant in sorted({x for x in variants if len(x) >= 3}, key=len, reverse=True):
|
||||
sanitized = re.sub(re.escape(variant), "Model", sanitized, flags=re.IGNORECASE)
|
||||
return sanitized
|
||||
|
||||
|
||||
class CodeSimilarityAnalyzer:
|
||||
"""
|
||||
Analyzer for detecting code similarities between model implementations.
|
||||
|
||||
This class uses embedding-based and token-based similarity metrics to identify similar
|
||||
code patterns across different model definitions in the transformers library.
|
||||
|
||||
Args:
|
||||
hub_dataset (`str`): The Hub dataset repository ID containing the code embeddings index.
|
||||
"""
|
||||
|
||||
def __init__(self, hub_dataset: str):
|
||||
for name in ("huggingface_hub", "httpx", "urllib3", "transformers"):
|
||||
logging.getLogger(name).setLevel(logging.ERROR)
|
||||
huggingface_hub_logging.set_verbosity_error()
|
||||
transformers_logging.set_verbosity_error()
|
||||
enable_tf32(True)
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
self.models_root = MODELS_ROOT
|
||||
self.hub_dataset = hub_dataset
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
|
||||
self.model = AutoModel.from_pretrained(EMBEDDING_MODEL, torch_dtype="auto", device_map="auto").eval()
|
||||
|
||||
self.device = self.model.device
|
||||
self.index_dir: Path | None = None
|
||||
|
||||
# ---------- HUB IO ----------
|
||||
|
||||
def _resolve_index_path(self, filename: str) -> Path:
|
||||
if self.index_dir is None:
|
||||
return Path(filename)
|
||||
return self.index_dir / filename
|
||||
|
||||
def ensure_local_index(self) -> None:
|
||||
"""Ensure index files are available locally, preferring Hub cache snapshots."""
|
||||
if self.index_dir is not None and all(
|
||||
(self.index_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)
|
||||
):
|
||||
return
|
||||
|
||||
workspace_dir = Path.cwd()
|
||||
if all((workspace_dir / fname).exists() for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH)):
|
||||
self.index_dir = workspace_dir
|
||||
return
|
||||
|
||||
logging.info(f"downloading index from hub cache: {self.hub_dataset}")
|
||||
snapshot_path = snapshot_download(repo_id=self.hub_dataset, repo_type="dataset")
|
||||
snapshot_dir = Path(snapshot_path)
|
||||
missing = [
|
||||
fname for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH) if not (snapshot_dir / fname).exists()
|
||||
]
|
||||
if missing:
|
||||
raise FileNotFoundError("Missing expected files in Hub snapshot: " + ", ".join(missing))
|
||||
self.index_dir = snapshot_dir
|
||||
|
||||
def push_index_to_hub(self) -> None:
|
||||
"""Upload index files to the Hub dataset repository."""
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=self.hub_dataset, repo_type="dataset", exist_ok=True)
|
||||
for fname in (EMBEDDINGS_PATH, INDEX_MAP_PATH, TOKENS_PATH):
|
||||
logging.info(f"pushing {fname} -> {self.hub_dataset}")
|
||||
api.upload_file(
|
||||
path_or_fileobj=fname,
|
||||
path_in_repo=os.path.basename(fname),
|
||||
repo_id=self.hub_dataset,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
# ---------- parsing & encoding ----------
|
||||
|
||||
def _extract_definitions(
|
||||
self, file_path: Path, relative_to: Path | None = None, model_hint: str | None = None
|
||||
) -> tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]:
|
||||
"""
|
||||
Extract class and function definitions from a Python file.
|
||||
|
||||
Args:
|
||||
file_path (`Path`): Path to the Python file to parse.
|
||||
relative_to (`Path` or `None`): Base path for computing relative identifiers.
|
||||
model_hint (`str` or `None`): Model name hint for sanitization.
|
||||
|
||||
Returns:
|
||||
`tuple[dict[str, str], dict[str, str], dict[str, list[str]], dict[str, str]]`: A tuple containing:
|
||||
- definitions_raw: Mapping of identifiers to raw source code
|
||||
- definitions_sanitized: Mapping of identifiers to sanitized source code
|
||||
- definitions_tokens: Mapping of identifiers to sorted token lists
|
||||
- definitions_kind: Mapping of identifiers to either "class" or "function"
|
||||
"""
|
||||
definitions_raw = {}
|
||||
definitions_sanitized = {}
|
||||
definitions_tokens = {}
|
||||
definitions_kind = {}
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
lines = source.splitlines()
|
||||
tree = ast.parse(source)
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
segment = ast.get_source_segment(source, node)
|
||||
if segment is None and hasattr(node, "lineno") and hasattr(node, "end_lineno"):
|
||||
start = max(0, node.lineno - 1)
|
||||
end = node.end_lineno
|
||||
segment = "\n".join(lines[start:end])
|
||||
if segment:
|
||||
identifier = (
|
||||
f"{file_path.relative_to(relative_to)}:{node.name}"
|
||||
if relative_to
|
||||
else f"{file_path.name}:{node.name}"
|
||||
)
|
||||
definitions_raw[identifier] = segment
|
||||
sanitized = _sanitize_for_embedding(segment, model_hint, node.name)
|
||||
definitions_sanitized[identifier] = sanitized
|
||||
definitions_tokens[identifier] = sorted(_tokenize(sanitized))
|
||||
if isinstance(node, ast.ClassDef):
|
||||
definitions_kind[identifier] = "class"
|
||||
else:
|
||||
definitions_kind[identifier] = "function"
|
||||
return definitions_raw, definitions_sanitized, definitions_tokens, definitions_kind
|
||||
|
||||
def _infer_model_from_relative_path(self, relative_path: Path) -> str | None:
|
||||
try:
|
||||
relative = relative_path.resolve().relative_to(self.models_root.resolve())
|
||||
return relative.parts[0]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _infer_query_model_name(self, modeling_file: Path) -> str | None:
|
||||
model = self._infer_model_from_relative_path(modeling_file)
|
||||
if model:
|
||||
return model
|
||||
stem = modeling_file.stem
|
||||
if stem.startswith("modeling_") and len(stem) > len("modeling_"):
|
||||
return stem[len("modeling_") :]
|
||||
return None
|
||||
|
||||
def _encode_batch(self, texts: list[str]) -> np.ndarray:
|
||||
"""
|
||||
Encode a batch of texts into normalized embeddings.
|
||||
|
||||
Args:
|
||||
texts (`list[str]`): List of text strings to encode.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: Normalized embeddings as a float32 numpy array.
|
||||
"""
|
||||
encoded = self.tokenizer(texts, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
|
||||
encoded = {key: value.to(self.device) for key, value in encoded.items()}
|
||||
with (
|
||||
torch.autocast(device_type=self.device.type, dtype=self.dtype)
|
||||
if self.device.type == "cuda"
|
||||
else torch.no_grad()
|
||||
):
|
||||
output = self.model(**encoded)
|
||||
if hasattr(output, "last_hidden_state"):
|
||||
embeddings = output.last_hidden_state
|
||||
mask = encoded["attention_mask"].unsqueeze(-1)
|
||||
embeddings = (embeddings * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-9)
|
||||
elif hasattr(output, "pooler_output"):
|
||||
embeddings = output.pooler_output
|
||||
else:
|
||||
embeddings = output[0].mean(dim=1)
|
||||
embeddings = torch.nn.functional.normalize(embeddings.float(), p=2, dim=1)
|
||||
return embeddings.cpu().numpy().astype("float32")
|
||||
|
||||
def encode(self, texts: list[str]) -> np.ndarray:
|
||||
"""
|
||||
Encode a list of texts into embeddings, processing in batches.
|
||||
|
||||
Args:
|
||||
texts (`list[str]`): List of text strings to encode.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: Stacked embeddings for all texts.
|
||||
"""
|
||||
output = []
|
||||
for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="encode", leave=False):
|
||||
output.append(self._encode_batch(texts[i : i + BATCH_SIZE]))
|
||||
if self.device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
return np.vstack(output) if output else np.zeros((0, 0), dtype="float32")
|
||||
|
||||
# ---------- build & search ----------
|
||||
|
||||
def build_index(self) -> None:
|
||||
"""Build the code similarity index from all modeling files and save to disk."""
|
||||
logging.info("collecting files")
|
||||
files = list(self.models_root.rglob("modeling_*.py"))
|
||||
logging.info(f"parsing {len(files)} files")
|
||||
|
||||
identifiers = []
|
||||
sanitized_sources = []
|
||||
tokens_map = {}
|
||||
|
||||
for file_path in tqdm(files, desc="parse", leave=False):
|
||||
model_hint = self._infer_model_from_relative_path(file_path)
|
||||
(
|
||||
_,
|
||||
definitions_sanitized,
|
||||
definitions_tokens,
|
||||
_,
|
||||
) = self._extract_definitions(file_path, self.models_root, model_hint)
|
||||
for identifier in definitions_sanitized.keys():
|
||||
identifiers.append(identifier)
|
||||
sanitized_sources.append(definitions_sanitized[identifier])
|
||||
tokens_map[identifier] = definitions_tokens[identifier]
|
||||
|
||||
logging.info(
|
||||
f"encoding {len(sanitized_sources)} definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})"
|
||||
)
|
||||
embeddings = self.encode(sanitized_sources)
|
||||
safetensors_save({"embeddings": embeddings}, EMBEDDINGS_PATH)
|
||||
with open(INDEX_MAP_PATH, "w", encoding="utf-8") as file:
|
||||
json.dump({int(i): identifiers[i] for i in range(len(identifiers))}, file)
|
||||
with open(TOKENS_PATH, "w", encoding="utf-8") as file:
|
||||
json.dump(tokens_map, file)
|
||||
|
||||
self.index_dir = Path.cwd()
|
||||
|
||||
def _topk_embedding(
|
||||
self,
|
||||
query_embedding_row: np.ndarray,
|
||||
base_embeddings: np.ndarray,
|
||||
identifier_map: dict[int, str],
|
||||
self_model_normalized: str,
|
||||
self_name: str,
|
||||
k: int,
|
||||
) -> list[tuple[str, float]]:
|
||||
similarities = query_embedding_row @ base_embeddings.T
|
||||
indices = np.argpartition(-similarities, k + 32)[: k + 32]
|
||||
indices = indices[np.argsort(-similarities[indices])]
|
||||
output = []
|
||||
for match_id in indices:
|
||||
identifier = identifier_map[int(match_id)]
|
||||
parent_relative_path, match_name = identifier.split(":", 1)
|
||||
parent_model = Path(parent_relative_path).parts[0]
|
||||
if match_name == self_name:
|
||||
continue
|
||||
if self_model_normalized and _normalize(parent_model) == self_model_normalized:
|
||||
continue
|
||||
output.append((identifier, float(similarities[match_id])))
|
||||
if len(output) >= k:
|
||||
break
|
||||
return output
|
||||
|
||||
def _topk_jaccard(
|
||||
self,
|
||||
query_tokens: set[str],
|
||||
identifiers: list[str],
|
||||
tokens_map: dict[str, list[str]],
|
||||
self_model_normalized: str,
|
||||
self_name: str,
|
||||
k: int,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""
|
||||
Find top-k most similar definitions using Jaccard similarity on token sets.
|
||||
|
||||
Args:
|
||||
query_tokens (`set[str]`): Set of tokens from the query definition.
|
||||
identifiers (`list[str]`): List of all definition identifiers in the index.
|
||||
tokens_map (`dict[str, list[str]]`): Mapping of identifiers to their token lists.
|
||||
self_model_normalized (`str`): Normalized name of the query model to exclude.
|
||||
self_name (`str`): Name of the query definition to exclude.
|
||||
k (`int`): Number of top results to return.
|
||||
|
||||
Returns:
|
||||
`list[tuple[str, float]]`: List of (identifier, score) tuples.
|
||||
"""
|
||||
scores = []
|
||||
for identifier in identifiers:
|
||||
parent_relative_path, match_name = identifier.split(":", 1)
|
||||
parent_model = Path(parent_relative_path).parts[0]
|
||||
if match_name == self_name:
|
||||
continue
|
||||
if self_model_normalized and _normalize(parent_model) == self_model_normalized:
|
||||
continue
|
||||
tokens = set(tokens_map.get(identifier, []))
|
||||
if not tokens or not query_tokens:
|
||||
continue
|
||||
score = len(query_tokens & tokens) / len(query_tokens | tokens)
|
||||
if score > 0:
|
||||
scores.append((identifier, score))
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
return scores[:k]
|
||||
|
||||
def analyze_file(
|
||||
self, modeling_file: Path, top_k_per_item: int = 5, allow_hub_fallback: bool = True, use_jaccard=False
|
||||
) -> dict[str, dict[str, list]]:
|
||||
"""
|
||||
Analyze a modeling file and find similar code definitions in the index.
|
||||
|
||||
Args:
|
||||
modeling_file (`Path`): Path to the modeling file to analyze.
|
||||
top_k_per_item (`int`, *optional*, defaults to 5): Number of top matches to return per definition.
|
||||
allow_hub_fallback (`bool`, *optional*, defaults to `True`): Whether to download index from Hub if not found locally.
|
||||
|
||||
Returns:
|
||||
`dict[str, dict[str, list]]`: Dictionary mapping definition names to their similarity results.
|
||||
Each result contains 'embedding', 'jaccard', and 'intersection' keys.
|
||||
"""
|
||||
if allow_hub_fallback:
|
||||
self.ensure_local_index()
|
||||
|
||||
base = safetensors_load(str(self._resolve_index_path(EMBEDDINGS_PATH)))
|
||||
base_embeddings = base["embeddings"]
|
||||
with open(self._resolve_index_path(INDEX_MAP_PATH), "r", encoding="utf-8") as file:
|
||||
identifier_map = {int(key): value for key, value in json.load(file).items()}
|
||||
identifiers = [identifier_map[i] for i in range(len(identifier_map))]
|
||||
with open(self._resolve_index_path(TOKENS_PATH), "r", encoding="utf-8") as file:
|
||||
tokens_map = json.load(file)
|
||||
|
||||
self_model = self._infer_query_model_name(modeling_file)
|
||||
definitions_raw, definitions_sanitized, _, definitions_kind = self._extract_definitions(
|
||||
modeling_file, None, self_model
|
||||
)
|
||||
query_identifiers = list(definitions_raw.keys())
|
||||
query_sources_sanitized = [definitions_sanitized[key] for key in query_identifiers]
|
||||
query_tokens_list = [set(_tokenize(source)) for source in query_sources_sanitized]
|
||||
self_model_normalized = _normalize(self_model)
|
||||
|
||||
logging.info(
|
||||
f"encoding {len(query_sources_sanitized)} query definitions with {EMBEDDING_MODEL} (device={self.device.type}, batch={BATCH_SIZE}, max_length={MAX_LENGTH})"
|
||||
)
|
||||
query_embeddings = self.encode(query_sources_sanitized)
|
||||
|
||||
output = {}
|
||||
for i, query_identifier in enumerate(query_identifiers):
|
||||
query_name = query_identifier.split(":")[-1]
|
||||
embedding_top = self._topk_embedding(
|
||||
query_embeddings[i], base_embeddings, identifier_map, self_model_normalized, query_name, top_k_per_item
|
||||
)
|
||||
embedding_set = {identifier for identifier, _ in embedding_top}
|
||||
kind = definitions_kind.get(query_identifier, "function")
|
||||
entry = {"kind": kind, "embedding": embedding_top}
|
||||
if use_jaccard:
|
||||
jaccard_top = self._topk_jaccard(
|
||||
query_tokens_list[i], identifiers, tokens_map, self_model_normalized, query_name, top_k_per_item
|
||||
)
|
||||
jaccard_set = {identifier for identifier, _ in jaccard_top}
|
||||
intersection = set(embedding_set & jaccard_set)
|
||||
|
||||
entry.update({"jaccard": jaccard_top, "intersection": intersection})
|
||||
output[query_name] = entry
|
||||
return output
|
||||
|
||||
|
||||
_RELEASE_RE = re.compile(
|
||||
r"(?:^|[\*_`\s>])(?:this|the)\s+model\s+was\s+released\s+on\s+(\d{4}-\d{2}-\d{2})\b", re.IGNORECASE
|
||||
)
|
||||
|
||||
|
||||
def build_date_data() -> dict[str, str]:
|
||||
"""
|
||||
Scan Markdown files in `root_dir` and build {model_id: date_released}.
|
||||
|
||||
- model_id is the filename without extension (e.g., "llama" for "llama.md")
|
||||
- date_released is the first YYYY-MM-DD matched after "...was released on ..."
|
||||
- Ignores non-*.md files and directories.
|
||||
|
||||
Returns:
|
||||
dict[str, str]: mapping of model_id -> ISO date string (YYYY-MM-DD).
|
||||
Files without a match are simply omitted.
|
||||
"""
|
||||
|
||||
root_dir = transformers.__file__.split("src/transformers")[0]
|
||||
root = Path(root_dir).joinpath("docs/source/en/model_doc")
|
||||
result: dict[str, str] = {}
|
||||
|
||||
for md_path in root.glob("*.md"):
|
||||
try:
|
||||
text = md_path.read_text(encoding="utf-8", errors="ignore")
|
||||
except Exception:
|
||||
# Skip unreadable files quietly
|
||||
logging.info(f"Failed to read md for {md_path}")
|
||||
|
||||
m = _RELEASE_RE.search(text)
|
||||
if m:
|
||||
model_id = md_path.stem # e.g., "llama" from "llama.md"
|
||||
result[model_id] = m.group(1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _format_table(headers: list[str], rows: list[tuple[str, ...] | None], row_styles: list[str] | None = None) -> str:
|
||||
if not rows:
|
||||
return f"{ANSI_ROW}(no matches){ANSI_RESET}"
|
||||
|
||||
widths = [len(header) for header in headers]
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
for idx, cell in enumerate(row):
|
||||
widths[idx] = max(widths[idx], len(cell))
|
||||
|
||||
header_line = " | ".join(header.ljust(widths[idx]) for idx, header in enumerate(headers))
|
||||
divider = "-+-".join("-" * widths[idx] for idx in range(len(headers)))
|
||||
total_width = sum(widths) + 3 * (len(headers) - 1)
|
||||
|
||||
styled_rows = []
|
||||
style_idx = 0
|
||||
for row in rows:
|
||||
if row is None:
|
||||
styled_rows.append(f"{ANSI_SECTION}{'-' * total_width}{ANSI_RESET}")
|
||||
continue
|
||||
|
||||
line = " | ".join(cell.ljust(widths[col_idx]) for col_idx, cell in enumerate(row))
|
||||
style = ANSI_ROW
|
||||
if row_styles and style_idx < len(row_styles) and row_styles[style_idx]:
|
||||
style = row_styles[style_idx]
|
||||
styled_rows.append(f"{style}{line}{ANSI_RESET}")
|
||||
style_idx += 1
|
||||
|
||||
return "\n".join([f"{ANSI_SECTION}{header_line}{ANSI_RESET}", divider] + styled_rows)
|
||||
|
||||
|
||||
def _parse_release_date(value: str) -> datetime | None:
|
||||
"""Return a datetime parsed from YYYY-MM-DD strings, otherwise None."""
|
||||
try:
|
||||
return datetime.strptime(value, "%Y-%m-%d")
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def _load_definition_line_map(relative_path: str) -> dict[str, int]:
|
||||
"""Return {definition_name: line_number} for top-level definitions in the given file."""
|
||||
file_path = MODELS_ROOT / relative_path
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
except (FileNotFoundError, OSError):
|
||||
return {} # gracefully keep going
|
||||
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError:
|
||||
return {}
|
||||
|
||||
line_map: dict[str, int] = {}
|
||||
for node in ast.iter_child_nodes(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
line_map[node.name] = getattr(node, "lineno", None) or 1
|
||||
elif isinstance(node, ast.Assign):
|
||||
continue
|
||||
return line_map
|
||||
|
||||
|
||||
def _resolve_definition_location(relative_path: str, definition: str) -> tuple[str, str]:
|
||||
"""Return full path and formatted line number string for the given definition."""
|
||||
full_path = MODELS_ROOT / relative_path
|
||||
line = _load_definition_line_map(relative_path).get(definition)
|
||||
line_str = str(line) if line is not None else "?"
|
||||
return str(full_path), line_str
|
||||
|
||||
|
||||
def _colorize_heading(text: str) -> str:
|
||||
return f"{ANSI_HEADER}{ANSI_BOLD}{text}{ANSI_RESET}"
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for the modular model detector."""
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
parser = argparse.ArgumentParser(prog="hf-code-sim")
|
||||
parser.add_argument("--build", action="store_true")
|
||||
parser.add_argument("--modeling-file", type=str, help='You can just specify "vits" if you are lazy like me.')
|
||||
parser.add_argument(
|
||||
"--push-new-index", action="store_true", help="After --build, push index files to a Hub dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub-dataset", type=str, default=HUB_DATASET_DEFAULT, help="Hub dataset repo id to pull/push the index."
|
||||
)
|
||||
parser.add_argument("--use_jaccard", type=bool, default=False, help="Whether or not to use jaccard index")
|
||||
args = parser.parse_args()
|
||||
|
||||
analyzer = CodeSimilarityAnalyzer(hub_dataset=args.hub_dataset)
|
||||
|
||||
if args.build:
|
||||
analyzer.build_index()
|
||||
if args.push_new_index:
|
||||
analyzer.push_index_to_hub()
|
||||
return
|
||||
|
||||
if not args.modeling_file:
|
||||
raise SystemExit("Provide --modeling-file or use --build")
|
||||
|
||||
dates = build_date_data()
|
||||
modeling_file = args.modeling_file
|
||||
if os.sep not in modeling_file:
|
||||
modeling_file = os.path.join("src", "transformers", "models", modeling_file, f"modeling_{modeling_file}.py")
|
||||
|
||||
results = analyzer.analyze_file(
|
||||
Path(modeling_file), top_k_per_item=5, allow_hub_fallback=True, use_jaccard=args.use_jaccard
|
||||
)
|
||||
modeling_filename = Path(modeling_file).name
|
||||
release_key = modeling_filename.split("modeling_")[-1][:-3]
|
||||
release_date = dates.get(release_key, "unknown release date")
|
||||
|
||||
aggregate_scores: dict[str, float] = {}
|
||||
for data in results.values():
|
||||
for identifier, score in data.get("embedding", []):
|
||||
try:
|
||||
relative_path, _ = identifier.split(":", 1)
|
||||
except ValueError:
|
||||
continue
|
||||
aggregate_scores[relative_path] = aggregate_scores.get(relative_path, 0.0) + score
|
||||
|
||||
best_candidate_path: str | None = None
|
||||
if aggregate_scores:
|
||||
best_candidate_path = max(aggregate_scores.items(), key=lambda item: item[1])[0]
|
||||
best_model = Path(best_candidate_path).parts[0] if Path(best_candidate_path).parts else "?"
|
||||
best_release = dates.get(best_model, "unknown release date")
|
||||
logging.info(
|
||||
f"{ANSI_HIGHLIGHT_CANDIDATE}Closest overall candidate: {MODELS_ROOT / best_candidate_path}"
|
||||
f" (release: {best_release}, total score: {aggregate_scores[best_candidate_path]:.4f}){ANSI_RESET}"
|
||||
)
|
||||
|
||||
grouped: dict[str, list[tuple[str, dict]]] = {"class": [], "function": []}
|
||||
for query_name, data in results.items():
|
||||
kind = data.get("kind", "function")
|
||||
grouped.setdefault(kind, []).append((query_name, data))
|
||||
|
||||
section_titles = [("class", "Classes"), ("function", "Functions")]
|
||||
legend_shown = False
|
||||
for kind, title in section_titles:
|
||||
entries = grouped.get(kind, [])
|
||||
if not entries:
|
||||
continue
|
||||
|
||||
metrics_present: set[str] = set()
|
||||
for _, data in entries:
|
||||
if data.get("embedding"):
|
||||
metrics_present.add("embedding")
|
||||
if args.use_jaccard:
|
||||
if data.get("jaccard"):
|
||||
metrics_present.add("jaccard")
|
||||
if data.get("intersection"):
|
||||
metrics_present.add("intersection")
|
||||
|
||||
include_metric_column = bool(metrics_present - {"embedding"})
|
||||
headers = ["Symbol", "Path", "Score", "Release"]
|
||||
if include_metric_column:
|
||||
headers = ["Symbol", "Metric", "Path", "Score", "Release"]
|
||||
|
||||
table_rows: list[tuple[str, ...] | None] = []
|
||||
row_styles: list[str] = []
|
||||
has_metric_rows = False
|
||||
|
||||
logging.info(_colorize_heading(title))
|
||||
|
||||
for query_name, data in entries:
|
||||
if table_rows:
|
||||
table_rows.append(None)
|
||||
|
||||
symbol_label = query_name
|
||||
if release_date:
|
||||
symbol_label = f"{symbol_label}"
|
||||
|
||||
symbol_row = (symbol_label,) + ("",) * (len(headers) - 1)
|
||||
table_rows.append(symbol_row)
|
||||
row_styles.append(ANSI_BOLD)
|
||||
|
||||
embedding_details: list[tuple[str, str, str, float, str]] = []
|
||||
embedding_style_indices: list[int] = []
|
||||
|
||||
for identifier, score in data.get("embedding", []):
|
||||
try:
|
||||
relative_path, match_name = identifier.split(":", 1)
|
||||
except ValueError:
|
||||
continue
|
||||
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
|
||||
match_release = dates.get(model_id, "unknown release date")
|
||||
full_path, line = _resolve_definition_location(relative_path, match_name)
|
||||
display_path = f"{full_path}:{line} ({match_name})"
|
||||
|
||||
if include_metric_column:
|
||||
row = ("", "embedding", display_path, f"{score:.4f}", match_release)
|
||||
else:
|
||||
row = ("", display_path, f"{score:.4f}", match_release)
|
||||
|
||||
table_rows.append(row)
|
||||
row_styles.append(ANSI_ROW)
|
||||
embedding_style_indices.append(len(row_styles) - 1)
|
||||
embedding_details.append((relative_path, model_id, match_name, score, match_release))
|
||||
has_metric_rows = True
|
||||
|
||||
if embedding_details:
|
||||
highest_score = None
|
||||
highest_idx = None
|
||||
for idx, (_, _, _, score, _) in enumerate(embedding_details):
|
||||
if highest_score is None or score > highest_score:
|
||||
highest_score = score
|
||||
highest_idx = idx
|
||||
|
||||
if highest_idx is not None:
|
||||
row_styles[embedding_style_indices[highest_idx]] = ANSI_HIGHLIGHT_TOP
|
||||
|
||||
if highest_score is not None:
|
||||
oldest_idx = None
|
||||
oldest_date = None
|
||||
for idx, (_, model_id, _, score, release_value) in enumerate(embedding_details):
|
||||
if highest_score - score > 0.1:
|
||||
continue
|
||||
parsed = _parse_release_date(release_value)
|
||||
if parsed is None:
|
||||
continue
|
||||
if oldest_date is None or parsed < oldest_date:
|
||||
oldest_date = parsed
|
||||
oldest_idx = idx
|
||||
if (
|
||||
oldest_idx is not None
|
||||
and row_styles[embedding_style_indices[oldest_idx]] != ANSI_HIGHLIGHT_TOP
|
||||
):
|
||||
row_styles[embedding_style_indices[oldest_idx]] = ANSI_HIGHLIGHT_OLD
|
||||
|
||||
if best_candidate_path is not None:
|
||||
for idx, (relative_path, _, _, _, _) in enumerate(embedding_details):
|
||||
style_position = embedding_style_indices[idx]
|
||||
if row_styles[style_position] != ANSI_ROW:
|
||||
continue
|
||||
if relative_path == best_candidate_path:
|
||||
row_styles[style_position] = ANSI_HIGHLIGHT_CANDIDATE
|
||||
|
||||
if args.use_jaccard:
|
||||
for identifier, score in data.get("jaccard", []):
|
||||
try:
|
||||
relative_path, match_name = identifier.split(":", 1)
|
||||
except ValueError:
|
||||
continue
|
||||
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
|
||||
match_release = dates.get(model_id, "unknown release date")
|
||||
full_path, line = _resolve_definition_location(relative_path, match_name)
|
||||
display_path = f"{full_path}:{line} ({match_name})"
|
||||
|
||||
if include_metric_column:
|
||||
row = ("", "jaccard", display_path, f"{score:.4f}", match_release)
|
||||
else:
|
||||
row = ("", display_path, f"{score:.4f}", match_release)
|
||||
|
||||
table_rows.append(row)
|
||||
row_styles.append(ANSI_ROW)
|
||||
has_metric_rows = True
|
||||
if best_candidate_path == relative_path:
|
||||
row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE
|
||||
|
||||
for identifier in sorted(data.get("intersection", [])):
|
||||
try:
|
||||
relative_path, match_name = identifier.split(":", 1)
|
||||
except ValueError:
|
||||
continue
|
||||
model_id = Path(relative_path).parts[0] if Path(relative_path).parts else "?"
|
||||
match_release = dates.get(model_id, "unknown release date")
|
||||
full_path, line = _resolve_definition_location(relative_path, match_name)
|
||||
display_path = f"{full_path}:{line} ({match_name})"
|
||||
|
||||
if include_metric_column:
|
||||
row = ("", "intersection", display_path, "--", match_release)
|
||||
else:
|
||||
row = ("", display_path, "--", match_release)
|
||||
|
||||
table_rows.append(row)
|
||||
row_styles.append(ANSI_ROW)
|
||||
has_metric_rows = True
|
||||
if best_candidate_path == relative_path:
|
||||
row_styles[-1] = ANSI_HIGHLIGHT_CANDIDATE
|
||||
|
||||
if table_rows:
|
||||
if not legend_shown and has_metric_rows:
|
||||
logging.info(
|
||||
"Legend: "
|
||||
f"{ANSI_HIGHLIGHT_TOP}highest match{ANSI_RESET}, "
|
||||
f"{ANSI_HIGHLIGHT_OLD}oldest within 0.1{ANSI_RESET}, "
|
||||
f"{ANSI_HIGHLIGHT_CANDIDATE}closest overall candidate{ANSI_RESET}"
|
||||
)
|
||||
legend_shown = True
|
||||
|
||||
logging.info(_format_table(headers, table_rows, row_styles))
|
||||
logging.info("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
750
utils/not_doctested.txt
Normal file
750
utils/not_doctested.txt
Normal file
@@ -0,0 +1,750 @@
|
||||
docs/source/en/_config.py
|
||||
docs/source/en/accelerate.md
|
||||
docs/source/en/add_new_model.md
|
||||
docs/source/en/add_new_pipeline.md
|
||||
docs/source/en/community.md
|
||||
docs/source/en/contributing.md
|
||||
docs/source/en/custom_models.md
|
||||
docs/source/en/debugging.md
|
||||
docs/source/en/fast_tokenizers.md
|
||||
docs/source/en/glossary.md
|
||||
docs/source/en/hpo_train.md
|
||||
docs/source/en/index.md
|
||||
docs/source/en/installation.md
|
||||
docs/source/en/internal/audio_utils.md
|
||||
docs/source/en/internal/file_utils.md
|
||||
docs/source/en/internal/image_processing_utils.md
|
||||
docs/source/en/internal/modeling_utils.md
|
||||
docs/source/en/internal/pipelines_utils.md
|
||||
docs/source/en/internal/time_series_utils.md
|
||||
docs/source/en/internal/tokenization_utils.md
|
||||
docs/source/en/internal/trainer_utils.md
|
||||
docs/source/en/llm_tutorial.md
|
||||
docs/source/en/main_classes/callback.md
|
||||
docs/source/en/main_classes/configuration.md
|
||||
docs/source/en/main_classes/data_collator.md
|
||||
docs/source/en/main_classes/deepspeed.md
|
||||
docs/source/en/main_classes/feature_extractor.md
|
||||
docs/source/en/main_classes/image_processor.md
|
||||
docs/source/en/main_classes/logging.md
|
||||
docs/source/en/main_classes/model.md
|
||||
docs/source/en/main_classes/optimizer_schedules.md
|
||||
docs/source/en/main_classes/output.md
|
||||
docs/source/en/main_classes/pipelines.md
|
||||
docs/source/en/main_classes/processors.md
|
||||
docs/source/en/main_classes/quantization.md
|
||||
docs/source/en/main_classes/tokenizer.md
|
||||
docs/source/en/main_classes/trainer.md
|
||||
docs/source/en/model_doc/albert.md
|
||||
docs/source/en/model_doc/align.md
|
||||
docs/source/en/model_doc/altclip.md
|
||||
docs/source/en/model_doc/audio-spectrogram-transformer.md
|
||||
docs/source/en/model_doc/auto.md
|
||||
docs/source/en/model_doc/autoformer.md
|
||||
docs/source/en/model_doc/bark.md
|
||||
docs/source/en/model_doc/bart.md
|
||||
docs/source/en/model_doc/barthez.md
|
||||
docs/source/en/model_doc/bartpho.md
|
||||
docs/source/en/model_doc/beit.md
|
||||
docs/source/en/model_doc/bert-generation.md
|
||||
docs/source/en/model_doc/bert-japanese.md
|
||||
docs/source/en/model_doc/bert.md
|
||||
docs/source/en/model_doc/bertweet.md
|
||||
docs/source/en/model_doc/big_bird.md
|
||||
docs/source/en/model_doc/bigbird_pegasus.md
|
||||
docs/source/en/model_doc/biogpt.md
|
||||
docs/source/en/model_doc/bit.md
|
||||
docs/source/en/model_doc/blenderbot-small.md
|
||||
docs/source/en/model_doc/blenderbot.md
|
||||
docs/source/en/model_doc/blip-2.md
|
||||
docs/source/en/model_doc/blip.md
|
||||
docs/source/en/model_doc/bloom.md
|
||||
docs/source/en/model_doc/bridgetower.md
|
||||
docs/source/en/model_doc/camembert.md
|
||||
docs/source/en/model_doc/canine.md
|
||||
docs/source/en/model_doc/chinese_clip.md
|
||||
docs/source/en/model_doc/clap.md
|
||||
docs/source/en/model_doc/clip.md
|
||||
docs/source/en/model_doc/clipseg.md
|
||||
docs/source/en/model_doc/codegen.md
|
||||
docs/source/en/model_doc/conditional_detr.md
|
||||
docs/source/en/model_doc/convbert.md
|
||||
docs/source/en/model_doc/convnext.md
|
||||
docs/source/en/model_doc/convnextv2.md
|
||||
docs/source/en/model_doc/cpm.md
|
||||
docs/source/en/model_doc/cpmant.md
|
||||
docs/source/en/model_doc/ctrl.md
|
||||
docs/source/en/model_doc/cvt.md
|
||||
docs/source/en/model_doc/data2vec.md
|
||||
docs/source/en/model_doc/deberta-v2.md
|
||||
docs/source/en/model_doc/deberta.md
|
||||
docs/source/en/model_doc/decision_transformer.md
|
||||
docs/source/en/model_doc/deformable_detr.md
|
||||
docs/source/en/model_doc/deit.md
|
||||
docs/source/en/model_doc/deplot.md
|
||||
docs/source/en/model_doc/detr.md
|
||||
docs/source/en/model_doc/dialogpt.md
|
||||
docs/source/en/model_doc/dinat.md
|
||||
docs/source/en/model_doc/dinov2.md
|
||||
docs/source/en/model_doc/distilbert.md
|
||||
docs/source/en/model_doc/dit.md
|
||||
docs/source/en/model_doc/dpr.md
|
||||
docs/source/en/model_doc/efficientnet.md
|
||||
docs/source/en/model_doc/electra.md
|
||||
docs/source/en/model_doc/encodec.md
|
||||
docs/source/en/model_doc/ernie.md
|
||||
docs/source/en/model_doc/esm.md
|
||||
docs/source/en/model_doc/flan-t5.md
|
||||
docs/source/en/model_doc/flan-ul2.md
|
||||
docs/source/en/model_doc/flaubert.md
|
||||
docs/source/en/model_doc/flava.md
|
||||
docs/source/en/model_doc/fnet.md
|
||||
docs/source/en/model_doc/focalnet.md
|
||||
docs/source/en/model_doc/fsmt.md
|
||||
docs/source/en/model_doc/funnel.md
|
||||
docs/source/en/model_doc/git.md
|
||||
docs/source/en/model_doc/glpn.md
|
||||
docs/source/en/model_doc/gpt-sw3.md
|
||||
docs/source/en/model_doc/gpt2.md
|
||||
docs/source/en/model_doc/gpt_bigcode.md
|
||||
docs/source/en/model_doc/gpt_neo.md
|
||||
docs/source/en/model_doc/gpt_neox.md
|
||||
docs/source/en/model_doc/gpt_neox_japanese.md
|
||||
docs/source/en/model_doc/gptj.md
|
||||
docs/source/en/model_doc/groupvit.md
|
||||
docs/source/en/model_doc/herbert.md
|
||||
docs/source/en/model_doc/hubert.md
|
||||
docs/source/en/model_doc/ibert.md
|
||||
docs/source/en/model_doc/idefics.md
|
||||
docs/source/en/model_doc/imagegpt.md
|
||||
docs/source/en/model_doc/informer.md
|
||||
docs/source/en/model_doc/instructblip.md
|
||||
docs/source/en/model_doc/layoutlm.md
|
||||
docs/source/en/model_doc/layoutlmv2.md
|
||||
docs/source/en/model_doc/layoutlmv3.md
|
||||
docs/source/en/model_doc/layoutxlm.md
|
||||
docs/source/en/model_doc/led.md
|
||||
docs/source/en/model_doc/levit.md
|
||||
docs/source/en/model_doc/lilt.md
|
||||
docs/source/en/model_doc/llama.md
|
||||
docs/source/en/model_doc/llama2.md
|
||||
docs/source/en/model_doc/llava.md
|
||||
docs/source/en/model_doc/llava_next.md
|
||||
docs/source/en/model_doc/longformer.md
|
||||
docs/source/en/model_doc/longt5.md
|
||||
docs/source/en/model_doc/luke.md
|
||||
docs/source/en/model_doc/lxmert.md
|
||||
docs/source/en/model_doc/m2m_100.md
|
||||
docs/source/en/model_doc/madlad-400.md
|
||||
docs/source/en/model_doc/marian.md
|
||||
docs/source/en/model_doc/mask2former.md
|
||||
docs/source/en/model_doc/maskformer.md
|
||||
docs/source/en/model_doc/matcha.md
|
||||
docs/source/en/model_doc/mbart.md
|
||||
docs/source/en/model_doc/megatron-bert.md
|
||||
docs/source/en/model_doc/megatron_gpt2.md
|
||||
docs/source/en/model_doc/mgp-str.md
|
||||
docs/source/en/model_doc/mistral.md
|
||||
docs/source/en/model_doc/mixtral.md
|
||||
docs/source/en/model_doc/mluke.md
|
||||
docs/source/en/model_doc/mms.md
|
||||
docs/source/en/model_doc/mobilebert.md
|
||||
docs/source/en/model_doc/mobilenet_v1.md
|
||||
docs/source/en/model_doc/mobilenet_v2.md
|
||||
docs/source/en/model_doc/mobilevit.md
|
||||
docs/source/en/model_doc/mobilevitv2.md
|
||||
docs/source/en/model_doc/mpnet.md
|
||||
docs/source/en/model_doc/mpt.md
|
||||
docs/source/en/model_doc/mra.md
|
||||
docs/source/en/model_doc/mt5.md
|
||||
docs/source/en/model_doc/musicgen.md
|
||||
docs/source/en/model_doc/musicgen_melody.md
|
||||
docs/source/en/model_doc/mvp.md
|
||||
docs/source/en/model_doc/nllb-moe.md
|
||||
docs/source/en/model_doc/nllb.md
|
||||
docs/source/en/model_doc/nystromformer.md
|
||||
docs/source/en/model_doc/oneformer.md
|
||||
docs/source/en/model_doc/openai-gpt.md
|
||||
docs/source/en/model_doc/opt.md
|
||||
docs/source/en/model_doc/owlvit.md
|
||||
docs/source/en/model_doc/pegasus.md
|
||||
docs/source/en/model_doc/pegasus_x.md
|
||||
docs/source/en/model_doc/perceiver.md
|
||||
docs/source/en/model_doc/phobert.md
|
||||
docs/source/en/model_doc/pix2struct.md
|
||||
docs/source/en/model_doc/plbart.md
|
||||
docs/source/en/model_doc/poolformer.md
|
||||
docs/source/en/model_doc/pop2piano.md
|
||||
docs/source/en/model_doc/prophetnet.md
|
||||
docs/source/en/model_doc/pvt.md
|
||||
docs/source/en/model_doc/qwen2.md
|
||||
docs/source/en/model_doc/qwen2_moe.md
|
||||
docs/source/en/model_doc/rag.md
|
||||
docs/source/en/model_doc/reformer.md
|
||||
docs/source/en/model_doc/regnet.md
|
||||
docs/source/en/model_doc/rembert.md
|
||||
docs/source/en/model_doc/resnet.md
|
||||
docs/source/en/model_doc/roberta-prelayernorm.md
|
||||
docs/source/en/model_doc/roberta.md
|
||||
docs/source/en/model_doc/roc_bert.md
|
||||
docs/source/en/model_doc/roformer.md
|
||||
docs/source/en/model_doc/rwkv.md
|
||||
docs/source/en/model_doc/sam.md
|
||||
docs/source/en/model_doc/sam_hq.md
|
||||
docs/source/en/model_doc/segformer.md
|
||||
docs/source/en/model_doc/sew-d.md
|
||||
docs/source/en/model_doc/sew.md
|
||||
docs/source/en/model_doc/speech-encoder-decoder.md
|
||||
docs/source/en/model_doc/speecht5.md
|
||||
docs/source/en/model_doc/splinter.md
|
||||
docs/source/en/model_doc/squeezebert.md
|
||||
docs/source/en/model_doc/swiftformer.md
|
||||
docs/source/en/model_doc/swin.md
|
||||
docs/source/en/model_doc/swin2sr.md
|
||||
docs/source/en/model_doc/swinv2.md
|
||||
docs/source/en/model_doc/table-transformer.md
|
||||
docs/source/en/model_doc/tapas.md
|
||||
docs/source/en/model_doc/time_series_transformer.md
|
||||
docs/source/en/model_doc/timesformer.md
|
||||
docs/source/en/model_doc/trocr.md
|
||||
docs/source/en/model_doc/ul2.md
|
||||
docs/source/en/model_doc/umt5.md
|
||||
docs/source/en/model_doc/unispeech-sat.md
|
||||
docs/source/en/model_doc/unispeech.md
|
||||
docs/source/en/model_doc/upernet.md
|
||||
docs/source/en/model_doc/videomae.md
|
||||
docs/source/en/model_doc/vilt.md
|
||||
docs/source/en/model_doc/vipllava.md
|
||||
docs/source/en/model_doc/vision-encoder-decoder.md
|
||||
docs/source/en/model_doc/vision-text-dual-encoder.md
|
||||
docs/source/en/model_doc/visual_bert.md
|
||||
docs/source/en/model_doc/vit.md
|
||||
docs/source/en/model_doc/vit_mae.md
|
||||
docs/source/en/model_doc/vit_msn.md
|
||||
docs/source/en/model_doc/vivit.md
|
||||
docs/source/en/model_doc/wav2vec2-conformer.md
|
||||
docs/source/en/model_doc/wav2vec2.md
|
||||
docs/source/en/model_doc/wav2vec2_phoneme.md
|
||||
docs/source/en/model_doc/wavlm.md
|
||||
docs/source/en/model_doc/whisper.md
|
||||
docs/source/en/model_doc/xclip.md
|
||||
docs/source/en/model_doc/xglm.md
|
||||
docs/source/en/model_doc/xlm-roberta-xl.md
|
||||
docs/source/en/model_doc/xlm-roberta.md
|
||||
docs/source/en/model_doc/xlm-v.md
|
||||
docs/source/en/model_doc/xlm.md
|
||||
docs/source/en/model_doc/xlnet.md
|
||||
docs/source/en/model_doc/xls_r.md
|
||||
docs/source/en/model_doc/xlsr_wav2vec2.md
|
||||
docs/source/en/model_doc/xmod.md
|
||||
docs/source/en/model_doc/yolos.md
|
||||
docs/source/en/model_doc/yoso.md
|
||||
docs/source/en/model_memory_anatomy.md
|
||||
docs/source/en/model_sharing.md
|
||||
docs/source/en/notebooks.md
|
||||
docs/source/en/peft.md
|
||||
docs/source/en/perf_hardware.md
|
||||
docs/source/en/perf_torch_compile.md
|
||||
docs/source/en/perf_train_cpu.md
|
||||
docs/source/en/perf_train_gpu_many.md
|
||||
docs/source/en/perf_train_special.md
|
||||
docs/source/en/perplexity.md
|
||||
docs/source/en/philosophy.md
|
||||
docs/source/en/pipeline_webserver.md
|
||||
docs/source/en/pr_checks.md
|
||||
docs/source/en/run_scripts.md
|
||||
docs/source/en/serialization.md
|
||||
docs/source/en/tasks/asr.md
|
||||
docs/source/en/tasks/audio_classification.md
|
||||
docs/source/en/tasks/document_question_answering.md
|
||||
docs/source/en/tasks/idefics.md
|
||||
docs/source/en/tasks/image_captioning.md
|
||||
docs/source/en/tasks/image_classification.md
|
||||
docs/source/en/tasks/language_modeling.md
|
||||
docs/source/en/tasks/masked_language_modeling.md
|
||||
docs/source/en/tasks/monocular_depth_estimation.md
|
||||
docs/source/en/tasks/multiple_choice.md
|
||||
docs/source/en/tasks/object_detection.md
|
||||
docs/source/en/tasks/question_answering.md
|
||||
docs/source/en/tasks/semantic_segmentation.md
|
||||
docs/source/en/tasks/sequence_classification.md
|
||||
docs/source/en/tasks/summarization.md
|
||||
docs/source/en/tasks/text-to-speech.md
|
||||
docs/source/en/tasks/token_classification.md
|
||||
docs/source/en/tasks/translation.md
|
||||
docs/source/en/tasks/video_classification.md
|
||||
docs/source/en/tasks/visual_question_answering.md
|
||||
docs/source/en/tasks/zero_shot_image_classification.md
|
||||
docs/source/en/tasks/zero_shot_object_detection.md
|
||||
docs/source/en/tokenizer_summary.md
|
||||
docs/source/en/training.md
|
||||
docs/source/en/troubleshooting.md
|
||||
src/transformers/activations.py
|
||||
src/transformers/audio_utils.py
|
||||
src/transformers/cli/add_new_model_like.py
|
||||
src/transformers/cli/chat.py
|
||||
src/transformers/cli/download.py
|
||||
src/transformers/cli/serve.py
|
||||
src/transformers/cli/system.py
|
||||
src/transformers/cli/transformers.py
|
||||
src/transformers/configuration_utils.py
|
||||
src/transformers/convert_slow_tokenizer.py
|
||||
src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py
|
||||
src/transformers/data/data_collator.py
|
||||
src/transformers/data/datasets/glue.py
|
||||
src/transformers/data/datasets/squad.py
|
||||
src/transformers/data/metrics/squad_metrics.py
|
||||
src/transformers/data/processors/glue.py
|
||||
src/transformers/data/processors/squad.py
|
||||
src/transformers/data/processors/utils.py
|
||||
src/transformers/data/processors/xnli.py
|
||||
src/transformers/debug_utils.py
|
||||
src/transformers/dependency_versions_check.py
|
||||
src/transformers/dependency_versions_table.py
|
||||
src/transformers/dynamic_module_utils.py
|
||||
src/transformers/feature_extraction_sequence_utils.py
|
||||
src/transformers/feature_extraction_utils.py
|
||||
src/transformers/file_utils.py
|
||||
src/transformers/hf_argparser.py
|
||||
src/transformers/hyperparameter_search.py
|
||||
src/transformers/image_processing_utils.py
|
||||
src/transformers/image_transforms.py
|
||||
src/transformers/image_utils.py
|
||||
src/transformers/integrations/bitsandbytes.py
|
||||
src/transformers/integrations/deepspeed.py
|
||||
src/transformers/integrations/integration_utils.py
|
||||
src/transformers/integrations/peft.py
|
||||
src/transformers/modelcard.py
|
||||
src/transformers/modeling_outputs.py
|
||||
src/transformers/modeling_utils.py
|
||||
src/transformers/models/align/configuration_align.py
|
||||
src/transformers/models/align/modeling_align.py
|
||||
src/transformers/models/altclip/configuration_altclip.py
|
||||
src/transformers/models/altclip/modeling_altclip.py
|
||||
src/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py
|
||||
src/transformers/models/audio_spectrogram_transformer/convert_audio_spectrogram_transformer_original_to_pytorch.py
|
||||
src/transformers/models/auto/auto_factory.py
|
||||
src/transformers/models/auto/configuration_auto.py
|
||||
src/transformers/models/auto/modeling_auto.py
|
||||
src/transformers/models/autoformer/configuration_autoformer.py
|
||||
src/transformers/models/autoformer/modeling_autoformer.py
|
||||
src/transformers/models/bark/convert_suno_to_hf.py
|
||||
src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/beit/convert_beit_unilm_to_pytorch.py
|
||||
src/transformers/models/bert_generation/modeling_bert_generation.py
|
||||
src/transformers/models/biogpt/configuration_biogpt.py
|
||||
src/transformers/models/biogpt/convert_biogpt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/biogpt/modeling_biogpt.py
|
||||
src/transformers/models/bit/configuration_bit.py
|
||||
src/transformers/models/bit/convert_bit_to_pytorch.py
|
||||
src/transformers/models/bit/modeling_bit.py
|
||||
src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/blip/configuration_blip.py
|
||||
src/transformers/models/blip/convert_blip_original_pytorch_to_hf.py
|
||||
src/transformers/models/blip/modeling_blip_text.py
|
||||
src/transformers/models/blip_2/configuration_blip_2.py
|
||||
src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py
|
||||
src/transformers/models/blip_2/modeling_blip_2.py
|
||||
src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/bloom/modeling_bloom.py
|
||||
src/transformers/models/bridgetower/configuration_bridgetower.py
|
||||
src/transformers/models/bridgetower/modeling_bridgetower.py
|
||||
src/transformers/models/bros/convert_bros_to_pytorch.py
|
||||
src/transformers/models/camembert/modeling_camembert.py
|
||||
src/transformers/models/chinese_clip/configuration_chinese_clip.py
|
||||
src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py
|
||||
src/transformers/models/chinese_clip/modeling_chinese_clip.py
|
||||
src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py
|
||||
src/transformers/models/clip/convert_clip_original_pytorch_to_hf.py
|
||||
src/transformers/models/clip/modeling_clip.py
|
||||
src/transformers/models/clipseg/configuration_clipseg.py
|
||||
src/transformers/models/clipseg/convert_clipseg_original_pytorch_to_hf.py
|
||||
src/transformers/models/codegen/modeling_codegen.py
|
||||
src/transformers/models/conditional_detr/convert_conditional_detr_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/convbert/modeling_convbert.py
|
||||
src/transformers/models/convnext/convert_convnext_to_pytorch.py
|
||||
src/transformers/models/convnextv2/configuration_convnextv2.py
|
||||
src/transformers/models/convnextv2/convert_convnextv2_to_pytorch.py
|
||||
src/transformers/models/convnextv2/modeling_convnextv2.py
|
||||
src/transformers/models/cpmant/configuration_cpmant.py
|
||||
src/transformers/models/cpmant/modeling_cpmant.py
|
||||
src/transformers/models/cpmant/tokenization_cpmant.py
|
||||
src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/data2vec/modeling_data2vec_text.py
|
||||
src/transformers/models/decision_transformer/modeling_decision_transformer.py
|
||||
src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py
|
||||
src/transformers/models/deit/convert_deit_timm_to_pytorch.py
|
||||
src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/detr/convert_detr_to_pytorch.py
|
||||
src/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/dinov2/configuration_dinov2.py
|
||||
src/transformers/models/dinov2/convert_dinov2_to_hf.py
|
||||
src/transformers/models/dinov2/modeling_dinov2.py
|
||||
src/transformers/models/distilbert/modeling_distilbert.py
|
||||
src/transformers/models/dit/convert_dit_unilm_to_pytorch.py
|
||||
src/transformers/models/donut/configuration_donut_swin.py
|
||||
src/transformers/models/donut/convert_donut_to_pytorch.py
|
||||
src/transformers/models/donut/modeling_donut_swin.py
|
||||
src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/dpr/modeling_dpr.py
|
||||
src/transformers/models/dpt/configuration_dpt.py
|
||||
src/transformers/models/dpt/convert_dpt_hybrid_to_pytorch.py
|
||||
src/transformers/models/dpt/convert_dpt_to_pytorch.py
|
||||
src/transformers/models/efficientnet/configuration_efficientnet.py
|
||||
src/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py
|
||||
src/transformers/models/efficientnet/modeling_efficientnet.py
|
||||
src/transformers/models/encodec/configuration_encodec.py
|
||||
src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py
|
||||
src/transformers/models/encoder_decoder/modeling_encoder_decoder.py
|
||||
src/transformers/models/ernie/modeling_ernie.py
|
||||
src/transformers/models/esm/configuration_esm.py
|
||||
src/transformers/models/esm/convert_esm.py
|
||||
src/transformers/models/esm/modeling_esm.py
|
||||
src/transformers/models/esm/modeling_esmfold.py
|
||||
src/transformers/models/esm/openfold_utils/chunk_utils.py
|
||||
src/transformers/models/esm/openfold_utils/data_transforms.py
|
||||
src/transformers/models/esm/openfold_utils/feats.py
|
||||
src/transformers/models/esm/openfold_utils/loss.py
|
||||
src/transformers/models/esm/openfold_utils/protein.py
|
||||
src/transformers/models/esm/openfold_utils/residue_constants.py
|
||||
src/transformers/models/esm/openfold_utils/rigid_utils.py
|
||||
src/transformers/models/esm/openfold_utils/tensor_utils.py
|
||||
src/transformers/models/falcon/configuration_falcon.py
|
||||
src/transformers/models/falcon/modeling_falcon.py
|
||||
src/transformers/models/flaubert/configuration_flaubert.py
|
||||
src/transformers/models/flaubert/modeling_flaubert.py
|
||||
src/transformers/models/flava/convert_dalle_to_flava_codebook.py
|
||||
src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py
|
||||
src/transformers/models/flava/modeling_flava.py
|
||||
src/transformers/models/fnet/modeling_fnet.py
|
||||
src/transformers/models/focalnet/configuration_focalnet.py
|
||||
src/transformers/models/focalnet/convert_focalnet_to_hf_format.py
|
||||
src/transformers/models/focalnet/modeling_focalnet.py
|
||||
src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/fsmt/modeling_fsmt.py
|
||||
src/transformers/models/funnel/configuration_funnel.py
|
||||
src/transformers/models/funnel/modeling_funnel.py
|
||||
src/transformers/models/fuyu/convert_fuyu_model_weights_to_hf.py
|
||||
src/transformers/models/gemma/configuration_gemma.py
|
||||
src/transformers/models/gemma/convert_gemma_weights_to_hf.py
|
||||
src/transformers/models/gemma/modeling_gemma.py
|
||||
src/transformers/models/git/configuration_git.py
|
||||
src/transformers/models/git/convert_git_to_pytorch.py
|
||||
src/transformers/models/glpn/configuration_glpn.py
|
||||
src/transformers/models/glpn/convert_glpn_to_pytorch.py
|
||||
src/transformers/models/gpt2/CONVERSION.md
|
||||
src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
|
||||
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
|
||||
src/transformers/models/gpt_neo/modeling_gpt_neo.py
|
||||
src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
|
||||
src/transformers/models/gpt_sw3/convert_megatron_to_pytorch.py
|
||||
src/transformers/models/gptj/configuration_gptj.py
|
||||
src/transformers/models/groupvit/configuration_groupvit.py
|
||||
src/transformers/models/groupvit/convert_groupvit_nvlab_to_hf.py
|
||||
src/transformers/models/hubert/configuration_hubert.py
|
||||
src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/hubert/convert_hubert_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/ibert/configuration_ibert.py
|
||||
src/transformers/models/ibert/modeling_ibert.py
|
||||
src/transformers/models/ibert/quant_modules.py
|
||||
src/transformers/models/idefics/configuration_idefics.py
|
||||
src/transformers/models/idefics/image_processing_idefics.py
|
||||
src/transformers/models/idefics/modeling_idefics.py
|
||||
src/transformers/models/idefics/perceiver.py
|
||||
src/transformers/models/idefics/processing_idefics.py
|
||||
src/transformers/models/idefics/vision.py
|
||||
src/transformers/models/informer/configuration_informer.py
|
||||
src/transformers/models/informer/modeling_informer.py
|
||||
src/transformers/models/instructblip/configuration_instructblip.py
|
||||
src/transformers/models/instructblip/convert_instructblip_original_to_pytorch.py
|
||||
src/transformers/models/instructblip/modeling_instructblip.py
|
||||
src/transformers/models/instructblip/processing_instructblip.py
|
||||
src/transformers/models/jamba/configuration_jamba.py
|
||||
src/transformers/models/jamba/modeling_jamba.py
|
||||
src/transformers/models/kosmos2/convert_kosmos2_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/led/configuration_led.py
|
||||
src/transformers/models/led/modeling_led.py
|
||||
src/transformers/models/levit/convert_levit_timm_to_pytorch.py
|
||||
src/transformers/models/levit/modeling_levit.py
|
||||
src/transformers/models/lilt/configuration_lilt.py
|
||||
src/transformers/models/llama/configuration_llama.py
|
||||
src/transformers/models/llama/convert_llama_weights_to_hf.py
|
||||
src/transformers/models/llama/modeling_llama.py
|
||||
src/transformers/models/llava/configuration_llava.py
|
||||
src/transformers/models/llava/modeling_llava.py
|
||||
src/transformers/models/llava_next/configuration_llava_next.py
|
||||
src/transformers/models/llava_next/modeling_llava_next.py
|
||||
src/transformers/models/longformer/configuration_longformer.py
|
||||
src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py
|
||||
src/transformers/models/longt5/configuration_longt5.py
|
||||
src/transformers/models/luke/configuration_luke.py
|
||||
src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/luke/modeling_luke.py
|
||||
src/transformers/models/lxmert/configuration_lxmert.py
|
||||
src/transformers/models/lxmert/modeling_lxmert.py
|
||||
src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/m2m_100/modeling_m2m_100.py
|
||||
src/transformers/models/marian/configuration_marian.py
|
||||
src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py
|
||||
src/transformers/models/marian/convert_marian_to_pytorch.py
|
||||
src/transformers/models/markuplm/configuration_markuplm.py
|
||||
src/transformers/models/markuplm/feature_extraction_markuplm.py
|
||||
src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/maskformer/configuration_maskformer_swin.py
|
||||
src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/maskformer/convert_maskformer_resnet_to_pytorch.py
|
||||
src/transformers/models/maskformer/convert_maskformer_swin_to_pytorch.py
|
||||
src/transformers/models/maskformer/modeling_maskformer_swin.py
|
||||
src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py
|
||||
src/transformers/models/megatron_bert/modeling_megatron_bert.py
|
||||
src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py
|
||||
src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py
|
||||
src/transformers/models/mgp_str/configuration_mgp_str.py
|
||||
src/transformers/models/mgp_str/modeling_mgp_str.py
|
||||
src/transformers/models/mistral/configuration_mistral.py
|
||||
src/transformers/models/mistral/modeling_mistral.py
|
||||
src/transformers/models/mixtral/configuration_mixtral.py
|
||||
src/transformers/models/mixtral/modeling_mixtral.py
|
||||
src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py
|
||||
src/transformers/models/mobilenet_v2/configuration_mobilenet_v2.py
|
||||
src/transformers/models/mobilevit/configuration_mobilevit.py
|
||||
src/transformers/models/mobilevit/convert_mlcvnets_to_pytorch.py
|
||||
src/transformers/models/mobilevitv2/convert_mlcvnets_to_pytorch.py
|
||||
src/transformers/models/mpnet/configuration_mpnet.py
|
||||
src/transformers/models/mpnet/modeling_mpnet.py
|
||||
src/transformers/models/mpt/configuration_mpt.py
|
||||
src/transformers/models/mpt/modeling_mpt.py
|
||||
src/transformers/models/mra/configuration_mra.py
|
||||
src/transformers/models/mra/convert_mra_pytorch_to_pytorch.py
|
||||
src/transformers/models/mra/modeling_mra.py
|
||||
src/transformers/models/mt5/configuration_mt5.py
|
||||
src/transformers/models/mt5/modeling_mt5.py
|
||||
src/transformers/models/musicgen/convert_musicgen_transformers.py
|
||||
src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py
|
||||
src/transformers/models/mvp/modeling_mvp.py
|
||||
src/transformers/models/nllb_moe/configuration_nllb_moe.py
|
||||
src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/nllb_moe/modeling_nllb_moe.py
|
||||
src/transformers/models/nougat/convert_nougat_to_hf.py
|
||||
src/transformers/models/nystromformer/configuration_nystromformer.py
|
||||
src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/nystromformer/modeling_nystromformer.py
|
||||
src/transformers/models/oneformer/convert_to_hf_oneformer.py
|
||||
src/transformers/models/openai/modeling_openai.py
|
||||
src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/owlvit/configuration_owlvit.py
|
||||
src/transformers/models/pegasus_x/modeling_pegasus_x.py
|
||||
src/transformers/models/perceiver/configuration_perceiver.py
|
||||
src/transformers/models/perceiver/convert_perceiver_haiku_to_pytorch.py
|
||||
src/transformers/models/persimmon/convert_persimmon_weights_to_hf.py
|
||||
src/transformers/models/persimmon/modeling_persimmon.py
|
||||
src/transformers/models/pix2struct/configuration_pix2struct.py
|
||||
src/transformers/models/pix2struct/convert_pix2struct_original_pytorch_to_hf.py
|
||||
src/transformers/models/pix2struct/image_processing_pix2struct.py
|
||||
src/transformers/models/pix2struct/processing_pix2struct.py
|
||||
src/transformers/models/plbart/convert_plbart_original_checkpoint_to_torch.py
|
||||
src/transformers/models/poolformer/convert_poolformer_original_to_pytorch.py
|
||||
src/transformers/models/pop2piano/convert_pop2piano_weights_to_hf.py
|
||||
src/transformers/models/pop2piano/feature_extraction_pop2piano.py
|
||||
src/transformers/models/pop2piano/processing_pop2piano.py
|
||||
src/transformers/models/pop2piano/tokenization_pop2piano.py
|
||||
src/transformers/models/prophetnet/configuration_prophetnet.py
|
||||
src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/prophetnet/modeling_prophetnet.py
|
||||
src/transformers/models/pvt/configuration_pvt.py
|
||||
src/transformers/models/pvt/convert_pvt_to_pytorch.py
|
||||
src/transformers/models/pvt/image_processing_pvt.py
|
||||
src/transformers/models/pvt/modeling_pvt.py
|
||||
src/transformers/models/qwen2/configuration_qwen2.py
|
||||
src/transformers/models/qwen2/modeling_qwen2.py
|
||||
src/transformers/models/qwen2/tokenization_qwen2.py
|
||||
src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
|
||||
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
|
||||
src/transformers/models/rag/configuration_rag.py
|
||||
src/transformers/models/rag/modeling_rag.py
|
||||
src/transformers/models/rag/retrieval_rag.py
|
||||
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
|
||||
src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
|
||||
src/transformers/models/regnet/configuration_regnet.py
|
||||
src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
|
||||
src/transformers/models/regnet/convert_regnet_to_pytorch.py
|
||||
src/transformers/models/rembert/configuration_rembert.py
|
||||
src/transformers/models/rembert/modeling_rembert.py
|
||||
src/transformers/models/resnet/convert_resnet_to_pytorch.py
|
||||
src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/roberta_prelayernorm/convert_roberta_prelayernorm_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/roc_bert/configuration_roc_bert.py
|
||||
src/transformers/models/roformer/modeling_roformer.py
|
||||
src/transformers/models/rwkv/configuration_rwkv.py
|
||||
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
|
||||
src/transformers/models/rwkv/modeling_rwkv.py
|
||||
src/transformers/models/sam/configuration_sam.py
|
||||
src/transformers/models/sam/convert_sam_to_hf.py
|
||||
src/transformers/models/sam/image_processing_sam.py
|
||||
src/transformers/models/sam/modeling_sam.py
|
||||
src/transformers/models/sam/processing_sam.py
|
||||
src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py
|
||||
src/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py
|
||||
src/transformers/models/segformer/configuration_segformer.py
|
||||
src/transformers/models/segformer/convert_segformer_original_to_pytorch.py
|
||||
src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py
|
||||
src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py
|
||||
src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py
|
||||
src/transformers/models/speecht5/configuration_speecht5.py
|
||||
src/transformers/models/speecht5/convert_hifigan.py
|
||||
src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/speecht5/number_normalizer.py
|
||||
src/transformers/models/splinter/configuration_splinter.py
|
||||
src/transformers/models/splinter/modeling_splinter.py
|
||||
src/transformers/models/squeezebert/modeling_squeezebert.py
|
||||
src/transformers/models/stablelm/modeling_stablelm.py
|
||||
src/transformers/models/starcoder2/modeling_starcoder2.py
|
||||
src/transformers/models/swiftformer/configuration_swiftformer.py
|
||||
src/transformers/models/swiftformer/convert_swiftformer_original_to_hf.py
|
||||
src/transformers/models/swiftformer/modeling_swiftformer.py
|
||||
src/transformers/models/swin/convert_swin_simmim_to_pytorch.py
|
||||
src/transformers/models/swin/convert_swin_timm_to_pytorch.py
|
||||
src/transformers/models/swin2sr/configuration_swin2sr.py
|
||||
src/transformers/models/swin2sr/convert_swin2sr_original_to_pytorch.py
|
||||
src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py
|
||||
src/transformers/models/swinv2/modeling_swinv2.py
|
||||
src/transformers/models/switch_transformers/configuration_switch_transformers.py
|
||||
src/transformers/models/switch_transformers/convert_big_switch.py
|
||||
src/transformers/models/switch_transformers/modeling_switch_transformers.py
|
||||
src/transformers/models/t5/configuration_t5.py
|
||||
src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py
|
||||
src/transformers/models/t5/modeling_t5.py
|
||||
src/transformers/models/table_transformer/configuration_table_transformer.py
|
||||
src/transformers/models/table_transformer/convert_table_transformer_to_hf.py
|
||||
src/transformers/models/table_transformer/convert_table_transformer_to_hf_no_timm.py
|
||||
src/transformers/models/tapas/configuration_tapas.py
|
||||
src/transformers/models/tapas/modeling_tapas.py
|
||||
src/transformers/models/timesformer/convert_timesformer_to_pytorch.py
|
||||
src/transformers/models/timm_backbone/configuration_timm_backbone.py
|
||||
src/transformers/models/timm_backbone/modeling_timm_backbone.py
|
||||
src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py
|
||||
src/transformers/models/umt5/configuration_umt5.py
|
||||
src/transformers/models/umt5/convert_umt5_checkpoint_to_pytorch.py
|
||||
src/transformers/models/umt5/modeling_umt5.py
|
||||
src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/unispeech_sat/configuration_unispeech_sat.py
|
||||
src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/upernet/configuration_upernet.py
|
||||
src/transformers/models/upernet/convert_convnext_upernet_to_pytorch.py
|
||||
src/transformers/models/upernet/convert_swin_upernet_to_pytorch.py
|
||||
src/transformers/models/videomae/configuration_videomae.py
|
||||
src/transformers/models/videomae/convert_videomae_to_pytorch.py
|
||||
src/transformers/models/vilt/configuration_vilt.py
|
||||
src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
|
||||
src/transformers/models/vipllava/configuration_vipllava.py
|
||||
src/transformers/models/vipllava/modeling_vipllava.py
|
||||
src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py
|
||||
src/transformers/models/visual_bert/convert_visual_bert_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/visual_bert/modeling_visual_bert.py
|
||||
src/transformers/models/vit/convert_dino_to_pytorch.py
|
||||
src/transformers/models/vit/convert_vit_timm_to_pytorch.py
|
||||
src/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py
|
||||
src/transformers/models/vit_msn/configuration_vit_msn.py
|
||||
src/transformers/models/vit_msn/convert_msn_to_pytorch.py
|
||||
src/transformers/models/vivit/configuration_vivit.py
|
||||
src/transformers/models/vivit/image_processing_vivit.py
|
||||
src/transformers/models/vivit/modeling_vivit.py
|
||||
src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py
|
||||
src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/whisper/convert_openai_to_hf.py
|
||||
src/transformers/models/whisper/english_normalizer.py
|
||||
src/transformers/models/x_clip/configuration_x_clip.py
|
||||
src/transformers/models/x_clip/convert_x_clip_original_pytorch_to_hf.py
|
||||
src/transformers/models/xglm/configuration_xglm.py
|
||||
src/transformers/models/xglm/convert_xglm_original_ckpt_to_trfms.py
|
||||
src/transformers/models/xglm/modeling_xglm.py
|
||||
src/transformers/models/xlm/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/xlm/modeling_xlm.py
|
||||
src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
|
||||
src/transformers/models/xlm_roberta_xl/convert_xlm_roberta_xl_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
|
||||
src/transformers/models/xlnet/modeling_xlnet.py
|
||||
src/transformers/models/xmod/convert_xmod_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/yolos/convert_yolos_to_pytorch.py
|
||||
src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py
|
||||
src/transformers/models/yoso/modeling_yoso.py
|
||||
src/transformers/models/zamba/configuration_zamba.py
|
||||
src/transformers/models/zamba/modeling_zamba.py
|
||||
src/transformers/optimization.py
|
||||
src/transformers/pipelines/audio_classification.py
|
||||
src/transformers/pipelines/audio_utils.py
|
||||
src/transformers/pipelines/automatic_speech_recognition.py
|
||||
src/transformers/pipelines/base.py
|
||||
src/transformers/pipelines/depth_estimation.py
|
||||
src/transformers/pipelines/document_question_answering.py
|
||||
src/transformers/pipelines/feature_extraction.py
|
||||
src/transformers/pipelines/fill_mask.py
|
||||
src/transformers/pipelines/image_classification.py
|
||||
src/transformers/pipelines/image_segmentation.py
|
||||
src/transformers/pipelines/mask_generation.py
|
||||
src/transformers/pipelines/object_detection.py
|
||||
src/transformers/pipelines/pt_utils.py
|
||||
src/transformers/pipelines/table_question_answering.py
|
||||
src/transformers/pipelines/text_classification.py
|
||||
src/transformers/pipelines/token_classification.py
|
||||
src/transformers/pipelines/video_classification.py
|
||||
src/transformers/pipelines/zero_shot_audio_classification.py
|
||||
src/transformers/pipelines/zero_shot_classification.py
|
||||
src/transformers/pipelines/zero_shot_image_classification.py
|
||||
src/transformers/pipelines/zero_shot_object_detection.py
|
||||
src/transformers/processing_utils.py
|
||||
src/transformers/pytorch_utils.py
|
||||
src/transformers/quantizers/auto.py
|
||||
src/transformers/quantizers/base.py
|
||||
src/transformers/quantizers/quantizer_awq.py
|
||||
src/transformers/quantizers/quantizer_bnb_4bit.py
|
||||
src/transformers/quantizers/quantizer_bnb_8bit.py
|
||||
src/transformers/quantizers/quantizer_gptq.py
|
||||
src/transformers/quantizers/quantizers_utils.py
|
||||
src/transformers/testing_utils.py
|
||||
src/transformers/time_series_utils.py
|
||||
src/transformers/tokenization_python.py
|
||||
src/transformers/tokenization_utils_base.py
|
||||
src/transformers/trainer.py
|
||||
src/transformers/trainer_callback.py
|
||||
src/transformers/trainer_pt_utils.py
|
||||
src/transformers/trainer_seq2seq.py
|
||||
src/transformers/trainer_utils.py
|
||||
src/transformers/training_args.py
|
||||
src/transformers/training_args_seq2seq.py
|
||||
src/transformers/utils/backbone_utils.py
|
||||
src/transformers/utils/constants.py
|
||||
src/transformers/utils/doc.py
|
||||
src/transformers/utils/dummy_detectron2_objects.py
|
||||
src/transformers/utils/dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects.py
|
||||
src/transformers/utils/dummy_music_objects.py
|
||||
src/transformers/utils/dummy_pt_objects.py
|
||||
src/transformers/utils/dummy_sentencepiece_and_tokenizers_objects.py
|
||||
src/transformers/utils/dummy_speech_objects.py
|
||||
src/transformers/utils/dummy_tokenizers_objects.py
|
||||
src/transformers/utils/dummy_vision_objects.py
|
||||
src/transformers/utils/generic.py
|
||||
src/transformers/utils/hp_naming.py
|
||||
src/transformers/utils/hub.py
|
||||
src/transformers/utils/import_utils.py
|
||||
src/transformers/utils/logging.py
|
||||
src/transformers/utils/notebook.py
|
||||
src/transformers/utils/peft_utils.py
|
||||
src/transformers/utils/quantization_config.py
|
||||
src/transformers/utils/sentencepiece_model_pb2.py
|
||||
src/transformers/utils/sentencepiece_model_pb2_new.py
|
||||
src/transformers/utils/versions.py
|
||||
1598
utils/notification_service.py
Normal file
1598
utils/notification_service.py
Normal file
File diff suppressed because it is too large
Load Diff
384
utils/notification_service_doc_tests.py
Normal file
384
utils/notification_service_doc_tests.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
from get_ci_error_statistics import get_jobs
|
||||
from slack_sdk import WebClient
|
||||
|
||||
|
||||
client = WebClient(token=os.environ["CI_SLACK_BOT_TOKEN"])
|
||||
|
||||
|
||||
def handle_test_results(test_results):
|
||||
expressions = test_results.split(" ")
|
||||
|
||||
failed = 0
|
||||
success = 0
|
||||
|
||||
# When the output is short enough, the output is surrounded by = signs: "== OUTPUT =="
|
||||
# When it is too long, those signs are not present.
|
||||
time_spent = expressions[-2] if "=" in expressions[-1] else expressions[-1]
|
||||
|
||||
for i, expression in enumerate(expressions):
|
||||
if "failed" in expression:
|
||||
failed += int(expressions[i - 1])
|
||||
if "passed" in expression:
|
||||
success += int(expressions[i - 1])
|
||||
|
||||
return failed, success, time_spent
|
||||
|
||||
|
||||
def extract_first_line_failure(failures_short_lines):
|
||||
failures = {}
|
||||
file = None
|
||||
in_error = False
|
||||
for line in failures_short_lines.split("\n"):
|
||||
if re.search(r"_ \[doctest\]", line):
|
||||
in_error = True
|
||||
file = line.split(" ")[2]
|
||||
elif in_error and not line.split(" ")[0].isdigit():
|
||||
failures[file] = line
|
||||
in_error = False
|
||||
|
||||
return failures
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, title: str, doc_test_results: dict):
|
||||
self.title = title
|
||||
|
||||
self.n_success = sum(job_result["n_success"] for job_result in doc_test_results.values())
|
||||
self.n_failures = sum(job_result["n_failures"] for job_result in doc_test_results.values())
|
||||
self.n_tests = self.n_success + self.n_failures
|
||||
|
||||
# Failures and success of the modeling tests
|
||||
self.doc_test_results = doc_test_results
|
||||
|
||||
@property
|
||||
def time(self) -> str:
|
||||
all_results = [*self.doc_test_results.values()]
|
||||
time_spent = [r["time_spent"].split(", ")[0] for r in all_results if len(r["time_spent"])]
|
||||
total_secs = 0
|
||||
|
||||
for timings in time_spent:
|
||||
time_parts = timings.split(":")
|
||||
|
||||
# Time can be formatted as xx:xx:xx, as .xx, or as x.xx if the time spent was less than a minute.
|
||||
if len(time_parts) == 1:
|
||||
time_parts = [0, 0, time_parts[0]]
|
||||
|
||||
hours, minutes, seconds = int(time_parts[0]), int(time_parts[1]), float(time_parts[2])
|
||||
total_secs += hours * 3600 + minutes * 60 + seconds
|
||||
|
||||
hours, minutes, seconds = total_secs // 3600, (total_secs % 3600) // 60, total_secs % 60
|
||||
return f"{int(hours)}h{int(minutes)}m{int(seconds)}s"
|
||||
|
||||
@property
|
||||
def header(self) -> dict:
|
||||
return {"type": "header", "text": {"type": "plain_text", "text": self.title}}
|
||||
|
||||
@property
|
||||
def no_failures(self) -> dict:
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": f"🌞 There were no failures: all {self.n_tests} tests passed. The suite ran in {self.time}.",
|
||||
"emoji": True,
|
||||
},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def failures(self) -> dict:
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": (
|
||||
f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in"
|
||||
f" {self.time}."
|
||||
),
|
||||
"emoji": True,
|
||||
},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def category_failures(self) -> list[dict]:
|
||||
failure_blocks = []
|
||||
|
||||
MAX_ERROR_TEXT = 3000 - len("The following examples had failures:\n\n\n\n") - len("[Truncated]\n")
|
||||
line_length = 40
|
||||
category_failures = {k: v["failed"] for k, v in doc_test_results.items() if isinstance(v, dict)}
|
||||
|
||||
def single_category_failures(category, failures):
|
||||
text = ""
|
||||
if len(failures) == 0:
|
||||
return ""
|
||||
text += f"*{category} failures*:".ljust(line_length // 2).rjust(line_length // 2) + "\n"
|
||||
|
||||
for idx, failure in enumerate(failures):
|
||||
new_text = text + f"`{failure}`\n"
|
||||
if len(new_text) > MAX_ERROR_TEXT:
|
||||
text = text + "[Truncated]\n"
|
||||
break
|
||||
text = new_text
|
||||
|
||||
return text
|
||||
|
||||
for category, failures in category_failures.items():
|
||||
report = single_category_failures(category, failures)
|
||||
if len(report) == 0:
|
||||
continue
|
||||
block = {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": f"The following examples had failures:\n\n\n{report}\n",
|
||||
},
|
||||
}
|
||||
failure_blocks.append(block)
|
||||
|
||||
return failure_blocks
|
||||
|
||||
@property
|
||||
def payload(self) -> str:
|
||||
blocks = [self.header]
|
||||
|
||||
if self.n_failures > 0:
|
||||
blocks.append(self.failures)
|
||||
|
||||
if self.n_failures > 0:
|
||||
blocks.extend(self.category_failures)
|
||||
|
||||
if self.n_failures == 0:
|
||||
blocks.append(self.no_failures)
|
||||
|
||||
return json.dumps(blocks)
|
||||
|
||||
@staticmethod
|
||||
def error_out():
|
||||
payload = [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "There was an issue running the tests.",
|
||||
},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
print("Sending the following payload")
|
||||
print(json.dumps({"blocks": json.loads(payload)}))
|
||||
|
||||
client.chat_postMessage(
|
||||
channel=SLACK_REPORT_CHANNEL_ID,
|
||||
text="There was an issue running the tests.",
|
||||
blocks=payload,
|
||||
)
|
||||
|
||||
def post(self):
|
||||
print("Sending the following payload")
|
||||
print(json.dumps({"blocks": json.loads(self.payload)}))
|
||||
|
||||
text = f"{self.n_failures} failures out of {self.n_tests} tests," if self.n_failures else "All tests passed."
|
||||
|
||||
self.thread_ts = client.chat_postMessage(
|
||||
channel=SLACK_REPORT_CHANNEL_ID,
|
||||
blocks=self.payload,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def get_reply_blocks(self, job_name, job_link, failures, text):
|
||||
# `text` must be less than 3001 characters in Slack SDK
|
||||
# keep some room for adding "[Truncated]" when necessary
|
||||
MAX_ERROR_TEXT = 3000 - len("[Truncated]")
|
||||
|
||||
failure_text = ""
|
||||
for key, value in failures.items():
|
||||
new_text = failure_text + f"*{key}*\n_{value}_\n\n"
|
||||
if len(new_text) > MAX_ERROR_TEXT:
|
||||
# `failure_text` here has length <= 3000
|
||||
failure_text = failure_text + "[Truncated]"
|
||||
break
|
||||
# `failure_text` here has length <= MAX_ERROR_TEXT
|
||||
failure_text = new_text
|
||||
|
||||
title = job_name
|
||||
content = {"type": "section", "text": {"type": "mrkdwn", "text": text}}
|
||||
|
||||
if job_link is not None:
|
||||
content["accessory"] = {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "GitHub Action job", "emoji": True},
|
||||
"url": job_link,
|
||||
}
|
||||
|
||||
return [
|
||||
{"type": "header", "text": {"type": "plain_text", "text": title, "emoji": True}},
|
||||
content,
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": failure_text}},
|
||||
]
|
||||
|
||||
def post_reply(self):
|
||||
if self.thread_ts is None:
|
||||
raise ValueError("Can only post reply if a post has been made.")
|
||||
|
||||
sorted_dict = sorted(self.doc_test_results.items(), key=lambda t: t[0])
|
||||
for job_name, job_result in sorted_dict:
|
||||
if len(job_result["failures"]) > 0:
|
||||
text = f"*Num failures* :{len(job_result['failed'])} \n"
|
||||
failures = job_result["failures"]
|
||||
blocks = self.get_reply_blocks(job_name, job_result["job_link"], failures, text=text)
|
||||
|
||||
print("Sending the following reply")
|
||||
print(json.dumps({"blocks": blocks}))
|
||||
|
||||
client.chat_postMessage(
|
||||
channel=SLACK_REPORT_CHANNEL_ID,
|
||||
text=f"Results for {job_name}",
|
||||
blocks=blocks,
|
||||
thread_ts=self.thread_ts["ts"],
|
||||
)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def retrieve_artifact(name: str):
|
||||
_artifact = {}
|
||||
|
||||
if os.path.exists(name):
|
||||
files = os.listdir(name)
|
||||
for file in files:
|
||||
try:
|
||||
with open(os.path.join(name, file), encoding="utf-8") as f:
|
||||
_artifact[file.split(".")[0]] = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
raise ValueError(f"Could not open {os.path.join(name, file)}.") from e
|
||||
|
||||
return _artifact
|
||||
|
||||
|
||||
def retrieve_available_artifacts():
|
||||
class Artifact:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.paths = []
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def add_path(self, path: str):
|
||||
self.paths.append({"name": self.name, "path": path})
|
||||
|
||||
_available_artifacts: dict[str, Artifact] = {}
|
||||
|
||||
directories = filter(os.path.isdir, os.listdir())
|
||||
for directory in directories:
|
||||
artifact_name = directory
|
||||
if artifact_name not in _available_artifacts:
|
||||
_available_artifacts[artifact_name] = Artifact(artifact_name)
|
||||
|
||||
_available_artifacts[artifact_name].add_path(directory)
|
||||
|
||||
return _available_artifacts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SLACK_REPORT_CHANNEL_ID = os.environ["SLACK_REPORT_CHANNEL"]
|
||||
|
||||
github_actions_jobs = get_jobs(
|
||||
workflow_run_id=os.environ["GITHUB_RUN_ID"], token=os.environ["ACCESS_REPO_INFO_TOKEN"]
|
||||
)
|
||||
|
||||
artifact_name_to_job_map = {}
|
||||
for job in github_actions_jobs:
|
||||
for step in job["steps"]:
|
||||
if step["name"].startswith("Test suite reports artifacts: "):
|
||||
artifact_name = step["name"][len("Test suite reports artifacts: ") :]
|
||||
artifact_name_to_job_map[artifact_name] = job
|
||||
break
|
||||
|
||||
available_artifacts = retrieve_available_artifacts()
|
||||
|
||||
doc_test_results = {}
|
||||
# `artifact_key` is the artifact path
|
||||
for artifact_obj in available_artifacts.values():
|
||||
artifact_path = artifact_obj.paths[0]
|
||||
if not artifact_path["path"].startswith("doc_tests_gpu_test_reports_"):
|
||||
continue
|
||||
|
||||
# change "_" back to "/" (to show the job name as path)
|
||||
job_name = artifact_path["path"].replace("doc_tests_gpu_test_reports_", "").replace("_", "/")
|
||||
|
||||
# This dict (for each job) will contain all the information relative to each doc test job, in particular:
|
||||
# - failed: list of failed tests
|
||||
# - failures: dict in the format 'test': 'error_message'
|
||||
job_result = {}
|
||||
doc_test_results[job_name] = job_result
|
||||
|
||||
job = artifact_name_to_job_map[artifact_path["path"]]
|
||||
job_result["job_link"] = job["html_url"]
|
||||
job_result["category"] = "Python Examples" if job_name.startswith("src/") else "MD Examples"
|
||||
|
||||
artifact = retrieve_artifact(artifact_path["path"])
|
||||
if "stats" in artifact:
|
||||
failed, success, time_spent = handle_test_results(artifact["stats"])
|
||||
job_result["n_failures"] = failed
|
||||
job_result["n_success"] = success
|
||||
job_result["time_spent"] = time_spent[1:-1] + ", "
|
||||
job_result["failed"] = []
|
||||
job_result["failures"] = {}
|
||||
|
||||
all_failures = extract_first_line_failure(artifact["failures_short"])
|
||||
for line in artifact["summary_short"].split("\n"):
|
||||
if re.search("FAILED", line):
|
||||
line = line.replace("FAILED ", "")
|
||||
line = line.split()[0].replace("\n", "")
|
||||
|
||||
if "::" in line:
|
||||
file_path, test = line.split("::")
|
||||
else:
|
||||
file_path, test = line, line
|
||||
|
||||
job_result["failed"].append(test)
|
||||
failure = all_failures.get(test, "N/A")
|
||||
job_result["failures"][test] = failure
|
||||
|
||||
# Save and to be uploaded as artifact
|
||||
os.makedirs("doc_test_results", exist_ok=True)
|
||||
with open("doc_test_results/doc_test_results.json", "w", encoding="UTF-8") as fp:
|
||||
json.dump(doc_test_results, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
message = Message("[INFO] Results of the doc tests.", doc_test_results)
|
||||
message.post()
|
||||
message.post_reply()
|
||||
155
utils/patch_helper.py
Normal file
155
utils/patch_helper.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This should help you prepare a patch, automatically extracting the commits to cherry-pick
|
||||
in chronological order to avoid merge conflicts. An equivalent way to do this is to use
|
||||
`git log --pretty=oneline HEAD...v4.41.0` and grep.
|
||||
|
||||
Potential TODO: automatically cherry-picks them.
|
||||
|
||||
Pass in a list of PR:
|
||||
`python utils/patch_helper.py --prs 31108 31054 31008 31010 31004`
|
||||
will produce the following:
|
||||
```bash
|
||||
Skipping invalid version tag: list
|
||||
Skipping invalid version tag: localattn1
|
||||
Git cherry-pick commands to run:
|
||||
git cherry-pick 03935d300d60110bb86edb49d2315089cfb19789 #2024-05-24 11:00:59+02:00
|
||||
git cherry-pick bdb9106f247fca48a71eb384be25dbbd29b065a8 #2024-05-24 19:02:55+02:00
|
||||
git cherry-pick 84c4b72ee99e8e65a8a5754a5f9d6265b45cf67e #2024-05-27 10:34:14+02:00
|
||||
git cherry-pick 936ab7bae5e040ec58994cb722dd587b9ab26581 #2024-05-28 11:56:05+02:00
|
||||
git cherry-pick 0bef4a273825d2cfc52ddfe62ba486ee61cc116f #2024-05-29 13:33:26+01:00
|
||||
```
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
LABEL = "for patch" # Replace with your label
|
||||
REPO = "huggingface/transformers" # Optional if already in correct repo
|
||||
|
||||
|
||||
def get_release_branch_name():
|
||||
"""Derive branch name from transformers version."""
|
||||
major, minor, *_ = transformers.__version__.split(".")
|
||||
major = int(major)
|
||||
minor = int(minor)
|
||||
|
||||
if minor == 0:
|
||||
# Handle major version rollback, e.g., from 5.0 to 4.latest (if ever needed)
|
||||
major -= 1
|
||||
# You'll need logic to determine the last minor of the previous major version
|
||||
raise ValueError("Minor version is 0; need logic to find previous major version's last minor")
|
||||
|
||||
return f"v{major}.{minor}-release"
|
||||
|
||||
|
||||
def checkout_branch(branch):
|
||||
"""Checkout the target branch."""
|
||||
try:
|
||||
subprocess.run(["git", "checkout", branch], check=True)
|
||||
print(f"[SUCCESS] Checked out branch: {branch}")
|
||||
except subprocess.CalledProcessError:
|
||||
print(f"[FAIL] Failed to checkout branch: {branch}. Does it exist?")
|
||||
exit(1)
|
||||
|
||||
|
||||
def get_prs_by_label(label):
|
||||
"""Call gh CLI to get PRs with a specific label."""
|
||||
cmd = [
|
||||
"gh",
|
||||
"pr",
|
||||
"list",
|
||||
"--label",
|
||||
label,
|
||||
"--state",
|
||||
"all",
|
||||
"--json",
|
||||
"number,title,mergeCommit,url",
|
||||
"--limit",
|
||||
"100",
|
||||
]
|
||||
result = subprocess.run(cmd, check=False, capture_output=True, text=True)
|
||||
result.check_returncode()
|
||||
prs = json.loads(result.stdout)
|
||||
for pr in prs:
|
||||
is_merged = pr.get("mergeCommit", {})
|
||||
if is_merged:
|
||||
pr["oid"] = is_merged.get("oid")
|
||||
return prs
|
||||
|
||||
|
||||
def get_commit_timestamp(commit_sha):
|
||||
"""Get UNIX timestamp of a commit using git."""
|
||||
result = subprocess.run(
|
||||
["git", "show", "-s", "--format=%ct", commit_sha], check=False, capture_output=True, text=True
|
||||
)
|
||||
result.check_returncode()
|
||||
return int(result.stdout.strip())
|
||||
|
||||
|
||||
def cherry_pick_commit(sha):
|
||||
"""Cherry-pick a given commit SHA."""
|
||||
try:
|
||||
subprocess.run(["git", "cherry-pick", sha], check=True)
|
||||
print(f"[SUCCESS] Cherry-picked commit {sha}")
|
||||
except subprocess.CalledProcessError:
|
||||
print(f"[WARNING] Failed to cherry-pick {sha}. Manual intervention required.")
|
||||
|
||||
|
||||
def commit_in_history(commit_sha, base_branch="HEAD"):
|
||||
"""Return True if commit is already part of base_branch history."""
|
||||
result = subprocess.run(
|
||||
["git", "merge-base", "--is-ancestor", commit_sha, base_branch],
|
||||
check=False,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def main(verbose=False):
|
||||
branch = get_release_branch_name()
|
||||
checkout_branch(branch)
|
||||
prs = get_prs_by_label(LABEL)
|
||||
# Attach commit timestamps
|
||||
for pr in prs:
|
||||
sha = pr.get("oid")
|
||||
if sha:
|
||||
pr["timestamp"] = get_commit_timestamp(sha)
|
||||
else:
|
||||
print("\n" + "=" * 80)
|
||||
print(f"[WARNING] PR #{pr['number']} ({sha}) is NOT in main!")
|
||||
print("[WARNING] A core maintainer must review this before cherry-picking.")
|
||||
print("=" * 80 + "\n")
|
||||
# Sort by commit timestamp (ascending)
|
||||
prs = [pr for pr in prs if pr.get("timestamp") is not None]
|
||||
prs.sort(key=lambda pr: pr["timestamp"])
|
||||
for pr in prs:
|
||||
sha = pr.get("oid")
|
||||
if sha:
|
||||
if commit_in_history(sha):
|
||||
if verbose:
|
||||
print(f"[INFO] PR #{pr['number']} ({pr['title']}) already in history. Skipping.")
|
||||
else:
|
||||
print(f"[INFO] PR #{pr['number']} ({pr['title']}) not in history. Cherry-picking...")
|
||||
cherry_pick_commit(sha)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
175
utils/pr_slow_ci_models.py
Normal file
175
utils/pr_slow_ci_models.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is used to get the models for which to run slow CI.
|
||||
|
||||
A new model added in a pull request will be included, as well as models specified in a GitHub pull request's comment
|
||||
with a prefix `run-slow`, `run_slow` or `run slow`. For example, the commit message `run_slow: bert, gpt2` will give
|
||||
`bert` and `gpt2`.
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python utils/pr_slow_ci_models.py
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os.path
|
||||
import re
|
||||
import string
|
||||
from pathlib import Path
|
||||
|
||||
from git import Repo
|
||||
|
||||
|
||||
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
def get_new_python_files_between_commits(base_commit: str, commits: list[str]) -> list[str]:
|
||||
"""
|
||||
Get the list of added python files between a base commit and one or several commits.
|
||||
|
||||
Args:
|
||||
repo (`git.Repo`):
|
||||
A git repository (for instance the Transformers repo).
|
||||
base_commit (`str`):
|
||||
The commit reference of where to compare for the diff. This is the current commit, not the branching point!
|
||||
commits (`List[str]`):
|
||||
The list of commits with which to compare the repo at `base_commit` (so the branching point).
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of python files added between a base commit and one or several commits.
|
||||
"""
|
||||
code_diff = []
|
||||
for commit in commits:
|
||||
for diff_obj in commit.diff(base_commit):
|
||||
# We always add new python files
|
||||
if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"):
|
||||
code_diff.append(diff_obj.b_path)
|
||||
|
||||
return code_diff
|
||||
|
||||
|
||||
def get_new_python_files(diff_with_last_commit=False) -> list[str]:
|
||||
"""
|
||||
Return a list of python files that have been added between the current head and the main branch.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of python files added.
|
||||
"""
|
||||
repo = Repo(PATH_TO_REPO)
|
||||
|
||||
try:
|
||||
# For the cases where the main branch exists locally
|
||||
main = repo.refs.main
|
||||
except AttributeError:
|
||||
# On GitHub Actions runners, it doesn't have local main branch
|
||||
main = repo.remotes.origin.refs.main
|
||||
|
||||
if not diff_with_last_commit:
|
||||
print(f"main is at {main.commit}")
|
||||
print(f"Current head is at {repo.head.commit}")
|
||||
|
||||
commits = repo.merge_base(main, repo.head)
|
||||
for commit in commits:
|
||||
print(f"Branching commit: {commit}")
|
||||
else:
|
||||
print(f"main is at {main.commit}")
|
||||
commits = main.commit.parents
|
||||
for commit in commits:
|
||||
print(f"Parent commit: {commit}")
|
||||
|
||||
return get_new_python_files_between_commits(repo.head.commit, commits)
|
||||
|
||||
|
||||
def get_new_model(diff_with_last_commit=False):
|
||||
new_files = get_new_python_files(diff_with_last_commit)
|
||||
reg = re.compile(r"src/transformers/models/(.*)/modeling_.*\.py")
|
||||
|
||||
new_model = ""
|
||||
for x in new_files:
|
||||
find_new_model = reg.findall(x)
|
||||
if len(find_new_model) > 0:
|
||||
new_model = find_new_model[0]
|
||||
# It's unlikely we have 2 new modeling files in a pull request.
|
||||
break
|
||||
return new_model
|
||||
|
||||
|
||||
def parse_message(message: str) -> str:
|
||||
"""
|
||||
Parses a GitHub pull request's comment to find the models specified in it to run slow CI.
|
||||
|
||||
Args:
|
||||
message (`str`): The body of a GitHub pull request's comment.
|
||||
|
||||
Returns:
|
||||
`str`: The substring in `message` after `run-slow`, run_slow` or run slow`. If no such prefix is found, the
|
||||
empty string is returned.
|
||||
"""
|
||||
if message is None:
|
||||
return ""
|
||||
|
||||
message = message.strip().lower()
|
||||
|
||||
# run-slow: model_1, model_2
|
||||
if not message.startswith(("run-slow", "run_slow", "run slow")):
|
||||
return ""
|
||||
message = message[len("run slow") :]
|
||||
# remove leading `:`
|
||||
while message.strip().startswith(":"):
|
||||
message = message.strip()[1:]
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_models(message: str):
|
||||
models = parse_message(message)
|
||||
return models.replace(",", " ").split()
|
||||
|
||||
|
||||
def check_model_names(model_name: str):
|
||||
allowed = string.ascii_letters + string.digits + "_"
|
||||
return not (model_name.startswith("_") or model_name.endswith("_")) and all(c in allowed for c in model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--message", type=str, default="", help="The content of a comment.")
|
||||
parser.add_argument("--quantization", action="store_true", help="If we collect quantization tests")
|
||||
args = parser.parse_args()
|
||||
|
||||
new_model = get_new_model()
|
||||
specified_models = get_models(args.message)
|
||||
models = ([] if new_model == "" else [new_model]) + specified_models
|
||||
# a guard for strange model names
|
||||
models = [model for model in models if check_model_names(model)]
|
||||
|
||||
# Add prefix
|
||||
final_list = []
|
||||
for model in models:
|
||||
if not args.quantization:
|
||||
if os.path.isdir(f"tests/models/{model}"):
|
||||
final_list.append(f"models/{model}")
|
||||
elif os.path.isdir(f"tests/{model}") and model != "quantization":
|
||||
final_list.append(model)
|
||||
elif os.path.isdir(f"tests/quantization/{model}"):
|
||||
final_list.append(f"quantization/{model}")
|
||||
|
||||
# Use `json.dumps` to get the double quotes instead of single quote, e.g. `["model/vit"]`.
|
||||
# (to avoid some shell expansion issues when this script is called from a Github Actions workflow)
|
||||
print(json.dumps(sorted(set(final_list))))
|
||||
73
utils/print_env.py
Normal file
73
utils/print_env.py
Normal file
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# this script dumps information about the environment
|
||||
|
||||
import sys
|
||||
|
||||
import transformers
|
||||
from transformers import is_torch_hpu_available, is_torch_xpu_available
|
||||
|
||||
|
||||
print("Python version:", sys.version)
|
||||
print("transformers version:", transformers.__version__)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
print("Torch version:", torch.__version__)
|
||||
accelerator = "NA"
|
||||
if torch.cuda.is_available():
|
||||
accelerator = "CUDA"
|
||||
elif is_torch_xpu_available():
|
||||
accelerator = "XPU"
|
||||
elif is_torch_hpu_available():
|
||||
accelerator = "HPU"
|
||||
|
||||
print("Torch accelerator:", accelerator)
|
||||
|
||||
if accelerator == "CUDA":
|
||||
print("Cuda version:", torch.version.cuda)
|
||||
print("CuDNN version:", torch.backends.cudnn.version())
|
||||
print("Number of GPUs available:", torch.cuda.device_count())
|
||||
print("NCCL version:", torch.cuda.nccl.version())
|
||||
elif accelerator == "XPU":
|
||||
print("SYCL version:", torch.version.xpu)
|
||||
print("Number of XPUs available:", torch.xpu.device_count())
|
||||
elif accelerator == "HPU":
|
||||
print("HPU version:", torch.__version__.split("+")[-1])
|
||||
print("Number of HPUs available:", torch.hpu.device_count())
|
||||
except ImportError:
|
||||
print("Torch version:", None)
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
|
||||
print("DeepSpeed version:", deepspeed.__version__)
|
||||
except ImportError:
|
||||
print("DeepSpeed version:", None)
|
||||
|
||||
|
||||
try:
|
||||
import torchcodec
|
||||
|
||||
versions = torchcodec._core.get_ffmpeg_library_versions()
|
||||
print("FFmpeg version:", versions["ffmpeg_version"])
|
||||
except ImportError:
|
||||
print("FFmpeg version:", None)
|
||||
except (AttributeError, KeyError, RuntimeError):
|
||||
print("Failed to get FFmpeg version")
|
||||
151
utils/process_bad_commit_report.py
Normal file
151
utils/process_bad_commit_report.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""An internal script to process `new_failures_with_bad_commit.json` produced by `utils/check_bad_commit.py`.
|
||||
|
||||
This is used by `.github/workflows/check_failed_model_tests.yml` to produce a slack report of the following form
|
||||
|
||||
```
|
||||
<{url}|New failed tests>
|
||||
{
|
||||
"GH_ydshieh": {
|
||||
"vit": 1
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
|
||||
from get_previous_daily_ci import get_last_daily_ci_run
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = HfApi()
|
||||
|
||||
job_name = os.environ.get("JOB_NAME")
|
||||
|
||||
# Upload to Hub and get the url
|
||||
# if it is not a scheduled run, upload the reports to a subfolder under `report_repo_folder`
|
||||
report_repo_subfolder = ""
|
||||
if os.getenv("GITHUB_EVENT_NAME") != "schedule":
|
||||
report_repo_subfolder = f"{os.getenv('GITHUB_RUN_NUMBER')}-{os.getenv('GITHUB_RUN_ID')}"
|
||||
report_repo_subfolder = f"runs/{report_repo_subfolder}"
|
||||
|
||||
workflow_run = get_last_daily_ci_run(
|
||||
token=os.environ["ACCESS_REPO_INFO_TOKEN"], workflow_run_id=os.getenv("GITHUB_RUN_ID")
|
||||
)
|
||||
workflow_run_created_time = workflow_run["created_at"]
|
||||
|
||||
report_repo_folder = workflow_run_created_time.split("T")[0]
|
||||
|
||||
if report_repo_subfolder:
|
||||
report_repo_folder = f"{report_repo_folder}/{report_repo_subfolder}"
|
||||
|
||||
report_repo_id = os.getenv("REPORT_REPO_ID")
|
||||
|
||||
with open("new_failures_with_bad_commit.json") as fp:
|
||||
data = json.load(fp)
|
||||
|
||||
with open(f"ci_results_{job_name}/job_links.json") as fp:
|
||||
job_links = json.load(fp)
|
||||
|
||||
# Update `new_failures_with_bad_commit.json` with job links information before uploading to Hub repository
|
||||
# - need to change `single-gpu` to `single` and same for `multi-gpu` to match the keys in `job_link`.
|
||||
for model, model_result in data.items():
|
||||
for device, failed_tests in model_result.items():
|
||||
for failed_test in failed_tests:
|
||||
key = model
|
||||
if list(job_links.keys()) == [job_name]:
|
||||
key = job_name
|
||||
failed_test["job_link"] = job_links[key][device.replace("-gpu", "")]
|
||||
|
||||
with open("new_failures_with_bad_commit.json", "w") as fp:
|
||||
json.dump(data, fp, indent=4, ensure_ascii=False)
|
||||
|
||||
commit_info = api.upload_file(
|
||||
path_or_fileobj="new_failures_with_bad_commit.json",
|
||||
path_in_repo=f"{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit.json",
|
||||
repo_id=report_repo_id,
|
||||
repo_type="dataset",
|
||||
token=os.environ.get("TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN", None),
|
||||
)
|
||||
url = f"https://huggingface.co/datasets/{report_repo_id}/raw/{commit_info.oid}/{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit.json"
|
||||
|
||||
with open("new_failures_with_bad_commit_url.txt", "w") as fp:
|
||||
fp.write(url)
|
||||
|
||||
# TODO: extend
|
||||
team_members = [
|
||||
"ArthurZucker",
|
||||
"Cyrilvallez",
|
||||
"LysandreJik",
|
||||
"MekkCyber",
|
||||
"Rocketknight1",
|
||||
"SunMarc",
|
||||
"ebezzam",
|
||||
"eustlb",
|
||||
"gante",
|
||||
"itazap",
|
||||
"ivarflakstad",
|
||||
"molbap",
|
||||
"remi-or",
|
||||
"stevhliu",
|
||||
"vasqu",
|
||||
"ydshieh",
|
||||
"zucchini-nlp",
|
||||
"tarekziade",
|
||||
]
|
||||
|
||||
# Counting the number of failures grouped by authors
|
||||
new_data = {}
|
||||
for model, model_result in data.items():
|
||||
for device, failed_tests in model_result.items():
|
||||
for failed_test in failed_tests:
|
||||
author = failed_test["author"]
|
||||
|
||||
# If author is not a team member, and the PR is already merged: change to the one who merged the PR
|
||||
if author not in team_members and failed_test["merged_by"] is not None:
|
||||
author = failed_test["merged_by"]
|
||||
|
||||
if author not in new_data:
|
||||
new_data[author] = Counter()
|
||||
new_data[author].update([model])
|
||||
for author in new_data:
|
||||
new_data[author] = dict(new_data[author])
|
||||
|
||||
# Group by author
|
||||
new_data_full = {author: deepcopy(data) for author in new_data}
|
||||
for author, _data in new_data_full.items():
|
||||
for model, model_result in _data.items():
|
||||
for device, failed_tests in model_result.items():
|
||||
failed_tests = [
|
||||
x
|
||||
for x in failed_tests
|
||||
if x["author"] == author or (x["merged_by"] is not None and x["merged_by"] == author)
|
||||
]
|
||||
model_result[device] = failed_tests
|
||||
_data[model] = {k: v for k, v in model_result.items() if len(v) > 0}
|
||||
new_data_full[author] = {k: v for k, v in _data.items() if len(v) > 0}
|
||||
|
||||
with open("new_failures_with_bad_commit_grouped_by_authors.json", "w") as fp:
|
||||
json.dump(new_data_full, fp, ensure_ascii=False, indent=4)
|
||||
commit_info = api.upload_file(
|
||||
path_or_fileobj="new_failures_with_bad_commit_grouped_by_authors.json",
|
||||
path_in_repo=f"{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit_grouped_by_authors.json",
|
||||
repo_id=report_repo_id,
|
||||
repo_type="dataset",
|
||||
token=os.environ.get("TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN", None),
|
||||
)
|
||||
url = f"https://huggingface.co/datasets/{report_repo_id}/raw/{commit_info.oid}/{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit_grouped_by_authors.json"
|
||||
|
||||
# Add `GH_` prefix as keyword mention
|
||||
output = {}
|
||||
for author, item in new_data.items():
|
||||
author = f"GH_{author}"
|
||||
output[author] = item
|
||||
|
||||
report = f"<{url}|New failed tests>\\n\\n"
|
||||
report += json.dumps(output, indent=4).replace('"', '\\"').replace("\n", "\\n")
|
||||
print(report)
|
||||
146
utils/process_circleci_workflow_test_reports.py
Normal file
146
utils/process_circleci_workflow_test_reports.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--workflow_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
r = requests.get(
|
||||
f"https://circleci.com/api/v2/workflow/{args.workflow_id}/job",
|
||||
headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")},
|
||||
)
|
||||
jobs = r.json()["items"]
|
||||
|
||||
os.makedirs("outputs", exist_ok=True)
|
||||
workflow_summary = {}
|
||||
failure_entries = []
|
||||
|
||||
for job in jobs:
|
||||
if job["name"].startswith(("tests_", "examples_", "pipelines_")):
|
||||
url = f"https://circleci.com/api/v2/project/{job['project_slug']}/{job['job_number']}/artifacts"
|
||||
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
|
||||
job_artifacts = r.json()["items"]
|
||||
|
||||
os.makedirs(f"outputs/{job['name']}", exist_ok=True)
|
||||
|
||||
job_test_summaries = {}
|
||||
job_failure_lines = {}
|
||||
|
||||
for artifact in job_artifacts:
|
||||
url = artifact["url"]
|
||||
if artifact["path"].endswith("/summary_short.txt"):
|
||||
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
|
||||
job_test_summaries[artifact["node_index"]] = r.text
|
||||
elif artifact["path"].endswith("/failures_line.txt"):
|
||||
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
|
||||
job_failure_lines[artifact["node_index"]] = r.text
|
||||
|
||||
summary = {}
|
||||
for node_index, node_test_summary in job_test_summaries.items():
|
||||
for line in node_test_summary.splitlines():
|
||||
if line.startswith("PASSED "):
|
||||
summary[line[7:]] = "passed"
|
||||
elif line.startswith("FAILED "):
|
||||
summary[line[7:].split()[0]] = "failed"
|
||||
|
||||
summary = dict(sorted(summary.items(), key=lambda x: (x[1], x[0])))
|
||||
workflow_summary[job["name"]] = summary
|
||||
|
||||
with open(f"outputs/{job['name']}/test_summary.json", "w") as fp:
|
||||
json.dump(summary, fp, indent=4)
|
||||
|
||||
# Collect failure details
|
||||
for node_index, summary_text in job_test_summaries.items():
|
||||
failure_lines_list = [
|
||||
l.strip()
|
||||
for l in job_failure_lines.get(node_index, "").splitlines()
|
||||
if l.strip() and not l.strip().startswith(("=", "_", "short test summary")) and ": " in l
|
||||
]
|
||||
|
||||
failure_idx = 0
|
||||
for line in summary_text.splitlines():
|
||||
if line.startswith("FAILED ") and " - Failed: (subprocess)" not in line:
|
||||
test_name, _, short_error = line[7:].strip().partition(" - ")
|
||||
test_name = test_name.strip()
|
||||
parts = test_name.split("::", 1)[0].split("/")
|
||||
model_name = parts[2] if len(parts) >= 3 and test_name.startswith("tests/models/") else None
|
||||
full_error = (
|
||||
failure_lines_list[failure_idx] if failure_idx < len(failure_lines_list) else short_error
|
||||
)
|
||||
|
||||
failure_entries.append(
|
||||
{
|
||||
"job_name": job["name"],
|
||||
"test_name": test_name,
|
||||
"short_error": short_error,
|
||||
"error": full_error,
|
||||
"model_name": model_name,
|
||||
}
|
||||
)
|
||||
failure_idx += 1
|
||||
|
||||
# Build workflow summary
|
||||
new_workflow_summary = {}
|
||||
for job_name, job_summary in workflow_summary.items():
|
||||
for test, status in job_summary.items():
|
||||
new_workflow_summary.setdefault(test, {})[job_name] = status
|
||||
|
||||
new_workflow_summary = {
|
||||
test: dict(sorted(result.items())) for test, result in sorted(new_workflow_summary.items())
|
||||
}
|
||||
|
||||
with open("outputs/test_summary.json", "w") as fp:
|
||||
json.dump(new_workflow_summary, fp, indent=4)
|
||||
|
||||
# Aggregate failures by test and model
|
||||
by_test, by_model = {}, {}
|
||||
|
||||
for entry in failure_entries:
|
||||
# Normalize test name
|
||||
normalized = entry["test_name"].split("[", 1)[0]
|
||||
parts = normalized.split("::")
|
||||
normalized = "::".join(parts[:-1] + [re.sub(r"_\d{2,}.*$", "", parts[-1])])
|
||||
|
||||
by_test.setdefault(normalized, {"count": 0, "errors": Counter(), "jobs": set(), "variants": set()})
|
||||
by_test[normalized]["count"] += 1
|
||||
by_test[normalized]["errors"][entry["error"]] += 1
|
||||
by_test[normalized]["jobs"].add(entry["job_name"])
|
||||
by_test[normalized]["variants"].add(entry["test_name"])
|
||||
|
||||
if entry["model_name"]:
|
||||
by_model.setdefault(entry["model_name"], {"count": 0, "errors": Counter(), "tests": set()})
|
||||
by_model[entry["model_name"]]["count"] += 1
|
||||
by_model[entry["model_name"]]["errors"][entry["error"]] += 1
|
||||
by_model[entry["model_name"]]["tests"].add(entry["test_name"])
|
||||
|
||||
# Convert Counter and sets to dicts/lists for JSON serialization
|
||||
for info in by_test.values():
|
||||
info["errors"] = dict(info["errors"].most_common())
|
||||
info["jobs"] = sorted(info["jobs"])
|
||||
info["variants"] = sorted(info["variants"])
|
||||
for info in by_model.values():
|
||||
info["errors"] = dict(info["errors"].most_common())
|
||||
info["tests"] = sorted(info["tests"])
|
||||
|
||||
with open("outputs/failure_summary.json", "w") as fp:
|
||||
json.dump({"failures": failure_entries, "by_test": by_test, "by_model": by_model}, fp, indent=4)
|
||||
74
utils/process_test_artifacts.py
Normal file
74
utils/process_test_artifacts.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
This helper computes the "ideal" number of nodes to use in circle CI.
|
||||
For each job, we compute this parameter and pass it to the `generated_config.yaml`.
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
|
||||
MAX_PARALLEL_NODES = 8 # TODO create a mapping!
|
||||
AVERAGE_TESTS_PER_NODES = 5
|
||||
|
||||
|
||||
def count_lines(filepath):
|
||||
"""Count the number of lines in a file."""
|
||||
try:
|
||||
with open(filepath, "r") as f:
|
||||
return len(f.read().split("\n"))
|
||||
except FileNotFoundError:
|
||||
return 0
|
||||
|
||||
|
||||
def compute_parallel_nodes(line_count, max_tests_per_node=10):
|
||||
"""Compute the number of parallel nodes required."""
|
||||
num_nodes = math.ceil(line_count / AVERAGE_TESTS_PER_NODES)
|
||||
if line_count < 4:
|
||||
return 1
|
||||
return min(MAX_PARALLEL_NODES, num_nodes)
|
||||
|
||||
|
||||
def process_artifacts(input_file, output_file):
|
||||
# Read the JSON data from the input file
|
||||
with open(input_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Process items and build the new JSON structure
|
||||
transformed_data = {}
|
||||
for item in data.get("items", []):
|
||||
if "test_list" in item["path"]:
|
||||
key = os.path.splitext(os.path.basename(item["path"]))[0]
|
||||
transformed_data[key] = item["url"]
|
||||
parallel_key = key.split("_test")[0] + "_parallelism"
|
||||
file_path = os.path.join("test_preparation", f"{key}.txt")
|
||||
line_count = count_lines(file_path)
|
||||
transformed_data[parallel_key] = compute_parallel_nodes(line_count)
|
||||
|
||||
# Remove the "generated_config" key if it exists
|
||||
if "generated_config" in transformed_data:
|
||||
del transformed_data["generated_config"]
|
||||
|
||||
# Write the transformed data to the output file
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(transformed_data, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_file = "test_preparation/artifacts.json"
|
||||
output_file = "test_preparation/transformed_artifacts.json"
|
||||
process_artifacts(input_file, output_file)
|
||||
227
utils/release.py
Normal file
227
utils/release.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that prepares the repository for releases (or patches) by updating all versions in the relevant places. It
|
||||
also performs some post-release cleanup, by updating the links in the main README to respective model doc pages (from
|
||||
main to stable).
|
||||
|
||||
To prepare for a release, use from the root of the repo on the release branch with:
|
||||
|
||||
```bash
|
||||
python release.py
|
||||
```
|
||||
|
||||
or use `make pre-release`.
|
||||
|
||||
To prepare for a patch release, use from the root of the repo on the release branch with:
|
||||
|
||||
```bash
|
||||
python release.py --patch
|
||||
```
|
||||
|
||||
or use `make pre-patch`.
|
||||
|
||||
To do the post-release cleanup, use from the root of the repo on the main branch with:
|
||||
|
||||
```bash
|
||||
python release.py --post_release
|
||||
```
|
||||
|
||||
or use `make post-release`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import packaging.version
|
||||
|
||||
|
||||
# All paths are defined with the intent that this script should be run from the root of the repo.
|
||||
PATH_TO_EXAMPLES = "examples/"
|
||||
PATH_TO_MODELS = "src/transformers/models"
|
||||
PATH_TO_UTILS = "utils"
|
||||
# This maps a type of file to the pattern to look for when searching where the version is defined, as well as the
|
||||
# template to follow when replacing it with the new version.
|
||||
REPLACE_PATTERNS = {
|
||||
"examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'),
|
||||
"init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
|
||||
"setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
|
||||
"uv_script_release": (
|
||||
re.compile(r'^# "transformers(\[.+\])?.*$', re.MULTILINE),
|
||||
r'# "transformers\g<1>==VERSION",',
|
||||
),
|
||||
"uv_script_dev": (
|
||||
re.compile(r'^# "transformers(\[.+\])?.*$', re.MULTILINE),
|
||||
r'# "transformers\g<1> @ git+https://github.com/huggingface/transformers.git",',
|
||||
),
|
||||
}
|
||||
# This maps a type of file to its path in Transformers
|
||||
REPLACE_FILES = {
|
||||
"init": "src/transformers/__init__.py",
|
||||
"setup": "setup.py",
|
||||
}
|
||||
README_FILE = "README.md"
|
||||
UV_SCRIPT_MARKER = "# /// script"
|
||||
|
||||
|
||||
def update_version_in_file(fname: str, version: str, file_type: str):
|
||||
"""
|
||||
Update the version of Transformers in one file.
|
||||
|
||||
Args:
|
||||
fname (`str`): The path to the file where we want to update the version.
|
||||
version (`str`): The new version to set in the file.
|
||||
file_type (`str`): The type of the file (should be a key in `REPLACE_PATTERNS`).
|
||||
"""
|
||||
with open(fname, "r", encoding="utf-8", newline="\n") as f:
|
||||
code = f.read()
|
||||
re_pattern, replace = REPLACE_PATTERNS[file_type]
|
||||
replace = replace.replace("VERSION", version)
|
||||
code = re_pattern.sub(replace, code)
|
||||
with open(fname, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(code)
|
||||
|
||||
|
||||
def update_version_in_examples(version: str, patch: bool = False):
|
||||
"""
|
||||
Update the version in all examples files.
|
||||
|
||||
Args:
|
||||
version (`str`): The new version to set in the examples.
|
||||
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
|
||||
"""
|
||||
for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES):
|
||||
# Removing some of the folders with non-actively maintained examples from the walk
|
||||
if "legacy" in directories:
|
||||
directories.remove("legacy")
|
||||
for fname in fnames:
|
||||
if fname.endswith(".py"):
|
||||
if UV_SCRIPT_MARKER in Path(folder, fname).read_text():
|
||||
# Update the dependencies in UV scripts
|
||||
uv_script_file_type = "uv_script_dev" if ".dev" in version else "uv_script_release"
|
||||
update_version_in_file(os.path.join(folder, fname), version, file_type=uv_script_file_type)
|
||||
if not patch:
|
||||
# We don't update the version in the examples for patch releases.
|
||||
update_version_in_file(os.path.join(folder, fname), version, file_type="examples")
|
||||
|
||||
|
||||
def global_version_update(version: str, patch: bool = False):
|
||||
"""
|
||||
Update the version in all needed files.
|
||||
|
||||
Args:
|
||||
version (`str`): The new version to set everywhere.
|
||||
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
|
||||
"""
|
||||
for pattern, fname in REPLACE_FILES.items():
|
||||
update_version_in_file(fname, version, pattern)
|
||||
# REMOVED AFTER v5! Uncomment to start updating the version of the examples again
|
||||
# update_version_in_examples(version, patch=patch)
|
||||
|
||||
|
||||
def remove_conversion_scripts():
|
||||
"""
|
||||
Delete the scripts that convert models from older, unsupported formats. We don't want to include these
|
||||
in release wheels because they often have to open insecure file types (pickle, Torch .bin models). This results in
|
||||
vulnerability scanners flagging us and can cause compliance issues for users with strict security policies.
|
||||
"""
|
||||
model_dir = Path(PATH_TO_MODELS)
|
||||
for conversion_script in list(model_dir.glob("**/convert*.py")):
|
||||
conversion_script.unlink()
|
||||
|
||||
|
||||
def remove_internal_utils():
|
||||
"""
|
||||
Delete internal utils that should not be included in releases for security reasons.
|
||||
"""
|
||||
(Path(PATH_TO_UTILS) / "modular_model_detector.py").unlink()
|
||||
|
||||
|
||||
def get_version() -> packaging.version.Version:
|
||||
"""
|
||||
Reads the current version in the main __init__.
|
||||
"""
|
||||
with open(REPLACE_FILES["init"], "r") as f:
|
||||
code = f.read()
|
||||
default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
|
||||
return packaging.version.parse(default_version)
|
||||
|
||||
|
||||
def pre_release_work(patch: bool = False):
|
||||
"""
|
||||
Do all the necessary pre-release steps:
|
||||
- figure out the next minor release version and ask confirmation
|
||||
- update the version everywhere
|
||||
- clean-up the model list in the main README
|
||||
|
||||
Args:
|
||||
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
|
||||
"""
|
||||
# First let's get the default version: base version if we are in dev, bump minor otherwise.
|
||||
default_version = get_version()
|
||||
if patch and default_version.is_devrelease:
|
||||
raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
|
||||
if default_version.is_devrelease:
|
||||
default_version = default_version.base_version
|
||||
elif patch:
|
||||
default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
|
||||
else:
|
||||
default_version = f"{default_version.major}.{default_version.minor + 1}.0"
|
||||
|
||||
# Now let's ask nicely if we have found the right version.
|
||||
version = input(f"Which version are you releasing? [{default_version}]")
|
||||
if len(version) == 0:
|
||||
version = default_version
|
||||
|
||||
print(f"Updating version to {version}.")
|
||||
global_version_update(version, patch=patch)
|
||||
print("Deleting conversion and internal utils scripts.")
|
||||
remove_conversion_scripts()
|
||||
remove_internal_utils()
|
||||
|
||||
|
||||
def post_release_work():
|
||||
"""
|
||||
Do all the necessary post-release steps:
|
||||
- figure out the next dev version and ask confirmation
|
||||
- update the version everywhere
|
||||
- clean-up the model list in the main README
|
||||
"""
|
||||
# First let's get the current version
|
||||
current_version = get_version()
|
||||
dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
|
||||
current_version = current_version.base_version
|
||||
|
||||
# Check with the user we got that right.
|
||||
version = input(f"Which version are we developing now? [{dev_version}]")
|
||||
if len(version) == 0:
|
||||
version = dev_version
|
||||
|
||||
print(f"Updating version to {version}.")
|
||||
global_version_update(version)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
|
||||
parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
|
||||
args = parser.parse_args()
|
||||
if not args.post_release:
|
||||
pre_release_work(patch=args.patch)
|
||||
elif args.patch:
|
||||
print("Nothing to do after a patch :-)")
|
||||
else:
|
||||
post_release_work()
|
||||
251
utils/rules.toml
Normal file
251
utils/rules.toml
Normal file
@@ -0,0 +1,251 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This file can carry repo-local rule overrides for faster iteration between
|
||||
# `transformers-mlinter` releases.
|
||||
# Keep it synced with the upstream package's rules.toml when possible so local
|
||||
# behavior does not drift from the published checker longer than necessary.
|
||||
|
||||
version = 1
|
||||
|
||||
[rules.TRF001]
|
||||
description = "Class-level config_class on <Model>PreTrainedModel should match <Model>Config naming."
|
||||
default_enabled = true
|
||||
allowlist_models = ["qwen3_omni_moe"]
|
||||
|
||||
[rules.TRF001.explanation]
|
||||
what_it_does = "Checks naming consistency between <Model>PreTrainedModel and config_class."
|
||||
why_bad = "Mismatched config_class can break loading, auto classes, and developer expectations."
|
||||
diff = '''
|
||||
class AcmePreTrainedModel(PreTrainedModel):
|
||||
- config_class = WileConfig
|
||||
+ config_class = AcmeConfig
|
||||
'''
|
||||
|
||||
[rules.TRF002]
|
||||
description = "base_model_prefix should be a non-empty canonical string when defined on PreTrainedModel classes."
|
||||
default_enabled = true
|
||||
allowlist_models = ["lighton_ocr"]
|
||||
|
||||
[rules.TRF002.explanation]
|
||||
what_it_does = "Checks that base_model_prefix, when set, is a non-empty, whitespace-free string literal."
|
||||
why_bad = "Invalid prefixes can break weight loading key mapping and base model access patterns."
|
||||
diff = '''
|
||||
class AcmePreTrainedModel(PreTrainedModel):
|
||||
- base_model_prefix = ""
|
||||
+ base_model_prefix = "model"
|
||||
'''
|
||||
|
||||
[rules.TRF003]
|
||||
description = "forward() should use capture_output/can_return_tuple decorators instead of manual return_dict branching."
|
||||
default_enabled = false
|
||||
allowlist_models = []
|
||||
|
||||
[rules.TRF003.explanation]
|
||||
what_it_does = "Detects forward methods that use the old 'if not return_dict: return (x,)' pattern."
|
||||
why_bad = "The old return_dict branching pattern is error-prone and verbose. Use the capture_output or can_return_tuple decorators instead."
|
||||
diff = '''
|
||||
-def forward(self, x, return_dict=None):
|
||||
- if not return_dict:
|
||||
- return (x,)
|
||||
- return AcmeModelOutput(last_hidden_state=x)
|
||||
+@can_return_tuple
|
||||
+def forward(self, x):
|
||||
+ return AcmeModelOutput(last_hidden_state=x)
|
||||
'''
|
||||
|
||||
[rules.TRF004]
|
||||
description = "Models must never override tie_weights. Use _tied_weights_keys instead."
|
||||
default_enabled = true
|
||||
allowlist_models = ["data2vec", "hubert", "sew", "sew_d", "unispeech", "unispeech_sat", "wav2vec2", "wav2vec2_conformer", "wavlm"]
|
||||
|
||||
[rules.TRF004.explanation]
|
||||
what_it_does = "Checks that no model class defines a tie_weights method."
|
||||
why_bad = "Overriding tie_weights leads to bad consequences for loading, device_map computation, and saving. Use _tied_weights_keys class attribute to declare tied weights instead."
|
||||
diff = '''
|
||||
-def tie_weights(self):
|
||||
- self.lm_head.weight = self.emb.weight
|
||||
+class AcmeForCausalLM(AcmePreTrainedModel):
|
||||
+ _tied_weights_keys = ["lm_head.weight"]
|
||||
'''
|
||||
|
||||
[rules.TRF005]
|
||||
description = "_no_split_modules, when defined, should be a list/tuple of non-empty strings."
|
||||
default_enabled = true
|
||||
allowlist_models = ["d_fine", "deformable_detr", "glm46v", "lw_detr", "pp_doclayout_v3", "rt_detr", "rt_detr_v2", "voxtral", "voxtral_realtime"]
|
||||
|
||||
[rules.TRF005.explanation]
|
||||
what_it_does = "Checks the shape of _no_split_modules when present."
|
||||
why_bad = "Malformed values can break device-map partitioning and sharding behavior."
|
||||
diff = '''
|
||||
-_no_split_modules = [SomeLayerClass, ""]
|
||||
+_no_split_modules = ["AcmeDecoderLayer", "AcmeAttention"]
|
||||
'''
|
||||
|
||||
[rules.TRF006]
|
||||
description = "forward with cache arguments should reference cache control/state variables consistently."
|
||||
default_enabled = true
|
||||
allowlist_models = ["chinese_clip", "evolla", "idefics2", "llama4"]
|
||||
|
||||
[rules.TRF006.explanation]
|
||||
what_it_does = "Checks forward signatures that expose cache arguments for usage of those arguments in method body."
|
||||
why_bad = "Unused cache arguments can indicate incomplete caching support and inconsistent API behavior."
|
||||
diff = '''
|
||||
def forward(self, x, past_key_values=None, use_cache=False):
|
||||
+ if use_cache:
|
||||
+ ...
|
||||
return x
|
||||
'''
|
||||
|
||||
[rules.TRF007]
|
||||
description = "self.post_init() in __init__ should remain at the end of initialization for PreTrainedModel classes."
|
||||
default_enabled = true
|
||||
allowlist_models = ["distilbert", "lxmert", "mt5", "pix2struct", "pop2piano", "switch_transformers", "t5"]
|
||||
|
||||
[rules.TRF007.explanation]
|
||||
what_it_does = "Checks for self attribute assignments after self.post_init() in __init__."
|
||||
why_bad = "Mutating model structure after post_init can bypass intended initialization/finalization logic."
|
||||
diff = '''
|
||||
def __init__(self, config):
|
||||
...
|
||||
- self.post_init()
|
||||
- self.proj = nn.Linear(...)
|
||||
+ self.proj = nn.Linear(...)
|
||||
+ self.post_init()
|
||||
'''
|
||||
|
||||
[rules.TRF008]
|
||||
description = "Doc decorators on PreTrainedModel classes should avoid empty add_start_docstrings usage."
|
||||
default_enabled = true
|
||||
|
||||
[rules.TRF008.explanation]
|
||||
what_it_does = "Checks add_start_docstrings usage on model classes for non-empty docstring arguments."
|
||||
why_bad = "Empty decorator usage produces unclear docs and weakens generated API documentation quality."
|
||||
diff = '''
|
||||
-@add_start_docstrings("")
|
||||
+@add_start_docstrings("The Acme model.")
|
||||
class AcmeModel(AcmePreTrainedModel):
|
||||
...
|
||||
'''
|
||||
|
||||
[rules.TRF009]
|
||||
description = "modeling_<name>.py should avoid importing implementation code from another model package."
|
||||
default_enabled = true
|
||||
allowlist_models = ["dpr", "maskformer", "sam3_video", "vision_text_dual_encoder"]
|
||||
|
||||
[rules.TRF009.explanation]
|
||||
what_it_does = "Checks modeling files for cross-model imports such as transformers.models.other_model.* or from ..other_model.* imports."
|
||||
why_bad = "Cross-model implementation imports violate the single-file policy and make model behavior harder to inspect and maintain."
|
||||
diff = '''
|
||||
-from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
+# Keep implementation local to this file.
|
||||
+# If reusing code, copy it with a # Copied from comment.
|
||||
'''
|
||||
|
||||
[rules.TRF010]
|
||||
description = "Direct config definitions must use @strict(accept_kwargs=True)."
|
||||
default_enabled = true
|
||||
allowlist_models = ["nemotron_h", "vibevoice_asr"]
|
||||
|
||||
[rules.TRF010.explanation]
|
||||
what_it_does = "Checks direct PreTrainedConfig/PretrainedConfig subclasses in configuration_*.py and modular_*.py for an explicit @strict(accept_kwargs=True) decorator."
|
||||
why_bad = "Without strict, new config classes miss the repo's runtime type-validation contract and drift from the dataclass-based config standard."
|
||||
diff = '''
|
||||
+@strict(accept_kwargs=True)
|
||||
class AcmeConfig(PreTrainedConfig):
|
||||
...
|
||||
'''
|
||||
|
||||
[rules.TRF011]
|
||||
description = "forward() must not access non-nn.Module attributes on submodules (breaks pipeline parallelism with Identity replacement)."
|
||||
default_enabled = true
|
||||
allowlist_models = []
|
||||
|
||||
[rules.TRF011.explanation]
|
||||
what_it_does = "In forward() methods of PreTrainedModel subclasses, checks for attribute accesses on submodules that would not exist on torch.nn.Identity. This includes attribute accesses on loop variables iterating over self.layers, and self.<submodule>.<attr> chains where <attr> is not a standard nn.Module attribute."
|
||||
why_bad = "Pipeline parallelism may replace any submodule with torch.nn.Identity. Accessing custom attributes (e.g. decoder_layer.attention_type) on a replaced module raises AttributeError at runtime. Per-layer metadata should be read from self.config instead."
|
||||
diff = '''
|
||||
def forward(self, ...):
|
||||
- for decoder_layer in self.layers:
|
||||
+ for i, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
- attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
+ attention_mask=causal_mask_mapping[self.config.layer_types[i]],
|
||||
)
|
||||
'''
|
||||
|
||||
[rules.TRF012]
|
||||
description = "_init_weights must use init primitives, not in-place operations on module weights."
|
||||
default_enabled = true
|
||||
allowlist_models = []
|
||||
|
||||
[rules.TRF012.explanation]
|
||||
what_it_does = "Checks that _init_weights(self, module) does not use in-place operations (e.g. .normal_(), .zero_()) directly on module weights."
|
||||
why_bad = "We rely on internal flags set on parameters to track whether they need re-initialization. In-place ops bypass this mechanism. Use the `init` primitives instead."
|
||||
diff = '''
|
||||
+from transformers import initialization as init
|
||||
+
|
||||
def _init_weights(self, module):
|
||||
- module.weight.normal_(mean=0.0, std=0.02)
|
||||
+ init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
'''
|
||||
|
||||
[rules.TRF013]
|
||||
description = "PreTrainedModel __init__ must call self.post_init()."
|
||||
default_enabled = true
|
||||
allowlist_models = []
|
||||
|
||||
[rules.TRF013.explanation]
|
||||
what_it_does = "Checks that every PreTrainedModel subclass with an __init__ method calls self.post_init(). In modular files, calling super().__init__() is also accepted since it propagates post_init from the parent."
|
||||
why_bad = "post_init performs essential finalization (weight initialization, gradient checkpointing setup, etc.). Omitting it causes subtle runtime bugs."
|
||||
diff = '''
|
||||
class AcmeModel(AcmePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList(...)
|
||||
+ self.post_init()
|
||||
'''
|
||||
|
||||
[rules.TRF014]
|
||||
description = "`trust_remote_code` should never be used in native model integrations."
|
||||
default_enabled = true
|
||||
allowlist_models = []
|
||||
|
||||
[rules.TRF014.explanation]
|
||||
what_it_does = "Checks whether `trust_remote_code` is passed or used in code (e.g. as kwarg) within native model integration files."
|
||||
why_bad = "`trust_remote_code` allows arbitrary loading, including binaries, which should only be a power feature for users, not a standard use-case. Native integrations must not depend on it, as remote code cannot be reviewed or maintained within transformers."
|
||||
diff = '''
|
||||
class AcmeModel(AcmePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
- self.model = AutoModel.from_pretrained(..., trust_remote_code=True)
|
||||
+ self.model = AutoModel.from_pretrained(...)
|
||||
'''
|
||||
|
||||
[rules.TRF015]
|
||||
description = "Models with non-empty _tied_weights_keys must have tie_word_embeddings in their Config."
|
||||
default_enabled = true
|
||||
allowlist_models = []
|
||||
|
||||
[rules.TRF015.explanation]
|
||||
what_it_does = "When a PreTrainedModel subclass defines _tied_weights_keys as a non-empty collection, checks that the corresponding configuration file declares a tie_word_embeddings field."
|
||||
why_bad = "Without tie_word_embeddings in the config, users cannot control weight tying behavior. The model ties weights unconditionally, breaking serialization round-trips and preventing fine-tuning with untied heads."
|
||||
diff = '''
|
||||
# configuration_foo.py
|
||||
@strict(accept_kwargs=True)
|
||||
class FooConfig(PreTrainedConfig):
|
||||
hidden_size: int = 768
|
||||
+ tie_word_embeddings: bool = True
|
||||
'''
|
||||
198
utils/scan_skipped_tests.py
Normal file
198
utils/scan_skipped_tests.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path().cwd()
|
||||
|
||||
COMMON_TEST_FILES: list[tuple[Path, str]] = [
|
||||
(REPO_ROOT / "tests/test_modeling_common.py", "common"),
|
||||
(REPO_ROOT / "tests/generation/test_utils.py", "GenerationMixin"),
|
||||
]
|
||||
|
||||
MODELS_DIR = REPO_ROOT / "tests/models"
|
||||
|
||||
|
||||
def get_common_tests(file_paths_with_origin: list[tuple[Path, str]]) -> dict[str, str]:
|
||||
"""Extract all common test function names (e.g., 'test_forward')."""
|
||||
tests_with_origin: dict[str, str] = {}
|
||||
for file_path, origin_tag in file_paths_with_origin:
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
for test_name in re.findall(r"^\s*def\s+(test_[A-Za-z0-9_]+)", content, re.MULTILINE):
|
||||
tests_with_origin[test_name] = origin_tag
|
||||
return tests_with_origin
|
||||
|
||||
|
||||
def get_models_and_test_files(models_dir: Path) -> tuple[list[str], list[Path]]:
|
||||
if not models_dir.is_dir():
|
||||
raise FileNotFoundError(f"Models directory not found at {models_dir}")
|
||||
test_files: list[Path] = sorted(models_dir.rglob("test_modeling_*.py"))
|
||||
model_names: list[str] = sorted({file_path.parent.name for file_path in test_files})
|
||||
return model_names, test_files
|
||||
|
||||
|
||||
def _extract_reason_from_decorators(decorators_block: str) -> str:
|
||||
"""Extracts the reason string from a decorator block, if any."""
|
||||
reason_match = re.search(r'reason\s*=\s*["\'](.*?)["\']', decorators_block)
|
||||
if reason_match:
|
||||
return reason_match.group(1)
|
||||
reason_match = re.search(r'\((?:.*?,\s*)?["\'](.*?)["\']\)', decorators_block)
|
||||
if reason_match:
|
||||
return reason_match.group(1)
|
||||
return decorators_block.strip().split("\n")[-1].strip()
|
||||
|
||||
|
||||
def extract_test_info(file_content: str) -> dict[str, tuple[str, str]]:
|
||||
"""
|
||||
Parse a test file once and return a mapping of test functions to their
|
||||
status and skip reason, e.g. {'test_forward': ('SKIPPED', 'too slow')}.
|
||||
"""
|
||||
result: dict[str, tuple[str, str]] = {}
|
||||
pattern = re.compile(r"((?:^\s*@.*?\n)*?)^\s*def\s+(test_[A-Za-z0-9_]+)\b", re.MULTILINE)
|
||||
for decorators_block, test_name in pattern.findall(file_content):
|
||||
if "skip" in decorators_block:
|
||||
result[test_name] = ("SKIPPED", _extract_reason_from_decorators(decorators_block))
|
||||
else:
|
||||
result[test_name] = ("RAN", "")
|
||||
return result
|
||||
|
||||
|
||||
def build_model_overrides(model_test_files: list[Path]) -> dict[str, dict[str, tuple[str, str]]]:
|
||||
"""Return *model_name → {test_name → (status, reason)}* mapping."""
|
||||
model_overrides: dict[str, dict[str, tuple[str, str]]] = {}
|
||||
for file_path in model_test_files:
|
||||
model_name = file_path.parent.name
|
||||
file_content = file_path.read_text(encoding="utf-8")
|
||||
model_overrides.setdefault(model_name, {}).update(extract_test_info(file_content))
|
||||
return model_overrides
|
||||
|
||||
|
||||
def save_json(obj: dict, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def summarize_single_test(
|
||||
test_name: str,
|
||||
model_names: list[str],
|
||||
model_overrides: dict[str, dict[str, tuple[str, str]]],
|
||||
) -> dict[str, object]:
|
||||
"""Print a concise terminal summary for *test_name* and return the raw data."""
|
||||
models_ran, models_skipped, reasons_for_skipping = [], [], []
|
||||
for model_name in model_names:
|
||||
status, reason = model_overrides.get(model_name, {}).get(test_name, ("RAN", ""))
|
||||
if status == "SKIPPED":
|
||||
models_skipped.append(model_name)
|
||||
reasons_for_skipping.append(f"{model_name}: {reason}")
|
||||
else:
|
||||
models_ran.append(model_name)
|
||||
|
||||
total_models = len(model_names)
|
||||
skipped_ratio = len(models_skipped) / total_models if total_models else 0.0
|
||||
|
||||
print(f"\n== {test_name} ==")
|
||||
print(f"Ran : {len(models_ran)}/{total_models}")
|
||||
print(f"Skipped : {len(models_skipped)}/{total_models} ({skipped_ratio:.1%})")
|
||||
for reason_entry in reasons_for_skipping[:10]:
|
||||
print(f" - {reason_entry}")
|
||||
if len(reasons_for_skipping) > 10:
|
||||
print(" - ...")
|
||||
|
||||
return {
|
||||
"models_ran": sorted(models_ran),
|
||||
"models_skipped": sorted(models_skipped),
|
||||
"skipped_proportion": round(skipped_ratio, 4),
|
||||
"reasons_skipped": sorted(reasons_for_skipping),
|
||||
}
|
||||
|
||||
|
||||
def summarize_all_tests(
|
||||
tests_with_origin: dict[str, str],
|
||||
model_names: list[str],
|
||||
model_overrides: dict[str, dict[str, tuple[str, str]]],
|
||||
) -> dict[str, object]:
|
||||
"""Return aggregated data for every discovered common test."""
|
||||
results: dict[str, object] = {}
|
||||
total_models = len(model_names)
|
||||
test_names = list(tests_with_origin)
|
||||
|
||||
print(f"[INFO] Aggregating {len(test_names)} tests...")
|
||||
for index, test_fn in enumerate(test_names, 1):
|
||||
print(f" ({index}/{len(test_names)}) {test_fn}", end="\r")
|
||||
models_ran, models_skipped, reasons_for_skipping = [], [], []
|
||||
for model_name in model_names:
|
||||
status, reason = model_overrides.get(model_name, {}).get(test_fn, ("RAN", ""))
|
||||
if status == "SKIPPED":
|
||||
models_skipped.append(model_name)
|
||||
reasons_for_skipping.append(f"{model_name}: {reason}")
|
||||
else:
|
||||
models_ran.append(model_name)
|
||||
|
||||
skipped_ratio = len(models_skipped) / total_models if total_models else 0.0
|
||||
results[test_fn] = {
|
||||
"origin": tests_with_origin[test_fn],
|
||||
"models_ran": sorted(models_ran),
|
||||
"models_skipped": sorted(models_skipped),
|
||||
"skipped_proportion": round(skipped_ratio, 4),
|
||||
"reasons_skipped": sorted(reasons_for_skipping),
|
||||
}
|
||||
print("\n[INFO] Scan complete.")
|
||||
return results
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Scan model tests for overridden or skipped common or generate tests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=".",
|
||||
help="Directory for JSON output (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_method_name",
|
||||
help="Scan only this test method (single‑test mode)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir).expanduser()
|
||||
test_method_name = args.test_method_name
|
||||
|
||||
tests_with_origin = get_common_tests(COMMON_TEST_FILES)
|
||||
if test_method_name:
|
||||
tests_with_origin = {test_method_name: tests_with_origin.get(test_method_name, "unknown")}
|
||||
|
||||
model_names, model_test_files = get_models_and_test_files(MODELS_DIR)
|
||||
print(f"[INFO] Parsing {len(model_test_files)} model test files once each...")
|
||||
model_overrides = build_model_overrides(model_test_files)
|
||||
|
||||
if test_method_name:
|
||||
data = summarize_single_test(test_method_name, model_names, model_overrides)
|
||||
json_path = output_dir / f"scan_{test_method_name}.json"
|
||||
else:
|
||||
data = summarize_all_tests(tests_with_origin, model_names, model_overrides)
|
||||
json_path = output_dir / "all_tests_scan_result.json"
|
||||
save_json(data, json_path)
|
||||
print(f"\n[INFO] JSON saved to {json_path.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
26
utils/set_cuda_devices_for_ci.py
Normal file
26
utils/set_cuda_devices_for_ci.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""A simple script to set flexibly CUDA_VISIBLE_DEVICES in GitHub Actions CI workflow files."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--test_folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The test folder name of the model being tested. For example, `models/cohere`.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# `test_eager_matches_sdpa_generate` for `cohere` needs a lot of GPU memory!
|
||||
# This depends on the runners. At this moment we are targeting our AWS CI runners.
|
||||
if args.test_folder == "models/cohere":
|
||||
cuda_visible_devices = "0,1,2,3"
|
||||
elif "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||
else:
|
||||
cuda_visible_devices = "0"
|
||||
|
||||
print(cuda_visible_devices)
|
||||
15
utils/slow_documentation_tests.txt
Normal file
15
utils/slow_documentation_tests.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
docs/source/en/generation_strategies.md
|
||||
docs/source/en/model_doc/code_llama.md
|
||||
docs/source/en/model_doc/ctrl.md
|
||||
docs/source/en/model_doc/kosmos-2.md
|
||||
docs/source/en/model_doc/seamless_m4t.md
|
||||
docs/source/en/model_doc/seamless_m4t_v2.md
|
||||
docs/source/en/tasks/prompting.md
|
||||
docs/source/ja/model_doc/code_llama.md
|
||||
src/transformers/models/blip_2/modeling_blip_2.py
|
||||
src/transformers/models/ctrl/modeling_ctrl.py
|
||||
src/transformers/models/fuyu/modeling_fuyu.py
|
||||
src/transformers/models/idefics2/modeling_idefics2.py
|
||||
src/transformers/models/kosmos2/modeling_kosmos2.py
|
||||
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
|
||||
src/transformers/models/musicgen_melody/processing_musicgen_melody.py
|
||||
130
utils/sort_auto_mappings.py
Normal file
130
utils/sort_auto_mappings.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that sorts the names in the auto mappings defines in the auto modules in alphabetical order.
|
||||
|
||||
Use from the root of the repo with:
|
||||
|
||||
```bash
|
||||
python utils/sort_auto_mappings.py
|
||||
```
|
||||
|
||||
to auto-fix all the auto mappings (used in `make style`).
|
||||
|
||||
To only check if the mappings are properly sorted (as used in `make check-repo`), do:
|
||||
|
||||
```bash
|
||||
python utils/sort_auto_mappings.py --check_only
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "sort_auto_mappings",
|
||||
"label": "Sort auto mappings",
|
||||
"cache_globs": ["src/transformers/models/auto/*.py"],
|
||||
"check_args": ["--check_only"],
|
||||
"fix_args": [],
|
||||
}
|
||||
|
||||
# Path are set with the intent you should run this script from the root of the repo.
|
||||
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
|
||||
|
||||
|
||||
# re pattern that matches XXX_MAPPING_NAMES or SPECIAL_MODEL_TYPE_TO_MODULE_NAMES
|
||||
_re_intro_mapping = re.compile(r"[A-Z_]+(_MAPPING|_MODEL_TYPE_TO_MODULE)(\s+|_[A-Z_]+\s+)=\s+OrderedDict(?!\(\*)")
|
||||
# re pattern that matches identifiers in mappings
|
||||
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
|
||||
|
||||
|
||||
def sort_auto_mapping(fname: str, overwrite: bool = False) -> bool | None:
|
||||
"""
|
||||
Sort all auto mappings in a file.
|
||||
|
||||
Args:
|
||||
fname (`str`): The name of the file where we want to sort auto-mappings.
|
||||
overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
|
||||
|
||||
Returns:
|
||||
`Optional[bool]`: Returns `None` if `overwrite=True`. Otherwise returns `True` if the file has an auto-mapping
|
||||
improperly sorted, `False` if the file is okay.
|
||||
"""
|
||||
with open(fname, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
line_idx = 0
|
||||
while line_idx < len(lines):
|
||||
if _re_intro_mapping.search(lines[line_idx]) is not None:
|
||||
# Start of a new mapping!
|
||||
indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
|
||||
while not lines[line_idx].startswith(" " * indent + "("):
|
||||
new_lines.append(lines[line_idx])
|
||||
line_idx += 1
|
||||
|
||||
blocks = []
|
||||
while lines[line_idx].strip() != "]":
|
||||
# Blocks either fit in one line or not
|
||||
if lines[line_idx].strip() == "(":
|
||||
start_idx = line_idx
|
||||
while not lines[line_idx].startswith(" " * indent + ")"):
|
||||
line_idx += 1
|
||||
blocks.append("\n".join(lines[start_idx : line_idx + 1]))
|
||||
else:
|
||||
blocks.append(lines[line_idx])
|
||||
line_idx += 1
|
||||
|
||||
# Sort blocks by their identifiers
|
||||
blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
|
||||
new_lines += blocks
|
||||
else:
|
||||
new_lines.append(lines[line_idx])
|
||||
line_idx += 1
|
||||
|
||||
if overwrite:
|
||||
with open(fname, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(new_lines))
|
||||
else:
|
||||
return "\n".join(new_lines) != content
|
||||
|
||||
|
||||
def sort_all_auto_mappings(overwrite: bool = False):
|
||||
"""
|
||||
Sort all auto mappings in the library.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
|
||||
"""
|
||||
fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
|
||||
diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
|
||||
|
||||
if not overwrite and any(diffs):
|
||||
failures = [f for f, d in zip(fnames, diffs) if d]
|
||||
raise ValueError(
|
||||
f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
|
||||
" this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
|
||||
args = parser.parse_args()
|
||||
|
||||
sort_all_auto_mappings(not args.check_only)
|
||||
98
utils/split_doctest_jobs.py
Normal file
98
utils/split_doctest_jobs.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is used to get the files against which we will run doc testing.
|
||||
This uses `tests_fetcher.get_all_doctest_files` then groups the test files by their directory paths.
|
||||
|
||||
The files in `docs/source/en/model_doc` or `docs/source/en/tasks` are **NOT** grouped together with other files in the
|
||||
same directory: the objective is to run doctest against them in independent GitHub Actions jobs.
|
||||
|
||||
Assume we are under `transformers` root directory:
|
||||
To get a map (dictionary) between directory (or file) paths and the corresponding files
|
||||
```bash
|
||||
python utils/split_doctest_jobs.py
|
||||
```
|
||||
or to get a list of lists of directory (or file) paths
|
||||
```bash
|
||||
python utils/split_doctest_jobs.py --only_return_keys --num_splits 4
|
||||
```
|
||||
(this is used to allow GitHub Actions to generate more than 256 jobs using matrix)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from tests_fetcher import get_all_doctest_files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--only_return_keys",
|
||||
action="store_true",
|
||||
help="if to only return the keys (which is a list of list of files' directory or file paths).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_splits",
|
||||
type=int,
|
||||
default=1,
|
||||
help="the number of splits into which the (flat) list of directory/file paths will be split. This has effect only if `only_return_keys` is `True`.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
all_doctest_files = get_all_doctest_files()
|
||||
|
||||
raw_test_collection_map = defaultdict(list)
|
||||
|
||||
for file in all_doctest_files:
|
||||
file_dir = "/".join(Path(file).parents[0].parts)
|
||||
|
||||
# not to run files in `src/` for now as it is completely broken at this moment. See issues/39159 and
|
||||
# https://github.com/huggingface/transformers/actions/runs/15988670157
|
||||
# TODO (ydshieh): fix the error, ideally before 2025/09
|
||||
if file_dir.startswith("src/"):
|
||||
continue
|
||||
|
||||
raw_test_collection_map[file_dir].append(file)
|
||||
|
||||
refined_test_collection_map = {}
|
||||
for file_dir in raw_test_collection_map:
|
||||
if file_dir in ["docs/source/en/model_doc", "docs/source/en/tasks"]:
|
||||
for file in raw_test_collection_map[file_dir]:
|
||||
refined_test_collection_map[file] = file
|
||||
else:
|
||||
refined_test_collection_map[file_dir] = " ".join(sorted(raw_test_collection_map[file_dir]))
|
||||
|
||||
sorted_file_dirs = sorted(refined_test_collection_map.keys())
|
||||
|
||||
test_collection_map = {}
|
||||
for file_dir in sorted_file_dirs:
|
||||
test_collection_map[file_dir] = refined_test_collection_map[file_dir]
|
||||
|
||||
num_jobs = len(test_collection_map)
|
||||
num_jobs_per_splits = num_jobs // args.num_splits
|
||||
|
||||
file_directory_splits = []
|
||||
end = 0
|
||||
for idx in range(args.num_splits):
|
||||
start = end
|
||||
end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0)
|
||||
file_directory_splits.append(sorted_file_dirs[start:end])
|
||||
|
||||
if args.only_return_keys:
|
||||
print(file_directory_splits)
|
||||
else:
|
||||
print(dict(test_collection_map))
|
||||
88
utils/split_model_tests.py
Normal file
88
utils/split_model_tests.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is used to get the list of folders under `tests/models` and split the list into `NUM_SLICES` splits.
|
||||
The main use case is a GitHub Actions workflow file calling this script to get the (nested) list of folders allowing it
|
||||
to split the list of jobs to run into multiple slices each containing a smaller number of jobs. This way, we can bypass
|
||||
the maximum of 256 jobs in a matrix.
|
||||
|
||||
See the `setup` and `run_models_gpu` jobs defined in the workflow file `.github/workflows/self-scheduled.yml` for more
|
||||
details.
|
||||
|
||||
Usage:
|
||||
|
||||
This script is required to be run under `tests` folder of `transformers` root directory.
|
||||
|
||||
Assume we are under `transformers` root directory:
|
||||
```bash
|
||||
cd tests
|
||||
python ../utils/split_model_tests.py --num_splits 64
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--subdirs",
|
||||
type=str,
|
||||
default="",
|
||||
help="the list of pre-computed model names (directory names under `tests/models`) or directory names under `tests` (except `models`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_splits",
|
||||
type=int,
|
||||
default=1,
|
||||
help="the number of splits into which the (flat) list of folders will be split.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
tests = os.getcwd()
|
||||
model_tests = os.listdir(os.path.join(tests, "models"))
|
||||
d1 = sorted(filter(os.path.isdir, os.listdir(tests)))
|
||||
d2 = sorted(filter(os.path.isdir, [f"models/{x}" for x in model_tests]))
|
||||
d1.remove("models")
|
||||
d = d2 + d1
|
||||
|
||||
if args.subdirs != "":
|
||||
model_tests = ast.literal_eval(args.subdirs)
|
||||
# We handle both cases with and without prefix because `push-important-models.yml` returns the list without
|
||||
# the prefix (i.e. `models`) but `utils/pr_slow_ci_models.py` (called by `self-comment-ci.yml`) returns the
|
||||
# list with the prefix (`models`) and some directory names under `tests`.
|
||||
d = []
|
||||
for x in model_tests:
|
||||
if os.path.isdir(x):
|
||||
d.append(x)
|
||||
if os.path.isdir(f"models/{x}"):
|
||||
d.append(f"models/{x}")
|
||||
d = sorted(d)
|
||||
|
||||
num_jobs = len(d)
|
||||
num_jobs_per_splits = num_jobs // args.num_splits
|
||||
|
||||
model_splits = []
|
||||
end = 0
|
||||
for idx in range(args.num_splits):
|
||||
start = end
|
||||
end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0)
|
||||
# Only add the slice if it is not an empty list
|
||||
if len(d[start:end]) > 0:
|
||||
model_splits.append(d[start:end])
|
||||
|
||||
print(model_splits)
|
||||
0
utils/test_module/__init__.py
Normal file
0
utils/test_module/__init__.py
Normal file
9
utils/test_module/custom_configuration.py
Normal file
9
utils/test_module/custom_configuration.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from transformers import PreTrainedConfig
|
||||
|
||||
|
||||
class CustomConfig(PreTrainedConfig):
|
||||
model_type = "custom"
|
||||
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
super().__init__(**kwargs)
|
||||
5
utils/test_module/custom_feature_extraction.py
Normal file
5
utils/test_module/custom_feature_extraction.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from transformers import Wav2Vec2FeatureExtractor
|
||||
|
||||
|
||||
class CustomFeatureExtractor(Wav2Vec2FeatureExtractor):
|
||||
pass
|
||||
5
utils/test_module/custom_image_processing.py
Normal file
5
utils/test_module/custom_image_processing.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
|
||||
class CustomImageProcessor(CLIPImageProcessor):
|
||||
pass
|
||||
20
utils/test_module/custom_modeling.py
Normal file
20
utils/test_module/custom_modeling.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .custom_configuration import CustomConfig
|
||||
|
||||
|
||||
class CustomModel(PreTrainedModel):
|
||||
config_class = CustomConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.post_init()
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
33
utils/test_module/custom_pipeline.py
Normal file
33
utils/test_module/custom_pipeline.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import numpy as np
|
||||
|
||||
from transformers import Pipeline
|
||||
|
||||
|
||||
def softmax(outputs):
|
||||
maxes = np.max(outputs, axis=-1, keepdims=True)
|
||||
shifted_exp = np.exp(outputs - maxes)
|
||||
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
class PairClassificationPipeline(Pipeline):
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "second_text" in kwargs:
|
||||
preprocess_kwargs["second_text"] = kwargs["second_text"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def preprocess(self, text, second_text=None):
|
||||
return self.tokenizer(text, text_pair=second_text, return_tensors="pt")
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
return self.model(**model_inputs)
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
logits = model_outputs.logits[0].numpy()
|
||||
probabilities = softmax(logits)
|
||||
|
||||
best_class = np.argmax(probabilities)
|
||||
label = self.model.config.id2label[best_class]
|
||||
score = probabilities[best_class].item()
|
||||
logits = logits.tolist()
|
||||
return {"label": label, "score": score, "logits": logits}
|
||||
6
utils/test_module/custom_processing.py
Normal file
6
utils/test_module/custom_processing.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
|
||||
class CustomProcessor(ProcessorMixin):
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
5
utils/test_module/custom_tokenization.py
Normal file
5
utils/test_module/custom_tokenization.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
class CustomTokenizer(BertTokenizer):
|
||||
pass
|
||||
10
utils/test_module/custom_tokenization_fast.py
Normal file
10
utils/test_module/custom_tokenization_fast.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from transformers import BertTokenizerFast
|
||||
|
||||
from .custom_tokenization import CustomTokenizer
|
||||
|
||||
|
||||
class CustomTokenizerFast(BertTokenizerFast):
|
||||
slow_tokenizer_class = CustomTokenizer
|
||||
_auto_map = {
|
||||
"AutoTokenizer": ("custom_tokenization.CustomTokenizer", "custom_tokenization_fast.CustomTokenizerFast")
|
||||
}
|
||||
5
utils/test_module/custom_video_processing.py
Normal file
5
utils/test_module/custom_video_processing.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from transformers import LlavaOnevisionVideoProcessor
|
||||
|
||||
|
||||
class CustomVideoProcessor(LlavaOnevisionVideoProcessor):
|
||||
pass
|
||||
1189
utils/tests_fetcher.py
Normal file
1189
utils/tests_fetcher.py
Normal file
File diff suppressed because it is too large
Load Diff
360
utils/update_metadata.py
Executable file
360
utils/update_metadata.py
Executable file
@@ -0,0 +1,360 @@
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that updates the metadata of the Transformers library in the repository `huggingface/transformers-metadata`.
|
||||
|
||||
Usage for an update (as used by the GitHub action `update_metadata`):
|
||||
|
||||
```bash
|
||||
python utils/update_metadata.py --token <token> --commit_sha <commit_sha>
|
||||
```
|
||||
|
||||
Usage to check all pipelines are properly defined in the constant `PIPELINE_TAGS_AND_AUTO_MODELS` of this script, so
|
||||
that new pipelines are properly added as metadata (as used in `make check-repo`):
|
||||
|
||||
```bash
|
||||
python utils/update_metadata.py --check-only
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import pandas as pd
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import hf_hub_download, upload_folder
|
||||
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
|
||||
CHECKER_CONFIG = {
|
||||
"name": "update_metadata",
|
||||
"label": "Model metadata",
|
||||
# Approximate: imports the transformers module and inspects pipeline/auto mappings
|
||||
# at runtime. Does not iterate over files matching these globs directly.
|
||||
"cache_globs": ["src/transformers/models/**/*.py", "docs/**/*.md"],
|
||||
"check_args": ["--check-only"],
|
||||
# No safe local "fix" mode: running without `--check-only` pushes to the
|
||||
# `huggingface/transformers-metadata` Hub dataset (requires an auth token).
|
||||
# `fix_args=None` makes `make fix-repo` skip this checker, like other check-only ones.
|
||||
"fix_args": None,
|
||||
}
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/update_metadata.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
||||
|
||||
|
||||
# Regexes that match model names
|
||||
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)")
|
||||
|
||||
|
||||
# Fill this with tuples (pipeline_tag, model_mapping, auto_model)
|
||||
PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
("pretraining", "MODEL_FOR_PRETRAINING_MAPPING_NAMES", "AutoModelForPreTraining"),
|
||||
("feature-extraction", "MODEL_MAPPING_NAMES", "AutoModel"),
|
||||
("image-feature-extraction", "MODEL_FOR_IMAGE_MAPPING_NAMES", "AutoModel"),
|
||||
("audio-classification", "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForAudioClassification"),
|
||||
("text-generation", "MODEL_FOR_CAUSAL_LM_MAPPING_NAMES", "AutoModelForCausalLM"),
|
||||
("automatic-speech-recognition", "MODEL_FOR_CTC_MAPPING_NAMES", "AutoModelForCTC"),
|
||||
("image-classification", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForImageClassification"),
|
||||
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
|
||||
("any-to-any", "MODEL_FOR_MULTIMODAL_LM_MAPPING_NAMES", "AutoModelForMultimodalLM"),
|
||||
("image-text-to-text", "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "AutoModelForImageTextToText"),
|
||||
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
|
||||
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
|
||||
(
|
||||
"zero-shot-object-detection",
|
||||
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES",
|
||||
"AutoModelForZeroShotObjectDetection",
|
||||
),
|
||||
("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"),
|
||||
("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"),
|
||||
("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"),
|
||||
("automatic-speech-recognition", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", "AutoModelForSpeechSeq2Seq"),
|
||||
(
|
||||
"table-question-answering",
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForTableQuestionAnswering",
|
||||
),
|
||||
("token-classification", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES", "AutoModelForTokenClassification"),
|
||||
("multiple-choice", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES", "AutoModelForMultipleChoice"),
|
||||
(
|
||||
"next-sentence-prediction",
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES",
|
||||
"AutoModelForNextSentencePrediction",
|
||||
),
|
||||
(
|
||||
"audio-frame-classification",
|
||||
"MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES",
|
||||
"AutoModelForAudioFrameClassification",
|
||||
),
|
||||
("audio-xvector", "MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "AutoModelForAudioXVector"),
|
||||
(
|
||||
"document-question-answering",
|
||||
"MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForDocumentQuestionAnswering",
|
||||
),
|
||||
(
|
||||
"visual-question-answering",
|
||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForVisualQuestionAnswering",
|
||||
),
|
||||
(
|
||||
"zero-shot-image-classification",
|
||||
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
|
||||
"AutoModelForZeroShotImageClassification",
|
||||
),
|
||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
||||
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
|
||||
("text-to-audio", "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES", "AutoModelForTextToSpectrogram"),
|
||||
("text-to-audio", "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES", "AutoModelForTextToWaveform"),
|
||||
("keypoint-matching", "MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES", "AutoModelForKeypointMatching"),
|
||||
]
|
||||
|
||||
|
||||
def camel_case_split(identifier: str) -> list[str]:
|
||||
"""
|
||||
Split a camel-cased name into words.
|
||||
|
||||
Args:
|
||||
identifier (`str`): The camel-cased name to parse.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of words in the identifier (as separated by capital letters).
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> camel_case_split("CamelCasedClass")
|
||||
["Camel", "Cased", "Class"]
|
||||
```
|
||||
"""
|
||||
# Regex thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
||||
return [m.group(0) for m in matches]
|
||||
|
||||
|
||||
def get_frameworks_table() -> pd.DataFrame:
|
||||
"""
|
||||
Generates a dataframe containing the supported auto classes for each model type, using the content of the auto
|
||||
modules.
|
||||
"""
|
||||
# Dictionary model names to config.
|
||||
config_mapping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||
model_prefix_to_model_type = {
|
||||
config.replace("Config", ""): model_type for model_type, config in config_mapping_names.items()
|
||||
}
|
||||
|
||||
pt_models = collections.defaultdict(bool)
|
||||
|
||||
# Let's lookup through all transformers object (once) and find if models are supported by a given backend.
|
||||
for attr_name in dir(transformers_module):
|
||||
lookup_dict = None
|
||||
if _re_pt_models.match(attr_name) is not None:
|
||||
lookup_dict = pt_models
|
||||
attr_name = _re_pt_models.match(attr_name).groups()[0]
|
||||
|
||||
if lookup_dict is not None:
|
||||
while len(attr_name) > 0:
|
||||
if attr_name in model_prefix_to_model_type:
|
||||
lookup_dict[model_prefix_to_model_type[attr_name]] = True
|
||||
break
|
||||
# Try again after removing the last word in the name
|
||||
attr_name = "".join(camel_case_split(attr_name)[:-1])
|
||||
|
||||
all_models = set(pt_models.keys())
|
||||
all_models = list(all_models)
|
||||
all_models.sort()
|
||||
|
||||
data = {"model_type": all_models}
|
||||
data["pytorch"] = [pt_models[t] for t in all_models]
|
||||
|
||||
# Now let's find the right processing class for each model. In order we check if there is a Processor, then a
|
||||
# Tokenizer, then a FeatureExtractor, then an ImageProcessor
|
||||
processors = {}
|
||||
for t in all_models:
|
||||
if t in transformers_module.models.auto.processing_auto.PROCESSOR_MAPPING_NAMES:
|
||||
processors[t] = "AutoProcessor"
|
||||
elif t in transformers_module.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES:
|
||||
processors[t] = "AutoTokenizer"
|
||||
elif t in transformers_module.models.auto.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES:
|
||||
processors[t] = "AutoImageProcessor"
|
||||
elif t in transformers_module.models.auto.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES:
|
||||
processors[t] = "AutoFeatureExtractor"
|
||||
else:
|
||||
# Default to AutoTokenizer if a model has nothing, for backward compatibility.
|
||||
processors[t] = "AutoTokenizer"
|
||||
|
||||
data["processor"] = [processors[t] for t in all_models]
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
def update_pipeline_and_auto_class_table(table: dict[str, tuple[str, str]]) -> dict[str, tuple[str, str]]:
|
||||
"""
|
||||
Update the table mapping models to pipelines and auto classes without removing old keys if they don't exist anymore.
|
||||
|
||||
Args:
|
||||
table (`Dict[str, Tuple[str, str]]`):
|
||||
The existing table mapping model names to a tuple containing the pipeline tag and the auto-class name with
|
||||
which they should be used.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Tuple[str, str]]`: The updated table in the same format.
|
||||
"""
|
||||
module = transformers_module.models.auto.modeling_auto
|
||||
for pipeline_tag, model_mapping, cls in PIPELINE_TAGS_AND_AUTO_MODELS:
|
||||
if not hasattr(module, model_mapping):
|
||||
continue
|
||||
# First extract all model_names
|
||||
model_names = []
|
||||
for name in getattr(module, model_mapping).values():
|
||||
if isinstance(name, str):
|
||||
model_names.append(name)
|
||||
else:
|
||||
model_names.extend(list(name))
|
||||
|
||||
# Add pipeline tag and auto model class for those models
|
||||
table.update(dict.fromkeys(model_names, (pipeline_tag, cls)))
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def update_metadata(token: str, commit_sha: str):
|
||||
"""
|
||||
Update the metadata for the Transformers repo in `huggingface/transformers-metadata`.
|
||||
|
||||
Args:
|
||||
token (`str`): A valid token giving write access to `huggingface/transformers-metadata`.
|
||||
commit_sha (`str`): The commit SHA on Transformers corresponding to this update.
|
||||
"""
|
||||
frameworks_table = get_frameworks_table()
|
||||
frameworks_dataset = Dataset.from_pandas(frameworks_table)
|
||||
|
||||
resolved_tags_file = hf_hub_download(
|
||||
"huggingface/transformers-metadata", "pipeline_tags.json", repo_type="dataset", token=token
|
||||
)
|
||||
tags_dataset = Dataset.from_json(resolved_tags_file)
|
||||
table = {
|
||||
tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"])
|
||||
for i in range(len(tags_dataset))
|
||||
}
|
||||
table = update_pipeline_and_auto_class_table(table)
|
||||
|
||||
# Sort the model classes to avoid some nondeterministic updates to create false update commits.
|
||||
model_classes = sorted(table.keys())
|
||||
tags_table = pd.DataFrame(
|
||||
{
|
||||
"model_class": model_classes,
|
||||
"pipeline_tag": [table[m][0] for m in model_classes],
|
||||
"auto_class": [table[m][1] for m in model_classes],
|
||||
}
|
||||
)
|
||||
tags_dataset = Dataset.from_pandas(tags_table)
|
||||
|
||||
hub_frameworks_json = hf_hub_download(
|
||||
repo_id="huggingface/transformers-metadata",
|
||||
filename="frameworks.json",
|
||||
repo_type="dataset",
|
||||
token=token,
|
||||
)
|
||||
with open(hub_frameworks_json) as f:
|
||||
hub_frameworks_json = f.read()
|
||||
|
||||
hub_pipeline_tags_json = hf_hub_download(
|
||||
repo_id="huggingface/transformers-metadata",
|
||||
filename="pipeline_tags.json",
|
||||
repo_type="dataset",
|
||||
token=token,
|
||||
)
|
||||
with open(hub_pipeline_tags_json) as f:
|
||||
hub_pipeline_tags_json = f.read()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json"))
|
||||
tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
|
||||
|
||||
with open(os.path.join(tmp_dir, "frameworks.json")) as f:
|
||||
frameworks_json = f.read()
|
||||
with open(os.path.join(tmp_dir, "pipeline_tags.json")) as f:
|
||||
pipeline_tags_json = f.read()
|
||||
|
||||
frameworks_equal = hub_frameworks_json == frameworks_json
|
||||
hub_pipeline_tags_equal = hub_pipeline_tags_json == pipeline_tags_json
|
||||
|
||||
if frameworks_equal and hub_pipeline_tags_equal:
|
||||
print("No updates on the Hub, not pushing the metadata files.")
|
||||
return
|
||||
|
||||
if commit_sha is not None:
|
||||
commit_message = (
|
||||
f"Update with commit {commit_sha}\n\nSee: "
|
||||
f"https://github.com/huggingface/transformers/commit/{commit_sha}"
|
||||
)
|
||||
else:
|
||||
commit_message = "Update"
|
||||
|
||||
upload_folder(
|
||||
repo_id="huggingface/transformers-metadata",
|
||||
folder_path=tmp_dir,
|
||||
repo_type="dataset",
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
|
||||
def check_pipeline_tags():
|
||||
"""
|
||||
Check all pipeline tags are properly defined in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant of this script.
|
||||
"""
|
||||
in_table = {tag: cls for tag, _, cls in PIPELINE_TAGS_AND_AUTO_MODELS}
|
||||
pipeline_tasks = transformers_module.pipelines.SUPPORTED_TASKS
|
||||
missing = []
|
||||
for key in pipeline_tasks:
|
||||
if key not in in_table:
|
||||
model = pipeline_tasks[key]["pt"]
|
||||
if isinstance(model, (list, tuple)):
|
||||
model = model[0]
|
||||
model = model.__name__
|
||||
if model not in in_table.values():
|
||||
missing.append(key)
|
||||
|
||||
if len(missing) > 0:
|
||||
msg = ", ".join(missing)
|
||||
raise ValueError(
|
||||
"The following pipeline tags are not present in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant inside "
|
||||
f"`utils/update_metadata.py`: {msg}. Please add them!"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--token", type=str, help="The token to use to push to the transformers-metadata dataset.")
|
||||
parser.add_argument("--commit_sha", type=str, help="The sha of the commit going with this update.")
|
||||
parser.add_argument("--check-only", action="store_true", help="Activate to just check all pipelines are present.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.check_only:
|
||||
check_pipeline_tags()
|
||||
else:
|
||||
update_metadata(args.token, args.commit_sha)
|
||||
173
utils/update_tiny_models.py
Normal file
173
utils/update_tiny_models.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A script running `create_dummy_models.py` with a pre-defined set of arguments.
|
||||
|
||||
This file is intended to be used in a CI workflow file without the need of specifying arguments. It creates and uploads
|
||||
tiny models for all model classes (if their tiny versions are not on the Hub yet), as well as produces an updated
|
||||
version of `tests/utils/tiny_model_summary.json`. That updated file should be merged into the `main` branch of
|
||||
`transformers` so the pipeline testing will use the latest created/updated tiny models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
from create_dummy_models import COMPOSITE_MODELS, create_tiny_models
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import transformers
|
||||
from transformers import AutoFeatureExtractor, AutoImageProcessor, AutoTokenizer, logging
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_all_model_names():
|
||||
model_names = set()
|
||||
|
||||
module_name = "modeling_auto"
|
||||
module = getattr(transformers.models.auto, module_name, None)
|
||||
if module is not None:
|
||||
# all mappings in a single auto modeling file
|
||||
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES") and x.startswith("MODEL_")]
|
||||
for name in mapping_names:
|
||||
mapping = getattr(module, name)
|
||||
if mapping is not None:
|
||||
for v in mapping.values():
|
||||
if isinstance(v, (list, tuple)):
|
||||
model_names.update(v)
|
||||
elif isinstance(v, str):
|
||||
model_names.add(v)
|
||||
|
||||
return sorted(model_names)
|
||||
|
||||
|
||||
def get_tiny_model_names_from_repo():
|
||||
with open("tests/utils/tiny_model_summary.json") as fp:
|
||||
tiny_model_info = json.load(fp)
|
||||
tiny_models_names = set()
|
||||
for model_base_name in tiny_model_info:
|
||||
tiny_models_names.update(tiny_model_info[model_base_name]["model_classes"])
|
||||
|
||||
return sorted(tiny_models_names)
|
||||
|
||||
|
||||
def get_tiny_model_summary_from_hub(output_path):
|
||||
api = HfApi()
|
||||
special_models = COMPOSITE_MODELS.values()
|
||||
|
||||
# All tiny model base names on Hub
|
||||
model_names = get_all_model_names()
|
||||
models = api.list_models(author="hf-internal-testing")
|
||||
_models = set()
|
||||
for x in models:
|
||||
model = x.id
|
||||
org, model = model.split("/")
|
||||
if not model.startswith("tiny-random-"):
|
||||
continue
|
||||
model = model.replace("tiny-random-", "")
|
||||
if not model[0].isupper():
|
||||
continue
|
||||
if model not in model_names and model not in special_models:
|
||||
continue
|
||||
_models.add(model)
|
||||
|
||||
models = sorted(_models)
|
||||
# All tiny model names on Hub
|
||||
summary = {}
|
||||
for model in models:
|
||||
repo_id = f"hf-internal-testing/tiny-random-{model}"
|
||||
model = model.split("-")[0]
|
||||
try:
|
||||
repo_info = api.repo_info(repo_id)
|
||||
content = {
|
||||
"tokenizer_classes": set(),
|
||||
"processor_classes": set(),
|
||||
"model_classes": set(),
|
||||
"sha": repo_info.sha,
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
try:
|
||||
time.sleep(1)
|
||||
tokenizer_fast = AutoTokenizer.from_pretrained(repo_id)
|
||||
content["tokenizer_classes"].add(tokenizer_fast.__class__.__name__)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load fast tokenizer for {repo_id}: {e}")
|
||||
try:
|
||||
time.sleep(1)
|
||||
tokenizer_slow = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
|
||||
content["tokenizer_classes"].add(tokenizer_slow.__class__.__name__)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load slow tokenizer for {repo_id}: {e}")
|
||||
try:
|
||||
time.sleep(1)
|
||||
img_p = AutoImageProcessor.from_pretrained(repo_id)
|
||||
content["processor_classes"].add(img_p.__class__.__name__)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load image processor for {repo_id}: {e}")
|
||||
try:
|
||||
time.sleep(1)
|
||||
feat_p = AutoFeatureExtractor.from_pretrained(repo_id)
|
||||
if not isinstance(feat_p, BaseImageProcessor):
|
||||
content["processor_classes"].add(feat_p.__class__.__name__)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load feature extractor for {repo_id}: {e}")
|
||||
try:
|
||||
time.sleep(1)
|
||||
model_class = getattr(transformers, model)
|
||||
m = model_class.from_pretrained(repo_id)
|
||||
content["model_classes"].add(m.__class__.__name__)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load model for {repo_id}: {e}")
|
||||
|
||||
content["tokenizer_classes"] = sorted(content["tokenizer_classes"])
|
||||
content["processor_classes"] = sorted(content["processor_classes"])
|
||||
content["model_classes"] = sorted(content["model_classes"])
|
||||
|
||||
summary[model] = content
|
||||
with open(os.path.join(output_path, "hub_tiny_model_summary.json"), "w") as fp:
|
||||
json.dump(summary, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num_workers", default=1, type=int, help="The number of workers to run.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# This has to be `spawn` to avoid hanging forever!
|
||||
multiprocessing.set_start_method("spawn")
|
||||
|
||||
output_path = "tiny_models"
|
||||
all = True
|
||||
model_types = None
|
||||
models_to_skip = get_tiny_model_names_from_repo()
|
||||
no_check = True
|
||||
upload = True
|
||||
organization = "hf-internal-testing"
|
||||
|
||||
create_tiny_models(
|
||||
output_path,
|
||||
all,
|
||||
model_types,
|
||||
models_to_skip,
|
||||
no_check,
|
||||
upload,
|
||||
organization,
|
||||
token=os.environ.get("TOKEN", None),
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
Reference in New Issue
Block a user