Skip to content

Commit bd0c782

Browse files
SolitaryThinkerAlvant
authored andcommitted
[multi-step] add flashinfer backend (vllm-project#7928)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent d28f296 commit bd0c782

File tree

9 files changed

+371
-84
lines changed

9 files changed

+371
-84
lines changed

csrc/ops.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
5454

5555
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
5656

57-
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
58-
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
59-
torch::Tensor& input_positions, torch::Tensor& seq_lens,
60-
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
57+
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
58+
int64_t block_size, torch::Tensor& input_tokens,
59+
torch::Tensor& sampled_token_ids,
60+
torch::Tensor& input_positions,
61+
torch::Tensor& seq_lens,
62+
torch::Tensor& slot_mapping,
63+
torch::Tensor& block_tables);
64+
65+
void advance_step_flashinfer(
66+
int64_t num_seqs, int64_t num_queries, int64_t block_size,
67+
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
68+
torch::Tensor& input_positions, torch::Tensor& seq_lens,
69+
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
70+
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
71+
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
6172

6273
#ifndef USE_ROCM
6374
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,

csrc/prepare_inputs/advance_step.cu

Lines changed: 200 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@ namespace prepare_inputs {
1212

1313
//
1414
template <int const num_threads>
15-
__global__ void advance_step_kernel(int num_seqs, int num_queries,
16-
int block_size, long* input_tokens_ptr,
17-
long const* sampled_token_ids_ptr,
18-
long* input_positions_ptr,
19-
int* seq_lens_ptr, long* slot_mapping_ptr,
20-
int const* block_tables_ptr,
21-
int64_t const block_tables_stride) {
15+
__global__ void advance_step_flashattn_kernel(
16+
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
17+
long const* sampled_token_ids_ptr, long* input_positions_ptr,
18+
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
19+
int64_t const block_tables_stride) {
2220
int num_query_blocks = div_ceil(num_queries, num_threads);
2321

2422
if (blockIdx.x >= num_query_blocks) {
@@ -79,16 +77,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
7977
}
8078
}
8179

82-
void advance_step(int num_seqs, int num_queries, int block_size,
83-
torch::Tensor& input_tokens, // type: long
84-
torch::Tensor& sampled_token_ids, // type: long
85-
torch::Tensor& input_positions, // type: long
86-
torch::Tensor& seq_lens, // type: int
87-
torch::Tensor& slot_mapping, // type: long
88-
torch::Tensor& block_tables) { // type: int
80+
__global__ void advance_step_flashinfer_kernel(
81+
int num_threads, int num_seqs, int num_queries, int block_size,
82+
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
83+
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
84+
int const* block_tables_ptr, int64_t const block_tables_stride,
85+
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
86+
int num_query_blocks = div_ceil(num_queries, num_threads);
87+
88+
if (blockIdx.x < num_query_blocks) {
89+
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
90+
91+
if (cur_query_id < num_queries) {
92+
// Update input_tokens
93+
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
94+
95+
int seq_len = seq_lens_ptr[cur_query_id];
96+
int next_seq_len = seq_len + 1;
97+
int next_input_pos = next_seq_len - 1;
98+
99+
// Update seq_lens
100+
seq_lens_ptr[cur_query_id] = next_seq_len;
101+
// Update input_positions
102+
input_positions_ptr[cur_query_id] = next_input_pos;
103+
104+
int const* seq_block_tables_ptr =
105+
block_tables_ptr + block_tables_stride * cur_query_id;
106+
107+
int block_index = next_input_pos / block_size;
108+
int block_offset = next_input_pos % block_size;
109+
110+
// Update paged_kv_last_page_len
111+
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
112+
113+
int slot_num =
114+
seq_block_tables_ptr[block_index] * block_size + block_offset;
115+
// Update slot_mapping
116+
slot_mapping_ptr[cur_query_id] = slot_num;
117+
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
118+
}
119+
}
120+
}
121+
122+
__global__ void advance_step_flashinfer_indptr_kernel(
123+
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
124+
int* block_table_bound_ptr) {
125+
int idx = blockIdx.x * num_threads + threadIdx.x;
126+
127+
// Update paged_kv_indptr
128+
if (idx < num_queries) {
129+
int sum = 0;
130+
for (int i = 0; i <= idx; ++i) {
131+
sum += block_table_bound_ptr[i];
132+
}
133+
paged_kv_indptr_ptr[idx + 1] = sum;
134+
}
135+
}
136+
137+
__global__ void advance_step_flashinfer_indices_kernel(
138+
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
139+
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
140+
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
141+
int idx = blockIdx.x * num_threads + threadIdx.x;
142+
int row = idx / block_tables_stride;
143+
int col = idx % block_tables_stride;
144+
145+
if (row < num_queries && col < block_table_bound_ptr[row]) {
146+
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
147+
block_tables_ptr[row * block_tables_stride + col];
148+
}
149+
// if cudagraph, fill padded seqs with the last valid seq's indptr
150+
if (num_queries < row && row <= num_seqs) {
151+
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
152+
}
153+
}
154+
155+
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
156+
torch::Tensor& input_tokens, // type: long
157+
torch::Tensor& sampled_token_ids, // type: long
158+
torch::Tensor& input_positions, // type: long
159+
torch::Tensor& seq_lens, // type: int
160+
torch::Tensor& slot_mapping, // type: long
161+
torch::Tensor& block_tables) { // type: int
89162

90163
if (logging) {
91-
printf("advance_step:\n");
164+
printf("advance_step_flashattn:\n");
92165
printf(" num_seqs = %d\n", num_seqs);
93166
printf(" num_queries = %d\n", num_queries);
94167
printf(" block_size = %d\n", block_size);
@@ -108,24 +181,126 @@ void advance_step(int num_seqs, int num_queries, int block_size,
108181
int blocks;
109182
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
110183

111-
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
112-
num_seqs, num_queries, block_size,
184+
advance_step_flashattn_kernel<max_threads>
185+
<<<blocks, max_threads, 0, stream>>>(
186+
num_seqs, num_queries, block_size,
187+
reinterpret_cast<long*>(input_tokens.data_ptr()),
188+
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
189+
reinterpret_cast<long*>(input_positions.data_ptr()),
190+
reinterpret_cast<int*>(seq_lens.data_ptr()),
191+
reinterpret_cast<long*>(slot_mapping.data_ptr()),
192+
reinterpret_cast<int const*>(block_tables.data_ptr()),
193+
block_tables.stride(0));
194+
}
195+
196+
void advance_step_flashinfer(
197+
int num_seqs, int num_queries, int block_size,
198+
torch::Tensor& input_tokens, // type: long
199+
torch::Tensor& sampled_token_ids, // type: long
200+
torch::Tensor& input_positions, // type: long
201+
torch::Tensor& seq_lens, // type: int
202+
torch::Tensor& slot_mapping, // type: long
203+
torch::Tensor& block_tables, // type: int
204+
torch::Tensor& paged_kv_indices, // type: int
205+
torch::Tensor& paged_kv_indptr, // type: int
206+
torch::Tensor& paged_kv_last_page_len, // type: int
207+
torch::Tensor& block_table_bound) { // type: int
208+
209+
if (logging) {
210+
printf("advance_step_flashinfer:\n");
211+
printf(" num_seqs = %d\n", num_seqs);
212+
printf(" num_queries = %d\n", num_queries);
213+
printf(" block_size = %d\n", block_size);
214+
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0));
215+
}
216+
// Verify all tensors
217+
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
218+
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
219+
// at::kLong);
220+
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
221+
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
222+
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
223+
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
224+
225+
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
226+
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
227+
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
228+
at::kInt);
229+
230+
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
231+
232+
int dev = sampled_token_ids.get_device();
233+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
234+
235+
int blocks;
236+
int threads;
237+
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
238+
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
239+
if (logging) {
240+
printf("launching kernel with %d blocks\n", blocks);
241+
}
242+
243+
// TODO(will): support arbitrary block_tables stride
244+
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
245+
TORCH_CHECK(false,
246+
"multi-step: not enough threads to map block_table to"
247+
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
248+
"of seqs,",
249+
" increasing the block size or take smaller steps.",
250+
" num_queries = ", num_queries,
251+
" block_tables.stride(0) = ", block_tables.stride(0),
252+
" blocks = ", blocks, " max_threads = ", threads);
253+
}
254+
255+
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
256+
threads, num_seqs, num_queries, block_size,
113257
reinterpret_cast<long*>(input_tokens.data_ptr()),
114258
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
115259
reinterpret_cast<long*>(input_positions.data_ptr()),
116260
reinterpret_cast<int*>(seq_lens.data_ptr()),
117261
reinterpret_cast<long*>(slot_mapping.data_ptr()),
118262
reinterpret_cast<int const*>(block_tables.data_ptr()),
119-
block_tables.stride(0));
263+
block_tables.stride(0),
264+
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
265+
reinterpret_cast<int*>(block_table_bound.data_ptr()));
266+
267+
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
268+
threads, num_seqs, num_queries,
269+
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
270+
reinterpret_cast<int*>(block_table_bound.data_ptr()));
271+
272+
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
273+
threads, num_seqs, num_queries,
274+
reinterpret_cast<int const*>(block_tables.data_ptr()),
275+
block_tables.stride(0),
276+
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
277+
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
278+
reinterpret_cast<int*>(block_table_bound.data_ptr()));
120279
}
121280

122281
} // namespace prepare_inputs
123282

124-
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
125-
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
126-
torch::Tensor& input_positions, torch::Tensor& seq_lens,
127-
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
128-
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
129-
sampled_token_ids, input_positions, seq_lens,
130-
slot_mapping, block_tables);
283+
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
284+
int64_t block_size, torch::Tensor& input_tokens,
285+
torch::Tensor& sampled_token_ids,
286+
torch::Tensor& input_positions,
287+
torch::Tensor& seq_lens,
288+
torch::Tensor& slot_mapping,
289+
torch::Tensor& block_tables) {
290+
prepare_inputs::advance_step_flashattn(
291+
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
292+
input_positions, seq_lens, slot_mapping, block_tables);
293+
}
294+
295+
void advance_step_flashinfer(
296+
int64_t num_seqs, int64_t num_queries, int64_t block_size,
297+
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
298+
torch::Tensor& input_positions, torch::Tensor& seq_lens,
299+
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
300+
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
301+
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
302+
prepare_inputs::advance_step_flashinfer(
303+
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
304+
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
305+
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
131306
}

csrc/torch_bindings.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7474

7575
// prepare_inputs advance_step
7676
ops.def(
77-
"advance_step(int num_seqs, int num_queries, int block_size, "
77+
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
7878
"Tensor! input_tokens, Tensor sampled_token_ids, "
7979
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
8080
"Tensor block_tables) -> ()");
81-
ops.impl("advance_step", torch::kCUDA, &advance_step);
81+
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
82+
83+
ops.def(
84+
"advance_step_flashinfer("
85+
" int num_seqs, int num_queries, int block_size,"
86+
" Tensor! input_tokens, Tensor sampled_token_ids,"
87+
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
88+
" Tensor block_tables, Tensor! paged_kv_indices,"
89+
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
90+
" Tensor! block_table_bounds"
91+
") -> ()");
92+
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
8293

8394
// Layernorm
8495
// Apply Root Mean Square (RMS) Normalization to the input tensor.

tests/multi_step/test_correctness_async_llm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Test the AsyncLLMEngine with multi-step-decoding
2-
32
from typing import List, Optional
43

54
import pytest
65

6+
from tests.kernels.utils import override_backend_env_variable
7+
78
from ..models.utils import check_logprobs_close
89
from ..utils import (completions_with_server_args, get_client_text_generations,
910
get_client_text_logprob_generations)
@@ -33,8 +34,9 @@
3334
@pytest.mark.parametrize("eager_mode", [False, True])
3435
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
3536
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
36-
@pytest.mark.parametrize("num_logprobs", [None, 5])
37-
@pytest.mark.parametrize("is_async", [False, True])
37+
@pytest.mark.parametrize("num_logprobs", [5])
38+
@pytest.mark.parametrize("is_async", [True])
39+
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
3840
@pytest.mark.asyncio
3941
async def test_multi_step(
4042
example_prompts,
@@ -46,6 +48,8 @@ async def test_multi_step(
4648
num_prompts: int,
4749
is_async: bool,
4850
num_logprobs: Optional[int],
51+
attention_backend: str,
52+
monkeypatch,
4953
) -> None:
5054
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
5155
client/server environment.
@@ -71,6 +75,8 @@ async def test_multi_step(
7175
completions endpoint; `None` -> no logprobs
7276
"""
7377

78+
override_backend_env_variable(monkeypatch, attention_backend)
79+
7480
prompts = example_prompts
7581
if len(prompts) < num_prompts:
7682
prompts = prompts * ((num_prompts // len(prompts)) + 1)

vllm/_custom_ops.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,16 +161,36 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
161161
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
162162

163163

164-
def advance_step(num_seqs: int, num_queries: int, block_size: int,
165-
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
166-
input_positions: torch.Tensor, seq_lens: torch.Tensor,
167-
slot_mapping: torch.Tensor,
168-
block_tables: torch.Tensor) -> None:
164+
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
165+
input_tokens: torch.Tensor,
166+
sampled_token_ids: torch.Tensor,
167+
input_positions: torch.Tensor,
168+
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
169+
block_tables: torch.Tensor) -> None:
169170
"""Advance a step on GPU for existing inputs for a multi-step runner"""
170-
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
171-
input_tokens, sampled_token_ids,
172-
input_positions, seq_lens, slot_mapping,
173-
block_tables)
171+
return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
172+
block_size, input_tokens,
173+
sampled_token_ids,
174+
input_positions, seq_lens,
175+
slot_mapping, block_tables)
176+
177+
178+
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
179+
input_tokens: torch.Tensor,
180+
sampled_token_ids: torch.Tensor,
181+
input_positions: torch.Tensor,
182+
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
183+
block_tables: torch.Tensor,
184+
paged_kv_indices: torch.Tensor,
185+
paged_kv_indptr: torch.Tensor,
186+
paged_kv_last_page_len: torch.Tensor,
187+
block_table_bound: torch.Tensor) -> None:
188+
189+
return torch.ops._C.advance_step_flashinfer(
190+
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
191+
input_positions, seq_lens, slot_mapping, block_tables,
192+
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
193+
block_table_bound)
174194

175195

176196
# quantization ops

0 commit comments

Comments
 (0)