Skip to content

Commit e0faa9a

Browse files
update binding
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent 1012435 commit e0faa9a

File tree

4 files changed

+181
-10
lines changed

4 files changed

+181
-10
lines changed

CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.26)
22

3-
project(vllm_flash_attn LANGUAGES CXX)
3+
project(vllm_flash_attn LANGUAGES CXX CUDA)
44
set(CMAKE_CXX_STANDARD 17)
55
set(CMAKE_CXX_EXTENSIONS OFF)
66

@@ -213,7 +213,9 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
213213
SRCS "${FA3_GEN_SRCS}"
214214
CUDA_ARCHS "${FA3_ARCHS}")
215215
set_gencode_flags_for_srcs(
216-
SRCS "hopper/flash_fwd_combine.cu"
216+
SRCS
217+
hopper/flash_fwd_combine.cu
218+
hopper/flash_prepare_scheduler.cu
217219
CUDA_ARCHS "${FA3_ARCHS}")
218220
endif()
219221

@@ -223,6 +225,7 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
223225
LANGUAGE ${VLLM_GPU_LANG}
224226
SOURCES
225227
hopper/flash_fwd_combine.cu
228+
hopper/flash_prepare_scheduler.cu
226229
hopper/flash_api.cpp
227230
hopper/flash_api_torch_lib.cpp
228231
${FA3_GEN_SRCS}

hopper/flash_api_torch_lib.cpp

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
* Externs for the flash_attn ops to be exposed as a pytorch library
1010
*/
1111

12+
// b: batch_size
13+
// b_k: batch_size_k
14+
// s_q: seqlen_q
15+
// s_k: seqlen_k
16+
// s_k_new: seqlen_k_new
17+
// h: num_heads
18+
// h_k: num_heads_k
19+
// d: head_size
1220
std::vector<at::Tensor>
1321
mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1422
const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
@@ -37,12 +45,41 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
3745
bool is_causal,
3846
int window_size_left,
3947
int window_size_right,
40-
int sink_token_length,
4148
float const softcap,
4249
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
50+
std::optional<at::Tensor> &scheduler_metadata_, // (b + 1)
4351
int num_splits,
4452
std::optional<bool> pack_gqa_,
45-
int const sm_margin);
53+
int const sm_margin
54+
);
55+
56+
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
57+
at::Tensor
58+
mha_fwd_get_scheduler_metadata(
59+
int batch_size,
60+
int max_seqlen_q,
61+
int max_seqlen_k,
62+
int num_heads,
63+
int num_heads_k,
64+
int headdim,
65+
int headdim_v,
66+
at::ScalarType qkv_dtype,
67+
const at::Tensor &seqused_k, // b
68+
std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
69+
std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
70+
std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
71+
std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
72+
std::optional<const at::Tensor> &leftpad_k_, // b
73+
std::optional<int> page_size,
74+
int max_seqlen_k_new, // 0 means we're not appending new KV
75+
bool is_causal,
76+
int window_size_left,
77+
int window_size_right,
78+
bool has_softcap,
79+
int num_splits,
80+
std::optional<bool> pack_gqa_,
81+
int const sm_margin
82+
);
4683

4784
/**
4885
* Torch Library Registration
@@ -74,13 +111,40 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
74111
" bool is_causal,"
75112
" int window_size_left,"
76113
" int window_size_right,"
77-
" int sink_token_length,"
78114
" float softcap,"
79115
" bool is_rotary_interleaved,"
116+
" Tensor? scheduler_metadata,"
80117
" int num_splits,"
81118
" bool? pack_gqa,"
82119
" int sm_margin) -> Tensor[]");
83120
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
121+
122+
ops.def("get_scheduler_metadata("
123+
" int batch_size,"
124+
" int max_seqlen_q,"
125+
" int max_seqlen_k,"
126+
" int num_heads,"
127+
" int num_heads_k,"
128+
" int headdim,"
129+
" int headdim_v,"
130+
" ScalarType qkv_dtype,"
131+
" Tensor seqused_k,"
132+
" Tensor? cu_seqlens_q,"
133+
" Tensor? cu_seqlens_k,"
134+
" Tensor? cu_seqlens_k_new,"
135+
" Tensor? seqused_q,"
136+
" Tensor? leftpad_k,"
137+
" int? page_size,"
138+
" int max_seqlen_k_new," // 0 means we're not appending new KV
139+
" bool is_causal,"
140+
" int window_size_left,"
141+
" int window_size_right,"
142+
" bool has_softcap,"
143+
" int num_splits,"
144+
" bool? pack_gqa,"
145+
" int sm_margin) -> Tensor");
146+
ops.impl("get_scheduler_metadata", torch::kCUDA,
147+
make_pytorch_shim(&mha_fwd_get_scheduler_metadata));
84148
}
85149

86150
REGISTER_EXTENSION(TORCH_EXTENSION_NAME);

tests/test_vllm_flash_attn.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from vllm_flash_attn.flash_attn_interface import (
1414
flash_attn_varlen_func,
1515
flash_attn_with_kvcache,
16-
is_fa_version_supported
16+
get_scheduler_metadata,
17+
is_fa_version_supported,
1718
)
1819

1920
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
@@ -185,6 +186,7 @@ def ref_paged_attn(
185186
@pytest.mark.parametrize("dtype", DTYPES)
186187
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
187188
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
189+
@pytest.mark.parametrize("aot_schedule", [True, False])
188190
@pytest.mark.parametrize("fa_version", VERSIONS)
189191
@torch.inference_mode()
190192
def test_flash_attn_with_paged_kv(
@@ -195,6 +197,7 @@ def test_flash_attn_with_paged_kv(
195197
block_size: int,
196198
soft_cap: Optional[float],
197199
num_blocks: int,
200+
aot_schedule: bool,
198201
fa_version: int,
199202
) -> None:
200203
torch.set_default_device("cuda")
@@ -221,6 +224,24 @@ def test_flash_attn_with_paged_kv(
221224
(num_seqs, max_num_blocks_per_seq),
222225
dtype=torch.int32)
223226

227+
scheduler_metadata = None
228+
if aot_schedule:
229+
if fa_version == 2:
230+
pytest.skip("AOT schedule is not supported in version 2")
231+
scheduler_metadata = get_scheduler_metadata(
232+
batch_size=num_seqs,
233+
max_seqlen_q=1,
234+
max_seqlen_k=max_kv_len,
235+
num_heads_q=num_query_heads,
236+
num_heads_kv=num_kv_heads,
237+
headdim=head_size,
238+
cache_seqlens=kv_lens_tensor,
239+
qkv_dtype=dtype,
240+
causal=True,
241+
window_size=(-1, -1),
242+
has_softcap=soft_cap is not None
243+
)
244+
224245
output = flash_attn_with_kvcache(
225246
query.unsqueeze(1),
226247
key_cache,
@@ -230,6 +251,7 @@ def test_flash_attn_with_paged_kv(
230251
block_table=block_tables,
231252
cache_seqlens=kv_lens_tensor,
232253
softcap=soft_cap if soft_cap is not None else 0,
254+
scheduler_metadata=scheduler_metadata,
233255
fa_version=fa_version,
234256
).squeeze(1)
235257

@@ -255,6 +277,7 @@ def test_flash_attn_with_paged_kv(
255277
@pytest.mark.parametrize("dtype", DTYPES)
256278
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
257279
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
280+
@pytest.mark.parametrize("aot_schedule", [True, False])
258281
@pytest.mark.parametrize("fa_version", VERSIONS)
259282
@torch.inference_mode()
260283
def test_varlen_with_paged_kv(
@@ -266,6 +289,7 @@ def test_varlen_with_paged_kv(
266289
block_size: int,
267290
soft_cap: Optional[float],
268291
num_blocks: int,
292+
aot_schedule: bool,
269293
fa_version: int,
270294
) -> None:
271295
torch.set_default_device("cuda")
@@ -303,6 +327,25 @@ def test_varlen_with_paged_kv(
303327
num_blocks,
304328
(num_seqs, max_num_blocks_per_seq),
305329
dtype=torch.int32)
330+
331+
scheduler_metadata = None
332+
if aot_schedule:
333+
if fa_version == 2:
334+
pytest.skip("AOT schedule is not supported in version 2")
335+
scheduler_metadata = get_scheduler_metadata(
336+
batch_size=num_seqs,
337+
max_seqlen_q=1,
338+
max_seqlen_k=max_kv_len,
339+
num_heads_q=num_query_heads,
340+
num_heads_kv=num_kv_heads,
341+
headdim=head_size,
342+
cache_seqlens=seqused_k,
343+
qkv_dtype=dtype,
344+
causal=True,
345+
window_size=(-1, -1),
346+
has_softcap=soft_cap is not None
347+
)
348+
306349
output = flash_attn_varlen_func(
307350
q=query,
308351
k=key_cache,
@@ -316,6 +359,7 @@ def test_varlen_with_paged_kv(
316359
window_size=window_size,
317360
block_table=block_tables,
318361
softcap=soft_cap if soft_cap is not None else 0,
362+
scheduler_metadata=scheduler_metadata,
319363
fa_version=fa_version
320364
)
321365

vllm_flash_attn/flash_attn_interface.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,48 @@ def fa_version_unsupported_reason(fa_version: int, device = None) \
7373
def maybe_contiguous(x):
7474
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
7575

76+
# NOTE only used in FA3
77+
def get_scheduler_metadata(
78+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
79+
cache_seqlens: torch.Tensor,
80+
qkv_dtype=torch.bfloat16,
81+
headdim_v=None,
82+
cu_seqlens_q: Optional[torch.Tensor] = None,
83+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
84+
cache_leftpad: Optional[torch.Tensor] = None,
85+
page_size: Optional[int] = None,
86+
max_seqlen_k_new=0,
87+
causal=False,
88+
window_size=(-1, -1), # -1 means infinite context window
89+
has_softcap=False,
90+
num_splits=0, # Can be tuned for speed
91+
pack_gqa=None, # Can be tuned for speed
92+
sm_margin=0, # Can be tuned if some SMs are used for communication
93+
):
94+
cache_seqlens = maybe_contiguous(cache_seqlens)
95+
if headdim_v is None:
96+
headdim_v = headdim
97+
scheduler_metadata = torch.ops._vllm_fa3_C.get_scheduler_metadata(
98+
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
99+
qkv_dtype,
100+
cache_seqlens,
101+
cu_seqlens_q,
102+
None, # cu_seqlens_k
103+
cu_seqlens_k_new,
104+
None, # seqused_q
105+
cache_leftpad,
106+
page_size,
107+
max_seqlen_k_new,
108+
causal,
109+
window_size[0], window_size[1],
110+
has_softcap,
111+
num_splits,
112+
pack_gqa,
113+
sm_margin,
114+
)
115+
116+
return scheduler_metadata
117+
76118

77119
def flash_attn_varlen_func(
78120
q,
@@ -95,10 +137,13 @@ def flash_attn_varlen_func(
95137
block_table=None,
96138
return_softmax_lse=False,
97139
out=None,
98-
fa_version: int = DEFAULT_FA_VERSION,
140+
# FA3 Only
141+
scheduler_metadata=None,
99142
q_descale=None,
100143
k_descale=None,
101144
v_descale=None,
145+
# Version selector
146+
fa_version: int = DEFAULT_FA_VERSION,
102147
):
103148
"""dropout_p should be set to 0.0 during evaluation
104149
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -173,6 +218,12 @@ def flash_attn_varlen_func(
173218
dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)
174219

175220
if fa_version == 2:
221+
if scheduler_metadata is not None and q_descale is not None \
222+
and k_descale is not None and v_descale is not None:
223+
raise NotImplementedError(
224+
"FA2 does not support scheduler_metadata, q_descale, "
225+
"k_descale, v_descale"
226+
)
176227
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
177228
q, k, v,
178229
out,
@@ -216,9 +267,9 @@ def flash_attn_varlen_func(
216267
softmax_scale,
217268
causal,
218269
real_window_size[0], real_window_size[1],
219-
0, # sink_token_length
220270
softcap,
221271
True, # rotary_interleaved
272+
scheduler_metadata,
222273
0, # num_splits
223274
None, # pack_gqa
224275
0, # sm_margin
@@ -250,10 +301,13 @@ def flash_attn_with_kvcache(
250301
return_softmax_lse=False,
251302
*,
252303
out=None,
253-
fa_version: int = DEFAULT_FA_VERSION,
304+
# FA3 Only
305+
scheduler_metadata=None,
254306
q_descale=None,
255307
k_descale=None,
256308
v_descale=None,
309+
# Version selector
310+
fa_version: int = DEFAULT_FA_VERSION,
257311
):
258312
"""
259313
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -355,6 +409,12 @@ def flash_attn_with_kvcache(
355409
block_table = maybe_contiguous(block_table)
356410

357411
if fa_version == 2:
412+
if scheduler_metadata is not None and q_descale is not None \
413+
and k_descale is not None and v_descale is not None:
414+
raise NotImplementedError(
415+
"FA2 does not support scheduler_metadata, q_descale, "
416+
"k_descale, v_descale"
417+
)
358418
out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache(
359419
q, k_cache, v_cache,
360420
k, v, # k_new, v_new
@@ -393,9 +453,9 @@ def flash_attn_with_kvcache(
393453
softmax_scale,
394454
causal,
395455
window_size[0], window_size[1],
396-
0, # sink_token_length
397456
softcap,
398457
rotary_interleaved, # rotary_interleaved
458+
scheduler_metadata,
399459
num_splits, # num_splits
400460
None, # pack_gqa
401461
0, # sm_margin

0 commit comments

Comments
 (0)