Skip to content

Commit 40e534a

Browse files
committed
Implement cache_leftpad
1 parent 116b05f commit 40e534a

File tree

6 files changed

+70
-12
lines changed

6 files changed

+70
-12
lines changed

csrc/flash_attn/flash_api.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
532532
const at::Tensor &cu_seqlens_q, // b+1
533533
const at::Tensor &cu_seqlens_k, // b+1
534534
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
535+
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
535536
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
536537
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
537538
int max_seqlen_q,
@@ -731,6 +732,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
731732
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
732733
}
733734

735+
if (leftpad_k_.has_value()) {
736+
auto leftpad_k = leftpad_k_.value();
737+
TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
738+
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
739+
CHECK_DEVICE(leftpad_k);
740+
CHECK_CONTIGUOUS(leftpad_k);
741+
CHECK_SHAPE(leftpad_k, batch_size);
742+
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
743+
}
744+
734745
// number of times random will be generated per thread, to offset philox counter in thc random
735746
// state
736747
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
@@ -1279,6 +1290,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
12791290
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
12801291
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
12811292
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
1293+
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
12821294
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
12831295
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
12841296
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
@@ -1469,6 +1481,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
14691481
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
14701482
}
14711483
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
1484+
if (leftpad_k_.has_value()) {
1485+
TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
1486+
auto leftpad_k = leftpad_k_.value();
1487+
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
1488+
CHECK_DEVICE(leftpad_k);
1489+
CHECK_CONTIGUOUS(leftpad_k);
1490+
CHECK_SHAPE(leftpad_k, batch_size);
1491+
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
1492+
}
14721493

14731494
if (rotary_cos_.has_value()) {
14741495
TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");

csrc/flash_attn/src/block_info.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ struct BlockInfo {
1818
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
1919
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
2020
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
21-
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
22-
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
21+
, leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
22+
, seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
23+
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
2324
{
2425
}
2526

@@ -30,13 +31,14 @@ struct BlockInfo {
3031

3132
template <typename index_t>
3233
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
33-
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
34+
return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
3435
}
3536

3637
const int sum_s_q;
3738
const int sum_s_k;
3839
const int actual_seqlen_q;
3940
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
41+
const int leftpad_k;
4042
const int seqlen_k_cache;
4143
const int actual_seqlen_k;
4244
};

csrc/flash_attn/src/flash.h

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ struct Flash_fwd_params : public Qkv_params {
7676
// array of length b+1 holding starting offset of each sequence.
7777
int * __restrict__ cu_seqlens_q;
7878
int * __restrict__ cu_seqlens_k;
79+
int * __restrict__ leftpad_k;
7980

8081
// If provided, the actual length of each k sequence.
8182
int * __restrict__ seqused_k;

csrc/flash_attn/src/flash_fwd_kernel.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
690690
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
691691
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
692692
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
693-
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
693+
const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2);
694694
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
695695
Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
696696
make_stride(params.rotary_dim / 2, _1{}));
@@ -711,9 +711,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
711711
// if (cute::thread(8, 0)) { print_tensor(gCos); }
712712
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
713713

714-
const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
714+
// const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
715+
const index_t row_offset_knew = bidb * params.knew_batch_stride
715716
+ ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
716-
const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
717+
// const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
718+
const index_t row_offset_vnew = bidb * params.vnew_batch_stride
717719
+ ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
718720
// Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
719721
// e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
@@ -791,7 +793,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
791793
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
792794
binfo.actual_seqlen_q - m_block * kBlockM);
793795
} else {
794-
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
796+
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
795797
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
796798
// We do this by setting the row stride of gCos / gSin to 0.
797799
Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),

flash_attn/flash_attn_interface.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def _flash_attn_varlen_forward(
8181
softcap,
8282
alibi_slopes,
8383
return_softmax,
84-
block_table,
84+
block_table=None,
85+
leftpad_k=None,
8586
):
8687
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
8788
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@@ -93,6 +94,7 @@ def _flash_attn_varlen_forward(
9394
cu_seqlens_q,
9495
cu_seqlens_k,
9596
None,
97+
leftpad_k,
9698
block_table,
9799
alibi_slopes,
98100
max_seqlen_q,
@@ -1150,6 +1152,7 @@ def flash_attn_with_kvcache(
11501152
rotary_sin=None,
11511153
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
11521154
cache_batch_idx: Optional[torch.Tensor] = None,
1155+
cache_leftpad: Optional[torch.Tensor] = None,
11531156
block_table: Optional[torch.Tensor] = None,
11541157
softmax_scale=None,
11551158
causal=False,
@@ -1217,11 +1220,12 @@ def flash_attn_with_kvcache(
12171220
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
12181221
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
12191222
KV cache.
1220-
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
12211223
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
12221224
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
12231225
If the indices are not distinct, and k and v are provided, the values updated in the cache
12241226
might come from any of the duplicate indices.
1227+
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
1228+
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
12251229
softmax_scale: float. The scaling of QK^T before applying softmax.
12261230
Default to 1 / sqrt(headdim).
12271231
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
@@ -1269,6 +1273,7 @@ def flash_attn_with_kvcache(
12691273
rotary_cos,
12701274
rotary_sin,
12711275
cache_batch_idx,
1276+
cache_leftpad,
12721277
block_table,
12731278
alibi_slopes,
12741279
None,

tests/test_flash_attn.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,14 @@ def construct_local_mask(
182182
query_padding_mask=None,
183183
key_padding_mask=None,
184184
device=None,
185+
key_leftpad=None,
185186
):
186187
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
187188
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
189+
if key_leftpad is not None:
190+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
191+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
192+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
188193
sk = (
189194
seqlen_k
190195
if key_padding_mask is None
@@ -219,6 +224,7 @@ def attention_ref(
219224
softcap=0.0,
220225
upcast=True,
221226
reorder_ops=False,
227+
key_leftpad=None,
222228
):
223229
"""
224230
Arguments:
@@ -268,6 +274,7 @@ def attention_ref(
268274
query_padding_mask,
269275
key_padding_mask,
270276
q.device,
277+
key_leftpad=key_leftpad,
271278
)
272279
scores.masked_fill_(local_mask, float("-inf"))
273280
if attn_bias is not None:
@@ -306,6 +313,7 @@ def attention_kvpacked_ref(
306313
softcap=0.0,
307314
upcast=True,
308315
reorder_ops=False,
316+
key_leftpad=None,
309317
):
310318
return attention_ref(
311319
q,
@@ -321,6 +329,7 @@ def attention_kvpacked_ref(
321329
window_size=window_size,
322330
softcap=softcap,
323331
reorder_ops=reorder_ops,
332+
key_leftpad=key_leftpad,
324333
)
325334

326335

@@ -1868,9 +1877,11 @@ def test_flash_attn_splitkv(
18681877
# @pytest.mark.parametrize("rotary_fraction", [0.0])
18691878
@pytest.mark.parametrize("paged_kv_block_size", [None, 256])
18701879
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
1871-
# @pytest.mark.parametrize("paged_kv_block_size", [256])
1872-
@pytest.mark.parametrize("has_batch_idx", [False, True])
1873-
# @pytest.mark.parametrize("has_batch_idx", [False])
1880+
# @pytest.mark.parametrize("paged_kv_block_size", [None])
1881+
@pytest.mark.parametrize("has_leftpad", [False, True])
1882+
# @pytest.mark.parametrize("has_leftpad", [True])
1883+
# @pytest.mark.parametrize("has_batch_idx", [False, True])
1884+
@pytest.mark.parametrize("has_batch_idx", [False])
18741885
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
18751886
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
18761887
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
@@ -1898,6 +1909,7 @@ def test_flash_attn_kvcache(
18981909
seqlen_k,
18991910
d,
19001911
has_batch_idx,
1912+
has_leftpad,
19011913
paged_kv_block_size,
19021914
rotary_fraction,
19031915
rotary_interleaved,
@@ -1916,6 +1928,8 @@ def test_flash_attn_kvcache(
19161928
pytest.skip()
19171929
if has_batch_idx and paged_kv_block_size is not None:
19181930
pytest.skip()
1931+
if has_leftpad and paged_kv_block_size is not None:
1932+
pytest.skip()
19191933
device = "cuda"
19201934
# set seed
19211935
torch.random.manual_seed(0)
@@ -1961,9 +1975,19 @@ def test_flash_attn_kvcache(
19611975
dtype=torch.int32,
19621976
device=device,
19631977
)
1978+
if has_leftpad:
1979+
cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device)
1980+
if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device)
1981+
for i in range(batch_size)])
1982+
else:
1983+
cache_leftpad = None
19641984
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
19651985
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
19661986
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0)
1987+
if has_leftpad:
1988+
key_padding_mask = torch.logical_and(
1989+
key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k)
1990+
)
19671991
if has_batch_idx:
19681992
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[
19691993
:batch_size
@@ -2038,6 +2062,7 @@ def test_flash_attn_kvcache(
20382062
rotary_sin=sin,
20392063
cache_seqlens=cache_seqlens,
20402064
cache_batch_idx=cache_batch_idx,
2065+
cache_leftpad=cache_leftpad,
20412066
block_table=block_table,
20422067
causal=causal,
20432068
window_size=window_size,
@@ -2066,6 +2091,7 @@ def test_flash_attn_kvcache(
20662091
None,
20672092
causal=causal,
20682093
window_size=window_size,
2094+
key_leftpad=cache_leftpad,
20692095
)
20702096
out_pt, _ = attention_ref(
20712097
q_ro,
@@ -2080,6 +2106,7 @@ def test_flash_attn_kvcache(
20802106
window_size=window_size,
20812107
upcast=False,
20822108
reorder_ops=True,
2109+
key_leftpad=cache_leftpad,
20832110
)
20842111
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
20852112
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")

0 commit comments

Comments
 (0)