Skip to content

Commit ee659e3

Browse files
authored
[Bugfix][ROCm] Use chunked_prefill_paged_decode as fallback for V1 attention on ROCm (#18093)
Signed-off-by: kf <kuanfu.liu@embeddedllm.com>
1 parent 4e1c6a0 commit ee659e3

File tree

1 file changed

+77
-32
lines changed

1 file changed

+77
-32
lines changed

vllm/v1/attention/backends/triton_attn.py

Lines changed: 77 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from vllm import _custom_ops as ops
88
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
99
AttentionMetadata, AttentionType)
10+
from vllm.attention.ops.chunked_prefill_paged_decode import (
11+
chunked_prefill_paged_decode)
12+
from vllm.attention.ops.paged_attn import PagedAttention
1013
from vllm.attention.ops.triton_unified_attention import unified_attention
1114
from vllm.logger import init_logger
1215
from vllm.platforms import current_platform
@@ -162,19 +165,40 @@ def forward(
162165
# Whenever making a change in this method, please benchmark the
163166
# performance to make sure it does not introduce any overhead.
164167

168+
num_queries_per_kv = query.shape[1] // key.shape[1]
169+
use_prefill_decode_attn = (num_queries_per_kv &
170+
(num_queries_per_kv - 1)) != 0
171+
165172
num_actual_tokens = attn_metadata.num_actual_tokens
166173

167-
key_cache, value_cache = kv_cache.unbind(0)
168-
torch.ops._C_cache_ops.reshape_and_cache_flash(
169-
key,
170-
value,
171-
key_cache,
172-
value_cache,
173-
attn_metadata.slot_mapping,
174-
self.kv_cache_dtype,
175-
layer._k_scale,
176-
layer._v_scale,
177-
)
174+
if use_prefill_decode_attn:
175+
key_cache, value_cache = PagedAttention.split_kv_cache(
176+
kv_cache, self.num_kv_heads, self.head_size)
177+
178+
# Reshape the input keys and values and store them in the cache.
179+
PagedAttention.write_to_paged_cache(
180+
key,
181+
value,
182+
key_cache,
183+
value_cache,
184+
attn_metadata.slot_mapping,
185+
self.kv_cache_dtype,
186+
layer._k_scale,
187+
layer._v_scale,
188+
)
189+
190+
else:
191+
key_cache, value_cache = kv_cache.unbind(0)
192+
torch.ops._C_cache_ops.reshape_and_cache_flash(
193+
key,
194+
value,
195+
key_cache,
196+
value_cache,
197+
attn_metadata.slot_mapping,
198+
self.kv_cache_dtype,
199+
layer._k_scale,
200+
layer._v_scale,
201+
)
178202

179203
if self.kv_cache_dtype.startswith("fp8"):
180204
key_cache = key_cache.view(self.fp8_dtype)
@@ -209,26 +233,47 @@ def forward(
209233
max_seqlen_k = attn_metadata.max_seq_len
210234
block_table = attn_metadata.block_table
211235

212-
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
213-
214-
unified_attention(
215-
q=query[:num_actual_tokens],
216-
k=key_cache,
217-
v=value_cache,
218-
out=output[:num_actual_tokens],
219-
cu_seqlens_q=cu_seqlens_q,
220-
max_seqlen_q=max_seqlen_q,
221-
seqused_k=seqused_k,
222-
max_seqlen_k=max_seqlen_k,
223-
softmax_scale=self.scale,
224-
causal=True,
225-
alibi_slopes=self.alibi_slopes,
226-
window_size=self.sliding_window,
227-
block_table=block_table,
228-
softcap=self.logits_soft_cap,
229-
q_descale=None, # Not supported
230-
k_descale=layer._k_scale.expand(descale_shape),
231-
v_descale=layer._v_scale.expand(descale_shape),
232-
)
236+
if use_prefill_decode_attn:
237+
# Compute attention and update output up to `num_actual_tokens`.
238+
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
239+
key=key[:num_actual_tokens],
240+
value=value[:num_actual_tokens],
241+
output=output[:num_actual_tokens],
242+
kv_cache_dtype=self.kv_cache_dtype,
243+
key_cache=key_cache,
244+
value_cache=value_cache,
245+
block_table=block_table,
246+
query_start_loc=cu_seqlens_q,
247+
seq_lens=seqused_k,
248+
max_seq_len=max_seqlen_k,
249+
max_query_len=max_seqlen_q,
250+
k_scale=layer._k_scale,
251+
v_scale=layer._v_scale,
252+
alibi_slopes=self.alibi_slopes,
253+
sliding_window=self.sliding_window[0],
254+
sm_scale=self.scale)
255+
256+
else:
257+
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
258+
259+
unified_attention(
260+
q=query[:num_actual_tokens],
261+
k=key_cache,
262+
v=value_cache,
263+
out=output[:num_actual_tokens],
264+
cu_seqlens_q=cu_seqlens_q,
265+
max_seqlen_q=max_seqlen_q,
266+
seqused_k=seqused_k,
267+
max_seqlen_k=max_seqlen_k,
268+
softmax_scale=self.scale,
269+
causal=True,
270+
alibi_slopes=self.alibi_slopes,
271+
window_size=self.sliding_window,
272+
block_table=block_table,
273+
softcap=self.logits_soft_cap,
274+
q_descale=None, # Not supported
275+
k_descale=layer._k_scale.expand(descale_shape),
276+
v_descale=layer._v_scale.expand(descale_shape),
277+
)
233278

234279
return output

0 commit comments

Comments
 (0)