Skip to content

Commit b16c279

Browse files
authored
Expose out in python API (#2)
1 parent eee8e47 commit b16c279

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

vllm_flash_attn/flash_attn_interface.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
4444

4545

4646
def _flash_attn_forward(
47-
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
47+
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None
4848
):
4949
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
5050
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
5151
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
5252
q,
5353
k,
5454
v,
55-
None,
55+
out,
5656
alibi_slopes,
5757
dropout_p,
5858
softmax_scale,
@@ -80,14 +80,16 @@ def _flash_attn_varlen_forward(
8080
alibi_slopes,
8181
return_softmax,
8282
block_table,
83+
*,
84+
out=None
8385
):
8486
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
8587
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
8688
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
8789
q,
8890
k,
8991
v,
90-
None,
92+
out,
9193
cu_seqlens_q,
9294
cu_seqlens_k,
9395
None,
@@ -220,6 +222,8 @@ def forward(
220222
alibi_slopes,
221223
deterministic,
222224
return_softmax,
225+
*,
226+
out=None,
223227
):
224228
if softmax_scale is None:
225229
softmax_scale = qkv.shape[-1] ** (-0.5)
@@ -233,6 +237,7 @@ def forward(
233237
window_size=window_size,
234238
alibi_slopes=alibi_slopes,
235239
return_softmax=return_softmax and dropout_p > 0,
240+
out=out,
236241
)
237242
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
238243
ctx.dropout_p = dropout_p
@@ -284,6 +289,8 @@ def forward(
284289
alibi_slopes,
285290
deterministic,
286291
return_softmax,
292+
*,
293+
out=None,
287294
):
288295
if softmax_scale is None:
289296
softmax_scale = qkv.shape[-1] ** (-0.5)
@@ -302,6 +309,7 @@ def forward(
302309
alibi_slopes=alibi_slopes,
303310
return_softmax=return_softmax and dropout_p > 0,
304311
block_table=None,
312+
out=out,
305313
)
306314
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
307315
ctx.dropout_p = dropout_p
@@ -357,6 +365,7 @@ def forward(
357365
alibi_slopes,
358366
deterministic,
359367
return_softmax,
368+
out=None,
360369
):
361370
if softmax_scale is None:
362371
softmax_scale = q.shape[-1] ** (-0.5)
@@ -370,6 +379,7 @@ def forward(
370379
window_size=window_size,
371380
alibi_slopes=alibi_slopes,
372381
return_softmax=return_softmax and dropout_p > 0,
382+
out=out,
373383
)
374384
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
375385
ctx.dropout_p = dropout_p
@@ -426,6 +436,7 @@ def forward(
426436
alibi_slopes,
427437
deterministic,
428438
return_softmax,
439+
out=None,
429440
):
430441
if softmax_scale is None:
431442
softmax_scale = q.shape[-1] ** (-0.5)
@@ -444,6 +455,7 @@ def forward(
444455
alibi_slopes=alibi_slopes,
445456
return_softmax=return_softmax and dropout_p > 0,
446457
block_table=None,
458+
out=out,
447459
)
448460
ctx.save_for_backward(
449461
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
@@ -505,6 +517,7 @@ def forward(
505517
alibi_slopes,
506518
deterministic,
507519
return_softmax,
520+
out=None,
508521
):
509522
if softmax_scale is None:
510523
softmax_scale = q.shape[-1] ** (-0.5)
@@ -518,6 +531,7 @@ def forward(
518531
window_size=window_size,
519532
alibi_slopes=alibi_slopes,
520533
return_softmax=return_softmax and dropout_p > 0,
534+
out=out,
521535
)
522536
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
523537
ctx.dropout_p = dropout_p
@@ -575,6 +589,7 @@ def forward(
575589
deterministic,
576590
return_softmax,
577591
block_table,
592+
out=None,
578593
):
579594
if softmax_scale is None:
580595
softmax_scale = q.shape[-1] ** (-0.5)
@@ -593,6 +608,7 @@ def forward(
593608
alibi_slopes=alibi_slopes,
594609
return_softmax=return_softmax and dropout_p > 0,
595610
block_table=block_table,
611+
out=out,
596612
)
597613
ctx.save_for_backward(
598614
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
@@ -648,6 +664,8 @@ def flash_attn_qkvpacked_func(
648664
alibi_slopes=None,
649665
deterministic=False,
650666
return_attn_probs=False,
667+
*,
668+
out=None,
651669
):
652670
"""dropout_p should be set to 0.0 during evaluation
653671
If Q, K, V are already stacked into 1 tensor, this function will be faster than
@@ -691,6 +709,7 @@ def flash_attn_qkvpacked_func(
691709
alibi_slopes,
692710
deterministic,
693711
return_attn_probs,
712+
out=out,
694713
)
695714

696715

@@ -704,6 +723,8 @@ def flash_attn_kvpacked_func(
704723
alibi_slopes=None,
705724
deterministic=False,
706725
return_attn_probs=False,
726+
*,
727+
out=None,
707728
):
708729
"""dropout_p should be set to 0.0 during evaluation
709730
If K, V are already stacked into 1 tensor, this function will be faster than
@@ -765,6 +786,7 @@ def flash_attn_kvpacked_func(
765786
alibi_slopes,
766787
deterministic,
767788
return_attn_probs,
789+
out=out,
768790
)
769791

770792

@@ -779,6 +801,8 @@ def flash_attn_func(
779801
alibi_slopes=None,
780802
deterministic=False,
781803
return_attn_probs=False,
804+
*,
805+
out=None,
782806
):
783807
"""dropout_p should be set to 0.0 during evaluation
784808
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -839,6 +863,7 @@ def flash_attn_func(
839863
alibi_slopes,
840864
deterministic,
841865
return_attn_probs,
866+
out=out,
842867
)
843868

844869

@@ -853,6 +878,8 @@ def flash_attn_varlen_qkvpacked_func(
853878
alibi_slopes=None,
854879
deterministic=False,
855880
return_attn_probs=False,
881+
*,
882+
out=None,
856883
):
857884
"""dropout_p should be set to 0.0 during evaluation
858885
If Q, K, V are already stacked into 1 tensor, this function will be faster than
@@ -901,6 +928,7 @@ def flash_attn_varlen_qkvpacked_func(
901928
alibi_slopes,
902929
deterministic,
903930
return_attn_probs,
931+
out=out,
904932
)
905933

906934

@@ -918,6 +946,8 @@ def flash_attn_varlen_kvpacked_func(
918946
alibi_slopes=None,
919947
deterministic=False,
920948
return_attn_probs=False,
949+
*,
950+
out=None,
921951
):
922952
"""dropout_p should be set to 0.0 during evaluation
923953
If K, V are already stacked into 1 tensor, this function will be faster than
@@ -989,6 +1019,7 @@ def flash_attn_varlen_kvpacked_func(
9891019
alibi_slopes,
9901020
deterministic,
9911021
return_attn_probs,
1022+
out=out,
9921023
)
9931024

9941025

@@ -1008,6 +1039,8 @@ def flash_attn_varlen_func(
10081039
deterministic=False,
10091040
return_attn_probs=False,
10101041
block_table=None,
1042+
*,
1043+
out=None,
10111044
):
10121045
"""dropout_p should be set to 0.0 during evaluation
10131046
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -1079,6 +1112,7 @@ def flash_attn_varlen_func(
10791112
deterministic,
10801113
return_attn_probs,
10811114
block_table,
1115+
out=out,
10821116
)
10831117

10841118

@@ -1099,6 +1133,8 @@ def flash_attn_with_kvcache(
10991133
rotary_interleaved=True,
11001134
alibi_slopes=None,
11011135
num_splits=0,
1136+
*,
1137+
out=None,
11021138
):
11031139
"""
11041140
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -1206,7 +1242,7 @@ def flash_attn_with_kvcache(
12061242
cache_batch_idx,
12071243
block_table,
12081244
alibi_slopes,
1209-
None,
1245+
out,
12101246
softmax_scale,
12111247
causal,
12121248
window_size[0],

0 commit comments

Comments
 (0)