@@ -122,6 +122,40 @@ def _(
122
122
return torch .empty_like (decode_query )
123
123
124
124
125
+ @torch .library .custom_op ("vllm::reshape_and_cache_flash" ,
126
+ mutates_args = ["kv_cache" ])
127
+ def reshape_and_cache_flash (
128
+ key : torch .Tensor ,
129
+ value : torch .Tensor ,
130
+ kv_cache : torch .Tensor ,
131
+ slot_mapping : torch .Tensor ,
132
+ kv_cache_dtype : str ,
133
+ k_scale : float ,
134
+ v_scale : float ,
135
+ ) -> None :
136
+ """Inductor cannot deal with inplace operations on views.
137
+ See https://github.com/pytorch/pytorch/issues/131192
138
+ and https://github.com/pytorch/pytorch/issues/130174
139
+ This is a workaround to hide the view operation from the inductor.
140
+ """
141
+ return torch .ops ._C_cache_ops .reshape_and_cache_flash (
142
+ key , value , kv_cache [0 ], kv_cache [1 ], slot_mapping , kv_cache_dtype ,
143
+ k_scale , v_scale )
144
+
145
+
146
+ @reshape_and_cache_flash .register_fake # type: ignore
147
+ def _ (
148
+ key : torch .Tensor ,
149
+ value : torch .Tensor ,
150
+ kv_cache : torch .Tensor ,
151
+ slot_mapping : torch .Tensor ,
152
+ kv_cache_dtype : str ,
153
+ k_scale : float ,
154
+ v_scale : float ,
155
+ ) -> None :
156
+ pass
157
+
158
+
125
159
class FlashAttentionBackend (AttentionBackend ):
126
160
127
161
@staticmethod
@@ -653,11 +687,10 @@ def forward(
653
687
# Reshape the input keys and values and store them in the cache.
654
688
# If kv_cache is not provided, the new key and value tensors are
655
689
# not cached. This happens during the initial memory profiling run.
656
- ops .reshape_and_cache_flash (
690
+ torch . ops . vllm .reshape_and_cache_flash (
657
691
key ,
658
692
value ,
659
- key_cache ,
660
- value_cache ,
693
+ kv_cache ,
661
694
attn_metadata .slot_mapping .flatten (),
662
695
self .kv_cache_dtype ,
663
696
k_scale ,
@@ -669,7 +702,6 @@ def forward(
669
702
assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens
670
703
assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens
671
704
672
- output = torch .empty_like (query )
673
705
# Query for decode. KV is not needed because it is already cached.
674
706
decode_query = query [num_prefill_tokens :]
675
707
# QKV for prefill.
@@ -680,14 +712,17 @@ def forward(
680
712
assert query .shape [0 ] == num_prefill_tokens
681
713
assert decode_query .shape [0 ] == num_decode_tokens
682
714
715
+ prefill_output : Optional [torch .Tensor ] = None
716
+ decode_output : Optional [torch .Tensor ] = None
717
+
683
718
if prefill_meta := attn_metadata .prefill_metadata :
684
719
# Prompt run.
685
720
if (kv_cache is None or prefill_meta .block_tables is None
686
721
or prefill_meta .block_tables .numel () == 0 ):
687
722
# normal attention
688
723
# When block_tables are not filled, it means q and k are the
689
724
# prompt, and they have the same length.
690
- out = torch .ops .vllm .flash_attn_varlen_func (
725
+ prefill_output = torch .ops .vllm .flash_attn_varlen_func (
691
726
q = query ,
692
727
k = key ,
693
728
v = value ,
@@ -701,42 +736,44 @@ def forward(
701
736
alibi_slopes = self .alibi_slopes ,
702
737
softcap = self .logits_soft_cap ,
703
738
)
704
- assert output [:num_prefill_tokens ].shape == out .shape
705
- output [:num_prefill_tokens ] = out
706
739
else :
707
740
# prefix-enabled attention
708
741
assert prefill_meta .seq_lens is not None
709
742
max_seq_len = max (prefill_meta .seq_lens )
710
- output [:
711
- num_prefill_tokens ] = torch .ops .vllm .flash_attn_varlen_func ( # noqa
712
- q = query ,
713
- k = key_cache ,
714
- v = value_cache ,
715
- cu_seqlens_q = prefill_meta .query_start_loc ,
716
- max_seqlen_q = prefill_meta .max_query_len ,
717
- cu_seqlens_k = prefill_meta .seq_start_loc ,
718
- max_seqlen_k = max_seq_len ,
719
- softmax_scale = self .scale ,
720
- causal = True ,
721
- alibi_slopes = self .alibi_slopes ,
722
- block_table = prefill_meta .block_tables ,
723
- softcap = self .logits_soft_cap ,
724
- )
725
-
726
- if decode_meta := attn_metadata .decode_metadata :
727
- # Decoding run.
728
- output [
729
- num_prefill_tokens :] = torch .ops .vllm .flash_attn_with_kvcache (
730
- decode_query .unsqueeze (1 ),
731
- key_cache ,
732
- value_cache ,
733
- block_table = decode_meta .block_tables ,
734
- cache_seqlens = decode_meta .seq_lens_tensor ,
743
+ prefill_output = torch .ops .vllm .flash_attn_varlen_func ( # noqa
744
+ q = query ,
745
+ k = key_cache ,
746
+ v = value_cache ,
747
+ cu_seqlens_q = prefill_meta .query_start_loc ,
748
+ max_seqlen_q = prefill_meta .max_query_len ,
749
+ cu_seqlens_k = prefill_meta .seq_start_loc ,
750
+ max_seqlen_k = max_seq_len ,
735
751
softmax_scale = self .scale ,
736
752
causal = True ,
737
753
alibi_slopes = self .alibi_slopes ,
754
+ block_table = prefill_meta .block_tables ,
738
755
softcap = self .logits_soft_cap ,
739
- ). squeeze ( 1 )
756
+ )
740
757
741
- # Reshape the output tensor.
758
+ if decode_meta := attn_metadata .decode_metadata :
759
+ # Decoding run.
760
+ decode_output = torch .ops .vllm .flash_attn_with_kvcache (
761
+ decode_query .unsqueeze (1 ),
762
+ key_cache ,
763
+ value_cache ,
764
+ block_table = decode_meta .block_tables ,
765
+ cache_seqlens = decode_meta .seq_lens_tensor ,
766
+ softmax_scale = self .scale ,
767
+ causal = True ,
768
+ alibi_slopes = self .alibi_slopes ,
769
+ softcap = self .logits_soft_cap ,
770
+ ).squeeze (1 )
771
+
772
+ if prefill_output is None :
773
+ assert decode_output is not None
774
+ return decode_output .view (num_decode_tokens , hidden_size )
775
+ if decode_output is None :
776
+ assert prefill_output is not None
777
+ return prefill_output .view (num_prefill_tokens , hidden_size )
778
+ output = torch .cat ([prefill_output , decode_output ], dim = 0 )
742
779
return output .view (num_tokens , hidden_size )
0 commit comments