Skip to content

Commit 582eb8f

Browse files
authored
Fix params.seqlen_k reference in the splitkv kernel to binfo.actual_seqlen_k (#18)
1 parent f9d2c10 commit 582eb8f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

csrc/flash_attn/src/flash_fwd_kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
531531
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
532532
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
533533

534-
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
534+
const int n_blocks_per_split = ((binfo.actual_seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
535535
const int n_block_min = !Is_local
536536
? n_split_idx * n_blocks_per_split
537537
: std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);

0 commit comments

Comments
 (0)