Skip to content

Commit 68c4df8

Browse files
committed
address review comments
1 parent 260da65 commit 68c4df8

File tree

4 files changed

+21
-28
lines changed

4 files changed

+21
-28
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ void set_params_fprop_sparse(Flash_fwd_params &params,
214214
TORCH_CHECK(column_index.size(2) == block_offset.size(2));
215215
TORCH_CHECK(column_count.size(2) == column_index.size(2));
216216
params.NUM_ROWS = block_count.size(2);
217+
// params.NUM_ROWS should be equal to cdiv(seqlen_q, BLOCK_M), and BLOCK_M has to be 64 for now.
218+
constexpr int BLOCK_M = 64;
219+
int expected_num_rows = (seqlen_q + BLOCK_M - 1) / BLOCK_M;
220+
TORCH_CHECK(params.NUM_ROWS == expected_num_rows);
217221
params.NNZ_S = block_offset.size(3);
218222
params.NNZ_V = column_index.size(3);
219223
}
@@ -355,6 +359,8 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea
355359
const bool return_softmax,
356360
c10::optional<at::Generator> gen_) {
357361

362+
TORCH_CHECK(window_size_left == -1, "sliding window is not supported in sparse_attn_func.");
363+
TORCH_CHECK(window_size_right == -1, "sliding window is not supported in sparse_attn_func.");
358364
auto dprops = at::cuda::getCurrentDeviceProperties();
359365
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
360366
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;

csrc/flash_attn/src/flash_fwd_sparse_kernel.h

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,6 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
156156
// PREDICATES
157157
//
158158

159-
// // Allocate predicate tensors for m and n
160-
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
161-
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
162-
163159
// Construct identity layout for sQ and sK
164160
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
165161
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
@@ -434,9 +430,6 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
434430
if (num_cols > 0) {
435431
auto* cols_ptr = params.column_index + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_V;
436432
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
437-
// tKgKBlock.data() = tKgKBlockData + blks_ptr[0] * int64_t(params.k_row_stride);
438-
// flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV,
439-
// binfo.actual_seqlen_k - blks_ptr[0]);
440433
#pragma unroll
441434
for (int m = 0; m < size<1>(tKgKToken); ++m) {
442435
if (Is_even_MN || get<0>(tKVcKV(0, m, 0)) < num_cols) { // Is_even_MN
@@ -445,7 +438,7 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
445438
for (int k = 0; k < size<2>(tKgKToken); ++k) {
446439
if (Is_even_K || tKVpKV(k)) {
447440
cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k));
448-
} else if (true) { // Clear_OOB_K
441+
} else { // Clear_OOB_K
449442
cute::clear(tKsK(_, m, k));
450443
}
451444
}
@@ -463,7 +456,6 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
463456

464457
// Advance gV
465458
if (n < num_cols_block - 1) {
466-
// flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV);
467459
#pragma unroll
468460
for (int m = 0; m < size<1>(tVgVToken); ++m) {
469461
if (true) { // Is_even_MN
@@ -480,9 +472,6 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
480472
}
481473
} else {
482474
// Clear the smem tiles to account for predicated off loads
483-
// flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
484-
// gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - start_n
485-
// );
486475
#pragma unroll
487476
for (int m = 0; m < size<1>(tVgVToken); ++m) {
488477
if (Is_even_MN || n * kBlockN + get<0>(tKVcKV(0, m, 0)) < num_cols) { // Is_even_MN
@@ -491,11 +480,11 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
491480
for (int k = 0; k < size<2>(tVgVToken); ++k) {
492481
if (Is_even_K || tKVpKV(k)) {
493482
cute::copy(gmem_tiled_copy_QKV, tVgVToken(_, m, k), tVsV(_, m, k));
494-
} else if (true) { // Clear_OOB_K
483+
} else { // Clear_OOB_K
495484
cute::clear(tVsV(_, m, k));
496485
}
497486
}
498-
} else if (true) { // Clear_OOB_MN
487+
} else { // Clear_OOB_MN
499488
cute::clear(tVsV(_, m, _));
500489
}
501490
}
@@ -511,9 +500,6 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
511500
flash::apply_softcap(acc_s, params.softcap);
512501
}
513502

514-
// mask.template apply_mask<Is_causal, Is_even_MN>(
515-
// acc_s, cols_ptr[n * kBlockN], m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
516-
// );
517503
if (n >= num_cols_block - n_masking_steps) {
518504
Tensor tensor = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
519505
const int lane_id = threadIdx.x % 32;
@@ -546,8 +532,6 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
546532
flash::cp_async_wait<0>();
547533
__syncthreads();
548534
if (n < num_cols_block - 2) {
549-
// tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index + 1] * int64_t(params.k_row_stride);
550-
// flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV);
551535
#pragma unroll
552536
for (int m = 0; m < size<1>(tKgKToken); ++m) {
553537
if (true) { // Is_even_MN
@@ -567,9 +551,6 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
567551
// isn't right and we get race conditions.
568552
cute::cp_async_fence();
569553
} else if (n == num_cols_block - 2) {
570-
// tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index + 1] * int64_t(params.k_row_stride);
571-
// flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV,
572-
// binfo.actual_seqlen_k - blks_ptr[block_index + 1]);
573554
#pragma unroll
574555
for (int m = 0; m < size<1>(tKgKToken); ++m) {
575556
if (Is_even_MN || (n + 1) * kBlockN + get<0>(tKVcKV(0, m, 0)) < num_cols) { // Is_even_MN

tests/test_vllm_flash_attn.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,14 @@ def test_varlen_with_paged_kv(
272272
@pytest.mark.parametrize("num_heads", [1, 2, 4])
273273
@pytest.mark.parametrize("head_size", [64, 128, 256])
274274
@pytest.mark.parametrize("dtype", DTYPES)
275+
@pytest.mark.parametrize("NNZ_S", [1, 2, 3, 7, 15, 32])
275276
@torch.inference_mode()
276277
def test_sparse_attention(
277278
seq_lens: List[Tuple[int, int]],
278279
num_heads: Tuple[int, int],
279280
head_size: int,
280281
dtype: torch.dtype,
282+
NNZ_S: int,
281283
) -> None:
282284
torch.set_default_device("cuda")
283285
torch.cuda.manual_seed_all(0)
@@ -295,12 +297,13 @@ def test_sparse_attention(
295297
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
296298
)
297299
NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M
298-
NNZ_S = seqlen_k // block_size_M // 2
299-
NNZ_V = seqlen_k - NNZ_S * block_size_M
300+
if NNZ_S * block_size_N > seqlen_k:
301+
return
302+
NNZ_V = seqlen_k - NNZ_S * block_size_N
300303
block_count = torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS)
301304
column_count = torch.tensor([NNZ_V] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS)
302-
block_offset = torch.tensor([[i * block_size_M for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
303-
column_index = torch.tensor([[NNZ_S * block_size_M + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
305+
block_offset = torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
306+
column_index = torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
304307
from vllm_flash_attn import sparse_attn_func, flash_attn_func
305308
out, lse = sparse_attn_func(
306309
q,

vllm_flash_attn/flash_attn_interface.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,12 @@ def sparse_attn_func(
156156
return_softmax_lse=False,
157157
out=None,
158158
):
159-
"""Compute attention with virtical and slash sparsity patterns.
159+
"""Compute attention with vertical and slash sparsity patterns.
160160
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
161161
block_count and block_offset for slash sparsity patterns, and
162-
column_count and column_index for virtical sparsity patterns.
162+
column_count and column_index for vertical sparsity patterns.
163+
For more details please refer to MInference
164+
(Paper: https://arxiv.org/abs/2407.02490, Code: https://github.com/microsoft/MInference).
163165
164166
Arguments:
165167
q: (batch_size, seqlen, nheads, headdim)
@@ -174,6 +176,7 @@ def sparse_attn_func(
174176
Default to 1 / sqrt(headdim).
175177
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
176178
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
179+
Sliding window is not supported for sparse_attn_func, so only (-1, -1) is valid.
177180
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
178181
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
179182
is added to the attention score of query i and key j.

0 commit comments

Comments
 (0)