Skip to content

Commit c0a0e1d

Browse files
committed
Fix for draft KVQ discrepencies.
Noob hack because... ^^ And some text corrections.
1 parent a71c356 commit c0a0e1d

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

expose.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ struct load_model_inputs
6767
const float tensor_split[tensor_split_max] = {};
6868
const int quant_k = 0;
6969
const int quant_v = 0;
70-
const int draft_quant_k = 0;
71-
const int draft_quant_v = 0;
70+
const int draft_quant_k = -1;
71+
const int draft_quant_v = -1;
7272
const bool quiet = false;
7373
const int debugmode = 0;
7474
};

gpttype_adapter.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,9 @@ static void speculative_decoding_setup(std::string spec_model_filename, const ll
599599
draft_ctx_params.n_threads_batch = base_ctx_params.n_threads_batch;
600600
draft_ctx_params.flash_attn = base_ctx_params.flash_attn;
601601

602+
draft_quant_k=draft_quant_k-1;
603+
draft_quant_v=draft_quant_v-1;
604+
602605
if (draft_quant_k==-1)
603606
{
604607
draft_ctx_params.type_k = base_ctx_params.type_k;

koboldcpp.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1715,22 +1715,22 @@ def load_model(model_filename):
17151715
else:
17161716
inputs.quant_k = inputs.quant_v = 0
17171717

1718-
if args.draft_quantkv==-1:
1719-
inputs.draft_quant_k = inputs.draft_quant_v = args.quantkv
1720-
elif args.draft_quantkv==0:
1721-
inputs.draft_quant_k = inputs.draft_quant_v = 0
1722-
elif args.draft_quantkv>0 and args.draft_quantkv<15:
1718+
# if args.draft_quantkv==-1:
1719+
# inputs.draft_quant_k = inputs.draft_quant_v = args.quantkv
1720+
# elif args.draft_quantkv==1:
1721+
# inputs.draft_quant_k = inputs.draft_quant_v = 1
1722+
if args.draft_quantkv>0 and args.draft_quantkv<16:
17231723
inputs.draft_quant_k = inputs.draft_quant_v = args.draft_quantkv
17241724
inputs.flash_attention = True
1725-
elif args.draft_quantkv>2 and args.draft_quantkv<8:
1725+
elif args.draft_quantkv>0 and args.draft_quantkv<9:
17261726
inputs.use_contextshift = 0
1727-
elif args.draft_quantkv>7 and args.draft_quantkv<15:
1727+
elif args.draft_quantkv>8 and args.draft_quantkv<16:
17281728
inputs.use_contextshift = 0
1729-
elif args.draft_quantkv==15:
1730-
inputs.draft_quant_k = inputs.draft_quant_v = 15
1729+
elif args.draft_quantkv==16:
1730+
inputs.draft_quant_k = inputs.draft_quant_v = 16
17311731
inputs.flash_attention = False
17321732
# inputs.use_contextshift = 0
1733-
elif args.draft_quantkv>15 and args.draft_quantkv<23:
1733+
elif args.draft_quantkv>16 and args.draft_quantkv<24:
17341734
inputs.draft_quant_k = inputs.draft_quant_v = args.draft_quantkv
17351735
inputs.flash_attention = False
17361736
else:
@@ -3994,7 +3994,7 @@ def hide_tooltip(event):
39943994
lowvram_var = ctk.IntVar()
39953995
mmq_var = ctk.IntVar(value=1)
39963996
quantkv_var = ctk.IntVar(value=0)
3997-
draft_quantkv_var = ctk.IntVar(value=-1)
3997+
draft_quantkv_var = ctk.IntVar(value=0)
39983998
blas_threads_var = ctk.StringVar()
39993999
blas_size_var = ctk.IntVar()
40004000

@@ -4544,7 +4544,7 @@ def togglerope(a,b,c):
45444544
makecheckbox(tokens_tab, "Use FlashAttention", flashattention, 28, tooltiptxt="Enable flash attention for GGUF models.")
45454545
# noqkvlabel = makelabel(tokens_tab,"Requirments Not Met",31,0,"Requires FlashAttention ENABLED and ContextShift DISABLED.")
45464546
# noqkvlabel.configure(text_color="#ff5555")
4547-
qkvslider,qkvlabel,qkvtitle = makeslider(tokens_tab, "Quantize KV Cache:", quantkv_text, quantkv_var, 0, 23, 30, set=0,tooltip="Enable quantization of KV cache (KVQ). Mode 0 (F16) is default. Modes 1-12 requires FlashAttention and disables ContextShift.\nModes 15-20 work without FA, for incompatible models. 0,13,14 can work with or without.")
4547+
qkvslider,qkvlabel,qkvtitle = makeslider(tokens_tab, "Quantize KV Cache:", quantkv_text, quantkv_var, 0, 22, 30, set=0,tooltip="Enable quantization of KV cache (KVQ). Mode 0 (F16) is default. Modes 1-12 requires FlashAttention and disables ContextShift.\nModes 15-22 work without FA, for incompatible models.")
45484548

45494549
# load model
45504550
makefileentry(tokens_tab, "Model:", "Select GGML or GGML Model File", model_var, 50, 576, onchoosefile=on_picked_model_file, filetypes=[("GGML bin or GGUF", ("*.bin","*.gguf"))] ,tooltiptxt="Select a GGUF or GGML model file on disk to be loaded.")
@@ -4569,7 +4569,7 @@ def togglerope(a,b,c):
45694569
makelabelentry(model_tab, "Draft Amount: ", draftamount_var, 13, 50,padx=100,singleline=True,tooltip="How many tokens to draft per chunk before verifying results")
45704570
makelabelentry(model_tab, "Splits: ", draftgpusplit_str_vars, 13, 50,padx=210,singleline=True,tooltip="Distribution of draft model layers. Leave blank to follow main model's gpu split. Only works if multi-gpu (All) selected in main model.", labelpadx=160)
45714571
makelabelentry(model_tab, "Layers: ", draftgpulayers_var, 13, 50,padx=320,singleline=True,tooltip="How many layers to GPU offload for the draft model", labelpadx=270)
4572-
makeslider(model_tab, "Quantize Draft KV Cache:", draft_quantkv_text, draft_quantkv_var, 0, 23, 30, set=-1,tooltip="Enable quantization of Draft KV cache (D_KVQ). Mode 0 (F16) is default. Modes 1-12 requires FlashAttention and disables ContextShift.\nModes 15-20 work without FA, for incompatible models. 0,13,14 can work with or without.")
4572+
makeslider(model_tab, "Quantize Draft KV Cache:", draft_quantkv_text, draft_quantkv_var, 0, 23, 30, set=-1,tooltip="Enable quantization of Draft KV cache (D_KVQ). Mode -1 (same as main) is default. Mode 0 (F16) is FA and non-FA both. Modes 1-12 requires FlashAttention and disables ContextShift.\nModes 15-22 work without FA, for incompatible models.")
45734573
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 15,width=280,singlerow=True,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.")
45744574
makefileentry(model_tab, "ChatCompletions Adapter:", "Select ChatCompletions Adapter File", chatcompletionsadapter_var, 24, width=250, filetypes=[("JSON Adapter", "*.json")], tooltiptxt="Select an optional ChatCompletions Adapter JSON file to force custom instruct tags.")
45754575
def pickpremadetemplate():
@@ -4781,7 +4781,7 @@ def export_vars():
47814781
args.nocertify = nocertifymode.get()==1
47824782
args.nomodel = nomodel.get()==1
47834783
args.quantkv = int(quantkv_values[int(quantkv_var.get())])
4784-
args.draft_quantkv = int(draft_quantkv_values[int(draft_quantkv_var.get())])
4784+
args.draft_quantkv = int(draft_quantkv_values[int(draft_quantkv_var.get()+1)])
47854785

47864786
args.poslayeroffset = int(poslayeroffset_values[int(poslayeroffset_var.get())])
47874787
args.neglayeroffset = int(neglayeroffset_values[int(neglayeroffset_var.get())])

0 commit comments

Comments
 (0)