Skip to content

Commit fd5eb19

Browse files
authored
Multi-Module Architecture Handling (#534)
Refactor architecture handling to support models with multiple distinct sub-modules. Intended to support use cases like encoder-decoder and vision-language architectures. Initially adds support for: * Gemma3 (w/ vision) * Whisper * T5 Also slightly speeds up merge planning by avoiding task reserialization.
1 parent 5dc8023 commit fd5eb19

35 files changed

+1899
-1738
lines changed
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
{
2+
"model_type": "gemma3_text",
3+
"architectures": [
4+
"Gemma3ForCausalLM"
5+
],
6+
"pre_weights": [
7+
{
8+
"name": "model.embed_tokens.weight",
9+
"is_embed": true
10+
}
11+
],
12+
"num_layers_config_key": "num_hidden_layers",
13+
"layer_templates": {
14+
"weights": [
15+
{
16+
"name": "model.layers.${layer_index}.input_layernorm.weight"
17+
},
18+
{
19+
"name": "model.layers.${layer_index}.self_attn.q_proj.weight"
20+
},
21+
{
22+
"name": "model.layers.${layer_index}.self_attn.q_norm.weight"
23+
},
24+
{
25+
"name": "model.layers.${layer_index}.self_attn.k_proj.weight"
26+
},
27+
{
28+
"name": "model.layers.${layer_index}.self_attn.k_norm.weight"
29+
},
30+
{
31+
"name": "model.layers.${layer_index}.self_attn.v_proj.weight"
32+
},
33+
{
34+
"name": "model.layers.${layer_index}.self_attn.o_proj.weight"
35+
},
36+
{
37+
"name": "model.layers.${layer_index}.post_attention_layernorm.weight"
38+
},
39+
{
40+
"name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight"
41+
},
42+
{
43+
"name": "model.layers.${layer_index}.mlp.up_proj.weight"
44+
},
45+
{
46+
"name": "model.layers.${layer_index}.mlp.gate_proj.weight"
47+
},
48+
{
49+
"name": "model.layers.${layer_index}.mlp.down_proj.weight"
50+
},
51+
{
52+
"name": "model.layers.${layer_index}.post_feedforward_layernorm.weight"
53+
}
54+
]
55+
},
56+
"post_weights": [
57+
{
58+
"name": "model.norm.weight"
59+
},
60+
{
61+
"name": "lm_head.weight",
62+
"is_embed": true,
63+
"optional": true,
64+
"tied_names": [
65+
"model.embed_tokens.weight"
66+
]
67+
}
68+
]
69+
}
+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
{
2+
"kind": "modular",
3+
"architectures": [
4+
"Gemma3ForConditionalGeneration"
5+
],
6+
"model_type": "gemma3",
7+
"tagalong_files": [
8+
"preprocessor_config.json",
9+
"processor_config.json"
10+
],
11+
"modules": {
12+
"text_decoder": {
13+
"weight_prefix": "language_model.",
14+
"architecture": {
15+
"model_type": "gemma3_text",
16+
"architectures": [
17+
"Gemma3ForCausalLM"
18+
],
19+
"pre_weights": [
20+
{
21+
"name": "model.embed_tokens.weight",
22+
"is_embed": true
23+
}
24+
],
25+
"num_layers_config_key": "text_config.num_hidden_layers",
26+
"layer_templates": {
27+
"weights": [
28+
{
29+
"name": "model.layers.${layer_index}.input_layernorm.weight"
30+
},
31+
{
32+
"name": "model.layers.${layer_index}.self_attn.q_proj.weight"
33+
},
34+
{
35+
"name": "model.layers.${layer_index}.self_attn.q_norm.weight"
36+
},
37+
{
38+
"name": "model.layers.${layer_index}.self_attn.k_proj.weight"
39+
},
40+
{
41+
"name": "model.layers.${layer_index}.self_attn.k_norm.weight"
42+
},
43+
{
44+
"name": "model.layers.${layer_index}.self_attn.v_proj.weight"
45+
},
46+
{
47+
"name": "model.layers.${layer_index}.self_attn.o_proj.weight"
48+
},
49+
{
50+
"name": "model.layers.${layer_index}.post_attention_layernorm.weight"
51+
},
52+
{
53+
"name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight"
54+
},
55+
{
56+
"name": "model.layers.${layer_index}.mlp.up_proj.weight"
57+
},
58+
{
59+
"name": "model.layers.${layer_index}.mlp.gate_proj.weight"
60+
},
61+
{
62+
"name": "model.layers.${layer_index}.mlp.down_proj.weight"
63+
},
64+
{
65+
"name": "model.layers.${layer_index}.post_feedforward_layernorm.weight"
66+
}
67+
]
68+
},
69+
"post_weights": [
70+
{
71+
"name": "model.norm.weight"
72+
},
73+
{
74+
"name": "lm_head.weight",
75+
"is_embed": true,
76+
"optional": true,
77+
"tied_names": [
78+
"model.embed_tokens.weight"
79+
]
80+
}
81+
]
82+
}
83+
},
84+
"multi_modal_projector": {
85+
"weight_prefix": "multi_modal_projector.",
86+
"architecture": {
87+
"model_type": "gemma3_mmproj",
88+
"architectures": [],
89+
"pre_weights": [
90+
{
91+
"name": "mm_input_projection_weight"
92+
},
93+
{
94+
"name": "mm_soft_emb_norm.weight"
95+
}
96+
],
97+
"post_weights": [],
98+
"layer_templates": {
99+
"weights": []
100+
},
101+
"override_num_layers": 0
102+
}
103+
},
104+
"vision_tower": {
105+
"weight_prefix": "vision_tower.vision_model.",
106+
"architecture": {
107+
"model_type": "siglip_vision_model",
108+
"architectures": [],
109+
"pre_weights": [
110+
{
111+
"name": "embeddings.patch_embedding.bias"
112+
},
113+
{
114+
"name": "embeddings.patch_embedding.weight"
115+
},
116+
{
117+
"name": "embeddings.position_embedding.weight"
118+
}
119+
],
120+
"post_weights": [
121+
{
122+
"name": "post_layernorm.bias"
123+
},
124+
{
125+
"name": "post_layernorm.weight"
126+
}
127+
],
128+
"layer_templates": {
129+
"weights": [
130+
{
131+
"name": "encoder.layers.${layer_index}.layer_norm1.bias"
132+
},
133+
{
134+
"name": "encoder.layers.${layer_index}.layer_norm1.weight"
135+
},
136+
{
137+
"name": "encoder.layers.${layer_index}.layer_norm2.bias"
138+
},
139+
{
140+
"name": "encoder.layers.${layer_index}.layer_norm2.weight"
141+
},
142+
{
143+
"name": "encoder.layers.${layer_index}.mlp.fc1.bias"
144+
},
145+
{
146+
"name": "encoder.layers.${layer_index}.mlp.fc1.weight"
147+
},
148+
{
149+
"name": "encoder.layers.${layer_index}.mlp.fc2.bias"
150+
},
151+
{
152+
"name": "encoder.layers.${layer_index}.mlp.fc2.weight"
153+
},
154+
{
155+
"name": "encoder.layers.${layer_index}.self_attn.k_proj.bias"
156+
},
157+
{
158+
"name": "encoder.layers.${layer_index}.self_attn.k_proj.weight"
159+
},
160+
{
161+
"name": "encoder.layers.${layer_index}.self_attn.out_proj.bias"
162+
},
163+
{
164+
"name": "encoder.layers.${layer_index}.self_attn.out_proj.weight"
165+
},
166+
{
167+
"name": "encoder.layers.${layer_index}.self_attn.q_proj.bias"
168+
},
169+
{
170+
"name": "encoder.layers.${layer_index}.self_attn.q_proj.weight"
171+
},
172+
{
173+
"name": "encoder.layers.${layer_index}.self_attn.v_proj.bias"
174+
},
175+
{
176+
"name": "encoder.layers.${layer_index}.self_attn.v_proj.weight"
177+
}
178+
]
179+
},
180+
"num_layers_config_key": "vision_config.num_hidden_layers"
181+
}
182+
}
183+
}
184+
}

0 commit comments

Comments
 (0)