From f18c0c3998ea4bdd702a0a174481f532da48107c Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Sun, 11 Feb 2024 06:08:30 -0500 Subject: [PATCH 01/29] WIP: most likely broken code to open up multiple architectures when planning --- mergekit/merge.py | 3 ++- mergekit/plan.py | 69 ++++++++++++++++++++++++----------------------- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/mergekit/merge.py b/mergekit/merge.py index 14d6ef46..61168f73 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -68,7 +68,7 @@ def run_merge( logging.info("Planning operations") targets = MergePlanner( merge_config, - arch_info, + model_arch_info, out_path=out_path, options=options, out_model_config=cfg_out, @@ -159,6 +159,7 @@ def _model_out_config( res.torch_dtype = config.dtype try: + print(config.slices) num_layers = sum( s.sources[0].layer_range[1] - s.sources[0].layer_range[0] for s in config.slices diff --git a/mergekit/plan.py b/mergekit/plan.py index d5d47a13..cf93e5ca 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -15,7 +15,7 @@ import logging from functools import lru_cache -from typing import Any, List, Optional +from typing import Any, List, Optional, Dict import torch @@ -44,6 +44,7 @@ class MergePlanner: config: MergeConfiguration arch_info: ArchitectureInfo + arch_dict: Dict[str, ConfiguredArchitectureInfo] clone_tensors: bool trust_remote_code: bool out_model_config: Any @@ -58,12 +59,14 @@ def __init__( self, config: MergeConfiguration, arch_info: ArchitectureInfo, + arch_dict: Dict[str, ConfiguredArchitectureInfo], # perhaps this should no longer be a disjoint step out_path: str, options: MergeOptions, out_model_config: Any, ): self.config = config - self.arch_info = arch_info + self.arch_info = arch_info # Special because how referenced models list is constructed ? + self.arch_dict = arch_dict self.clone_tensors = options.clone_tensors self.trust_remote_code = options.trust_remote_code self.out_model_config = out_model_config @@ -85,16 +88,7 @@ def __init__( if config.base_model and config.align_weights: self._space_planner = SpacePlanner(config.base_model) - @lru_cache - def model_arch_info(self, model: ModelReference): - return ConfiguredArchitectureInfo( - info=self.arch_info, - config=model.config(trust_remote_code=self.trust_remote_code), - ) - def normalize_config(self): - base_model = self.config.base_model - # if models to merge are specified instead of output slices, compute them if self.config.models: if self.config.slices: @@ -103,13 +97,13 @@ def normalize_config(self): ) slices_in = [] - base_included = False for model_in in self.config.models: - if base_model and model_in.model == base_model: - base_included = True - model_info = self.model_arch_info(model_in.model) + if model_in == self.config.base_model: + continue + + model_info = self.arch_dict[model_in.model.path] slices_in.append( InputSliceDefinition( layer_range=[0, model_info.num_layers()], @@ -118,19 +112,22 @@ def normalize_config(self): ) ) - if base_model and not base_included: - logging.info("Base model specified but not in input models - adding") - base_info = self.model_arch_info(base_model) - slices_in.append( + # Ensures base model is first in list + + if self.config.base_model: + base_model_info = self.arch_dict[self.config.base_model.path] + slices_in = [ InputSliceDefinition( - layer_range=[0, base_info.num_layers()], - model=base_model, + layer_range=[0, base_model_info.num_layers()], + model=self.config.base_model, + parameters=self.config.base_model.parameters, ) - ) + ] + slices_in self.config.slices = [OutputSliceDefinition(sources=slices_in)] self.config.models = None + def plan_tensor( self, weight: WeightInfo, @@ -217,7 +214,7 @@ def plan_layer( config=self.out_model_config, ) weights_in: List[List[WeightInfo]] = [ - self.model_arch_info(s.model).layer_weights( + self.arch_dict(s.model.path).layer_weights( index=s.layer_range[0] + layer_offset ) for s in sources @@ -234,6 +231,8 @@ def plan_layer( self._current_layers += 1 def plan_slice(self, definition: OutputSliceDefinition): + print("plan_slice:") + print(definition) slice_lengths = [ s.layer_range[1] - s.layer_range[0] for s in definition.sources ] @@ -266,30 +265,32 @@ def plan(self): for space in self.arch_info.procedural_spaces(config=self.out_model_config): self._space_planner.add_procedural_space(space) - for weight_info in self.arch_info.pre_weights(config=self.out_model_config): + models_ = [s.model for s in self.config.slices[0].sources] + for weight_infos in zip(*[self.arch_dict[m.name].pre_weights(config=self.out_model_config) for m in models_.name]): self.plan_tensor( - weight_info, - [weight_info] * len(self.config.slices[0].sources), - [s.model for s in self.config.slices[0].sources], - ConfigReader( + weight_infos[0], + list(weight_infos), + models_, + ConfigReader( # possible trouble here? config=self.config, t=0, - tensor_name=weight_info.name, + tensor_name=weight_infos[0].name, ).for_out_slice(self.config.slices[0]), ) for out_slice in self.config.slices: self.plan_slice(out_slice) - for weight_info in self.arch_info.post_weights(config=self.out_model_config): + models_ = [s.model for s in self.config.slices[-1].sources] + for weight_infos in zip(*[self.arch_dict[m.name].post_weights(config=self.out_model_config) for m in models_.name]): self.plan_tensor( - weight_info, - [weight_info] * len(self.config.slices[-1].sources), - [s.model for s in self.config.slices[-1].sources], + weight_infos[0], + list(weight_infos), + models_, ConfigReader( config=self.config, t=1, - tensor_name=weight_info.name, + tensor_name=weight_infos[0].name, ).for_out_slice(self.config.slices[-1]), ) From 238419cfafa84bcc77bff50d52a4db98949341e7 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Sun, 11 Feb 2024 06:11:34 -0500 Subject: [PATCH 02/29] WIP: note and tiny correction --- mergekit/merge.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/mergekit/merge.py b/mergekit/merge.py index 61168f73..4e0c4671 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -20,7 +20,7 @@ import tqdm import transformers -from mergekit.architecture import ArchitectureInfo, get_architecture_info +from mergekit.architecture import ArchitectureInfo, get_architecture_info, ConfiguredArchitectureInfo from mergekit.card import generate_card from mergekit.config import MergeConfiguration from mergekit.graph import Executor @@ -42,10 +42,26 @@ def run_merge( if not merge_config.models and not merge_config.slices: raise RuntimeError("No output requested") + ## TODO: ----------- reconcile these steps ------------- + model_arch_info = [ get_architecture_info(m.config(trust_remote_code=options.trust_remote_code)) for m in merge_config.referenced_models() ] + + arch_dict = { + m.model.path: ConfiguredArchitectureInfo( + info=get_architecture_info( + m.config(trust_remote_code=options.trust_remote_code) + ), + config=m.model.config(trust_remote_code=options.trust_remote_code) + ) + for m in merge_config.referenced_models() + } + + ## ---------------------------------------------------- + + if not options.allow_crimes: if not all(a == model_arch_info[0] for a in model_arch_info[1:]): raise RuntimeError( @@ -68,7 +84,8 @@ def run_merge( logging.info("Planning operations") targets = MergePlanner( merge_config, - model_arch_info, + arch_info, + arch_dict, out_path=out_path, options=options, out_model_config=cfg_out, From 5c6de5be18b22c80d0e9509e736407db39f6440b Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Sun, 11 Feb 2024 19:21:35 -0500 Subject: [PATCH 03/29] WIP: corrections --- mergekit/merge.py | 2 +- mergekit/plan.py | 36 +++++++++++++++++++++--------------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/mergekit/merge.py b/mergekit/merge.py index 4e0c4671..0e224085 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -54,7 +54,7 @@ def run_merge( info=get_architecture_info( m.config(trust_remote_code=options.trust_remote_code) ), - config=m.model.config(trust_remote_code=options.trust_remote_code) + config=m.config(trust_remote_code=options.trust_remote_code) ) for m in merge_config.referenced_models() } diff --git a/mergekit/plan.py b/mergekit/plan.py index cf93e5ca..56a408b4 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -97,32 +97,33 @@ def normalize_config(self): ) slices_in = [] + base_model = None for model_in in self.config.models: - if model_in == self.config.base_model: - continue - - model_info = self.arch_dict[model_in.model.path] - slices_in.append( - InputSliceDefinition( + model_info = self.arch_dict[model_in.model.model.path] + slice = InputSliceDefinition( layer_range=[0, model_info.num_layers()], model=model_in.model, parameters=model_in.parameters, ) - ) - # Ensures base model is first in list + if model_in.model == self.config.base_model: + base_model = slice + else: + slices_in.append(slice) + # Ensures base model is first in list if self.config.base_model: - base_model_info = self.arch_dict[self.config.base_model.path] - slices_in = [ - InputSliceDefinition( + if not base_model: + base_model_info = self.arch_dict[self.config.base_model.model.path] + base_model = InputSliceDefinition( layer_range=[0, base_model_info.num_layers()], model=self.config.base_model, - parameters=self.config.base_model.parameters, + # TODO: possible problematic area + # parameters=self.config.base_model.parameters, ) - ] + slices_in + slices_in = [base_model] + slices_in self.config.slices = [OutputSliceDefinition(sources=slices_in)] self.config.models = None @@ -266,7 +267,9 @@ def plan(self): self._space_planner.add_procedural_space(space) models_ = [s.model for s in self.config.slices[0].sources] - for weight_infos in zip(*[self.arch_dict[m.name].pre_weights(config=self.out_model_config) for m in models_.name]): + print("==========================") + print(list(zip(*[self.arch_dict[m.model.path].info.pre_weights(config=self.out_model_config) for m in models_]))) + for weight_infos in zip(*[self.arch_dict[m.model.path].info.pre_weights(config=self.out_model_config) for m in models_]): self.plan_tensor( weight_infos[0], list(weight_infos), @@ -282,7 +285,10 @@ def plan(self): self.plan_slice(out_slice) models_ = [s.model for s in self.config.slices[-1].sources] - for weight_infos in zip(*[self.arch_dict[m.name].post_weights(config=self.out_model_config) for m in models_.name]): + print("==========POST ==============") + for a in list(zip(*[self.arch_dict[m.model.path].info.post_weights(config=self.out_model_config) for m in models_])): + print(a) + for weight_infos in zip(*[self.arch_dict[m.model.path].info.post_weights(config=self.out_model_config) for m in models_]): self.plan_tensor( weight_infos[0], list(weight_infos), From cc4da48ae74559b334d313479c05da8e0153a3bf Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Mon, 12 Feb 2024 23:24:49 -0500 Subject: [PATCH 04/29] Appeasing pre-commit --- mergekit/merge.py | 13 +++++++----- mergekit/plan.py | 54 +++++++++++++++++++++++++++-------------------- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/mergekit/merge.py b/mergekit/merge.py index 0e224085..ab3a2b1d 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -20,7 +20,11 @@ import tqdm import transformers -from mergekit.architecture import ArchitectureInfo, get_architecture_info, ConfiguredArchitectureInfo +from mergekit.architecture import ( + ArchitectureInfo, + ConfiguredArchitectureInfo, + get_architecture_info, +) from mergekit.card import generate_card from mergekit.config import MergeConfiguration from mergekit.graph import Executor @@ -54,14 +58,13 @@ def run_merge( info=get_architecture_info( m.config(trust_remote_code=options.trust_remote_code) ), - config=m.config(trust_remote_code=options.trust_remote_code) - ) - for m in merge_config.referenced_models() + config=m.config(trust_remote_code=options.trust_remote_code), + ) + for m in merge_config.referenced_models() } ## ---------------------------------------------------- - if not options.allow_crimes: if not all(a == model_arch_info[0] for a in model_arch_info[1:]): raise RuntimeError( diff --git a/mergekit/plan.py b/mergekit/plan.py index 56a408b4..85c852a1 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -15,7 +15,7 @@ import logging from functools import lru_cache -from typing import Any, List, Optional, Dict +from typing import Any, Dict, List, Optional import torch @@ -59,13 +59,17 @@ def __init__( self, config: MergeConfiguration, arch_info: ArchitectureInfo, - arch_dict: Dict[str, ConfiguredArchitectureInfo], # perhaps this should no longer be a disjoint step + arch_dict: Dict[ + str, ConfiguredArchitectureInfo + ], # perhaps this should no longer be a disjoint step out_path: str, options: MergeOptions, out_model_config: Any, ): self.config = config - self.arch_info = arch_info # Special because how referenced models list is constructed ? + self.arch_info = ( + arch_info # Special because how referenced models list is constructed ? + ) self.arch_dict = arch_dict self.clone_tensors = options.clone_tensors self.trust_remote_code = options.trust_remote_code @@ -100,35 +104,32 @@ def normalize_config(self): base_model = None for model_in in self.config.models: - model_info = self.arch_dict[model_in.model.model.path] slice = InputSliceDefinition( - layer_range=[0, model_info.num_layers()], - model=model_in.model, - parameters=model_in.parameters, - ) + layer_range=[0, model_info.num_layers()], + model=model_in.model, + parameters=model_in.parameters, + ) if model_in.model == self.config.base_model: base_model = slice else: slices_in.append(slice) - # Ensures base model is first in list if self.config.base_model: if not base_model: base_model_info = self.arch_dict[self.config.base_model.model.path] base_model = InputSliceDefinition( layer_range=[0, base_model_info.num_layers()], model=self.config.base_model, - # TODO: possible problematic area - # parameters=self.config.base_model.parameters, ) + + # Ensures base model is first in list slices_in = [base_model] + slices_in self.config.slices = [OutputSliceDefinition(sources=slices_in)] self.config.models = None - def plan_tensor( self, weight: WeightInfo, @@ -215,7 +216,7 @@ def plan_layer( config=self.out_model_config, ) weights_in: List[List[WeightInfo]] = [ - self.arch_dict(s.model.path).layer_weights( + self.arch_dict(s.model.model.path).layer_weights( index=s.layer_range[0] + layer_offset ) for s in sources @@ -232,8 +233,6 @@ def plan_layer( self._current_layers += 1 def plan_slice(self, definition: OutputSliceDefinition): - print("plan_slice:") - print(definition) slice_lengths = [ s.layer_range[1] - s.layer_range[0] for s in definition.sources ] @@ -267,14 +266,19 @@ def plan(self): self._space_planner.add_procedural_space(space) models_ = [s.model for s in self.config.slices[0].sources] - print("==========================") - print(list(zip(*[self.arch_dict[m.model.path].info.pre_weights(config=self.out_model_config) for m in models_]))) - for weight_infos in zip(*[self.arch_dict[m.model.path].info.pre_weights(config=self.out_model_config) for m in models_]): + for weight_infos in zip( + *[ + self.arch_dict[m.model.path].info.pre_weights( + config=self.out_model_config + ) + for m in models_ + ] + ): self.plan_tensor( weight_infos[0], list(weight_infos), models_, - ConfigReader( # possible trouble here? + ConfigReader( # possible trouble here? config=self.config, t=0, tensor_name=weight_infos[0].name, @@ -285,10 +289,14 @@ def plan(self): self.plan_slice(out_slice) models_ = [s.model for s in self.config.slices[-1].sources] - print("==========POST ==============") - for a in list(zip(*[self.arch_dict[m.model.path].info.post_weights(config=self.out_model_config) for m in models_])): - print(a) - for weight_infos in zip(*[self.arch_dict[m.model.path].info.post_weights(config=self.out_model_config) for m in models_]): + for weight_infos in zip( + *[ + self.arch_dict[m.model.path].info.post_weights( + config=self.out_model_config + ) + for m in models_ + ] + ): self.plan_tensor( weight_infos[0], list(weight_infos), From 4d3b7f65fb67959b5b0f506594fb4159c59f2698 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Mon, 12 Feb 2024 23:29:44 -0500 Subject: [PATCH 05/29] Add dependency scipy --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c5acd4fc..fbd3fbde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "typing-extensions", "sentencepiece", "protobuf", + "scipy" ] [project.optional-dependencies] From 81b2c140b15723436a0824a8bbebd3ae069212bb Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Mon, 12 Feb 2024 23:38:17 -0500 Subject: [PATCH 06/29] Correction --- mergekit/plan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mergekit/plan.py b/mergekit/plan.py index 85c852a1..2cec12b2 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -216,7 +216,7 @@ def plan_layer( config=self.out_model_config, ) weights_in: List[List[WeightInfo]] = [ - self.arch_dict(s.model.model.path).layer_weights( + self.arch_dict[s.model.model.path].layer_weights( index=s.layer_range[0] + layer_offset ) for s in sources From 86772ed2db82fef2e126a4d1eae719e860aab033 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Tue, 13 Feb 2024 01:48:21 -0500 Subject: [PATCH 07/29] Attempt to compress code --- mergekit/merge.py | 35 ++++++++++++++++------------------- mergekit/plan.py | 16 +++++----------- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/mergekit/merge.py b/mergekit/merge.py index ab3a2b1d..bf444287 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -15,7 +15,7 @@ import logging import os -from typing import Optional +from typing import Dict, List, Optional, Tuple import tqdm import transformers @@ -26,6 +26,7 @@ get_architecture_info, ) from mergekit.card import generate_card +from mergekit.common import ModelReference from mergekit.config import MergeConfiguration from mergekit.graph import Executor from mergekit.io.tasks import LoaderCache @@ -46,31 +47,27 @@ def run_merge( if not merge_config.models and not merge_config.slices: raise RuntimeError("No output requested") - ## TODO: ----------- reconcile these steps ------------- - - model_arch_info = [ - get_architecture_info(m.config(trust_remote_code=options.trust_remote_code)) - for m in merge_config.referenced_models() - ] - - arch_dict = { - m.model.path: ConfiguredArchitectureInfo( - info=get_architecture_info( - m.config(trust_remote_code=options.trust_remote_code) + model_arch_info: List[Tuple[ModelReference, ConfiguredArchitectureInfo]] = [ + ( + m, + ConfiguredArchitectureInfo( + info=get_architecture_info( + m.config(trust_remote_code=options.trust_remote_code) + ), + config=m.config(trust_remote_code=options.trust_remote_code), ), - config=m.config(trust_remote_code=options.trust_remote_code), ) for m in merge_config.referenced_models() - } - - ## ---------------------------------------------------- + ] if not options.allow_crimes: - if not all(a == model_arch_info[0] for a in model_arch_info[1:]): + if not all( + a[1].info == model_arch_info[0][1].info for a in model_arch_info[1:] + ): raise RuntimeError( "Must specify --allow-crimes to attempt to mix different architectures" ) - arch_info = model_arch_info[0] + arch_info: ArchitectureInfo = model_arch_info[0][1].info # initialize loader cache and set options loader_cache = LoaderCache() @@ -88,7 +85,7 @@ def run_merge( targets = MergePlanner( merge_config, arch_info, - arch_dict, + dict(model_arch_info), out_path=out_path, options=options, out_model_config=cfg_out, diff --git a/mergekit/plan.py b/mergekit/plan.py index 2cec12b2..c5d0f7c8 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -44,7 +44,7 @@ class MergePlanner: config: MergeConfiguration arch_info: ArchitectureInfo - arch_dict: Dict[str, ConfiguredArchitectureInfo] + arch_dict: Dict[ModelReference, ConfiguredArchitectureInfo] clone_tensors: bool trust_remote_code: bool out_model_config: Any @@ -60,7 +60,7 @@ def __init__( config: MergeConfiguration, arch_info: ArchitectureInfo, arch_dict: Dict[ - str, ConfiguredArchitectureInfo + ModelReference, ConfiguredArchitectureInfo ], # perhaps this should no longer be a disjoint step out_path: str, options: MergeOptions, @@ -216,9 +216,7 @@ def plan_layer( config=self.out_model_config, ) weights_in: List[List[WeightInfo]] = [ - self.arch_dict[s.model.model.path].layer_weights( - index=s.layer_range[0] + layer_offset - ) + self.arch_dict[s.model].layer_weights(index=s.layer_range[0] + layer_offset) for s in sources ] @@ -268,9 +266,7 @@ def plan(self): models_ = [s.model for s in self.config.slices[0].sources] for weight_infos in zip( *[ - self.arch_dict[m.model.path].info.pre_weights( - config=self.out_model_config - ) + self.arch_dict[m].info.pre_weights(config=self.out_model_config) for m in models_ ] ): @@ -291,9 +287,7 @@ def plan(self): models_ = [s.model for s in self.config.slices[-1].sources] for weight_infos in zip( *[ - self.arch_dict[m.model.path].info.post_weights( - config=self.out_model_config - ) + self.arch_dict[m].info.post_weights(config=self.out_model_config) for m in models_ ] ): From 0f4b44c4db7e3782ac7d04b6cd4743f6836df191 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Tue, 13 Feb 2024 02:06:39 -0500 Subject: [PATCH 08/29] Correction --- mergekit/plan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mergekit/plan.py b/mergekit/plan.py index c5d0f7c8..7ed6445b 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -104,7 +104,7 @@ def normalize_config(self): base_model = None for model_in in self.config.models: - model_info = self.arch_dict[model_in.model.model.path] + model_info = self.arch_dict[model_in.model] slice = InputSliceDefinition( layer_range=[0, model_info.num_layers()], model=model_in.model, @@ -118,7 +118,7 @@ def normalize_config(self): if self.config.base_model: if not base_model: - base_model_info = self.arch_dict[self.config.base_model.model.path] + base_model_info = self.arch_dict[self.config.base_model] base_model = InputSliceDefinition( layer_range=[0, base_model_info.num_layers()], model=self.config.base_model, From 0dc55a4297642b824ec8bb688aa401cbf0316ed6 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Tue, 13 Feb 2024 04:37:30 -0500 Subject: [PATCH 09/29] Bypass mixed archiecture user warning/error if align_weights is enabled --- mergekit/merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mergekit/merge.py b/mergekit/merge.py index bf444287..711bc9b1 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -60,7 +60,7 @@ def run_merge( for m in merge_config.referenced_models() ] - if not options.allow_crimes: + if not (options.allow_crimes or options.align_weights): if not all( a[1].info == model_arch_info[0][1].info for a in model_arch_info[1:] ): From 0e3d4cbd31ab179f020db223ac0b5e554ecd939f Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Thu, 15 Feb 2024 03:31:26 -0500 Subject: [PATCH 10/29] WIP: Draft of correction --- mergekit/architecture.py | 94 +++++++++++++++++++++++++++++++++++++--- mergekit/merge.py | 25 +++++++++-- 2 files changed, 109 insertions(+), 10 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 708faeea..8ba16103 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -16,7 +16,7 @@ import importlib.resources import string from abc import ABC, abstractmethod -from typing import ClassVar, Dict, List, Optional, Tuple, Union +from typing import ClassVar, Dict, List, Optional, Tuple, TypeAlias, Union from pydantic import BaseModel, Field from transformers import PretrainedConfig @@ -123,18 +123,48 @@ def has_defined_spaces(self) -> bool: return False +class MappingInfo(BaseModel, frozen=True): + """Information about a mapping between two models. + + Attributes: + from_model (str): + The name of the model from which the mapping originates. + to_model (str): + The name of the model to which the mapping applies. + """ + + start_architectures: List[str] + destination_architectures: List[str] + pre_weights_mapping: Dict[str, List[str]] + post_weights_mapping: Dict[str, List[str]] + + +class Mapping(BaseModel, frozen=True): + pre_weights: List[str] + post_weights: List[str] + + class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True): info: ArchitectureInfo config: PretrainedConfig + overrides: Optional[ + Dict[str, List[str]] + ] # TODO: check if the optional is necessary def num_layers(self) -> int: return self.info.num_layers(self.config) def pre_weights(self) -> List[WeightInfo]: - return self.info.pre_weights(self.config) + if not self.overrides: + return self.info.pre_weights(self.config) + + return self.overrides["pre_weights"] def post_weights(self) -> List[WeightInfo]: - return self.info.post_weights(self.config) + if not self.overrides: + return self.info.post_weights(self.config) + + return self.overrides["post_weights"] def layer_weights(self, index: int) -> List[WeightInfo]: return self.info.layer_weights(index, self.config) @@ -145,6 +175,13 @@ def procedural_spaces(self) -> List[ProceduralSpaceInfo]: def all_weights(self) -> List[WeightInfo]: return self.info.all_weights(self.config) + def update_overrides( + self, overrides: Dict[str, List[str]] + ) -> "ConfiguredArchitectureInfo": + return ConfiguredArchitectureInfo( + info=self.info, config=self.config, overrides=overrides + ) + class JSONLayerTemplates(BaseModel, frozen=True): weights: List[WeightInfo] @@ -198,10 +235,16 @@ def _substitute( return type(item).model_validate(obj_dict) def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - return [ + # assume fleshed out names for now in self.overrides + weights = [ self._substitute(wi, config=config) for wi in self.definition.pre_weights ] + if self.overrides: + weights = [wi for wi in weights if wi.name in self.overrides["pre_weights"]] + + return weights + def layer_weights( self, index: int, config: PretrainedConfig ) -> Optional[List[WeightInfo]]: @@ -211,10 +254,18 @@ def layer_weights( ] def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - return [ + # assume fleshed out names for now in self.overrides + weights = [ self._substitute(wi, config=config) for wi in self.definition.post_weights ] + if self.overrides: + weights = [ + wi for wi in weights if wi.name in self.overrides["post_weights"] + ] + + return weights + def sliceable(self) -> bool: return True @@ -314,7 +365,7 @@ def _load_all_architectures() -> ( MISTRAL_INFO = _load_json_arch("mistral.json") -def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: +def _load_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: if len(config.architectures) != 1: raise RuntimeError("More than one architecture in config?") @@ -337,3 +388,34 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: raise RuntimeError( f"Unsupported model_type {config.model_type} for architecture {arch_name}" ) + + +def _load_arch_mappings(name) -> JsonArchitectureInfo: + text = importlib.resources.read_text(mergekit._data.mappings, name) + return MappingInfo.model_validate_json(text) + + +# TODO: should be immutable map +def _load_all_mappings() -> Dict[str, Dict[str, Dict[str, List[str]]]]: + mappings: Dict[str, MappingInfo] = {} + for f in importlib.resources.contents(mergekit._data.mappings): + if f.lower().endswith(".json"): + mapping = _load_arch_mappings(f) + for start_architecture in mapping.start_architectures: + if start_architecture not in mappings: + mappings[start_architecture] = {} + + for destination_architecture in mapping.destination_architectures: + if destination_architecture not in mappings[start_architecture]: + pre_weights = mapping["pre_weights"]["destination"] + post_weights = mapping["post_weights"]["destination"] + mappings[start_architecture][destination_architecture] = { + "pre_weights": pre_weights, + "post_weights": post_weights, + } + + # TODO: think about reverse mapping + return mappings + + +JSON_MAPPINGS = _load_all_mappings() diff --git a/mergekit/merge.py b/mergekit/merge.py index 711bc9b1..9036bc33 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -20,6 +20,9 @@ import tqdm import transformers +from mergekit.architecture import ( + JSON_MAPPINGS, # TODO best write a function to get the JSON_MAPPINGS +) from mergekit.architecture import ( ArchitectureInfo, ConfiguredArchitectureInfo, @@ -60,15 +63,32 @@ def run_merge( for m in merge_config.referenced_models() ] - if not (options.allow_crimes or options.align_weights): + if not (options.allow_crimes or merge_config.align_weights): if not all( a[1].info == model_arch_info[0][1].info for a in model_arch_info[1:] ): raise RuntimeError( "Must specify --allow-crimes to attempt to mix different architectures" ) + arch_info: ArchitectureInfo = model_arch_info[0][1].info + if merge_config.align_weights: + new_model_arch_info = [model_arch_info[0]] + for m, destination_arch_info in model_arch_info[1:]: + mapping = JSON_MAPPINGS.get(arch_info.config.architecture, {}).get( + destination_arch_info.config.architecture_version + ) + if mapping: + new_model_arch_info.append( + (m, destination_arch_info.update_overrides(mapping)) + ) + else: + # warn user + new_model_arch_info.append((m, destination_arch_info)) + + model_arch_info = new_model_arch_info + # initialize loader cache and set options loader_cache = LoaderCache() loader_cache.lazy_unpickle = options.lazy_unpickle @@ -202,6 +222,3 @@ def _update_config_vocab( "Unable to set vocabulary size in output config - you may need to manually correct it.", exc_info=e, ) - - -__all__ = ["MergeOptions", "run_merge"] From e5134f4bdfe02a4fbeda9787472cfc0fb2c474ef Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Thu, 15 Feb 2024 03:32:15 -0500 Subject: [PATCH 11/29] WIP: Example mapping --- mergekit/_data/mappings/a_b.json | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 mergekit/_data/mappings/a_b.json diff --git a/mergekit/_data/mappings/a_b.json b/mergekit/_data/mappings/a_b.json new file mode 100644 index 00000000..bb468a42 --- /dev/null +++ b/mergekit/_data/mappings/a_b.json @@ -0,0 +1,12 @@ +{ + "start_architectures": ["A"], + "destination_architectures": ["B"], + "pre_weights_mapping":{ + "start" : ["a", "b", "c"], + "destination" : ["x", "y", "z"] + }, + "post_weights_mapping":{ + "start" : ["d", "e", "f"], + "destination" : ["m", "n", "o"] + } +} From fa87d743bfa7f9001dfe672089e866cf6953851a Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Fri, 16 Feb 2024 02:45:27 -0500 Subject: [PATCH 12/29] WIP: Code for handling unmapped pre/post weights --- mergekit/_data/mappings/a_b.json | 4 ++-- mergekit/architecture.py | 36 +++++++++++++++++++++----------- mergekit/common.py | 12 +++++++++++ mergekit/plan.py | 8 +++---- 4 files changed, 42 insertions(+), 18 deletions(-) diff --git a/mergekit/_data/mappings/a_b.json b/mergekit/_data/mappings/a_b.json index bb468a42..329028dd 100644 --- a/mergekit/_data/mappings/a_b.json +++ b/mergekit/_data/mappings/a_b.json @@ -3,10 +3,10 @@ "destination_architectures": ["B"], "pre_weights_mapping":{ "start" : ["a", "b", "c"], - "destination" : ["x", "y", "z"] + "destination" : [null, "y", "z"] }, "post_weights_mapping":{ "start" : ["d", "e", "f"], - "destination" : ["m", "n", "o"] + "destination" : ["m", null, null] } } diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 8ba16103..e587206c 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -16,7 +16,7 @@ import importlib.resources import string from abc import ABC, abstractmethod -from typing import ClassVar, Dict, List, Optional, Tuple, TypeAlias, Union +from typing import ClassVar, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, Field from transformers import PretrainedConfig @@ -135,20 +135,20 @@ class MappingInfo(BaseModel, frozen=True): start_architectures: List[str] destination_architectures: List[str] - pre_weights_mapping: Dict[str, List[str]] - post_weights_mapping: Dict[str, List[str]] + pre_weights_mapping: Dict[str, List[Optional[str]]] + post_weights_mapping: Dict[str, List[Optional[str]]] class Mapping(BaseModel, frozen=True): - pre_weights: List[str] - post_weights: List[str] + pre_weights: List[Optional[str]] + post_weights: List[Optional[str]] class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True): info: ArchitectureInfo config: PretrainedConfig overrides: Optional[ - Dict[str, List[str]] + Dict[str, List[Optional[str]]] ] # TODO: check if the optional is necessary def num_layers(self) -> int: @@ -234,14 +234,21 @@ def _substitute( ) return type(item).model_validate(obj_dict) - def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + def pre_weights(self, config: PretrainedConfig) -> List[Optional[WeightInfo]]: # assume fleshed out names for now in self.overrides weights = [ self._substitute(wi, config=config) for wi in self.definition.pre_weights ] if self.overrides: - weights = [wi for wi in weights if wi.name in self.overrides["pre_weights"]] + new_weights = [] + for wi in self.overrides["pre_weights"]: + for w in weights: + if not w: + new_weights.append(w) + elif w.name == wi: + new_weights.append(w) + weights = new_weights return weights @@ -253,16 +260,21 @@ def layer_weights( for wi in self.definition.layer_templates.weights ] - def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + def post_weights(self, config: PretrainedConfig) -> List[Optional[WeightInfo]]: # assume fleshed out names for now in self.overrides weights = [ self._substitute(wi, config=config) for wi in self.definition.post_weights ] if self.overrides: - weights = [ - wi for wi in weights if wi.name in self.overrides["post_weights"] - ] + new_weights = [] + for wi in self.overrides["post_weights"]: + for w in weights: + if not w: + new_weights.append(w) + elif w.name == wi: + new_weights.append(w) + weights = new_weights return weights diff --git a/mergekit/common.py b/mergekit/common.py index 00a85b5b..366fe293 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -271,3 +271,15 @@ def items(self) -> Iterator[Tuple[T_K, T_V]]: def values(self) -> Iterator[T_V]: return self.data.values() + + +def zip_remove_nones(*args) -> List[List[Any]]: + """ + Example: + + >>> zip_remove_nones([1, None, 3], [2, 5, 6], [None, None, 7]) + [[1, 2], [5], [3, 6, 7]] + + + """ + return [[element for element in t if element is not None] for t in zip(args)] diff --git a/mergekit/plan.py b/mergekit/plan.py index 7ed6445b..48bfecda 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -25,7 +25,7 @@ ConfiguredArchitectureInfo, WeightInfo, ) -from mergekit.common import ImmutableMap, ModelReference +from mergekit.common import ImmutableMap, ModelReference, zip_remove_nones from mergekit.config import ( ConfigReader, InputSliceDefinition, @@ -264,7 +264,7 @@ def plan(self): self._space_planner.add_procedural_space(space) models_ = [s.model for s in self.config.slices[0].sources] - for weight_infos in zip( + for weight_infos in zip_remove_nones( *[ self.arch_dict[m].info.pre_weights(config=self.out_model_config) for m in models_ @@ -274,7 +274,7 @@ def plan(self): weight_infos[0], list(weight_infos), models_, - ConfigReader( # possible trouble here? + ConfigReader( config=self.config, t=0, tensor_name=weight_infos[0].name, @@ -285,7 +285,7 @@ def plan(self): self.plan_slice(out_slice) models_ = [s.model for s in self.config.slices[-1].sources] - for weight_infos in zip( + for weight_infos in zip_remove_nones( *[ self.arch_dict[m].info.post_weights(config=self.out_model_config) for m in models_ From 0d38791dee0a3f21644057bd9d03ca23a2675222 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Fri, 16 Feb 2024 02:50:03 -0500 Subject: [PATCH 13/29] Add missing __init__ --- mergekit/_data/mappings/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 mergekit/_data/mappings/__init__.py diff --git a/mergekit/_data/mappings/__init__.py b/mergekit/_data/mappings/__init__.py new file mode 100644 index 00000000..e69de29b From 36f76391f2f998b3ca87c53f14b324b4fede4f2c Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Fri, 16 Feb 2024 02:56:31 -0500 Subject: [PATCH 14/29] Fix missing submodule for mappings --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index fbd3fbde..87f764aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ packages = [ "mergekit.scripts", "mergekit._data", "mergekit._data.architectures", + "mergekiy._data.mappings", ] [tool.isort] From 2536033995beae2a7dd0ff887bf0997c961db3e1 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Fri, 16 Feb 2024 02:58:40 -0500 Subject: [PATCH 15/29] Spelling correction --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 87f764aa..ced30a20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ packages = [ "mergekit.scripts", "mergekit._data", "mergekit._data.architectures", - "mergekiy._data.mappings", + "mergekit._data.mappings", ] [tool.isort] From 3f552a13da07e0c5e0b4bf6f5b98385dddbaeaf1 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Sat, 17 Feb 2024 00:20:24 -0500 Subject: [PATCH 16/29] Corrections post test run --- mergekit/architecture.py | 72 +++++++++++++++++++++------------------- mergekit/common.py | 2 +- mergekit/merge.py | 13 ++++++-- mergekit/plan.py | 20 +++++++++-- 4 files changed, 66 insertions(+), 41 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index e587206c..dcb45994 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -23,6 +23,7 @@ from typing_extensions import Literal import mergekit._data.architectures +import mergekit._data.mappings class WeightInfo(BaseModel, frozen=True): @@ -123,6 +124,11 @@ def has_defined_spaces(self) -> bool: return False +class Mapping(BaseModel, frozen=True): + pre_weights: List[Optional[str]] + post_weights: List[Optional[str]] + + class MappingInfo(BaseModel, frozen=True): """Information about a mapping between two models. @@ -138,33 +144,30 @@ class MappingInfo(BaseModel, frozen=True): pre_weights_mapping: Dict[str, List[Optional[str]]] post_weights_mapping: Dict[str, List[Optional[str]]] - -class Mapping(BaseModel, frozen=True): - pre_weights: List[Optional[str]] - post_weights: List[Optional[str]] + def destination_map(self) -> Mapping: + return Mapping( + pre_weights=self.pre_weights_mapping["destination"], + post_weights=self.post_weights_mapping["destination"], + ) class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True): info: ArchitectureInfo config: PretrainedConfig - overrides: Optional[ - Dict[str, List[Optional[str]]] - ] # TODO: check if the optional is necessary + overrides: Optional[Mapping] = None # TODO: check if the optional is necessary def num_layers(self) -> int: return self.info.num_layers(self.config) - def pre_weights(self) -> List[WeightInfo]: - if not self.overrides: - return self.info.pre_weights(self.config) - - return self.overrides["pre_weights"] - - def post_weights(self) -> List[WeightInfo]: - if not self.overrides: - return self.info.post_weights(self.config) + def pre_weights(self) -> List[Optional[WeightInfo]]: + if self.overrides: + return self.info.pre_weights(self.config, self.overrides.pre_weights) + return self.info.pre_weights(self.config) - return self.overrides["post_weights"] + def post_weights(self) -> List[Optional[WeightInfo]]: + if self.overrides: + return self.info.post_weights(self.config, self.overrides.post_weights) + return self.info.post_weights(self.config) def layer_weights(self, index: int) -> List[WeightInfo]: return self.info.layer_weights(index, self.config) @@ -175,9 +178,7 @@ def procedural_spaces(self) -> List[ProceduralSpaceInfo]: def all_weights(self) -> List[WeightInfo]: return self.info.all_weights(self.config) - def update_overrides( - self, overrides: Dict[str, List[str]] - ) -> "ConfiguredArchitectureInfo": + def update_overrides(self, overrides: Mapping) -> "ConfiguredArchitectureInfo": return ConfiguredArchitectureInfo( info=self.info, config=self.config, overrides=overrides ) @@ -234,15 +235,17 @@ def _substitute( ) return type(item).model_validate(obj_dict) - def pre_weights(self, config: PretrainedConfig) -> List[Optional[WeightInfo]]: + def pre_weights( + self, config: PretrainedConfig, overrides: Optional[List[Optional[str]]] = None + ) -> List[Optional[WeightInfo]]: # assume fleshed out names for now in self.overrides weights = [ self._substitute(wi, config=config) for wi in self.definition.pre_weights ] - if self.overrides: + if overrides: new_weights = [] - for wi in self.overrides["pre_weights"]: + for wi in overrides: for w in weights: if not w: new_weights.append(w) @@ -260,15 +263,17 @@ def layer_weights( for wi in self.definition.layer_templates.weights ] - def post_weights(self, config: PretrainedConfig) -> List[Optional[WeightInfo]]: + def post_weights( + self, config: PretrainedConfig, overrides: Optional[List[Optional[str]]] = None + ) -> List[Optional[WeightInfo]]: # assume fleshed out names for now in self.overrides weights = [ self._substitute(wi, config=config) for wi in self.definition.post_weights ] - if self.overrides: + if overrides: new_weights = [] - for wi in self.overrides["post_weights"]: + for wi in overrides: for w in weights: if not w: new_weights.append(w) @@ -377,7 +382,7 @@ def _load_all_architectures() -> ( MISTRAL_INFO = _load_json_arch("mistral.json") -def _load_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: +def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: if len(config.architectures) != 1: raise RuntimeError("More than one architecture in config?") @@ -402,13 +407,13 @@ def _load_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: ) -def _load_arch_mappings(name) -> JsonArchitectureInfo: +def _load_arch_mappings(name) -> MappingInfo: text = importlib.resources.read_text(mergekit._data.mappings, name) return MappingInfo.model_validate_json(text) # TODO: should be immutable map -def _load_all_mappings() -> Dict[str, Dict[str, Dict[str, List[str]]]]: +def _load_all_mappings() -> Dict[str, Dict[str, Mapping]]: mappings: Dict[str, MappingInfo] = {} for f in importlib.resources.contents(mergekit._data.mappings): if f.lower().endswith(".json"): @@ -419,12 +424,9 @@ def _load_all_mappings() -> Dict[str, Dict[str, Dict[str, List[str]]]]: for destination_architecture in mapping.destination_architectures: if destination_architecture not in mappings[start_architecture]: - pre_weights = mapping["pre_weights"]["destination"] - post_weights = mapping["post_weights"]["destination"] - mappings[start_architecture][destination_architecture] = { - "pre_weights": pre_weights, - "post_weights": post_weights, - } + mappings[start_architecture][ + destination_architecture + ] = mapping.destination_map() # TODO: think about reverse mapping return mappings diff --git a/mergekit/common.py b/mergekit/common.py index 366fe293..d7c39285 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -282,4 +282,4 @@ def zip_remove_nones(*args) -> List[List[Any]]: """ - return [[element for element in t if element is not None] for t in zip(args)] + return [[element for element in t if element is not None] for t in zip(*args)] diff --git a/mergekit/merge.py b/mergekit/merge.py index 9036bc33..4f75e164 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -75,10 +75,17 @@ def run_merge( if merge_config.align_weights: new_model_arch_info = [model_arch_info[0]] + base_model = model_arch_info[0][1] for m, destination_arch_info in model_arch_info[1:]: - mapping = JSON_MAPPINGS.get(arch_info.config.architecture, {}).get( - destination_arch_info.config.architecture_version - ) + mapping = None + for arch in base_model.config.architectures: + for destination_arch in destination_arch_info.config.architectures: + mapping = JSON_MAPPINGS.get(arch, {}).get(destination_arch) + if mapping: + break + if mapping: + break + if mapping: new_model_arch_info.append( (m, destination_arch_info.update_overrides(mapping)) diff --git a/mergekit/plan.py b/mergekit/plan.py index 48bfecda..87eebb8a 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -264,12 +264,26 @@ def plan(self): self._space_planner.add_procedural_space(space) models_ = [s.model for s in self.config.slices[0].sources] + print(models_) + print( + [ + # TODO: this is wayy too complicated + self.arch_dict[m].info.pre_weights( + config=self.out_model_config, overrides=self.arch_dict[m].overrides + ) + for m in models_ + ] + ) for weight_infos in zip_remove_nones( *[ - self.arch_dict[m].info.pre_weights(config=self.out_model_config) + # TODO: this is wayy too complicated + self.arch_dict[m].info.pre_weights( + config=self.out_model_config, overrides=self.arch_dict[m].overrides + ) for m in models_ ] ): + print(weight_infos) self.plan_tensor( weight_infos[0], list(weight_infos), @@ -287,7 +301,9 @@ def plan(self): models_ = [s.model for s in self.config.slices[-1].sources] for weight_infos in zip_remove_nones( *[ - self.arch_dict[m].info.post_weights(config=self.out_model_config) + self.arch_dict[m].info.post_weights( + config=self.out_model_config, overrides=self.arch_dict[m].overrides + ) for m in models_ ] ): From 4f179d77fd98d4bdad6c910322cbc140eb4be05d Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Tue, 20 Feb 2024 16:18:24 -0500 Subject: [PATCH 17/29] Correction --- mergekit/common.py | 21 ++++++++++++++++----- mergekit/plan.py | 23 +++++++---------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/mergekit/common.py b/mergekit/common.py index d7c39285..70e5f1d6 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -273,13 +273,24 @@ def values(self) -> Iterator[T_V]: return self.data.values() -def zip_remove_nones(*args) -> List[List[Any]]: +def zip_remove_nones(keys: List[Any], *args) -> List[Tuple[List[Any], List[Any]]]: """ Example: - >>> zip_remove_nones([1, None, 3], [2, 5, 6], [None, None, 7]) - [[1, 2], [5], [3, 6, 7]] - + >>> zip_remove_nones(['a','b','c'], [1, None, 3], [2, 5, 6], [None, None, 7]) + [ + (['a','b'], [1,2]), + (['b'], [5]), + (['a','b','c'], [3, 6, 7]) + ] """ - return [[element for element in t if element is not None] for t in zip(*args)] + result = [] + + for zipped in zip(*args): + r = [(k, v) for k, v in zip(keys, zipped) if v is not None] + _keys = [k for k, v in r] + _values = [v for k, v in r] + result.append((_keys, _values)) + + return result diff --git a/mergekit/plan.py b/mergekit/plan.py index 87eebb8a..29c45c96 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -264,30 +264,21 @@ def plan(self): self._space_planner.add_procedural_space(space) models_ = [s.model for s in self.config.slices[0].sources] - print(models_) - print( - [ - # TODO: this is wayy too complicated - self.arch_dict[m].info.pre_weights( - config=self.out_model_config, overrides=self.arch_dict[m].overrides - ) - for m in models_ - ] - ) - for weight_infos in zip_remove_nones( + + for models, weight_infos in zip_remove_nones( + models_, *[ # TODO: this is wayy too complicated self.arch_dict[m].info.pre_weights( config=self.out_model_config, overrides=self.arch_dict[m].overrides ) for m in models_ - ] + ], ): - print(weight_infos) self.plan_tensor( weight_infos[0], list(weight_infos), - models_, + models, ConfigReader( config=self.config, t=0, @@ -299,7 +290,7 @@ def plan(self): self.plan_slice(out_slice) models_ = [s.model for s in self.config.slices[-1].sources] - for weight_infos in zip_remove_nones( + for models, weight_infos in zip_remove_nones( *[ self.arch_dict[m].info.post_weights( config=self.out_model_config, overrides=self.arch_dict[m].overrides @@ -310,7 +301,7 @@ def plan(self): self.plan_tensor( weight_infos[0], list(weight_infos), - models_, + models, ConfigReader( config=self.config, t=1, From d2f649fdc11ae008e3c25d44febf41ed106e2997 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Tue, 20 Feb 2024 17:06:31 -0500 Subject: [PATCH 18/29] Correction --- mergekit/plan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mergekit/plan.py b/mergekit/plan.py index 29c45c96..1251fcf6 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -291,12 +291,13 @@ def plan(self): models_ = [s.model for s in self.config.slices[-1].sources] for models, weight_infos in zip_remove_nones( + models_, *[ self.arch_dict[m].info.post_weights( config=self.out_model_config, overrides=self.arch_dict[m].overrides ) for m in models_ - ] + ], ): self.plan_tensor( weight_infos[0], From a79a75a786d472b0fe80f823263bc04c766c367f Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Tue, 20 Feb 2024 21:14:52 -0500 Subject: [PATCH 19/29] Charles' review changes --- mergekit/_data/mappings/a_b.json | 11 ++-- mergekit/architecture.py | 90 +++++++++++--------------------- mergekit/common.py | 23 -------- mergekit/plan.py | 31 +++++------ 4 files changed, 47 insertions(+), 108 deletions(-) diff --git a/mergekit/_data/mappings/a_b.json b/mergekit/_data/mappings/a_b.json index 329028dd..ec81181e 100644 --- a/mergekit/_data/mappings/a_b.json +++ b/mergekit/_data/mappings/a_b.json @@ -1,12 +1,9 @@ { "start_architectures": ["A"], "destination_architectures": ["B"], - "pre_weights_mapping":{ - "start" : ["a", "b", "c"], - "destination" : [null, "y", "z"] - }, - "post_weights_mapping":{ - "start" : ["d", "e", "f"], - "destination" : ["m", null, null] + "weights_mapping":{ + "a" : "d", + "b" : "e", + "c" : "m" } } diff --git a/mergekit/architecture.py b/mergekit/architecture.py index dcb45994..685bb30a 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -124,11 +124,6 @@ def has_defined_spaces(self) -> bool: return False -class Mapping(BaseModel, frozen=True): - pre_weights: List[Optional[str]] - post_weights: List[Optional[str]] - - class MappingInfo(BaseModel, frozen=True): """Information about a mapping between two models. @@ -141,36 +136,41 @@ class MappingInfo(BaseModel, frozen=True): start_architectures: List[str] destination_architectures: List[str] - pre_weights_mapping: Dict[str, List[Optional[str]]] - post_weights_mapping: Dict[str, List[Optional[str]]] - - def destination_map(self) -> Mapping: - return Mapping( - pre_weights=self.pre_weights_mapping["destination"], - post_weights=self.post_weights_mapping["destination"], - ) + weights_mapping: Dict[str, str] class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True): info: ArchitectureInfo config: PretrainedConfig - overrides: Optional[Mapping] = None # TODO: check if the optional is necessary + overrides: Optional[Dict[str, str]] = None + + def _substitute_name(self, weight: WeightInfo) -> WeightInfo: + if self.overrides and weight.name in self.overrides: + return WeightInfo( + name=self.overrides[weight.name], **self.model_dump(exclude=["name"]) + ) + return weight def num_layers(self) -> int: return self.info.num_layers(self.config) - def pre_weights(self) -> List[Optional[WeightInfo]]: - if self.overrides: - return self.info.pre_weights(self.config, self.overrides.pre_weights) - return self.info.pre_weights(self.config) + def pre_weights(self) -> List[WeightInfo]: + return [ + self._substitute_name(wi) + for wi in self.info.pre_weights(self.config, self.overrides.pre_weights) + ] - def post_weights(self) -> List[Optional[WeightInfo]]: - if self.overrides: - return self.info.post_weights(self.config, self.overrides.post_weights) - return self.info.post_weights(self.config) + def post_weights(self) -> List[WeightInfo]: + return [ + self._substitute_name(wi) + for wi in self.info.post_weights(self.config, self.overrides.pre_weights) + ] def layer_weights(self, index: int) -> List[WeightInfo]: - return self.info.layer_weights(index, self.config) + return [ + self._substitute_name(wi) + for wi in self.info.layer_weights(index, self.config) + ] def procedural_spaces(self) -> List[ProceduralSpaceInfo]: return self.info.procedural_spaces(self.config) @@ -178,7 +178,9 @@ def procedural_spaces(self) -> List[ProceduralSpaceInfo]: def all_weights(self) -> List[WeightInfo]: return self.info.all_weights(self.config) - def update_overrides(self, overrides: Mapping) -> "ConfiguredArchitectureInfo": + def update_overrides( + self, overrides: Dict[str, str] + ) -> "ConfiguredArchitectureInfo": return ConfiguredArchitectureInfo( info=self.info, config=self.config, overrides=overrides ) @@ -235,52 +237,24 @@ def _substitute( ) return type(item).model_validate(obj_dict) - def pre_weights( - self, config: PretrainedConfig, overrides: Optional[List[Optional[str]]] = None - ) -> List[Optional[WeightInfo]]: - # assume fleshed out names for now in self.overrides + def pre_weights(self, config: PretrainedConfig) -> List[Optional[WeightInfo]]: weights = [ self._substitute(wi, config=config) for wi in self.definition.pre_weights ] - if overrides: - new_weights = [] - for wi in overrides: - for w in weights: - if not w: - new_weights.append(w) - elif w.name == wi: - new_weights.append(w) - weights = new_weights - return weights - def layer_weights( - self, index: int, config: PretrainedConfig - ) -> Optional[List[WeightInfo]]: + def layer_weights(self, index: int, config: PretrainedConfig) -> List[WeightInfo]: return [ self._substitute(wi, config=config, layer_idx=index) for wi in self.definition.layer_templates.weights ] - def post_weights( - self, config: PretrainedConfig, overrides: Optional[List[Optional[str]]] = None - ) -> List[Optional[WeightInfo]]: - # assume fleshed out names for now in self.overrides + def post_weights(self, config: PretrainedConfig) -> Optional[WeightInfo]: weights = [ self._substitute(wi, config=config) for wi in self.definition.post_weights ] - if overrides: - new_weights = [] - for wi in overrides: - for w in weights: - if not w: - new_weights.append(w) - elif w.name == wi: - new_weights.append(w) - weights = new_weights - return weights def sliceable(self) -> bool: @@ -412,8 +386,7 @@ def _load_arch_mappings(name) -> MappingInfo: return MappingInfo.model_validate_json(text) -# TODO: should be immutable map -def _load_all_mappings() -> Dict[str, Dict[str, Mapping]]: +def _load_all_mappings() -> Dict[str, Dict[str, Dict[str, str]]]: mappings: Dict[str, MappingInfo] = {} for f in importlib.resources.contents(mergekit._data.mappings): if f.lower().endswith(".json"): @@ -426,9 +399,8 @@ def _load_all_mappings() -> Dict[str, Dict[str, Mapping]]: if destination_architecture not in mappings[start_architecture]: mappings[start_architecture][ destination_architecture - ] = mapping.destination_map() + ] = mapping.weights_mapping - # TODO: think about reverse mapping return mappings diff --git a/mergekit/common.py b/mergekit/common.py index 70e5f1d6..00a85b5b 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -271,26 +271,3 @@ def items(self) -> Iterator[Tuple[T_K, T_V]]: def values(self) -> Iterator[T_V]: return self.data.values() - - -def zip_remove_nones(keys: List[Any], *args) -> List[Tuple[List[Any], List[Any]]]: - """ - Example: - - >>> zip_remove_nones(['a','b','c'], [1, None, 3], [2, 5, 6], [None, None, 7]) - [ - (['a','b'], [1,2]), - (['b'], [5]), - (['a','b','c'], [3, 6, 7]) - ] - - """ - result = [] - - for zipped in zip(*args): - r = [(k, v) for k, v in zip(keys, zipped) if v is not None] - _keys = [k for k, v in r] - _values = [v for k, v in r] - result.append((_keys, _values)) - - return result diff --git a/mergekit/plan.py b/mergekit/plan.py index 1251fcf6..0877455e 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -25,7 +25,7 @@ ConfiguredArchitectureInfo, WeightInfo, ) -from mergekit.common import ImmutableMap, ModelReference, zip_remove_nones +from mergekit.common import ImmutableMap, ModelReference from mergekit.config import ( ConfigReader, InputSliceDefinition, @@ -263,22 +263,18 @@ def plan(self): for space in self.arch_info.procedural_spaces(config=self.out_model_config): self._space_planner.add_procedural_space(space) - models_ = [s.model for s in self.config.slices[0].sources] + _models = [s.model for s in self.config.slices[0].sources] - for models, weight_infos in zip_remove_nones( - models_, + for weight_infos in zip( *[ - # TODO: this is wayy too complicated - self.arch_dict[m].info.pre_weights( - config=self.out_model_config, overrides=self.arch_dict[m].overrides - ) - for m in models_ - ], + self.arch_dict[m].info.pre_weights(config=self.out_model_config) + for m in _models + ] ): self.plan_tensor( weight_infos[0], list(weight_infos), - models, + _models, ConfigReader( config=self.config, t=0, @@ -289,20 +285,17 @@ def plan(self): for out_slice in self.config.slices: self.plan_slice(out_slice) - models_ = [s.model for s in self.config.slices[-1].sources] - for models, weight_infos in zip_remove_nones( - models_, + _models = [s.model for s in self.config.slices[-1].sources] + for weight_infos in zip( *[ - self.arch_dict[m].info.post_weights( - config=self.out_model_config, overrides=self.arch_dict[m].overrides - ) - for m in models_ + self.arch_dict[m].info.post_weights(config=self.out_model_config) + for m in _models ], ): self.plan_tensor( weight_infos[0], list(weight_infos), - models, + _models, ConfigReader( config=self.config, t=1, From c7c46b1bedf5c55a1d2151333a85ac61de4083dd Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 03:40:30 -0500 Subject: [PATCH 20/29] More feedback base corrections --- mergekit/architecture.py | 18 +++++++++++++- mergekit/merge.py | 53 ++++------------------------------------ mergekit/plan.py | 42 +++++++++++++++++++++++-------- 3 files changed, 54 insertions(+), 59 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 685bb30a..8afac67c 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -123,6 +123,16 @@ def has_defined_spaces(self) -> bool: """ return False + @abstractmethod + def _substitute( + self, + item: Union[WeightInfo, ProceduralSpaceInfo], + config: PretrainedConfig, + layer_idx: Optional[int] = None, + ) -> Union[WeightInfo, ProceduralSpaceInfo]: + """Substitute any template variables in the item with values from the config""" + ... + class MappingInfo(BaseModel, frozen=True): """Information about a mapping between two models. @@ -146,7 +156,7 @@ class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed def _substitute_name(self, weight: WeightInfo) -> WeightInfo: if self.overrides and weight.name in self.overrides: - return WeightInfo( + return self.info.WeightInfo( name=self.overrides[weight.name], **self.model_dump(exclude=["name"]) ) return weight @@ -181,6 +191,12 @@ def all_weights(self) -> List[WeightInfo]: def update_overrides( self, overrides: Dict[str, str] ) -> "ConfiguredArchitectureInfo": + # NOTE: this makes sure strings in overrides if templates are filled in + overrides = { + self.info._substitute(k, self.config): self.info._substitute(v, self.config) + for k, v in overrides.items() + } + return ConfiguredArchitectureInfo( info=self.info, config=self.config, overrides=overrides ) diff --git a/mergekit/merge.py b/mergekit/merge.py index 4f75e164..368fa245 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -20,14 +20,7 @@ import tqdm import transformers -from mergekit.architecture import ( - JSON_MAPPINGS, # TODO best write a function to get the JSON_MAPPINGS -) -from mergekit.architecture import ( - ArchitectureInfo, - ConfiguredArchitectureInfo, - get_architecture_info, -) +from mergekit.architecture import ArchitectureInfo, get_architecture_info from mergekit.card import generate_card from mergekit.common import ModelReference from mergekit.config import MergeConfiguration @@ -50,51 +43,18 @@ def run_merge( if not merge_config.models and not merge_config.slices: raise RuntimeError("No output requested") - model_arch_info: List[Tuple[ModelReference, ConfiguredArchitectureInfo]] = [ - ( - m, - ConfiguredArchitectureInfo( - info=get_architecture_info( - m.config(trust_remote_code=options.trust_remote_code) - ), - config=m.config(trust_remote_code=options.trust_remote_code), - ), - ) + model_arch_info: List[ArchitectureInfo] = [ + get_architecture_info(m.config(trust_remote_code=options.trust_remote_code)) for m in merge_config.referenced_models() ] if not (options.allow_crimes or merge_config.align_weights): - if not all( - a[1].info == model_arch_info[0][1].info for a in model_arch_info[1:] - ): + if not all(a == model_arch_info[0] for a in model_arch_info[1:]): raise RuntimeError( "Must specify --allow-crimes to attempt to mix different architectures" ) - arch_info: ArchitectureInfo = model_arch_info[0][1].info - - if merge_config.align_weights: - new_model_arch_info = [model_arch_info[0]] - base_model = model_arch_info[0][1] - for m, destination_arch_info in model_arch_info[1:]: - mapping = None - for arch in base_model.config.architectures: - for destination_arch in destination_arch_info.config.architectures: - mapping = JSON_MAPPINGS.get(arch, {}).get(destination_arch) - if mapping: - break - if mapping: - break - - if mapping: - new_model_arch_info.append( - (m, destination_arch_info.update_overrides(mapping)) - ) - else: - # warn user - new_model_arch_info.append((m, destination_arch_info)) - - model_arch_info = new_model_arch_info + arch_info: ArchitectureInfo = model_arch_info[0] # initialize loader cache and set options loader_cache = LoaderCache() @@ -111,8 +71,6 @@ def run_merge( logging.info("Planning operations") targets = MergePlanner( merge_config, - arch_info, - dict(model_arch_info), out_path=out_path, options=options, out_model_config=cfg_out, @@ -203,7 +161,6 @@ def _model_out_config( res.torch_dtype = config.dtype try: - print(config.slices) num_layers = sum( s.sources[0].layer_range[1] - s.sources[0].layer_range[0] for s in config.slices diff --git a/mergekit/plan.py b/mergekit/plan.py index 0877455e..7c0f7b4a 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -21,9 +21,11 @@ from mergekit import merge_methods from mergekit.architecture import ( + JSON_MAPPINGS, ArchitectureInfo, ConfiguredArchitectureInfo, WeightInfo, + get_architecture_info, ) from mergekit.common import ImmutableMap, ModelReference from mergekit.config import ( @@ -43,8 +45,6 @@ class MergePlanner: config: MergeConfiguration - arch_info: ArchitectureInfo - arch_dict: Dict[ModelReference, ConfiguredArchitectureInfo] clone_tensors: bool trust_remote_code: bool out_model_config: Any @@ -58,19 +58,11 @@ class MergePlanner: def __init__( self, config: MergeConfiguration, - arch_info: ArchitectureInfo, - arch_dict: Dict[ - ModelReference, ConfiguredArchitectureInfo - ], # perhaps this should no longer be a disjoint step out_path: str, options: MergeOptions, out_model_config: Any, ): self.config = config - self.arch_info = ( - arch_info # Special because how referenced models list is constructed ? - ) - self.arch_dict = arch_dict self.clone_tensors = options.clone_tensors self.trust_remote_code = options.trust_remote_code self.out_model_config = out_model_config @@ -92,6 +84,36 @@ def __init__( if config.base_model and config.align_weights: self._space_planner = SpacePlanner(config.base_model) + self.arch_dict: Dict[ModelReference, ConfiguredArchitectureInfo] = {} + _models = config.referenced_models() + base_model = _models[0] + base_config = base_model.config(trust_remote_code=options.trust_remote_code) + self.arch_info = get_architecture_info(base_config) + + for m in config.referenced_models(): + m_config = m.config(trust_remote_code=options.trust_remote_code) + configured_arch_info = ConfiguredArchitectureInfo( + info=get_architecture_info(m_config), + config=m.config(m_config), + ) + + if config.align_weights: + mapping = None + for arch in base_config.architectures: + for destination_arch in m_config.architectures: + mapping = JSON_MAPPINGS.get(arch, {}).get(destination_arch) + if mapping: + break + if mapping: + break + + if mapping: + configured_arch_info = configured_arch_info.update_overrides( + mapping + ) + + self.arch_dict[m] = configured_arch_info + def normalize_config(self): # if models to merge are specified instead of output slices, compute them if self.config.models: From 552f8d51c6d873de04a6f372e62d01f1ff2765ba Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 03:57:14 -0500 Subject: [PATCH 21/29] More corrections --- mergekit/architecture.py | 10 ++++------ mergekit/plan.py | 19 +++++++++---------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 8afac67c..2d8cdde9 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -188,10 +188,8 @@ def procedural_spaces(self) -> List[ProceduralSpaceInfo]: def all_weights(self) -> List[WeightInfo]: return self.info.all_weights(self.config) - def update_overrides( - self, overrides: Dict[str, str] - ) -> "ConfiguredArchitectureInfo": - # NOTE: this makes sure strings in overrides if templates are filled in + def set_overrides(self, overrides: Dict[str, str]) -> "ConfiguredArchitectureInfo": + # NOTE: this makes sure template strings in overrides are filled in overrides = { self.info._substitute(k, self.config): self.info._substitute(v, self.config) for k, v in overrides.items() @@ -253,7 +251,7 @@ def _substitute( ) return type(item).model_validate(obj_dict) - def pre_weights(self, config: PretrainedConfig) -> List[Optional[WeightInfo]]: + def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: weights = [ self._substitute(wi, config=config) for wi in self.definition.pre_weights ] @@ -266,7 +264,7 @@ def layer_weights(self, index: int, config: PretrainedConfig) -> List[WeightInfo for wi in self.definition.layer_templates.weights ] - def post_weights(self, config: PretrainedConfig) -> Optional[WeightInfo]: + def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: weights = [ self._substitute(wi, config=config) for wi in self.definition.post_weights ] diff --git a/mergekit/plan.py b/mergekit/plan.py index 7c0f7b4a..67c282b6 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -99,19 +99,18 @@ def __init__( if config.align_weights: mapping = None - for arch in base_config.architectures: - for destination_arch in m_config.architectures: - mapping = JSON_MAPPINGS.get(arch, {}).get(destination_arch) - if mapping: - break + for arch, destination_arch in [ + (arch1, arch2) + for arch1 in base_config.architectures + for arch2 in m_config.architectures + ]: + mapping = JSON_MAPPINGS.get(arch, {}).get(destination_arch) if mapping: + configured_arch_info = configured_arch_info.set_overrides( + mapping + ) break - if mapping: - configured_arch_info = configured_arch_info.update_overrides( - mapping - ) - self.arch_dict[m] = configured_arch_info def normalize_config(self): From 3ebe13d8d66ce159ed07835a9f65d7bd489163e0 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 16:04:40 -0500 Subject: [PATCH 22/29] Code review corrections --- mergekit/architecture.py | 85 +++++++++++++++++++++++++--------------- mergekit/plan.py | 13 +++++- 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 2d8cdde9..a85884f6 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -123,16 +123,6 @@ def has_defined_spaces(self) -> bool: """ return False - @abstractmethod - def _substitute( - self, - item: Union[WeightInfo, ProceduralSpaceInfo], - config: PretrainedConfig, - layer_idx: Optional[int] = None, - ) -> Union[WeightInfo, ProceduralSpaceInfo]: - """Substitute any template variables in the item with values from the config""" - ... - class MappingInfo(BaseModel, frozen=True): """Information about a mapping between two models. @@ -190,13 +180,34 @@ def all_weights(self) -> List[WeightInfo]: def set_overrides(self, overrides: Dict[str, str]) -> "ConfiguredArchitectureInfo": # NOTE: this makes sure template strings in overrides are filled in - overrides = { - self.info._substitute(k, self.config): self.info._substitute(v, self.config) - for k, v in overrides.items() - } + + def detect_layer_template(s: str) -> bool: + return "{" in s and "layer_index" in s + + new_overrides = {} + + for k, v in overrides.items(): + if detect_layer_template(k): + if not detect_layer_template(v): + raise RuntimeError( + f"Usage of mapping requires one-to-one mapping between architectures. A template was found in {k} but not in {v}" + ) + + for layer_idx in range(self.num_layers()): + new_overrides[ + _template_substitution(k, self.config, layer_idx) + ] = _template_substitution(v, self.config, layer_idx) + elif detect_layer_template(v): + raise RuntimeError( + f"Usage requires one-to-one mapping for {k} and {v}. A template was found in {v} but not in {k}" + ) + else: + new_overrides[ + _template_substitution(k, self.config) + ] = _template_substitution(v, self.config) return ConfiguredArchitectureInfo( - info=self.info, config=self.config, overrides=overrides + info=self.info, config=self.config, overrides=new_overrides ) @@ -219,6 +230,30 @@ class TemplateWithArithmetic(string.Template): idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)" +def _template_substitution( + template: str, num_layers: int, layer_idx: Optional[int] +) -> str: + if "{" not in template: + return template + + substitutions = { + "num_layers": num_layers, + "num_layers+1": num_layers + 1, + "num_layers-1": num_layers - 1, + } + + if layer_idx is not None: + substitutions.update( + { + "layer_index": layer_idx, + "layer_index+1": layer_idx + 1, + "layer_index-1": layer_idx - 1, + } + ) + + return TemplateWithArithmetic(template).substitute(substitutions) + + class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True): definition: JSONArchitectureDefinition @@ -229,25 +264,11 @@ def _substitute( layer_idx: Optional[int] = None, ) -> Union[WeightInfo, ProceduralSpaceInfo]: num_layers = self.num_layers(config) - substitutions = { - "num_layers": num_layers, - "num_layers+1": num_layers + 1, - "num_layers-1": num_layers - 1, - } - if layer_idx is not None: - substitutions.update( - { - "layer_index": layer_idx, - "layer_index+1": layer_idx + 1, - "layer_index-1": layer_idx - 1, - } - ) - obj_dict = item.model_dump(mode="json", exclude_unset=True) for key in obj_dict: - if isinstance(obj_dict[key], str) and "{" in obj_dict[key]: - obj_dict[key] = TemplateWithArithmetic(obj_dict[key]).substitute( - substitutions + if isinstance(obj_dict[key], str): + obj_dict[key] = _template_substitution( + obj_dict[key], num_layers, layer_idx ) return type(item).model_validate(obj_dict) diff --git a/mergekit/plan.py b/mergekit/plan.py index 67c282b6..3e5b0fc1 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -44,6 +44,7 @@ class MergePlanner: + arch_dict: Dict[ModelReference, ConfiguredArchitectureInfo] config: MergeConfiguration clone_tensors: bool trust_remote_code: bool @@ -84,7 +85,7 @@ def __init__( if config.base_model and config.align_weights: self._space_planner = SpacePlanner(config.base_model) - self.arch_dict: Dict[ModelReference, ConfiguredArchitectureInfo] = {} + self.arch_dict = {} _models = config.referenced_models() base_model = _models[0] base_config = base_model.config(trust_remote_code=options.trust_remote_code) @@ -99,11 +100,16 @@ def __init__( if config.align_weights: mapping = None + is_same_arch = False + for arch, destination_arch in [ (arch1, arch2) for arch1 in base_config.architectures for arch2 in m_config.architectures ]: + if arch == destination_arch: + is_same_arch = True + break mapping = JSON_MAPPINGS.get(arch, {}).get(destination_arch) if mapping: configured_arch_info = configured_arch_info.set_overrides( @@ -111,6 +117,11 @@ def __init__( ) break + if not (mapping or is_same_arch): + raise RuntimeError( + "For cross architecture merges, a mapping is required!!!" + ) + self.arch_dict[m] = configured_arch_info def normalize_config(self): From cf6a77e6a1c305444fb65a665f1d4741f0f10fc0 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 16:28:37 -0500 Subject: [PATCH 23/29] Type annotation and better error_message --- mergekit/plan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mergekit/plan.py b/mergekit/plan.py index 3e5b0fc1..0bf140fb 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -45,6 +45,7 @@ class MergePlanner: arch_dict: Dict[ModelReference, ConfiguredArchitectureInfo] + arch_info: ArchitectureInfo config: MergeConfiguration clone_tensors: bool trust_remote_code: bool @@ -119,7 +120,7 @@ def __init__( if not (mapping or is_same_arch): raise RuntimeError( - "For cross architecture merges, a mapping is required!!!" + f"For cross architecture merge between architectures {base_config.architectures[0]} and {m_config.architectures[0]}, a mapping must be provided" ) self.arch_dict[m] = configured_arch_info From 4d6a9ba42cb53991db60a842c176ea1c1f48db9d Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 16:39:41 -0500 Subject: [PATCH 24/29] Better error message --- mergekit/architecture.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index a85884f6..c9fbfa27 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -190,7 +190,7 @@ def detect_layer_template(s: str) -> bool: if detect_layer_template(k): if not detect_layer_template(v): raise RuntimeError( - f"Usage of mapping requires one-to-one mapping between architectures. A template was found in {k} but not in {v}" + f"Cross-architecture merging requires one-to-one mapping between architectures. A template was found in {k} but not in {v}" ) for layer_idx in range(self.num_layers()): @@ -199,7 +199,7 @@ def detect_layer_template(s: str) -> bool: ] = _template_substitution(v, self.config, layer_idx) elif detect_layer_template(v): raise RuntimeError( - f"Usage requires one-to-one mapping for {k} and {v}. A template was found in {v} but not in {k}" + f"Cross-architecture merging requires one-to-one mapping between architectures. A template was found in {v} but not in {k}" ) else: new_overrides[ From a7b52da1e589b06a23a9ac5a371266670b495c55 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 16:44:06 -0500 Subject: [PATCH 25/29] Better error message --- mergekit/plan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mergekit/plan.py b/mergekit/plan.py index 0bf140fb..e379e82a 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -120,7 +120,7 @@ def __init__( if not (mapping or is_same_arch): raise RuntimeError( - f"For cross architecture merge between architectures {base_config.architectures[0]} and {m_config.architectures[0]}, a mapping must be provided" + f"For cross architecture merge between architectures {base_config.architectures[0]} and {m_config.architectures[0]}, a mapping must be provided in the directory _data/mappings" ) self.arch_dict[m] = configured_arch_info From f86a2dab5eec2fc84327c42d4c4546f85b425d7f Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 17:47:35 -0500 Subject: [PATCH 26/29] Add basic failing test for testing set_override --- tests/test_architecture.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/test_architecture.py diff --git a/tests/test_architecture.py b/tests/test_architecture.py new file mode 100644 index 00000000..11f88845 --- /dev/null +++ b/tests/test_architecture.py @@ -0,0 +1,16 @@ +import pytest +from transformers import LlamaConfig + +from mergekit.architecture import ConfiguredArchitectureInfo, get_architecture_info + + +class TestArchitecture: + def test_set_overrides(self): + cfg = LlamaConfig(vocab_size=64, hidden_size=32) + arch_info = get_architecture_info(cfg) + configured_arch_info = ConfiguredArchitectureInfo(arch_info, cfg) + + overrides = {"a_{layer_index}": "b_layer_{layer_index}"} + new_config = configured_arch_info.set_overrides(overrides) + + assert new_config.overrides == overrides From 2eaacebe33801f61d31bbfa0ec165fb7c74ace80 Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 18:15:35 -0500 Subject: [PATCH 27/29] Test based corrections --- mergekit/architecture.py | 16 +++++++++------- tests/test_architecture.py | 8 ++++---- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mergekit/architecture.py b/mergekit/architecture.py index c9fbfa27..4a9116bb 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -182,10 +182,12 @@ def set_overrides(self, overrides: Dict[str, str]) -> "ConfiguredArchitectureInf # NOTE: this makes sure template strings in overrides are filled in def detect_layer_template(s: str) -> bool: - return "{" in s and "layer_index" in s + return "${" in s and "layer_index" in s new_overrides = {} + num_layers = self.num_layers() + for k, v in overrides.items(): if detect_layer_template(k): if not detect_layer_template(v): @@ -193,18 +195,18 @@ def detect_layer_template(s: str) -> bool: f"Cross-architecture merging requires one-to-one mapping between architectures. A template was found in {k} but not in {v}" ) - for layer_idx in range(self.num_layers()): + for layer_idx in range(num_layers): new_overrides[ - _template_substitution(k, self.config, layer_idx) - ] = _template_substitution(v, self.config, layer_idx) + _template_substitution(k, num_layers, layer_idx) + ] = _template_substitution(v, num_layers, layer_idx) elif detect_layer_template(v): raise RuntimeError( f"Cross-architecture merging requires one-to-one mapping between architectures. A template was found in {v} but not in {k}" ) else: new_overrides[ - _template_substitution(k, self.config) - ] = _template_substitution(v, self.config) + _template_substitution(k, num_layers) + ] = _template_substitution(v, num_layers) return ConfiguredArchitectureInfo( info=self.info, config=self.config, overrides=new_overrides @@ -231,7 +233,7 @@ class TemplateWithArithmetic(string.Template): def _template_substitution( - template: str, num_layers: int, layer_idx: Optional[int] + template: str, num_layers: int, layer_idx: Optional[int] = None ) -> str: if "{" not in template: return template diff --git a/tests/test_architecture.py b/tests/test_architecture.py index 11f88845..d70b46d9 100644 --- a/tests/test_architecture.py +++ b/tests/test_architecture.py @@ -1,16 +1,16 @@ import pytest -from transformers import LlamaConfig +from transformers import AutoConfig, LlamaConfig from mergekit.architecture import ConfiguredArchitectureInfo, get_architecture_info class TestArchitecture: def test_set_overrides(self): - cfg = LlamaConfig(vocab_size=64, hidden_size=32) + cfg = AutoConfig.from_pretrained("gpt-2") arch_info = get_architecture_info(cfg) configured_arch_info = ConfiguredArchitectureInfo(arch_info, cfg) - overrides = {"a_{layer_index}": "b_layer_{layer_index}"} + overrides = {"a_${layer_index}": "b_${layer_index}"} new_config = configured_arch_info.set_overrides(overrides) - assert new_config.overrides == overrides + assert new_config.overrides == {f"a_{i}": f"b_{i}" for i in range(12)} From 77ade329a595d7156b5aaa72467273d2cbeb7d3a Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 18:19:19 -0500 Subject: [PATCH 28/29] Correct test model name --- tests/test_architecture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_architecture.py b/tests/test_architecture.py index d70b46d9..1797dffd 100644 --- a/tests/test_architecture.py +++ b/tests/test_architecture.py @@ -6,7 +6,7 @@ class TestArchitecture: def test_set_overrides(self): - cfg = AutoConfig.from_pretrained("gpt-2") + cfg = AutoConfig.from_pretrained("gpt2") arch_info = get_architecture_info(cfg) configured_arch_info = ConfiguredArchitectureInfo(arch_info, cfg) From 40725e53cbc80db4f56dc6b0a2c47f32959df7ee Mon Sep 17 00:00:00 2001 From: Luke Meyers Date: Wed, 21 Feb 2024 18:24:44 -0500 Subject: [PATCH 29/29] Correct mistake --- tests/test_architecture.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_architecture.py b/tests/test_architecture.py index 1797dffd..6b97e5bb 100644 --- a/tests/test_architecture.py +++ b/tests/test_architecture.py @@ -8,7 +8,7 @@ class TestArchitecture: def test_set_overrides(self): cfg = AutoConfig.from_pretrained("gpt2") arch_info = get_architecture_info(cfg) - configured_arch_info = ConfiguredArchitectureInfo(arch_info, cfg) + configured_arch_info = ConfiguredArchitectureInfo(info=arch_info, config=cfg) overrides = {"a_${layer_index}": "b_${layer_index}"} new_config = configured_arch_info.set_overrides(overrides)