Skip to content

Commit 2ede1a2

Browse files
committed
Added HPU support for Automatic Prefix Caching
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
1 parent be633fb commit 2ede1a2

File tree

3 files changed

+146
-56
lines changed

3 files changed

+146
-56
lines changed

vllm/attention/backends/hpu_attn.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,16 @@ def get_kv_cache_shape(
5757
def swap_blocks(
5858
src_kv_cache: torch.Tensor,
5959
dst_kv_cache: torch.Tensor,
60-
src_to_dst: Dict[int, int],
60+
src_to_dsts: torch.Tensor,
6161
) -> None:
62-
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
62+
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)
6363

6464
@staticmethod
6565
def copy_blocks(
6666
kv_caches: List[torch.Tensor],
67-
src_to_dists: Dict[int, List[int]],
67+
src_to_dsts: torch.Tensor,
6868
) -> None:
69-
HPUPagedAttention.copy_blocks(kv_caches, src_to_dists)
69+
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)
7070

7171

7272
@dataclass
@@ -77,6 +77,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
7777
is_prompt: bool
7878
attn_bias: Optional[torch.Tensor]
7979
seq_lens_tensor: Optional[torch.Tensor]
80+
context_lens_tensor: Optional[torch.Tensor]
8081

8182

8283
class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
@@ -198,8 +199,7 @@ def forward(
198199
key_cache = None
199200
value_cache = None
200201
if attn_metadata.is_prompt and self.attn_type \
201-
is not AttentionType.ENCODER_ONLY \
202-
and attn_metadata.block_list is None:
202+
is not AttentionType.ENCODER_ONLY:
203203
key = key.unflatten(0, (block_indices.size(0), -1))
204204
value = value.unflatten(0, (block_indices.size(0), -1))
205205
if kv_cache is not None and isinstance(kv_cache, tuple):
@@ -229,6 +229,9 @@ def forward(
229229
attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1))
230230
attn_bias.add_(position_bias)
231231

232+
block_list = attn_metadata.block_list if attn_metadata \
233+
and attn_metadata.block_list is not None else None
234+
232235
out = ops.prompt_attention(
233236
impl=self.prefill_impl,
234237
query=query.view(query_shape),
@@ -237,23 +240,25 @@ def forward(
237240
is_causal=True,
238241
attn_bias=attn_bias,
239242
valid_seq_lengths=attn_metadata.seq_lens_tensor,
240-
**self.common_attention_args())
243+
**self.common_attention_args(block_list, key_cache,
244+
value_cache))
241245
output = out.reshape(batch_size, seq_len, hidden_size)
242246
else:
243247
# Decoding run.
244248
output = HPUPagedAttention.forward_decode(
245249
query=query,
246-
key_cache=key_cache,
247-
value_cache=value_cache,
248-
block_list=attn_metadata.block_list,
249250
block_mapping=attn_metadata.block_mapping,
250251
block_bias=attn_metadata.attn_bias,
251252
block_groups=attn_metadata.block_groups,
252-
**self.common_attention_args())
253+
**self.common_attention_args(attn_metadata.block_list,
254+
key_cache, value_cache))
253255
# Reshape the output tensor.
254256
return output.view(batch_size, seq_len, hidden_size)
255257

256-
def common_attention_args(self):
258+
def common_attention_args(self,
259+
block_list=None,
260+
key_cache=None,
261+
value_cache=None):
257262
fsdpa_op = self.fused_scaled_dot_product_attention.apply \
258263
if self.fused_scaled_dot_product_attention is not None else None
259264
return {
@@ -266,6 +271,9 @@ def common_attention_args(self):
266271
'keys_fetch_func': self.k_cache.fetch_from_cache,
267272
'values_fetch_func': self.v_cache.fetch_from_cache,
268273
'softmax_op': self.softmax,
274+
'block_list': block_list,
275+
'key_cache': key_cache,
276+
'value_cache': value_cache,
269277
}
270278

271279

vllm/attention/ops/hpu_paged_attn.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
###############################################################################
66

77
from dataclasses import dataclass
8-
from typing import Dict, List, Optional, Tuple
8+
from typing import List, Optional, Tuple
99

1010
import torch
1111
from vllm_hpu_extension import cache_ops, ops
@@ -63,43 +63,25 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor,
6363
def forward_decode(**kwargs) -> torch.Tensor:
6464
return ops.flat_pa(**kwargs)
6565

66-
@staticmethod
67-
def forward_prefix(
68-
query: torch.Tensor,
69-
key: torch.Tensor,
70-
value: torch.Tensor,
71-
key_cache: torch.Tensor,
72-
value_cache: torch.Tensor,
73-
block_tables: torch.Tensor,
74-
subquery_start_loc: torch.Tensor,
75-
seq_lens_tensor: torch.Tensor,
76-
context_lens: torch.Tensor,
77-
max_query_len: int,
78-
alibi_slopes: Optional[torch.Tensor],
79-
sliding_window: Optional[int],
80-
) -> torch.Tensor:
81-
raise NotImplementedError(
82-
"forward_prefix is not implemented for HPUPagedAttention")
83-
8466
@staticmethod
8567
def swap_blocks(
86-
src_kv_cache: torch.Tensor,
87-
dst_kv_cache: torch.Tensor,
88-
src_to_dst: Dict[int, int],
68+
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
69+
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
70+
src_to_dsts: torch.Tensor,
8971
) -> None:
9072
src_key_cache = src_kv_cache[0]
9173
dst_key_cache = dst_kv_cache[0]
92-
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
74+
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts)
9375

9476
src_value_cache = src_kv_cache[1]
9577
dst_value_cache = dst_kv_cache[1]
96-
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
78+
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts)
9779

9880
@staticmethod
9981
def copy_blocks(
100-
kv_caches: List[torch.Tensor],
101-
src_to_dists: Dict[int, List[int]],
82+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
83+
src_to_dsts: torch.Tensor,
10284
) -> None:
10385
key_caches = [kv_cache[0] for kv_cache in kv_caches]
10486
value_caches = [kv_cache[1] for kv_cache in kv_caches]
105-
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
87+
cache_ops.copy_blocks(key_caches, value_caches, src_to_dist)

vllm/worker/hpu_model_runner.py

Lines changed: 117 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
import time
1616
from array import array
17-
from enum import IntEnum
17+
from enum import Enum, IntEnum
1818
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
1919
Optional, Set, Tuple, Type, TypeVar, Union)
2020

@@ -75,6 +75,12 @@
7575
DUMMY_TOKEN_ID = -1
7676

7777

78+
class PhaseType(Enum):
79+
PREFILL = 'prefill'
80+
PREFIX_PREFILL = 'prefix_prefill'
81+
DECODE = 'decode'
82+
83+
7884
def subtuple(obj: object,
7985
typename: str,
8086
to_copy: List[str],
@@ -213,20 +219,40 @@ def _compile_region(self, model, name, module):
213219

214220
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
215221
dtype):
216-
prefill_metadata = attn_metadata
217-
if prefill_metadata is None or self.prefill_use_fusedsdpa:
222+
if (attn_metadata is None
223+
or (self.prefill_use_fusedsdpa \
224+
and attn_metadata.block_list is None)
225+
or not attn_metadata.is_prompt):
218226
return attn_metadata
219227

228+
prefill_metadata = attn_metadata
229+
220230
seq_lens_t = prefill_metadata.seq_lens_tensor
231+
context_lens_t = prefill_metadata.context_lens_tensor
232+
query_lens_t = seq_lens_t - context_lens_t
233+
234+
block_list = attn_metadata.block_list
235+
max_context_len = (block_list.size(-1) //
236+
batch_size if block_list is not None else 0)
237+
max_context_len = max_context_len * self.block_size
238+
past_mask = torch.arange(0,
239+
max_context_len,
240+
dtype=torch.int32,
241+
device=device)
242+
past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge(
243+
context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand(
244+
batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1))
245+
221246
len_mask = (torch.arange(0, seq_len, device=device,
222247
dtype=torch.int32).view(1, seq_len).ge(
223-
seq_lens_t.unsqueeze(-1)).view(
248+
query_lens_t.unsqueeze(-1)).view(
224249
batch_size, 1, 1, seq_len))
225250
causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len),
226251
device=device,
227252
dtype=torch.bool),
228253
diagonal=1)
229254
mask = causal_mask.logical_or(len_mask)
255+
mask = torch.concat((past_mask, mask), dim=-1)
230256
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
231257
mask, -math.inf))
232258
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
@@ -517,6 +543,11 @@ def __init__(
517543
False, self.max_model_len)
518544
self.graphed_buckets: Set[Any] = set()
519545
self._set_gc_threshold()
546+
if self.vllm_config.cache_config.enable_prefix_caching:
547+
os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False")
548+
assert os.environ.get(
549+
"VLLM_CONTIGUOUS_PA",
550+
"").lower() != "true", "Contiguous PA doesn't support APC"
520551
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
521552

522553
# For multi-step scheduling
@@ -702,6 +733,10 @@ def _prepare_prompt(
702733
computed_block_nums) > 0 and self.sliding_window is None:
703734
# Prefix is not supported with sliding_window
704735
context_len = len(computed_block_nums) * self.block_size
736+
if context_len == seq_len \
737+
and self.vllm_config.cache_config.enable_prefix_caching:
738+
# Fully cached prompt - compute only last token
739+
context_len = context_len - 1
705740
prompt_tokens = prompt_tokens[context_len:]
706741
prefix_block_tables.append(computed_block_nums)
707742
elif self.scheduler_config.chunked_prefill_enabled:
@@ -779,12 +814,33 @@ def _prepare_prompt(
779814
if lora_id > 0:
780815
lora_requests.add(seq_group_metadata.lora_request)
781816

782-
lora_index_mapping += [lora_id] * (max_prompt_len - context_len)
817+
lora_index_mapping += [lora_id] * max_prompt_len
783818
lora_prompt_mapping.extend(
784819
[lora_id] *
785-
(max_prompt_len - context_len
820+
(max_prompt_len
786821
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
787822

823+
if any(context_lens):
824+
assert not self.scheduler_config.chunked_prefill_enabled
825+
# prefix caching
826+
827+
max_num_block = max(len(bt) for bt in prefix_block_tables)
828+
prefix_block_list = list(
829+
itertools.chain.from_iterable(
830+
bt if len(bt) == max_num_block else bt +
831+
([_PAD_BLOCK_ID] * (max_num_block - len(bt)))
832+
for bt in prefix_block_tables))
833+
834+
pad_len = len(prefix_block_list)
835+
prefix_block_list = pad_list(prefix_block_list, pad_len,
836+
_PAD_BLOCK_ID)
837+
838+
prefix_block_list_tensor = torch.tensor(prefix_block_list,
839+
dtype=torch.long,
840+
device=self.device)
841+
else:
842+
prefix_block_list_tensor = None
843+
788844
input_tokens = make_tensor_with_pad(input_tokens,
789845
max_len=max_prompt_len,
790846
pad=0,
@@ -807,18 +863,23 @@ def _prepare_prompt(
807863
dtype=torch.long,
808864
device=self.device)
809865

866+
context_lens_tensor = torch.tensor(context_lens,
867+
dtype=torch.long,
868+
device=self.device)
869+
810870
block_indices, block_offsets = precompute_indices_and_offsets(
811871
self.block_size, slot_mapping, True)
812872
attn_metadata = self.attn_backend.make_metadata(
813873
is_prompt=True,
814-
block_list=None,
874+
block_list=prefix_block_list_tensor,
815875
block_mapping=None,
816876
block_usage=None,
817877
block_indices=block_indices,
818878
block_offsets=block_offsets,
819879
block_groups=None,
820880
attn_bias=None,
821881
seq_lens_tensor=seq_lens_tensor,
882+
context_lens_tensor=context_lens_tensor,
822883
num_prefills=real_num_seqs,
823884
num_prefill_tokens=sum_query_len,
824885
num_decode_tokens=0,
@@ -987,6 +1048,7 @@ def _prepare_decode(
9871048
block_groups=block_groups,
9881049
attn_bias=None,
9891050
seq_lens_tensor=None,
1051+
context_lens_tensor=None,
9901052
num_prefills=0,
9911053
num_prefill_tokens=0,
9921054
num_decode_tokens=num_decode_tokens,
@@ -1091,7 +1153,7 @@ def prepare_input_tensors(
10911153
# FIXME: We need to adjust selected_token_indices to accommodate
10921154
# for padding
10931155
max_len = input_tokens.size(1)
1094-
paddings = [max_len - s for s in seq_lens]
1156+
paddings = [max_len - q for q in query_lens]
10951157
paddings = [0] + paddings[:-1]
10961158
paddings = list(itertools.accumulate(paddings))
10971159
paddings_prompt_logprobs = []
@@ -1187,9 +1249,17 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
11871249
# input_hash(123) != input_hash(321)
11881250
# input_hash("abc") != input_hash("cba")
11891251
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
1190-
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
1191-
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
1192-
'block_offsets', 'block_groups'
1252+
'attn_bias',
1253+
'seq_lens_tensor',
1254+
'context_lens_tensor',
1255+
'block_list',
1256+
'block_mapping',
1257+
'block_usage',
1258+
'slot_mapping',
1259+
'is_prompt',
1260+
'block_indices',
1261+
'block_offsets',
1262+
'block_groups',
11931263
])
11941264
return attention_metadata
11951265

@@ -1733,14 +1803,44 @@ def finish_measurements(self):
17331803
from neural_compressor.torch.quantization import finalize_calibration
17341804
finalize_calibration(self.model.model)
17351805

1736-
def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
1737-
cfg = (batch_size, seq_len, is_prompt)
1806+
def _num_blocks(self, attn_metadata):
1807+
if attn_metadata.block_list is None:
1808+
return 0
1809+
return attn_metadata.block_list.numel()
1810+
1811+
def _phase(self, attn_metadata):
1812+
phase_type: PhaseType
1813+
is_prompt = attn_metadata.is_prompt
1814+
is_prefix_prefill = is_prompt and attn_metadata.block_list is not None
1815+
if is_prompt and is_prefix_prefill:
1816+
phase_type = PhaseType.PREFIX_PREFILL
1817+
elif is_prompt and not is_prefix_prefill:
1818+
phase_type = PhaseType.PREFILL
1819+
elif not is_prompt:
1820+
phase_type = PhaseType.DECODE
1821+
else:
1822+
raise ValueError("Unrecognized pass type, likely due to malformed "
1823+
"attention metadata")
1824+
return phase_type
1825+
1826+
def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode):
1827+
is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
1828+
cfg: Optional[tuple] = None
1829+
assert cfg is None, "Configs changed between 2D and 3D"
1830+
if is_prefix_caching:
1831+
phase = self._phase(attn_metadata)
1832+
num_blocks = self._num_blocks(attn_metadata)
1833+
cfg = (batch_size, seq_len, num_blocks, phase)
1834+
else:
1835+
phase = 'prompt' if attn_metadata.is_prompt else 'decode'
1836+
cfg = (batch_size, seq_len, phase)
17381837
seen = cfg in self.seen_configs
17391838
self.seen_configs.add(cfg)
17401839
if not seen and not warmup_mode:
1741-
phase = 'prompt' if is_prompt else 'decode'
1742-
logger.warning("Configuration: (%s, %s, %s) was not warmed-up!",
1743-
phase, batch_size, seq_len)
1840+
logger.warning("Configuration: %s was not warmed-up!",
1841+
(phase.value, batch_size, seq_len,
1842+
num_blocks) if is_prefix_caching else
1843+
(phase, batch_size, seq_len))
17441844

17451845
def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
17461846
is_prompt: bool):
@@ -1912,7 +2012,7 @@ def execute_model(
19122012
batch_size = input_tokens.size(0)
19132013
seq_len = self._seq_len(attn_metadata)
19142014
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
1915-
self._check_config(batch_size, seq_len, is_prompt, warmup_mode)
2015+
self._check_config(batch_size, seq_len, attn_metadata, warmup_mode)
19162016

19172017
lora_mask: torch.Tensor = None
19182018
lora_logits_mask: torch.Tensor = None

0 commit comments

Comments
 (0)