@@ -119,7 +119,7 @@ void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int stride, in
119
119
120
120
// -------------- FP16 -> Q
121
121
122
- template <int wbits >
122
+ template <int wbits_k, int wbits_v >
123
123
__global__ void fp16_to_q_kv_paged_kernel
124
124
(
125
125
const half* __restrict__ k_in,
@@ -172,11 +172,14 @@ __global__ void fp16_to_q_kv_paged_kernel
172
172
{
173
173
int j = i + blockIdx .y * BLOCKSIZE_Q;
174
174
if (j >= block_b) continue ;
175
- fp16_to_q<wbits>(t, in, out, scales, j, cal, dim);
175
+ if (kv)
176
+ fp16_to_q<wbits_v>(t, in, out, scales, j, cal, dim);
177
+ else
178
+ fp16_to_q<wbits_k>(t, in, out, scales, j, cal, dim);
176
179
}
177
180
}
178
181
179
- template <int wbits >
182
+ template <int wbits_k, int wbits_v >
180
183
__global__ void fp16_to_q_kv_kernel
181
184
(
182
185
const half* __restrict__ k_in,
@@ -193,13 +196,17 @@ __global__ void fp16_to_q_kv_kernel
193
196
)
194
197
{
195
198
int t = threadIdx .x ;
196
- const half* in = blockIdx .z ? v_in : k_in;
197
- unsigned char * out = blockIdx .z ? v_out : k_out;
198
- half* scales = blockIdx .z ? v_scales : k_scales;
199
- const half* cal = blockIdx .z ? cal_v : cal_k;
199
+ int kv = blockIdx .z & 1 ;
200
+ const half* in = kv ? v_in : k_in;
201
+ unsigned char * out = kv ? v_out : k_out;
202
+ half* scales = kv ? v_scales : k_scales;
203
+ const half* cal = kv ? cal_v : cal_k;
200
204
int block_offset = (offset + blockIdx .y * stride + blockIdx .x * BLOCKSIZE_Q);
201
205
202
- fp16_to_q<wbits>(t, in, out, scales, block_offset, cal, dim);
206
+ if (kv)
207
+ fp16_to_q<wbits_v>(t, in, out, scales, block_offset, cal, dim);
208
+ else
209
+ fp16_to_q<wbits_k>(t, in, out, scales, block_offset, cal, dim);
203
210
}
204
211
205
212
void array_fp16_to_q_kv_paged_cuda
@@ -229,7 +236,17 @@ void array_fp16_to_q_kv_paged_cuda
229
236
gridDim .z = batch_size * 2 ;
230
237
231
238
if (wbits == 4 )
232
- fp16_to_q_kv_paged_kernel<4 ><<<gridDim , blockDim >>>
239
+ fp16_to_q_kv_paged_kernel<4 , 4 ><<<gridDim , blockDim >>>
240
+ (
241
+ k_in, k_out, k_scales,
242
+ v_in, v_out, v_scales,
243
+ cache_seqlens, block_table,
244
+ pages_per_seq, page_size,
245
+ dim, q_len,
246
+ cal_k, cal_v
247
+ );
248
+ else if (wbits == 6 )
249
+ fp16_to_q_kv_paged_kernel<8 , 4 ><<<gridDim , blockDim >>>
233
250
(
234
251
k_in, k_out, k_scales,
235
252
v_in, v_out, v_scales,
@@ -239,7 +256,7 @@ void array_fp16_to_q_kv_paged_cuda
239
256
cal_k, cal_v
240
257
);
241
258
else if (wbits == 8 )
242
- fp16_to_q_kv_paged_kernel<8 ><<<gridDim , blockDim >>>
259
+ fp16_to_q_kv_paged_kernel<8 , 8 ><<<gridDim , blockDim >>>
243
260
(
244
261
k_in, k_out, k_scales,
245
262
v_in, v_out, v_scales,
@@ -275,14 +292,21 @@ void array_fp16_to_q_kv_cuda
275
292
gridDim .z = v_in ? 2 : 1 ;
276
293
277
294
if (wbits == 4 )
278
- fp16_to_q_kv_kernel<4 ><<<gridDim , blockDim >>> (
295
+ fp16_to_q_kv_kernel<4 , 4 ><<<gridDim , blockDim >>> (
296
+ k_in, k_out, k_scales,
297
+ v_in, v_out, v_scales,
298
+ dim, offset, stride,
299
+ cal_k, cal_v
300
+ );
301
+ else if (wbits == 6 )
302
+ fp16_to_q_kv_kernel<8 , 4 ><<<gridDim , blockDim >>> (
279
303
k_in, k_out, k_scales,
280
304
v_in, v_out, v_scales,
281
305
dim, offset, stride,
282
306
cal_k, cal_v
283
307
);
284
308
else if (wbits == 8 )
285
- fp16_to_q_kv_kernel<8 ><<<gridDim , blockDim >>> (
309
+ fp16_to_q_kv_kernel<8 , 8 ><<<gridDim , blockDim >>> (
286
310
k_in, k_out, k_scales,
287
311
v_in, v_out, v_scales,
288
312
dim, offset, stride,
@@ -292,7 +316,7 @@ void array_fp16_to_q_kv_cuda
292
316
293
317
// --------------- Q -> FP16
294
318
295
- template <int wbits >
319
+ template <int wbits_k, int wbits_v >
296
320
__global__ void q_to_fp16_kv_paged_kernel
297
321
(
298
322
const unsigned char * __restrict__ k_in,
@@ -342,11 +366,14 @@ __global__ void q_to_fp16_kv_paged_kernel
342
366
{
343
367
int j = i + blockIdx .y * BLOCKSIZE_Q;
344
368
if (j >= block_b) continue ;
345
- q_to_fp16<wbits>(t, in, scales, out, j, cal, dim);
369
+ if (kv)
370
+ q_to_fp16<wbits_v>(t, in, scales, out, j, cal, dim);
371
+ else
372
+ q_to_fp16<wbits_k>(t, in, scales, out, j, cal, dim);
346
373
}
347
374
}
348
375
349
- template <int wbits >
376
+ template <int wbits_k, int wbits_v >
350
377
__global__ void q_to_fp16_kv_kernel
351
378
(
352
379
const unsigned char * __restrict__ k_in,
@@ -363,13 +390,17 @@ __global__ void q_to_fp16_kv_kernel
363
390
)
364
391
{
365
392
int t = threadIdx .x ;
366
- const unsigned char * in = blockIdx .z ? v_in : k_in;
367
- const half* scales = blockIdx .z ? v_scales : k_scales;
368
- half* out = blockIdx .z ? v_out : k_out;
369
- const half* cal = blockIdx .z ? cal_v : cal_k;
393
+ int kv = blockIdx .z & 1 ;
394
+ const unsigned char * in = kv ? v_in : k_in;
395
+ const half* scales = kv ? v_scales : k_scales;
396
+ half* out = kv ? v_out : k_out;
397
+ const half* cal = kv ? cal_v : cal_k;
370
398
int block_offset = (offset + blockIdx .y * stride + blockIdx .x * BLOCKSIZE_Q);
371
399
372
- q_to_fp16<wbits>(t, in, scales, out, block_offset, cal, dim);
400
+ if (kv)
401
+ q_to_fp16<wbits_v>(t, in, scales, out, block_offset, cal, dim);
402
+ else
403
+ q_to_fp16<wbits_k>(t, in, scales, out, block_offset, cal, dim);
373
404
}
374
405
375
406
void array_q_to_fp16_kv_paged_cuda
@@ -398,7 +429,17 @@ void array_q_to_fp16_kv_paged_cuda
398
429
gridDim .z = batch_size * 2 ;
399
430
400
431
if (wbits == 4 )
401
- q_to_fp16_kv_paged_kernel<4 ><<<gridDim , blockDim >>>
432
+ q_to_fp16_kv_paged_kernel<4 , 4 ><<<gridDim , blockDim >>>
433
+ (
434
+ k_in, k_scales, k_out,
435
+ v_in, v_scales, v_out,
436
+ cache_seqlens, block_table,
437
+ pages_per_seq, page_size,
438
+ dim,
439
+ cal_k, cal_v
440
+ );
441
+ else if (wbits == 6 )
442
+ q_to_fp16_kv_paged_kernel<8 , 4 ><<<gridDim , blockDim >>>
402
443
(
403
444
k_in, k_scales, k_out,
404
445
v_in, v_scales, v_out,
@@ -408,7 +449,7 @@ void array_q_to_fp16_kv_paged_cuda
408
449
cal_k, cal_v
409
450
);
410
451
else if (wbits == 8 )
411
- q_to_fp16_kv_paged_kernel<8 ><<<gridDim , blockDim >>>
452
+ q_to_fp16_kv_paged_kernel<8 , 8 ><<<gridDim , blockDim >>>
412
453
(
413
454
k_in, k_scales, k_out,
414
455
v_in, v_scales, v_out,
@@ -444,14 +485,21 @@ void array_q_to_fp16_kv_cuda
444
485
gridDim .z = v_in ? 2 : 1 ;
445
486
446
487
if (wbits == 4 )
447
- q_to_fp16_kv_kernel<4 ><<<gridDim , blockDim >>> (
488
+ q_to_fp16_kv_kernel<4 , 4 ><<<gridDim , blockDim >>> (
489
+ k_in, k_scales, k_out,
490
+ v_in, v_scales, v_out,
491
+ dim, offset, stride,
492
+ cal_k, cal_v
493
+ );
494
+ else if (wbits == 6 )
495
+ q_to_fp16_kv_kernel<8 , 4 ><<<gridDim , blockDim >>> (
448
496
k_in, k_scales, k_out,
449
497
v_in, v_scales, v_out,
450
498
dim, offset, stride,
451
499
cal_k, cal_v
452
500
);
453
501
else if (wbits == 8 )
454
- q_to_fp16_kv_kernel<8 ><<<gridDim , blockDim >>> (
502
+ q_to_fp16_kv_kernel<8 , 8 ><<<gridDim , blockDim >>> (
455
503
k_in, k_scales, k_out,
456
504
v_in, v_scales, v_out,
457
505
dim, offset, stride,
0 commit comments