Skip to content

Commit e76466d

Browse files
authored
[Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step (#6338)
1 parent 5f0b993 commit e76466d

File tree

12 files changed

+568
-130
lines changed

12 files changed

+568
-130
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ set(VLLM_EXT_SRC
151151
"csrc/quantization/fp8/common.cu"
152152
"csrc/cuda_utils_kernels.cu"
153153
"csrc/moe_align_block_size_kernels.cu"
154+
"csrc/prepare_inputs/advance_step.cu"
154155
"csrc/torch_bindings.cpp")
155156

156157
if(VLLM_GPU_LANG STREQUAL "CUDA")

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
5252

5353
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
5454

55+
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
56+
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
57+
torch::Tensor& input_positions, torch::Tensor& seq_lens,
58+
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
59+
5560
#ifndef USE_ROCM
5661
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
5762
const torch::Tensor& codebooks,

csrc/prepare_inputs/advance_step.cu

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* The goal of this GPU kernel is to advance input tensors on the GPU directly
3+
* PR: https://github.com/vllm-project/vllm/pull/6338
4+
* Current restrictions:
5+
* 1. Specialized for DraftModelRunner
6+
* 2. Supports flash_attn only
7+
*/
8+
9+
#include "advance_step.cuh"
10+
11+
namespace prepare_inputs {
12+
13+
//
14+
template <int const num_threads>
15+
__global__ void advance_step_kernel(int num_seqs, int num_queries,
16+
int block_size, long* input_tokens_ptr,
17+
long const* sampled_token_ids_ptr,
18+
long* input_positions_ptr,
19+
int* seq_lens_ptr, long* slot_mapping_ptr,
20+
int const* block_tables_ptr,
21+
int64_t const block_tables_stride) {
22+
int num_query_blocks = div_ceil(num_queries, num_threads);
23+
24+
if (blockIdx.x >= num_query_blocks) {
25+
return;
26+
}
27+
28+
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
29+
30+
if (cur_query_id >= num_queries) {
31+
return;
32+
}
33+
34+
// Update input_tokens
35+
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
36+
37+
int seq_len = seq_lens_ptr[cur_query_id];
38+
int next_seq_len = seq_len + 1;
39+
int next_input_pos = next_seq_len - 1;
40+
41+
// Update seq_lens
42+
seq_lens_ptr[cur_query_id] = next_seq_len;
43+
// Update input_positions
44+
input_positions_ptr[cur_query_id] = next_input_pos;
45+
46+
int const* seq_block_tables_ptr =
47+
block_tables_ptr + block_tables_stride * cur_query_id;
48+
49+
int block_index = next_input_pos / block_size;
50+
int block_offset = next_input_pos % block_size;
51+
52+
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
53+
// Update slot_mapping
54+
slot_mapping_ptr[cur_query_id] = slot_num;
55+
}
56+
57+
inline void verify_tensor(std::string const& name, torch::Tensor& t,
58+
int64_t const size_0, int64_t const size_1,
59+
c10::ScalarType const type) {
60+
bool size_0_cond = true;
61+
if (size_0 != -1) {
62+
size_0_cond = t.size(0) == size_0;
63+
}
64+
65+
bool size_1_cond = true;
66+
if (size_1 != -1) {
67+
size_1_cond = t.size(1) == size_1;
68+
}
69+
70+
bool is_contiguous = t.is_contiguous();
71+
bool same_type = t.dtype() == type;
72+
73+
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
74+
if (!pass) {
75+
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
76+
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
77+
" is not as expected: shape = [", size_0, ", ", size_1,
78+
"], type = ", type);
79+
}
80+
}
81+
82+
void advance_step(int num_seqs, int num_queries, int block_size,
83+
torch::Tensor& input_tokens, // type: long
84+
torch::Tensor& sampled_token_ids, // type: long
85+
torch::Tensor& input_positions, // type: long
86+
torch::Tensor& seq_lens, // type: int
87+
torch::Tensor& slot_mapping, // type: long
88+
torch::Tensor& block_tables) { // type: int
89+
90+
if (logging) {
91+
printf("advance_step:\n");
92+
printf(" num_seqs = %d\n", num_seqs);
93+
printf(" num_queries = %d\n", num_queries);
94+
printf(" block_size = %d\n", block_size);
95+
}
96+
// Verify all tensors
97+
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
98+
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
99+
at::kLong);
100+
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
101+
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
102+
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
103+
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
104+
105+
int dev = sampled_token_ids.get_device();
106+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
107+
108+
int blocks;
109+
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
110+
111+
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
112+
num_seqs, num_queries, block_size,
113+
reinterpret_cast<long*>(input_tokens.data_ptr()),
114+
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
115+
reinterpret_cast<long*>(input_positions.data_ptr()),
116+
reinterpret_cast<int*>(seq_lens.data_ptr()),
117+
reinterpret_cast<long*>(slot_mapping.data_ptr()),
118+
reinterpret_cast<int const*>(block_tables.data_ptr()),
119+
block_tables.stride(0));
120+
}
121+
122+
} // namespace prepare_inputs
123+
124+
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
125+
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
126+
torch::Tensor& input_positions, torch::Tensor& seq_lens,
127+
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
128+
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
129+
sampled_token_ids, input_positions, seq_lens,
130+
slot_mapping, block_tables);
131+
}

csrc/prepare_inputs/advance_step.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <torch/all.h>
4+
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <c10/cuda/CUDAGuard.h>
7+
#include <cuda.h>
8+
#include <cuda_fp16.h>
9+
#include <cuda_runtime.h>
10+
#include <iostream>
11+
12+
namespace prepare_inputs {
13+
14+
static constexpr int max_threads = 256;
15+
static constexpr bool logging = false;
16+
17+
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
18+
19+
} // namespace prepare_inputs

csrc/torch_bindings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7272
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
7373
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
7474

75+
// prepare_inputs advance_step
76+
ops.def("advance_step", &advance_step);
77+
ops.impl("advance_step", torch::kCUDA, &advance_step);
78+
7579
// Layernorm
7680
// Apply Root Mean Square (RMS) Normalization to the input tensor.
7781
ops.def(

tests/spec_decode/e2e/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def get_output_from_llm_generator(
227227
maybe_assert_ngram_worker(llm)
228228

229229
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
230+
230231
token_ids = [output.outputs[0].token_ids for output in outputs]
231232
tokens = [output.outputs[0].text for output in outputs]
232233

tests/spec_decode/test_multi_step_worker.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k():
642642
assert proposals.proposal_lens.tolist() == [
643643
k for _ in range(expected_num_proposal_seqs - 1)
644644
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
645+
646+
647+
@torch.inference_mode()
648+
def test_use_draft_model_runner_advance_step():
649+
"""Verify that draft model runner triggers advance step
650+
when applicable.
651+
"""
652+
seed = 100
653+
model_name = 'JackFram/llama-68m'
654+
655+
k = 5
656+
batch_size = 32
657+
block_size = 32
658+
num_gpu_blocks = 2048 // block_size
659+
worker = create_worker(
660+
MultiStepWorker,
661+
model_name,
662+
block_size,
663+
num_gpu_blocks,
664+
seed,
665+
model_runner_cls=TP1DraftModelRunner,
666+
)
667+
668+
# Mock "_gpu_advance_step" to raise an exception when called.
669+
exception_secret = "artificial stop"
670+
worker.model_runner._gpu_advance_step = MagicMock()
671+
worker.model_runner._gpu_advance_step.side_effect = ValueError(
672+
exception_secret)
673+
674+
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
675+
676+
# Fallback (should not call) when num_steps=1.
677+
execute_model_req = ExecuteModelRequest(
678+
seq_group_metadata_list=seq_group_metadata_list,
679+
num_lookahead_slots=k,
680+
num_steps=1)
681+
worker.execute_model(execute_model_req=execute_model_req)
682+
683+
# Expect exception if _gpu_advance_step is called.
684+
execute_model_req = ExecuteModelRequest(
685+
seq_group_metadata_list=seq_group_metadata_list,
686+
num_lookahead_slots=k,
687+
num_steps=k)
688+
689+
with pytest.raises(ValueError, match=exception_secret):
690+
worker.execute_model(execute_model_req=execute_model_req)
691+
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
692+
assert len(call_args_list) == 1

vllm/_custom_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
166166
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
167167

168168

169+
def advance_step(num_seqs: int, num_queries: int, block_size: int,
170+
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
171+
input_positions: torch.Tensor, seq_lens: torch.Tensor,
172+
slot_mapping: torch.Tensor,
173+
block_tables: torch.Tensor) -> None:
174+
"""Advance a step on GPU for existing inputs for a multi-step runner"""
175+
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
176+
input_tokens, sampled_token_ids,
177+
input_positions, seq_lens, slot_mapping,
178+
block_tables)
179+
180+
169181
# quantization ops
170182
# awq
171183
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,

0 commit comments

Comments
 (0)