Skip to content

Commit a14a552

Browse files
add back flash_attn_func api (and support FA3)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent d4e0903 commit a14a552

File tree

6 files changed

+227
-15
lines changed

6 files changed

+227
-15
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
231231
# FLASHATTENTION_DISABLE_LOCAL
232232
FLASHATTENTION_DISABLE_PYBIND
233233
FLASHATTENTION_DISABLE_FP8 # TODO Enable FP8
234-
FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size
235234
)
236235
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0)
237236
message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.")

csrc/flash_attn/flash_api_torch_lib.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,21 @@ namespace FLASH_NAMESPACE {
1414

1515
////////////////////////////// From flash_api.cpp //////////////////////////////
1616

17+
std::vector<at::Tensor>
18+
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
19+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
20+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
21+
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
22+
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
23+
const float p_dropout,
24+
const float softmax_scale,
25+
bool is_causal,
26+
int window_size_left,
27+
int window_size_right,
28+
const float softcap,
29+
const bool return_softmax,
30+
std::optional<at::Generator> gen_);
31+
1732
std::vector<at::Tensor>
1833
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
1934
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
@@ -105,6 +120,12 @@ mha_varlen_fwd_sparse(at::Tensor &q, // total_q x num_heads x head_size, total_
105120
* Torch Library Registration
106121
*/
107122
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
123+
ops.def("fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor? alibi_slopes, "
124+
"float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
125+
"float softcap, bool return_softmax, Generator? gen)"
126+
"-> Tensor[]");
127+
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
128+
108129
ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, "
109130
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, "
110131
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "

hopper/static_switch.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,6 @@
117117
constexpr static bool CONST_NAME = false; \
118118
return __VA_ARGS__(); \
119119
}()
120-
#elif defined(FLASHATTENTION_VARLEN_ONLY)
121-
#define VARLEN_SWITCH(COND, CONST_NAME, ...) \
122-
[&] { \
123-
TORCH_CHECK(COND, "This flash attention build only supports varlen " \
124-
"(for build size reasons)."); \
125-
constexpr static bool CONST_NAME = true; \
126-
return __VA_ARGS__(); \
127-
}()
128120
#else
129121
#define VARLEN_SWITCH BOOL_SWITCH
130122
#endif

tests/test_vllm_flash_attn.py

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from einops import rearrange, repeat
1212

1313
from vllm_flash_attn.flash_attn_interface import (
14+
flash_attn_func,
1415
flash_attn_varlen_func,
1516
flash_attn_with_kvcache,
16-
is_fa_version_supported
17+
is_fa_version_supported,
18+
fa_version_unsupported_reason
1719
)
1820

1921
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
@@ -23,15 +25,49 @@
2325
# one value large enough to test overflow in index calculation.
2426
# one value small enough to test the schema op check
2527
NUM_BLOCKS = [32768, 2048]
26-
VERSIONS = \
27-
([2] if is_fa_version_supported(2) else []) + \
28-
([3] if is_fa_version_supported(3) else [])
28+
VERSIONS = [2, 3]
29+
30+
31+
def construct_local_mask(
32+
seqlen_q,
33+
seqlen_k,
34+
window_size=(-1, -1), # -1 means infinite window size
35+
query_padding_mask=None,
36+
key_padding_mask=None,
37+
device=None,
38+
key_leftpad=None,
39+
):
40+
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
41+
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
42+
if key_leftpad is not None:
43+
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
44+
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
45+
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
46+
sk = (
47+
seqlen_k
48+
if key_padding_mask is None
49+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
50+
)
51+
sq = (
52+
seqlen_q
53+
if query_padding_mask is None
54+
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
55+
)
56+
if window_size[0] < 0:
57+
return col_idx > row_idx + sk - sq + window_size[1]
58+
else:
59+
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
60+
return torch.logical_or(
61+
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
62+
col_idx < row_idx + sk - sq - window_size[0],
63+
)
2964

3065

3166
def ref_attn(
3267
q,
3368
k,
3469
v,
70+
scale,
3571
query_padding_mask=None,
3672
key_padding_mask=None,
3773
attn_bias=None,
@@ -74,10 +110,11 @@ def ref_attn(
74110
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
75111
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
76112
d = q.shape[-1]
113+
q *= scale
77114
if not reorder_ops:
78-
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
115+
scores = torch.einsum("bthd,bshd->bhts", q, k)
79116
else:
80-
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
117+
scores = torch.einsum("bthd,bshd->bhts", q, k)
81118

82119
lse_ref = scores.logsumexp(dim=-1)
83120

@@ -178,6 +215,59 @@ def ref_paged_attn(
178215
return torch.cat(outputs, dim=0)
179216

180217

218+
@pytest.mark.parametrize("seq_len", [1, 10, 256, 533])
219+
@pytest.mark.parametrize("batch_size", [1, 7, 32])
220+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
221+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
222+
@pytest.mark.parametrize("dtype", DTYPES)
223+
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
224+
@pytest.mark.parametrize("fa_version", VERSIONS)
225+
@torch.inference_mode()
226+
def test_flash_attn(
227+
seq_len: int,
228+
batch_size: int,
229+
num_heads: Tuple[int, int],
230+
head_size: int,
231+
dtype: torch.dtype,
232+
soft_cap: Optional[float],
233+
fa_version: int,
234+
) -> None:
235+
torch.set_default_device("cuda")
236+
torch.cuda.manual_seed_all(0)
237+
num_query_heads = num_heads[0]
238+
num_kv_heads = num_heads[1]
239+
assert num_query_heads % num_kv_heads == 0
240+
scale = head_size**-0.5
241+
242+
query = torch.randn(
243+
batch_size, seq_len, num_query_heads, head_size, dtype=dtype)
244+
key = torch.randn(
245+
batch_size, seq_len, num_kv_heads, head_size, dtype=dtype)
246+
value = torch.randn(
247+
batch_size, seq_len, num_kv_heads, head_size, dtype=dtype)
248+
249+
output = flash_attn_func(
250+
query,
251+
key,
252+
value,
253+
softmax_scale=scale,
254+
causal=True,
255+
softcap=soft_cap if soft_cap is not None else 0,
256+
fa_version=fa_version,
257+
)
258+
259+
ref_output, _ = ref_attn(
260+
q=query,
261+
k=key,
262+
v=value,
263+
scale=scale,
264+
causal=True,
265+
softcap=soft_cap if soft_cap is not None else 0,
266+
)
267+
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
268+
f"{torch.max(torch.abs(output - ref_output))}"
269+
270+
181271
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
182272
@pytest.mark.parametrize("num_heads", NUM_HEADS)
183273
@pytest.mark.parametrize("head_size", HEAD_SIZES)

vllm_flash_attn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# Use relative import to support build-from-source installation in vLLM
44
from .flash_attn_interface import (
5+
flash_attn_func,
56
flash_attn_varlen_func,
67
flash_attn_with_kvcache,
78
sparse_attn_func,

vllm_flash_attn/flash_attn_interface.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,115 @@ def maybe_contiguous(x):
7373
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
7474

7575

76+
def flash_attn_func(
77+
q,
78+
k,
79+
v,
80+
dropout_p=0.0,
81+
softmax_scale=None,
82+
causal=False,
83+
window_size=(-1, -1), # -1 means infinite context window
84+
softcap=0.0, # 0.0 means deactivated
85+
alibi_slopes=None,
86+
deterministic=False,
87+
return_attn_probs=False,
88+
*,
89+
return_softmax_lse=False,
90+
out=None,
91+
fa_version: int = DEFAULT_FA_VERSION,
92+
):
93+
"""dropout_p should be set to 0.0 during evaluation
94+
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
95+
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
96+
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
97+
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
98+
99+
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
100+
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
101+
1 1 1 1 0
102+
1 1 1 1 1
103+
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
104+
0 0
105+
0 0
106+
0 0
107+
1 0
108+
1 1
109+
If the row of the mask is all zero, the output will be zero.
110+
111+
If window_size != (-1, -1), implements sliding window local attention. Query at position i
112+
will only attend to keys between
113+
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
114+
115+
Arguments:
116+
q: (batch_size, seqlen, nheads, headdim)
117+
k: (batch_size, seqlen, nheads_k, headdim)
118+
v: (batch_size, seqlen, nheads_k, headdim)
119+
dropout_p: float. Dropout probability.
120+
softmax_scale: float. The scaling of QK^T before applying softmax.
121+
Default to 1 / sqrt(headdim).
122+
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
123+
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
124+
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
125+
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
126+
is added to the attention score of query i and key j.
127+
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
128+
which is slightly slower and uses more memory. The forward pass is always deterministic.
129+
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
130+
testing only. The returned probabilities are not guaranteed to be correct
131+
(they might not have the right scaling).
132+
Return:
133+
out: (batch_size, seqlen, nheads, headdim).
134+
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
135+
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
136+
normalization factor).
137+
"""
138+
if softmax_scale is None:
139+
softmax_scale = q.shape[-1] ** (-0.5)
140+
141+
if fa_version == 2:
142+
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
143+
out, softmax_lse = torch.ops._vllm_fa2_C.fwd(
144+
q,
145+
k,
146+
v,
147+
out,
148+
alibi_slopes,
149+
dropout_p,
150+
softmax_scale,
151+
causal,
152+
window_size[0], window_size[1],
153+
softcap,
154+
return_softmax_lse and dropout_p > 0,
155+
None,
156+
)
157+
elif fa_version == 3:
158+
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
159+
q, k, v,
160+
None, None, # k_new, v_new
161+
out,
162+
None, None, # cu_seqlens_q, cu_seqlens_k
163+
None, # cu_seqlens_k_new
164+
None, None, # seqused_q, seqused_k
165+
None, None, # max_seqlen_q, max_seqlen_k
166+
None,
167+
alibi_slopes,
168+
None, # kv_batch_idx
169+
None, None, # rotary_cos, rotary_sin
170+
None, None, None, # q_descale, k_descale, v_descale
171+
softmax_scale,
172+
causal,
173+
window_size[0], window_size[1],
174+
0, # sink_token_length
175+
softcap,
176+
True, # rotary_interleaved
177+
0, # num_splits
178+
None, # pack_gqa
179+
0, # sm_margin
180+
)
181+
182+
return (out, softmax_lse) if return_softmax_lse else out
183+
184+
76185
def flash_attn_varlen_func(
77186
q,
78187
k,

0 commit comments

Comments
 (0)