Skip to content

Commit c0bac85

Browse files
SolitaryThinkerLeiWang1999
authored andcommitted
[bugfix] [AMD] add multi-step advance_step to ROCmFlashAttentionMetadata (vllm-project#8474)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 8785c7d commit c0bac85

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Attention layer ROCm GPUs."""
22
from dataclasses import dataclass
3-
from typing import Any, Dict, List, Optional, Tuple, Type
3+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
44

55
import torch
66

@@ -15,6 +15,9 @@
1515
from vllm.logger import init_logger
1616
from vllm.platforms import current_platform
1717

18+
if TYPE_CHECKING:
19+
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
20+
1821
logger = init_logger(__name__)
1922

2023
_PARTITION_SIZE_ROCM = 512
@@ -180,6 +183,59 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
180183
)
181184
return self._cached_decode_metadata
182185

186+
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
187+
sampled_token_ids: Optional[torch.Tensor],
188+
block_size: int, num_seqs: int, num_queries: int):
189+
"""
190+
Update metadata in-place to advance one decode step.
191+
"""
192+
# When using cudagraph, the num_seqs is padded to the next captured
193+
# batch sized, but num_queries tracks the actual number of requests in
194+
# the batch. For --enforce-eager mode, num_seqs == num_queries
195+
if num_seqs != num_queries:
196+
assert num_seqs > num_queries
197+
assert self.use_cuda_graph
198+
199+
assert self.num_prefills == 0
200+
assert self.num_prefill_tokens == 0
201+
assert self.num_decode_tokens == num_seqs
202+
assert self.slot_mapping.shape == (num_seqs, )
203+
204+
assert self.seq_lens is not None
205+
assert len(self.seq_lens) == num_seqs
206+
assert self.seq_lens_tensor is not None
207+
assert self.seq_lens_tensor.shape == (num_seqs, )
208+
assert self.max_query_len == 1
209+
assert self.max_prefill_seq_len == 0
210+
assert self.max_decode_seq_len == max(self.seq_lens)
211+
212+
assert self.query_start_loc is not None
213+
assert self.query_start_loc.shape == (num_queries + 1, )
214+
assert self.seq_start_loc is not None
215+
assert self.seq_start_loc.shape == (num_seqs + 1, )
216+
217+
assert self.context_lens_tensor is not None
218+
assert self.context_lens_tensor.shape == (num_queries, )
219+
220+
assert self.block_tables is not None
221+
assert self.block_tables.shape[0] == num_seqs
222+
223+
# Update query lengths. Note that we update only queries and not seqs,
224+
# since tensors may be padded due to captured cuda graph batch size
225+
for i in range(num_queries):
226+
self.seq_lens[i] += 1
227+
self.max_decode_seq_len = max(self.seq_lens)
228+
229+
ops.advance_step_flashattn(num_seqs=num_seqs,
230+
num_queries=num_queries,
231+
block_size=block_size,
232+
input_tokens=model_input.input_tokens,
233+
sampled_token_ids=sampled_token_ids,
234+
input_positions=model_input.input_positions,
235+
seq_lens=self.seq_lens_tensor,
236+
slot_mapping=self.slot_mapping,
237+
block_tables=self.block_tables)
238+
183239

184240
class ROCmFlashAttentionMetadataBuilder(
185241
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

vllm/worker/multi_step_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
logger = init_logger(__name__)
3131

32-
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"]
32+
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"]
3333

3434

3535
def seq_output_builder():

0 commit comments

Comments
 (0)