Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inter-architecture weights mapping #165

Merged
merged 29 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f18c0c3
WIP: most likely broken code to open up multiple architectures when p…
metric-space Feb 11, 2024
238419c
WIP: note and tiny correction
metric-space Feb 11, 2024
5c6de5b
WIP: corrections
metric-space Feb 12, 2024
cc4da48
Appeasing pre-commit
metric-space Feb 13, 2024
4d3b7f6
Add dependency scipy
metric-space Feb 13, 2024
81b2c14
Correction
metric-space Feb 13, 2024
86772ed
Attempt to compress code
metric-space Feb 13, 2024
0f4b44c
Correction
metric-space Feb 13, 2024
0dc55a4
Bypass mixed archiecture user warning/error if align_weights is enabled
metric-space Feb 13, 2024
0e3d4cb
WIP: Draft of correction
metric-space Feb 15, 2024
e5134f4
WIP: Example mapping
metric-space Feb 15, 2024
fa87d74
WIP: Code for handling unmapped pre/post weights
metric-space Feb 16, 2024
0d38791
Add missing __init__
metric-space Feb 16, 2024
36f7639
Fix missing submodule for mappings
metric-space Feb 16, 2024
2536033
Spelling correction
metric-space Feb 16, 2024
3f552a1
Corrections post test run
metric-space Feb 17, 2024
4f179d7
Correction
metric-space Feb 20, 2024
d2f649f
Correction
metric-space Feb 20, 2024
a79a75a
Charles' review changes
metric-space Feb 21, 2024
c7c46b1
More feedback base corrections
metric-space Feb 21, 2024
552f8d5
More corrections
metric-space Feb 21, 2024
3ebe13d
Code review corrections
metric-space Feb 21, 2024
cf6a77e
Type annotation and better error_message
metric-space Feb 21, 2024
4d6a9ba
Better error message
metric-space Feb 21, 2024
a7b52da
Better error message
metric-space Feb 21, 2024
f86a2da
Add basic failing test for testing set_override
metric-space Feb 21, 2024
2eaaceb
Test based corrections
metric-space Feb 21, 2024
77ade32
Correct test model name
metric-space Feb 21, 2024
40725e5
Correct mistake
metric-space Feb 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
9 changes: 9 additions & 0 deletions mergekit/_data/mappings/a_b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"start_architectures": ["A"],
"destination_architectures": ["B"],
"weights_mapping":{
"a" : "d",
"b" : "e",
"c" : "m"
}
}
155 changes: 130 additions & 25 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing_extensions import Literal

import mergekit._data.architectures
import mergekit._data.mappings


class WeightInfo(BaseModel, frozen=True):
Expand Down Expand Up @@ -123,28 +124,94 @@ def has_defined_spaces(self) -> bool:
return False


class MappingInfo(BaseModel, frozen=True):
"""Information about a mapping between two models.

Attributes:
from_model (str):
The name of the model from which the mapping originates.
to_model (str):
The name of the model to which the mapping applies.
"""

start_architectures: List[str]
destination_architectures: List[str]
weights_mapping: Dict[str, str]


class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True):
info: ArchitectureInfo
config: PretrainedConfig
overrides: Optional[Dict[str, str]] = None

def _substitute_name(self, weight: WeightInfo) -> WeightInfo:
if self.overrides and weight.name in self.overrides:
return self.info.WeightInfo(
name=self.overrides[weight.name], **self.model_dump(exclude=["name"])
)
return weight

def num_layers(self) -> int:
return self.info.num_layers(self.config)

def pre_weights(self) -> List[WeightInfo]:
return self.info.pre_weights(self.config)
return [
self._substitute_name(wi)
for wi in self.info.pre_weights(self.config, self.overrides.pre_weights)
]

def post_weights(self) -> List[WeightInfo]:
return self.info.post_weights(self.config)
return [
self._substitute_name(wi)
for wi in self.info.post_weights(self.config, self.overrides.pre_weights)
]

def layer_weights(self, index: int) -> List[WeightInfo]:
return self.info.layer_weights(index, self.config)
return [
self._substitute_name(wi)
for wi in self.info.layer_weights(index, self.config)
]

def procedural_spaces(self) -> List[ProceduralSpaceInfo]:
return self.info.procedural_spaces(self.config)

def all_weights(self) -> List[WeightInfo]:
return self.info.all_weights(self.config)

def set_overrides(self, overrides: Dict[str, str]) -> "ConfiguredArchitectureInfo":
# NOTE: this makes sure template strings in overrides are filled in

def detect_layer_template(s: str) -> bool:
return "${" in s and "layer_index" in s

new_overrides = {}

num_layers = self.num_layers()

for k, v in overrides.items():
if detect_layer_template(k):
if not detect_layer_template(v):
raise RuntimeError(
f"Cross-architecture merging requires one-to-one mapping between architectures. A template was found in {k} but not in {v}"
)

for layer_idx in range(num_layers):
new_overrides[
_template_substitution(k, num_layers, layer_idx)
] = _template_substitution(v, num_layers, layer_idx)
elif detect_layer_template(v):
raise RuntimeError(
f"Cross-architecture merging requires one-to-one mapping between architectures. A template was found in {v} but not in {k}"
)
else:
new_overrides[
_template_substitution(k, num_layers)
] = _template_substitution(v, num_layers)

return ConfiguredArchitectureInfo(
info=self.info, config=self.config, overrides=new_overrides
)


class JSONLayerTemplates(BaseModel, frozen=True):
weights: List[WeightInfo]
Expand All @@ -165,6 +232,30 @@ class TemplateWithArithmetic(string.Template):
idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)"


def _template_substitution(
template: str, num_layers: int, layer_idx: Optional[int] = None
) -> str:
if "{" not in template:
return template

substitutions = {
"num_layers": num_layers,
"num_layers+1": num_layers + 1,
"num_layers-1": num_layers - 1,
}

if layer_idx is not None:
substitutions.update(
{
"layer_index": layer_idx,
"layer_index+1": layer_idx + 1,
"layer_index-1": layer_idx - 1,
}
)

return TemplateWithArithmetic(template).substitute(substitutions)


class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True):
definition: JSONArchitectureDefinition

Expand All @@ -175,46 +266,34 @@ def _substitute(
layer_idx: Optional[int] = None,
) -> Union[WeightInfo, ProceduralSpaceInfo]:
num_layers = self.num_layers(config)
substitutions = {
"num_layers": num_layers,
"num_layers+1": num_layers + 1,
"num_layers-1": num_layers - 1,
}
if layer_idx is not None:
substitutions.update(
{
"layer_index": layer_idx,
"layer_index+1": layer_idx + 1,
"layer_index-1": layer_idx - 1,
}
)

obj_dict = item.model_dump(mode="json", exclude_unset=True)
for key in obj_dict:
if isinstance(obj_dict[key], str) and "{" in obj_dict[key]:
obj_dict[key] = TemplateWithArithmetic(obj_dict[key]).substitute(
substitutions
if isinstance(obj_dict[key], str):
obj_dict[key] = _template_substitution(
obj_dict[key], num_layers, layer_idx
)
return type(item).model_validate(obj_dict)

def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return [
weights = [
self._substitute(wi, config=config) for wi in self.definition.pre_weights
]

def layer_weights(
self, index: int, config: PretrainedConfig
) -> Optional[List[WeightInfo]]:
return weights

def layer_weights(self, index: int, config: PretrainedConfig) -> List[WeightInfo]:
return [
self._substitute(wi, config=config, layer_idx=index)
for wi in self.definition.layer_templates.weights
]

def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
return [
weights = [
self._substitute(wi, config=config) for wi in self.definition.post_weights
]

return weights

def sliceable(self) -> bool:
return True

Expand Down Expand Up @@ -337,3 +416,29 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo:
raise RuntimeError(
f"Unsupported model_type {config.model_type} for architecture {arch_name}"
)


def _load_arch_mappings(name) -> MappingInfo:
text = importlib.resources.read_text(mergekit._data.mappings, name)
return MappingInfo.model_validate_json(text)


def _load_all_mappings() -> Dict[str, Dict[str, Dict[str, str]]]:
mappings: Dict[str, MappingInfo] = {}
for f in importlib.resources.contents(mergekit._data.mappings):
if f.lower().endswith(".json"):
mapping = _load_arch_mappings(f)
for start_architecture in mapping.start_architectures:
if start_architecture not in mappings:
mappings[start_architecture] = {}

for destination_architecture in mapping.destination_architectures:
if destination_architecture not in mappings[start_architecture]:
mappings[start_architecture][
destination_architecture
] = mapping.weights_mapping

return mappings


JSON_MAPPINGS = _load_all_mappings()
15 changes: 7 additions & 8 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@

import logging
import os
from typing import Optional
from typing import Dict, List, Optional, Tuple

import tqdm
import transformers

from mergekit.architecture import ArchitectureInfo, get_architecture_info
from mergekit.card import generate_card
from mergekit.common import ModelReference
from mergekit.config import MergeConfiguration
from mergekit.graph import Executor
from mergekit.io.tasks import LoaderCache
Expand All @@ -42,16 +43,18 @@ def run_merge(
if not merge_config.models and not merge_config.slices:
raise RuntimeError("No output requested")

model_arch_info = [
model_arch_info: List[ArchitectureInfo] = [
get_architecture_info(m.config(trust_remote_code=options.trust_remote_code))
for m in merge_config.referenced_models()
]
if not options.allow_crimes:

if not (options.allow_crimes or merge_config.align_weights):
if not all(a == model_arch_info[0] for a in model_arch_info[1:]):
raise RuntimeError(
"Must specify --allow-crimes to attempt to mix different architectures"
)
arch_info = model_arch_info[0]

arch_info: ArchitectureInfo = model_arch_info[0]

# initialize loader cache and set options
loader_cache = LoaderCache()
Expand All @@ -68,7 +71,6 @@ def run_merge(
logging.info("Planning operations")
targets = MergePlanner(
merge_config,
arch_info,
out_path=out_path,
options=options,
out_model_config=cfg_out,
Expand Down Expand Up @@ -184,6 +186,3 @@ def _update_config_vocab(
"Unable to set vocabulary size in output config - you may need to manually correct it.",
exc_info=e,
)


__all__ = ["MergeOptions", "run_merge"]
Loading
Loading