@@ -182,9 +182,14 @@ def construct_local_mask(
182
182
query_padding_mask = None ,
183
183
key_padding_mask = None ,
184
184
device = None ,
185
+ key_leftpad = None ,
185
186
):
186
187
row_idx = rearrange (torch .arange (seqlen_q , device = device , dtype = torch .long ), "s -> s 1" )
187
188
col_idx = torch .arange (seqlen_k , device = device , dtype = torch .long )
189
+ if key_leftpad is not None :
190
+ key_leftpad = rearrange (key_leftpad , "b -> b 1 1 1" )
191
+ col_idx = repeat (col_idx , "s -> b 1 1 s" , b = key_leftpad .shape [0 ])
192
+ col_idx = torch .where (col_idx >= key_leftpad , col_idx - key_leftpad , 2 ** 32 )
188
193
sk = (
189
194
seqlen_k
190
195
if key_padding_mask is None
@@ -219,6 +224,7 @@ def attention_ref(
219
224
softcap = 0.0 ,
220
225
upcast = True ,
221
226
reorder_ops = False ,
227
+ key_leftpad = None ,
222
228
):
223
229
"""
224
230
Arguments:
@@ -268,6 +274,7 @@ def attention_ref(
268
274
query_padding_mask ,
269
275
key_padding_mask ,
270
276
q .device ,
277
+ key_leftpad = key_leftpad ,
271
278
)
272
279
scores .masked_fill_ (local_mask , float ("-inf" ))
273
280
if attn_bias is not None :
@@ -306,6 +313,7 @@ def attention_kvpacked_ref(
306
313
softcap = 0.0 ,
307
314
upcast = True ,
308
315
reorder_ops = False ,
316
+ key_leftpad = None ,
309
317
):
310
318
return attention_ref (
311
319
q ,
@@ -321,6 +329,7 @@ def attention_kvpacked_ref(
321
329
window_size = window_size ,
322
330
softcap = softcap ,
323
331
reorder_ops = reorder_ops ,
332
+ key_leftpad = key_leftpad ,
324
333
)
325
334
326
335
@@ -1868,9 +1877,11 @@ def test_flash_attn_splitkv(
1868
1877
# @pytest.mark.parametrize("rotary_fraction", [0.0])
1869
1878
@pytest .mark .parametrize ("paged_kv_block_size" , [None , 256 ])
1870
1879
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
1871
- # @pytest.mark.parametrize("paged_kv_block_size", [256])
1872
- @pytest .mark .parametrize ("has_batch_idx" , [False , True ])
1873
- # @pytest.mark.parametrize("has_batch_idx", [False])
1880
+ # @pytest.mark.parametrize("paged_kv_block_size", [None])
1881
+ @pytest .mark .parametrize ("has_leftpad" , [False , True ])
1882
+ # @pytest.mark.parametrize("has_leftpad", [True])
1883
+ # @pytest.mark.parametrize("has_batch_idx", [False, True])
1884
+ @pytest .mark .parametrize ("has_batch_idx" , [False ])
1874
1885
@pytest .mark .parametrize ("d" , [32 , 59 , 64 , 80 , 128 , 256 ])
1875
1886
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
1876
1887
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
@@ -1898,6 +1909,7 @@ def test_flash_attn_kvcache(
1898
1909
seqlen_k ,
1899
1910
d ,
1900
1911
has_batch_idx ,
1912
+ has_leftpad ,
1901
1913
paged_kv_block_size ,
1902
1914
rotary_fraction ,
1903
1915
rotary_interleaved ,
@@ -1916,6 +1928,8 @@ def test_flash_attn_kvcache(
1916
1928
pytest .skip ()
1917
1929
if has_batch_idx and paged_kv_block_size is not None :
1918
1930
pytest .skip ()
1931
+ if has_leftpad and paged_kv_block_size is not None :
1932
+ pytest .skip ()
1919
1933
device = "cuda"
1920
1934
# set seed
1921
1935
torch .random .manual_seed (0 )
@@ -1961,9 +1975,19 @@ def test_flash_attn_kvcache(
1961
1975
dtype = torch .int32 ,
1962
1976
device = device ,
1963
1977
)
1978
+ if has_leftpad :
1979
+ cache_leftpad = torch .cat ([torch .randint (0 , cache_seqlens [i ].item (), (1 ,), dtype = torch .int32 , device = device )
1980
+ if cache_seqlens [i ].item () > 0 else torch .zeros (1 , dtype = torch .int32 , device = device )
1981
+ for i in range (batch_size )])
1982
+ else :
1983
+ cache_leftpad = None
1964
1984
arange = rearrange (torch .arange (seqlen_k , device = device ), "s -> 1 s" )
1965
1985
cache_seqlens_expanded = rearrange (cache_seqlens , "b -> b 1" )
1966
1986
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0 )
1987
+ if has_leftpad :
1988
+ key_padding_mask = torch .logical_and (
1989
+ key_padding_mask , arange >= cache_leftpad .unsqueeze (- 1 ).expand (- 1 , seqlen_k )
1990
+ )
1967
1991
if has_batch_idx :
1968
1992
cache_batch_idx = torch .randperm (batch_size_cache , dtype = torch .int32 , device = device )[
1969
1993
:batch_size
@@ -2038,6 +2062,7 @@ def test_flash_attn_kvcache(
2038
2062
rotary_sin = sin ,
2039
2063
cache_seqlens = cache_seqlens ,
2040
2064
cache_batch_idx = cache_batch_idx ,
2065
+ cache_leftpad = cache_leftpad ,
2041
2066
block_table = block_table ,
2042
2067
causal = causal ,
2043
2068
window_size = window_size ,
@@ -2066,6 +2091,7 @@ def test_flash_attn_kvcache(
2066
2091
None ,
2067
2092
causal = causal ,
2068
2093
window_size = window_size ,
2094
+ key_leftpad = cache_leftpad ,
2069
2095
)
2070
2096
out_pt , _ = attention_ref (
2071
2097
q_ro ,
@@ -2080,6 +2106,7 @@ def test_flash_attn_kvcache(
2080
2106
window_size = window_size ,
2081
2107
upcast = False ,
2082
2108
reorder_ops = True ,
2109
+ key_leftpad = cache_leftpad ,
2083
2110
)
2084
2111
print (f"Output max diff: { (out - out_ref ).abs ().max ().item ()} " )
2085
2112
print (f"Output mean diff: { (out - out_ref ).abs ().mean ().item ()} " )
0 commit comments