9
9
["ln_2" ]]
10
10
layer_keys_yi_norms = [["ln1" , "input_layernorm" ],
11
11
["ln2" , "post_attention_layernorm" ]]
12
+ layer_keys_gemma2_norms = [["input_layernorm" ],
13
+ ["post_attention_layernorm" ],
14
+ ["pre_feedforward_layernorm" ],
15
+ ["post_feedforward_layernorm" ]]
16
+ layer_keys_internlm2_norms = [["attention_norm" ],
17
+ ["ffn_norm" ]]
12
18
layer_keys_llama_attn = [["self_attn.q_proj" ],
13
19
["self_attn.k_proj" ],
14
20
["self_attn.v_proj" ],
17
23
["self_attn.c_attn" , "self_attn.k_proj" ],
18
24
["self_attn.c_attn" , "self_attn.v_proj" ],
19
25
["self_attn.o_proj" ]]
26
+ layer_keys_internlm2_attn = [["self_attn.wqkv" , "self_attn.q_proj" ],
27
+ ["self_attn.wqkv" , "self_attn.k_proj" ],
28
+ ["self_attn.wqkv" , "self_attn.v_proj" ],
29
+ ["self_attn.o_proj" ]]
20
30
layer_keys_dbrx_attn = [["self_attn.Wqkv" , "self_attn.q_proj" ],
21
31
["self_attn.Wqkv" , "self_attn.k_proj" ],
22
32
["self_attn.Wqkv" , "self_attn.v_proj" ],
28
38
layer_keys_llama_mlp = [["mlp.down_proj" ],
29
39
["mlp.gate_proj" ],
30
40
["mlp.up_proj" ]]
41
+ layer_keys_internlm2_mlp = [["feed_forward.w1" ],
42
+ ["feed_forward.w2" ],
43
+ ["feed_forward.w3" ]]
31
44
layer_keys_phi3_mlp = [["mlp.down_proj" ],
32
45
["mlp.gate_up_proj" , "mlp.gate_proj" ],
33
46
["mlp.gate_up_proj" , "mlp.up_proj" ]]
76
89
("$h." , "model.layers." ),
77
90
("$wte." , "model.embed_tokens." ),
78
91
("$wpe." , "model.wpe." )]
92
+ internlm2_keymap = [("$output." , "lm_head." ),
93
+ ("$model.tok_embeddings." , "model.embed_tokens." ),
94
+ (".attention." , ".self_attn." ),
95
+ (".wo." , ".o_proj." )]
79
96
80
97
class RopeStyle (Enum ):
81
98
NONE = 0
@@ -100,6 +117,18 @@ def __init__(self, arch_string, read_config):
100
117
self .orig_weights_transposed = False
101
118
self .logit_scale_basedim = False
102
119
120
+ self .norm_key_1_post = None
121
+ self .norm_key_2_post = None
122
+
123
+ self .swa = False
124
+ self .alternating_swa = False
125
+
126
+ self .eager_attn_only = False
127
+ self .clamp_hidden_states = False
128
+ self .residual_stream_fp32 = False
129
+
130
+ self .fused_qkv_altpack = False
131
+
103
132
# Mistral
104
133
105
134
if arch_string == "MistralForCausalLM" :
@@ -305,6 +334,45 @@ def __init__(self, arch_string, read_config):
305
334
self .mqa = False
306
335
self .scale_attn_weights = False
307
336
337
+ # Gemma2
338
+
339
+ if arch_string == "Gemma2ForCausalLM" :
340
+ arch_recognized = True
341
+ self .layer_keys += \
342
+ layer_keys_gemma2_norms + \
343
+ layer_keys_llama_attn + \
344
+ layer_keys_llama_mlp
345
+ self .expect_keys += \
346
+ expect_keys_gemma
347
+ self .norm_eps_key = "rms_norm_eps"
348
+ self .attention_bias_qkv = False
349
+ self .attention_bias_o = False
350
+ self .mlp_bias = False
351
+ self .mlp_gate = True
352
+ self .mlp_key_gate = ".mlp.gate_proj"
353
+ self .mlp_key_up = ".mlp.up_proj"
354
+ self .mlp_key_down = ".mlp.down_proj"
355
+ self .mlp_act_func = "gelu"
356
+ self .is_moe = False
357
+ self .norm = "rmsnorm"
358
+ self .lm_head_key = "model.embed_tokens"
359
+ self .normalize_embeddings = True
360
+ self .norm_key_1 = ".input_layernorm"
361
+ self .norm_key_1_post = ".post_attention_layernorm"
362
+ self .norm_key_2 = ".pre_feedforward_layernorm"
363
+ self .norm_key_2_post = ".post_feedforward_layernorm"
364
+ self .norm_constant_bias = 1
365
+ self .parallel_decoder_blocks = False
366
+ self .requires_bos = True
367
+ self .rope_style = RopeStyle .NEOX
368
+ self .keymap = None
369
+ self .fused_qkv_key = None
370
+ self .mqa = False
371
+ self .scale_attn_weights = False
372
+ self .pre_post_layernorm = True
373
+ self .alternating_swa = True
374
+ self .residual_stream_fp32 = True
375
+
308
376
# StarCoder2
309
377
310
378
if arch_string == "Starcoder2ForCausalLM" :
@@ -586,6 +654,41 @@ def __init__(self, arch_string, read_config):
586
654
self .scale_attn_weights = False
587
655
self .logit_scale_basedim = True
588
656
657
+ # InternLM2
658
+
659
+ if arch_string == "InternLM2ForCausalLM" :
660
+ arch_recognized = True
661
+ self .layer_keys += \
662
+ layer_keys_internlm2_norms + \
663
+ layer_keys_internlm2_attn + \
664
+ layer_keys_internlm2_mlp
665
+ self .expect_keys += \
666
+ expect_keys_llama
667
+ self .norm_eps_key = "rms_norm_eps"
668
+ self .attention_bias_qkv = False
669
+ self .attention_bias_o = False
670
+ self .mlp_bias = False
671
+ self .mlp_gate = True
672
+ self .mlp_key_gate = ".feed_forward.w1"
673
+ self .mlp_key_up = ".feed_forward.w3"
674
+ self .mlp_key_down = ".feed_forward.w2"
675
+ self .mlp_act_func = "silu"
676
+ self .is_moe = False
677
+ self .norm = "rmsnorm"
678
+ self .lm_head_key = "lm_head"
679
+ self .normalize_embeddings = False
680
+ self .norm_key_1 = ".attention_norm"
681
+ self .norm_key_2 = ".ffn_norm"
682
+ self .norm_constant_bias = 0
683
+ self .parallel_decoder_blocks = False
684
+ self .requires_bos = False
685
+ self .rope_style = RopeStyle .NEOX
686
+ self .keymap = internlm2_keymap
687
+ self .fused_qkv_key = "wqkv"
688
+ self .fused_qkv_altpack = True
689
+ self .mqa = False
690
+ self .scale_attn_weights = False
691
+
589
692
# Llama (default + fallback)
590
693
591
694
if arch_string != "LlamaForCausalLM" and not arch_recognized :
@@ -637,6 +740,11 @@ def __init__(self, arch_string, read_config):
637
740
self .expect_keys .remove (["lm_head" ])
638
741
self .lm_head_key = "model.embed_tokens"
639
742
743
+ # Sanity checks
744
+
745
+ if self .residual_stream_fp32 :
746
+ assert self .norm_key_1_post and self .norm_key_2_post , \
747
+ "FP32 residual stream only implemented for arch with post layernorms"
640
748
641
749
def make_fused_mlp (self ):
642
750
0 commit comments