From 44bf2abb2940241cb97917aedc053b949cac3002 Mon Sep 17 00:00:00 2001 From: Tao He Date: Fri, 24 Jan 2025 02:10:31 +0800 Subject: [PATCH] Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support. Signed-off-by: Tao He --- CMakeLists.txt | 1 + csrc/attention/vertical_slash_index.cu | 401 +++++ csrc/ops.h | 25 + csrc/torch_bindings.cpp | 23 + examples/offline_inference/qwen_1m.py | 66 + vllm/_custom_ops.py | 95 ++ .../backends/dual_chunk_flash_attn.py | 1494 +++++++++++++++++ vllm/config.py | 19 + vllm/engine/arg_utils.py | 15 +- .../model_executor/layers/rotary_embedding.py | 204 ++- .../model_loader/weight_utils.py | 33 + vllm/model_executor/models/qwen2.py | 56 +- vllm/model_executor/models/qwen2_moe.py | 26 +- vllm/platforms/cuda.py | 4 + vllm/platforms/interface.py | 1 + vllm/utils.py | 1 + vllm/worker/model_runner.py | 12 + 17 files changed, 2444 insertions(+), 32 deletions(-) create mode 100644 csrc/attention/vertical_slash_index.cu create mode 100644 examples/offline_inference/qwen_1m.py create mode 100644 vllm/attention/backends/dual_chunk_flash_attn.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 270c480003e..fed6e11e5ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -230,6 +230,7 @@ set(VLLM_EXT_SRC "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" "csrc/attention/merge_attn_states.cu" + "csrc/attention/vertical_slash_index.cu" "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" diff --git a/csrc/attention/vertical_slash_index.cu b/csrc/attention/vertical_slash_index.cu new file mode 100644 index 00000000000..c1b45b143f4 --- /dev/null +++ b/csrc/attention/vertical_slash_index.cu @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include + +#include + +__device__ int64_t save_blocks(int* block_offset, int64_t range_start, + int64_t range_end, int64_t block_size, + int64_t input_block_count, int64_t kv_seqlen) { + if (range_start >= kv_seqlen) { + return input_block_count; + } + if (range_end > kv_seqlen) { + range_end = kv_seqlen; + } + int64_t current_block_count = input_block_count; + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[current_block_count++] = idx; + } + return current_block_count; +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count, + block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, + BLOCK_SIZE_N, NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * This function builds the index of each row of blocks from vertical indices + * and slash indices. The vertical indices are treated as points, while the + * slash indices are converted as ranges. The output consists of the merged + * ranges and separate column indices, where the ranges are represented by + * block indices. + * + * The implementation is referenced from the original MInference repo: + * https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu. + */ +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + block_count.data_ptr(), block_offset.data_ptr(), + column_count.data_ptr(), column_index.data_ptr(), batch_size, + num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash, + causal); +} + +__global__ void convert_vertical_slash_indexes_kernel_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + const int* per_head_vertical_topkv, const int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + // MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S + // above is buffer size, use to compute offset) + NNZ_S = per_head_slash_topkv[head_idx]; + NNZ_V = per_head_vertical_topkv[head_idx]; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* per_head_vertical_topkv, int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel_mergehead<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, + per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset, + column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, + NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * Like the above convert_vertical_slash_indexes, but with + * pre-computed vertical and slash counts. + */ +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, // [N_HEADS, ] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64_mergehead( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + vertical_indices_count.data_ptr(), + slash_indices_count.data_ptr(), block_count.data_ptr(), + block_offset.data_ptr(), column_count.data_ptr(), + column_index.data_ptr(), batch_size, num_heads, num_rows, + block_size_M, block_size_N, nnz_vertical, nnz_slash, causal); +} diff --git a/csrc/ops.h b/csrc/ops.h index 21c5a9e2974..7044b4588b8 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse); + +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal); + +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, int64_t context_size, + int64_t block_size_M, int64_t block_size_N, bool causal); #endif void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1dbd11f5f2a..e9601f864b3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -77,6 +77,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor suffix_output," " Tensor suffix_lse) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); + + ops.def( + "convert_vertical_slash_indexes(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + ops.impl("convert_vertical_slash_indexes", torch::kCUDA, + &convert_vertical_slash_indexes); + + ops.def( + "convert_vertical_slash_indexes_mergehead(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " Tensor vertical_indices_count, Tensor slash_indices_count, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, + &convert_vertical_slash_indexes_mergehead); #endif // Activation ops diff --git a/examples/offline_inference/qwen_1m.py b/examples/offline_inference/qwen_1m.py new file mode 100644 index 00000000000..64a1f4c54b6 --- /dev/null +++ b/examples/offline_inference/qwen_1m.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from urllib.request import urlopen + +from vllm import LLM, SamplingParams + +os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN" +os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" + + +def load_prompt() -> str: + # Test cases with various lengths can be found at: + # + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt + + with urlopen( + "https://qianwen-res.oss-cn-beijing.aliyuncs.com" + "/Qwen2.5-1M/test-data/600k.txt", + timeout=5) as response: + prompt = response.read().decode('utf-8') + return prompt + + +# Processing the prompt. +def process_requests(llm: LLM, prompts: list[str]) -> None: + # Create a sampling params object. + sampling_params = SamplingParams( + temperature=0.7, + top_p=0.8, + top_k=20, + repetition_penalty=1.05, + detokenize=True, + max_tokens=256, + ) + # Generate texts from the prompts. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt_token_ids = output.prompt_token_ids + generated_text = output.outputs[0].text + print(f"Prompt length: {len(prompt_token_ids)}, " + f"Generated text: {generated_text!r}") + + +# Create an LLM. +def initialize_engine() -> LLM: + llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M", + max_model_len=1048576, + tensor_parallel_size=4, + enforce_eager=True, + enable_chunked_prefill=True, + max_num_batched_tokens=131072) + return llm + + +def main(): + llm = initialize_engine() + prompt = load_prompt() + process_requests(llm, [prompt]) + + +if __name__ == '__main__': + main() diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 80f54974521..556d357b95a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -150,6 +150,101 @@ def merge_attn_states(output: torch.Tensor, prefix_lse, suffix_output, suffix_lse) +def convert_vertical_slash_indexes( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.zeros(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.zeros(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, context_size, + block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + +def convert_vertical_slash_indexes_mergehead( + q_seqlens: torch.Tensor, # [BATCH, ] + kv_seqlens: torch.Tensor, # [BATCH, ] + vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + # [N_HEADS] : different head use different number of indices + vertical_indices_count: torch.Tensor, + slash_indices_count: torch.Tensor, + context_size: int, + block_size_M: int, + block_size_N: int, + causal: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size = slash_indexes.size(0) + num_heads = slash_indexes.size(1) + nnz_slash = slash_indexes.size(2) + nnz_vertical = vertical_indexes.size(2) + num_rows = (context_size + block_size_M - 1) // block_size_M + + block_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + block_offset = torch.empty(batch_size, + num_heads, + num_rows, + nnz_slash, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_count = torch.empty(batch_size, + num_heads, + num_rows, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + column_index = torch.empty(batch_size, + num_heads, + num_rows, + nnz_vertical, + dtype=q_seqlens.dtype, + device=q_seqlens.device) + + torch.ops._C.convert_vertical_slash_indexes_mergehead( + block_count, block_offset, column_count, column_index, q_seqlens, + kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, causal) + return block_count, block_offset, column_count, column_index + + # pos encoding ops def rotary_embedding( positions: torch.Tensor, diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py new file mode 100644 index 00000000000..eceab1f1ac9 --- /dev/null +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -0,0 +1,1494 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with Dual chunk flash attention and sparse attention. +""" +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch +import torch.distributed +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.flash_attn import (FlashAttentionBackend, + FlashAttentionImpl, + FlashAttentionMetadata, + FlashAttentionMetadataBuilder) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.utils import async_tensor_h2d +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache, sparse_attn_func) + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + +logger = init_logger(__name__) + + +class DualChunkFlashAttentionBackend(FlashAttentionBackend): + + accept_output_buffer: bool = False + + @staticmethod + def get_name() -> str: + return "DUAL_CHUNK_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]: + return DualChunkFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]: + return DualChunkFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]: + return DualChunkFlashAttentionMetadataBuilder + + +@dataclass +class DualChunkFlashAttentionMetadata(FlashAttentionMetadata): + # Block size of the paged kv cache. + block_size: int = 16 + + # Original max position embeddings. + original_max_position_embeddings: int = 0 + + # Chunk size + chunk_size: int = 8192 + + # Local size + local_size: int = 1024 + + # (batch_size,). The orig sequence length per sequence. + orig_seq_lens: Optional[List[int]] = None + + # orig_seq_lens stored as a tensor. + orig_seq_lens_tensor: Optional[torch.Tensor] = None + + # Length scaling factor + scaling_factor: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for intra attention. + seq_lens_intra: Optional[torch.Tensor] = None + + # Max sequence length for intra attention. + max_seq_len_intra: Optional[int] = None + + # (batch_size, num_blocks). Block table for intra attention. + block_tables_intra: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for succ attention. + seq_lens_succ: Optional[torch.Tensor] = None + + # Max sequence length for succ attention. + max_seq_len_succ: Optional[int] = None + + # (batch_size, num_blocks). Block table for succ attention. + block_tables_succ: Optional[torch.Tensor] = None + + # (batch_size,). Sequence lengths for inter attention. + seq_lens_inter: Optional[torch.Tensor] = None + + # Max sequence length for inter attention. + max_seq_len_inter: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "DualChunkFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + prefill_metadata = super().prefill_metadata + if prefill_metadata is None: + return None + + prefill_metadata = DualChunkFlashAttentionMetadata( + **prefill_metadata.asdict_zerocopy()) + + prefill_metadata.orig_seq_lens = ( + None if self.orig_seq_lens is None else + self.orig_seq_lens[:self.num_prefills]) + prefill_metadata.orig_seq_lens_tensor = ( + None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[:self.num_prefills]) + + if self.original_max_position_embeddings > 0: + assert prefill_metadata.orig_seq_lens_tensor is not None + prefill_metadata.scaling_factor = ( + 0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor / + self.original_max_position_embeddings) + + 1.0).clip(min=1) + + self._cached_prefill_metadata = prefill_metadata + return prefill_metadata + + @property + def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + + decode_metadata = super().decode_metadata + if decode_metadata is None: + return None + + decode_metadata = DualChunkFlashAttentionMetadata( + **decode_metadata.asdict_zerocopy()) + + decode_metadata.orig_seq_lens_tensor = ( + None if self.orig_seq_lens_tensor is None else + self.orig_seq_lens_tensor[self.num_prefills:]) + + assert decode_metadata.orig_seq_lens_tensor is not None + assert decode_metadata.block_tables is not None + + cache_seq_lens = decode_metadata.orig_seq_lens_tensor + chunk_len = self.chunk_size - self.local_size + chunk_num_curr = (cache_seq_lens - 1) // chunk_len + batch_size = decode_metadata.num_decode_tokens + + if self.original_max_position_embeddings > 0: + decode_metadata.scaling_factor = (0.1 * torch.log( + cache_seq_lens / self.original_max_position_embeddings) + + 1.0).clip(min=1) + + seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len + max_seq_len_intra = seq_lens_intra.max().item() + decode_metadata.seq_lens_intra = seq_lens_intra + decode_metadata.max_seq_len_intra = max_seq_len_intra + + block_tables_intra = torch.zeros( + batch_size, + (max_seq_len_intra - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + st = chunk_num_curr[i] * chunk_len // self.block_size + ed = min( + st + (max_seq_len_intra - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_intra[i, :ed - + st] = decode_metadata.block_tables[i, st:ed] + decode_metadata.block_tables_intra = block_tables_intra + + seq_lens_succ = (chunk_num_curr - + (chunk_num_curr - 1).clip(min=0)) * chunk_len + max_seq_len_succ = seq_lens_succ.max().item() + decode_metadata.seq_lens_succ = seq_lens_succ + decode_metadata.max_seq_len_succ = max_seq_len_succ + if max_seq_len_succ: + block_tables_succ = torch.zeros( + batch_size, + (max_seq_len_succ - 1) // self.block_size + 1, + dtype=decode_metadata.block_tables.dtype, + device=decode_metadata.block_tables.device, + ) + for i in range(batch_size): + start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // + self.block_size) + end = min( + start + (max_seq_len_succ - 1) // self.block_size + 1, + (cache_seq_lens[i] - 1) // self.block_size + 1, + ) + block_tables_succ[ + i, :end - start] = decode_metadata.block_tables[i, + start:end] + decode_metadata.block_tables_succ = block_tables_succ + + seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len + max_seq_len_inter = seq_lens_inter.max().item() + decode_metadata.seq_lens_inter = seq_lens_inter + decode_metadata.max_seq_len_inter = max_seq_len_inter + + self._cached_decode_metadata = decode_metadata + return decode_metadata + + +class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): + + def prepare(self): + super().prepare() + self.orig_seq_lens: List[int] = [] + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + super()._add_seq_group(inter_data, chunked_prefill_enabled, + prefix_cache_hit) + for prompt_len, seq_len in zip(inter_data.prompt_lens, + inter_data.seq_lens): + self.orig_seq_lens.append(max(prompt_len, seq_len)) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + attn_metadata = super().build(seq_lens, query_lens, + cuda_graph_pad_size, batch_size) + attn_metadata = DualChunkFlashAttentionMetadata( + **attn_metadata.asdict_zerocopy()) + + device = self.runner.device + attn_metadata.orig_seq_lens = self.orig_seq_lens + attn_metadata.orig_seq_lens_tensor = async_tensor_h2d( + self.orig_seq_lens, torch.int, device, self.runner.pin_memory) + + attn_metadata.block_size = self.runner.block_size + dual_chunk_attn_config = getattr(self.runner.model_config.hf_config, + "dual_chunk_attention_config", {}) + attn_metadata.original_max_position_embeddings = \ + dual_chunk_attn_config.get("original_max_position_embeddings", 0) + attn_metadata.chunk_size = dual_chunk_attn_config.get( + "chunk_size", 8192) + attn_metadata.local_size = dual_chunk_attn_config.get( + "local_size", 1024) + + return attn_metadata + + +class DualChunkFlashAttentionImpl(FlashAttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + The prompts might have different lengths, while the generation tokens + always have length 1. + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + layer_idx: int = -1, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + DualChunkFlashAttentionBackend.get_supported_head_sizes()) + + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + assert dual_chunk_attention_config is not None + self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192) + self.local_size = dual_chunk_attention_config.get("local_size", 1024) + self.original_max_position_embeddings = dual_chunk_attention_config.get( + "original_max_position_embeddings", 0) + self.sparse_attention_config = dual_chunk_attention_config.get( + "sparse_attention_config", None) + if not self.sparse_attention_config: + logger.warning_once("Sparse attention will not be enabled as " + "sparse attention config is not provided.") + self.sparse_attention_enabled = dual_chunk_attention_config.get( + "sparse_attention_enabled", self.sparse_attention_config + is not None) + self.sparse_attention_threshold = dual_chunk_attention_config.get( + "sparse_attention_threshold", 32768) + self.sparse_attention_last_q = dual_chunk_attention_config.get( + "sparse_attention_last_q", 64) + self.layer_idx = layer_idx + self.dual_chunk_attention_config = dual_chunk_attention_config + + if self.sparse_attention_config: + self.sparse_attention_config = { + int(i): j + for i, j in self.sparse_attention_config[ + self.layer_idx].items() + } + start_head = self.num_heads * get_tensor_model_parallel_rank() + end_head = start_head + self.num_heads + self.sparse_attention_config = [ + self.sparse_attention_config[i] + for i in range(start_head, end_head) + ] + + if self.sparse_attention_enabled: + self.arange = torch.arange(self.sparse_attention_last_q, + device="cuda") + self.last_q_mask = (self.arange[None, None, :, None] + >= self.arange[None, None, None, :]) + + def forward( # type: ignore + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DualChunkFlashAttentionMetadata, + ) -> torch.Tensor: + """Forward pass with DualChunkFlashAttention. + Args: + query: shape = [num_tokens, num_heads * head_size] + query_succ: shape = [num_tokens, num_heads * head_size] + query_inter: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + ( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ) = torch.split(query, query.shape[-1] // 5, dim=-1) + + assert ( + query_succ is not None and query_inter is not None + ), "query_succ and query_inter are required in Dual Chunk Attention." + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + query_succ = query_succ.view(-1, self.num_heads, self.head_size) + query_inter = query_inter.view(-1, self.num_heads, self.head_size) + query_succ_critical = query_succ_critical.view(-1, self.num_heads, + self.head_size) + query_inter_critical = query_inter_critical.view( + -1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.original_max_position_embeddings > 0: + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.scaling_factor is not None + assert prefill_meta.query_start_loc is not None + assert prefill_meta.orig_seq_lens is not None + current_start = 0 + query_start_loc_cpu = prefill_meta.query_start_loc.cpu() + for i in range(len(prefill_meta.orig_seq_lens)): + current_end = (current_start + + (query_start_loc_cpu[i + 1] - + query_start_loc_cpu[i]).item()) + key[current_start:current_end].mul_( + prefill_meta.scaling_factor[i]) + current_start = current_end + assert current_end <= attn_metadata.num_prefill_tokens + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + key[attn_metadata.num_prefill_tokens:].mul_( + scaling_factor.unsqueeze(-1).unsqueeze(-1)) + + if kv_cache is not None and kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + output = torch.empty_like(query) + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + decode_query_succ = query_succ[num_prefill_tokens:] + decode_query_inter = query_inter[num_prefill_tokens:] + + # QKV for prefill. + query = query[:num_prefill_tokens] + query_succ = query_succ[:num_prefill_tokens] + query_inter = query_inter[:num_prefill_tokens] + query_succ_critical = query_succ_critical[:num_prefill_tokens] + query_inter_critical = query_inter_critical[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention, called during the profiling run. + out = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + assert prefill_meta.orig_seq_lens is not None + output[:num_prefill_tokens] = ( + self._dual_chunk_flash_attn_prefill( + q=query, + q_succ=query_succ, + q_inter=query_inter, + q_succ_critical=query_succ_critical, + q_inter_critical=query_inter_critical, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + orig_seq_lens=prefill_meta.orig_seq_lens, + scaling_factor=prefill_meta.scaling_factor, + softmax_scale=self.scale, + causal=True, + window_size=(-1, -1), + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + chunk_size=self.chunk_size, + local_size=self.local_size, + )) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output[num_prefill_tokens:] = ( + self._dual_chunk_flash_attn_decoding( + decode_query.unsqueeze(1), + decode_query_succ.unsqueeze(1), + decode_query_inter.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + chunk_size=self.chunk_size, + local_size=self.local_size, + original_max_position_embeddings=self. + original_max_position_embeddings, + decode_meta=decode_meta, + ).squeeze(1)) + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + def _dual_chunk_flash_attn_prefill( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + orig_seq_lens: List[int], + scaling_factor: torch.Tensor, + softmax_scale: float, + causal: Optional[bool] = True, + window_size: Tuple[int, int] = (-1, -1), + alibi_slopes: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + chunk_size: int = 8192, + local_size: int = 1024, + ): + if alibi_slopes is not None: + raise ValueError( + "Dual Chunk Attention does not support alibi_slopes") + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + if window_size != (-1, -1): + raise ValueError( + "Dual Chunk Attention does not support window_size") + + cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist() + cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist() + all_outputs = [] + + for i in range(0, len(cu_seqlens_q_cpu) - 1): + qs = cu_seqlens_q_cpu[i] + qe = cu_seqlens_q_cpu[i:i + 2][-1] + ks = cu_seqlens_k_cpu[i] + ke = cu_seqlens_k_cpu[i:i + 2][-1] + + current_q = q[qs:qe] + current_q_succ = q_succ[qs:qe] + current_q_inter = q_inter[qs:qe] + current_q_succ_critical = q_succ_critical[qs:qe] + current_q_inter_critical = q_inter_critical[qs:qe] + + if block_table is None: + current_k = k[ks:ke] + current_v = v[ks:ke] + current_block_table = None + current_orig_seq_len = orig_seq_lens[i] + else: + current_block_table = block_table[i] + current_orig_seq_len = orig_seq_lens[i] + current_k = k + current_v = v + sparse_attn_enabled = (self.sparse_attention_enabled + and current_orig_seq_len + > self.sparse_attention_threshold) + + if current_q.shape[0] == 0: + continue + + if current_k.shape[0] == 0: + all_outputs.append( + torch.zeros( + (current_q.shape[0], current_q.shape[1], v.shape[2]), + device=q.device, + dtype=q.dtype, + )) + continue + + current_output = torch.empty_like(current_q) + group_size = int(current_q.size(-2) / current_k.size(-2)) + + if sparse_attn_enabled: + num_device_q_heads = current_q.size(-2) + heads_vertical_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + heads_slash_size = torch.empty(size=(num_device_q_heads, ), + dtype=torch.int32) + for head_id in range(current_q.size(-2)): + ( + ty, + vertical_size, + slash_size, + _, + ) = self.sparse_attention_config[head_id] + assert ty == "vertical_and_slash", "only support slash mode" + + if vertical_size == 30: + vertical_size += 100 + heads_vertical_size[head_id] = vertical_size + heads_slash_size[head_id] = slash_size + + current_output = self._dual_chunk_flash_attn_prefill_func( + current_q, # allheads + current_q_succ, + current_q_inter, + current_q_succ_critical, + current_q_inter_critical, + current_k, + current_v, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + heads_vertical_size=heads_vertical_size, + heads_slash_size=heads_slash_size, + group_size=group_size) + else: + for head_id in range(current_q.size(-2)): + # (seq_len, num_heads, head_size) + current_q_head = current_q[:, head_id, :].unsqueeze(1) + current_q_succ_head = \ + current_q_succ[:, head_id, :].unsqueeze(1) + current_q_inter_head = \ + current_q_inter[:, head_id, :].unsqueeze(1) + current_q_succ_head_critical = \ + current_q_succ_critical[:, head_id, :].unsqueeze(1) + current_q_inter_head_critical = \ + current_q_inter_critical[:, head_id, :].unsqueeze(1) + if block_table is not None: + current_k_head = current_k[..., head_id // + group_size, :].unsqueeze(2) + current_v_head = current_v[..., head_id // + group_size, :].unsqueeze(2) + + else: + current_k_head = current_k[:, head_id, :].unsqueeze(1) + current_v_head = current_v[:, head_id, :].unsqueeze(1) + + current_out = self._dual_chunk_flash_attn_prefill_func( + current_q_head, + current_q_succ_head, + current_q_inter_head, + current_q_succ_head_critical, + current_q_inter_head_critical, + current_k_head, + current_v_head, + current_block_table, + softmax_scale, + chunk_size, + local_size, + scaling_factor[i].item(), + ke - ks, + sparse_attn_enabled=sparse_attn_enabled, + ) + current_output[:, head_id:head_id + 1, :] = current_out + all_outputs.append(current_output) + return torch.cat(all_outputs, dim=0) + + def _dual_chunk_flash_attn_prefill_func( + self, + q, + q_succ, + q_inter, + q_succ_critical, + q_inter_critical, + k, + v, + block_table, + softmax_scale: float, + chunk_size: int, + local_size: int, + scaling_factor: float, + k_length: int, + sparse_attn_enabled: Optional[bool] = True, + heads_vertical_size=None, + heads_slash_size=None, + group_size=None, + ): + flash_results = [] + chunk_len = chunk_size - local_size + + if block_table is not None: + block_size = v.shape[1] + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + else: + block_size = 1 + + if self.original_max_position_embeddings > 0: + softmax_scale = softmax_scale * scaling_factor + + begin = k_length - q.shape[0] + while begin < k_length: + flash_per_chunk = [] + + prev_chunk_end_pos = (begin // chunk_len) * chunk_len + next_chunk_end_pos = prev_chunk_end_pos + chunk_len + end = min(next_chunk_end_pos, k_length) + qbegin = begin - (k_length - q.shape[0]) + qend = end - (k_length - q.shape[0]) + + qk_chunks = [] + q_states_intra = q[qbegin:qend] + # choose critical token + if block_table is not None: + block_tables_intra = _get_block(block_table, block_size, + prev_chunk_end_pos, end) + k_states_intra = k[block_tables_intra].view( + -1, *k.shape[-2:])[:(end - prev_chunk_end_pos)] + v_states_intra = v[block_tables_intra].view( + -1, *v.shape[-2:])[:(end - prev_chunk_end_pos)] + else: + block_tables_intra = None + k_states_intra = k[prev_chunk_end_pos:end] + v_states_intra = v[prev_chunk_end_pos:end] + + if sparse_attn_enabled: + last_q_size = min(qend - qbegin, self.sparse_attention_last_q) + _, num_device_k_heads, head_dim = k_states_intra.shape + k_states_intra = (k_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + v_states_intra = (v_states_intra.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, head_dim)) + qk_chunks.append( + (q_states_intra.transpose(0, 1)[:, -last_q_size:] * + softmax_scale) @ k_states_intra.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len >= 0: + q_states_succ = q_succ[qbegin:qend] + q_states_succ_critical = q_succ_critical[qbegin:qend] + if block_table is not None: + block_tables_succ = _get_block( + block_table, block_size, + prev_chunk_end_pos - chunk_len, prev_chunk_end_pos) + k_states_succ = k[block_tables_succ].view( + -1, *k.shape[-2:])[:chunk_len] + v_states_succ = v[block_tables_succ].view( + -1, *v.shape[-2:])[:chunk_len] + else: + k_states_succ = k[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + v_states_succ = v[prev_chunk_end_pos - + chunk_len:prev_chunk_end_pos] + + if sparse_attn_enabled: + k_states_succ = (k_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_succ = (v_states_succ.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_succ_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_succ.permute(1, 2, 0)) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + q_states_inter = q_inter[qbegin:qend] + q_states_inter_critical = q_inter_critical[qbegin:qend] + if block_table is not None: + block_tables_inter = _get_block( + block_table, block_size, 0, + prev_chunk_end_pos - chunk_len) + k_states_inter = k[block_tables_inter].view( + -1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + v_states_inter = v[block_tables_inter].view( + -1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] + else: + k_states_inter = k[:prev_chunk_end_pos - chunk_len] + v_states_inter = v[:prev_chunk_end_pos - chunk_len] + + if sparse_attn_enabled: + k_states_inter = (k_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + v_states_inter = (v_states_inter.unsqueeze(2).repeat( + 1, 1, group_size, + 1).reshape(-1, num_device_k_heads * group_size, + head_dim)) + qk_chunks.append((q_states_inter_critical.transpose( + 0, 1)[:, -last_q_size:] * softmax_scale) + @ k_states_inter.permute(1, 2, 0)) + + if sparse_attn_enabled: + reversed_qk = qk_chunks[::-1] + qk = torch.cat(reversed_qk, dim=-1) + + qk[:, :, -last_q_size:] = torch.where( + self.last_q_mask[..., -last_q_size:, + -last_q_size:].to(qk.device), + qk[:, :, -last_q_size:], -torch.inf) + qk = F.softmax(qk, dim=-1, dtype=torch.float32) + + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + + # Avoid sorting by using the min/max ints to fill the indexer + # buffers. + int32_max = torch.iinfo(torch.int32).max + int32_min = torch.iinfo(torch.int32).min + n_heads = qk.size()[0] + max_slash_topk = torch.max(heads_slash_size).item() + max_vertical_topk = torch.max(heads_vertical_size).item() + # store each head's slash topk, vertical topk + vertical = vertical.reshape((n_heads, -1)) + # prevent out of range when prompt size < max_vertical_topk + max_vertical_topk = min(vertical.shape[-1], max_vertical_topk) + vertical_topk_buffer = torch.topk(vertical, max_vertical_topk, + -1).indices + slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk), + dtype=torch.int64, + device=qk.device) + for head_i in range(n_heads): + # (nqheads=1, lastq, k_len) + head_score = qk[head_i:head_i + 1, :, :] + slash_scores = _sum_all_diagonal_matrix(head_score) + if head_score.size(1) != 1: + # drop right up corner + slash_scores = slash_scores[..., :-last_q_size + 1] + slash_scores[..., -100:] = torch.inf + + head_slash_size = heads_slash_size[head_i] + head_slash_size = min(head_slash_size, vertical.size(-1)) + slash_topk = torch.topk(slash_scores, head_slash_size, + -1).indices + #(nheads, max_topk) + slash_topk_buffer[head_i, :head_slash_size] = slash_topk + + # reset heads topk + heads_slash_size[head_i] = head_slash_size + heads_vertical_size[head_i] = min( + heads_vertical_size[head_i], max_vertical_topk) + + # store + vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + succ_vertical_buffer = torch.full((n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + succ_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + inter_vertical_buffer = torch.full( + (n_heads, max_vertical_topk), + int32_max, + dtype=torch.int64, + device=q.device) + inter_slash_buffer = torch.full((n_heads, max_slash_topk), + int32_min, + dtype=torch.int64, + device=q.device) + + vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + succ_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_vertical_size_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + inter_slash_sizes_buffer = torch.empty(size=(n_heads, ), + dtype=torch.int32, + device=q.device) + + for head_i in range(n_heads): + vertical_topk = vertical_topk_buffer[ + head_i, :heads_vertical_size[head_i]] + # intra + intra_vertical_indices = vertical_topk[ + vertical_topk >= + prev_chunk_end_pos] - prev_chunk_end_pos + if intra_vertical_indices.nelement() == 0: + intra_vertical_indices = torch.cat([ + intra_vertical_indices, + torch.arange(0, + k_states_intra.size(0), + max(1, + k_states_intra.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + slash_topk = slash_topk_buffer[ + head_i, :heads_slash_size[head_i]] + intra_slash_indices = ( + (qk.size(-1) - 1) - + slash_topk[slash_topk >= prev_chunk_end_pos]) + # fill buffer + v_count = intra_vertical_indices.nelement() + s_count = intra_slash_indices.nelement() + vertical_size_buffer[head_i] = v_count + slash_sizes_buffer[head_i] = s_count + vertical_buffer[head_i, :v_count].copy_( + intra_vertical_indices) + slash_buffer[head_i, :s_count].copy_(intra_slash_indices) + # succ + if prev_chunk_end_pos - chunk_len >= 0: + succ_vertical_indices = vertical_topk[ + (vertical_topk < prev_chunk_end_pos) + & (vertical_topk >= prev_chunk_end_pos - + chunk_len)] - (prev_chunk_end_pos - chunk_len) + # TODO: support no vertical + if succ_vertical_indices.nelement() == 0: + succ_vertical_indices = torch.cat([ + succ_vertical_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + succ_slash_indices = ( + (prev_chunk_end_pos + (qend - qbegin) - 1) - + slash_topk[((slash_topk >= + (prev_chunk_end_pos - chunk_len)) & + (slash_topk < (prev_chunk_end_pos + + (qend - qbegin))))]) + if succ_slash_indices.nelement() == 0: + succ_slash_indices = torch.cat([ + succ_slash_indices, + torch.arange( + 0, + k_states_succ.size(0), + max(1, + k_states_succ.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = succ_vertical_indices.nelement() + s_count = succ_slash_indices.nelement() + succ_vertical_size_buffer[head_i] = v_count + succ_slash_sizes_buffer[head_i] = s_count + succ_vertical_buffer[head_i, :v_count].copy_( + succ_vertical_indices) + succ_slash_buffer[head_i, :s_count].copy_( + succ_slash_indices) + + if prev_chunk_end_pos - 2 * chunk_len >= 0: + inter_vertical_indices = vertical_topk[ + vertical_topk < prev_chunk_end_pos - chunk_len] + + if inter_vertical_indices.nelement() == 0: + inter_vertical_indices = torch.cat([ + inter_vertical_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + inter_slash_indices = ( + (prev_chunk_end_pos - chunk_len + + (qend - qbegin) - 1) - + slash_topk[slash_topk < (prev_chunk_end_pos - + chunk_len + + (qend - qbegin))]) + if inter_slash_indices.nelement() == 0: + inter_slash_indices = torch.cat([ + inter_slash_indices, + torch.arange( + 0, + k_states_inter.size(0), + max(1, + k_states_inter.size(0) / 5), + dtype=torch.int32, + device=intra_vertical_indices.device) + ]) + # fill buffer + v_count = inter_vertical_indices.nelement() + s_count = inter_slash_indices.nelement() + inter_vertical_size_buffer[head_i] = v_count + inter_slash_sizes_buffer[head_i] = s_count + inter_vertical_buffer[head_i, :v_count].copy_( + inter_vertical_indices) + inter_slash_buffer[head_i, :s_count].copy_( + inter_slash_indices) + else: + intra_vertical_indices, intra_slash_indices = None, None + succ_vertical_indices, succ_slash_indices = None, None + inter_vertical_indices, inter_slash_indices = None, None + + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + block_table=block_table, + stage="intra", + vertical_indices=vertical_buffer, + slash_indices=slash_buffer, + vertical_indices_count=vertical_size_buffer, + slash_indices_count=slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_intra, + k_states_intra, + v_states_intra, + softmax_scale=softmax_scale, + causal=True, + block_table=block_table, + stage="intra", + vertical_indices=intra_vertical_indices, + slash_indices=intra_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="succ", + vertical_indices=succ_vertical_buffer, + slash_indices=succ_slash_buffer, + vertical_indices_count=succ_vertical_size_buffer, + slash_indices_count=succ_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_succ, + k_states_succ, + v_states_succ, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="succ", + vertical_indices=succ_vertical_indices, + slash_indices=succ_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + if prev_chunk_end_pos - chunk_len * 2 >= 0: + if sparse_attn_enabled: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="inter", + vertical_indices=inter_vertical_buffer, + slash_indices=inter_slash_buffer, + vertical_indices_count=inter_vertical_size_buffer, + slash_indices_count=inter_slash_sizes_buffer, + mergehead_softmax_scale=softmax_scale, + sparse_attn_enabled=sparse_attn_enabled) + else: + flash_result = self._do_flash_attn( + q_states_inter, + k_states_inter, + v_states_inter, + softmax_scale=softmax_scale, + causal=False, + block_table=block_table, + stage="inter", + vertical_indices=inter_vertical_indices, + slash_indices=inter_slash_indices, + sparse_attn_enabled=sparse_attn_enabled) + flash_per_chunk.append(flash_result) + + flash_results.append(flash_per_chunk) + begin = end + + attn_output = self._merge_attn_outputs(flash_results) + del flash_results + return attn_output + + def _do_flash_attn( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + softmax_scale: float, + causal: bool = True, + block_table: torch.Tensor = None, + max_seqlen_k: Optional[int] = None, + stage: str = "intra", + vertical_indices: Optional[torch.Tensor] = None, + slash_indices: Optional[torch.Tensor] = None, + vertical_indices_count: Optional[torch.Tensor] = None, + slash_indices_count: Optional[torch.Tensor] = None, + mergehead_softmax_scale: Optional[float] = None, + sparse_attn_enabled: Optional[bool] = False, + ): + if max_seqlen_k is None: + max_seqlen_k = key_states.shape[0] + + q_len = query_states.shape[0] + q_heads = query_states.shape[1] + h_dim = query_states.shape[-1] + + if sparse_attn_enabled: + assert slash_indices is not None + if stage == "intra": + assert causal + else: + assert not causal + + query_states = query_states.unsqueeze(0).transpose(1, 2) + key_states = key_states.unsqueeze(0).transpose(1, 2) + value_states = value_states.unsqueeze(0).transpose(1, 2) + + q = query_states + k = key_states + v = value_states + + if (vertical_indices_count is not None and \ + slash_indices_count is not None): + assert mergehead_softmax_scale is not None + + res, s_lse = _vertical_slash_sparse_attention( + q, + k, + v, + vertical_indices, + slash_indices, + mergehead_softmax_scale, + causal=causal, + stage=stage, + vertical_indices_count=vertical_indices_count, + slash_indices_count=slash_indices_count) + res = res.view(q_heads, q_len, + h_dim).transpose(0, 1) # (qlen,nhead,h_dim) + s_lse = s_lse.view( + q_heads, q_len, + 1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen) + else: + res, s_lse = _vertical_slash_sparse_attention(q, + k, + v, + vertical_indices, + slash_indices, + softmax_scale, + causal=causal, + stage=stage) + res = res.view(q_len, q_heads, h_dim) + s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() + return res, s_lse + + output, softmax_lse = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + softmax_scale=softmax_scale, + cu_seqlens_q=torch.tensor([0, query_states.shape[0]], + dtype=torch.int32, + device=query_states.device), + max_seqlen_q=query_states.shape[0], + cu_seqlens_k=torch.tensor([0, max_seqlen_k], + dtype=torch.int32, + device=query_states.device), + max_seqlen_k=max_seqlen_k, + causal=causal, + block_table=block_table.unsqueeze(0), + return_softmax_lse=True, + ) + softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, + 2).float() + return output, softmax_lse + + def _merge_attn_outputs( + self, + flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]], + return_lse: Optional[bool] = False, + ) -> torch.Tensor: + attn_outputs_all = [] + logits_all = [] + + for flash_per_chunk in flash_results: + if len(flash_per_chunk) == 1: + attn_outputs_all.append(flash_per_chunk[0][0]) + if return_lse: + logits_all.append(flash_per_chunk[0][1]) + continue + + attn_outputs = torch.stack([ + flash_attn_output[0] for flash_attn_output in flash_per_chunk + ]) + logits = torch.stack([ + flash_attn_output[1] for flash_attn_output in flash_per_chunk + ]) + logits = logits.to(torch.float32) + + if return_lse: + max_val = torch.max(logits, dim=0).values + diff = torch.abs(logits[0] - logits[1]) + log_sum_exp = max_val + torch.log1p(torch.exp(-diff)) + logits_all.append(log_sum_exp) + + max_logits = torch.max(logits, dim=0).values + stable_logits = logits - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1) + attn_outputs_all.append(attn_outputs.sum(dim=0)) + + if return_lse: + return (torch.cat(attn_outputs_all, + dim=0), torch.cat(logits_all, dim=-1)) + else: + return torch.cat(attn_outputs_all, dim=0) + + def _dual_chunk_flash_attn_decoding( + self, + query: torch.Tensor, + query_succ: torch.Tensor, + query_inter: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + causal: bool, + alibi_slopes: Optional[torch.Tensor], + chunk_size: int, + local_size: int, + original_max_position_embeddings: int, + decode_meta: DualChunkFlashAttentionMetadata, + ): + if not causal: + raise ValueError( + "Dual Chunk Attention does not support causal=False") + + block_size = value_cache.shape[1] + chunk_len = chunk_size - local_size + if chunk_len % block_size != 0: + raise ValueError("chunk_len must be divisible by block_size.") + if original_max_position_embeddings > 0: + assert decode_meta.scaling_factor is not None + scaling_factor = decode_meta.scaling_factor + query = (query * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype + ) # possible for numerical issue, need to fused in the kernel + query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to( + query.dtype) + outputs_list = [] + softmax_lses_list = [] + + # intra-attention + intra_output, intra_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query, + key_cache, + value_cache, + decode_meta.block_tables_intra, + decode_meta.seq_lens_intra, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(intra_output) + softmax_lses_list.append(intra_softmax_lse) + + # succ-attention + if decode_meta.max_seq_len_succ: + succ_output, succ_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_succ, + key_cache, + value_cache, + decode_meta.block_tables_succ, + decode_meta.seq_lens_succ, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(succ_output) + softmax_lses_list.append(succ_softmax_lse) + + # inter-attention + if decode_meta.max_seq_len_inter: + inter_output, inter_softmax_lse = ( + self._dual_chunk_flash_attn_decoding_with_exp_sums( + query_inter, + key_cache, + value_cache, + block_table[:, :decode_meta.max_seq_len_inter], + decode_meta.seq_lens_inter, + softmax_scale, + alibi_slopes, + causal=False, + )) + outputs_list.append(inter_output) + softmax_lses_list.append(inter_softmax_lse) + outputs = torch.stack(outputs_list, dim=0) + del outputs_list + softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32) + del softmax_lses_list + max_logits = torch.max(softmax_lses, dim=0).values + stable_logits = softmax_lses - max_logits.unsqueeze(0) + lse_s = torch.exp(stable_logits).detach() + lse_sum = torch.sum(lse_s, dim=0) + lse_s /= lse_sum + outputs *= lse_s.unsqueeze(-1).transpose(2, 3) + return outputs.sum(0) + + def _dual_chunk_flash_attn_decoding_with_exp_sums( + self, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + softmax_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + ): + out, softmax_lse = flash_attn_with_kvcache( + q=query, + k_cache=key_cache, + v_cache=value_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + softmax_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + return_softmax_lse=True, + ) + mask = (cache_seqlens == 0) + out[mask] = 0 + softmax_lse[mask] = -float("inf") + return out, softmax_lse + + +def _vertical_slash_sparse_attention( + query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] + key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] + v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] + s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] + softmax_scale: float, + causal: bool = True, + stage: str = "intra", + block_size_M: int = 64, + block_size_N: int = 64, + vertical_indices_count: torch.Tensor = None, # [N_HEADS,] + slash_indices_count: torch.Tensor = None, +): + if stage == "intra": + assert causal + else: + assert not causal + + batch_size, num_heads, context_size, head_dim = query.shape + _, _, kv_seq_len, _ = key.shape + + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim + query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) + key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) + value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) + + v_idx = v_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] + s_idx = s_idx.to(torch.int32).reshape( + (batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] + q_seqlens = torch.tensor([context_size], + dtype=torch.int32, + device=query.device) + kv_seqlens = torch.tensor([kv_seq_len], + dtype=torch.int32, + device=query.device) + + if vertical_indices_count is not None and slash_indices_count is not None: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes_mergehead( + q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count, + slash_indices_count, context_size, block_size_M, block_size_N, + causal) + else: + ( + block_count, + block_offset, + column_count, + column_index, + ) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx, + s_idx, context_size, + block_size_M, block_size_N, + causal) + + q = query.transpose(1, 2).contiguous() + k = key.transpose(1, 2).contiguous() + v = value.transpose(1, 2).contiguous() + out, lse = sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + causal=causal, + softmax_scale=softmax_scale, + return_softmax_lse=True, + ) + out = out.transpose(1, 2).contiguous() + softmax_lse = lse.reshape(*lse.shape, 1) + return (out[..., :context_size, :head_dim], + softmax_lse[..., :context_size, :]) + + +def _sum_all_diagonal_matrix(mat: torch.tensor): + h, n, m = mat.shape + # Zero matrix used for padding + zero_mat = torch.zeros((h, n, n), device=mat.device) + # pads the matrix on left and right + mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) + # Change the strides + mat_strided = mat_padded.as_strided((1, n, n + m), + (n * (2 * n + m), 2 * n + m + 1, 1)) + # Sums the resulting matrix's columns + sum_diags = torch.sum(mat_strided, 1) + return sum_diags[:, 1:] # drop left bottom corner + + +def _get_block(block_table: torch.Tensor, block_size: int, begin: int, + end: int): + begin_block = begin // block_size + end_block = (end - 1) // block_size + 1 + return block_table[begin_block:end_block] diff --git a/vllm/config.py b/vllm/config.py index ef0163eaff8..16c5cde7c5e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -928,6 +928,23 @@ def _verify_with_expert_parallelism(self) -> None: "Number of experts in the model must be greater than 0 " "when expert parallelism is enabled.") + def verify_dual_chunk_attention_config( + self, + load_config: "LoadConfig", + ) -> None: + if hasattr(self.hf_config, "dual_chunk_attention_config"): + # Try loading the sparse attention config + from vllm.model_executor.model_loader.weight_utils import ( + get_sparse_attention_config) + sparse_attn_config = get_sparse_attention_config(self, load_config) + if sparse_attn_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_config"] = sparse_attn_config + if "sparse_attention_enabled" not in \ + self.hf_config.dual_chunk_attention_config: + self.hf_config.dual_chunk_attention_config[ + "sparse_attention_enabled"] = True + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -4173,6 +4190,8 @@ def __post_init__(self): self.speculative_config, self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) + self.model_config.verify_dual_chunk_attention_config( + self.load_config) if self.cache_config is not None: self.cache_config.verify_with_parallel_config(self.parallel_config) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0ff6a6fbbc1..f9a7d7dce43 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -36,8 +36,8 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext -from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build, - is_in_ray_actor) +from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, + GiB_bytes, is_in_doc_build, is_in_ray_actor) # yapf: enable @@ -981,6 +981,17 @@ def create_engine_config( assert self.enable_chunked_prefill is not None + if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]: + assert self.enforce_eager, ( + "Cuda graph is not supported with DualChunkFlashAttention. " + "To run the model in eager mode, set 'enforce_eager=True' " + "or use '--enforce-eager' in the CLI.") + assert current_platform.is_cuda(), ( + "DualChunkFlashAttention is only supported on CUDA platform.") + assert not use_v1, ( + "DualChunkFlashAttention is not supported on V1 engine. " + "To run the model in V0 engine, try set 'VLLM_USE_V1=0'") + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index f8392eb679d..2d634273eca 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -1486,6 +1486,184 @@ def omni_get_updates_use_audio_in_video( return updates +@CustomOp.register("dual_chunk_rotary_embedding") +class DualChunkRotaryEmbedding(CustomOp): + """Rotary positional embedding for Dual Chunk Attention.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + chunk_size: int, + local_size: int, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.chunk_size = chunk_size + self.local_size = local_size + self.dtype = dtype + self.device = torch.device(f"cuda:{torch.cuda.current_device()}") + (q_cache, qc_cache, k_cache, qc_no_clamp_cache, + q_inter_cache) = self._compute_cos_sin_cache() + + self.register_buffer("cos_sin_q_cache", q_cache, persistent=False) + self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False) + self.register_buffer("cos_sin_k_cache", k_cache, persistent=False) + self.register_buffer("cos_sin_qc_no_clamp_cache", + qc_no_clamp_cache, + persistent=False) + self.register_buffer("cos_sin_q_inter_cache", + q_inter_cache, + persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + chunk_len = self.chunk_size - self.local_size + q_t = torch.arange(chunk_len, dtype=torch.float) + qc_t = (torch.arange(chunk_len, dtype=torch.float) + + chunk_len).clamp(max=self.chunk_size) + k_t = torch.arange(self.max_position_embeddings, + dtype=torch.float) % chunk_len + + # count from chunk_len, no clamp(self.chunk_size) restriction + qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len + # count from self.chunk_size for q_inter's rope + q_inter_t = torch.arange(chunk_len, + dtype=torch.float) + self.chunk_size + + q_freqs = torch.outer(q_t, inv_freq) + qc_freqs = torch.outer(qc_t, inv_freq) + k_freqs = torch.outer(k_t, inv_freq) + qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq) + q_inter_freqs = torch.outer(q_inter_t, inv_freq) + + q_cos = q_freqs.cos() + q_sin = q_freqs.sin() + qc_cos = qc_freqs.cos() + qc_sin = qc_freqs.sin() + k_cos = k_freqs.cos() + k_sin = k_freqs.sin() + + qc_no_clamp_cos = qc_no_clamp_freqs.cos() + qc_no_clamp_sin = qc_no_clamp_freqs.sin() + q_inter_cos = q_inter_freqs.cos() + q_inter_sin = q_inter_freqs.sin() + + q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype, + device=self.device) + qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), + dim=-1).to(dtype=self.dtype, + device=self.device) + return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + else: + query_pass = None + key_pass = None + + positions_with_offsets = (torch.add(positions, offsets) + if offsets is not None else positions) + key = self._apply_rotary_embedding( + self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass) + chunk_len = self.chunk_size - self.local_size + query = self._apply_rotary_embedding( + self.cos_sin_q_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_succ = self._apply_rotary_embedding( + self.cos_sin_qc_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter = self._apply_rotary_embedding( + self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1), + query_rot, query_pass) + query_succ_critical = self._apply_rotary_embedding( + self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + query_inter_critical = self._apply_rotary_embedding( + self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len], + query_rot, query_pass) + + # merge query into one tensor to simplify the interfaces + query = torch.cat(( + query, + query_succ, + query_inter, + query_succ_critical, + query_inter_critical, + ), + dim=-1) + return query, key + + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin + + if self.rotary_dim < self.head_size: + hidden = torch.cat((hidden_rot, hidden_pass), dim=-1) + else: + hidden = hidden_rot + return hidden.flatten(-2).squeeze(0) + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + s += f", chunk_size={self.chunk_size}, local_size={self.local_size}" + return s + + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} @@ -1498,6 +1676,7 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() @@ -1510,14 +1689,35 @@ def get_rope( rope_scaling_args = tuple(rope_scaling_tuple.items()) else: rope_scaling_args = None + + if dual_chunk_attention_config is not None: + dual_chunk_attention_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in dual_chunk_attention_config.items() + if k != "sparse_attention_config" + } + dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) + else: + dual_chunk_attention_args = None + if partial_rotary_factor < 1.0: rotary_dim = int(rotary_dim * partial_rotary_factor) key = (head_size, rotary_dim, max_position, base, is_neox_style, - rope_scaling_args, dtype) + rope_scaling_args, dual_chunk_attention_args, dtype) if key in _ROPE_DICT: return _ROPE_DICT[key] - if not rope_scaling: + if dual_chunk_attention_config is not None: + extra_kwargs = { + k: v + for k, v in dual_chunk_attention_config.items() + if k in ("chunk_size", "local_size") + } + rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + **extra_kwargs) + elif not rope_scaling: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, dtype) else: diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index beff33414ad..8f9d809022a 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -217,6 +217,39 @@ def get_quant_config(model_config: ModelConfig, return quant_cls.from_config(config) +def get_sparse_attention_config( + model_config: ModelConfig, + load_config: LoadConfig, + sparse_attention_config_filename: str = "sparse_attention_config.json", +) -> Dict[str, Any]: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + config_file = os.path.join(hf_folder, sparse_attention_config_filename) + if not os.path.exists(config_file): + return {} + + # Load the sparse attention config. + with open(config_file) as f: + config = json.load(f) + logger.info("Loaded sparse attention config from %s", config_file) + + return config + + def download_weights_from_hf( model_name_or_path: str, cache_dir: Optional[str], diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index f76f31c9fc8..b5850011e7f 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -23,7 +23,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Set, Tuple, Union +from typing import Any, Iterable, Optional, Set, Tuple, Union import torch from torch import nn @@ -53,7 +53,7 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter, + extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -99,17 +99,20 @@ def forward(self, x): class Qwen2Attention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - rope_scaling: Optional[Tuple] = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER) -> None: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, + Any]] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -131,6 +134,7 @@ def __init__(self, self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, @@ -155,15 +159,21 @@ def __init__(self, max_position=max_position, base=self.rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=attn_type) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_type=attn_type, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}) def forward( self, @@ -192,6 +202,9 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) # By default, Qwen2 uses causal attention as it is a decoder-only model. # You can override the HF config with `is_causal=False` to enable @@ -213,6 +226,7 @@ def __init__( rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 47d90919ed8..14f9f815894 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -175,6 +175,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -198,6 +199,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, @@ -221,14 +223,20 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + **{ + "layer_idx": extract_layer_index(prefix), + "dual_chunk_attention_config": dual_chunk_attention_config, + } if dual_chunk_attention_config else {}) def forward( self, @@ -256,6 +264,9 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Qwen2MoeAttention( @@ -268,6 +279,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + dual_chunk_attention_config=dual_chunk_attention_config, ) # Note: Qwen/Qwen2-57B-A14B-Instruct does not have diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index f116285870e..587f7c0c056 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -240,6 +240,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, elif selected_backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") return "vllm.attention.backends.xformers.XFormersBackend" + elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN: + logger.info("Using DualChunkFlashAttention backend.") + return ("vllm.attention.backends.dual_chunk_flash_attn." + "DualChunkFlashAttentionBackend") elif selected_backend == _Backend.FLASH_ATTN: pass elif selected_backend: diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 68b90796ece..5b013d80a9f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -50,6 +50,7 @@ class _Backend(enum.Enum): PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto() + DUAL_CHUNK_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() diff --git a/vllm/utils.py b/vllm/utils.py index 6779c5b3f8d..59635a25eb3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -153,6 +153,7 @@ STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" GB_bytes = 1_000_000_000 diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d96021cc688..8a294de45c8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -204,6 +204,7 @@ def simple_reinit(self): self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore + self.prompt_lens[0] = 0 # type: ignore self.query_lens[0] = 0 # type: ignore self.context_lens[0] = 0 # type: ignore self.curr_sliding_window_blocks[0] = 0 # type: ignore @@ -236,6 +237,8 @@ def __init__( # The original sequence length (before applying sliding window). # This is used to compute slot mapping. orig_seq_lens: Optional[List[int]] = None, + # This is used in the dual-chunk flash attention backend. + prompt_lens: Optional[List[int]] = None, # The query length. query_lens: Optional[List[int]] = None, # The number of tokens that are already computed. @@ -316,6 +319,12 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.orig_seq_lens[seq_id] = 0 + if prompt_lens: + self.prompt_lens = prompt_lens + else: + for seq_id in range(len(self.seq_ids)): + self.prompt_lens[seq_id] = 0 + if query_lens: self.query_lens = query_lens else: @@ -370,6 +379,7 @@ def __init__( self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] + self.prompt_lens = prompt_lens or [] self.query_lens = query_lens or [] self.context_lens = context_lens or [] self.curr_sliding_window_blocks = \ @@ -403,6 +413,7 @@ def __post_init__(self): self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs + self.prompt_lens = [0] * self.n_seqs self.query_lens = [0] * self.n_seqs self.context_lens = [0] * self.n_seqs self.curr_sliding_window_blocks = [0] * self.n_seqs @@ -552,6 +563,7 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len + inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len() inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) inter_data.inputs_embeds = prompt_embeds