From 39925d68fed72945105aad4aaf32702f1ea8716a Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Tue, 6 May 2025 21:06:11 +0800 Subject: [PATCH 1/2] [MISC][pre-commit] add pre-commit check for triton import Signed-off-by: Mengqing Cao --- .pre-commit-config.yaml | 7 ++++ tools/check_triton_import.py | 68 ++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 tools/check_triton_import.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3dc06952c0d..66d937ef1be 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -125,6 +125,13 @@ repos: name: Update Dockerfile dependency graph entry: tools/update-dockerfile-graph.sh language: script + # forbid directly import triton + - id: forbid-direct-triton-import + name: "Forbid direct 'import triton'" + entry: python tools/check_triton_import.py + language: python + types: [python] + pass_filenames: false # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/tools/check_triton_import.py b/tools/check_triton_import.py new file mode 100644 index 00000000000..d2ab71f883f --- /dev/null +++ b/tools/check_triton_import.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +import re +import subprocess +import sys + +FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)") +ALLOWED_LINE = "from vllm.triton_utils import triton" + + +def is_forbidden_import(line: str) -> bool: + return bool(FORBIDDEN_IMPORT_RE.match( + line.strip())) and ALLOWED_LINE not in line.strip() + + +def parse_diff(diff: str) -> list[str]: + violations = [] + current_file = None + current_lineno = None + + for line in diff.splitlines(): + if line.startswith("+++ b/"): + current_file = line[6:] + elif line.startswith("@@"): + match = re.search(r"\+(\d+)", line) + if match: + current_lineno = int( + match.group(1)) - 1 # next "+ line" is here + elif line.startswith("+") and not line.startswith("++"): + current_lineno += 1 + code_line = line[1:] + if is_forbidden_import(code_line): + violations.append( + f"{current_file}:{current_lineno}: {code_line.strip()}") + return violations + + +def get_diff(diff_type: str) -> str: + if diff_type == "staged": + return subprocess.check_output( + ["git", "diff", "--cached", "--unified=0"], text=True) + elif diff_type == "unstaged": + return subprocess.check_output(["git", "diff", "--unified=0"], + text=True) + else: + raise ValueError(f"Unknown diff_type: {diff_type}") + + +def main(): + all_violations = [] + for diff_type in ["staged", "unstaged"]: + try: + diff_output = get_diff(diff_type) + violations = parse_diff(diff_output) + all_violations.extend(violations) + except subprocess.CalledProcessError as e: + print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr) + + if all_violations: + print("❌ Forbidden direct `import triton` detected." + " ➤ Use `from vllm.triton_utils import triton` instead.\n") + for v in all_violations: + print(f"❌ {v}") + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From 59b1790daa79011a4764f91bff49b9ef9bc34209 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Thu, 8 May 2025 13:49:46 +0800 Subject: [PATCH 2/2] add more allowed import Signed-off-by: Mengqing Cao --- tools/check_triton_import.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tools/check_triton_import.py b/tools/check_triton_import.py index d2ab71f883f..d938ff1df59 100644 --- a/tools/check_triton_import.py +++ b/tools/check_triton_import.py @@ -4,12 +4,19 @@ import sys FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)") -ALLOWED_LINE = "from vllm.triton_utils import triton" + +# the way allowed to import triton +ALLOWED_LINES = { + "from vllm.triton_utils import triton", + "from vllm.triton_utils import tl", + "from vllm.triton_utils import tl, triton", +} def is_forbidden_import(line: str) -> bool: - return bool(FORBIDDEN_IMPORT_RE.match( - line.strip())) and ALLOWED_LINE not in line.strip() + stripped = line.strip() + return bool( + FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES def parse_diff(diff: str) -> list[str]: