Skip to content

Commit 60f7624

Browse files
authored
Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844)
1 parent f6518b2 commit 60f7624

File tree

17 files changed

+2444
-32
lines changed

17 files changed

+2444
-32
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
230230
"csrc/attention/paged_attention_v1.cu"
231231
"csrc/attention/paged_attention_v2.cu"
232232
"csrc/attention/merge_attn_states.cu"
233+
"csrc/attention/vertical_slash_index.cu"
233234
"csrc/pos_encoding_kernels.cu"
234235
"csrc/activation_kernels.cu"
235236
"csrc/layernorm_kernels.cu"

csrc/attention/vertical_slash_index.cu

Lines changed: 401 additions & 0 deletions
Large diffs are not rendered by default.

csrc/ops.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output,
5959
const torch::Tensor& prefix_lse,
6060
const torch::Tensor& suffix_output,
6161
const torch::Tensor& suffix_lse);
62+
63+
void convert_vertical_slash_indexes(
64+
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
65+
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
66+
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
67+
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
68+
torch::Tensor q_seqlens, // [BATCH, ]
69+
torch::Tensor kv_seqlens, // [BATCH, ]
70+
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
71+
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
72+
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
73+
bool causal);
74+
75+
void convert_vertical_slash_indexes_mergehead(
76+
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
77+
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
78+
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
79+
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
80+
torch::Tensor q_seqlens, // [BATCH, ]
81+
torch::Tensor kv_seqlens, // [BATCH, ]
82+
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
83+
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
84+
torch::Tensor vertical_indices_count, // [N_HEADS, ]
85+
torch::Tensor slash_indices_count, int64_t context_size,
86+
int64_t block_size_M, int64_t block_size_N, bool causal);
6287
#endif
6388

6489
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,

csrc/torch_bindings.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7777
" Tensor suffix_output,"
7878
" Tensor suffix_lse) -> ()");
7979
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
80+
81+
ops.def(
82+
"convert_vertical_slash_indexes("
83+
" Tensor! block_count, Tensor! block_offset, "
84+
" Tensor! column_count, Tensor! column_index, "
85+
" Tensor q_seqlens, Tensor q_seqlens, "
86+
" Tensor vertical_indexes, Tensor slash_indexes, "
87+
" int context_size, int block_size_M, int block_size_N, "
88+
" bool causal) -> ()");
89+
ops.impl("convert_vertical_slash_indexes", torch::kCUDA,
90+
&convert_vertical_slash_indexes);
91+
92+
ops.def(
93+
"convert_vertical_slash_indexes_mergehead("
94+
" Tensor! block_count, Tensor! block_offset, "
95+
" Tensor! column_count, Tensor! column_index, "
96+
" Tensor q_seqlens, Tensor q_seqlens, "
97+
" Tensor vertical_indexes, Tensor slash_indexes, "
98+
" Tensor vertical_indices_count, Tensor slash_indices_count, "
99+
" int context_size, int block_size_M, int block_size_N, "
100+
" bool causal) -> ()");
101+
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
102+
&convert_vertical_slash_indexes_mergehead);
80103
#endif
81104

82105
// Activation ops

examples/offline_inference/qwen_1m.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
from urllib.request import urlopen
4+
5+
from vllm import LLM, SamplingParams
6+
7+
os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN"
8+
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
9+
10+
11+
def load_prompt() -> str:
12+
# Test cases with various lengths can be found at:
13+
#
14+
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
15+
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
16+
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
17+
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
18+
19+
with urlopen(
20+
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
21+
"/Qwen2.5-1M/test-data/600k.txt",
22+
timeout=5) as response:
23+
prompt = response.read().decode('utf-8')
24+
return prompt
25+
26+
27+
# Processing the prompt.
28+
def process_requests(llm: LLM, prompts: list[str]) -> None:
29+
# Create a sampling params object.
30+
sampling_params = SamplingParams(
31+
temperature=0.7,
32+
top_p=0.8,
33+
top_k=20,
34+
repetition_penalty=1.05,
35+
detokenize=True,
36+
max_tokens=256,
37+
)
38+
# Generate texts from the prompts.
39+
outputs = llm.generate(prompts, sampling_params)
40+
# Print the outputs.
41+
for output in outputs:
42+
prompt_token_ids = output.prompt_token_ids
43+
generated_text = output.outputs[0].text
44+
print(f"Prompt length: {len(prompt_token_ids)}, "
45+
f"Generated text: {generated_text!r}")
46+
47+
48+
# Create an LLM.
49+
def initialize_engine() -> LLM:
50+
llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M",
51+
max_model_len=1048576,
52+
tensor_parallel_size=4,
53+
enforce_eager=True,
54+
enable_chunked_prefill=True,
55+
max_num_batched_tokens=131072)
56+
return llm
57+
58+
59+
def main():
60+
llm = initialize_engine()
61+
prompt = load_prompt()
62+
process_requests(llm, [prompt])
63+
64+
65+
if __name__ == '__main__':
66+
main()

vllm/_custom_ops.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,101 @@ def merge_attn_states(output: torch.Tensor,
150150
prefix_lse, suffix_output, suffix_lse)
151151

152152

153+
def convert_vertical_slash_indexes(
154+
q_seqlens: torch.Tensor, # [BATCH, ]
155+
kv_seqlens: torch.Tensor, # [BATCH, ]
156+
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
157+
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
158+
context_size: int,
159+
block_size_M: int,
160+
block_size_N: int,
161+
causal: bool = True,
162+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
163+
batch_size = slash_indexes.size(0)
164+
num_heads = slash_indexes.size(1)
165+
nnz_slash = slash_indexes.size(2)
166+
nnz_vertical = vertical_indexes.size(2)
167+
num_rows = (context_size + block_size_M - 1) // block_size_M
168+
169+
block_count = torch.zeros(batch_size,
170+
num_heads,
171+
num_rows,
172+
dtype=q_seqlens.dtype,
173+
device=q_seqlens.device)
174+
block_offset = torch.zeros(batch_size,
175+
num_heads,
176+
num_rows,
177+
nnz_slash,
178+
dtype=q_seqlens.dtype,
179+
device=q_seqlens.device)
180+
column_count = torch.zeros(batch_size,
181+
num_heads,
182+
num_rows,
183+
dtype=q_seqlens.dtype,
184+
device=q_seqlens.device)
185+
column_index = torch.zeros(batch_size,
186+
num_heads,
187+
num_rows,
188+
nnz_vertical,
189+
dtype=q_seqlens.dtype,
190+
device=q_seqlens.device)
191+
192+
torch.ops._C.convert_vertical_slash_indexes(
193+
block_count, block_offset, column_count, column_index, q_seqlens,
194+
kv_seqlens, vertical_indexes, slash_indexes, context_size,
195+
block_size_M, block_size_N, causal)
196+
return block_count, block_offset, column_count, column_index
197+
198+
199+
def convert_vertical_slash_indexes_mergehead(
200+
q_seqlens: torch.Tensor, # [BATCH, ]
201+
kv_seqlens: torch.Tensor, # [BATCH, ]
202+
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
203+
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
204+
# [N_HEADS] : different head use different number of indices
205+
vertical_indices_count: torch.Tensor,
206+
slash_indices_count: torch.Tensor,
207+
context_size: int,
208+
block_size_M: int,
209+
block_size_N: int,
210+
causal: bool = True,
211+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
212+
batch_size = slash_indexes.size(0)
213+
num_heads = slash_indexes.size(1)
214+
nnz_slash = slash_indexes.size(2)
215+
nnz_vertical = vertical_indexes.size(2)
216+
num_rows = (context_size + block_size_M - 1) // block_size_M
217+
218+
block_count = torch.empty(batch_size,
219+
num_heads,
220+
num_rows,
221+
dtype=q_seqlens.dtype,
222+
device=q_seqlens.device)
223+
block_offset = torch.empty(batch_size,
224+
num_heads,
225+
num_rows,
226+
nnz_slash,
227+
dtype=q_seqlens.dtype,
228+
device=q_seqlens.device)
229+
column_count = torch.empty(batch_size,
230+
num_heads,
231+
num_rows,
232+
dtype=q_seqlens.dtype,
233+
device=q_seqlens.device)
234+
column_index = torch.empty(batch_size,
235+
num_heads,
236+
num_rows,
237+
nnz_vertical,
238+
dtype=q_seqlens.dtype,
239+
device=q_seqlens.device)
240+
241+
torch.ops._C.convert_vertical_slash_indexes_mergehead(
242+
block_count, block_offset, column_count, column_index, q_seqlens,
243+
kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count,
244+
slash_indices_count, context_size, block_size_M, block_size_N, causal)
245+
return block_count, block_offset, column_count, column_index
246+
247+
153248
# pos encoding ops
154249
def rotary_embedding(
155250
positions: torch.Tensor,

0 commit comments

Comments
 (0)