Skip to content

Commit c9d548f

Browse files
committed
remove window_size and useless code
1 parent 68c4df8 commit c9d548f

File tree

3 files changed

+21
-44
lines changed

3 files changed

+21
-44
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,6 @@ void set_params_fprop_sparse(Flash_fwd_params &params,
181181
void *softmax_lse_d,
182182
float p_dropout,
183183
float softmax_scale,
184-
int window_size_left,
185-
int window_size_right,
186184
const float softcap,
187185
bool seqlenq_ngroups_swapped=false,
188186
const bool unpadded_lse=false) {
@@ -200,8 +198,8 @@ void set_params_fprop_sparse(Flash_fwd_params &params,
200198
softmax_lse_d,
201199
p_dropout,
202200
softmax_scale,
203-
window_size_left,
204-
window_size_right,
201+
-1, // window_size_left
202+
-1, // window_size_right
205203
softcap,
206204
seqlenq_ngroups_swapped,
207205
unpadded_lse
@@ -353,14 +351,10 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea
353351
const double p_dropout,
354352
const double softmax_scale,
355353
bool is_causal,
356-
int64_t window_size_left,
357-
int64_t window_size_right,
358354
const double softcap,
359355
const bool return_softmax,
360356
c10::optional<at::Generator> gen_) {
361357

362-
TORCH_CHECK(window_size_left == -1, "sliding window is not supported in sparse_attn_func.");
363-
TORCH_CHECK(window_size_right == -1, "sliding window is not supported in sparse_attn_func.");
364358
auto dprops = at::cuda::getCurrentDeviceProperties();
365359
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
366360
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
@@ -398,12 +392,8 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea
398392

399393
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
400394

401-
if (window_size_left >= seqlen_k) { window_size_left = -1; }
402-
if (window_size_right >= seqlen_k) { window_size_right = -1; }
403-
404395
// causal=true is the same as causal=false in this case
405396
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
406-
if (is_causal) { window_size_right = 0; }
407397

408398
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
409399
// H/t Daniel Haziza
@@ -483,8 +473,6 @@ mha_fwd_sparse(at::Tensor &q, // batch_size x seqlen_q x num_heads x hea
483473
softmax_lse.data_ptr(),
484474
p_dropout,
485475
softmax_scale,
486-
window_size_left,
487-
window_size_right,
488476
softcap
489477
);
490478

@@ -1290,7 +1278,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
12901278
ops.def("fwd_sparse(Tensor! q, Tensor k, Tensor v, "
12911279
"Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, "
12921280
"Tensor!? out, Tensor? alibi_slopes, "
1293-
"float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
1281+
"float p_dropout, float softmax_scale, bool is_causal, "
12941282
"float softcap, bool return_softmax, Generator? gen)"
12951283
"-> Tensor[]");
12961284
ops.impl("fwd_sparse", torch::kCUDA, &mha_fwd_sparse);

csrc/flash_attn/src/flash_fwd_sparse_kernel.h

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -458,15 +458,13 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
458458
if (n < num_cols_block - 1) {
459459
#pragma unroll
460460
for (int m = 0; m < size<1>(tVgVToken); ++m) {
461-
if (true) { // Is_even_MN
462-
tVgVToken.data() = tVgVTokenData + cols_ptr[n * kBlockN + get<0>(tKVcKV(0, m, 0))] * int64_t(params.v_row_stride);
463-
#pragma unroll
464-
for (int k = 0; k < size<2>(tVgVToken); ++k) {
465-
if (Is_even_K || tKVpKV(k)) {
466-
cute::copy(gmem_tiled_copy_QKV, tVgVToken(_, m, k), tVsV(_, m, k));
467-
} else if (true) { // Clear_OOB_K
468-
cute::clear(tVsV(_, m, k));
469-
}
461+
tVgVToken.data() = tVgVTokenData + cols_ptr[n * kBlockN + get<0>(tKVcKV(0, m, 0))] * int64_t(params.v_row_stride);
462+
#pragma unroll
463+
for (int k = 0; k < size<2>(tVgVToken); ++k) {
464+
if (Is_even_K || tKVpKV(k)) {
465+
cute::copy(gmem_tiled_copy_QKV, tVgVToken(_, m, k), tVsV(_, m, k));
466+
} else { // Clear_OOB_K
467+
cute::clear(tVsV(_, m, k));
470468
}
471469
}
472470
}
@@ -534,16 +532,14 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
534532
if (n < num_cols_block - 2) {
535533
#pragma unroll
536534
for (int m = 0; m < size<1>(tKgKToken); ++m) {
537-
if (true) { // Is_even_MN
538-
int token_idx = cols_ptr[(n + 1) * kBlockN + get<0>(tKVcKV(0, m, 0))];
539-
tKgKToken.data() = tKgKTokenData + token_idx * int64_t(params.k_row_stride);
540-
#pragma unroll
541-
for (int k = 0; k < size<2>(tKgKToken); ++k) {
542-
if (Is_even_K || tKVpKV(k)) {
543-
cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k));
544-
} else if (true) { // Clear_OOB_K
545-
cute::clear(tKsK(_, m, k));
546-
}
535+
int token_idx = cols_ptr[(n + 1) * kBlockN + get<0>(tKVcKV(0, m, 0))];
536+
tKgKToken.data() = tKgKTokenData + token_idx * int64_t(params.k_row_stride);
537+
#pragma unroll
538+
for (int k = 0; k < size<2>(tKgKToken); ++k) {
539+
if (Is_even_K || tKVpKV(k)) {
540+
cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k));
541+
} else { // Clear_OOB_K
542+
cute::clear(tKsK(_, m, k));
547543
}
548544
}
549545
}
@@ -560,7 +556,7 @@ inline __device__ void sparse_attn_1rowblock(const Params &params, const int bid
560556
for (int k = 0; k < size<2>(tKgKToken); ++k) {
561557
if (Is_even_K || tKVpKV(k)) {
562558
cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k));
563-
} else if (true) { // Clear_OOB_K
559+
} else { // Clear_OOB_K
564560
cute::clear(tKsK(_, m, k));
565561
}
566562
}

vllm_flash_attn/flash_attn_interface.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
4646
return 64
4747

4848
def _sparse_attn_forward(
49-
q, k, v, block_count, block_offset, column_count, column_index,dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, *, out=None
49+
q, k, v, block_count, block_offset, column_count, column_index,dropout_p, softmax_scale, causal, softcap, alibi_slopes, return_softmax, *, out=None
5050
):
5151
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
5252
out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd_sparse(
@@ -62,8 +62,6 @@ def _sparse_attn_forward(
6262
dropout_p,
6363
softmax_scale,
6464
causal,
65-
window_size[0],
66-
window_size[1],
6765
softcap,
6866
return_softmax,
6967
None,
@@ -147,7 +145,6 @@ def sparse_attn_func(
147145
dropout_p=0.0,
148146
softmax_scale=None,
149147
causal=False,
150-
window_size=(-1, -1), # -1 means infinite context window
151148
softcap=0.0, # 0.0 means deactivated
152149
alibi_slopes=None,
153150
deterministic=False,
@@ -160,8 +157,7 @@ def sparse_attn_func(
160157
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
161158
block_count and block_offset for slash sparsity patterns, and
162159
column_count and column_index for vertical sparsity patterns.
163-
For more details please refer to MInference
164-
(Paper: https://arxiv.org/abs/2407.02490, Code: https://github.com/microsoft/MInference).
160+
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
165161
166162
Arguments:
167163
q: (batch_size, seqlen, nheads, headdim)
@@ -175,8 +171,6 @@ def sparse_attn_func(
175171
softmax_scale: float. The scaling of QK^T before applying softmax.
176172
Default to 1 / sqrt(headdim).
177173
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
178-
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
179-
Sliding window is not supported for sparse_attn_func, so only (-1, -1) is valid.
180174
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
181175
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
182176
is added to the attention score of query i and key j.
@@ -204,7 +198,6 @@ def sparse_attn_func(
204198
dropout_p,
205199
softmax_scale,
206200
causal=causal,
207-
window_size=window_size,
208201
softcap=softcap,
209202
alibi_slopes=alibi_slopes,
210203
return_softmax=return_attn_probs and dropout_p > 0,

0 commit comments

Comments
 (0)