Skip to content

Commit 6df39b3

Browse files
committed
Merge branch 'upstream-main' into mistral-nemo-support
2 parents ef3b5ba + ecdb462 commit 6df39b3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+1913
-716
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ steps:
8484
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
8585
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
8686
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
87+
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
88+
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
8789
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
8890
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
8991
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
@@ -108,6 +110,7 @@ steps:
108110
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
109111
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
110112
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
113+
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
111114
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
112115
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
113116

@@ -140,6 +143,7 @@ steps:
140143
# install tensorizer for tensorize_vllm_model.py
141144
- pip install awscli tensorizer
142145
- python3 offline_inference.py
146+
- python3 cpu_offload.py
143147
- python3 offline_inference_with_prefix.py
144148
- python3 llm_engine_example.py
145149
- python3 llava_example.py

.github/workflows/reminder_comment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
owner: context.repo.owner,
1616
repo: context.repo.repo,
1717
issue_number: context.issue.number,
18-
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only trigger `fastcheck` CI to run, which consists only a small and essential subset of tests to quickly catch errors with the flexibility to run extra individual tests on top (you can do this by unblocking test steps in the Buildkite run). \n\nFull CI run is still required to merge this PR so once the PR is ready to go, please make sure to run it. If you need all test signals in between PR commits, you can trigger full CI as well.\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
18+
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it's required (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
1919
})
2020
env:
2121
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ set(VLLM_EXT_SRC
151151
"csrc/quantization/fp8/common.cu"
152152
"csrc/cuda_utils_kernels.cu"
153153
"csrc/moe_align_block_size_kernels.cu"
154+
"csrc/prepare_inputs/advance_step.cu"
154155
"csrc/torch_bindings.cpp")
155156

156157
if(VLLM_GPU_LANG STREQUAL "CUDA")

csrc/ops.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
5252

5353
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
5454

55+
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
56+
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
57+
torch::Tensor& input_positions, torch::Tensor& seq_lens,
58+
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
59+
5560
#ifndef USE_ROCM
5661
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
5762
const torch::Tensor& codebooks,
@@ -123,12 +128,16 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
123128

124129
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
125130

126-
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
127-
torch::Tensor& scale);
131+
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
132+
torch::Tensor const& scale);
128133

129-
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
134+
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
130135
torch::Tensor& scale);
131136

137+
void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out,
138+
torch::Tensor const& input,
139+
torch::Tensor& scale);
140+
132141
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
133142
int64_t block_size, torch::Tensor sorted_token_ids,
134143
torch::Tensor experts_ids,

csrc/prepare_inputs/advance_step.cu

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* The goal of this GPU kernel is to advance input tensors on the GPU directly
3+
* PR: https://github.com/vllm-project/vllm/pull/6338
4+
* Current restrictions:
5+
* 1. Specialized for DraftModelRunner
6+
* 2. Supports flash_attn only
7+
*/
8+
9+
#include "advance_step.cuh"
10+
11+
namespace prepare_inputs {
12+
13+
//
14+
template <int const num_threads>
15+
__global__ void advance_step_kernel(int num_seqs, int num_queries,
16+
int block_size, long* input_tokens_ptr,
17+
long const* sampled_token_ids_ptr,
18+
long* input_positions_ptr,
19+
int* seq_lens_ptr, long* slot_mapping_ptr,
20+
int const* block_tables_ptr,
21+
int64_t const block_tables_stride) {
22+
int num_query_blocks = div_ceil(num_queries, num_threads);
23+
24+
if (blockIdx.x >= num_query_blocks) {
25+
return;
26+
}
27+
28+
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
29+
30+
if (cur_query_id >= num_queries) {
31+
return;
32+
}
33+
34+
// Update input_tokens
35+
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
36+
37+
int seq_len = seq_lens_ptr[cur_query_id];
38+
int next_seq_len = seq_len + 1;
39+
int next_input_pos = next_seq_len - 1;
40+
41+
// Update seq_lens
42+
seq_lens_ptr[cur_query_id] = next_seq_len;
43+
// Update input_positions
44+
input_positions_ptr[cur_query_id] = next_input_pos;
45+
46+
int const* seq_block_tables_ptr =
47+
block_tables_ptr + block_tables_stride * cur_query_id;
48+
49+
int block_index = next_input_pos / block_size;
50+
int block_offset = next_input_pos % block_size;
51+
52+
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
53+
// Update slot_mapping
54+
slot_mapping_ptr[cur_query_id] = slot_num;
55+
}
56+
57+
inline void verify_tensor(std::string const& name, torch::Tensor& t,
58+
int64_t const size_0, int64_t const size_1,
59+
c10::ScalarType const type) {
60+
bool size_0_cond = true;
61+
if (size_0 != -1) {
62+
size_0_cond = t.size(0) == size_0;
63+
}
64+
65+
bool size_1_cond = true;
66+
if (size_1 != -1) {
67+
size_1_cond = t.size(1) == size_1;
68+
}
69+
70+
bool is_contiguous = t.is_contiguous();
71+
bool same_type = t.dtype() == type;
72+
73+
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
74+
if (!pass) {
75+
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
76+
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
77+
" is not as expected: shape = [", size_0, ", ", size_1,
78+
"], type = ", type);
79+
}
80+
}
81+
82+
void advance_step(int num_seqs, int num_queries, int block_size,
83+
torch::Tensor& input_tokens, // type: long
84+
torch::Tensor& sampled_token_ids, // type: long
85+
torch::Tensor& input_positions, // type: long
86+
torch::Tensor& seq_lens, // type: int
87+
torch::Tensor& slot_mapping, // type: long
88+
torch::Tensor& block_tables) { // type: int
89+
90+
if (logging) {
91+
printf("advance_step:\n");
92+
printf(" num_seqs = %d\n", num_seqs);
93+
printf(" num_queries = %d\n", num_queries);
94+
printf(" block_size = %d\n", block_size);
95+
}
96+
// Verify all tensors
97+
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
98+
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
99+
at::kLong);
100+
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
101+
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
102+
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
103+
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
104+
105+
int dev = sampled_token_ids.get_device();
106+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
107+
108+
int blocks;
109+
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
110+
111+
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
112+
num_seqs, num_queries, block_size,
113+
reinterpret_cast<long*>(input_tokens.data_ptr()),
114+
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
115+
reinterpret_cast<long*>(input_positions.data_ptr()),
116+
reinterpret_cast<int*>(seq_lens.data_ptr()),
117+
reinterpret_cast<long*>(slot_mapping.data_ptr()),
118+
reinterpret_cast<int const*>(block_tables.data_ptr()),
119+
block_tables.stride(0));
120+
}
121+
122+
} // namespace prepare_inputs
123+
124+
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
125+
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
126+
torch::Tensor& input_positions, torch::Tensor& seq_lens,
127+
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
128+
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
129+
sampled_token_ids, input_positions, seq_lens,
130+
slot_mapping, block_tables);
131+
}

csrc/prepare_inputs/advance_step.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <torch/all.h>
4+
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <c10/cuda/CUDAGuard.h>
7+
#include <cuda.h>
8+
#include <cuda_fp16.h>
9+
#include <cuda_runtime.h>
10+
#include <iostream>
11+
12+
namespace prepare_inputs {
13+
14+
static constexpr int max_threads = 256;
15+
static constexpr bool logging = false;
16+
17+
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
18+
19+
} // namespace prepare_inputs

0 commit comments

Comments
 (0)