Skip to content

Commit e30cd11

Browse files
committed
Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support.
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
1 parent fe742ae commit e30cd11

File tree

18 files changed

+2426
-32
lines changed

18 files changed

+2426
-32
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ set(VLLM_EXT_SRC
231231
"csrc/attention/paged_attention_v1.cu"
232232
"csrc/attention/paged_attention_v2.cu"
233233
"csrc/attention/merge_attn_states.cu"
234+
"csrc/attention/vertical_slash_index.cu"
234235
"csrc/pos_encoding_kernels.cu"
235236
"csrc/activation_kernels.cu"
236237
"csrc/layernorm_kernels.cu"

cmake/external_projects/vllm_flash_attn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ else()
3838
FetchContent_Declare(
3939
vllm-flash-attn
4040
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
41-
GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
41+
GIT_TAG e371cc41310682b806dc132b68546235288ffb8b
4242
GIT_PROGRESS TRUE
4343
# Don't share the vllm-flash-attn build between build types
4444
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#include <assert.h>
5+
6+
#include <cuda.h>
7+
8+
#include <torch/all.h>
9+
10+
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
11+
int64_t range_end, int64_t block_size,
12+
int64_t input_block_count, int64_t kv_seqlen) {
13+
if (range_start >= kv_seqlen) {
14+
return input_block_count;
15+
}
16+
if (range_end > kv_seqlen) {
17+
range_end = kv_seqlen;
18+
}
19+
int64_t current_block_count = input_block_count;
20+
for (int idx = range_start; idx < range_end; idx += block_size) {
21+
block_offset[current_block_count++] = idx;
22+
}
23+
return current_block_count;
24+
}
25+
26+
__global__ void convert_vertical_slash_indexes_kernel(
27+
const int* q_seqlens, // [BATCH, ]
28+
const int* kv_seqlens, // [BATCH, ]
29+
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
30+
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
31+
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
32+
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
33+
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
34+
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
35+
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
36+
int64_t NNZ_V, int64_t NNZ_S,
37+
bool causal // True for intra, False for succ
38+
) {
39+
const int batch_idx = blockIdx.y;
40+
const int head_idx = blockIdx.x;
41+
const int group_idx = blockIdx.z;
42+
43+
int64_t q_seqlen = q_seqlens[batch_idx];
44+
int64_t kv_seqlen = kv_seqlens[batch_idx];
45+
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
46+
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
47+
if (start_m >= q_seqlen) {
48+
return;
49+
}
50+
int64_t end_m = start_m + BLOCK_SIZE_M;
51+
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
52+
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
53+
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
54+
block_count += row_offset;
55+
block_offset += row_offset * NNZ_S;
56+
column_count += row_offset;
57+
column_index += row_offset * NNZ_V;
58+
59+
bool has_slash = true;
60+
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
61+
int64_t s = 0, v = 0;
62+
int64_t v_idx = vertical_indexes[v++];
63+
int64_t s_idx = slash_indexes[s++];
64+
if (causal) {
65+
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
66+
s_idx = slash_indexes[s++];
67+
}
68+
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
69+
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
70+
} else {
71+
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
72+
s_idx = slash_indexes[s++];
73+
}
74+
if (s_idx > end_m + kv_seqlen) has_slash = false;
75+
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
76+
}
77+
78+
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
79+
if (!has_slash) {
80+
if (causal) {
81+
range_start = (kv_seqlen - q_seqlen) + end_m;
82+
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
83+
} else {
84+
range_start = kv_seqlen;
85+
range_end = kv_seqlen + BLOCK_SIZE_N;
86+
}
87+
}
88+
89+
bool slash_finished = false;
90+
while (1) {
91+
if (v_idx < range_end) {
92+
if (v_idx < range_start) {
93+
column_index[tmp_col_cnt++] = v_idx;
94+
}
95+
if (v < NNZ_V) {
96+
v_idx = vertical_indexes[v++];
97+
} else {
98+
if (causal)
99+
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
100+
else
101+
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
102+
}
103+
} else {
104+
if ((s < NNZ_S && causal) ||
105+
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
106+
if (causal)
107+
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
108+
BLOCK_SIZE_M);
109+
else
110+
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
111+
} else {
112+
if (v == NNZ_V || (v_idx > range_start && causal)) {
113+
// add the last vertical if no more slash
114+
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
115+
column_index[tmp_col_cnt++] = v_idx;
116+
}
117+
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
118+
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
119+
break;
120+
} else {
121+
if (causal) {
122+
range_start = (kv_seqlen - q_seqlen) + end_m;
123+
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
124+
} else {
125+
// if slash_finished but there are vertical left, save current
126+
// blocks
127+
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
128+
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
129+
range_start = kv_seqlen;
130+
range_end = kv_seqlen + BLOCK_SIZE_N;
131+
}
132+
slash_finished = true;
133+
}
134+
}
135+
if (!slash_finished) {
136+
if (s_idx > range_end + BLOCK_SIZE_M) {
137+
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
138+
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
139+
range_start = s_idx - BLOCK_SIZE_M;
140+
range_end = s_idx;
141+
} else if (s_idx > range_end) {
142+
range_end += BLOCK_SIZE_M;
143+
}
144+
}
145+
}
146+
}
147+
148+
block_count[0] = tmp_blk_cnt;
149+
column_count[0] = tmp_col_cnt;
150+
}
151+
152+
void convert_vertical_slash_indexes_64x64(
153+
const int* q_seqlens, // [BATCH, ]
154+
const int* kv_seqlens, // [BATCH, ]
155+
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
156+
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
157+
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
158+
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
159+
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
160+
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
161+
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
162+
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
163+
const int N_THREADS = 64;
164+
const dim3 dimBlock(N_THREADS);
165+
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
166+
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
167+
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
168+
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
169+
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
170+
}
171+
172+
void convert_vertical_slash_indexes(
173+
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
174+
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
175+
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
176+
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
177+
torch::Tensor q_seqlens, // [BATCH, ]
178+
torch::Tensor kv_seqlens, // [BATCH, ]
179+
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
180+
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
181+
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
182+
bool causal) {
183+
cudaSetDevice(q_seqlens.get_device());
184+
185+
int batch_size = slash_indexes.size(0);
186+
int num_heads = slash_indexes.size(1);
187+
int nnz_slash = slash_indexes.size(2);
188+
int nnz_vertical = vertical_indexes.size(2);
189+
int num_rows = (context_size + block_size_M - 1) / block_size_M;
190+
191+
convert_vertical_slash_indexes_64x64(
192+
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
193+
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
194+
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
195+
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
196+
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
197+
causal);
198+
}
199+
200+
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
201+
const int* q_seqlens, // [BATCH, ]
202+
const int* kv_seqlens, // [BATCH, ]
203+
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
204+
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
205+
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
206+
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
207+
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
208+
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
209+
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
210+
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
211+
int64_t NNZ_V, int64_t NNZ_S,
212+
bool causal // True for intra, False for succ
213+
) {
214+
const int batch_idx = blockIdx.y;
215+
const int head_idx = blockIdx.x;
216+
const int group_idx = blockIdx.z;
217+
218+
int64_t q_seqlen = q_seqlens[batch_idx];
219+
int64_t kv_seqlen = kv_seqlens[batch_idx];
220+
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
221+
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
222+
if (start_m >= q_seqlen) {
223+
return;
224+
}
225+
int64_t end_m = start_m + BLOCK_SIZE_M;
226+
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
227+
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
228+
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
229+
block_count += row_offset;
230+
block_offset += row_offset * NNZ_S;
231+
column_count += row_offset;
232+
column_index += row_offset * NNZ_V;
233+
234+
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
235+
// above is buffer size, use to compute offset)
236+
NNZ_S = per_head_slash_topkv[head_idx];
237+
NNZ_V = per_head_vertical_topkv[head_idx];
238+
239+
bool has_slash = true;
240+
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
241+
int64_t s = 0, v = 0;
242+
int64_t v_idx = vertical_indexes[v++];
243+
int64_t s_idx = slash_indexes[s++];
244+
if (causal) {
245+
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
246+
s_idx = slash_indexes[s++];
247+
}
248+
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
249+
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
250+
} else {
251+
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
252+
s_idx = slash_indexes[s++];
253+
}
254+
if (s_idx > end_m + kv_seqlen) has_slash = false;
255+
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
256+
}
257+
258+
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
259+
if (!has_slash) {
260+
if (causal) {
261+
range_start = (kv_seqlen - q_seqlen) + end_m;
262+
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
263+
} else {
264+
range_start = kv_seqlen;
265+
range_end = kv_seqlen + BLOCK_SIZE_N;
266+
}
267+
}
268+
269+
bool slash_finished = false;
270+
while (1) {
271+
if (v_idx < range_end) {
272+
if (v_idx < range_start) {
273+
column_index[tmp_col_cnt++] = v_idx;
274+
}
275+
if (v < NNZ_V) {
276+
v_idx = vertical_indexes[v++];
277+
} else {
278+
if (causal)
279+
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
280+
else
281+
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
282+
}
283+
} else {
284+
if ((s < NNZ_S && causal) ||
285+
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
286+
if (causal)
287+
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
288+
BLOCK_SIZE_M);
289+
else
290+
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
291+
} else {
292+
if (v == NNZ_V || (v_idx > range_start && causal)) {
293+
// add the last vertical if no more slash
294+
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
295+
column_index[tmp_col_cnt++] = v_idx;
296+
}
297+
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
298+
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
299+
break;
300+
} else {
301+
if (causal) {
302+
range_start = (kv_seqlen - q_seqlen) + end_m;
303+
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
304+
} else {
305+
// if slash_finished but there are vertical left, save current
306+
// blocks
307+
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
308+
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
309+
range_start = kv_seqlen;
310+
range_end = kv_seqlen + BLOCK_SIZE_N;
311+
}
312+
slash_finished = true;
313+
}
314+
}
315+
if (!slash_finished) {
316+
if (s_idx > range_end + BLOCK_SIZE_M) {
317+
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
318+
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
319+
range_start = s_idx - BLOCK_SIZE_M;
320+
range_end = s_idx;
321+
} else if (s_idx > range_end) {
322+
range_end += BLOCK_SIZE_M;
323+
}
324+
}
325+
}
326+
}
327+
328+
block_count[0] = tmp_blk_cnt;
329+
column_count[0] = tmp_col_cnt;
330+
}
331+
332+
void convert_vertical_slash_indexes_64x64_mergehead(
333+
const int* q_seqlens, // [BATCH, ]
334+
const int* kv_seqlens, // [BATCH, ]
335+
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
336+
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
337+
int* per_head_vertical_topkv, int* per_head_slash_topkv,
338+
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
339+
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
340+
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
341+
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
342+
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
343+
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
344+
const int N_THREADS = 64;
345+
const dim3 dimBlock(N_THREADS);
346+
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
347+
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
348+
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
349+
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
350+
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
351+
NNZ_V, NNZ_S, causal);
352+
}
353+
354+
void convert_vertical_slash_indexes_mergehead(
355+
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
356+
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
357+
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
358+
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
359+
torch::Tensor q_seqlens, // [BATCH, ]
360+
torch::Tensor kv_seqlens, // [BATCH, ]
361+
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
362+
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
363+
torch::Tensor vertical_indices_count, // [N_HEADS, ]
364+
torch::Tensor slash_indices_count, int64_t context_size,
365+
int64_t block_size_M, int64_t block_size_N, bool causal) {
366+
cudaSetDevice(q_seqlens.get_device());
367+
368+
int batch_size = slash_indexes.size(0);
369+
int num_heads = slash_indexes.size(1);
370+
int nnz_slash = slash_indexes.size(2);
371+
int nnz_vertical = vertical_indexes.size(2);
372+
int num_rows = (context_size + block_size_M - 1) / block_size_M;
373+
374+
convert_vertical_slash_indexes_64x64_mergehead(
375+
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
376+
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
377+
vertical_indices_count.data_ptr<int>(),
378+
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
379+
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
380+
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
381+
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
382+
}

0 commit comments

Comments
 (0)