Skip to content

Commit 7de49aa

Browse files
authored
[torch.compile] hide slicing under custom op for inductor (#8384)
1 parent 42ffba1 commit 7de49aa

File tree

2 files changed

+74
-35
lines changed

2 files changed

+74
-35
lines changed

tests/compile/test_full_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,7 @@ def test_full_graph(model):
1616
"The future of AI is",
1717
]
1818
sampling_params = SamplingParams(temperature=0)
19-
llm = LLM(model="meta-llama/Meta-Llama-3-8B")
19+
llm = LLM(model="meta-llama/Meta-Llama-3-8B",
20+
enforce_eager=True,
21+
load_format="dummy")
2022
llm.generate(prompts, sampling_params)

vllm/attention/backends/flash_attn.py

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,40 @@ def _(
122122
return torch.empty_like(decode_query)
123123

124124

125+
@torch.library.custom_op("vllm::reshape_and_cache_flash",
126+
mutates_args=["kv_cache"])
127+
def reshape_and_cache_flash(
128+
key: torch.Tensor,
129+
value: torch.Tensor,
130+
kv_cache: torch.Tensor,
131+
slot_mapping: torch.Tensor,
132+
kv_cache_dtype: str,
133+
k_scale: float,
134+
v_scale: float,
135+
) -> None:
136+
"""Inductor cannot deal with inplace operations on views.
137+
See https://github.com/pytorch/pytorch/issues/131192
138+
and https://github.com/pytorch/pytorch/issues/130174
139+
This is a workaround to hide the view operation from the inductor.
140+
"""
141+
return torch.ops._C_cache_ops.reshape_and_cache_flash(
142+
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
143+
k_scale, v_scale)
144+
145+
146+
@reshape_and_cache_flash.register_fake # type: ignore
147+
def _(
148+
key: torch.Tensor,
149+
value: torch.Tensor,
150+
kv_cache: torch.Tensor,
151+
slot_mapping: torch.Tensor,
152+
kv_cache_dtype: str,
153+
k_scale: float,
154+
v_scale: float,
155+
) -> None:
156+
pass
157+
158+
125159
class FlashAttentionBackend(AttentionBackend):
126160

127161
@staticmethod
@@ -653,11 +687,10 @@ def forward(
653687
# Reshape the input keys and values and store them in the cache.
654688
# If kv_cache is not provided, the new key and value tensors are
655689
# not cached. This happens during the initial memory profiling run.
656-
ops.reshape_and_cache_flash(
690+
torch.ops.vllm.reshape_and_cache_flash(
657691
key,
658692
value,
659-
key_cache,
660-
value_cache,
693+
kv_cache,
661694
attn_metadata.slot_mapping.flatten(),
662695
self.kv_cache_dtype,
663696
k_scale,
@@ -669,7 +702,6 @@ def forward(
669702
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
670703
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
671704

672-
output = torch.empty_like(query)
673705
# Query for decode. KV is not needed because it is already cached.
674706
decode_query = query[num_prefill_tokens:]
675707
# QKV for prefill.
@@ -680,14 +712,17 @@ def forward(
680712
assert query.shape[0] == num_prefill_tokens
681713
assert decode_query.shape[0] == num_decode_tokens
682714

715+
prefill_output: Optional[torch.Tensor] = None
716+
decode_output: Optional[torch.Tensor] = None
717+
683718
if prefill_meta := attn_metadata.prefill_metadata:
684719
# Prompt run.
685720
if (kv_cache is None or prefill_meta.block_tables is None
686721
or prefill_meta.block_tables.numel() == 0):
687722
# normal attention
688723
# When block_tables are not filled, it means q and k are the
689724
# prompt, and they have the same length.
690-
out = torch.ops.vllm.flash_attn_varlen_func(
725+
prefill_output = torch.ops.vllm.flash_attn_varlen_func(
691726
q=query,
692727
k=key,
693728
v=value,
@@ -701,42 +736,44 @@ def forward(
701736
alibi_slopes=self.alibi_slopes,
702737
softcap=self.logits_soft_cap,
703738
)
704-
assert output[:num_prefill_tokens].shape == out.shape
705-
output[:num_prefill_tokens] = out
706739
else:
707740
# prefix-enabled attention
708741
assert prefill_meta.seq_lens is not None
709742
max_seq_len = max(prefill_meta.seq_lens)
710-
output[:
711-
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
712-
q=query,
713-
k=key_cache,
714-
v=value_cache,
715-
cu_seqlens_q=prefill_meta.query_start_loc,
716-
max_seqlen_q=prefill_meta.max_query_len,
717-
cu_seqlens_k=prefill_meta.seq_start_loc,
718-
max_seqlen_k=max_seq_len,
719-
softmax_scale=self.scale,
720-
causal=True,
721-
alibi_slopes=self.alibi_slopes,
722-
block_table=prefill_meta.block_tables,
723-
softcap=self.logits_soft_cap,
724-
)
725-
726-
if decode_meta := attn_metadata.decode_metadata:
727-
# Decoding run.
728-
output[
729-
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
730-
decode_query.unsqueeze(1),
731-
key_cache,
732-
value_cache,
733-
block_table=decode_meta.block_tables,
734-
cache_seqlens=decode_meta.seq_lens_tensor,
743+
prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
744+
q=query,
745+
k=key_cache,
746+
v=value_cache,
747+
cu_seqlens_q=prefill_meta.query_start_loc,
748+
max_seqlen_q=prefill_meta.max_query_len,
749+
cu_seqlens_k=prefill_meta.seq_start_loc,
750+
max_seqlen_k=max_seq_len,
735751
softmax_scale=self.scale,
736752
causal=True,
737753
alibi_slopes=self.alibi_slopes,
754+
block_table=prefill_meta.block_tables,
738755
softcap=self.logits_soft_cap,
739-
).squeeze(1)
756+
)
740757

741-
# Reshape the output tensor.
758+
if decode_meta := attn_metadata.decode_metadata:
759+
# Decoding run.
760+
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
761+
decode_query.unsqueeze(1),
762+
key_cache,
763+
value_cache,
764+
block_table=decode_meta.block_tables,
765+
cache_seqlens=decode_meta.seq_lens_tensor,
766+
softmax_scale=self.scale,
767+
causal=True,
768+
alibi_slopes=self.alibi_slopes,
769+
softcap=self.logits_soft_cap,
770+
).squeeze(1)
771+
772+
if prefill_output is None:
773+
assert decode_output is not None
774+
return decode_output.view(num_decode_tokens, hidden_size)
775+
if decode_output is None:
776+
assert prefill_output is not None
777+
return prefill_output.view(num_prefill_tokens, hidden_size)
778+
output = torch.cat([prefill_output, decode_output], dim=0)
742779
return output.view(num_tokens, hidden_size)

0 commit comments

Comments
 (0)