Skip to content

Commit e77dc4b

Browse files
authored
[MISC][pre-commit] Add pre-commit check for triton import (#17716)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
1 parent 07458a5 commit e77dc4b

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ repos:
128128
name: Update Dockerfile dependency graph
129129
entry: tools/update-dockerfile-graph.sh
130130
language: script
131+
# forbid directly import triton
132+
- id: forbid-direct-triton-import
133+
name: "Forbid direct 'import triton'"
134+
entry: python tools/check_triton_import.py
135+
language: python
136+
types: [python]
137+
pass_filenames: false
131138
# Keep `suggestion` last
132139
- id: suggestion
133140
name: Suggestion

tools/check_triton_import.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import re
3+
import subprocess
4+
import sys
5+
6+
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")
7+
8+
# the way allowed to import triton
9+
ALLOWED_LINES = {
10+
"from vllm.triton_utils import triton",
11+
"from vllm.triton_utils import tl",
12+
"from vllm.triton_utils import tl, triton",
13+
}
14+
15+
16+
def is_forbidden_import(line: str) -> bool:
17+
stripped = line.strip()
18+
return bool(
19+
FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
20+
21+
22+
def parse_diff(diff: str) -> list[str]:
23+
violations = []
24+
current_file = None
25+
current_lineno = None
26+
27+
for line in diff.splitlines():
28+
if line.startswith("+++ b/"):
29+
current_file = line[6:]
30+
elif line.startswith("@@"):
31+
match = re.search(r"\+(\d+)", line)
32+
if match:
33+
current_lineno = int(
34+
match.group(1)) - 1 # next "+ line" is here
35+
elif line.startswith("+") and not line.startswith("++"):
36+
current_lineno += 1
37+
code_line = line[1:]
38+
if is_forbidden_import(code_line):
39+
violations.append(
40+
f"{current_file}:{current_lineno}: {code_line.strip()}")
41+
return violations
42+
43+
44+
def get_diff(diff_type: str) -> str:
45+
if diff_type == "staged":
46+
return subprocess.check_output(
47+
["git", "diff", "--cached", "--unified=0"], text=True)
48+
elif diff_type == "unstaged":
49+
return subprocess.check_output(["git", "diff", "--unified=0"],
50+
text=True)
51+
else:
52+
raise ValueError(f"Unknown diff_type: {diff_type}")
53+
54+
55+
def main():
56+
all_violations = []
57+
for diff_type in ["staged", "unstaged"]:
58+
try:
59+
diff_output = get_diff(diff_type)
60+
violations = parse_diff(diff_output)
61+
all_violations.extend(violations)
62+
except subprocess.CalledProcessError as e:
63+
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
64+
65+
if all_violations:
66+
print("❌ Forbidden direct `import triton` detected."
67+
" ➤ Use `from vllm.triton_utils import triton` instead.\n")
68+
for v in all_violations:
69+
print(f"❌ {v}")
70+
return 1
71+
return 0
72+
73+
74+
if __name__ == "__main__":
75+
sys.exit(main())

0 commit comments

Comments
 (0)