Skip to content

Commit 86772ed

Browse files
committed
Attempt to compress code
1 parent 81b2c14 commit 86772ed

File tree

2 files changed

+21
-30
lines changed

2 files changed

+21
-30
lines changed

mergekit/merge.py

+16-19
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import logging
1717
import os
18-
from typing import Optional
18+
from typing import Dict, List, Optional, Tuple
1919

2020
import tqdm
2121
import transformers
@@ -26,6 +26,7 @@
2626
get_architecture_info,
2727
)
2828
from mergekit.card import generate_card
29+
from mergekit.common import ModelReference
2930
from mergekit.config import MergeConfiguration
3031
from mergekit.graph import Executor
3132
from mergekit.io.tasks import LoaderCache
@@ -46,31 +47,27 @@ def run_merge(
4647
if not merge_config.models and not merge_config.slices:
4748
raise RuntimeError("No output requested")
4849

49-
## TODO: ----------- reconcile these steps -------------
50-
51-
model_arch_info = [
52-
get_architecture_info(m.config(trust_remote_code=options.trust_remote_code))
53-
for m in merge_config.referenced_models()
54-
]
55-
56-
arch_dict = {
57-
m.model.path: ConfiguredArchitectureInfo(
58-
info=get_architecture_info(
59-
m.config(trust_remote_code=options.trust_remote_code)
50+
model_arch_info: List[Tuple[ModelReference, ConfiguredArchitectureInfo]] = [
51+
(
52+
m,
53+
ConfiguredArchitectureInfo(
54+
info=get_architecture_info(
55+
m.config(trust_remote_code=options.trust_remote_code)
56+
),
57+
config=m.config(trust_remote_code=options.trust_remote_code),
6058
),
61-
config=m.config(trust_remote_code=options.trust_remote_code),
6259
)
6360
for m in merge_config.referenced_models()
64-
}
65-
66-
## ----------------------------------------------------
61+
]
6762

6863
if not options.allow_crimes:
69-
if not all(a == model_arch_info[0] for a in model_arch_info[1:]):
64+
if not all(
65+
a[1].info == model_arch_info[0][1].info for a in model_arch_info[1:]
66+
):
7067
raise RuntimeError(
7168
"Must specify --allow-crimes to attempt to mix different architectures"
7269
)
73-
arch_info = model_arch_info[0]
70+
arch_info: ArchitectureInfo = model_arch_info[0][1].info
7471

7572
# initialize loader cache and set options
7673
loader_cache = LoaderCache()
@@ -88,7 +85,7 @@ def run_merge(
8885
targets = MergePlanner(
8986
merge_config,
9087
arch_info,
91-
arch_dict,
88+
dict(model_arch_info),
9289
out_path=out_path,
9390
options=options,
9491
out_model_config=cfg_out,

mergekit/plan.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
class MergePlanner:
4545
config: MergeConfiguration
4646
arch_info: ArchitectureInfo
47-
arch_dict: Dict[str, ConfiguredArchitectureInfo]
47+
arch_dict: Dict[ModelReference, ConfiguredArchitectureInfo]
4848
clone_tensors: bool
4949
trust_remote_code: bool
5050
out_model_config: Any
@@ -60,7 +60,7 @@ def __init__(
6060
config: MergeConfiguration,
6161
arch_info: ArchitectureInfo,
6262
arch_dict: Dict[
63-
str, ConfiguredArchitectureInfo
63+
ModelReference, ConfiguredArchitectureInfo
6464
], # perhaps this should no longer be a disjoint step
6565
out_path: str,
6666
options: MergeOptions,
@@ -216,9 +216,7 @@ def plan_layer(
216216
config=self.out_model_config,
217217
)
218218
weights_in: List[List[WeightInfo]] = [
219-
self.arch_dict[s.model.model.path].layer_weights(
220-
index=s.layer_range[0] + layer_offset
221-
)
219+
self.arch_dict[s.model].layer_weights(index=s.layer_range[0] + layer_offset)
222220
for s in sources
223221
]
224222

@@ -268,9 +266,7 @@ def plan(self):
268266
models_ = [s.model for s in self.config.slices[0].sources]
269267
for weight_infos in zip(
270268
*[
271-
self.arch_dict[m.model.path].info.pre_weights(
272-
config=self.out_model_config
273-
)
269+
self.arch_dict[m].info.pre_weights(config=self.out_model_config)
274270
for m in models_
275271
]
276272
):
@@ -291,9 +287,7 @@ def plan(self):
291287
models_ = [s.model for s in self.config.slices[-1].sources]
292288
for weight_infos in zip(
293289
*[
294-
self.arch_dict[m.model.path].info.post_weights(
295-
config=self.out_model_config
296-
)
290+
self.arch_dict[m].info.post_weights(config=self.out_model_config)
297291
for m in models_
298292
]
299293
):

0 commit comments

Comments
 (0)