Skip to content

Commit 783632d

Browse files
committed
prepare cutom quants in CLI
1 parent 71a60fb commit 783632d

File tree

3 files changed

+238
-260
lines changed

3 files changed

+238
-260
lines changed

Diff for: examples/quantize/quantize.cpp

+48
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,54 @@ int main(int argc, char ** argv) {
322322
} else {
323323
usage(argv[0]);
324324
}
325+
} else if (strcmp(argv[arg_idx], "--attn-q-type") == 0) {
326+
if (arg_idx < argc-1) {
327+
params.attn_q_type = parse_ggml_type(argv[++arg_idx]);
328+
} else {
329+
usage(argv[0]);
330+
}
331+
} else if (strcmp(argv[arg_idx], "--attn-k-type") == 0) {
332+
if (arg_idx < argc-1) {
333+
params.attn_k_type = parse_ggml_type(argv[++arg_idx]);
334+
} else {
335+
usage(argv[0]);
336+
}
337+
} else if (strcmp(argv[arg_idx], "--attn-v-type") == 0) {
338+
if (arg_idx < argc-1) {
339+
params.attn_v_type = parse_ggml_type(argv[++arg_idx]);
340+
} else {
341+
usage(argv[0]);
342+
}
343+
} else if (strcmp(argv[arg_idx], "--attn-qkv-type") == 0) {
344+
if (arg_idx < argc-1) {
345+
params.attn_qkv_type = parse_ggml_type(argv[++arg_idx]);
346+
} else {
347+
usage(argv[0]);
348+
}
349+
} else if (strcmp(argv[arg_idx], "--attn-output-type") == 0) {
350+
if (arg_idx < argc-1) {
351+
params.attn_output_type = parse_ggml_type(argv[++arg_idx]);
352+
} else {
353+
usage(argv[0]);
354+
}
355+
} else if (strcmp(argv[arg_idx], "--ffn-gate-type") == 0) {
356+
if (arg_idx < argc-1) {
357+
params.ffn_gate_type = parse_ggml_type(argv[++arg_idx]);
358+
} else {
359+
usage(argv[0]);
360+
}
361+
} else if (strcmp(argv[arg_idx], "--ffn-down-type") == 0) {
362+
if (arg_idx < argc-1) {
363+
params.ffn_down_type = parse_ggml_type(argv[++arg_idx]);
364+
} else {
365+
usage(argv[0]);
366+
}
367+
} else if (strcmp(argv[arg_idx], "--ffn-up-type") == 0) {
368+
if (arg_idx < argc-1) {
369+
params.ffn_up_type = parse_ggml_type(argv[++arg_idx]);
370+
} else {
371+
usage(argv[0]);
372+
}
325373
} else if (strcmp(argv[arg_idx], "--override-kv") == 0) {
326374
if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) {
327375
usage(argv[0]);

Diff for: include/llama.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,15 @@ extern "C" {
407407
int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
408408
enum llama_ftype ftype; // quantize to this llama_ftype
409409
enum ggml_type output_tensor_type; // output tensor type
410-
enum ggml_type token_embedding_type; // itoken embeddings tensor type
410+
enum ggml_type token_embedding_type; // token embeddings tensor type
411+
enum ggml_type attn_q_type; // attention query tensor type
412+
enum ggml_type attn_k_type; // attention key tensor type
413+
enum ggml_type attn_v_type; // attention value tensor type
414+
enum ggml_type attn_qkv_type; // attention query-key-value tensor type
415+
enum ggml_type attn_output_type; // attention output tensor type
416+
enum ggml_type ffn_gate_type; // FFN gate type
417+
enum ggml_type ffn_down_type; // FFN down type
418+
enum ggml_type ffn_up_type; // FFN up type
411419
bool allow_requantize; // allow quantizing non-f32/f16 tensors
412420
bool quantize_output_tensor; // quantize output.weight
413421
bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored

0 commit comments

Comments
 (0)