Skip to content

Commit a9b15c6

Browse files
authored
[torch.compile] use empty tensor instead of None for profiling (#8875)
1 parent 8df2dc3 commit a9b15c6

15 files changed

+84
-32
lines changed

tests/kernels/test_encoder_decoder_attn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ class that Attention will automatically select when it is constructed.
136136
)
137137
if test_pt.num_blocks is None or test_pt.num_heads is None:
138138
# Caller does not require a KV cache
139-
return TestResources(scale, attn_backend, attn, None)
139+
return TestResources(
140+
scale, attn_backend, attn,
141+
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
140142

141143
# Construct KV cache
142144
kv_cache = make_kv_cache(test_pt.num_blocks,
@@ -620,7 +622,9 @@ def _run_encoder_attention_test(
620622
return attn.forward(packed_qkv.query,
621623
packed_qkv.key,
622624
packed_qkv.value,
623-
None,
625+
torch.tensor([],
626+
dtype=torch.float32,
627+
device=packed_qkv.query.device),
624628
attn_metadata,
625629
attn_type=attn_type)
626630

vllm/attention/backends/blocksparse_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ def forward(
357357
key: shape = [num_tokens, num_kv_heads * head_size]
358358
value: shape = [num_tokens, num_kv_heads * head_size]
359359
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
360+
NOTE: kv_cache will be an empty tensor with shape [0]
361+
for profiling run.
360362
attn_metadata: Metadata for attention.
361363
Returns:
362364
shape = [num_tokens, num_heads * head_size]
@@ -373,7 +375,7 @@ def forward(
373375
key = key.view(-1, self.num_kv_heads, self.head_size)
374376
value = value.view(-1, self.num_kv_heads, self.head_size)
375377

376-
if kv_cache is not None:
378+
if kv_cache.numel() > 0:
377379
key_cache, value_cache = PagedAttention.split_kv_cache(
378380
kv_cache, self.num_kv_heads, self.head_size)
379381

@@ -399,7 +401,7 @@ def forward(
399401
# When block_tables are not filled, it means q and k are the
400402
# prompt, and they have the same length.
401403

402-
assert kv_cache is None \
404+
assert kv_cache.numel() == 0 \
403405
or prefill_meta.block_tables is None \
404406
or prefill_meta.block_tables.numel() == 0, \
405407
"Does not support prefix-enabled attention."

vllm/attention/backends/flash_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,8 @@ def forward(
665665
key: shape = [num_tokens, num_kv_heads * head_size]
666666
value: shape = [num_tokens, num_kv_heads * head_size]
667667
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
668+
NOTE: kv_cache will be an empty tensor with shape [0]
669+
for profiling run.
668670
attn_metadata: Metadata for attention.
669671
Returns:
670672
shape = [num_tokens, num_heads * head_size]
@@ -685,7 +687,7 @@ def forward(
685687
key = key.view(-1, self.num_kv_heads, self.head_size)
686688
value = value.view(-1, self.num_kv_heads, self.head_size)
687689

688-
if kv_cache is not None:
690+
if kv_cache.numel() > 0:
689691
key_cache = kv_cache[0]
690692
value_cache = kv_cache[1]
691693

@@ -722,7 +724,7 @@ def forward(
722724

723725
if prefill_meta := attn_metadata.prefill_metadata:
724726
# Prompt run.
725-
if (kv_cache is None or prefill_meta.block_tables is None
727+
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
726728
or prefill_meta.block_tables.numel() == 0):
727729
# normal attention
728730
# When block_tables are not filled, it means q and k are the

vllm/attention/backends/flashinfer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def forward(
746746
query: torch.Tensor,
747747
key: torch.Tensor,
748748
value: torch.Tensor,
749-
kv_cache: Optional[torch.Tensor],
749+
kv_cache: torch.Tensor,
750750
attn_metadata: FlashInferMetadata,
751751
k_scale: float = 1.0,
752752
v_scale: float = 1.0,
@@ -770,7 +770,7 @@ def forward(
770770
if attn_metadata.num_decode_tokens > 0:
771771
assert attn_metadata.num_prefill_tokens == 0, (
772772
"Chunked prefill is not supported with flashinfer yet.")
773-
if kv_cache is not None:
773+
if kv_cache.numel() > 0:
774774
# Use the same reshape and cache kernel as flash attention.
775775
ops.reshape_and_cache_flash(
776776
key,
@@ -796,7 +796,7 @@ def forward(
796796
# when kv_cache is not provided.
797797
# This happens when vllm runs the profiling to
798798
# determine the number of blocks.
799-
if kv_cache is None:
799+
if kv_cache.numel() == 0:
800800
output = torch.ops.vllm.flash_attn_varlen_func(
801801
q=query,
802802
k=key,

vllm/attention/backends/ipex_attn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def forward(
167167
query: torch.Tensor,
168168
key: torch.Tensor,
169169
value: torch.Tensor,
170-
kv_cache: Optional[torch.Tensor],
170+
kv_cache: torch.Tensor,
171171
attn_metadata: IpexAttnMetadata, # type: ignore
172172
k_scale: float = 1.0,
173173
v_scale: float = 1.0,
@@ -180,6 +180,8 @@ def forward(
180180
key: shape = [num_tokens, num_kv_heads * head_size]
181181
value: shape = [num_tokens, num_kv_heads * head_size]
182182
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
183+
NOTE: kv_cache will be an empty tensor with shape [0]
184+
for profiling run.
183185
attn_metadata: Metadata for attention.
184186
Returns:
185187
shape = [num_tokens, num_heads * head_size]
@@ -196,7 +198,7 @@ def forward(
196198
key = key.view(-1, self.num_kv_heads, self.head_size)
197199
value = value.view(-1, self.num_kv_heads, self.head_size)
198200

199-
if kv_cache is not None:
201+
if kv_cache.numel() > 0:
200202
key_cache, value_cache = self.split_kv_cache(
201203
kv_cache, self.num_kv_heads, self.head_size)
202204
ipex_ops.reshape_and_cache(
@@ -212,7 +214,8 @@ def forward(
212214

213215
if attn_metadata.is_prompt:
214216
assert attn_metadata.seq_lens is not None
215-
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
217+
if (kv_cache.numel() == 0
218+
or attn_metadata.block_tables.numel() == 0):
216219
if self.num_kv_heads != self.num_heads:
217220
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
218221
value = value.repeat_interleave(self.num_queries_per_kv,

vllm/attention/backends/pallas.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def forward(
143143
query: torch.Tensor,
144144
key: torch.Tensor,
145145
value: torch.Tensor,
146-
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
146+
kv_cache: Tuple[torch.Tensor, torch.Tensor],
147147
attn_metadata: PallasMetadata,
148148
k_scale: float = 1.0,
149149
v_scale: float = 1.0,
@@ -155,8 +155,10 @@ def forward(
155155
query: shape = [batch_size, seq_len, num_heads * head_size]
156156
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
157157
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
158-
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
159-
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
158+
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
159+
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
160+
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
161+
with shape [0] for profiling run.
160162
attn_metadata: Metadata for attention.
161163
Returns:
162164
shape = [batch_size, seq_len, num_heads * head_size]
@@ -173,7 +175,7 @@ def forward(
173175
value = value.view(batch_size, seq_len, self.num_kv_heads,
174176
self.head_size)
175177

176-
if kv_cache[0] is not None:
178+
if kv_cache[0].numel() > 0:
177179
slot_mapping = attn_metadata.slot_mapping
178180
key_cache, value_cache = kv_cache
179181
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
@@ -205,7 +207,7 @@ def forward(
205207
output = output.permute(0, 2, 1, 3)
206208
else:
207209
# Decoding run.
208-
assert kv_cache is not None
210+
assert kv_cache[0].numel() > 0
209211

210212
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
211213
if self.megacore_mode == "batch" and batch_size % 2 != 0:

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ def forward(
396396
key: shape = [num_tokens, num_kv_heads * head_size]
397397
value: shape = [num_tokens, num_kv_heads * head_size]
398398
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
399+
NOTE: kv_cache will be an empty tensor with shape [0]
400+
for profiling run.
399401
attn_metadata: Metadata for attention.
400402
Returns:
401403
shape = [num_tokens, num_heads * head_size]
@@ -412,7 +414,7 @@ def forward(
412414
key = key.view(-1, self.num_kv_heads, self.head_size)
413415
value = value.view(-1, self.num_kv_heads, self.head_size)
414416

415-
if kv_cache is not None:
417+
if kv_cache.numel() > 0:
416418
key_cache, value_cache = PagedAttention.split_kv_cache(
417419
kv_cache, self.num_kv_heads, self.head_size)
418420

@@ -449,7 +451,7 @@ def forward(
449451
if prefill_meta := attn_metadata.prefill_metadata:
450452
# Prompt run.
451453
assert prefill_meta.seq_lens is not None
452-
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
454+
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
453455
# triton attention
454456
# When block_tables are not filled, it means q and k are the
455457
# prompt, and they have the same length.

vllm/attention/backends/torch_sdpa.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def forward(
151151
query: torch.Tensor,
152152
key: torch.Tensor,
153153
value: torch.Tensor,
154-
kv_cache: Optional[torch.Tensor],
154+
kv_cache: torch.Tensor,
155155
attn_metadata: TorchSDPAMetadata, # type: ignore
156156
k_scale: float = 1.0,
157157
v_scale: float = 1.0,
@@ -164,6 +164,8 @@ def forward(
164164
key: shape = [num_tokens, num_kv_heads * head_size]
165165
value: shape = [num_tokens, num_kv_heads * head_size]
166166
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
167+
NOTE: kv_cache will be an empty tensor with shape [0]
168+
for profiling run.
167169
attn_metadata: Metadata for attention.
168170
Returns:
169171
shape = [num_tokens, num_heads * head_size]
@@ -180,7 +182,7 @@ def forward(
180182
key = key.view(-1, self.num_kv_heads, self.head_size)
181183
value = value.view(-1, self.num_kv_heads, self.head_size)
182184

183-
if kv_cache is not None:
185+
if kv_cache.numel() > 0:
184186
key_cache, value_cache = PagedAttention.split_kv_cache(
185187
kv_cache, self.num_kv_heads, self.head_size)
186188
PagedAttention.write_to_paged_cache(key, value, key_cache,
@@ -191,7 +193,8 @@ def forward(
191193

192194
if attn_metadata.is_prompt:
193195
assert attn_metadata.seq_lens is not None
194-
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
196+
if (kv_cache.numel() == 0
197+
or attn_metadata.block_tables.numel() == 0):
195198
if self.num_kv_heads != self.num_heads:
196199
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
197200
value = value.repeat_interleave(self.num_queries_per_kv,

vllm/attention/backends/xformers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def forward(
445445
query: torch.Tensor,
446446
key: Optional[torch.Tensor],
447447
value: Optional[torch.Tensor],
448-
kv_cache: Optional[torch.Tensor],
448+
kv_cache: torch.Tensor,
449449
attn_metadata: "XFormersMetadata",
450450
k_scale: float = 1.0,
451451
v_scale: float = 1.0,
@@ -489,6 +489,8 @@ def forward(
489489
key: shape = [num_tokens, num_kv_heads * head_size]
490490
value: shape = [num_tokens, num_kv_heads * head_size]
491491
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
492+
NOTE: kv_cache will be an empty tensor with shape [0]
493+
for profiling run.
492494
attn_metadata: Metadata for attention.
493495
attn_type: Select attention type, between encoder attention,
494496
decoder self-attention, or encoder/decoder cross-
@@ -522,7 +524,7 @@ def forward(
522524
# which KV cache memory-mapping & which
523525
# seqlen datastructures we utilize
524526

525-
if (attn_type != AttentionType.ENCODER and kv_cache is not None):
527+
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
526528
# KV-cache during decoder-self- or
527529
# encoder-decoder-cross-attention, but not
528530
# during encoder attention.
@@ -588,7 +590,7 @@ def forward(
588590

589591
if prefill_meta := attn_metadata.prefill_metadata:
590592
# Prompt run.
591-
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
593+
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
592594
# normal attention.
593595
# block tables are empty if the prompt does not have a cached
594596
# prefix.

vllm/worker/embedding_model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,13 @@ def execute_model(
9797
model_executable = self.model
9898

9999
num_layers = self.model_config.get_num_layers(self.parallel_config)
100-
kv_caches = [None] * num_layers
100+
# use an empty tensor instead of `None`` to force Dynamo to pass
101+
# it by reference, rather by specializing on the value ``None``.
102+
# the `dtype` argument does not matter, and we use `float32` as
103+
# a placeholder (it has wide hardware support).
104+
kv_caches = [
105+
torch.tensor([], dtype=torch.float32, device=self.device)
106+
] * num_layers
101107

102108
execute_model_kwargs = {
103109
"input_ids":

vllm/worker/enc_dec_model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,13 @@ def profile_run(self) -> None:
340340

341341
# Run the model with the dummy inputs.
342342
num_layers = self.model_config.get_num_layers(self.parallel_config)
343-
kv_caches = [None] * num_layers
343+
# use an empty tensor instead of `None`` to force Dynamo to pass
344+
# it by reference, rather by specializing on the value ``None``.
345+
# the `dtype` argument does not matter, and we use `float32` as
346+
# a placeholder (it has wide hardware support).
347+
kv_caches = [
348+
torch.tensor([], dtype=torch.float32, device=self.device)
349+
] * num_layers
344350
finished_requests_ids = [seq.request_id for seq in seqs]
345351
model_input = self.prepare_model_input(
346352
seqs, finished_requests_ids=finished_requests_ids)

vllm/worker/model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1223,7 +1223,13 @@ def profile_run(self) -> None:
12231223

12241224
# Run the model with the dummy inputs.
12251225
num_layers = self.model_config.get_num_layers(self.parallel_config)
1226-
kv_caches = [None] * num_layers
1226+
# use an empty tensor instead of `None`` to force Dynamo to pass
1227+
# it by reference, rather by specializing on the value ``None``.
1228+
# the `dtype` argument does not matter, and we use `float32` as
1229+
# a placeholder (it has wide hardware support).
1230+
kv_caches = [
1231+
torch.tensor([], dtype=torch.float32, device=self.device)
1232+
] * num_layers
12271233
finished_requests_ids = [seq.request_id for seq in seqs]
12281234
model_input = self.prepare_model_input(
12291235
seqs, finished_requests_ids=finished_requests_ids)

vllm/worker/tpu_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ def forward(
714714
t: torch.Tensor,
715715
p: torch.Tensor,
716716
num_samples: int,
717-
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
717+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
718718
) -> torch.Tensor:
719719
"""Executes the forward pass of the model and samples the next token.
720720
@@ -745,7 +745,7 @@ def forward(
745745
)
746746

747747
# Skip this in memory profiling at initialization.
748-
if kv_caches[0][0] is not None:
748+
if kv_caches[0][0].numel() > 0:
749749
# index_copy_(slot_mapping) only works when the inserted dimension
750750
# is 0. However, the KV cache in the Pallas backend has the shape
751751
# [num_kv_heads, num_blocks, block_size, head_size]. To make it

vllm/worker/tpu_worker.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
115115
head_size = self.model_config.get_head_size()
116116
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
117117

118-
kv_caches = [(None, None) for _ in range(num_layers)]
118+
# use an empty tensor instead of `None`` to force Dynamo to pass
119+
# it by reference, rather by specializing on the value ``None``.
120+
# the `dtype` argument does not matter, and we use `float32` as
121+
# a placeholder (it has wide hardware support).
122+
kv_caches = [(torch.tensor([], dtype=torch.float32,
123+
device=self.device),
124+
torch.tensor([], dtype=torch.float32,
125+
device=self.device))
126+
for _ in range(num_layers)]
119127
self.model_runner._dummy_run(
120128
batch_size=1,
121129
seq_len=self.scheduler_config.max_num_batched_tokens,

vllm/worker/xpu_model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,13 @@ def profile_run(self) -> None:
464464

465465
# Run the model with the dummy inputs.
466466
num_layers = self.model_config.get_num_layers(self.parallel_config)
467-
kv_caches = [None] * num_layers
467+
# use an empty tensor instead of `None`` to force Dynamo to pass
468+
# it by reference, rather by specializing on the value ``None``.
469+
# the `dtype` argument does not matter, and we use `float32` as
470+
# a placeholder (it has wide hardware support).
471+
kv_caches = [
472+
torch.tensor([], dtype=torch.float32, device=self.device)
473+
] * num_layers
468474
finished_requests_ids = [seq.request_id for seq in seqs]
469475
model_input = self.prepare_model_input(
470476
seqs, finished_requests_ids=finished_requests_ids)

0 commit comments

Comments
 (0)