Files
transformers/utils/check_import_complexity.py
陈赣 06f1fd69a6
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
first commit
2026-06-05 16:53:03 +08:00

254 lines
7.9 KiB
Python

#!/usr/bin/env python3
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Check that `import transformers` does not pull in too many modules.
Traces the full import tree triggered by `import transformers` using a custom
``importlib.abc.MetaPathFinder`` and counts every module that gets loaded.
If the count exceeds ``MAX_IMPORT_COUNT`` the check fails, signalling a
potential regression in import speed.
Usage:
python utils/check_import_complexity.py # CI check mode
python utils/check_import_complexity.py --display # show the full import tree
"""
from __future__ import annotations
import argparse
import importlib
import importlib.abc
import sys
import threading
from dataclasses import dataclass, field
from types import ModuleType
from typing import Any
MAX_IMPORT_COUNT = 1000
# ---------------------------------------------------------------------------
# Import-tree data structures
# ---------------------------------------------------------------------------
@dataclass
class ImportNode:
name: str
children: list[ImportNode] = field(default_factory=list)
class LoaderProxy(importlib.abc.Loader):
"""Wrap a real loader to track the import stack during exec_module."""
def __init__(self, wrapped: Any, tracer: ImportTreeTracer, fullname: str):
self._wrapped = wrapped
self._tracer = tracer
self._fullname = fullname
def create_module(self, spec):
if hasattr(self._wrapped, "create_module"):
return self._wrapped.create_module(spec)
return None
def exec_module(self, module: ModuleType) -> None:
self._tracer.push(self._fullname)
try:
if hasattr(self._wrapped, "exec_module"):
self._wrapped.exec_module(module)
elif hasattr(self._wrapped, "load_module"):
self._wrapped.load_module(self._fullname)
else:
raise ImportError(f"Loader for {self._fullname!r} has neither exec_module nor load_module")
finally:
self._tracer.pop()
def __getattr__(self, name: str) -> Any:
return getattr(self._wrapped, name)
class ImportTreeFinder(importlib.abc.MetaPathFinder):
"""Intercept imports to build a parent/child tree of loaded modules."""
def __init__(self, tracer: ImportTreeTracer, original_meta_path: list[Any]):
self._tracer = tracer
self._original = list(original_meta_path)
def find_spec(self, fullname: str, path=None, target=None):
if self._tracer.is_seen(fullname):
return None
for finder in self._original:
try:
spec = finder.find_spec(fullname, path, target) if hasattr(finder, "find_spec") else None
except Exception:
continue
if spec is None:
continue
self._tracer.record(fullname)
if spec.loader is not None:
spec.loader = LoaderProxy(spec.loader, self._tracer, fullname)
return spec
return None
class ImportTreeTracer:
def __init__(self) -> None:
self._local = threading.local()
self._nodes: dict[str, ImportNode] = {}
self._roots: list[ImportNode] = []
self._seen: set[str] = set()
def _stack(self) -> list[str]:
stack = getattr(self._local, "stack", None)
if stack is None:
stack = []
self._local.stack = stack
return stack
def is_seen(self, fullname: str) -> bool:
return fullname in self._seen
def _get_or_create(self, fullname: str) -> ImportNode:
if fullname not in self._nodes:
self._nodes[fullname] = ImportNode(name=fullname)
return self._nodes[fullname]
def record(self, fullname: str) -> None:
if fullname in self._seen:
return
self._seen.add(fullname)
node = self._get_or_create(fullname)
stack = self._stack()
if stack:
parent = self._get_or_create(stack[-1])
if all(c.name != fullname for c in parent.children):
parent.children.append(node)
else:
if all(r.name != fullname for r in self._roots):
self._roots.append(node)
def push(self, fullname: str) -> None:
self._stack().append(fullname)
def pop(self) -> None:
stack = self._stack()
if stack:
stack.pop()
@property
def count(self) -> int:
return len(self._seen)
@property
def roots(self) -> list[ImportNode]:
return self._roots
# ---------------------------------------------------------------------------
# Tracing entry-point
# ---------------------------------------------------------------------------
def trace_import(target: str) -> ImportTreeTracer:
tracer = ImportTreeTracer()
original_meta_path = list(sys.meta_path)
finder = ImportTreeFinder(tracer, original_meta_path)
sys.meta_path.insert(0, finder)
try:
importlib.import_module(target)
finally:
try:
sys.meta_path.remove(finder)
except ValueError:
pass
return tracer
# ---------------------------------------------------------------------------
# Display helpers
# ---------------------------------------------------------------------------
def format_tree(nodes: list[ImportNode]) -> str:
lines: list[str] = []
def _walk(node: ImportNode, prefix: str, is_last: bool) -> None:
connector = "└── " if is_last else "├── "
lines.append(f"{prefix}{connector}{node.name}")
child_prefix = prefix + (" " if is_last else "")
for i, child in enumerate(node.children):
_walk(child, child_prefix, i == len(node.children) - 1)
for i, root in enumerate(nodes):
_walk(root, "", i == len(nodes) - 1)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> int:
parser = argparse.ArgumentParser(description="Check import complexity for `import transformers`.")
parser.add_argument(
"--display",
action="store_true",
help="Display the full import tree (for debugging regressions).",
)
parser.add_argument(
"--max-count",
type=int,
default=MAX_IMPORT_COUNT,
help=f"Maximum allowed number of imported modules (default: {MAX_IMPORT_COUNT}).",
)
args = parser.parse_args()
try:
tracer = trace_import("transformers")
except Exception as exc:
print(f"ERROR: `import transformers` failed: {exc}", file=sys.stderr)
return 1
if args.display:
print(format_tree(tracer.roots))
print()
print(f"Total modules imported: {tracer.count}")
return 0
if tracer.count > args.max_count:
print(
f"Import complexity regression: `import transformers` triggered {tracer.count} module imports "
f"(maximum allowed: {args.max_count}).\n"
f"\n"
f"Run the following command to display the full import tree and identify the cause:\n"
f"\n"
f" python utils/check_import_complexity.py --display\n"
)
return 1
print(f"Import complexity OK: {tracer.count} modules (max {args.max_count})")
return 0
if __name__ == "__main__":
raise SystemExit(main())