Skip to content

Commit f00b84f

Browse files
comaniacjimpang
authored andcommitted
[Core] Refactor _prepare_model_input_tensors - take 2 (vllm-project#6164)
1 parent 63b9c10 commit f00b84f

File tree

12 files changed

+1050
-470
lines changed

12 files changed

+1050
-470
lines changed

tests/worker/test_model_input.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from vllm.attention import AttentionMetadata
6+
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
77
from vllm.attention.backends.abstract import AttentionBackend
88
from vllm.model_executor import SamplingMetadata
99
from vllm.model_executor.pooling_metadata import PoolingMetadata
@@ -26,6 +26,10 @@ def get_impl_cls():
2626
def get_metadata_cls() -> Type["AttentionMetadata"]:
2727
return AttentionMetadata
2828

29+
@staticmethod
30+
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
31+
raise AttentionMetadataBuilder
32+
2933
@staticmethod
3034
def get_kv_cache_shape(
3135
num_blocks: int,

vllm/attention/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from vllm.attention.backends.abstract import (AttentionBackend,
2-
AttentionMetadata)
2+
AttentionMetadata,
3+
AttentionMetadataBuilder)
34
from vllm.attention.layer import Attention
45
from vllm.attention.selector import get_attn_backend
56

67
__all__ = [
78
"Attention",
89
"AttentionBackend",
910
"AttentionMetadata",
11+
"AttentionMetadataBuilder",
1012
"Attention",
1113
"get_attn_backend",
1214
]

vllm/attention/backends/abstract.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass, fields
33
from enum import Enum, auto
4-
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
5-
TypeVar)
4+
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
5+
Tuple, Type, TypeVar)
66

77
import torch
88

9+
if TYPE_CHECKING:
10+
from vllm.sequence import SequenceGroupMetadata
11+
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
12+
913

1014
class AttentionType(Enum):
1115
DECODER = auto() # Decoder attention between previous layer Q/K/V
@@ -35,6 +39,16 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
3539
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
3640
return cls.get_metadata_cls()(*args, **kwargs)
3741

42+
@staticmethod
43+
@abstractmethod
44+
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
45+
raise NotImplementedError
46+
47+
@classmethod
48+
def make_metadata_builder(cls, *args,
49+
**kwargs) -> "AttentionMetadataBuilder":
50+
return cls.get_builder_cls()(*args, **kwargs)
51+
3852
@staticmethod
3953
@abstractmethod
4054
def get_kv_cache_shape(
@@ -110,6 +124,33 @@ def asdict_zerocopy(self,
110124
T = TypeVar("T", bound=AttentionMetadata)
111125

112126

127+
class AttentionMetadataBuilder(ABC, Generic[T]):
128+
"""Abstract class for attention metadata builders."""
129+
130+
@abstractmethod
131+
def __init__(self, input_builder) -> None:
132+
raise NotImplementedError
133+
134+
@abstractmethod
135+
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
136+
token_lens: List[int], seq_lens: List[int],
137+
curr_seq_lens: List[int], query_lens: List[int],
138+
context_lens: List[int],
139+
curr_sliding_window_blocks: List[int],
140+
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
141+
"""Add a sequence group to the metadata and update
142+
corresponding fields (in Python objects).
143+
"""
144+
raise NotImplementedError
145+
146+
@abstractmethod
147+
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
148+
query_lens: List[int], cuda_graph_pad_size: int,
149+
batch_size: int) -> T:
150+
"""Build attention metadata with on-device tensors."""
151+
raise NotImplementedError
152+
153+
113154
class AttentionImpl(ABC, Generic[T]):
114155

115156
@abstractmethod

vllm/attention/backends/blocksparse_attn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
77
AttentionMetadata, AttentionType)
8+
from vllm.attention.backends.utils import CommonMetadataBuilder
89
from vllm.attention.ops.blocksparse_attention.interface import (
910
LocalStridedBlockSparseAttn, get_head_sliding_step)
1011
from vllm.attention.ops.paged_attn import PagedAttention
@@ -93,6 +94,10 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
9394
def get_metadata_cls() -> Type["AttentionMetadata"]:
9495
return BlocksparseFlashAttentionMetadata
9596

97+
@staticmethod
98+
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
99+
return BlocksparseFlashAttentionMetadataBuilder
100+
96101
@staticmethod
97102
def get_kv_cache_shape(
98103
num_blocks: int,
@@ -244,6 +249,12 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
244249
return self._cached_decode_metadata
245250

246251

252+
class BlocksparseFlashAttentionMetadataBuilder(
253+
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
254+
255+
_metadata_cls = BlocksparseFlashAttentionMetadata
256+
257+
247258
class BlocksparseFlashAttentionImpl(AttentionImpl):
248259
"""
249260
If the input tensors contain prompt tokens, the layout is as follows:

vllm/attention/backends/flash_attn.py

Lines changed: 181 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
"""Attention layer with FlashAttention."""
22
from dataclasses import dataclass
3-
from typing import Any, Dict, List, Optional, Tuple, Type
3+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
44

55
import torch
66
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
77

88
from vllm import _custom_ops as ops
99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10-
AttentionMetadata, AttentionType)
10+
AttentionMetadata,
11+
AttentionMetadataBuilder,
12+
AttentionType)
13+
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
14+
compute_slot_mapping_start_idx,
15+
is_block_tables_empty)
16+
from vllm.sequence import SequenceGroupMetadata
17+
from vllm.utils import make_tensor_with_pad
18+
19+
if TYPE_CHECKING:
20+
from vllm.worker.model_runner import (GPUModelRunnerBase,
21+
ModelInputForGPUBuilder)
1122

1223

1324
class FlashAttentionBackend(AttentionBackend):
@@ -28,6 +39,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
2839
def get_metadata_cls() -> Type["AttentionMetadata"]:
2940
return FlashAttentionMetadata
3041

42+
@staticmethod
43+
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
44+
return FlashAttentionMetadataBuilder
45+
3146
@staticmethod
3247
def get_kv_cache_shape(
3348
num_blocks: int,
@@ -184,6 +199,170 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
184199
return self._cached_decode_metadata
185200

186201

202+
class FlashAttentionMetadataBuilder(
203+
AttentionMetadataBuilder[FlashAttentionMetadata]):
204+
205+
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
206+
self.slot_mapping: List[int] = []
207+
self.prefill_seq_lens: List[int] = []
208+
self.context_lens: List[int] = []
209+
self.block_tables: List[List[int]] = []
210+
self.curr_seq_lens: List[int] = []
211+
self.num_prefills = 0
212+
self.num_prefill_tokens = 0
213+
self.num_decode_tokens = 0
214+
215+
self.sliding_window = input_builder.sliding_window
216+
self.block_size = input_builder.block_size
217+
self.use_v2_block_manager = (
218+
input_builder.scheduler_config.use_v2_block_manager)
219+
220+
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
221+
token_lens: List[int], seq_lens: List[int],
222+
curr_seq_lens: List[int], query_lens: List[int],
223+
context_lens: List[int],
224+
curr_sliding_window_blocks: List[int],
225+
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
226+
"""Add a sequence group to the metadata. Specifically update/append
227+
1. context length.
228+
2. block table.
229+
3. slot mapping.
230+
"""
231+
is_prompt = seq_group_metadata.is_prompt
232+
block_tables = seq_group_metadata.block_tables
233+
234+
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
235+
curr_sliding_window_block) in zip(
236+
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
237+
curr_seq_lens, query_lens, context_lens,
238+
curr_sliding_window_blocks):
239+
self.context_lens.append(context_len)
240+
241+
if is_prompt:
242+
self.num_prefills += 1
243+
self.num_prefill_tokens += token_len
244+
self.prefill_seq_lens.append(seq_len)
245+
else:
246+
assert query_len == 1, (
247+
"seq_len: {}, context_len: {}, query_len: {}".format(
248+
seq_len, context_len, query_len))
249+
self.num_decode_tokens += query_len
250+
self.curr_seq_lens.append(curr_seq_len)
251+
252+
# Compute block table.
253+
# TODO(sang): Combine chunked prefill and prefix caching by
254+
# only allowing multiple of block_size chunk size.
255+
# NOTE: This only works for oooooooxxx style attention.
256+
block_table = []
257+
if prefix_cache_hit:
258+
# NOTE(woosuk): For flash-attn, the block table should
259+
# include the entries for the incoming prefill tokens.
260+
block_table = block_tables[seq_id]
261+
elif ((chunked_prefill_enabled or not is_prompt)
262+
and block_tables is not None):
263+
block_table = block_tables[seq_id][-curr_sliding_window_block:]
264+
self.block_tables.append(block_table)
265+
266+
# Compute slot mapping.
267+
is_profile_run = is_block_tables_empty(block_tables)
268+
start_idx = compute_slot_mapping_start_idx(
269+
is_prompt, query_len, context_len, self.sliding_window,
270+
self.use_v2_block_manager)
271+
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
272+
seq_len, context_len, start_idx,
273+
self.block_size,
274+
seq_group_metadata.block_tables)
275+
276+
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
277+
cuda_graph_pad_size: int, batch_size: int):
278+
"""Build attention metadata with on-device tensors."""
279+
device = runner.device
280+
use_captured_graph = cuda_graph_pad_size != -1
281+
282+
logits_soft_cap = getattr(runner.model_config.hf_config,
283+
"attn_logit_softcapping", None)
284+
if logits_soft_cap is not None:
285+
raise ValueError(
286+
"Please use Flashinfer backend for models with logits_soft_cap"
287+
" (i.e., Gemma-2). Otherwise, the output might be wrong."
288+
" Set Flashinfer backend by "
289+
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
290+
291+
max_query_len = max(query_lens)
292+
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
293+
max_decode_seq_len = max(self.curr_seq_lens, default=0)
294+
num_decode_tokens = self.num_decode_tokens
295+
296+
if use_captured_graph:
297+
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
298+
self.block_tables.extend([] * cuda_graph_pad_size)
299+
num_decode_tokens = batch_size + cuda_graph_pad_size
300+
301+
# The shape of graph_block_tables is
302+
# [max batch size, max context len // block size].
303+
input_block_tables = runner.graph_block_tables[:batch_size]
304+
for i, block_table in enumerate(self.block_tables):
305+
if block_table:
306+
input_block_tables[i, :len(block_table)] = block_table
307+
block_tables = torch.tensor(input_block_tables, device=device)
308+
else:
309+
max_block_table_len = max(
310+
len(block_table) for block_table in self.block_tables)
311+
block_tables = make_tensor_with_pad(
312+
self.block_tables,
313+
max_len=max_block_table_len,
314+
pad=0,
315+
dtype=torch.int,
316+
device=device,
317+
)
318+
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
319+
320+
context_lens_tensor = torch.tensor(self.context_lens,
321+
dtype=torch.int,
322+
device=device)
323+
seq_lens_tensor = torch.tensor(seq_lens,
324+
dtype=torch.int,
325+
device=device)
326+
query_lens_tensor = torch.tensor(query_lens,
327+
dtype=torch.long,
328+
device=device)
329+
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
330+
dtype=torch.int32,
331+
device=device)
332+
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
333+
dtype=torch.int32,
334+
device=device)
335+
torch.cumsum(seq_lens_tensor,
336+
dim=0,
337+
dtype=seq_start_loc.dtype,
338+
out=seq_start_loc[1:])
339+
torch.cumsum(query_lens_tensor,
340+
dim=0,
341+
dtype=query_start_loc.dtype,
342+
out=query_start_loc[1:])
343+
344+
slot_mapping_tensor = torch.tensor(self.slot_mapping,
345+
dtype=torch.long,
346+
device=device)
347+
348+
return FlashAttentionMetadata(
349+
num_prefills=self.num_prefills,
350+
slot_mapping=slot_mapping_tensor,
351+
num_prefill_tokens=self.num_prefill_tokens,
352+
num_decode_tokens=num_decode_tokens,
353+
seq_lens=seq_lens,
354+
seq_lens_tensor=seq_lens_tensor,
355+
max_query_len=max_query_len,
356+
max_prefill_seq_len=max_prefill_seq_len,
357+
max_decode_seq_len=max_decode_seq_len,
358+
query_start_loc=query_start_loc,
359+
seq_start_loc=seq_start_loc,
360+
context_lens_tensor=context_lens_tensor,
361+
block_tables=block_tables,
362+
use_cuda_graph=use_captured_graph,
363+
)
364+
365+
187366
class FlashAttentionImpl(AttentionImpl):
188367
"""
189368
If the input tensors contain prompt tokens, the layout is as follows:

0 commit comments

Comments
 (0)