|
7 | 7 | from vllm import _custom_ops as ops
|
8 | 8 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
9 | 9 | AttentionMetadata, AttentionType)
|
| 10 | +from vllm.attention.ops.chunked_prefill_paged_decode import ( |
| 11 | + chunked_prefill_paged_decode) |
| 12 | +from vllm.attention.ops.paged_attn import PagedAttention |
10 | 13 | from vllm.attention.ops.triton_unified_attention import unified_attention
|
11 | 14 | from vllm.logger import init_logger
|
12 | 15 | from vllm.platforms import current_platform
|
@@ -162,19 +165,40 @@ def forward(
|
162 | 165 | # Whenever making a change in this method, please benchmark the
|
163 | 166 | # performance to make sure it does not introduce any overhead.
|
164 | 167 |
|
| 168 | + num_queries_per_kv = query.shape[1] // key.shape[1] |
| 169 | + use_prefill_decode_attn = (num_queries_per_kv & |
| 170 | + (num_queries_per_kv - 1)) != 0 |
| 171 | + |
165 | 172 | num_actual_tokens = attn_metadata.num_actual_tokens
|
166 | 173 |
|
167 |
| - key_cache, value_cache = kv_cache.unbind(0) |
168 |
| - torch.ops._C_cache_ops.reshape_and_cache_flash( |
169 |
| - key, |
170 |
| - value, |
171 |
| - key_cache, |
172 |
| - value_cache, |
173 |
| - attn_metadata.slot_mapping, |
174 |
| - self.kv_cache_dtype, |
175 |
| - layer._k_scale, |
176 |
| - layer._v_scale, |
177 |
| - ) |
| 174 | + if use_prefill_decode_attn: |
| 175 | + key_cache, value_cache = PagedAttention.split_kv_cache( |
| 176 | + kv_cache, self.num_kv_heads, self.head_size) |
| 177 | + |
| 178 | + # Reshape the input keys and values and store them in the cache. |
| 179 | + PagedAttention.write_to_paged_cache( |
| 180 | + key, |
| 181 | + value, |
| 182 | + key_cache, |
| 183 | + value_cache, |
| 184 | + attn_metadata.slot_mapping, |
| 185 | + self.kv_cache_dtype, |
| 186 | + layer._k_scale, |
| 187 | + layer._v_scale, |
| 188 | + ) |
| 189 | + |
| 190 | + else: |
| 191 | + key_cache, value_cache = kv_cache.unbind(0) |
| 192 | + torch.ops._C_cache_ops.reshape_and_cache_flash( |
| 193 | + key, |
| 194 | + value, |
| 195 | + key_cache, |
| 196 | + value_cache, |
| 197 | + attn_metadata.slot_mapping, |
| 198 | + self.kv_cache_dtype, |
| 199 | + layer._k_scale, |
| 200 | + layer._v_scale, |
| 201 | + ) |
178 | 202 |
|
179 | 203 | if self.kv_cache_dtype.startswith("fp8"):
|
180 | 204 | key_cache = key_cache.view(self.fp8_dtype)
|
@@ -209,26 +233,47 @@ def forward(
|
209 | 233 | max_seqlen_k = attn_metadata.max_seq_len
|
210 | 234 | block_table = attn_metadata.block_table
|
211 | 235 |
|
212 |
| - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) |
213 |
| - |
214 |
| - unified_attention( |
215 |
| - q=query[:num_actual_tokens], |
216 |
| - k=key_cache, |
217 |
| - v=value_cache, |
218 |
| - out=output[:num_actual_tokens], |
219 |
| - cu_seqlens_q=cu_seqlens_q, |
220 |
| - max_seqlen_q=max_seqlen_q, |
221 |
| - seqused_k=seqused_k, |
222 |
| - max_seqlen_k=max_seqlen_k, |
223 |
| - softmax_scale=self.scale, |
224 |
| - causal=True, |
225 |
| - alibi_slopes=self.alibi_slopes, |
226 |
| - window_size=self.sliding_window, |
227 |
| - block_table=block_table, |
228 |
| - softcap=self.logits_soft_cap, |
229 |
| - q_descale=None, # Not supported |
230 |
| - k_descale=layer._k_scale.expand(descale_shape), |
231 |
| - v_descale=layer._v_scale.expand(descale_shape), |
232 |
| - ) |
| 236 | + if use_prefill_decode_attn: |
| 237 | + # Compute attention and update output up to `num_actual_tokens`. |
| 238 | + chunked_prefill_paged_decode(query=query[:num_actual_tokens], |
| 239 | + key=key[:num_actual_tokens], |
| 240 | + value=value[:num_actual_tokens], |
| 241 | + output=output[:num_actual_tokens], |
| 242 | + kv_cache_dtype=self.kv_cache_dtype, |
| 243 | + key_cache=key_cache, |
| 244 | + value_cache=value_cache, |
| 245 | + block_table=block_table, |
| 246 | + query_start_loc=cu_seqlens_q, |
| 247 | + seq_lens=seqused_k, |
| 248 | + max_seq_len=max_seqlen_k, |
| 249 | + max_query_len=max_seqlen_q, |
| 250 | + k_scale=layer._k_scale, |
| 251 | + v_scale=layer._v_scale, |
| 252 | + alibi_slopes=self.alibi_slopes, |
| 253 | + sliding_window=self.sliding_window[0], |
| 254 | + sm_scale=self.scale) |
| 255 | + |
| 256 | + else: |
| 257 | + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) |
| 258 | + |
| 259 | + unified_attention( |
| 260 | + q=query[:num_actual_tokens], |
| 261 | + k=key_cache, |
| 262 | + v=value_cache, |
| 263 | + out=output[:num_actual_tokens], |
| 264 | + cu_seqlens_q=cu_seqlens_q, |
| 265 | + max_seqlen_q=max_seqlen_q, |
| 266 | + seqused_k=seqused_k, |
| 267 | + max_seqlen_k=max_seqlen_k, |
| 268 | + softmax_scale=self.scale, |
| 269 | + causal=True, |
| 270 | + alibi_slopes=self.alibi_slopes, |
| 271 | + window_size=self.sliding_window, |
| 272 | + block_table=block_table, |
| 273 | + softcap=self.logits_soft_cap, |
| 274 | + q_descale=None, # Not supported |
| 275 | + k_descale=layer._k_scale.expand(descale_shape), |
| 276 | + v_descale=layer._v_scale.expand(descale_shape), |
| 277 | + ) |
233 | 278 |
|
234 | 279 | return output
|
0 commit comments