15
15
16
16
import logging
17
17
import os
18
- from typing import Optional
18
+ from typing import Dict , List , Optional , Tuple
19
19
20
20
import tqdm
21
21
import transformers
26
26
get_architecture_info ,
27
27
)
28
28
from mergekit .card import generate_card
29
+ from mergekit .common import ModelReference
29
30
from mergekit .config import MergeConfiguration
30
31
from mergekit .graph import Executor
31
32
from mergekit .io .tasks import LoaderCache
@@ -46,31 +47,27 @@ def run_merge(
46
47
if not merge_config .models and not merge_config .slices :
47
48
raise RuntimeError ("No output requested" )
48
49
49
- ## TODO: ----------- reconcile these steps -------------
50
-
51
- model_arch_info = [
52
- get_architecture_info (m .config (trust_remote_code = options .trust_remote_code ))
53
- for m in merge_config .referenced_models ()
54
- ]
55
-
56
- arch_dict = {
57
- m .model .path : ConfiguredArchitectureInfo (
58
- info = get_architecture_info (
59
- m .config (trust_remote_code = options .trust_remote_code )
50
+ model_arch_info : List [Tuple [ModelReference , ConfiguredArchitectureInfo ]] = [
51
+ (
52
+ m ,
53
+ ConfiguredArchitectureInfo (
54
+ info = get_architecture_info (
55
+ m .config (trust_remote_code = options .trust_remote_code )
56
+ ),
57
+ config = m .config (trust_remote_code = options .trust_remote_code ),
60
58
),
61
- config = m .config (trust_remote_code = options .trust_remote_code ),
62
59
)
63
60
for m in merge_config .referenced_models ()
64
- }
65
-
66
- ## ----------------------------------------------------
61
+ ]
67
62
68
63
if not options .allow_crimes :
69
- if not all (a == model_arch_info [0 ] for a in model_arch_info [1 :]):
64
+ if not all (
65
+ a [1 ].info == model_arch_info [0 ][1 ].info for a in model_arch_info [1 :]
66
+ ):
70
67
raise RuntimeError (
71
68
"Must specify --allow-crimes to attempt to mix different architectures"
72
69
)
73
- arch_info = model_arch_info [0 ]
70
+ arch_info : ArchitectureInfo = model_arch_info [0 ][ 1 ]. info
74
71
75
72
# initialize loader cache and set options
76
73
loader_cache = LoaderCache ()
@@ -88,7 +85,7 @@ def run_merge(
88
85
targets = MergePlanner (
89
86
merge_config ,
90
87
arch_info ,
91
- arch_dict ,
88
+ dict ( model_arch_info ) ,
92
89
out_path = out_path ,
93
90
options = options ,
94
91
out_model_config = cfg_out ,
0 commit comments