Skip to content

Commit ac235a7

Browse files
vllmellmqli88hongxiayang
authored and
Mu Huai
committed
[FEAT][ROCm]: Support AITER MLA on V1 Engine (vllm-project#17523)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: qli88 <qiang.li2@amd.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 5725ea8 commit ac235a7

File tree

10 files changed

+269
-14
lines changed

10 files changed

+269
-14
lines changed

docker/Dockerfile.rocm_base

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1313
ARG FA_BRANCH="1a7f4dfa"
1414
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15-
ARG AITER_BRANCH="7e1ed08"
15+
ARG AITER_BRANCH="5a77249"
1616
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1717

1818
FROM ${BASE_IMAGE} AS base

tests/kernels/attention/test_attention_selector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ def test_env(
102102
block_size,
103103
False,
104104
use_mla=use_mla)
105-
assert backend.get_name() == name
105+
if use_v1 and name != "TRITON_MLA":
106+
assert backend.get_name() == f"{name}_VLLM_V1"
107+
else:
108+
assert backend.get_name() == name
106109
else:
107110
with pytest.raises(ValueError) as exc_info:
108111
get_attn_backend(16,

tests/kernels/attention/test_rocm_attention_selector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
4848
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
4949
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
5050
False, True)
51-
assert backend.get_name() == "ROCM_AITER_MLA"
51+
assert (backend.get_name() == "ROCM_AITER_MLA"
52+
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")
5253

5354
# If attention backend is None
5455
# If use_mla is true
@@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
5859
m.setenv("VLLM_ROCM_USE_AITER", "1")
5960
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False,
6061
False, True)
61-
assert backend.get_name() == "ROCM_AITER_MLA"
62+
assert (backend.get_name() == "ROCM_AITER_MLA"
63+
or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1")

vllm/attention/ops/rocm_aiter_mla.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
import torch
66

7+
from vllm.platforms import current_platform
8+
from vllm.utils import direct_register_custom_op
9+
710

811
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
912
max_block_per_batch: int,
@@ -30,6 +33,28 @@ def aiter_mla_decode_fwd(
3033
kv_last_page_lens: Optional[torch.Tensor] = None,
3134
logit_cap: float = 0.0,
3235
):
36+
37+
torch.ops.vllm.rocm_aiter_mla_decode_fwd(q,
38+
kv_buffer.view(
39+
-1, 1, 1, q.shape[-1]),
40+
o,
41+
kv_indptr,
42+
kv_indices,
43+
kv_last_page_lens,
44+
sm_scale=sm_scale,
45+
logit_cap=logit_cap)
46+
47+
48+
def mla_decode_fwd_impl(
49+
q: torch.Tensor,
50+
kv_buffer: torch.Tensor,
51+
o: torch.Tensor,
52+
kv_indptr: Optional[torch.Tensor] = None,
53+
kv_indices: Optional[torch.Tensor] = None,
54+
kv_last_page_lens: Optional[torch.Tensor] = None,
55+
sm_scale: float = 1.0,
56+
logit_cap: float = 0.0,
57+
) -> None:
3358
from aiter.mla import mla_decode_fwd
3459

3560
mla_decode_fwd(q,
@@ -40,3 +65,24 @@ def aiter_mla_decode_fwd(
4065
kv_last_page_lens,
4166
sm_scale=sm_scale,
4267
logit_cap=logit_cap)
68+
69+
70+
def mla_decode_fwd_fake(
71+
q: torch.Tensor,
72+
kv_buffer: torch.Tensor,
73+
o: torch.Tensor,
74+
kv_indptr: Optional[torch.Tensor] = None,
75+
kv_indices: Optional[torch.Tensor] = None,
76+
kv_last_page_lens: Optional[torch.Tensor] = None,
77+
sm_scale: float = 1.0,
78+
logit_cap: float = 0.0,
79+
) -> None:
80+
pass
81+
82+
83+
if current_platform.is_rocm():
84+
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
85+
op_func=mla_decode_fwd_impl,
86+
mutates_args=["o"],
87+
fake_impl=mla_decode_fwd_fake,
88+
tags=[torch.Tag.needs_fixed_stride_order])

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13191319
"FLASHMLA",
13201320
"FLASHINFER",
13211321
"FLASHINFER_VLLM_V1",
1322+
"ROCM_AITER_MLA",
13221323
]
13231324
if (envs.is_set("VLLM_ATTENTION_BACKEND")
13241325
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
145145
block_shape: List[int],
146146
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
147147

148-
return torch.empty_like(a1, dtype=torch.bf16)
148+
return torch.empty_like(a1, dtype=hidden_states_dtype)
149149

150150

151151
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,

vllm/platforms/interface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ class _Backend(enum.Enum):
3939
TRITON_ATTN_VLLM_V1 = enum.auto()
4040
XFORMERS = enum.auto()
4141
ROCM_FLASH = enum.auto()
42-
ROCM_AITER_MLA = enum.auto()
42+
ROCM_AITER_MLA = enum.auto() # Supported by V1
43+
ROCM_AITER_MLA_VLLM_V1 = enum.auto()
4344
TORCH_SDPA = enum.auto()
4445
FLASHINFER = enum.auto()
4546
TRITON_MLA = enum.auto() # Supported by V1

vllm/platforms/rocm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
168168
raise ValueError(
169169
f" The selected backend, {selected_backend.name},"
170170
f"does not support block size {block_size}.")
171-
elif selected_backend == _Backend.ROCM_AITER_MLA:
171+
elif selected_backend == _Backend.ROCM_AITER_MLA \
172+
or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
172173
if block_size == 1:
173-
logger.info("Using AITER MLA backend.")
174-
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
174+
if use_v1:
175+
logger.info("Using AITER MLA backend on V1 engine.")
176+
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
177+
else:
178+
logger.info("Using AITER MLA backend")
179+
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
175180
else:
176181
raise ValueError(
177182
f" The selected backend, {selected_backend.name},"

vllm/v1/attention/backends/mla/common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -496,11 +496,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
496496
max_context_chunk = (self.chunked_prefill_workspace_size //
497497
num_prefills_with_context_cpu)
498498

499-
# align max_context_chunk to page_size by rounding down,
500-
# currently the `gather_cache` kernel cannot handle
501-
# `context_chunk_starts` that are not aligned to page_size
502-
max_context_chunk = round_down(max_context_chunk,
503-
self.page_size)
499+
if self.aot_schedule:
500+
# align max_context_chunk to page_size by rounding down,
501+
# currently the `gather_cache` kernel cannot handle
502+
# `context_chunk_starts` that are not aligned to page_size
503+
max_context_chunk = round_down(max_context_chunk,
504+
self.page_size)
504505

505506
assert max_context_chunk > 0
506507
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from dataclasses import dataclass
4+
from typing import Any, Optional
5+
6+
import torch
7+
8+
import vllm.envs as envs
9+
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
10+
# yapf conflicts with isort for this docstring
11+
# yapf: disable
12+
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
13+
MLACommonDecodeMetadata,
14+
MLACommonImpl,
15+
MLACommonMetadata,
16+
MLACommonMetadataBuilder)
17+
18+
# yapf: enable
19+
20+
21+
def is_aiter_mla_enabled() -> bool:
22+
return envs.VLLM_ROCM_USE_AITER \
23+
and envs.VLLM_ROCM_USE_AITER_MLA
24+
25+
26+
class AiterMLABackend(MLACommonBackend):
27+
28+
@staticmethod
29+
def get_name() -> str:
30+
return "ROCM_AITER_MLA_VLLM_V1"
31+
32+
@staticmethod
33+
def get_impl_cls() -> type["AiterMLAImpl"]:
34+
return AiterMLAImpl
35+
36+
@staticmethod
37+
def get_metadata_cls() -> type["AiterMLAMetadata"]:
38+
return AiterMLAMetadata
39+
40+
@staticmethod
41+
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
42+
return AiterMLAMetadataBuilder
43+
44+
45+
@dataclass
46+
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
47+
# The indptr of the paged kv cache, shape: [batch_size + 1]
48+
paged_kv_indptr: Optional[torch.Tensor] = None
49+
# The page indices of the paged kv cache
50+
paged_kv_indices: Optional[torch.Tensor] = None
51+
# The number of entries in the last page of each request in
52+
# the paged kv cache, shape: [batch_size]
53+
paged_kv_last_page_len: Optional[torch.Tensor] = None
54+
55+
56+
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
57+
pass
58+
59+
60+
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
61+
62+
def __init__(self, runner):
63+
super().__init__(runner)
64+
max_model_len = self.runner.model_config.max_model_len
65+
assert max_model_len == 32768,\
66+
"AITER MLA requires max_model_len=32768"
67+
assert self.runner.block_size == 1, "AITER MLA" \
68+
"only supports block size 1."
69+
70+
def _get_paged_kv_tensors(
71+
self, block_table: torch.Tensor,
72+
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
73+
page_size = self.runner.block_size
74+
block_table_bounds = (seq_lens + page_size - 1) // page_size
75+
76+
mask = (torch.arange(block_table.size(1),
77+
dtype=block_table.dtype,
78+
device=block_table.device).unsqueeze(0)
79+
< block_table_bounds.unsqueeze(1))
80+
paged_kv_indices = block_table[mask]
81+
82+
paged_kv_indptr = torch.cat([
83+
torch.zeros(1,
84+
dtype=block_table_bounds.dtype,
85+
device=block_table_bounds.device),
86+
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
87+
])
88+
89+
paged_kv_last_page_len = seq_lens % page_size
90+
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
91+
page_size, paged_kv_last_page_len)
92+
return (
93+
paged_kv_indices,
94+
paged_kv_indptr,
95+
paged_kv_last_page_len,
96+
)
97+
98+
def _build_decode(self, input_positions: torch.Tensor,
99+
block_table: torch.Tensor,
100+
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
101+
102+
(
103+
paged_kv_indices,
104+
paged_kv_indptr,
105+
paged_last_page_len,
106+
) = self._get_paged_kv_tensors(block_table, seq_lens)
107+
108+
attn_metadata = AiterMLADecodeMetadata(
109+
input_positions=input_positions,
110+
block_table=block_table,
111+
seq_lens=seq_lens,
112+
paged_kv_indptr=paged_kv_indptr,
113+
paged_kv_indices=paged_kv_indices,
114+
paged_kv_last_page_len=paged_last_page_len)
115+
116+
return attn_metadata
117+
118+
119+
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
120+
121+
def __init__(
122+
self,
123+
num_heads: int,
124+
head_size: int,
125+
scale: float,
126+
num_kv_heads: int,
127+
alibi_slopes: Optional[list[float]],
128+
sliding_window: Optional[int],
129+
kv_cache_dtype: str,
130+
blocksparse_params: Optional[dict[str, Any]],
131+
logits_soft_cap: Optional[float],
132+
attn_type: str,
133+
# MLA Specific Arguments
134+
**mla_args) -> None:
135+
super().__init__(num_heads, head_size, scale, num_kv_heads,
136+
alibi_slopes, sliding_window, kv_cache_dtype,
137+
blocksparse_params, logits_soft_cap, attn_type,
138+
**mla_args)
139+
140+
unsupported_features = [
141+
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
142+
]
143+
if any(unsupported_features):
144+
raise NotImplementedError(
145+
"Aiter MLA does not support one of the following: "
146+
"alibi_slopes, sliding_window, blocksparse_params, "
147+
"logits_soft_cap")
148+
149+
from aiter import flash_attn_varlen_func
150+
self.flash_attn_varlen_func = flash_attn_varlen_func
151+
152+
def _flash_attn_varlen_diff_headdims(self,
153+
q,
154+
k,
155+
v,
156+
return_softmax_lse=False,
157+
softmax_scale=None,
158+
**kwargs):
159+
output = self.flash_attn_varlen_func(
160+
q=q,
161+
k=k,
162+
v=v,
163+
softmax_scale=softmax_scale,
164+
return_lse=return_softmax_lse,
165+
**kwargs,
166+
)
167+
168+
return output
169+
170+
def _forward_decode(
171+
self,
172+
q_nope: torch.Tensor,
173+
q_pe: torch.Tensor,
174+
kv_c_and_k_pe_cache: torch.Tensor,
175+
attn_metadata: AiterMLAMetadata,
176+
) -> torch.Tensor:
177+
assert kv_c_and_k_pe_cache.numel() > 0
178+
assert attn_metadata.decode is not None
179+
180+
B = q_nope.shape[0]
181+
182+
q = torch.cat([q_nope, q_pe], dim=-1)
183+
o = torch.zeros(B,
184+
self.num_heads,
185+
self.kv_lora_rank,
186+
dtype=q.dtype,
187+
device=q.device)
188+
189+
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
190+
191+
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
192+
attn_metadata.decode.paged_kv_indptr,
193+
attn_metadata.decode.paged_kv_indices,
194+
attn_metadata.decode.paged_kv_last_page_len)
195+
196+
return self._v_up_proj(o)

0 commit comments

Comments
 (0)