Skip to content

[Bugfix] Fix triton import with local TritonPlaceholder #17446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser

FP8_DTYPE = current_platform.fp8_dtype()
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from typing import Optional, Union

import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn

from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton


class HuggingFaceRMSNorm(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# Import DeepGEMM functions
import deep_gemm
import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor

# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.triton_utils import triton


# Copied from
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/attention/test_flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import pytest
import torch
import triton

from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.triton_utils import triton


def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
Expand Down
92 changes: 92 additions & 0 deletions tests/test_triton_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0

import sys
import types
from unittest import mock

from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
TritonPlaceholder)


def test_triton_placeholder_is_module():
triton = TritonPlaceholder()
assert isinstance(triton, types.ModuleType)
assert triton.__name__ == "triton"


def test_triton_language_placeholder_is_module():
triton_language = TritonLanguagePlaceholder()
assert isinstance(triton_language, types.ModuleType)
assert triton_language.__name__ == "triton.language"


def test_triton_placeholder_decorators():
triton = TritonPlaceholder()

@triton.jit
def foo(x):
return x

@triton.autotune
def bar(x):
return x

@triton.heuristics
def baz(x):
return x

assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3


def test_triton_placeholder_decorators_with_args():
triton = TritonPlaceholder()

@triton.jit(debug=True)
def foo(x):
return x

@triton.autotune(configs=[], key="x")
def bar(x):
return x

@triton.heuristics(
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
def baz(x):
return x

assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3


def test_triton_placeholder_language():
lang = TritonLanguagePlaceholder()
assert isinstance(lang, types.ModuleType)
assert lang.__name__ == "triton.language"
assert lang.constexpr is None
assert lang.dtype is None
assert lang.int64 is None


def test_triton_placeholder_language_from_parent():
triton = TritonPlaceholder()
lang = triton.language
assert isinstance(lang, TritonLanguagePlaceholder)


def test_no_triton_fallback():
# clear existing triton modules
sys.modules.pop("triton", None)
sys.modules.pop("triton.language", None)
sys.modules.pop("vllm.triton_utils", None)
sys.modules.pop("vllm.triton_utils.importing", None)

# mock triton not being installed
with mock.patch.dict(sys.modules, {"triton": None}):
from vllm.triton_utils import HAS_TRITON, tl, triton
assert HAS_TRITON is False
assert triton.__class__.__name__ == "TritonPlaceholder"
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


def blocksparse_flash_attn_varlen_fwd(
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/ops/blocksparse_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

import numpy as np
import torch
import triton

from vllm.triton_utils import triton


class csr_matrix:
Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
# - Thomas Parnell <tpa@zurich.ibm.com>

import torch
import triton
import triton.language as tl

from vllm import _custom_ops as ops
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton

from .prefix_prefill import context_attention_fwd

Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py

import torch
import triton
import triton.language as tl

from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
Expand Down
4 changes: 1 addition & 3 deletions vllm/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@

import logging

import triton
import triton.language as tl

from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

is_hip_ = current_platform.is_rocm()

Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@
from typing import Optional

import torch
import triton
import triton.language as tl

from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']

Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/ops/triton_merge_attn_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Optional

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
Expand Down
3 changes: 1 addition & 2 deletions vllm/lora/ops/triton_ops/kernel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
"""
Utilities for Punica kernel construction.
"""
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton


@triton.jit
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import triton
import triton.language as tl

import vllm.envs as envs
from vllm import _custom_ops as ops
Expand All @@ -21,6 +19,7 @@
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op

from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/fused_moe/moe_align_block_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from typing import Optional, Tuple

import torch
import triton
import triton.language as tl

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton
from vllm.utils import round_up


Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/lightning_attn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange

from vllm.triton_utils import tl, triton


@triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py

import torch
import triton
import triton.language as tl
from packaging import version

from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON
from vllm.triton_utils import HAS_TRITON, tl, triton

TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
>= version.parse("3.0.0"))
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import math

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


@triton.autotune(
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
# ruff: noqa: E501,SIM102

import torch
import triton
import triton.language as tl
from packaging import version

from vllm.triton_utils import tl, triton

TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')


Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import math

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton

from .mamba_ssm import softplus

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/mamba/ops/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
# ruff: noqa: E501

import torch
import triton
from einops import rearrange
from packaging import version

from vllm.triton_utils import triton

from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# ruff: noqa: E501

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


@triton.autotune(
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/awq_triton.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton

AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Optional, Type

import torch
import triton
import triton.language as tl

from vllm.triton_utils import tl, triton


def is_weak_contiguous(x: torch.Tensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import triton
import triton.language as tl

from vllm import _custom_ops as ops
from vllm.logger import init_logger
Expand All @@ -17,6 +15,7 @@
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op

logger = init_logger(__name__)
Expand Down
Loading