Skip to content

Commit 5e1888e

Browse files
committed
Bring along preprocessor/processor config
1 parent 6125c51 commit 5e1888e

File tree

5 files changed

+236
-6
lines changed

5 files changed

+236
-6
lines changed

mergekit/_data/architectures/gemma3vl.json

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
"Gemma3ForConditionalGeneration"
55
],
66
"model_type": "gemma3",
7+
"tagalong_files": [
8+
"preprocessor_config.json",
9+
"processor_config.json"
10+
],
711
"modules": {
812
"text_decoder": {
913
"weight_prefix": "language_model.",
+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
{
2+
"kind": "modular",
3+
"architectures": [
4+
"WhisperForConditionalGeneration"
5+
],
6+
"model_type": "whisper",
7+
"tagalong_files": [
8+
"preprocessor_config.json",
9+
"normalizer.json"
10+
],
11+
"modules": {
12+
"decoder": {
13+
"weight_prefix": "model.decoder",
14+
"architecture": {
15+
"model_type": "",
16+
"architectures": [],
17+
"pre_weights": [
18+
{
19+
"name": "embed_tokens.weight",
20+
"is_embed": true
21+
},
22+
{
23+
"name": "embed_positions.weight"
24+
}
25+
],
26+
"num_layers_config_key": "decoder_layers",
27+
"layer_templates": {
28+
"weights": [
29+
{
30+
"name": "layers.${layer_index}.encoder_attn.k_proj.weight"
31+
},
32+
{
33+
"name": "layers.${layer_index}.encoder_attn.out_proj.bias"
34+
},
35+
{
36+
"name": "layers.${layer_index}.encoder_attn.out_proj.weight"
37+
},
38+
{
39+
"name": "layers.${layer_index}.encoder_attn.q_proj.bias"
40+
},
41+
{
42+
"name": "layers.${layer_index}.encoder_attn.q_proj.weight"
43+
},
44+
{
45+
"name": "layers.${layer_index}.encoder_attn.v_proj.bias"
46+
},
47+
{
48+
"name": "layers.${layer_index}.encoder_attn.v_proj.weight"
49+
},
50+
{
51+
"name": "layers.${layer_index}.encoder_attn_layer_norm.bias"
52+
},
53+
{
54+
"name": "layers.${layer_index}.encoder_attn_layer_norm.weight"
55+
},
56+
{
57+
"name": "layers.${layer_index}.fc1.bias"
58+
},
59+
{
60+
"name": "layers.${layer_index}.fc1.weight"
61+
},
62+
{
63+
"name": "layers.${layer_index}.fc2.bias"
64+
},
65+
{
66+
"name": "layers.${layer_index}.fc2.weight"
67+
},
68+
{
69+
"name": "layers.${layer_index}.final_layer_norm.bias"
70+
},
71+
{
72+
"name": "layers.${layer_index}.final_layer_norm.weight"
73+
},
74+
{
75+
"name": "layers.${layer_index}.self_attn.k_proj.weight"
76+
},
77+
{
78+
"name": "layers.${layer_index}.self_attn.out_proj.bias"
79+
},
80+
{
81+
"name": "layers.${layer_index}.self_attn.out_proj.weight"
82+
},
83+
{
84+
"name": "layers.${layer_index}.self_attn.q_proj.bias"
85+
},
86+
{
87+
"name": "layers.${layer_index}.self_attn.q_proj.weight"
88+
},
89+
{
90+
"name": "layers.${layer_index}.self_attn.v_proj.bias"
91+
},
92+
{
93+
"name": "layers.${layer_index}.self_attn.v_proj.weight"
94+
},
95+
{
96+
"name": "layers.${layer_index}.self_attn_layer_norm.bias"
97+
},
98+
{
99+
"name": "layers.${layer_index}.self_attn_layer_norm.weight"
100+
}
101+
]
102+
},
103+
"post_weights": [
104+
{
105+
"name": "layer_norm.bias"
106+
},
107+
{
108+
"name": "layer_norm.weight"
109+
}
110+
]
111+
}
112+
},
113+
"encoder": {
114+
"weight_prefix": "model.encoder.",
115+
"architecture": {
116+
"model_type": "",
117+
"architectures": [],
118+
"pre_weights": [
119+
{
120+
"name": "embed_positions.weight"
121+
},
122+
{
123+
"name": "conv1.bias"
124+
},
125+
{
126+
"name": "conv1.weight"
127+
},
128+
{
129+
"name": "conv2.bias"
130+
},
131+
{
132+
"name": "conv2.weight"
133+
}
134+
],
135+
"post_weights": [
136+
{
137+
"name": "layer_norm.bias"
138+
},
139+
{
140+
"name": "layer_norm.weight"
141+
}
142+
],
143+
"layer_templates": {
144+
"weights": [
145+
{
146+
"name": "layers.${layer_index}.fc1.bias"
147+
},
148+
{
149+
"name": "layers.${layer_index}.fc1.weight"
150+
},
151+
{
152+
"name": "layers.${layer_index}.fc2.bias"
153+
},
154+
{
155+
"name": "layers.${layer_index}.fc2.weight"
156+
},
157+
{
158+
"name": "layers.${layer_index}.final_layer_norm.bias"
159+
},
160+
{
161+
"name": "layers.${layer_index}.final_layer_norm.weight"
162+
},
163+
{
164+
"name": "layers.${layer_index}.self_attn.k_proj.weight"
165+
},
166+
{
167+
"name": "layers.${layer_index}.self_attn.out_proj.bias"
168+
},
169+
{
170+
"name": "layers.${layer_index}.self_attn.out_proj.weight"
171+
},
172+
{
173+
"name": "layers.${layer_index}.self_attn.q_proj.bias"
174+
},
175+
{
176+
"name": "layers.${layer_index}.self_attn.q_proj.weight"
177+
},
178+
{
179+
"name": "layers.${layer_index}.self_attn.v_proj.bias"
180+
},
181+
{
182+
"name": "layers.${layer_index}.self_attn.v_proj.weight"
183+
},
184+
{
185+
"name": "layers.${layer_index}.self_attn_layer_norm.bias"
186+
},
187+
{
188+
"name": "layers.${layer_index}.self_attn_layer_norm.weight"
189+
}
190+
]
191+
},
192+
"num_layers_config_key": "encoder_layers"
193+
}
194+
}
195+
}
196+
}

mergekit/architecture/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class ModelArchitecture(BaseModel, frozen=True):
127127
modules: Dict[str, ModuleDefinition]
128128
architectures: List[str]
129129
expected_model_type: str = Field(alias="model_type")
130+
tagalong_files: Optional[List[str]] = None
130131

131132
def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
132133
res = []

mergekit/architecture/json_definitions.py

+2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class JsonModularArchitectureDefinition(BaseModel, frozen=True):
104104
modules: Dict[str, JsonModuleDefinition]
105105
architectures: List[str]
106106
expected_model_type: str = Field(alias="model_type")
107+
tagalong_files: Optional[List[str]] = None
107108

108109

109110
class TemplateWithArithmetic(string.Template):
@@ -152,6 +153,7 @@ def _load_architecture_json(name: str) -> ModelArchitecture:
152153
},
153154
architectures=parsed.architectures,
154155
model_type=parsed.expected_model_type,
156+
tagalong_files=parsed.tagalong_files,
155157
)
156158
elif data.get("kind", "module") == "module":
157159
module = JsonModuleArchitecture(

mergekit/merge.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
import shutil
99
from collections import Counter
10-
from typing import Optional
10+
from typing import List, Optional
1111

1212
import tqdm
1313
import transformers
@@ -112,7 +112,11 @@ def run_merge(
112112
) as fp:
113113
fp.write(config_source)
114114

115-
if tokenizer is None:
115+
if tokenizer is not None:
116+
logger.info("Saving tokenizer")
117+
_set_chat_template(tokenizer, merge_config)
118+
tokenizer.save_pretrained(out_path, safe_serialization=True)
119+
else:
116120
if options.copy_tokenizer:
117121
try:
118122
_copy_tokenizer(
@@ -128,10 +132,12 @@ def run_merge(
128132
"Chat template specified but no tokenizer found. Chat template will not be saved."
129133
)
130134

131-
if tokenizer:
132-
logger.info("Saving tokenizer")
133-
_set_chat_template(tokenizer, merge_config)
134-
tokenizer.save_pretrained(out_path, safe_serialization=True)
135+
_copy_tagalong_files(
136+
merge_config,
137+
out_path,
138+
files=arch_info.tagalong_files or [],
139+
trust_remote_code=options.trust_remote_code,
140+
)
135141

136142
if getattr(arch_info, "post_fill_parameters", False):
137143
from mergekit.scripts.fill_missing_params import copy_and_fill_missing_params
@@ -192,6 +198,25 @@ def _set_chat_template(
192198
tokenizer.chat_template = chat_template
193199

194200

201+
def _copy_tagalong_files(
202+
merge_config: MergeConfiguration,
203+
out_path: str,
204+
files: List[str],
205+
trust_remote_code: bool = False,
206+
):
207+
donor_model = merge_config.base_model or (merge_config.referenced_models()[0])
208+
209+
for file_name in files:
210+
if os.path.exists(os.path.join(donor_model.model.path, file_name)):
211+
logger.info(f"Copying {file_name} from {donor_model}")
212+
shutil.copy(
213+
os.path.join(donor_model.model.path, file_name),
214+
os.path.join(out_path, file_name),
215+
)
216+
217+
return
218+
219+
195220
def _copy_tokenizer(
196221
merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False
197222
):
@@ -214,6 +239,8 @@ def _copy_tokenizer(
214239
"special_tokens_map.json",
215240
"tokenizer.json",
216241
"tokenizer.model",
242+
"added_tokens.json",
243+
"merges.txt",
217244
]:
218245
if os.path.exists(os.path.join(donor_model.model.path, file_name)):
219246
shutil.copy(

0 commit comments

Comments
 (0)