Skip to content

Commit 7b8033b

Browse files
committed
Merge remote-tracking branch 'origin/skystream' into frankenrope
2 parents ccfd1ff + 7626690 commit 7b8033b

File tree

5 files changed

+58
-35
lines changed

5 files changed

+58
-35
lines changed

convert-hf-to-gguf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,6 +1632,12 @@ def set_gguf_parameters(self):
16321632
super().set_gguf_parameters()
16331633
if (n_experts := self.hparams.get("num_experts")) is not None:
16341634
self.gguf_writer.add_expert_count(n_experts)
1635+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
1636+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
1637+
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
1638+
if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None:
1639+
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size)
1640+
logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}")
16351641

16361642
_experts: list[dict[str, Tensor]] | None = None
16371643

ggml-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#define MIN(a, b) ((a) < (b) ? (a) : (b))
1818
#define MAX(a, b) ((a) > (b) ? (a) : (b))
1919

20-
#if defined(_WIN32)
20+
#if defined(_MSC_VER)
2121

2222
#define m512bh(p) p
2323
#define m512i(p) p

gguf-py/gguf/constants.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,22 @@ class General:
3333
FILE_TYPE = "general.file_type"
3434

3535
class LLM:
36-
VOCAB_SIZE = "{arch}.vocab_size"
37-
CONTEXT_LENGTH = "{arch}.context_length"
38-
EMBEDDING_LENGTH = "{arch}.embedding_length"
39-
BLOCK_COUNT = "{arch}.block_count"
40-
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
41-
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
42-
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
43-
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
44-
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
45-
EXPERT_COUNT = "{arch}.expert_count"
46-
EXPERT_USED_COUNT = "{arch}.expert_used_count"
47-
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
48-
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
49-
POOLING_TYPE = "{arch}.pooling_type"
50-
LOGIT_SCALE = "{arch}.logit_scale"
36+
VOCAB_SIZE = "{arch}.vocab_size"
37+
CONTEXT_LENGTH = "{arch}.context_length"
38+
EMBEDDING_LENGTH = "{arch}.embedding_length"
39+
BLOCK_COUNT = "{arch}.block_count"
40+
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
41+
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
42+
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
43+
EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
44+
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
45+
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
46+
EXPERT_COUNT = "{arch}.expert_count"
47+
EXPERT_USED_COUNT = "{arch}.expert_used_count"
48+
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
49+
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
50+
POOLING_TYPE = "{arch}.pooling_type"
51+
LOGIT_SCALE = "{arch}.logit_scale"
5152

5253
class Attention:
5354
HEAD_COUNT = "{arch}.attention.head_count"

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,9 @@ def add_feed_forward_length(self, length: int) -> None:
394394
def add_expert_feed_forward_length(self, length: int) -> None:
395395
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
396396

397+
def add_expert_shared_feed_forward_length(self, length: int) -> None:
398+
self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
399+
397400
def add_parallel_residual(self, use: bool) -> None:
398401
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
399402

llama.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ enum llm_kv {
310310
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
311311
LLM_KV_FEED_FORWARD_LENGTH,
312312
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
313+
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
313314
LLM_KV_USE_PARALLEL_RESIDUAL,
314315
LLM_KV_TENSOR_DATA_LAYOUT,
315316
LLM_KV_EXPERT_COUNT,
@@ -388,21 +389,22 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
388389
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
389390
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
390391

391-
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
392-
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
393-
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
394-
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
395-
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
396-
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
397-
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
398-
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
399-
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
400-
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
401-
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
402-
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
403-
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
404-
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
405-
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
392+
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
393+
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
394+
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
395+
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
396+
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
397+
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
398+
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
399+
{ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
400+
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
401+
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
402+
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
403+
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
404+
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
405+
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
406+
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
407+
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
406408

407409
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
408410
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -1878,6 +1880,7 @@ struct llama_hparams {
18781880
uint32_t n_lora_q = 0;
18791881
uint32_t n_lora_kv = 0;
18801882
uint32_t n_ff_exp = 0;
1883+
uint32_t n_ff_shexp = 0;
18811884
uint32_t n_expert_shared = 0;
18821885
float expert_weights_scale = 0.0;
18831886

@@ -1926,6 +1929,7 @@ struct llama_hparams {
19261929
if (this->n_lora_q != other.n_lora_q) return true;
19271930
if (this->n_lora_kv != other.n_lora_kv) return true;
19281931
if (this->n_ff_exp != other.n_ff_exp) return true;
1932+
if (this->n_ff_shexp != other.n_ff_shexp) return true;
19291933
if (this->n_expert_shared != other.n_expert_shared) return true;
19301934

19311935
if (this->rope_finetuned != other.rope_finetuned) return true;
@@ -4386,6 +4390,9 @@ static void llm_load_hparams(
43864390
} break;
43874391
case LLM_ARCH_QWEN2MOE:
43884392
{
4393+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
4394+
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
4395+
43894396
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
43904397
switch (hparams.n_layer) {
43914398
case 24: model.type = e_model::MODEL_A2_7B; break;
@@ -5202,6 +5209,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
52025209
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
52035210
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
52045211
}
5212+
5213+
if (model.arch == LLM_ARCH_QWEN2MOE) {
5214+
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
5215+
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
5216+
}
52055217
}
52065218

52075219
// Returns false if cancelled by progress_callback
@@ -5995,16 +6007,17 @@ static bool llm_load_tensors(
59956007
GGML_ASSERT(hparams.n_expert_used > 0);
59966008

59976009
// MoE branch
5998-
auto n_ff_exp = n_ff / hparams.n_expert_used;
6010+
auto n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / hparams.n_expert_used;
59996011
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
60006012
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
60016013
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
60026014

60036015
// Shared expert branch
6016+
auto n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
60046017
layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
6005-
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff});
6006-
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff, n_embd});
6007-
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff});
6018+
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp});
6019+
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd});
6020+
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp});
60086021
}
60096022
} break;
60106023
case LLM_ARCH_PHI2:

0 commit comments

Comments
 (0)