Skip to content

Commit a699f63

Browse files
committed
Fixes
1 parent 96d66fd commit a699f63

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

mergekit/common.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def config(self, trust_remote_code: bool = False) -> PretrainedConfig:
149149
res.architectures = [self.override_architecture]
150150
return res
151151

152-
def tensor_index(self, cache_dir: Optional[str] = None) -> ShardedTensorIndex:
152+
def local_path(self, cache_dir: Optional[str] = None) -> str:
153153
assert self.lora is None
154154

155155
path = self.model.path
@@ -172,8 +172,10 @@ def tensor_index(self, cache_dir: Optional[str] = None) -> ShardedTensorIndex:
172172
cache_dir=cache_dir,
173173
allow_patterns=patterns,
174174
)
175+
return path
175176

176-
return ShardedTensorIndex.from_disk(path)
177+
def tensor_index(self, cache_dir: Optional[str] = None) -> ShardedTensorIndex:
178+
return ShardedTensorIndex.from_disk(self.local_path(cache_dir))
177179

178180
def lazy_loader(
179181
self, cache_dir: Optional[str] = None, lazy_unpickle: bool = True

mergekit/merge.py

+35-19
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def run_merge(
138138
merge_config,
139139
out_path,
140140
files=arch_info.tagalong_files or [],
141-
trust_remote_code=options.trust_remote_code,
142141
)
143142

144143
if getattr(arch_info, "post_fill_parameters", False):
@@ -204,15 +203,16 @@ def _copy_tagalong_files(
204203
merge_config: MergeConfiguration,
205204
out_path: str,
206205
files: List[str],
207-
trust_remote_code: bool = False,
208206
):
209207
donor_model = merge_config.base_model or (merge_config.referenced_models()[0])
208+
donor_local_path = donor_model.local_path()
210209

211210
for file_name in files:
212-
if os.path.exists(os.path.join(donor_model.model.path, file_name)):
211+
fp = os.path.join(donor_local_path, file_name)
212+
if os.path.exists(fp):
213213
logger.info(f"Copying {file_name} from {donor_model}")
214214
shutil.copy(
215-
os.path.join(donor_model.model.path, file_name),
215+
fp,
216216
os.path.join(out_path, file_name),
217217
)
218218

@@ -223,15 +223,14 @@ def _copy_tokenizer(
223223
merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False
224224
):
225225
donor_model = merge_config.base_model or (merge_config.referenced_models()[0])
226+
donor_local_path = donor_model.local_path()
226227

227228
if (
228229
(not merge_config.chat_template)
229-
and os.path.exists(
230-
os.path.join(donor_model.model.path, "tokenizer_config.json")
231-
)
230+
and os.path.exists(os.path.join(donor_local_path, "tokenizer_config.json"))
232231
and (
233-
os.path.exists(os.path.join(donor_model.model.path, "tokenizer.json"))
234-
or os.path.exists(os.path.join(donor_model.model.path, "tokenizer.model"))
232+
os.path.exists(os.path.join(donor_local_path, "tokenizer.json"))
233+
or os.path.exists(os.path.join(donor_local_path, "tokenizer.model"))
235234
)
236235
):
237236
logger.info(f"Copying tokenizer from {donor_model}")
@@ -244,9 +243,9 @@ def _copy_tokenizer(
244243
"added_tokens.json",
245244
"merges.txt",
246245
]:
247-
if os.path.exists(os.path.join(donor_model.model.path, file_name)):
246+
if os.path.exists(os.path.join(donor_local_path, file_name)):
248247
shutil.copy(
249-
os.path.join(donor_model.model.path, file_name),
248+
os.path.join(donor_local_path, file_name),
250249
os.path.join(out_path, file_name),
251250
)
252251

@@ -282,21 +281,38 @@ def _model_out_config(
282281
for module_name in arch_info.modules:
283282
if config.modules and module_name in config.modules:
284283
module_def = config.modules.get(module_name)
285-
module_layers[module_name] = sum(
286-
s.sources[0].layer_range[1] - s.sources[0].layer_range[0]
287-
for s in module_def.slices
288-
)
284+
if module_def and module_def.slices:
285+
module_layers[module_name] = sum(
286+
[
287+
s.sources[0].layer_range[1] - s.sources[0].layer_range[0]
288+
for s in module_def.slices
289+
]
290+
)
289291
elif config.slices:
290292
module_layers[module_name] = sum(
291-
s.sources[0].layer_range[1] - s.sources[0].layer_range[0]
292-
for s in config.slices
293+
[
294+
s.sources[0].layer_range[1] - s.sources[0].layer_range[0]
295+
for s in config.slices
296+
]
293297
)
294298

295299
if module_layers:
296300
for module_name in module_layers:
301+
if module_name not in arch_info.modules:
302+
logger.warning(
303+
f"Module {module_name} in config but not in architecture info"
304+
)
305+
continue
306+
module_info = arch_info.modules[module_name]
307+
cfg_key = module_info.architecture.num_layers_config_key()
308+
if not cfg_key:
309+
if module_layers[module_name] > 0:
310+
logger.warning(
311+
f"Module {module_name} has no configuration key for number of layers, "
312+
"but the number of layers is not zero."
313+
)
314+
continue
297315
try:
298-
module_info = arch_info.modules[module_name]
299-
cfg_key = module_info.architecture.num_layers_config_key()
300316
set_config_value(res, cfg_key, module_layers[module_name])
301317
except Exception as e:
302318
logger.warning(

0 commit comments

Comments
 (0)