Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
102 lines
4.8 KiB
Python
102 lines
4.8 KiB
Python
import re
|
|
|
|
from transformers.pipelines import SUPPORTED_TASKS, Pipeline
|
|
|
|
|
|
CHECKER_CONFIG = {
|
|
"name": "pipeline_typing",
|
|
"label": "Pipeline type hints",
|
|
"cache_globs": ["src/transformers/pipelines/__init__.py"],
|
|
"check_args": [],
|
|
"fix_args": ["--fix_and_overwrite"],
|
|
}
|
|
|
|
HEADER = """
|
|
# fmt: off
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# The part of the file below was automatically generated from the code.
|
|
# Do NOT edit this part of the file manually as any edits will be overwritten by the generation
|
|
# of the file. If any change should be done, please apply the changes to the `pipeline` function
|
|
# below and run `python utils/check_pipeline_typing.py --fix_and_overwrite` to update the file.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
|
|
from typing import Literal, overload
|
|
|
|
|
|
"""
|
|
|
|
FOOTER = """
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# The part of the file above was automatically generated from the code.
|
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
|
# fmt: on
|
|
"""
|
|
|
|
TASK_PATTERN = "task: str | None = None"
|
|
|
|
|
|
def main(pipeline_file_path: str, fix_and_overwrite: bool = False):
|
|
with open(pipeline_file_path, "r") as file:
|
|
content = file.read()
|
|
|
|
# extract generated code in between <generated-code> and </generated-code>
|
|
current_generated_code = re.search(r"# <generated-code>(.*)# </generated-code>", content, re.DOTALL).group(1)
|
|
content_without_generated_code = content.replace(current_generated_code, "")
|
|
|
|
# extract pipeline signature in between `def pipeline` and `-> Pipeline`
|
|
pipeline_signature = re.search(r"def pipeline(.*) -> Pipeline:", content_without_generated_code, re.DOTALL).group(
|
|
1
|
|
)
|
|
pipeline_signature = pipeline_signature.replace("(\n ", "(") # start of the signature
|
|
pipeline_signature = pipeline_signature.replace(",\n ", ", ") # intermediate arguments
|
|
pipeline_signature = pipeline_signature.replace(",\n)", ")") # end of the signature
|
|
|
|
# collect and sort available pipelines
|
|
pipelines = [(f'"{task}"', task_info["impl"]) for task, task_info in SUPPORTED_TASKS.items()]
|
|
pipelines = sorted(pipelines, key=lambda x: x[0])
|
|
pipelines.insert(0, (None, Pipeline))
|
|
|
|
# generate new `pipeline` signatures
|
|
new_generated_code = ""
|
|
for task, pipeline_class in pipelines:
|
|
if TASK_PATTERN not in pipeline_signature:
|
|
raise ValueError(f"Can't find `{TASK_PATTERN}` in pipeline signature: {pipeline_signature}")
|
|
pipeline_type = pipeline_class if isinstance(pipeline_class, str) else pipeline_class.__name__
|
|
new_pipeline_signature = pipeline_signature.replace(TASK_PATTERN, f"task: Literal[{task}]")
|
|
new_generated_code += f"@overload\ndef pipeline{new_pipeline_signature} -> {pipeline_type}: ...\n"
|
|
|
|
new_generated_code = HEADER + new_generated_code + FOOTER
|
|
new_generated_code = new_generated_code.rstrip("\n") + "\n"
|
|
|
|
if new_generated_code != current_generated_code and fix_and_overwrite:
|
|
print(f"Updating {pipeline_file_path}...")
|
|
wrapped_current_generated_code = "# <generated-code>" + current_generated_code + "# </generated-code>"
|
|
wrapped_new_generated_code = "# <generated-code>" + new_generated_code + "# </generated-code>"
|
|
content = content.replace(wrapped_current_generated_code, wrapped_new_generated_code)
|
|
|
|
# write content to file
|
|
with open(pipeline_file_path, "w") as file:
|
|
file.write(content)
|
|
|
|
elif new_generated_code != current_generated_code and not fix_and_overwrite:
|
|
message = (
|
|
f"Found inconsistencies in {pipeline_file_path}. "
|
|
"Run `python utils/check_pipeline_typing.py --fix_and_overwrite` to fix them."
|
|
)
|
|
raise ValueError(message)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
|
parser.add_argument(
|
|
"--pipeline_file_path",
|
|
type=str,
|
|
default="src/transformers/pipelines/__init__.py",
|
|
help="Path to the pipeline file.",
|
|
)
|
|
args = parser.parse_args()
|
|
main(args.pipeline_file_path, args.fix_and_overwrite)
|