@@ -12,13 +12,11 @@ namespace prepare_inputs {
12
12
13
13
//
14
14
template <int const num_threads>
15
- __global__ void advance_step_kernel (int num_seqs, int num_queries,
16
- int block_size, long * input_tokens_ptr,
17
- long const * sampled_token_ids_ptr,
18
- long * input_positions_ptr,
19
- int * seq_lens_ptr, long * slot_mapping_ptr,
20
- int const * block_tables_ptr,
21
- int64_t const block_tables_stride) {
15
+ __global__ void advance_step_flashattn_kernel (
16
+ int num_seqs, int num_queries, int block_size, long * input_tokens_ptr,
17
+ long const * sampled_token_ids_ptr, long * input_positions_ptr,
18
+ int * seq_lens_ptr, long * slot_mapping_ptr, int const * block_tables_ptr,
19
+ int64_t const block_tables_stride) {
22
20
int num_query_blocks = div_ceil (num_queries, num_threads);
23
21
24
22
if (blockIdx .x >= num_query_blocks) {
@@ -79,16 +77,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
79
77
}
80
78
}
81
79
82
- void advance_step (int num_seqs, int num_queries, int block_size,
83
- torch::Tensor& input_tokens, // type: long
84
- torch::Tensor& sampled_token_ids, // type: long
85
- torch::Tensor& input_positions, // type: long
86
- torch::Tensor& seq_lens, // type: int
87
- torch::Tensor& slot_mapping, // type: long
88
- torch::Tensor& block_tables) { // type: int
80
+ __global__ void advance_step_flashinfer_kernel (
81
+ int num_threads, int num_seqs, int num_queries, int block_size,
82
+ long * input_tokens_ptr, long const * sampled_token_ids_ptr,
83
+ long * input_positions_ptr, int * seq_lens_ptr, long * slot_mapping_ptr,
84
+ int const * block_tables_ptr, int64_t const block_tables_stride,
85
+ int * paged_kv_last_page_len_ptr, int * block_table_bound_ptr) {
86
+ int num_query_blocks = div_ceil (num_queries, num_threads);
87
+
88
+ if (blockIdx .x < num_query_blocks) {
89
+ int cur_query_id = blockIdx .x * num_threads + threadIdx .x ;
90
+
91
+ if (cur_query_id < num_queries) {
92
+ // Update input_tokens
93
+ input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
94
+
95
+ int seq_len = seq_lens_ptr[cur_query_id];
96
+ int next_seq_len = seq_len + 1 ;
97
+ int next_input_pos = next_seq_len - 1 ;
98
+
99
+ // Update seq_lens
100
+ seq_lens_ptr[cur_query_id] = next_seq_len;
101
+ // Update input_positions
102
+ input_positions_ptr[cur_query_id] = next_input_pos;
103
+
104
+ int const * seq_block_tables_ptr =
105
+ block_tables_ptr + block_tables_stride * cur_query_id;
106
+
107
+ int block_index = next_input_pos / block_size;
108
+ int block_offset = next_input_pos % block_size;
109
+
110
+ // Update paged_kv_last_page_len
111
+ paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1 ;
112
+
113
+ int slot_num =
114
+ seq_block_tables_ptr[block_index] * block_size + block_offset;
115
+ // Update slot_mapping
116
+ slot_mapping_ptr[cur_query_id] = slot_num;
117
+ block_table_bound_ptr[cur_query_id] = div_ceil (next_seq_len, block_size);
118
+ }
119
+ }
120
+ }
121
+
122
+ __global__ void advance_step_flashinfer_indptr_kernel (
123
+ int num_threads, int num_seqs, int num_queries, int * paged_kv_indptr_ptr,
124
+ int * block_table_bound_ptr) {
125
+ int idx = blockIdx .x * num_threads + threadIdx .x ;
126
+
127
+ // Update paged_kv_indptr
128
+ if (idx < num_queries) {
129
+ int sum = 0 ;
130
+ for (int i = 0 ; i <= idx; ++i) {
131
+ sum += block_table_bound_ptr[i];
132
+ }
133
+ paged_kv_indptr_ptr[idx + 1 ] = sum;
134
+ }
135
+ }
136
+
137
+ __global__ void advance_step_flashinfer_indices_kernel (
138
+ int num_threads, int num_seqs, int num_queries, int const * block_tables_ptr,
139
+ int64_t const block_tables_stride, int * paged_kv_indices_ptr,
140
+ int * paged_kv_indptr_ptr, int * block_table_bound_ptr) {
141
+ int idx = blockIdx .x * num_threads + threadIdx .x ;
142
+ int row = idx / block_tables_stride;
143
+ int col = idx % block_tables_stride;
144
+
145
+ if (row < num_queries && col < block_table_bound_ptr[row]) {
146
+ paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
147
+ block_tables_ptr[row * block_tables_stride + col];
148
+ }
149
+ // if cudagraph, fill padded seqs with the last valid seq's indptr
150
+ if (num_queries < row && row <= num_seqs) {
151
+ paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
152
+ }
153
+ }
154
+
155
+ void advance_step_flashattn (int num_seqs, int num_queries, int block_size,
156
+ torch::Tensor& input_tokens, // type: long
157
+ torch::Tensor& sampled_token_ids, // type: long
158
+ torch::Tensor& input_positions, // type: long
159
+ torch::Tensor& seq_lens, // type: int
160
+ torch::Tensor& slot_mapping, // type: long
161
+ torch::Tensor& block_tables) { // type: int
89
162
90
163
if (logging) {
91
- printf (" advance_step :\n " );
164
+ printf (" advance_step_flashattn :\n " );
92
165
printf (" num_seqs = %d\n " , num_seqs);
93
166
printf (" num_queries = %d\n " , num_queries);
94
167
printf (" block_size = %d\n " , block_size);
@@ -108,24 +181,126 @@ void advance_step(int num_seqs, int num_queries, int block_size,
108
181
int blocks;
109
182
cudaDeviceGetAttribute (&blocks, cudaDevAttrMultiProcessorCount, dev);
110
183
111
- advance_step_kernel<max_threads><<<blocks, max_threads, 0 , stream>>> (
112
- num_seqs, num_queries, block_size,
184
+ advance_step_flashattn_kernel<max_threads>
185
+ <<<blocks, max_threads, 0 , stream>>> (
186
+ num_seqs, num_queries, block_size,
187
+ reinterpret_cast <long *>(input_tokens.data_ptr ()),
188
+ reinterpret_cast <long const *>(sampled_token_ids.data_ptr ()),
189
+ reinterpret_cast <long *>(input_positions.data_ptr ()),
190
+ reinterpret_cast <int *>(seq_lens.data_ptr ()),
191
+ reinterpret_cast <long *>(slot_mapping.data_ptr ()),
192
+ reinterpret_cast <int const *>(block_tables.data_ptr ()),
193
+ block_tables.stride (0 ));
194
+ }
195
+
196
+ void advance_step_flashinfer (
197
+ int num_seqs, int num_queries, int block_size,
198
+ torch::Tensor& input_tokens, // type: long
199
+ torch::Tensor& sampled_token_ids, // type: long
200
+ torch::Tensor& input_positions, // type: long
201
+ torch::Tensor& seq_lens, // type: int
202
+ torch::Tensor& slot_mapping, // type: long
203
+ torch::Tensor& block_tables, // type: int
204
+ torch::Tensor& paged_kv_indices, // type: int
205
+ torch::Tensor& paged_kv_indptr, // type: int
206
+ torch::Tensor& paged_kv_last_page_len, // type: int
207
+ torch::Tensor& block_table_bound) { // type: int
208
+
209
+ if (logging) {
210
+ printf (" advance_step_flashinfer:\n " );
211
+ printf (" num_seqs = %d\n " , num_seqs);
212
+ printf (" num_queries = %d\n " , num_queries);
213
+ printf (" block_size = %d\n " , block_size);
214
+ printf (" block_tables.stride(0) = %d\n " , block_tables.stride (0 ));
215
+ }
216
+ // Verify all tensors
217
+ verify_tensor (" input_tokens" , input_tokens, num_seqs, -1 , at::kLong );
218
+ // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
219
+ // at::kLong);
220
+ verify_tensor (" input_positions" , input_positions, num_seqs, -1 , at::kLong );
221
+ verify_tensor (" seq_lens" , seq_lens, num_seqs, -1 , at::kInt );
222
+ verify_tensor (" slot_mapping" , slot_mapping, num_seqs, -1 , at::kLong );
223
+ verify_tensor (" block_tables" , block_tables, num_seqs, -1 , at::kInt );
224
+
225
+ verify_tensor (" paged_kv_indices" , paged_kv_indices, -1 , -1 , at::kInt );
226
+ verify_tensor (" paged_kv_indptr" , paged_kv_indptr, num_seqs + 1 , -1 , at::kInt );
227
+ verify_tensor (" paged_kv_last_page_len" , paged_kv_last_page_len, num_seqs, -1 ,
228
+ at::kInt );
229
+
230
+ verify_tensor (" block_table_bound" , block_table_bound, num_seqs, -1 , at::kInt );
231
+
232
+ int dev = sampled_token_ids.get_device ();
233
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream (dev);
234
+
235
+ int blocks;
236
+ int threads;
237
+ cudaDeviceGetAttribute (&blocks, cudaDevAttrMultiProcessorCount, dev);
238
+ cudaDeviceGetAttribute (&threads, cudaDevAttrMaxThreadsPerBlock, dev);
239
+ if (logging) {
240
+ printf (" launching kernel with %d blocks\n " , blocks);
241
+ }
242
+
243
+ // TODO(will): support arbitrary block_tables stride
244
+ if ((blocks * threads) / block_tables.stride (0 ) < num_queries) {
245
+ TORCH_CHECK (false ,
246
+ " multi-step: not enough threads to map block_table to"
247
+ " FlashInfer's paged_kv_indices on GPU. Try reducing the number "
248
+ " of seqs," ,
249
+ " increasing the block size or take smaller steps." ,
250
+ " num_queries = " , num_queries,
251
+ " block_tables.stride(0) = " , block_tables.stride (0 ),
252
+ " blocks = " , blocks, " max_threads = " , threads);
253
+ }
254
+
255
+ advance_step_flashinfer_kernel<<<blocks, threads, 0 , stream>>> (
256
+ threads, num_seqs, num_queries, block_size,
113
257
reinterpret_cast <long *>(input_tokens.data_ptr ()),
114
258
reinterpret_cast <long const *>(sampled_token_ids.data_ptr ()),
115
259
reinterpret_cast <long *>(input_positions.data_ptr ()),
116
260
reinterpret_cast <int *>(seq_lens.data_ptr ()),
117
261
reinterpret_cast <long *>(slot_mapping.data_ptr ()),
118
262
reinterpret_cast <int const *>(block_tables.data_ptr ()),
119
- block_tables.stride (0 ));
263
+ block_tables.stride (0 ),
264
+ reinterpret_cast <int *>(paged_kv_last_page_len.data_ptr ()),
265
+ reinterpret_cast <int *>(block_table_bound.data_ptr ()));
266
+
267
+ advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0 , stream>>> (
268
+ threads, num_seqs, num_queries,
269
+ reinterpret_cast <int *>(paged_kv_indptr.data_ptr ()),
270
+ reinterpret_cast <int *>(block_table_bound.data_ptr ()));
271
+
272
+ advance_step_flashinfer_indices_kernel<<<blocks, threads, 0 , stream>>> (
273
+ threads, num_seqs, num_queries,
274
+ reinterpret_cast <int const *>(block_tables.data_ptr ()),
275
+ block_tables.stride (0 ),
276
+ reinterpret_cast <int *>(paged_kv_indices.data_ptr ()),
277
+ reinterpret_cast <int *>(paged_kv_indptr.data_ptr ()),
278
+ reinterpret_cast <int *>(block_table_bound.data_ptr ()));
120
279
}
121
280
122
281
} // namespace prepare_inputs
123
282
124
- void advance_step (int64_t num_seqs, int64_t num_queries, int64_t block_size,
125
- torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
126
- torch::Tensor& input_positions, torch::Tensor& seq_lens,
127
- torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
128
- prepare_inputs::advance_step (num_seqs, num_queries, block_size, input_tokens,
129
- sampled_token_ids, input_positions, seq_lens,
130
- slot_mapping, block_tables);
283
+ void advance_step_flashattn (int64_t num_seqs, int64_t num_queries,
284
+ int64_t block_size, torch::Tensor& input_tokens,
285
+ torch::Tensor& sampled_token_ids,
286
+ torch::Tensor& input_positions,
287
+ torch::Tensor& seq_lens,
288
+ torch::Tensor& slot_mapping,
289
+ torch::Tensor& block_tables) {
290
+ prepare_inputs::advance_step_flashattn (
291
+ num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
292
+ input_positions, seq_lens, slot_mapping, block_tables);
293
+ }
294
+
295
+ void advance_step_flashinfer (
296
+ int64_t num_seqs, int64_t num_queries, int64_t block_size,
297
+ torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
298
+ torch::Tensor& input_positions, torch::Tensor& seq_lens,
299
+ torch::Tensor& slot_mapping, torch::Tensor& block_tables,
300
+ torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
301
+ torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
302
+ prepare_inputs::advance_step_flashinfer (
303
+ num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
304
+ input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
305
+ paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
131
306
}
0 commit comments