Skip to content

Commit b8bc53b

Browse files
MengqingCaodtrifiro
authored andcommitted
[Bugfix] Fix triton import with local TritonPlaceholder (vllm-project#17446)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
1 parent 77f1266 commit b8bc53b

File tree

3 files changed

+32
-33
lines changed

3 files changed

+32
-33
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from vllm.model_executor.layers.fused_moe.fused_moe import *
1717
from vllm.platforms import current_platform
18-
from vllm.transformers_utils.config import get_config
1918
from vllm.triton_utils import triton
2019
from vllm.utils import FlexibleArgumentParser
2120

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
# Import vLLM functions
1212
from vllm import _custom_ops as ops
1313
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
14-
per_token_group_quant_fp8,
15-
w8a8_block_fp8_matmul,
16-
)
14+
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
1715
from vllm.triton_utils import triton
1816

1917

vllm/triton_utils/importing.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,34 @@
1616
logger.info("Triton not installed or not compatible; certain GPU-related"
1717
" functions will not be available.")
1818

19-
class TritonPlaceholder(types.ModuleType):
20-
21-
def __init__(self):
22-
super().__init__("triton")
23-
self.jit = self._dummy_decorator("jit")
24-
self.autotune = self._dummy_decorator("autotune")
25-
self.heuristics = self._dummy_decorator("heuristics")
26-
self.language = TritonLanguagePlaceholder()
27-
logger.warning_once(
28-
"Triton is not installed. Using dummy decorators. "
29-
"Install it via `pip install triton` to enable kernel"
30-
"compilation.")
31-
32-
def _dummy_decorator(self, name):
33-
34-
def decorator(func=None, **kwargs):
35-
if func is None:
36-
return lambda f: f
37-
return func
38-
39-
return decorator
40-
41-
class TritonLanguagePlaceholder(types.ModuleType):
42-
43-
def __init__(self):
44-
super().__init__("triton.language")
45-
self.constexpr = None
46-
self.dtype = None
47-
self.int64 = None
19+
20+
class TritonPlaceholder(types.ModuleType):
21+
22+
def __init__(self):
23+
super().__init__("triton")
24+
self.jit = self._dummy_decorator("jit")
25+
self.autotune = self._dummy_decorator("autotune")
26+
self.heuristics = self._dummy_decorator("heuristics")
27+
self.language = TritonLanguagePlaceholder()
28+
logger.warning_once(
29+
"Triton is not installed. Using dummy decorators. "
30+
"Install it via `pip install triton` to enable kernel"
31+
" compilation.")
32+
33+
def _dummy_decorator(self, name):
34+
35+
def decorator(*args, **kwargs):
36+
if args and callable(args[0]):
37+
return args[0]
38+
return lambda f: f
39+
40+
return decorator
41+
42+
43+
class TritonLanguagePlaceholder(types.ModuleType):
44+
45+
def __init__(self):
46+
super().__init__("triton.language")
47+
self.constexpr = None
48+
self.dtype = None
49+
self.int64 = None

0 commit comments

Comments
 (0)