Skip to content

Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support #11844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ set(VLLM_EXT_SRC
"csrc/attention/paged_attention_v1.cu"
"csrc/attention/paged_attention_v2.cu"
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
Expand Down
401 changes: 401 additions & 0 deletions csrc/attention/vertical_slash_index.cu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some comments describing what the functions in this file are doing? Comments describing what blocks of code within convert_vertical_slash_indexes_kernel would be helpful as well

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);

void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal);

void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, int64_t context_size,
int64_t block_size_M, int64_t block_size_N, bool causal);
#endif

void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
Expand Down
23 changes: 23 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);

ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes", torch::kCUDA,
&convert_vertical_slash_indexes);

ops.def(
"convert_vertical_slash_indexes_mergehead("
" Tensor! block_count, Tensor! block_offset, "
" Tensor! column_count, Tensor! column_index, "
" Tensor q_seqlens, Tensor q_seqlens, "
" Tensor vertical_indexes, Tensor slash_indexes, "
" Tensor vertical_indices_count, Tensor slash_indices_count, "
" int context_size, int block_size_M, int block_size_N, "
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
&convert_vertical_slash_indexes_mergehead);
#endif

// Activation ops
Expand Down
66 changes: 66 additions & 0 deletions examples/offline_inference/qwen_1m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
import os
from urllib.request import urlopen

from vllm import LLM, SamplingParams

os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN"
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"


def load_prompt() -> str:
# Test cases with various lengths can be found at:
#
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt

with urlopen(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
"/Qwen2.5-1M/test-data/600k.txt",
timeout=5) as response:
prompt = response.read().decode('utf-8')
return prompt


# Processing the prompt.
def process_requests(llm: LLM, prompts: list[str]) -> None:
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.8,
top_k=20,
repetition_penalty=1.05,
detokenize=True,
max_tokens=256,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt_token_ids = output.prompt_token_ids
generated_text = output.outputs[0].text
print(f"Prompt length: {len(prompt_token_ids)}, "
f"Generated text: {generated_text!r}")


# Create an LLM.
def initialize_engine() -> LLM:
llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M",
max_model_len=1048576,
tensor_parallel_size=4,
enforce_eager=True,
enable_chunked_prefill=True,
max_num_batched_tokens=131072)
return llm


def main():
llm = initialize_engine()
prompt = load_prompt()
process_requests(llm, [prompt])


if __name__ == '__main__':
main()
95 changes: 95 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,101 @@ def merge_attn_states(output: torch.Tensor,
prefix_lse, suffix_output, suffix_lse)


def convert_vertical_slash_indexes(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M

block_count = torch.zeros(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
block_offset = torch.zeros(batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_count = torch.zeros(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_index = torch.zeros(batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device)

torch.ops._C.convert_vertical_slash_indexes(
block_count, block_offset, column_count, column_index, q_seqlens,
kv_seqlens, vertical_indexes, slash_indexes, context_size,
block_size_M, block_size_N, causal)
return block_count, block_offset, column_count, column_index


def convert_vertical_slash_indexes_mergehead(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
# [N_HEADS] : different head use different number of indices
vertical_indices_count: torch.Tensor,
slash_indices_count: torch.Tensor,
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M

block_count = torch.empty(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
block_offset = torch.empty(batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_count = torch.empty(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_index = torch.empty(batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device)

torch.ops._C.convert_vertical_slash_indexes_mergehead(
block_count, block_offset, column_count, column_index, q_seqlens,
kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count,
slash_indices_count, context_size, block_size_M, block_size_N, causal)
return block_count, block_offset, column_count, column_index


# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
Expand Down
Loading