Skip to content

Commit 6e1f8b6

Browse files
authored
Implements the attention kernel with vertical and slash sparse pattern described in Appendix C.4.2 of https://arxiv.org/abs/2407.02490 (as sparse_attn_func) (#33)
* Add sparse attention with virtical and slash Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * update Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * move sparse_attn to new files Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * Refine Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * remove registering as custom op Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * address review comments Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * remove window_size and useless code Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * only keep hdim128 Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * add seqlen_q=1 in ut and remove useless code Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * support batch_size > 1 Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * add interface sparse_attn_varlen_func Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> * remove useless code Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com> --------- Signed-off-by: Minmin Sun <minmin.smm@alibaba-inc.com>
1 parent 9dbad20 commit 6e1f8b6

12 files changed

+1715
-1
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 463 additions & 0 deletions
Large diffs are not rendered by default.

csrc/flash_attn/src/flash.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@ struct Flash_fwd_params : public Qkv_params {
142142

143143
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
144144
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
145+
146+
// For sparse attention
147+
const int* block_count;
148+
const int* block_offset;
149+
const int* column_count;
150+
const int* column_index;
151+
int NUM_ROWS;
152+
int NNZ_S;
153+
int NNZ_V;
145154
};
146155

147156
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -189,6 +198,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
189198
////////////////////////////////////////////////////////////////////////////////////////////////////
190199

191200
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
201+
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_sparse_(Flash_fwd_params &params, cudaStream_t stream);
192202
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
193203

194204
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) 2023, Tri Dao.
2+
// Splitting the different head dimensions to different files to speed up compilation.
3+
// This file is auto-generated. See "generate_kernels.py"
4+
5+
#include "flash_fwd_sparse_launch_template.h"
6+
7+
template<>
8+
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
9+
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, true>(params, stream);
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) 2023, Tri Dao.
2+
// Splitting the different head dimensions to different files to speed up compilation.
3+
// This file is auto-generated. See "generate_kernels.py"
4+
5+
#include "flash_fwd_sparse_launch_template.h"
6+
7+
template<>
8+
void run_mha_fwd_sparse_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
9+
run_mha_fwd_sparse_hdim128<cutlass::bfloat16_t, false>(params, stream);
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) 2023, Tri Dao.
2+
// Splitting the different head dimensions to different files to speed up compilation.
3+
// This file is auto-generated. See "generate_kernels.py"
4+
5+
#include "flash_fwd_sparse_launch_template.h"
6+
7+
template<>
8+
void run_mha_fwd_sparse_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
9+
run_mha_fwd_sparse_hdim128<cutlass::half_t, true>(params, stream);
10+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright (c) 2023, Tri Dao.
2+
// Splitting the different head dimensions to different files to speed up compilation.
3+
// This file is auto-generated. See "generate_kernels.py"
4+
5+
#include "flash_fwd_sparse_launch_template.h"
6+
7+
template<>
8+
void run_mha_fwd_sparse_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
9+
run_mha_fwd_sparse_hdim128<cutlass::half_t, false>(params, stream);
10+
}

csrc/flash_attn/src/flash_fwd_sparse_kernel.h

Lines changed: 685 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/******************************************************************************
2+
* Copyright (c) 2024, PAI, Alibaba Cloud.
3+
******************************************************************************/
4+
5+
#pragma once
6+
7+
#include "flash_fwd_launch_template.h"
8+
#include "flash_fwd_sparse_kernel.h"
9+
10+
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_sparse_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
11+
#if defined(ARCH_SUPPORTS_FLASH)
12+
static_assert(!(Is_causal && Is_local)); // Enforce constraints
13+
flash::compute_sparse_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
14+
#else
15+
FLASH_UNSUPPORTED_ARCH
16+
#endif
17+
}
18+
19+
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
20+
void run_flash_sparse_fwd(Flash_fwd_params &params, cudaStream_t stream) {
21+
constexpr size_t smem_size = Kernel_traits::kSmemSize;
22+
// printf("smem_size = %d\n", smem_size);
23+
24+
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
25+
// https://github.com/kokkos/kokkos-kernels/issues/349
26+
// https://github.com/HazyResearch/flash-attention/issues/21
27+
28+
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
29+
dim3 grid(num_m_block, params.b, params.h);
30+
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
31+
const bool return_softmax = params.p_ptr != nullptr;
32+
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
33+
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
34+
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
35+
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
36+
constexpr bool IsEvenMNConst = false;
37+
constexpr bool Is_local = false;
38+
// Will only return softmax if dropout, to reduce compilation time.
39+
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
40+
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
41+
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
42+
// If Is_local, set Is_causal to false
43+
auto kernel = &flash_fwd_sparse_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
44+
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
45+
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
46+
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
47+
if (smem_size >= 48 * 1024) {
48+
C10_CUDA_CHECK(cudaFuncSetAttribute(
49+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
50+
}
51+
// int ctas_per_sm;
52+
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
53+
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
54+
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
55+
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
56+
C10_CUDA_KERNEL_LAUNCH_CHECK();
57+
});
58+
});
59+
});
60+
});
61+
}
62+
63+
template<typename T, bool Is_causal>
64+
void run_mha_fwd_sparse_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
65+
constexpr static int Headdim = 32;
66+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
67+
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
68+
});
69+
}
70+
71+
template<typename T, bool Is_causal>
72+
void run_mha_fwd_sparse_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
73+
constexpr static int Headdim = 64;
74+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
75+
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
76+
});
77+
}
78+
79+
template<typename T, bool Is_causal>
80+
void run_mha_fwd_sparse_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
81+
constexpr static int Headdim = 96;
82+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
83+
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
84+
});
85+
}
86+
87+
template<typename T, bool Is_causal>
88+
void run_mha_fwd_sparse_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
89+
constexpr static int Headdim = 128;
90+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
91+
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
92+
});
93+
}
94+
95+
template<typename T, bool Is_causal>
96+
void run_mha_fwd_sparse_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
97+
constexpr static int Headdim = 160;
98+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
99+
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
100+
});
101+
}
102+
103+
template<typename T, bool Is_causal>
104+
void run_mha_fwd_sparse_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
105+
constexpr static int Headdim = 192;
106+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
107+
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
108+
});
109+
}
110+
111+
template<typename T, bool Is_causal>
112+
void run_mha_fwd_sparse_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
113+
constexpr static int Headdim = 224;
114+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
115+
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
116+
});
117+
}
118+
119+
template<typename T, bool Is_causal>
120+
void run_mha_fwd_sparse_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
121+
constexpr static int Headdim = 256;
122+
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
123+
run_flash_sparse_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
124+
});
125+
}

csrc/flash_attn/src/generate_kernels.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@
2525
}}
2626
"""
2727

28+
KERNEL_IMPL_TEMPLATE_FWD_SPARSE = """#include "flash_fwd_sparse_launch_template.h"
29+
30+
template<>
31+
void run_mha_fwd_sparse_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
32+
run_mha_fwd_sparse_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
33+
}}
34+
"""
35+
2836
KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"
2937
3038
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);
@@ -53,6 +61,10 @@ def template(self) -> str:
5361
return KERNEL_IMPL_TEMPLATE_FWD.format(
5462
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
5563
)
64+
elif self.direction == "fwd_sparse":
65+
return KERNEL_IMPL_TEMPLATE_FWD_SPARSE.format(
66+
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal
67+
)
5668
elif self.direction == "bwd":
5769
return KERNEL_IMPL_TEMPLATE_BWD.format(
5870
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
@@ -68,7 +80,7 @@ def filename(self) -> str:
6880

6981

7082
def get_all_kernels() -> List[Kernel]:
71-
for direction in ["fwd", "fwd_split"]:
83+
for direction in ["fwd", "fwd_split", "fwd_sparse"]:
7284
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
7385
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)
7486
for direction in ["bwd"]:

tests/test_vllm_flash_attn.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,154 @@ def test_varlen_with_paged_kv(
268268
)
269269
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
270270
f"{torch.max(torch.abs(output - ref_output))}"
271+
272+
@pytest.mark.parametrize("batch_size", [1, 2])
273+
@pytest.mark.parametrize("seq_lens", [(1, 1), (1, 1024), (1, 2048), (1023, 2049), (1023, 1023), (32, 32), (65, 65), (129, 129)])
274+
@pytest.mark.parametrize("num_heads", [1, 2, 4])
275+
@pytest.mark.parametrize("head_size", [128])
276+
@pytest.mark.parametrize("dtype", DTYPES)
277+
@pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32])
278+
@torch.inference_mode()
279+
def test_sparse_attention(
280+
batch_size: int,
281+
seq_lens: Tuple[int, int],
282+
num_heads: int,
283+
head_size: int,
284+
dtype: torch.dtype,
285+
NNZ_S: int,
286+
) -> None:
287+
torch.set_default_device("cuda")
288+
torch.cuda.manual_seed_all(0)
289+
block_size_M = 64
290+
block_size_N = 64
291+
seqlen_q, seqlen_k = seq_lens
292+
q = torch.randn(
293+
batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False
294+
)
295+
k = torch.randn(
296+
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
297+
)
298+
v = torch.randn(
299+
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
300+
)
301+
NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M
302+
if NNZ_S * block_size_N > seqlen_k:
303+
return
304+
NNZ_V = seqlen_k - NNZ_S * block_size_N
305+
block_count = torch.tensor([NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS)
306+
column_count = torch.tensor([NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS)
307+
block_offset = torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
308+
column_index = torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
309+
from vllm_flash_attn import sparse_attn_func, flash_attn_func
310+
out, lse = sparse_attn_func(
311+
q,
312+
k,
313+
v,
314+
block_count,
315+
block_offset,
316+
column_count,
317+
column_index,
318+
return_softmax_lse=True,
319+
)
320+
321+
ref_out, ref_lse = flash_attn_func(
322+
q,
323+
k,
324+
v,
325+
return_softmax_lse=True,
326+
)
327+
328+
torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
329+
f"{torch.max(torch.abs(out - ref_out))}"
330+
torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
331+
f"{torch.max(torch.abs(lse - ref_lse))}"
332+
333+
@pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
334+
[(1024, 1328), (1, 2048)],
335+
[(1025, 1328), (2, 2048)],
336+
[(1025, 2049), (2, 1281)],
337+
])
338+
@pytest.mark.parametrize("head_size", [128])
339+
@pytest.mark.parametrize("dtype", DTYPES)
340+
@torch.inference_mode()
341+
def test_sparse_attention_varlen(
342+
seq_lens: List[Tuple[int, int]],
343+
head_size: int,
344+
dtype: torch.dtype,
345+
) -> None:
346+
torch.set_default_device("cuda")
347+
torch.cuda.manual_seed_all(0)
348+
block_size_M = 64
349+
block_size_N = 64
350+
num_seqs = len(seq_lens)
351+
query_lens = [x[0] for x in seq_lens]
352+
kv_lens = [x[1] for x in seq_lens]
353+
num_heads = 1
354+
query = torch.randn(sum(query_lens),
355+
num_heads,
356+
head_size,
357+
dtype=dtype)
358+
key = torch.randn(sum(kv_lens),
359+
num_heads,
360+
head_size,
361+
dtype=dtype)
362+
value = torch.randn_like(key)
363+
cu_query_lens = torch.tensor([0] + query_lens,
364+
dtype=torch.int32).cumsum(dim=0,
365+
dtype=torch.int32)
366+
cu_kv_lens = torch.tensor([0] + kv_lens,
367+
dtype=torch.int32).cumsum(dim=0,
368+
dtype=torch.int32)
369+
max_query_len = max(query_lens)
370+
max_kv_len = max(kv_lens)
371+
372+
NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M
373+
NNZ_S = 20
374+
NNZ_V = 2048
375+
batch_size = len(query_lens)
376+
377+
block_counts = []
378+
column_counts = []
379+
block_offsets = []
380+
column_indices = []
381+
for b in range(batch_size):
382+
block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
383+
columns = kv_lens[b] - NNZ_S * block_size_N
384+
column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
385+
block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S))
386+
column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V))
387+
block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS)
388+
column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS)
389+
block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
390+
column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
391+
from vllm_flash_attn import sparse_attn_varlen_func, flash_attn_varlen_func
392+
out, lse = sparse_attn_varlen_func(
393+
query,
394+
key,
395+
value,
396+
block_count,
397+
block_offset,
398+
column_count,
399+
column_index,
400+
cu_seqlens_q=cu_query_lens,
401+
cu_seqlens_k=cu_kv_lens,
402+
max_seqlen_q=max_query_len,
403+
max_seqlen_k=max_kv_len,
404+
return_softmax_lse=True,
405+
)
406+
407+
ref_out, ref_lse = flash_attn_varlen_func(
408+
query,
409+
key,
410+
value,
411+
cu_seqlens_q=cu_query_lens,
412+
cu_seqlens_k=cu_kv_lens,
413+
max_seqlen_q=max_query_len,
414+
max_seqlen_k=max_kv_len,
415+
return_softmax_lse=True,
416+
)
417+
418+
torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
419+
f"{torch.max(torch.abs(out - ref_out))}"
420+
torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
421+
f"{torch.max(torch.abs(lse - ref_lse))}"

vllm_flash_attn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# Use relative import to support build-from-source installation in vLLM
44
from .flash_attn_interface import (
55
flash_attn_func,
6+
sparse_attn_func,
7+
sparse_attn_varlen_func,
68
flash_attn_varlen_func,
79
flash_attn_with_kvcache,
810
)

0 commit comments

Comments
 (0)