@@ -44,15 +44,15 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
44
44
45
45
46
46
def _flash_attn_forward (
47
- q , k , v , dropout_p , softmax_scale , causal , window_size , alibi_slopes , return_softmax
47
+ q , k , v , dropout_p , softmax_scale , causal , window_size , alibi_slopes , return_softmax , * , out = None
48
48
):
49
49
maybe_contiguous = lambda x : x .contiguous () if x .stride (- 1 ) != 1 else x
50
50
q , k , v = [maybe_contiguous (x ) for x in (q , k , v )]
51
51
out , q , k , v , out_padded , softmax_lse , S_dmask , rng_state = flash_attn_cuda .fwd (
52
52
q ,
53
53
k ,
54
54
v ,
55
- None ,
55
+ out ,
56
56
alibi_slopes ,
57
57
dropout_p ,
58
58
softmax_scale ,
@@ -80,14 +80,16 @@ def _flash_attn_varlen_forward(
80
80
alibi_slopes ,
81
81
return_softmax ,
82
82
block_table ,
83
+ * ,
84
+ out = None
83
85
):
84
86
maybe_contiguous = lambda x : x .contiguous () if x .stride (- 1 ) != 1 else x
85
87
q , k , v = [maybe_contiguous (x ) for x in (q , k , v )]
86
88
out , q , k , v , out_padded , softmax_lse , S_dmask , rng_state = flash_attn_cuda .varlen_fwd (
87
89
q ,
88
90
k ,
89
91
v ,
90
- None ,
92
+ out ,
91
93
cu_seqlens_q ,
92
94
cu_seqlens_k ,
93
95
None ,
@@ -220,6 +222,8 @@ def forward(
220
222
alibi_slopes ,
221
223
deterministic ,
222
224
return_softmax ,
225
+ * ,
226
+ out = None ,
223
227
):
224
228
if softmax_scale is None :
225
229
softmax_scale = qkv .shape [- 1 ] ** (- 0.5 )
@@ -233,6 +237,7 @@ def forward(
233
237
window_size = window_size ,
234
238
alibi_slopes = alibi_slopes ,
235
239
return_softmax = return_softmax and dropout_p > 0 ,
240
+ out = out ,
236
241
)
237
242
ctx .save_for_backward (q , k , v , out_padded , softmax_lse , rng_state )
238
243
ctx .dropout_p = dropout_p
@@ -284,6 +289,8 @@ def forward(
284
289
alibi_slopes ,
285
290
deterministic ,
286
291
return_softmax ,
292
+ * ,
293
+ out = None ,
287
294
):
288
295
if softmax_scale is None :
289
296
softmax_scale = qkv .shape [- 1 ] ** (- 0.5 )
@@ -302,6 +309,7 @@ def forward(
302
309
alibi_slopes = alibi_slopes ,
303
310
return_softmax = return_softmax and dropout_p > 0 ,
304
311
block_table = None ,
312
+ out = out ,
305
313
)
306
314
ctx .save_for_backward (q , k , v , out_padded , softmax_lse , cu_seqlens , rng_state )
307
315
ctx .dropout_p = dropout_p
@@ -357,6 +365,7 @@ def forward(
357
365
alibi_slopes ,
358
366
deterministic ,
359
367
return_softmax ,
368
+ out = None ,
360
369
):
361
370
if softmax_scale is None :
362
371
softmax_scale = q .shape [- 1 ] ** (- 0.5 )
@@ -370,6 +379,7 @@ def forward(
370
379
window_size = window_size ,
371
380
alibi_slopes = alibi_slopes ,
372
381
return_softmax = return_softmax and dropout_p > 0 ,
382
+ out = out ,
373
383
)
374
384
ctx .save_for_backward (q , k , v , out_padded , softmax_lse , rng_state )
375
385
ctx .dropout_p = dropout_p
@@ -426,6 +436,7 @@ def forward(
426
436
alibi_slopes ,
427
437
deterministic ,
428
438
return_softmax ,
439
+ out = None ,
429
440
):
430
441
if softmax_scale is None :
431
442
softmax_scale = q .shape [- 1 ] ** (- 0.5 )
@@ -444,6 +455,7 @@ def forward(
444
455
alibi_slopes = alibi_slopes ,
445
456
return_softmax = return_softmax and dropout_p > 0 ,
446
457
block_table = None ,
458
+ out = out ,
447
459
)
448
460
ctx .save_for_backward (
449
461
q , k , v , out_padded , softmax_lse , cu_seqlens_q , cu_seqlens_k , rng_state
@@ -505,6 +517,7 @@ def forward(
505
517
alibi_slopes ,
506
518
deterministic ,
507
519
return_softmax ,
520
+ out = None ,
508
521
):
509
522
if softmax_scale is None :
510
523
softmax_scale = q .shape [- 1 ] ** (- 0.5 )
@@ -518,6 +531,7 @@ def forward(
518
531
window_size = window_size ,
519
532
alibi_slopes = alibi_slopes ,
520
533
return_softmax = return_softmax and dropout_p > 0 ,
534
+ out = out ,
521
535
)
522
536
ctx .save_for_backward (q , k , v , out_padded , softmax_lse , rng_state )
523
537
ctx .dropout_p = dropout_p
@@ -575,6 +589,7 @@ def forward(
575
589
deterministic ,
576
590
return_softmax ,
577
591
block_table ,
592
+ out = None ,
578
593
):
579
594
if softmax_scale is None :
580
595
softmax_scale = q .shape [- 1 ] ** (- 0.5 )
@@ -593,6 +608,7 @@ def forward(
593
608
alibi_slopes = alibi_slopes ,
594
609
return_softmax = return_softmax and dropout_p > 0 ,
595
610
block_table = block_table ,
611
+ out = out ,
596
612
)
597
613
ctx .save_for_backward (
598
614
q , k , v , out_padded , softmax_lse , cu_seqlens_q , cu_seqlens_k , rng_state
@@ -648,6 +664,8 @@ def flash_attn_qkvpacked_func(
648
664
alibi_slopes = None ,
649
665
deterministic = False ,
650
666
return_attn_probs = False ,
667
+ * ,
668
+ out = None ,
651
669
):
652
670
"""dropout_p should be set to 0.0 during evaluation
653
671
If Q, K, V are already stacked into 1 tensor, this function will be faster than
@@ -691,6 +709,7 @@ def flash_attn_qkvpacked_func(
691
709
alibi_slopes ,
692
710
deterministic ,
693
711
return_attn_probs ,
712
+ out = out ,
694
713
)
695
714
696
715
@@ -704,6 +723,8 @@ def flash_attn_kvpacked_func(
704
723
alibi_slopes = None ,
705
724
deterministic = False ,
706
725
return_attn_probs = False ,
726
+ * ,
727
+ out = None ,
707
728
):
708
729
"""dropout_p should be set to 0.0 during evaluation
709
730
If K, V are already stacked into 1 tensor, this function will be faster than
@@ -765,6 +786,7 @@ def flash_attn_kvpacked_func(
765
786
alibi_slopes ,
766
787
deterministic ,
767
788
return_attn_probs ,
789
+ out = out ,
768
790
)
769
791
770
792
@@ -779,6 +801,8 @@ def flash_attn_func(
779
801
alibi_slopes = None ,
780
802
deterministic = False ,
781
803
return_attn_probs = False ,
804
+ * ,
805
+ out = None ,
782
806
):
783
807
"""dropout_p should be set to 0.0 during evaluation
784
808
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -839,6 +863,7 @@ def flash_attn_func(
839
863
alibi_slopes ,
840
864
deterministic ,
841
865
return_attn_probs ,
866
+ out = out ,
842
867
)
843
868
844
869
@@ -853,6 +878,8 @@ def flash_attn_varlen_qkvpacked_func(
853
878
alibi_slopes = None ,
854
879
deterministic = False ,
855
880
return_attn_probs = False ,
881
+ * ,
882
+ out = None ,
856
883
):
857
884
"""dropout_p should be set to 0.0 during evaluation
858
885
If Q, K, V are already stacked into 1 tensor, this function will be faster than
@@ -901,6 +928,7 @@ def flash_attn_varlen_qkvpacked_func(
901
928
alibi_slopes ,
902
929
deterministic ,
903
930
return_attn_probs ,
931
+ out = out ,
904
932
)
905
933
906
934
@@ -918,6 +946,8 @@ def flash_attn_varlen_kvpacked_func(
918
946
alibi_slopes = None ,
919
947
deterministic = False ,
920
948
return_attn_probs = False ,
949
+ * ,
950
+ out = None ,
921
951
):
922
952
"""dropout_p should be set to 0.0 during evaluation
923
953
If K, V are already stacked into 1 tensor, this function will be faster than
@@ -989,6 +1019,7 @@ def flash_attn_varlen_kvpacked_func(
989
1019
alibi_slopes ,
990
1020
deterministic ,
991
1021
return_attn_probs ,
1022
+ out = out ,
992
1023
)
993
1024
994
1025
@@ -1008,6 +1039,8 @@ def flash_attn_varlen_func(
1008
1039
deterministic = False ,
1009
1040
return_attn_probs = False ,
1010
1041
block_table = None ,
1042
+ * ,
1043
+ out = None ,
1011
1044
):
1012
1045
"""dropout_p should be set to 0.0 during evaluation
1013
1046
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -1079,6 +1112,7 @@ def flash_attn_varlen_func(
1079
1112
deterministic ,
1080
1113
return_attn_probs ,
1081
1114
block_table ,
1115
+ out = out ,
1082
1116
)
1083
1117
1084
1118
@@ -1099,6 +1133,8 @@ def flash_attn_with_kvcache(
1099
1133
rotary_interleaved = True ,
1100
1134
alibi_slopes = None ,
1101
1135
num_splits = 0 ,
1136
+ * ,
1137
+ out = None ,
1102
1138
):
1103
1139
"""
1104
1140
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -1206,7 +1242,7 @@ def flash_attn_with_kvcache(
1206
1242
cache_batch_idx ,
1207
1243
block_table ,
1208
1244
alibi_slopes ,
1209
- None ,
1245
+ out ,
1210
1246
softmax_scale ,
1211
1247
causal ,
1212
1248
window_size [0 ],
0 commit comments