Skip to content

Commit f3596fc

Browse files
committed
Add Q6 cache mode
1 parent f6abbba commit f3596fc

File tree

6 files changed

+127
-42
lines changed

6 files changed

+127
-42
lines changed

examples/chat.py

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ExLlamaV2Cache,
99
ExLlamaV2Cache_8bit,
1010
ExLlamaV2Cache_Q4,
11+
ExLlamaV2Cache_Q6,
1112
ExLlamaV2Cache_Q8,
1213
ExLlamaV2Tokenizer,
1314
model_init,
@@ -55,6 +56,7 @@
5556

5657
parser.add_argument("-c8", "--cache_8bit", action = "store_true", help = "Use 8-bit (FP8) cache")
5758
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
59+
parser.add_argument("-cq6", "--cache_q6", action = "store_true", help = "Use Q6 cache")
5860
parser.add_argument("-cq8", "--cache_q8", action = "store_true", help = "Use Q8 cache")
5961

6062
parser.add_argument("-ngram", "--ngram_decoding", action = "store_true", help = "Use n-gram speculative decoding")
@@ -130,6 +132,8 @@
130132
draft_cache = ExLlamaV2Cache_8bit(draft_model)
131133
elif args.cache_q4:
132134
draft_cache = ExLlamaV2Cache_Q4(draft_model)
135+
elif args.cache_q6:
136+
draft_cache = ExLlamaV2Cache_Q6(draft_model)
133137
elif args.cache_q8:
134138
draft_cache = ExLlamaV2Cache_Q8(draft_model)
135139
else:
@@ -141,6 +145,8 @@
141145
cache = ExLlamaV2Cache_8bit(model, lazy = not model.loaded)
142146
elif args.cache_q4:
143147
cache = ExLlamaV2Cache_Q4(model, lazy = not model.loaded)
148+
elif args.cache_q6:
149+
cache = ExLlamaV2Cache_Q6(model, lazy=not model.loaded)
144150
elif args.cache_q8:
145151
cache = ExLlamaV2Cache_Q8(model, lazy = not model.loaded)
146152
else:

exllamav2/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from exllamav2.cache import ExLlamaV2CacheBase
55
from exllamav2.cache import ExLlamaV2Cache
66
from exllamav2.cache import ExLlamaV2Cache_Q4
7+
from exllamav2.cache import ExLlamaV2Cache_Q6
78
from exllamav2.cache import ExLlamaV2Cache_Q8
89
from exllamav2.cache import ExLlamaV2Cache_8bit
910
from exllamav2.config import ExLlamaV2Config

exllamav2/cache.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ class ExLlamaV2CacheBase:
2525
head_dim: int
2626

2727
dtype: torch.dtype
28-
weights_per_element: int
28+
weights_per_element_k: int
29+
weights_per_element_v: int
2930
has_scales: bool
3031

3132

@@ -34,14 +35,16 @@ def __init__(self,
3435
batch_size: int,
3536
max_seq_len: int,
3637
dtype: torch.dtype,
37-
weights_per_element: int,
38+
weights_per_element_k: int,
39+
weights_per_element_v: int,
3840
has_scales: bool):
3941

4042
self.model = model
4143
self.max_seq_len = max_seq_len if max_seq_len != -1 else self.model.config.max_seq_len
4244
self.batch_size = batch_size
4345
self.dtype = dtype
44-
self.weights_per_element = weights_per_element
46+
self.weights_per_element_k = weights_per_element_k
47+
self.weights_per_element_v = weights_per_element_v
4548
self.has_scales = has_scales
4649

4750
self.key_states = []
@@ -55,7 +58,8 @@ def __init__(self,
5558

5659
self.current_seq_len = 0
5760
self.shape_basic = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim)
58-
self.shape_w = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // self.weights_per_element)
61+
self.shape_wk = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // self.weights_per_element_k)
62+
self.shape_wv = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // self.weights_per_element_v)
5963
self.shape_s = (self.batch_size, self.max_seq_len, self.num_key_value_heads, self.head_dim // 32)
6064

6165

@@ -74,8 +78,8 @@ def create_state_tensors(self,
7478

7579
if copy_from is None:
7680
device = self.model.cache_map[i]
77-
p_key_states = torch.zeros(self.shape_w, dtype = self.dtype, device = device).contiguous()
78-
p_value_states = torch.zeros(self.shape_w, dtype = self.dtype, device = device).contiguous()
81+
p_key_states = torch.zeros(self.shape_wk, dtype = self.dtype, device = device).contiguous()
82+
p_value_states = torch.zeros(self.shape_wv, dtype = self.dtype, device = device).contiguous()
7983
if self.has_scales:
8084
p_key_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = device).contiguous()
8185
p_value_scales = torch.zeros(self.shape_s, dtype = torch.float16, device = device).contiguous()
@@ -115,8 +119,8 @@ def update_cache_tensors(self):
115119
self.key_states[k] = None
116120
self.value_states[k] = None
117121

118-
p_key_states = torch.zeros(self.shape_w, dtype = self.dtype, device = v).contiguous()
119-
p_value_states = torch.zeros(self.shape_w, dtype = self.dtype, device = v).contiguous()
122+
p_key_states = torch.zeros(self.shape_wk, dtype = self.dtype, device = v).contiguous()
123+
p_value_states = torch.zeros(self.shape_wv, dtype = self.dtype, device = v).contiguous()
120124
self.key_states[k] = p_key_states
121125
self.value_states[k] = p_value_states
122126
if self.has_scales:
@@ -220,7 +224,7 @@ def __init__(self,
220224
copy_from: ExLlamaV2Cache | None = None,
221225
lazy: bool = False):
222226

223-
super().__init__(model, batch_size, max_seq_len, torch.half, 1, False)
227+
super().__init__(model, batch_size, max_seq_len, torch.half, 1, 1, False)
224228

225229
self.create_state_tensors(copy_from, lazy)
226230

@@ -280,7 +284,7 @@ def __init__(self,
280284
copy_from: ExLlamaV2Cache_8bit | None = None,
281285
lazy: bool = False):
282286

283-
super().__init__(model, batch_size, max_seq_len, torch.uint8, 1, False)
287+
super().__init__(model, batch_size, max_seq_len, torch.uint8, 1, 1, False)
284288

285289
self.create_state_tensors(copy_from, lazy)
286290

@@ -365,9 +369,10 @@ def __init__(self,
365369
max_seq_len: int = -1,
366370
copy_from: ExLlamaV2Cache_Q4 | None = None,
367371
lazy: bool = False,
368-
weights_per_byte: int = -1):
372+
weights_per_byte_k: int = -1,
373+
weights_per_byte_v: int = -1):
369374

370-
super().__init__(model, batch_size, max_seq_len, torch.uint8, weights_per_byte, True)
375+
super().__init__(model, batch_size, max_seq_len, torch.uint8, weights_per_byte_k, weights_per_byte_v, True)
371376
cfg = self.model.config
372377

373378
self.create_state_tensors(copy_from, lazy)
@@ -607,18 +612,31 @@ def __init__(self,
607612
copy_from: ExLlamaV2Cache_Q4 | None = None,
608613
lazy: bool = False):
609614

610-
super().__init__(model, batch_size, max_seq_len, copy_from, lazy, 2)
615+
super().__init__(model, batch_size, max_seq_len, copy_from, lazy, 2, 2)
611616
self.wbits = 4
612617

613618

619+
class ExLlamaV2Cache_Q6(ExLlamaV2Cache_Q):
620+
621+
def __init__(self,
622+
model: ExLlamaV2,
623+
batch_size: int = 1,
624+
max_seq_len: int = -1,
625+
copy_from: ExLlamaV2Cache_Q6 | None = None,
626+
lazy: bool = False):
627+
628+
super().__init__(model, batch_size, max_seq_len, copy_from, lazy, 1, 2)
629+
self.wbits = 6
630+
631+
614632
class ExLlamaV2Cache_Q8(ExLlamaV2Cache_Q):
615633

616634
def __init__(self,
617635
model: ExLlamaV2,
618636
batch_size: int = 1,
619637
max_seq_len: int = -1,
620-
copy_from: ExLlamaV2Cache_Q4 | None = None,
638+
copy_from: ExLlamaV2Cache_Q8 | None = None,
621639
lazy: bool = False):
622640

623-
super().__init__(model, batch_size, max_seq_len, copy_from, lazy, 1)
641+
super().__init__(model, batch_size, max_seq_len, copy_from, lazy, 1, 1)
624642
self.wbits = 8

exllamav2/exllamav2_ext/cuda/cache.cu

+72-24
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int stride, in
119119

120120
// -------------- FP16 -> Q
121121

122-
template <int wbits>
122+
template <int wbits_k, int wbits_v>
123123
__global__ void fp16_to_q_kv_paged_kernel
124124
(
125125
const half* __restrict__ k_in,
@@ -172,11 +172,14 @@ __global__ void fp16_to_q_kv_paged_kernel
172172
{
173173
int j = i + blockIdx.y * BLOCKSIZE_Q;
174174
if (j >= block_b) continue;
175-
fp16_to_q<wbits>(t, in, out, scales, j, cal, dim);
175+
if (kv)
176+
fp16_to_q<wbits_v>(t, in, out, scales, j, cal, dim);
177+
else
178+
fp16_to_q<wbits_k>(t, in, out, scales, j, cal, dim);
176179
}
177180
}
178181

179-
template <int wbits>
182+
template <int wbits_k, int wbits_v>
180183
__global__ void fp16_to_q_kv_kernel
181184
(
182185
const half* __restrict__ k_in,
@@ -193,13 +196,17 @@ __global__ void fp16_to_q_kv_kernel
193196
)
194197
{
195198
int t = threadIdx.x;
196-
const half* in = blockIdx.z ? v_in : k_in;
197-
unsigned char* out = blockIdx.z ? v_out : k_out;
198-
half* scales = blockIdx.z ? v_scales : k_scales;
199-
const half* cal = blockIdx.z ? cal_v : cal_k;
199+
int kv = blockIdx.z & 1;
200+
const half* in = kv ? v_in : k_in;
201+
unsigned char* out = kv ? v_out : k_out;
202+
half* scales = kv ? v_scales : k_scales;
203+
const half* cal = kv ? cal_v : cal_k;
200204
int block_offset = (offset + blockIdx.y * stride + blockIdx.x * BLOCKSIZE_Q);
201205

202-
fp16_to_q<wbits>(t, in, out, scales, block_offset, cal, dim);
206+
if (kv)
207+
fp16_to_q<wbits_v>(t, in, out, scales, block_offset, cal, dim);
208+
else
209+
fp16_to_q<wbits_k>(t, in, out, scales, block_offset, cal, dim);
203210
}
204211

205212
void array_fp16_to_q_kv_paged_cuda
@@ -229,7 +236,17 @@ void array_fp16_to_q_kv_paged_cuda
229236
gridDim.z = batch_size * 2;
230237

231238
if (wbits == 4)
232-
fp16_to_q_kv_paged_kernel<4><<<gridDim, blockDim>>>
239+
fp16_to_q_kv_paged_kernel<4, 4><<<gridDim, blockDim>>>
240+
(
241+
k_in, k_out, k_scales,
242+
v_in, v_out, v_scales,
243+
cache_seqlens, block_table,
244+
pages_per_seq, page_size,
245+
dim, q_len,
246+
cal_k, cal_v
247+
);
248+
else if (wbits == 6)
249+
fp16_to_q_kv_paged_kernel<8, 4><<<gridDim, blockDim>>>
233250
(
234251
k_in, k_out, k_scales,
235252
v_in, v_out, v_scales,
@@ -239,7 +256,7 @@ void array_fp16_to_q_kv_paged_cuda
239256
cal_k, cal_v
240257
);
241258
else if (wbits == 8)
242-
fp16_to_q_kv_paged_kernel<8><<<gridDim, blockDim>>>
259+
fp16_to_q_kv_paged_kernel<8, 8><<<gridDim, blockDim>>>
243260
(
244261
k_in, k_out, k_scales,
245262
v_in, v_out, v_scales,
@@ -275,14 +292,21 @@ void array_fp16_to_q_kv_cuda
275292
gridDim.z = v_in ? 2 : 1;
276293

277294
if (wbits == 4)
278-
fp16_to_q_kv_kernel<4><<<gridDim, blockDim>>>(
295+
fp16_to_q_kv_kernel<4, 4><<<gridDim, blockDim>>>(
296+
k_in, k_out, k_scales,
297+
v_in, v_out, v_scales,
298+
dim, offset, stride,
299+
cal_k, cal_v
300+
);
301+
else if (wbits == 6)
302+
fp16_to_q_kv_kernel<8, 4><<<gridDim, blockDim>>>(
279303
k_in, k_out, k_scales,
280304
v_in, v_out, v_scales,
281305
dim, offset, stride,
282306
cal_k, cal_v
283307
);
284308
else if (wbits == 8)
285-
fp16_to_q_kv_kernel<8><<<gridDim, blockDim>>>(
309+
fp16_to_q_kv_kernel<8, 8><<<gridDim, blockDim>>>(
286310
k_in, k_out, k_scales,
287311
v_in, v_out, v_scales,
288312
dim, offset, stride,
@@ -292,7 +316,7 @@ void array_fp16_to_q_kv_cuda
292316

293317
// --------------- Q -> FP16
294318

295-
template <int wbits>
319+
template <int wbits_k, int wbits_v>
296320
__global__ void q_to_fp16_kv_paged_kernel
297321
(
298322
const unsigned char* __restrict__ k_in,
@@ -342,11 +366,14 @@ __global__ void q_to_fp16_kv_paged_kernel
342366
{
343367
int j = i + blockIdx.y * BLOCKSIZE_Q;
344368
if (j >= block_b) continue;
345-
q_to_fp16<wbits>(t, in, scales, out, j, cal, dim);
369+
if (kv)
370+
q_to_fp16<wbits_v>(t, in, scales, out, j, cal, dim);
371+
else
372+
q_to_fp16<wbits_k>(t, in, scales, out, j, cal, dim);
346373
}
347374
}
348375

349-
template <int wbits>
376+
template <int wbits_k, int wbits_v>
350377
__global__ void q_to_fp16_kv_kernel
351378
(
352379
const unsigned char* __restrict__ k_in,
@@ -363,13 +390,17 @@ __global__ void q_to_fp16_kv_kernel
363390
)
364391
{
365392
int t = threadIdx.x;
366-
const unsigned char* in = blockIdx.z ? v_in : k_in;
367-
const half* scales = blockIdx.z ? v_scales : k_scales;
368-
half* out = blockIdx.z ? v_out : k_out;
369-
const half* cal = blockIdx.z ? cal_v : cal_k;
393+
int kv = blockIdx.z & 1;
394+
const unsigned char* in = kv ? v_in : k_in;
395+
const half* scales = kv ? v_scales : k_scales;
396+
half* out = kv ? v_out : k_out;
397+
const half* cal = kv ? cal_v : cal_k;
370398
int block_offset = (offset + blockIdx.y * stride + blockIdx.x * BLOCKSIZE_Q);
371399

372-
q_to_fp16<wbits>(t, in, scales, out, block_offset, cal, dim);
400+
if (kv)
401+
q_to_fp16<wbits_v>(t, in, scales, out, block_offset, cal, dim);
402+
else
403+
q_to_fp16<wbits_k>(t, in, scales, out, block_offset, cal, dim);
373404
}
374405

375406
void array_q_to_fp16_kv_paged_cuda
@@ -398,7 +429,17 @@ void array_q_to_fp16_kv_paged_cuda
398429
gridDim.z = batch_size * 2;
399430

400431
if (wbits == 4)
401-
q_to_fp16_kv_paged_kernel<4><<<gridDim, blockDim>>>
432+
q_to_fp16_kv_paged_kernel<4, 4><<<gridDim, blockDim>>>
433+
(
434+
k_in, k_scales, k_out,
435+
v_in, v_scales, v_out,
436+
cache_seqlens, block_table,
437+
pages_per_seq, page_size,
438+
dim,
439+
cal_k, cal_v
440+
);
441+
else if (wbits == 6)
442+
q_to_fp16_kv_paged_kernel<8, 4><<<gridDim, blockDim>>>
402443
(
403444
k_in, k_scales, k_out,
404445
v_in, v_scales, v_out,
@@ -408,7 +449,7 @@ void array_q_to_fp16_kv_paged_cuda
408449
cal_k, cal_v
409450
);
410451
else if (wbits == 8)
411-
q_to_fp16_kv_paged_kernel<8><<<gridDim, blockDim>>>
452+
q_to_fp16_kv_paged_kernel<8, 8><<<gridDim, blockDim>>>
412453
(
413454
k_in, k_scales, k_out,
414455
v_in, v_scales, v_out,
@@ -444,14 +485,21 @@ void array_q_to_fp16_kv_cuda
444485
gridDim.z = v_in ? 2 : 1;
445486

446487
if (wbits == 4)
447-
q_to_fp16_kv_kernel<4><<<gridDim, blockDim>>>(
488+
q_to_fp16_kv_kernel<4, 4><<<gridDim, blockDim>>>(
489+
k_in, k_scales, k_out,
490+
v_in, v_scales, v_out,
491+
dim, offset, stride,
492+
cal_k, cal_v
493+
);
494+
else if (wbits == 6)
495+
q_to_fp16_kv_kernel<8, 4><<<gridDim, blockDim>>>(
448496
k_in, k_scales, k_out,
449497
v_in, v_scales, v_out,
450498
dim, offset, stride,
451499
cal_k, cal_v
452500
);
453501
else if (wbits == 8)
454-
q_to_fp16_kv_kernel<8><<<gridDim, blockDim>>>(
502+
q_to_fp16_kv_kernel<8, 8><<<gridDim, blockDim>>>(
455503
k_in, k_scales, k_out,
456504
v_in, v_scales, v_out,
457505
dim, offset, stride,

exllamav2/exllamav2_ext/ext_cache.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void fp16_to_q_kv
107107
TORCH_CHECK_SHAPES(k_in, 0, v_in, 0, 1);
108108
TORCH_CHECK_SHAPES(k_in, 1, v_in, 1, 1);
109109
TORCH_CHECK_SHAPES(k_in, 2, v_in, 2, 1);
110-
TORCH_CHECK_SHAPES(k_in, 3, v_in, 3, 1);
110+
// TORCH_CHECK_SHAPES(k_in, 3, v_in, 3, 1);
111111

112112
if (!cal_k.device().is_meta())
113113
TORCH_CHECK_SHAPES_OPT(cal_k, 0, k_in, 2, 1);
@@ -207,7 +207,7 @@ void q_to_fp16_kv
207207
TORCH_CHECK_SHAPES(k_in, 0, v_in, 0, 1);
208208
TORCH_CHECK_SHAPES(k_in, 1, v_in, 1, 1);
209209
TORCH_CHECK_SHAPES(k_in, 2, v_in, 2, 1);
210-
TORCH_CHECK_SHAPES(k_in, 3, v_in, 3, 1);
210+
// TORCH_CHECK_SHAPES(k_in, 3, v_in, 3, 1);
211211

212212
if (!cal_k.device().is_meta())
213213
TORCH_CHECK_SHAPES_OPT(cal_k, 0, k_out, 2, 1);

0 commit comments

Comments
 (0)