diff --git a/mergekit/_data/architectures/gemma3.json b/mergekit/_data/architectures/gemma3.json new file mode 100644 index 00000000..787fd041 --- /dev/null +++ b/mergekit/_data/architectures/gemma3.json @@ -0,0 +1,69 @@ +{ + "model_type": "gemma3_text", + "architectures": [ + "Gemma3ForCausalLM" + ], + "pre_weights": [ + { + "name": "model.embed_tokens.weight", + "is_embed": true + } + ], + "num_layers_config_key": "num_hidden_layers", + "layer_templates": { + "weights": [ + { + "name": "model.layers.${layer_index}.input_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.q_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.q_norm.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.k_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.k_norm.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.v_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.o_proj.weight" + }, + { + "name": "model.layers.${layer_index}.post_attention_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.up_proj.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.gate_proj.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.down_proj.weight" + }, + { + "name": "model.layers.${layer_index}.post_feedforward_layernorm.weight" + } + ] + }, + "post_weights": [ + { + "name": "model.norm.weight" + }, + { + "name": "lm_head.weight", + "is_embed": true, + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] + } + ] +} diff --git a/mergekit/_data/architectures/gemma3vl.json b/mergekit/_data/architectures/gemma3vl.json new file mode 100644 index 00000000..cb45f7e2 --- /dev/null +++ b/mergekit/_data/architectures/gemma3vl.json @@ -0,0 +1,184 @@ +{ + "kind": "modular", + "architectures": [ + "Gemma3ForConditionalGeneration" + ], + "model_type": "gemma3", + "tagalong_files": [ + "preprocessor_config.json", + "processor_config.json" + ], + "modules": { + "text_decoder": { + "weight_prefix": "language_model.", + "architecture": { + "model_type": "gemma3_text", + "architectures": [ + "Gemma3ForCausalLM" + ], + "pre_weights": [ + { + "name": "model.embed_tokens.weight", + "is_embed": true + } + ], + "num_layers_config_key": "text_config.num_hidden_layers", + "layer_templates": { + "weights": [ + { + "name": "model.layers.${layer_index}.input_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.q_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.q_norm.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.k_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.k_norm.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.v_proj.weight" + }, + { + "name": "model.layers.${layer_index}.self_attn.o_proj.weight" + }, + { + "name": "model.layers.${layer_index}.post_attention_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.up_proj.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.gate_proj.weight" + }, + { + "name": "model.layers.${layer_index}.mlp.down_proj.weight" + }, + { + "name": "model.layers.${layer_index}.post_feedforward_layernorm.weight" + } + ] + }, + "post_weights": [ + { + "name": "model.norm.weight" + }, + { + "name": "lm_head.weight", + "is_embed": true, + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] + } + ] + } + }, + "multi_modal_projector": { + "weight_prefix": "multi_modal_projector.", + "architecture": { + "model_type": "gemma3_mmproj", + "architectures": [], + "pre_weights": [ + { + "name": "mm_input_projection_weight" + }, + { + "name": "mm_soft_emb_norm.weight" + } + ], + "post_weights": [], + "layer_templates": { + "weights": [] + }, + "override_num_layers": 0 + } + }, + "vision_tower": { + "weight_prefix": "vision_tower.vision_model.", + "architecture": { + "model_type": "siglip_vision_model", + "architectures": [], + "pre_weights": [ + { + "name": "embeddings.patch_embedding.bias" + }, + { + "name": "embeddings.patch_embedding.weight" + }, + { + "name": "embeddings.position_embedding.weight" + } + ], + "post_weights": [ + { + "name": "post_layernorm.bias" + }, + { + "name": "post_layernorm.weight" + } + ], + "layer_templates": { + "weights": [ + { + "name": "encoder.layers.${layer_index}.layer_norm1.bias" + }, + { + "name": "encoder.layers.${layer_index}.layer_norm1.weight" + }, + { + "name": "encoder.layers.${layer_index}.layer_norm2.bias" + }, + { + "name": "encoder.layers.${layer_index}.layer_norm2.weight" + }, + { + "name": "encoder.layers.${layer_index}.mlp.fc1.bias" + }, + { + "name": "encoder.layers.${layer_index}.mlp.fc1.weight" + }, + { + "name": "encoder.layers.${layer_index}.mlp.fc2.bias" + }, + { + "name": "encoder.layers.${layer_index}.mlp.fc2.weight" + }, + { + "name": "encoder.layers.${layer_index}.self_attn.k_proj.bias" + }, + { + "name": "encoder.layers.${layer_index}.self_attn.k_proj.weight" + }, + { + "name": "encoder.layers.${layer_index}.self_attn.out_proj.bias" + }, + { + "name": "encoder.layers.${layer_index}.self_attn.out_proj.weight" + }, + { + "name": "encoder.layers.${layer_index}.self_attn.q_proj.bias" + }, + { + "name": "encoder.layers.${layer_index}.self_attn.q_proj.weight" + }, + { + "name": "encoder.layers.${layer_index}.self_attn.v_proj.bias" + }, + { + "name": "encoder.layers.${layer_index}.self_attn.v_proj.weight" + } + ] + }, + "num_layers_config_key": "vision_config.num_hidden_layers" + } + } + } +} diff --git a/mergekit/_data/architectures/t5.json b/mergekit/_data/architectures/t5.json new file mode 100644 index 00000000..9ed8a7b8 --- /dev/null +++ b/mergekit/_data/architectures/t5.json @@ -0,0 +1,170 @@ +{ + "kind": "modular", + "architectures": [ + "T5ForConditionalGeneration" + ], + "model_type": "t5", + "modules": { + "decoder": { + "architecture": { + "model_type": "", + "architectures": [], + "pre_weights": [ + { + "name": "decoder.embed_tokens.weight", + "is_embed": true, + "optional": true, + "tied_names": [ + "shared.weight", + "lm_head.weight", + "encoder.embed_tokens.weight" + ] + } + ], + "num_layers_config_key": "num_decoder_layers", + "layer_templates": { + "weights": [ + { + "name": "decoder.block.${layer_index}.layer.0.layer_norm.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.0.SelfAttention.q.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.0.SelfAttention.k.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.0.SelfAttention.v.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.0.SelfAttention.o.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.0.SelfAttention.relative_attention_bias.weight", + "optional": true + }, + { + "name": "decoder.block.${layer_index}.layer.1.EncDecAttention.q.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.1.EncDecAttention.k.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.1.EncDecAttention.v.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.1.EncDecAttention.o.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.1.layer_norm.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.2.DenseReluDense.wi_0.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.2.DenseReluDense.wi_1.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.2.DenseReluDense.wo.weight" + }, + { + "name": "decoder.block.${layer_index}.layer.2.layer_norm.weight" + } + ] + }, + "post_weights": [ + { + "name": "decoder.final_layer_norm.weight" + } + ] + } + }, + "encoder": { + "architecture": { + "model_type": "", + "architectures": [], + "pre_weights": [ + { + "name": "encoder.embed_tokens.weight", + "is_embed": true, + "optional": true, + "tied_names": [ + "shared.weight", + "lm_head.weight", + "decoder.embed_tokens.weight" + ] + } + ], + "num_layers_config_key": "num_hidden_layers", + "layer_templates": { + "weights": [ + { + "name": "encoder.block.${layer_index}.layer.0.layer_norm.weight" + }, + { + "name": "encoder.block.${layer_index}.layer.0.SelfAttention.q.weight" + }, + { + "name": "encoder.block.${layer_index}.layer.0.SelfAttention.k.weight" + }, + { + "name": "encoder.block.${layer_index}.layer.0.SelfAttention.v.weight" + }, + { + "name": "encoder.block.${layer_index}.layer.0.SelfAttention.o.weight" + }, + { + "name": "encoder.block.${layer_index}.layer.0.SelfAttention.relative_attention_bias.weight", + "optional": true + }, + { + "name": "encoder.block.${layer_index}.layer.1.DenseReluDense.wi_0.weight" + }, + { + "name": "encoder.block.${layer_index}.layer.1.DenseReluDense.wi_1.weight" + }, + { + "name": "encoder.block.${layer_index}.layer.1.DenseReluDense.wo.weight" + }, + { + "name": "encoder.block.${layer_index}.layer.1.layer_norm.weight" + } + ] + }, + "post_weights": [ + { + "name": "encoder.final_layer_norm.weight" + } + ] + } + }, + "shared": { + "architecture": { + "model_type": "", + "architectures": [], + "pre_weights": [ + { + "name": "shared.weight", + "is_embed": true + } + ], + "layer_templates": { + "weights": [] + }, + "post_weights": [ + { + "name": "lm_head.weight", + "is_embed": true, + "optional": true, + "tied_names": [ + "shared.weight", + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight" + ] + } + ], + "override_num_layers": 0 + } + } + } +} diff --git a/mergekit/_data/architectures/whisper.json b/mergekit/_data/architectures/whisper.json new file mode 100644 index 00000000..82ac444b --- /dev/null +++ b/mergekit/_data/architectures/whisper.json @@ -0,0 +1,196 @@ +{ + "kind": "modular", + "architectures": [ + "WhisperForConditionalGeneration" + ], + "model_type": "whisper", + "tagalong_files": [ + "preprocessor_config.json", + "normalizer.json" + ], + "modules": { + "decoder": { + "weight_prefix": "model.decoder", + "architecture": { + "model_type": "", + "architectures": [], + "pre_weights": [ + { + "name": "embed_tokens.weight", + "is_embed": true + }, + { + "name": "embed_positions.weight" + } + ], + "num_layers_config_key": "decoder_layers", + "layer_templates": { + "weights": [ + { + "name": "layers.${layer_index}.encoder_attn.k_proj.weight" + }, + { + "name": "layers.${layer_index}.encoder_attn.out_proj.bias" + }, + { + "name": "layers.${layer_index}.encoder_attn.out_proj.weight" + }, + { + "name": "layers.${layer_index}.encoder_attn.q_proj.bias" + }, + { + "name": "layers.${layer_index}.encoder_attn.q_proj.weight" + }, + { + "name": "layers.${layer_index}.encoder_attn.v_proj.bias" + }, + { + "name": "layers.${layer_index}.encoder_attn.v_proj.weight" + }, + { + "name": "layers.${layer_index}.encoder_attn_layer_norm.bias" + }, + { + "name": "layers.${layer_index}.encoder_attn_layer_norm.weight" + }, + { + "name": "layers.${layer_index}.fc1.bias" + }, + { + "name": "layers.${layer_index}.fc1.weight" + }, + { + "name": "layers.${layer_index}.fc2.bias" + }, + { + "name": "layers.${layer_index}.fc2.weight" + }, + { + "name": "layers.${layer_index}.final_layer_norm.bias" + }, + { + "name": "layers.${layer_index}.final_layer_norm.weight" + }, + { + "name": "layers.${layer_index}.self_attn.k_proj.weight" + }, + { + "name": "layers.${layer_index}.self_attn.out_proj.bias" + }, + { + "name": "layers.${layer_index}.self_attn.out_proj.weight" + }, + { + "name": "layers.${layer_index}.self_attn.q_proj.bias" + }, + { + "name": "layers.${layer_index}.self_attn.q_proj.weight" + }, + { + "name": "layers.${layer_index}.self_attn.v_proj.bias" + }, + { + "name": "layers.${layer_index}.self_attn.v_proj.weight" + }, + { + "name": "layers.${layer_index}.self_attn_layer_norm.bias" + }, + { + "name": "layers.${layer_index}.self_attn_layer_norm.weight" + } + ] + }, + "post_weights": [ + { + "name": "layer_norm.bias" + }, + { + "name": "layer_norm.weight" + } + ] + } + }, + "encoder": { + "weight_prefix": "model.encoder.", + "architecture": { + "model_type": "", + "architectures": [], + "pre_weights": [ + { + "name": "embed_positions.weight" + }, + { + "name": "conv1.bias" + }, + { + "name": "conv1.weight" + }, + { + "name": "conv2.bias" + }, + { + "name": "conv2.weight" + } + ], + "post_weights": [ + { + "name": "layer_norm.bias" + }, + { + "name": "layer_norm.weight" + } + ], + "layer_templates": { + "weights": [ + { + "name": "layers.${layer_index}.fc1.bias" + }, + { + "name": "layers.${layer_index}.fc1.weight" + }, + { + "name": "layers.${layer_index}.fc2.bias" + }, + { + "name": "layers.${layer_index}.fc2.weight" + }, + { + "name": "layers.${layer_index}.final_layer_norm.bias" + }, + { + "name": "layers.${layer_index}.final_layer_norm.weight" + }, + { + "name": "layers.${layer_index}.self_attn.k_proj.weight" + }, + { + "name": "layers.${layer_index}.self_attn.out_proj.bias" + }, + { + "name": "layers.${layer_index}.self_attn.out_proj.weight" + }, + { + "name": "layers.${layer_index}.self_attn.q_proj.bias" + }, + { + "name": "layers.${layer_index}.self_attn.q_proj.weight" + }, + { + "name": "layers.${layer_index}.self_attn.v_proj.bias" + }, + { + "name": "layers.${layer_index}.self_attn.v_proj.weight" + }, + { + "name": "layers.${layer_index}.self_attn_layer_norm.bias" + }, + { + "name": "layers.${layer_index}.self_attn_layer_norm.weight" + } + ] + }, + "num_layers_config_key": "encoder_layers" + } + } + } +} diff --git a/mergekit/architecture.py b/mergekit/architecture.py deleted file mode 100644 index 49840b73..00000000 --- a/mergekit/architecture.py +++ /dev/null @@ -1,779 +0,0 @@ -# Copyright (C) 2025 Arcee AI -# SPDX-License-Identifier: BUSL-1.1 - -import importlib.resources -import logging -import re -import string -import warnings -from abc import ABC, abstractmethod -from collections import defaultdict -from pathlib import Path -from typing import ClassVar, Dict, List, Optional, Tuple, Union - -from huggingface_hub import snapshot_download -from pydantic import BaseModel, Field -from transformers import PretrainedConfig -from typing_extensions import Literal - -import mergekit._data.architectures -from mergekit.io.lazy_tensor_loader import ShardedTensorIndex - - -class WeightInfo(BaseModel, frozen=True): - """Information about an individual weight tensor in a model. - - Attributes: - name (str): - The name of the tensor representing the weight. - is_embed (bool): - Indicates whether the weight is for an embedding or language model head. - input_space (Optional[str]): - The name of the input space associated with the weight, if applicable. - output_space (Optional[str]): - The name of the output space associated with the weight, if applicable. - optional (bool): - Indicates whether the weight can be omitted from a model. - aliases (Optional[List[str]]): - List of alternative names for the weight, if applicable. - tied_names (Optional[List[str]]): - List of names for weights that are tied to this weight, if applicable. - force_dtype (Optional[str]): - Mandatory dtype for the weight, if applicable. - """ - - name: str - is_embed: bool = False - input_space: Optional[str] = None - output_space: Optional[str] = None - optional: bool = False - tied: bool = False - aliases: Optional[Tuple[str, ...]] = None - tied_names: Optional[Tuple[str, ...]] = None - force_dtype: Optional[str] = None - head_split: Literal[None, "input", "output"] = None - is_kq: Optional[bool] = False - - -class ProceduralSpaceInfo(BaseModel, frozen=True): - """Defines a procedural space computed from one or more other spaces. - - Currently only supports residual connections. - - Attributes: - name (str): The name of the space defined. - type (str): The type of procedural space. - inputs (List[str]): List of names of spaces used to define this space.""" - - name: str - type: Literal["residual"] - inputs: List[str] - - -class ArchitectureInfo(ABC): - @abstractmethod - def name(self) -> str: - """Return the name of the architecture.""" - ... - - @abstractmethod - def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """Return a list of all weights preceding the first layer.""" - ... - - @abstractmethod - def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """Return a list of all weights following the final layer.""" - ... - - @abstractmethod - def layer_weights( - self, index: int, config: PretrainedConfig - ) -> Optional[List[WeightInfo]]: - """Return a list of all weights associated with a given layer.""" - ... - - @abstractmethod - def sliceable(self) -> bool: - """ - Return True if the layers of this architecture can be meaningfully sliced. - """ - ... - - def num_layers_config_key(self) -> str: - """Key in config that represents number of layers""" - return "num_hidden_layers" - - def num_layers(self, config: PretrainedConfig) -> int: - """Return the number of layers in a model.""" - return getattr(config, self.num_layers_config_key()) - - def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """Return all weights associated with a model.""" - num_layers = self.num_layers(config) - res = list(self.pre_weights(config)) - for layer_idx in range(num_layers): - res.extend(self.layer_weights(layer_idx, config)) - res.extend(self.post_weights(config)) - return res - - def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]: - """Return a list of all procedurally defined spaces in a model.""" - return [] - - def has_defined_spaces(self) -> bool: - """ - Return True if this architecture defines space information needed for - matching-based merge methods. - """ - return False - - -class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True): - info: ArchitectureInfo - config: PretrainedConfig - - def name(self) -> str: - return self.info.name() - - def num_layers(self) -> int: - return self.info.num_layers(self.config) - - def pre_weights(self) -> List[WeightInfo]: - return self.info.pre_weights(self.config) - - def post_weights(self) -> List[WeightInfo]: - return self.info.post_weights(self.config) - - def layer_weights(self, index: int) -> List[WeightInfo]: - return self.info.layer_weights(index, self.config) - - def procedural_spaces(self) -> List[ProceduralSpaceInfo]: - return self.info.procedural_spaces(self.config) - - def all_weights(self) -> List[WeightInfo]: - return self.info.all_weights(self.config) - - -class JSONLayerTemplates(BaseModel, frozen=True): - weights: List[WeightInfo] - procedural_spaces: Optional[List[ProceduralSpaceInfo]] = None - - -class JSONArchitectureDefinition(BaseModel, frozen=True): - expected_model_type: str = Field(alias="model_type") - architectures: List[str] - pre_weights: List[WeightInfo] - layer_templates: JSONLayerTemplates - post_weights: List[WeightInfo] - procedural_spaces: Optional[List[ProceduralSpaceInfo]] = None - num_layers_config_key: Optional[str] = None - - -class TemplateWithArithmetic(string.Template): - idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)" - - -def _template_substitution( - template: str, num_layers: int, layer_idx: Optional[int] = None -) -> str: - if "{" not in template: - return template - - substitutions = { - "num_layers": num_layers, - "num_layers+1": num_layers + 1, - "num_layers-1": num_layers - 1, - } - - if layer_idx is not None: - substitutions.update( - { - "layer_index": layer_idx, - "layer_index+1": layer_idx + 1, - "layer_index-1": layer_idx - 1, - } - ) - - return TemplateWithArithmetic(template).substitute(substitutions) - - -def _hierarchy(names, layer_prefix=r"\.\d+\.") -> Dict[str, List[str]]: - hierarchy = defaultdict(list) - - # Regular expression to match layers (denoted by .{integer}. by default) - layer_pattern = re.compile(layer_prefix) - - if names: - for name in names: - # Find the layer part of the string (e.g., 'model.layers.0.') - match = layer_pattern.search(name) - if match: - # Extract everything up to the layer identifier - layer_prefix = name[: match.end() - 1] # e.g., 'model.layers.0' - # Extract the parameter name after the layer identifier - param_name = name[match.end() :] # e.g., 'input_layernorm.weight' - # Add the parameter name to the corresponding layer in the hierarchy - hierarchy[layer_prefix].append(param_name) - else: - hierarchy[name].append("") - - return hierarchy - - -class AutomaticArchitectureInfo(ArchitectureInfo, BaseModel): - arch_name: str = Field(default="") - parameter_names: List[str] = Field(default_factory=list) - embed: List[str] = Field(default_factory=list) - layered_parameter_names: Dict[str, List[str]] = Field(default_factory=dict) - prefix_tracker: Dict[str, str] = Field(default_factory=dict) - post_fill_parameters: bool = False - - def __init__( - self, - arch_name: str, - parameter_names: List[str], - prefix_tracker: Optional[Dict[str, str]] = None, - post_fill_parameters: bool = False, - ): - super().__init__() - self.arch_name = arch_name - self.parameter_names = parameter_names - self.layered_parameter_names = _hierarchy(self.parameter_names) - self.prefix_tracker = prefix_tracker or {} - self.embed = self._find_embed_params() - self.post_fill_parameters = post_fill_parameters - - def _find_embed_params(self) -> List[str]: - """Identify embedding parameters (e.g., 'lm_head', 'embed') that may require special handling.""" - embed_params = [] - for name in self.parameter_names: - if any(embedding_name in name for embedding_name in ["lm_head", "embed"]): - embed_params.append(name) - return embed_params - - def name(self) -> str: - """Returns the architecture name.""" - return self.arch_name - - def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """This architecture does not distinguish pre-weights.""" - return [] - - def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """This architecture does not distinguish post-weights.""" - return [] - - def layer_weights( - self, index: int, config: PretrainedConfig - ) -> Optional[List[WeightInfo]]: - """ - Retrieves the weights for a specified layer, adjusting names for prefixes if applicable. - """ - layer_name = list(self.layered_parameter_names.keys())[index] - adjusted_layer_name = self._adjust_layer_name(layer_name, config) - - weights = [ - WeightInfo( - name=f"{adjusted_layer_name}.{param}" if param else adjusted_layer_name, - is_embed=(layer_name in self.embed), - ) - for param in self.layered_parameter_names[layer_name] - ] - return ( - weights - if weights - else [ - WeightInfo( - name=adjusted_layer_name, is_embed=(layer_name in self.embed) - ) - ] - ) - - def _adjust_layer_name(self, layer_name: str, config: PretrainedConfig) -> str: - """Adjust layer names by removing any prefix as indicated in the prefix tracker.""" - if config and config.name_or_path in self.prefix_tracker: - prefix = self.prefix_tracker.get(config.name_or_path, "") - if layer_name.startswith(prefix): - return layer_name[len(prefix) :] - return layer_name - - def sliceable(self) -> bool: - """Indicates if the architecture supports slicing.""" - return True - - def num_layers(self, config: PretrainedConfig) -> int: - """Returns the number of layers based on layered parameter names.""" - return len(self.layered_parameter_names) - - -class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True): - definition: JSONArchitectureDefinition - - def _substitute( - self, - item: Union[WeightInfo, ProceduralSpaceInfo], - config: PretrainedConfig, - layer_idx: Optional[int] = None, - ) -> Union[WeightInfo, ProceduralSpaceInfo]: - num_layers = self.num_layers(config) - - obj_dict = item.model_dump(mode="json", exclude_unset=True) - for key in obj_dict: - if isinstance(obj_dict[key], str): - obj_dict[key] = _template_substitution( - obj_dict[key], num_layers, layer_idx - ) - elif isinstance(obj_dict[key], list): - obj_dict[key] = [ - ( - _template_substitution(s, num_layers, layer_idx) - if isinstance(s, str) - else s - ) - for s in obj_dict[key] - ] - return type(item).model_validate(obj_dict) - - def name(self) -> str: - return self.definition.expected_model_type - - def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - return [ - self._substitute(wi, config=config) for wi in self.definition.pre_weights - ] - - def layer_weights( - self, index: int, config: PretrainedConfig - ) -> Optional[List[WeightInfo]]: - return [ - self._substitute(wi, config=config, layer_idx=index) - for wi in self.definition.layer_templates.weights - ] - - def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - return [ - self._substitute(wi, config=config) for wi in self.definition.post_weights - ] - - def sliceable(self) -> bool: - return True - - def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]: - res = [] - for s in self.definition.procedural_spaces or []: - res.append(self._substitute(s, config=config)) - for idx in range(self.num_layers(config)): - for s in self.definition.layer_templates.procedural_spaces or []: - res.append(self._substitute(s, config=config, layer_idx=idx)) - return res - - def has_defined_spaces(self) -> bool: - if ( - self.definition.procedural_spaces - or self.definition.layer_templates.procedural_spaces - ): - return True - for wi in ( - self.definition.layer_templates.weights - + self.definition.pre_weights - + self.definition.post_weights - ): - if wi.input_space or wi.output_space: - return True - return False - - def num_layers_config_key(self) -> str: - return self.definition.num_layers_config_key - - -class MixtralTensorNames(ArchitectureInfo, BaseModel): - ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM" - num_local_experts: int - - def name(self) -> str: - return "mixtral" - - @classmethod - def from_config(cls, config: PretrainedConfig): - return MixtralTensorNames(num_local_experts=config.num_local_experts) - - def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - return MISTRAL_INFO.pre_weights(config) - - def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - return MISTRAL_INFO.post_weights(config) - - def num_layers_config_key(self) -> str: - return MISTRAL_INFO.num_layers_config_key() - - def layer_weights( - self, index: int, config: PretrainedConfig - ) -> Optional[List[WeightInfo]]: - num_experts = self.num_local_experts - prefix = f"model.layers.{index}" - tensor_names = [] - for expert_idx in range(num_experts): - for param in ("w1", "w2", "w3"): - tensor_names.append( - prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight" - ) - tensor_names.append(prefix + ".block_sparse_moe.gate.weight") - res = [] - for name in tensor_names: - res.append(WeightInfo(name=name)) - for weight_info in MISTRAL_INFO.layer_weights(index, config): - if ".mlp." in weight_info.name: - continue - res.append(weight_info) - return res - - def sliceable(self) -> bool: - return True - - def has_defined_spaces(self) -> bool: - return False - - -def _load_json_arch(name: str) -> JsonArchitectureInfo: - text = importlib.resources.read_text(mergekit._data.architectures, name) - return JsonArchitectureInfo( - definition=JSONArchitectureDefinition.model_validate_json(text) - ) - - -def _load_all_architectures() -> ( - Tuple[List[JsonArchitectureInfo], Dict[str, List[JsonArchitectureInfo]]] -): - architectures: List[JsonArchitectureInfo] = [] - for f in importlib.resources.contents(mergekit._data.architectures): - if f.lower().endswith(".json"): - architectures.append(_load_json_arch(f)) - - name_to_arch: Dict[str, List[JsonArchitectureInfo]] = {} - for arch_info in architectures: - for name in arch_info.definition.architectures: - name_to_arch[name] = name_to_arch.get(name, []) - name_to_arch[name].append(arch_info) - return architectures, name_to_arch - - -JSON_ARCHITECTURES, NAME_TO_ARCH = _load_all_architectures() -MISTRAL_INFO = _load_json_arch("mistral.json") -QWEN2_INFO = _load_json_arch("qwen2.json") - - -class ArchitectureInfoUtils: - """Functions for inferring architecture information from a merge configuration.""" - - @staticmethod - def get_architecture_info(config: PretrainedConfig) -> Optional[ArchitectureInfo]: - """Get architecture info from an existing model config.""" - if len(config.architectures) != 1: - raise RuntimeError("More than one architecture in config?") - - arch_name = config.architectures[0] - - if arch_name == MixtralTensorNames.ARCHITECTURE_NAME: - return MixtralTensorNames.from_config(config) - - if arch_name in NAME_TO_ARCH: - candidates = list(NAME_TO_ARCH[arch_name]) - if len(candidates) == 1: - return candidates[0] - - for c in candidates: - if c.definition.expected_model_type == config.model_type: - return c - - warnings.warn(f"No architecture config available for: {arch_name}.") - return None - - @staticmethod - def infer_architecture_info(merge_config) -> AutomaticArchitectureInfo: - """ - Infer architecture info and prefixes for alignment. - Prefixes typically denote where a model is used as a subcomponent of another model. - e.g., [layer.0, layer.1, ...] and []'vision_tower.layer.0', vision_tower.layer.1', ...] - inferring ßprefix = 'vision_tower' is required to align the two models. - - Usage: - Similar to `get_architecture_info`, but requires a merge configuration object rather than a model config. - This is so the common parameter names between all models can be inferred. - """ - param_names = [ - ParameterNamesUtils.get_model_parameter_names(source_model.model.path) - for source_model in merge_config.referenced_models() - ] - base_model = merge_config.base_model - - paired_list = list(zip(param_names, merge_config.referenced_models())) - paired_list.sort(key=lambda x: len(x[0]), reverse=True) - for i, (_, model_name) in enumerate(paired_list): - if model_name == base_model: - paired_list.insert(0, paired_list.pop(i)) - break - param_names, referenced_models = zip(*paired_list) - logging.info(f"Base model selected: {referenced_models[0].model.path}") - - prefixes = [""] - for i in range(1, len(param_names)): - assert len(param_names[0]) >= len( - param_names[i] - ), f"base model names list can't be shorter than model {i} names list" - prefixes.append( - ParameterNamesUtils.find_prefix(param_names[0], param_names[i]) - ) - - common_names = ParameterNamesUtils.find_common_ordered_names( - param_names, prefixes - ) - - common_names = ParameterNamesUtils.remove_size_conflicts( - common_names, referenced_models, prefixes - ) - - ArchitectureInfoUtils.log_info(common_names, param_names, referenced_models) - - if not common_names or any([p is None for p in prefixes]): - raise ValueError("Could not resolve model architecture automatically.") - - prefix_tracker = { - model.model.path: f"{prefix}." if prefix else "" - for model, prefix in zip(referenced_models, prefixes) - } - - arch_name = referenced_models[0].model.path - parameter_names = common_names - - return AutomaticArchitectureInfo( - arch_name=arch_name, - parameter_names=parameter_names, - prefix_tracker=prefix_tracker, - post_fill_parameters=( - referenced_models[0].model.path # base model name - if len(common_names) != len(param_names[0]) - else None # no post-fill needed - ), - ) - - @staticmethod - def log_info(common_names, param_names, referenced_models): - for i in range(1, len(param_names)): - prefix, case_message = ParameterNamesUtils.report_names_similarity( - param_names[0], param_names[i] - ) - logging.info( - f"Model {referenced_models[i].model.path}: \ - \n {f'Best prefix found: {prefix}' if prefix else 'No prefix found'}\ - \n {case_message.replace('MODEL_ID', referenced_models[i].model.path)}" - ) - - if len(common_names) != len(param_names[0]): - warnings.warn( - f"Merging {len(common_names)}/{len(param_names[0])} base model parameters. \ - \n Base model selected: {referenced_models[0].model.path} \ - \n copy_and_fill_missing_params will run when merge is complete, to fill in missing params from base model." - ) - - if len(common_names) < 0.3 * len(param_names[0]): - warnings.warn( - "Not many common parameters found. Are you sure you are merging the correct models?" - ) - - -class ParameterNamesUtils: - """Utility functions for handling parameter names.""" - - @staticmethod - def resolve_model_directory(repo_id: str) -> Path: - """Resolve the model directory (local or Hugging Face Hub).""" - if Path(repo_id).is_dir(): - return Path(repo_id) - - return Path(snapshot_download(repo_id)) - - @staticmethod - def get_model_parameter_names(repo_id: str) -> List[str]: - """Get parameter names of a model from a Hugging Face repo or local directory.""" - model_dir = ParameterNamesUtils.resolve_model_directory(repo_id) - return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys()) - - @staticmethod - def strip_prefix(name: str, prefix: str) -> str: - """Remove a single prefix from the start of a name.""" - if prefix != "" and name.startswith(prefix + "."): - return name[len(prefix) + 1 :] - return name - - @staticmethod - def find_prefix(list1: List[str], list2: List[str]) -> Optional[str]: - """ - Find a prefix in list1 that, after removal, makes list2 an ordered sublist. - """ - assert len(list1) >= len(list2), "params name list1 can't be shorter than list2" - - possible_prefixes = {item.split(".")[0] for item in list1 if "." in item} - possible_prefixes = [""] + list(possible_prefixes) - - prefix_matches = {} - best_prefix = "" # Default to no prefix - for prefix in possible_prefixes: - stripped_list1 = [ - ParameterNamesUtils.strip_prefix(item, prefix) for item in list1 - ] - prefix_matches[prefix] = len( - [item for item in list2 if item in stripped_list1] - ) - - if max(prefix_matches.values()) > prefix_matches[""]: - best_prefix = max(prefix_matches, key=prefix_matches.get) - - return best_prefix - - @staticmethod - def find_common_ordered_names( - param_names: List[List[str]], prefixes: List[str] - ) -> List[str]: - """Identify and return common parameter names across all models, ensuring correct order. Also account for prefix.""" - common_names = set(param_names[0]) - for i in range(1, len(param_names)): - prefix = f"{prefixes[i]}." if prefixes[i] else "" - common_names.intersection_update({prefix + name for name in param_names[i]}) - return [name for name in param_names[0] if name in common_names] - - @staticmethod - def remove_size_conflicts(common_names, referenced_models, prefixes): - model_dirs = [ - ParameterNamesUtils.resolve_model_directory(m.model.path) - for m in referenced_models - ] - model_indices = [ShardedTensorIndex.from_disk(str(dir)) for dir in model_dirs] - - common_name_and_shape = common_names.copy() - removed_names = [] - - for name in common_names: - base_shape = ParameterNamesUtils.tensor_shape(name, model_indices[0]) - - for i in range(1, len(referenced_models)): - other_name = name - prefix = f"{prefixes[i]}." if prefixes[i] else "" - if name.startswith(prefix) and prefix != "": - other_name = name[len(prefix) :] - shape = ParameterNamesUtils.tensor_shape(other_name, model_indices[i]) - - if base_shape != shape: - common_name_and_shape.remove(name) - removed_names.append((name, base_shape, shape, i)) - break - - size_mismatch_count = len(removed_names) - if size_mismatch_count > 0: - logging.warning( - f"Size mismatch detected for {size_mismatch_count}/{size_mismatch_count + len(common_names)} tensors. " - "These names were removed from the merge list." - ) - logging.info( - "The following tensors have different shapes across models and were removed from the merge list:" - ) - for name, base_shape, shape, i in removed_names: - logging.info( - f"Tensor name: {name}, Base model shape: {base_shape}, Mismatched shape: {shape} in model {referenced_models[i].model.path}" - ) - - return common_name_and_shape - - @staticmethod - def are_common_params_ordered(list1: List[str], list2: List[str]) -> bool: - """ - Check if common elements of list2 maintain their relative order in list1. - """ - common_params = set(list1).intersection(set(list2)) - last_index = -1 - - for param in list2: - if param in common_params: - current_index = list1.index(param) - if current_index < last_index: - return False - last_index = current_index - return True - - @staticmethod - def ordered_sublist(list1: List[str], list2: List[str]) -> bool: - """ - Check if list2 is a contiguous ordered sublist of list1. - """ - n, m = len(list1), len(list2) - - for i in range(n - m + 1): - if list1[i : i + m] == list2: - return True - return False - - @staticmethod - def report_names_similarity( - base_names: List[str], other_names: List[str] - ) -> Tuple[Optional[str], str]: - """ - Analyze similarity between parameter names of two models and identify shared prefixes. - - Returns: - best_prefix (str): Best matching prefix for parameter names. - case_message (str): Explanation of the structural relationship. - """ - possible_prefixes = {""} - possible_prefixes.update( - {item.split(".")[0] for item in base_names if "." in item} - ) - - prefixes_subset_overlap = {} - best_prefix = None - case_message = "No common parameter names found for any prefix" - - for prefix in possible_prefixes: - base_names_stripped = [ - ParameterNamesUtils.strip_prefix(name, prefix) for name in base_names - ] - - if ParameterNamesUtils.ordered_sublist(base_names_stripped, other_names): - return prefix, "All params in model have exact match in base model." - - intersection = set(base_names_stripped).intersection(set(other_names)) - prefixes_subset_overlap[prefix] = intersection - - if prefixes_subset_overlap: - best_prefix = max( - prefixes_subset_overlap, key=lambda x: len(prefixes_subset_overlap[x]) - ) - base_names_stripped = [ - ParameterNamesUtils.strip_prefix(name, best_prefix) - for name in base_names - ] - - overlap = len(prefixes_subset_overlap[best_prefix]) - ordered = ParameterNamesUtils.are_common_params_ordered( - base_names_stripped, other_names - ) - mismatched = [ - item for item in other_names if item not in base_names_stripped - ] - mismatched = "\n ".join(mismatched) - case_message = ( - f"{overlap}/{len(other_names)} ({100 * overlap / len(other_names):.2f}%) " - f"of model parameters are in the base model. \n" - f" Name ordering is {'preserved' if ordered else 'not preserved'}.\n" - f" Missing parameters:\n {mismatched}" - ) - - return best_prefix, case_message - - @staticmethod - def tensor_shape(name, index) -> Tuple[int]: - from safetensors import safe_open - - with safe_open( - Path(index.base_path) / index.tensor_paths[name], framework="pt" - ) as f: - return f.get_slice(name).get_shape() diff --git a/mergekit/architecture/__init__.py b/mergekit/architecture/__init__.py new file mode 100644 index 00000000..7f1310b8 --- /dev/null +++ b/mergekit/architecture/__init__.py @@ -0,0 +1,88 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 + +import logging +from typing import TYPE_CHECKING, Optional + +from transformers import PretrainedConfig + +from mergekit.architecture.auto import infer_architecture_info +from mergekit.architecture.base import ( + ConfiguredModelArchitecture, + ConfiguredModuleArchitecture, + ModelArchitecture, + ModuleArchitecture, + ModuleDefinition, + WeightInfo, +) +from mergekit.architecture.json_definitions import NAME_TO_ARCH +from mergekit.architecture.mixtral import MixtralTensorNames +from mergekit.options import MergeOptions + +if TYPE_CHECKING: + from mergekit.config import MergeConfiguration + +logger = logging.getLogger(__name__) + + +def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture]: + if len(config.architectures) != 1: + raise RuntimeError("More than one architecture in config?") + arch_name = config.architectures[0] + + if arch_name == MixtralTensorNames.ARCHITECTURE_NAME: + module = MixtralTensorNames.from_config(config) + return ModelArchitecture( + modules={"default": ModuleDefinition(architecture=module)}, + architectures=[arch_name], + ) + elif arch_name in NAME_TO_ARCH: + candidates = list(NAME_TO_ARCH[arch_name]) + if len(candidates) == 1: + return candidates[0] + + for c in candidates: + if c.expected_model_type == config.model_type: + return c + logger.warning( + f"Multiple architectures for {arch_name}, none match model type {config.model_type}" + ) + + logger.warning(f"No JSON architecture found for {arch_name}") + return None + + +def get_architecture_info( + config: "MergeConfiguration", options: MergeOptions +) -> ModelArchitecture: + models = config.referenced_models() + if not models: + raise ValueError("No models referenced in config") + + model_arch_info = [ + arch_info_for_config(m.config(trust_remote_code=options.trust_remote_code)) + for m in models + ] + if all(arch is not None for arch in model_arch_info): + if not options.allow_crimes and any( + arch != model_arch_info[0] for arch in model_arch_info + ): + raise RuntimeError( + "Must specify --allow-crimes to attempt to mix different architectures" + ) + return model_arch_info[0] + + # try to infer from all models + return infer_architecture_info(models, config.base_model, options) + + +__all__ = [ + "ModelArchitecture", + "ModuleArchitecture", + "ModuleDefinition", + "ConfiguredModuleArchitecture", + "ConfiguredModelArchitecture", + "WeightInfo", + "get_architecture_info", + "arch_info_for_config", +] diff --git a/mergekit/architecture/auto.py b/mergekit/architecture/auto.py new file mode 100644 index 00000000..3eee5855 --- /dev/null +++ b/mergekit/architecture/auto.py @@ -0,0 +1,120 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 + +import logging +import re +from collections import defaultdict +from typing import List, Optional + +from mergekit.architecture.base import ( + ModelArchitecture, + ModuleDefinition, + WeightInfo, +) +from mergekit.architecture.json_definitions import ( + JsonLayerTemplates, + JsonModuleArchDef, + JsonModuleArchitecture, +) +from mergekit.common import ModelReference +from mergekit.options import MergeOptions + +RE_LAYER_INDEX = re.compile(r"\.(\d+)\.") + +logger = logging.getLogger(__name__) + + +def get_model_tensor_names(model: ModelReference, options: MergeOptions) -> List[str]: + loader = model.lazy_loader( + cache_dir=options.transformers_cache, lazy_unpickle=options.lazy_unpickle + ) + return list(loader.index.tensor_paths.keys()) + + +def infer_architecture_info( + models: List[ModelReference], + base_model: Optional[ModelReference], + options: MergeOptions, +) -> ModelArchitecture: + model_tensor_names = { + model: set(get_model_tensor_names(model, options)) + for model in (set(models).union({base_model} if base_model else {})) + } + if base_model is None: + base_model = models.pop(0) + all_tensor_names = set().union(*model_tensor_names.values()) + in_all_models = all_tensor_names.intersection(*model_tensor_names.values()) + + module_prefixes = set() + module_layer_counts = defaultdict(int) + module_templates = defaultdict(set) + module_loose_weights = defaultdict(set) + for tensor_name in all_tensor_names: + if len(RE_LAYER_INDEX.findall(tensor_name)) > 1: + raise ValueError( + f"Tensor name {tensor_name} has more than one layer index - not supported" + ) + elif match := RE_LAYER_INDEX.search(tensor_name): + prefix = tensor_name[: match.start()] + module_prefixes.add(prefix) + layer_idx = int(match.group(1)) + module_layer_counts[prefix] = max( + module_layer_counts[prefix], layer_idx + 1 + ) + module_templates[prefix] = module_templates[prefix].union( + set([RE_LAYER_INDEX.sub("{layer_index}", tensor_name)]) + ) + + # create a default module with no prefix + module_prefixes.add("") + + for tensor_name in all_tensor_names: + if RE_LAYER_INDEX.search(tensor_name): + continue + for prefix in module_prefixes: + if tensor_name.startswith(prefix): + module_loose_weights[prefix].add(tensor_name[len(prefix) :]) + + if not (module_loose_weights[""] or module_templates[""]): + module_prefixes.remove("") + if not module_prefixes: + raise ValueError("No modules found in models") + + logging.warning(f"Inferred {len(module_prefixes)} modules:") + for prefix in module_prefixes: + logging.warning( + f" {repr(prefix or 'default')} with {module_layer_counts[prefix]} layers, {len(module_templates[prefix])} templates, and {len(module_loose_weights[prefix])} loose weights" + ) + + def _wi(template: str) -> WeightInfo: + optional = template.replace("{layer_index}", "0") not in in_all_models + return WeightInfo( + name=template, + optional=optional, + ) + + module_archs = {} + for prefix in module_prefixes: + num_layers = module_layer_counts[prefix] + module_archs[prefix or "default"] = JsonModuleArchitecture( + definition=JsonModuleArchDef( + model_type="", + architectures=[], + pre_weights=[_wi(t) for t in module_loose_weights[prefix]], + layer_templates=JsonLayerTemplates( + weights=[_wi(t) for t in module_templates[prefix]] + ), + post_weights=[], + num_layers_config_key=None, + override_num_layers=num_layers, + ), + ) + + return ModelArchitecture( + modules={ + key: ModuleDefinition(architecture=value) + for key, value in module_archs.items() + }, + architectures=[], + model_type="", + ) diff --git a/mergekit/architecture/base.py b/mergekit/architecture/base.py new file mode 100644 index 00000000..26c0d231 --- /dev/null +++ b/mergekit/architecture/base.py @@ -0,0 +1,152 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field +from transformers import PretrainedConfig + +from mergekit.common import get_config_value + + +class WeightInfo(BaseModel, frozen=True): + """Information about an individual weight tensor in a model. + + Attributes: + name (str): + The name of the tensor representing the weight. + is_embed (bool): + Indicates whether the weight is for an embedding or language model head. + optional (bool): + Indicates whether the weight can be omitted from a model. + aliases (Optional[List[str]]): + List of alternative names for the weight, if applicable. + force_dtype (Optional[str]): + Mandatory dtype for the weight, if applicable. + """ + + name: str + is_embed: bool = False + optional: bool = False + aliases: Optional[Tuple[str, ...]] = None + force_dtype: Optional[str] = None + tied_names: Optional[Tuple[str, ...]] = None + + +def _prefix_weight(weight: WeightInfo, prefix: Optional[str] = None) -> WeightInfo: + if prefix is None: + return weight + return WeightInfo( + name=prefix + weight.name, + aliases=tuple(prefix + alias for alias in weight.aliases or ()) or None, + tied_names=tuple(prefix + tied_name for tied_name in weight.tied_names or ()) + or None, + **weight.model_dump(exclude={"name", "aliases", "tied_names"}), + ) + + +class ModuleArchitecture(ABC): + @abstractmethod + def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + """Return a list of all weights preceding the first layer.""" + ... + + @abstractmethod + def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + """Return a list of all weights following the final layer.""" + ... + + @abstractmethod + def layer_weights( + self, index: int, config: PretrainedConfig + ) -> Optional[List[WeightInfo]]: + """Return a list of all weights associated with a given layer.""" + ... + + def num_layers_config_key(self) -> str: + """Key in config that represents number of layers""" + return "num_hidden_layers" + + def num_layers(self, config: PretrainedConfig) -> int: + """Return the number of layers in a model.""" + return get_config_value(config, self.num_layers_config_key()) + + def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + """Return all weights associated with a model.""" + num_layers = self.num_layers(config) + res = list(self.pre_weights(config)) + for layer_idx in range(num_layers): + res.extend(self.layer_weights(layer_idx, config)) + res.extend(self.post_weights(config)) + return res + + +class ConfiguredModuleArchitecture( + BaseModel, frozen=True, arbitrary_types_allowed=True +): + info: ModuleArchitecture + config: PretrainedConfig + weight_prefix: Optional[str] = None + + def num_layers(self) -> int: + return self.info.num_layers(self.config) + + def pre_weights(self) -> List[WeightInfo]: + return [ + _prefix_weight(w, self.weight_prefix) + for w in self.info.pre_weights(self.config) + ] + + def post_weights(self) -> List[WeightInfo]: + return [ + _prefix_weight(w, self.weight_prefix) + for w in self.info.post_weights(self.config) + ] + + def layer_weights(self, index: int) -> List[WeightInfo]: + return [ + _prefix_weight(w, self.weight_prefix) + for w in self.info.layer_weights(index, self.config) + ] + + def all_weights(self) -> List[WeightInfo]: + return [ + _prefix_weight(w, self.weight_prefix) + for w in self.info.all_weights(self.config) + ] + + +class ModuleDefinition(BaseModel, frozen=True, arbitrary_types_allowed=True): + architecture: ModuleArchitecture + weight_prefix: Optional[str] = None + subfolder: Optional[str] = None + + +class ModelArchitecture(BaseModel, frozen=True): + modules: Dict[str, ModuleDefinition] + architectures: List[str] + expected_model_type: str = Field(alias="model_type") + tagalong_files: Optional[List[str]] = None + + def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + res = [] + for module in self.modules.values(): + for weight_info in module.architecture.all_weights(config=config): + res.append(_prefix_weight(weight_info, module.weight_prefix)) + return res + + +class ConfiguredModelArchitecture(BaseModel, frozen=True, arbitrary_types_allowed=True): + info: ModelArchitecture + config: PretrainedConfig + + def all_weights(self) -> List[WeightInfo]: + return self.info.all_weights(self.config) + + def get_module(self, module_name: str) -> ConfiguredModuleArchitecture: + return ConfiguredModuleArchitecture( + info=self.info.modules[module_name].architecture, + config=self.config, + weight_prefix=self.info.modules[module_name].weight_prefix, + ) diff --git a/mergekit/architecture/json_definitions.py b/mergekit/architecture/json_definitions.py new file mode 100644 index 00000000..0c11c486 --- /dev/null +++ b/mergekit/architecture/json_definitions.py @@ -0,0 +1,187 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 + +import importlib +import importlib.resources +import json +import string +from typing import Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field +from transformers import PretrainedConfig +from typing_extensions import Literal + +import mergekit._data.architectures +from mergekit.architecture.base import ( + ModelArchitecture, + ModuleArchitecture, + ModuleDefinition, + WeightInfo, +) + + +class JsonLayerTemplates(BaseModel, frozen=True): + weights: List[WeightInfo] + + +class JsonModuleArchDef(BaseModel, frozen=True): + expected_model_type: str = Field(alias="model_type") + architectures: List[str] + pre_weights: List[WeightInfo] + layer_templates: JsonLayerTemplates + post_weights: List[WeightInfo] + num_layers_config_key: Optional[str] = None + override_num_layers: Optional[int] = None + + +class JsonModuleArchitecture(ModuleArchitecture, BaseModel, frozen=True): + kind: Literal["module"] = "module" + definition: JsonModuleArchDef + + def _substitute( + self, + item: WeightInfo, + config: PretrainedConfig, + layer_idx: Optional[int] = None, + ) -> WeightInfo: + num_layers = self.num_layers(config) + + obj_dict = item.model_dump(mode="json", exclude_unset=True) + for key in obj_dict: + if isinstance(obj_dict[key], str): + obj_dict[key] = _template_substitution( + obj_dict[key], num_layers, layer_idx + ) + elif isinstance(obj_dict[key], list): + obj_dict[key] = [ + ( + _template_substitution(s, num_layers, layer_idx) + if isinstance(s, str) + else s + ) + for s in obj_dict[key] + ] + return type(item).model_validate(obj_dict) + + def name(self) -> str: + return self.definition.expected_model_type + + def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + return [ + self._substitute(wi, config=config) for wi in self.definition.pre_weights + ] + + def layer_weights( + self, index: int, config: PretrainedConfig + ) -> Optional[List[WeightInfo]]: + return [ + self._substitute(wi, config=config, layer_idx=index) + for wi in self.definition.layer_templates.weights + ] + + def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + return [ + self._substitute(wi, config=config) for wi in self.definition.post_weights + ] + + def num_layers_config_key(self) -> str: + return self.definition.num_layers_config_key + + def num_layers(self, config): + if self.definition.override_num_layers is not None: + return self.definition.override_num_layers + return super().num_layers(config) + + +class JsonModuleDefinition(BaseModel, frozen=True): + architecture: JsonModuleArchDef + weight_prefix: Optional[str] = None + subfolder: Optional[str] = None + + +class JsonModularArchitectureDefinition(BaseModel, frozen=True): + kind: Literal["modular"] + modules: Dict[str, JsonModuleDefinition] + architectures: List[str] + expected_model_type: str = Field(alias="model_type") + tagalong_files: Optional[List[str]] = None + + +class TemplateWithArithmetic(string.Template): + idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)" + + +def _template_substitution( + template: str, num_layers: int, layer_idx: Optional[int] = None +) -> str: + if "{" not in template: + return template + + substitutions = { + "num_layers": num_layers, + "num_layers+1": num_layers + 1, + "num_layers-1": num_layers - 1, + } + + if layer_idx is not None: + substitutions.update( + { + "layer_index": layer_idx, + "layer_index+1": layer_idx + 1, + "layer_index-1": layer_idx - 1, + } + ) + + return TemplateWithArithmetic(template).substitute(substitutions) + + +def _load_architecture_json(name: str) -> ModelArchitecture: + with importlib.resources.open_text(mergekit._data.architectures, name) as f: + text = f.read() + data = json.loads(text) + kind = data.get("kind", "module") + if kind == "modular": + parsed = JsonModularArchitectureDefinition.model_validate_json(text) + return ModelArchitecture( + modules={ + k: ModuleDefinition( + architecture=JsonModuleArchitecture(definition=v.architecture), + weight_prefix=v.weight_prefix, + subfolder=v.subfolder, + ) + for k, v in parsed.modules.items() + }, + architectures=parsed.architectures, + model_type=parsed.expected_model_type, + tagalong_files=parsed.tagalong_files, + ) + elif data.get("kind", "module") == "module": + module = JsonModuleArchitecture( + definition=JsonModuleArchDef.model_validate(data) + ) + return ModelArchitecture( + modules={"default": ModuleDefinition(architecture=module)}, + architectures=module.definition.architectures, + model_type=module.definition.expected_model_type, + ) + else: + raise RuntimeError(f"Unexpected architecture kind: {data['kind']}") + + +def _load_all_architectures() -> ( + Tuple[List[ModelArchitecture], Dict[str, List[ModelArchitecture]]] +): + architectures: List[ModelArchitecture] = [] + for f in importlib.resources.contents(mergekit._data.architectures): + if f.lower().endswith(".json"): + architectures.append(_load_architecture_json(f)) + + name_to_arch: Dict[str, List[JsonModuleArchitecture]] = {} + for arch_info in architectures: + for arch_name in arch_info.architectures: + name_to_arch[arch_name] = name_to_arch.get(arch_name, []) + name_to_arch[arch_name].append(arch_info) + return architectures, name_to_arch + + +JSON_ARCHITECTURES, NAME_TO_ARCH = _load_all_architectures() diff --git a/mergekit/architecture/mixtral.py b/mergekit/architecture/mixtral.py new file mode 100644 index 00000000..47c9c440 --- /dev/null +++ b/mergekit/architecture/mixtral.py @@ -0,0 +1,58 @@ +# Copyright (C) 2025 Arcee AI +# SPDX-License-Identifier: BUSL-1.1 + +from typing import ClassVar, List, Optional + +from pydantic import BaseModel +from transformers import PretrainedConfig + +from mergekit.architecture.base import ( + ModuleArchitecture, + WeightInfo, +) +from mergekit.architecture.json_definitions import NAME_TO_ARCH + +MISTRAL_INFO = NAME_TO_ARCH["MistralForCausalLM"][0] +MISTRAL_MODULE_ARCH = MISTRAL_INFO.modules["default"].architecture + + +class MixtralTensorNames(ModuleArchitecture, BaseModel): + ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM" + num_local_experts: int + + def name(self) -> str: + return "mixtral" + + @classmethod + def from_config(cls, config: PretrainedConfig): + return MixtralTensorNames(num_local_experts=config.num_local_experts) + + def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + return MISTRAL_MODULE_ARCH.pre_weights(config) + + def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + return MISTRAL_MODULE_ARCH.post_weights(config) + + def num_layers_config_key(self) -> str: + return MISTRAL_MODULE_ARCH.num_layers_config_key() + + def layer_weights( + self, index: int, config: PretrainedConfig + ) -> Optional[List[WeightInfo]]: + num_experts = self.num_local_experts + prefix = f"model.layers.{index}" + tensor_names = [] + for expert_idx in range(num_experts): + for param in ("w1", "w2", "w3"): + tensor_names.append( + prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight" + ) + tensor_names.append(prefix + ".block_sparse_moe.gate.weight") + res = [] + for name in tensor_names: + res.append(WeightInfo(name=name)) + for weight_info in MISTRAL_MODULE_ARCH.layer_weights(index, config): + if ".mlp." in weight_info.name: + continue + res.append(weight_info) + return res diff --git a/mergekit/common.py b/mergekit/common.py index 8a087543..f16ddde5 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -13,6 +13,7 @@ Iterator, Mapping, Optional, + Protocol, Tuple, Union, get_args, @@ -31,6 +32,32 @@ from mergekit.io import LazyTensorLoader, ShardedTensorIndex +def set_config_value(config: PretrainedConfig, key: str, value: Any): + """Set a value in a PretrainedConfig object.""" + parts = key.split(".") + obj = config + for idx, part in enumerate(parts[:-1]): + if not hasattr(obj, part): + raise RuntimeError( + f"Config {config} has no attribute {'.'.join(parts[:idx+1])}" + ) + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def get_config_value(config: PretrainedConfig, key: str) -> Any: + """Get a value from a PretrainedConfig object.""" + parts = key.split(".") + obj = config + for idx, part in enumerate(parts): + if not hasattr(obj, part): + raise RuntimeError( + f"Config {config} has no attribute {'.'.join(parts[:idx+1])}" + ) + obj = getattr(obj, part) + return obj + + class ModelPath(BaseModel, frozen=True): path: str revision: Optional[str] = None @@ -92,7 +119,7 @@ def merged( os.makedirs(out_path, exist_ok=True) config = self.config(trust_remote_code) - auto_cls = _get_auto_cls(config.architectures[0]) + auto_cls = get_auto_cls(config.architectures[0]) logging.info(f"Loading {self.model} for merge...") model = auto_cls.from_pretrained( @@ -110,7 +137,7 @@ def merged( model.save_pretrained(out_path, safe_serialization=True) del model - return ModelReference(model=out_path) + return ModelReference(model=ModelPath(path=out_path)) def config(self, trust_remote_code: bool = False) -> PretrainedConfig: res = AutoConfig.from_pretrained( @@ -270,8 +297,70 @@ def values(self) -> Iterator[T_V]: return self.data.values() -def _get_auto_cls(arch_name: str): +ARCH_NAME_TO_AUTO_CLS = {} + +try: + import transformers.models.auto.modeling_auto as tf_auto +except ImportError: + tf_auto = None + +if tf_auto is not None: + for map_name, cls_name in [ + ("MODEL_MAPPING_NAMES", "AutoModel"), + ( + "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES", + "AutoModelForAudioClassification", + ), + ( + "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", + "AutoModelForImageClassification", + ), + ("MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", "AutoModelForSpeechSeq2Seq"), + ( + "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", + "AutoModelForSequenceClassification", + ), + ("MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"), + ( + "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES", + "AutoModelForTokenClassification", + ), + ("MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "AutoModelForImageTextToText"), + ("MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES", "AutoModelForTextToWaveform"), + ("MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"), + ("MODEL_FOR_CAUSAL_LM_MAPPING_NAMES", "AutoModelForCausalLM"), + ]: + cls = getattr(transformers, cls_name, None) + if cls is None: + logging.info(f"Could not find {cls_name} in transformers") + continue + if hasattr(tf_auto, map_name): + name_to_arch_name = getattr(tf_auto, map_name) + for arch_name in name_to_arch_name.values(): + ARCH_NAME_TO_AUTO_CLS[arch_name] = cls + + +class AutoClassProtocol(Protocol): + def from_pretrained( + self, + pretrained_model_name_or_path: str, + *model_args, + **kwargs, + ) -> transformers.PreTrainedModel: ... + + def from_config( + self, + config: transformers.PretrainedConfig, + *model_args, + **kwargs, + ) -> transformers.PreTrainedModel: ... + + +def get_auto_cls(arch_name: str) -> AutoClassProtocol: """Get the AutoModel class for a given architecture name.""" + if arch_name in ARCH_NAME_TO_AUTO_CLS: + return ARCH_NAME_TO_AUTO_CLS[arch_name] + if arch_name.endswith("ForMaskedLM"): auto_cls = transformers.AutoModelForMaskedLM elif arch_name.endswith("ForSequenceClassification"): diff --git a/mergekit/config.py b/mergekit/config.py index 532d30cf..c449e778 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -70,11 +70,24 @@ class OutputSliceDefinition(BaseModel): parameters: Optional[Dict[str, ParameterSetting]] = None -class MergeConfiguration(BaseModel): - merge_method: str +class OutputModuleDefinition(BaseModel): slices: Optional[List[OutputSliceDefinition]] = None models: Optional[List[InputModelDefinition]] = None parameters: Optional[Dict[str, ParameterSetting]] = None + + @model_validator(mode="after") + def validate_inputs(self): + if ((not self.slices) and (not self.models)) or (self.slices and self.models): + raise RuntimeError("Must specify either output slices or models to merge") + return self + + +class MergeConfiguration(BaseModel): + modules: Optional[Dict[str, OutputModuleDefinition]] = None + slices: Optional[List[OutputSliceDefinition]] = None + models: Optional[List[InputModelDefinition]] = None + + merge_method: str base_model: Optional[ModelReference] = None dtype: Optional[str] = None tokenizer_source: Union[Literal["union"], Literal["base"], ModelReference, None] = ( @@ -83,6 +96,7 @@ class MergeConfiguration(BaseModel): tokenizer: Optional[TokenizerConfig] = None chat_template: Optional[str] = None out_dtype: Optional[str] = None + parameters: Optional[Dict[str, ParameterSetting]] = None def referenced_models(self) -> List[ModelReference]: models = set() @@ -95,12 +109,31 @@ def referenced_models(self) -> List[ModelReference]: for s in self.slices: for src in s.sources: models.add(src.model) + if self.modules: + for m in self.modules.values(): + if m.models: + for model_in in m.models: + models.add(model_in.model) + if m.slices: + for s in m.slices: + for src in s.sources: + models.add(src.model) return list(models) @model_validator(mode="after") def validate_inputs(self): - if ((not self.slices) and (not self.models)) or (self.slices and self.models): - raise RuntimeError("Must specify either output slices or models to merge") + set_ct = 0 + if self.modules: + set_ct += 1 + if self.slices: + set_ct += 1 + if self.models: + set_ct += 1 + + if set_ct != 1: + raise RuntimeError( + "Exactly one of 'models', 'slices', or 'modules' must be present" + ) return self @model_validator(mode="after") @@ -121,6 +154,7 @@ class ConfigReader(BaseModel): t: float tensor_name: Optional[str] = None slice_out: Optional[OutputSliceDefinition] = None + module: Optional[OutputModuleDefinition] = None @property def base_model(self) -> Optional[ModelReference]: @@ -137,6 +171,7 @@ def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader": t=self.t, tensor_name=self.tensor_name, slice_out=slice, + module=self.module, ) def for_tensor(self, tensor_name: str) -> "ConfigReader": @@ -145,6 +180,7 @@ def for_tensor(self, tensor_name: str) -> "ConfigReader": t=self.t, tensor_name=tensor_name, slice_out=self.slice_out, + module=self.module, ) def with_t(self, t: float) -> "ConfigReader": @@ -153,6 +189,16 @@ def with_t(self, t: float) -> "ConfigReader": t=t, tensor_name=self.tensor_name, slice_out=self.slice_out, + module=self.module, + ) + + def for_module(self, module: OutputModuleDefinition) -> "ConfigReader": + return ConfigReader( + config=self.config, + t=self.t, + tensor_name=self.tensor_name, + slice_out=self.slice_out, + module=module, ) def parameter( @@ -179,6 +225,15 @@ def parameter( if value is not None: return value + if self.module and self.module.parameters and name in self.module.parameters: + value = evaluate_setting( + self.tensor_name, + self.module.parameters[name], + self.t, + ) + if value is not None: + return value + if self.config.parameters and name in self.config.parameters: value = evaluate_setting( self.tensor_name, diff --git a/mergekit/evo/actors.py b/mergekit/evo/actors.py index 43abdcbf..c37ff7e0 100644 --- a/mergekit/evo/actors.py +++ b/mergekit/evo/actors.py @@ -17,13 +17,15 @@ import transformers from transformers.utils import is_flash_attn_2_available +from mergekit.architecture.base import ConfiguredModelArchitecture + try: import vllm except ImportError: vllm = None -from mergekit.architecture import ArchitectureInfoUtils, ConfiguredArchitectureInfo +from mergekit.architecture import arch_info_for_config from mergekit.config import MergeConfiguration from mergekit.evo.config import EvolMergeConfiguration from mergekit.evo.genome import InvalidGenotypeError, ModelGenome @@ -130,7 +132,7 @@ class InMemoryMergeEvaluator(MergeActorBase): model: Union[ lm_eval.models.huggingface.HFLM, lm_eval.models.vllm_causallms.VLLM, None ] = None - arch_info: Optional[ConfiguredArchitectureInfo] = None + arch_info: Optional[ConfiguredModelArchitecture] = None def __init__( self, @@ -142,9 +144,7 @@ def __init__( super().__init__(*args, vllm=vllm, **kwargs) def _maybe_init_model(self, config: MergeConfiguration): - ai = ArchitectureInfoUtils.get_architecture_info( - self.genome._input_config_example - ) + ai = arch_info_for_config(self.genome._input_config_example) cfg_out = _model_out_config( config, ai, @@ -167,7 +167,7 @@ def _maybe_init_model(self, config: MergeConfiguration): continue if getattr(cfg_out, key) != getattr(self.arch_info.config, key, None): - logger.warn(f"Config key {key} changed, reinitializing model") + logger.warning(f"Config key {key} changed, reinitializing model") different = True break @@ -240,7 +240,14 @@ def _maybe_init_model(self, config: MergeConfiguration): ) else: self.model = lm_eval.models.huggingface.HFLM(pretrained=inner_model) - self.arch_info = ConfiguredArchitectureInfo(info=ai, config=cfg_out) + self.arch_info = ( + ConfiguredModelArchitecture( + info=ai, + config=cfg_out, + ) + if ai + else None + ) logger.info("Model initialized") def evaluate(self, genotype: torch.Tensor) -> dict: diff --git a/mergekit/graph.py b/mergekit/graph.py index d518ccae..1d6309b1 100644 --- a/mergekit/graph.py +++ b/mergekit/graph.py @@ -8,6 +8,7 @@ Executor: Class for scheduling and executing directed acyclic task graphs. """ +import logging from abc import ABC, abstractmethod from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union @@ -19,6 +20,8 @@ ValueT = TypeVar("ValueT") +logger = logging.getLogger(__name__) + class Task(ABC, BaseModel, Generic[ValueT], frozen=True): """ @@ -243,20 +246,35 @@ def _move_tensors( DUMMY_TASK_VALUE = "!!DUMMY!!" def _make_schedule(self, targets: List[Task]) -> List[Task]: + logger.debug(f"Building schedule for {len(targets)} targets") self.schedule = [] self.dependencies = self._build_dependencies(targets) + node_indices = {} + node_values = [] + + # instead of using the actual task objects as nodes in the graph, + # use an integer index to avoid reserializing the task objects + # inside networkx (slow) + def _index(node: Union[Task, str]) -> int: + if node not in node_indices: + node_indices[node] = len(node_indices) + node_values.append(node) + return node_indices[node] + edge_tups = [] for node in self.dependencies: for dependency in self.dependencies[node]: - edge_tups.append((dependency, node)) + edge_tups.append((_index(dependency), _index(node))) + # add edges from a dummy node to each target to guarantee + # they will be included in the final schedule + dummy_index = _index(Executor.DUMMY_TASK_VALUE) for task in targets: - # add edges from a dummy node to each target to guarantee - # they will be included in the final schedule - edge_tups.append((Executor.DUMMY_TASK_VALUE, task)) + edge_tups.append((dummy_index, _index(task))) - def _compare_key(task: Union[Task, str]): + def _compare_key(node: int) -> Tuple[str, int]: + task = node_values[node] if task == Executor.DUMMY_TASK_VALUE: return ("", 0) return ( @@ -265,13 +283,14 @@ def _compare_key(task: Union[Task, str]): ) graph = networkx.DiGraph(edge_tups) - res = [ - t - for t in networkx.lexicographical_topological_sort(graph, key=_compare_key) - if (t != Executor.DUMMY_TASK_VALUE) - and (t not in (self.cached_values or {})) + return [ + node_values[idx] + for idx in networkx.lexicographical_topological_sort( + graph, key=_compare_key + ) + if (idx != dummy_index) + and node_values[idx] not in (self.cached_values or {}) ] - return res def _build_dependencies(self, targets: List[Task]) -> Dict[Task, Set[Task]]: task_dependencies: Dict[Task, Set[Task]] = {} diff --git a/mergekit/io/tensor_writer.py b/mergekit/io/tensor_writer.py index 6ec474d1..19778f76 100644 --- a/mergekit/io/tensor_writer.py +++ b/mergekit/io/tensor_writer.py @@ -122,7 +122,7 @@ def finalize(self): json.dump( { "metadata": { - "mergekit_version": "0.1.1", + "mergekit_version": "0.1.2", }, "weight_map": self.weight_map, }, diff --git a/mergekit/merge.py b/mergekit/merge.py index 994774e3..cf18ad9e 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -6,16 +6,16 @@ import logging import os import shutil -import warnings from collections import Counter -from typing import Optional +from typing import List, Optional import tqdm import transformers from mergekit._data import chat_templates -from mergekit.architecture import ArchitectureInfo, ArchitectureInfoUtils +from mergekit.architecture import ModelArchitecture, get_architecture_info from mergekit.card import generate_card +from mergekit.common import set_config_value from mergekit.config import MergeConfiguration from mergekit.graph import Executor from mergekit.io.tasks import LoaderCache @@ -39,7 +39,7 @@ def run_merge( if not merge_config.models and not merge_config.slices: raise RuntimeError("No output requested") - arch_info = _load_arch_info(merge_config, options) + arch_info = get_architecture_info(merge_config, options) # initialize loader cache and set options loader_cache = LoaderCache() @@ -112,7 +112,11 @@ def run_merge( ) as fp: fp.write(config_source) - if tokenizer is None: + if tokenizer is not None: + logger.info("Saving tokenizer") + _set_chat_template(tokenizer, merge_config) + tokenizer.save_pretrained(out_path, safe_serialization=True) + else: if options.copy_tokenizer: try: _copy_tokenizer( @@ -128,10 +132,12 @@ def run_merge( "Chat template specified but no tokenizer found. Chat template will not be saved." ) - if tokenizer: - logger.info("Saving tokenizer") - _set_chat_template(tokenizer, merge_config) - tokenizer.save_pretrained(out_path, safe_serialization=True) + _copy_tagalong_files( + merge_config, + out_path, + files=arch_info.tagalong_files or [], + trust_remote_code=options.trust_remote_code, + ) if getattr(arch_info, "post_fill_parameters", False): from mergekit.scripts.fill_missing_params import copy_and_fill_missing_params @@ -192,6 +198,25 @@ def _set_chat_template( tokenizer.chat_template = chat_template +def _copy_tagalong_files( + merge_config: MergeConfiguration, + out_path: str, + files: List[str], + trust_remote_code: bool = False, +): + donor_model = merge_config.base_model or (merge_config.referenced_models()[0]) + + for file_name in files: + if os.path.exists(os.path.join(donor_model.model.path, file_name)): + logger.info(f"Copying {file_name} from {donor_model}") + shutil.copy( + os.path.join(donor_model.model.path, file_name), + os.path.join(out_path, file_name), + ) + + return + + def _copy_tokenizer( merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False ): @@ -214,6 +239,8 @@ def _copy_tokenizer( "special_tokens_map.json", "tokenizer.json", "tokenizer.model", + "added_tokens.json", + "merges.txt", ]: if os.path.exists(os.path.join(donor_model.model.path, file_name)): shutil.copy( @@ -236,7 +263,7 @@ def _copy_tokenizer( def _model_out_config( config: MergeConfiguration, - arch_info: ArchitectureInfo, + arch_info: ModelArchitecture, trust_remote_code: bool = False, ) -> transformers.PretrainedConfig: """Return a configuration for the resulting model.""" @@ -249,19 +276,33 @@ def _model_out_config( elif config.dtype: res.torch_dtype = config.dtype - if config.slices: - try: - num_layers = sum( + module_layers = {} + for module_name in arch_info.modules: + if config.modules and module_name in config.modules: + module_def = config.modules.get(module_name) + module_layers[module_name] = sum( s.sources[0].layer_range[1] - s.sources[0].layer_range[0] - for s in config.slices + for s in module_def.slices ) - setattr(res, arch_info.num_layers_config_key(), num_layers) - except Exception as e: - logger.warning( - "Unable to set number of layers in output config - you may need to manually correct it.", - exc_info=e, + elif config.slices: + module_layers[module_name] = sum( + s.sources[0].layer_range[1] - s.sources[0].layer_range[0] + for s in config.slices ) + if module_layers: + for module_name in module_layers: + try: + module_info = arch_info.modules[module_name] + cfg_key = module_info.architecture.num_layers_config_key() + set_config_value(res, cfg_key, module_layers[module_name]) + except Exception as e: + logger.warning( + f"Unable to set number of layers for module {module_name} in output config " + "- you may need to manually correct it.", + exc_info=e, + ) + return res @@ -282,32 +323,4 @@ def _update_config_vocab( ) -def _load_arch_info( - merge_config: MergeConfiguration, options: MergeOptions -) -> ArchitectureInfo: - """ - Loads architecture information, handling cases where models lack predefined architecture info. - """ - model_arch_info = [ - ArchitectureInfoUtils.get_architecture_info( - m.config(trust_remote_code=options.trust_remote_code) - ) - for m in merge_config.referenced_models() - ] - - if all(a is not None for a in model_arch_info): - if not options.allow_crimes and not all( - a == model_arch_info[0] for a in model_arch_info[1:] - ): - raise RuntimeError( - "Must specify --allow-crimes to attempt to mix different architectures" - ) - return model_arch_info[0] - else: - warnings.warn("Attempting Automatic Merge.") - model_arch_info = ArchitectureInfoUtils.infer_architecture_info(merge_config) - - return model_arch_info - - __all__ = ["MergeOptions", "run_merge"] diff --git a/mergekit/moe/deepseek.py b/mergekit/moe/deepseek.py index 9f8a4b1f..dba6b78e 100644 --- a/mergekit/moe/deepseek.py +++ b/mergekit/moe/deepseek.py @@ -10,7 +10,7 @@ import tqdm import transformers -from mergekit.architecture import ArchitectureInfoUtils +from mergekit.architecture import arch_info_for_config from mergekit.moe.arch import MoEOutputArchitecture from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype from mergekit.moe.config import MoEMergeConfig @@ -126,7 +126,7 @@ def write_model( loaders, base_loader, writer = initialize_io(config, out_path, merge_options) shared_loader = loaders.get(shared_def.source_model) if shared_def else None for weight_info in tqdm.tqdm( - ArchitectureInfoUtils.get_architecture_info(base_cfg).all_weights(base_cfg), + arch_info_for_config(base_cfg).all_weights(base_cfg), desc="Weights", ): tensor_name = weight_info.name diff --git a/mergekit/moe/mixtral.py b/mergekit/moe/mixtral.py index 5f0c7dfd..187e5f1e 100644 --- a/mergekit/moe/mixtral.py +++ b/mergekit/moe/mixtral.py @@ -8,7 +8,8 @@ import tqdm import transformers -from mergekit.architecture import MISTRAL_INFO, WeightInfo +from mergekit.architecture import WeightInfo +from mergekit.architecture.mixtral import MISTRAL_INFO from mergekit.moe.arch import MoEOutputArchitecture from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype from mergekit.moe.config import MoEMergeConfig diff --git a/mergekit/moe/qwen.py b/mergekit/moe/qwen.py index f5730a6c..46cc820c 100644 --- a/mergekit/moe/qwen.py +++ b/mergekit/moe/qwen.py @@ -12,12 +12,14 @@ # if the transformers version installed is too old from transformers.models.qwen2_moe import Qwen2MoeConfig -from mergekit.architecture import QWEN2_INFO +from mergekit.architecture.json_definitions import NAME_TO_ARCH from mergekit.moe.arch import MoEOutputArchitecture from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype from mergekit.moe.config import MoEMergeConfig from mergekit.options import MergeOptions +QWEN2_INFO = NAME_TO_ARCH["Qwen2ForCausalLM"][0] + class QwenMoE(MoEOutputArchitecture): def name(self) -> str: diff --git a/mergekit/multigpu_executor.py b/mergekit/multigpu_executor.py index 73c41e42..88aec975 100644 --- a/mergekit/multigpu_executor.py +++ b/mergekit/multigpu_executor.py @@ -50,7 +50,7 @@ def __init__( num_gpus: Number of GPUs to utilize (None = all available) storage_device: Device for storing tensors between stages """ - self.results = {} + self.results: Dict[Task, Any] = {} self.targets = set(tasks) self.storage_device = storage_device @@ -140,10 +140,10 @@ def update_progress(): ) for future in concurrent.futures.as_completed(futures): - if future.exception(): + if ex := future.exception(): self.done_event.set() executor.shutdown(wait=False) - raise future.exception() + raise ex self.done_event.set() progress_thread.join() @@ -237,7 +237,7 @@ def _assign_islands_to_gpus( islands = list(nx.weakly_connected_components(island_graph)) logger.info(f"Found {len(islands)} islands in parallel task graph") - assignments = {} + assignments: Dict[torch.device, List[Task]] = {} for island in islands: # Borrow orderings from original task list island_tasks = [t for t in tasks if t in island] diff --git a/mergekit/options.py b/mergekit/options.py index fb88f6a3..86ec4f9a 100644 --- a/mergekit/options.py +++ b/mergekit/options.py @@ -164,7 +164,7 @@ def wrapper(*args, **kwargs): class PrettyPrintHelp(click.Command): def format_options(self, ctx: Context, formatter: HelpFormatter) -> None: - categories = {None: []} + categories: dict[str, list[Parameter]] = {None: []} for param in ctx.command.params: if param.name in OPTION_CATEGORIES: category = OPTION_CATEGORIES[param.name] diff --git a/mergekit/plan.py b/mergekit/plan.py index 65e63bef..973bc69c 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -7,15 +7,17 @@ from mergekit import merge_methods from mergekit.architecture import ( - ArchitectureInfo, - ConfiguredArchitectureInfo, + ConfiguredModuleArchitecture, + ModelArchitecture, WeightInfo, ) +from mergekit.architecture.base import ConfiguredModelArchitecture from mergekit.common import ImmutableMap, ModelReference from mergekit.config import ( ConfigReader, InputSliceDefinition, MergeConfiguration, + OutputModuleDefinition, OutputSliceDefinition, ) from mergekit.graph import Task @@ -34,18 +36,18 @@ class MergePlanner: config: MergeConfiguration - arch_info: ArchitectureInfo + arch_info: ModelArchitecture options: MergeOptions out_model_config: Any _method: MergeMethod _tensors: List[Tuple[WeightInfo, Task]] - _current_layers: int = 0 + _current_module_layers: int = 0 _tokenizer_task: Optional[BuildTokenizer] = None def __init__( self, config: MergeConfiguration, - arch_info: ArchitectureInfo, + arch_info: ModelArchitecture, options: MergeOptions, out_model_config: Any, ): @@ -54,6 +56,7 @@ def __init__( self.options = options self.out_model_config = out_model_config self._method = merge_methods.get(config.merge_method) + self._tensors = [] token_cfg = {} tokenizer_source = config.tokenizer_source @@ -69,9 +72,17 @@ def __init__( add_tokens=tuple(token_cfg.keys()), ) + def _out_module_arch(self, module: str) -> ConfiguredModuleArchitecture: + module_def = self.arch_info.modules[module] + return ConfiguredModuleArchitecture( + info=module_def.architecture, + config=self.out_model_config, + weight_prefix=module_def.weight_prefix, + ) + @lru_cache - def model_arch_info(self, model: ModelReference): - return ConfiguredArchitectureInfo( + def _model_arch(self, model: ModelReference): + return ConfiguredModelArchitecture( info=self.arch_info, config=model.config(trust_remote_code=self.options.trust_remote_code), ) @@ -79,41 +90,70 @@ def model_arch_info(self, model: ModelReference): def normalize_config(self): base_model = self.config.base_model - # if models to merge are specified instead of output slices, compute them + # models -> modules.models if self.config.models: - if self.config.slices: - raise RuntimeError( - "Must specify either models to merge or output slices" + self.config.modules = {} + for module_name in self.arch_info.modules: + self.config.modules[module_name] = OutputModuleDefinition( + name=module_name, models=self.config.models ) + self.config.models = None - slices_in = [] - base_included = False - - for model_in in self.config.models: - if base_model and model_in.model == base_model: - base_included = True - - model_info = self.model_arch_info(model_in.model) - slices_in.append( - InputSliceDefinition( - layer_range=[0, model_info.num_layers()], - model=model_in.model, - parameters=model_in.parameters, - ) + # slices -> modules.slices + if self.config.slices: + if len(self.arch_info.modules) != 1: + raise RuntimeError( + "Model has multiple modules, must use modules: config syntax " + "to work with slices" ) + module_name = list(self.arch_info.modules.keys())[0] + self.config.modules = { + module_name: OutputModuleDefinition(slices=self.config.slices) + } + self.config.slices = None + + # modules.models -> modules.slices + for module_name in self.config.modules: + module_out = self.config.modules[module_name] + module_arch = self.arch_info.modules[module_name].architecture + + if module_out.models: + slices_in = [] + base_included = False + + for model_in in module_out.models: + if base_model and model_in.model == base_model: + base_included = True + + model_cfg = model_in.model.config( + trust_remote_code=self.options.trust_remote_code + ) + num_layers = module_arch.num_layers(model_cfg) + slices_in.append( + InputSliceDefinition( + layer_range=[0, num_layers], + model=model_in.model, + parameters=model_in.parameters, + ) + ) - if base_model and not base_included: - logging.info("Base model specified but not in input models - adding") - base_info = self.model_arch_info(base_model) - slices_in.append( - InputSliceDefinition( - layer_range=[0, base_info.num_layers()], - model=base_model, + if base_model and not base_included: + logging.info( + "Base model specified but not in input models - adding" + ) + base_cfg = base_model.config( + trust_remote_code=self.options.trust_remote_code + ) + num_layers = module_arch.num_layers(base_cfg) + slices_in.append( + InputSliceDefinition( + layer_range=[0, num_layers], + model=base_model, + ) ) - ) - self.config.slices = [OutputSliceDefinition(sources=slices_in)] - self.config.models = None + module_out.slices = [OutputSliceDefinition(sources=slices_in)] + module_out.models = None def plan_tensor( self, @@ -201,15 +241,16 @@ def plan_layer( layer_offset: int, t: float, cfg_reader: ConfigReader, + module_name: str, ): - weights_out: List[WeightInfo] = self.arch_info.layer_weights( - index=self._current_layers, - config=self.out_model_config, + module_arch = self._out_module_arch(module_name) + weights_out: List[WeightInfo] = module_arch.layer_weights( + index=self._current_module_layers, ) weights_in: List[List[WeightInfo]] = [ - self.model_arch_info(s.model).layer_weights( - index=s.layer_range[0] + layer_offset - ) + self._model_arch(s.model) + .get_module(module_name) + .layer_weights(index=s.layer_range[0] + layer_offset) for s in sources ] @@ -221,9 +262,14 @@ def plan_layer( cfg_reader=cfg_reader.with_t(t), ) - self._current_layers += 1 + self._current_module_layers += 1 - def plan_slice(self, definition: OutputSliceDefinition): + def plan_slice( + self, + definition: OutputSliceDefinition, + module_def: OutputModuleDefinition, + module_name: str, + ): slice_lengths = [ s.layer_range[1] - s.layer_range[0] for s in definition.sources ] @@ -233,7 +279,9 @@ def plan_slice(self, definition: OutputSliceDefinition): ) num_layers = slice_lengths[0] - cfg_reader = ConfigReader(config=self.config, slice_out=definition, t=0) + cfg_reader = ConfigReader( + config=self.config, slice_out=definition, t=0, module=module_def + ) for idx in range(num_layers): # compute t for interpolated gradients if num_layers > 1: @@ -246,6 +294,40 @@ def plan_slice(self, definition: OutputSliceDefinition): layer_offset=idx, t=t, cfg_reader=cfg_reader, + module_name=module_name, + ) + + def plan_module(self, module_name: str, definition: OutputModuleDefinition): + self._current_module_layers = 0 + + module_arch = self._out_module_arch(module_name) + config_reader = ConfigReader(config=self.config, t=0, module=definition) + + for weight_info in module_arch.pre_weights(): + self.plan_tensor( + weight_info, + [weight_info] * len(definition.slices[0].sources), + [s.model for s in definition.slices[0].sources], + config_reader.for_tensor(tensor_name=weight_info.name).for_out_slice( + definition.slices[0] + ), + ) + + for out_slice in definition.slices: + self.plan_slice( + out_slice, + module_def=definition, + module_name=module_name, + ) + + for weight_info in module_arch.post_weights(): + self.plan_tensor( + weight_info, + [weight_info] * len(definition.slices[0].sources), + [s.model for s in definition.slices[-1].sources], + config_reader.for_tensor(tensor_name=weight_info.name).for_out_slice( + definition.slices[-1] + ), ) def plan_to_disk(self, out_path: str) -> List[Task]: @@ -292,31 +374,7 @@ def plan_in_memory(self) -> List[ReturnTensor]: def _plan(self): self.normalize_config() - self._tensors = [] - - for weight_info in self.arch_info.pre_weights(config=self.out_model_config): - self.plan_tensor( - weight_info, - [weight_info] * len(self.config.slices[0].sources), - [s.model for s in self.config.slices[0].sources], - ConfigReader( - config=self.config, - t=0, - tensor_name=weight_info.name, - ).for_out_slice(self.config.slices[0]), - ) - - for out_slice in self.config.slices: - self.plan_slice(out_slice) + self._tasks = [] - for weight_info in self.arch_info.post_weights(config=self.out_model_config): - self.plan_tensor( - weight_info, - [weight_info] * len(self.config.slices[-1].sources), - [s.model for s in self.config.slices[-1].sources], - ConfigReader( - config=self.config, - t=1, - tensor_name=weight_info.name, - ).for_out_slice(self.config.slices[-1]), - ) + for module_name in self.config.modules: + self.plan_module(module_name, self.config.modules[module_name]) diff --git a/mergekit/scripts/ABM/activations_based_merge.py b/mergekit/scripts/ABM/activations_based_merge.py deleted file mode 100644 index 3834892d..00000000 --- a/mergekit/scripts/ABM/activations_based_merge.py +++ /dev/null @@ -1,171 +0,0 @@ -import logging -import os -from typing import Optional - -import click -import safetensors.torch -import torch -import tqdm -from transformers import AutoTokenizer - -from mergekit.architecture import ArchitectureInfoUtils -from mergekit.common import ModelReference, dtype_from_name -from mergekit.io.tasks import LoaderCache -from mergekit.io.tensor_writer import TensorWriter -from mergekit.options import MergeOptions, add_merge_options - - -@click.command("mergekit-activation-based-merge") -@click.argument("model_path", type=str) -@click.argument("secondary_model_path", type=str) -@click.argument("merge_unmerge_directory", type=str) -@click.option("--out-path", "-o", required=True, type=str, help="Output model path") -@click.option( - "--dtype", - type=str, - default="float16", - help="Data type to convert weights to", -) -@click.option( - "--device", - "-d", - type=str, - default="cuda", - help="Device to compute on (default: cuda)", -) -@add_merge_options -def main( - model_path: str, - secondary_model_path, - merge_unmerge_directory: str, - out_path: str, - dtype: Optional[str], - device: Optional[str], - merge_options: MergeOptions, -): - model = ModelReference.model_validate(model_path) - secondary_model = ModelReference.model_validate(secondary_model_path) - - dtype = dtype_from_name(dtype) if dtype else None - - cache = LoaderCache() - cache.lazy_unpickle = merge_options.lazy_unpickle - cache.hf_cache_dir = merge_options.transformers_cache - - for m in tqdm.tqdm([model, secondary_model], desc="Preparing models"): - cache.get(m) - - writer = TensorWriter( - out_path=out_path, - max_shard_size=merge_options.out_shard_size, - safe_serialization=merge_options.safe_serialization, - ) - - model_config = model.config(trust_remote_code=merge_options.trust_remote_code) - model_arch_info = ArchitectureInfoUtils.get_architecture_info( - model.config(trust_remote_code=merge_options.trust_remote_code) - ) - - loader_1 = cache.get(model) - loader_2 = cache.get(secondary_model) - - os.makedirs(out_path, exist_ok=True) - - merge_unmerge_dictionary = {} - # load files from merge_unmerge_directory - spaces = [ - f.split("_unmerge")[0] - for f in os.listdir(merge_unmerge_directory) - if "_unmerge" in f - ] - for i in spaces: - logging.info(f"Loading merge/unmerge tensors for {i}") - m = safetensors.torch.load_file( - os.path.join(merge_unmerge_directory, f"{i}_merge.safetensor"), - device=device, - ) - u = safetensors.torch.load_file( - os.path.join(merge_unmerge_directory, f"{i}_unmerge.safetensor"), - device=device, - ) - merge_unmerge_dictionary[i] = ( - m[i].to(device, dtype=dtype), - u[i].to(device, dtype=dtype), - ) - - for weight_info in model_arch_info.all_weights(config=model_config): - merge_matrix, unmerge_matrix = None, None - - if weight_info.input_space in merge_unmerge_dictionary: - _, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space] - unmerge_matrix = unmerge_matrix.chunk(2, dim=0) - - if weight_info.output_space in merge_unmerge_dictionary: - merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space] - merge_matrix = merge_matrix.chunk(2, dim=1) - - original_w = loader_1.get_tensor(weight_info.name, device=device) - original_w2 = loader_2.get_tensor(weight_info.name, device=device) - - if dtype is not None: - original_w = original_w.to(dtype=dtype) - original_w2 = original_w2.to(dtype=dtype) - - w = torch.clone(original_w) - w2 = torch.clone(original_w2) - - if not merge_matrix and not unmerge_matrix: - logging.warning( - f"❌ Weight {weight_info.name} for model 1 and model 2 has no merge or unmerge matrix" - ) - - if merge_matrix is not None: - if weight_info.is_embed: - w = (merge_matrix[0] @ w.T).T - w2 = (merge_matrix[1] @ w2.T).T - else: - w = merge_matrix[0] @ w - w2 = merge_matrix[1] @ w2 - - if unmerge_matrix is not None: - w = w @ unmerge_matrix[0] - w2 = w2 @ unmerge_matrix[1] - - # check if weights have not mutated, if yes then shoot warning - if torch.allclose(original_w, w): - logging.warning( - f"❌ Weight {weight_info.name} for model 1 has NOT mutated during merge" - ) - else: - logging.warning( - f"✅ Weight {weight_info.name} for model 1 has mutated during merge" - ) - - if torch.allclose(original_w2, w2): - logging.warning( - f"❌ Weight {weight_info.name} for model 2 has NOT mutated during merge" - ) - else: - logging.warning( - f"✅ Weight {weight_info.name} for model 2 has mutated during merge" - ) - - # average weights and save them - if merge_matrix: - w = w + w2 - else: - w = (w + w2) / 2 - writer.save_tensor(weight_info.name, w) - writer.finalize() - - tokenizer = AutoTokenizer.from_pretrained(model_path) - tokenizer.save_pretrained(out_path, safe_serialization=True) - - # write config - model_out_config = model.config(trust_remote_code=merge_options.trust_remote_code) - if dtype: - model_out_config.torch_dtype = dtype - model_out_config.save_pretrained(out_path) - - -main() diff --git a/mergekit/scripts/ABM/extract_activations.py b/mergekit/scripts/ABM/extract_activations.py deleted file mode 100644 index 3f7c151b..00000000 --- a/mergekit/scripts/ABM/extract_activations.py +++ /dev/null @@ -1,347 +0,0 @@ -import logging -import os -from collections import defaultdict -from typing import List, Optional - -import click -import datasets -import numpy as np -import torch -from safetensors.torch import save_file -from torch.utils.data import DataLoader -from transformers import AutoModel, AutoTokenizer, DefaultDataCollator - -from mergekit.architecture import ArchitectureInfoUtils, _template_substitution -from mergekit.common import ModelReference - -logging.basicConfig(level=logging.INFO) - -# set seed -torch.manual_seed(42) -np.random.seed(42) - - -def clean_name(name): - return name.replace(".weight", "").replace("model.", "") - - -def parse_items(ctx, param, value): - if value is not None: - return [item.strip() for item in value.split(",")] - - -def remove_pads(attention_mask, feature_vector): - if ( - len(feature_vector.shape) == 3 - ): # Hidden states: (batch_size, seq_length, embedding_dim) - # Expand mask to match the feature_vector dimensions and apply it - expanded_mask = attention_mask.unsqueeze(-1) - filtered_feature_vector = feature_vector * expanded_mask - else: - raise ValueError("Unsupported feature vector shape.") - - return filtered_feature_vector - - -def get_attention_output_hook(storage_dict, space_name, capture_input=True): - """ - Returns a hook function that stores the output of the attention layer. - """ - - def hook(module, input, output): - # NOTE: shape of input is [batch, seq_len, dim] and output is Tuple[(seq_len, dim),...] - if capture_input: - o = input[0].detach() - else: - o = output.detach() - - if space_name not in storage_dict: - storage_dict[space_name] = o - else: - storage_dict[space_name] = torch.cat((storage_dict[space_name], o), dim=0) - - return hook - - -""" - -What this script does: - -It tries to map input/output spaces to activation maps - -""" - - -@click.command("mergekit-abm-extract-activations") -@click.argument("model-path", type=str) -@click.option( - "--dataset", "-d", required=True, type=str, help="Dataset to use for activations" -) -@click.option("--out-path", "-o", required=True, type=str, help="Output model path") -@click.option("--batch-size", "-b", type=int, default=2, help="Batch size") -@click.option( - "--dataset-size", - "-s", - type=int, - default=None, - help="Dataset size. If None, use full dataset", -) -@click.option( - "--dataset-column", "-c", type=str, default="text", help="Dataset column to use" -) -@click.option( - "--dataset-subset", "-u", type=str, default="eval", help="Dataset subset to use" -) -@click.option( - "--chat-template/--no-chat-template", - default=False, - help="use Chat template for inference", -) -@click.option("--max-length", "-l", type=int, default=512, help="Max length") -@click.option("--dtype", type=str, default=None, help="Data type to convert weights to") -@click.option( - "--device", type=str, default=None, help="device to compute the activations" -) -@click.option( - "--ignore-spaces", - "-i", - type=str, - default="", - callback=parse_items, - help="Spaces to ignore separated by comma. Example: up_${layer_index}", -) -def main( - model_path: str, - dataset: str, - dataset_column: str, - out_path: str, - batch_size: int, - max_length: int, - dataset_size: Optional[int], - dataset_subset: Optional[str], - chat_template: Optional[bool], - dtype: Optional[str], - device: Optional[str], - ignore_spaces: Optional[List[str]], -): - # sorting out locations to hook into - # we do this via the predefined json architecture definitions in mergekit - - model = ModelReference.model_validate(model_path) - - model_config = model.config() - model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config) - - _json = model_arch_info.definition - - residual_space = None - - weights = [] - for weight in _json.layer_templates.weights: - if weight.is_kq: - residual_space = weight.input_space - weights.append(weight) - - if residual_space is None: - raise ValueError("No residual space found") - - # ======================== Mapping spaces to weights ======================== - - # just a list of connected components - space_to_output_weight_templates = defaultdict(list) - space_to_input_weight_templates = defaultdict(list) - - for layer_template in weights: - if ( - not layer_template.input_space - or layer_template.input_space in ignore_spaces - ): - continue - space_to_output_weight_templates[layer_template.input_space].append( - layer_template.name - ) - - for layer_template in weights: - if ( - not layer_template.output_space - or layer_template.output_space in ignore_spaces - ): - continue - space_to_input_weight_templates[layer_template.output_space].append( - layer_template.name - ) - - # remove the residual space from the input and output - space_to_input_weight_templates.pop(residual_space, None) - space_to_output_weight_templates.pop(residual_space, None) - - # NOTE: if space has input and output weights, remove one or the other because hooking - # into both will result in duplicate activations - to_remove = [] - for space, input_weights in space_to_input_weight_templates.items(): - if space in space_to_output_weight_templates: - # if count of input weights and output weights is non zero, remove the space from space to output_weights - if ( - len(input_weights) > 0 - and len(space_to_output_weight_templates[space]) > 0 - ): - to_remove.append(space) - - # remove keys from output - space_to_output_weight_templates = { - k: v for k, v in space_to_output_weight_templates.items() if k not in to_remove - } - - num_layers = model_arch_info.num_layers(model_config) - - space_to_input_weights = {} - for k, v in space_to_input_weight_templates.items(): - for j in range(num_layers): - f = lambda x: _template_substitution(x, num_layers=num_layers, layer_idx=j) - space_to_input_weights[f(k)] = [f(_v) for _v in v] - - space_to_output_weights = {} - for k, v in space_to_output_weight_templates.items(): - for j in range(num_layers): - f = lambda x: _template_substitution(x, num_layers=num_layers, layer_idx=j) - space_to_output_weights[f(k)] = [f(_v) for _v in v] - - # ================== Load model, tokenizer for inference and prepare dataset ================== - - model = AutoModel.from_pretrained( - model_path, output_attentions=True, attn_implementation="eager" - ) - tokenizer = AutoTokenizer.from_pretrained(model_path) - - if not tokenizer.pad_token: - tokenizer.pad_token = tokenizer.eos_token - - tokenize_function = None - if chat_template: - logging.info("Using chat template for inference") - tokenize_function = lambda x: tokenizer.apply_chat_template( - x, - padding="longest", - max_length=max_length, - truncation=True, - return_dict=True, - ) - else: - logging.info("Using default tokenizer (no chat template) for inference") - tokenize_function = lambda x: tokenizer( - x, - padding="longest", - max_length=max_length, - truncation=True, - ) - - model.eval() - model.to(device) - if dtype is not None: - model = model.to(dtype=dtype) - - dataset = datasets.load_dataset(dataset)[dataset_subset] - - if dataset_size is not None: - logging.info("Using dataset size %s", dataset_size) - dataset = dataset.select(range(dataset_size)) - - def tokenize(element): - outputs = tokenize_function(element[dataset_column]) - return { - "input_ids": outputs["input_ids"], - "attention_mask": outputs["attention_mask"], - } - - dataset = dataset.map(tokenize).select_columns(["input_ids", "attention_mask"]) - - datasets_dataloader = DataLoader( - dataset, batch_size=batch_size, shuffle=False, collate_fn=DefaultDataCollator() - ) - - feature_storage = {} - storage_dict = {} - - # ================== Hooking into the model ================== - - # NOTE: if the capture input set to True seems confusing, a space's output is a weight recieving input from the space - for k, v in space_to_output_weights.items(): - for weight in v: - weight = clean_name(weight) - model.get_submodule(weight).register_forward_hook( - get_attention_output_hook(feature_storage, k, capture_input=True) - ) - for k, v in space_to_input_weights.items(): - for weight in v: - weight = clean_name(weight) - model.get_submodule(weight).register_forward_hook( - get_attention_output_hook(feature_storage, k, capture_input=False) - ) - - # ================== Inference ================== - - for batch in datasets_dataloader: - with torch.no_grad(): - inputs = {k: v.to(device) for k, v in batch.items()} - outputs = model( - **inputs, output_hidden_states=True, output_attentions=False - ) - - # NOTE: https://huggingface.co/docs/transformers/en/main_classes/output#transformers.modeling_outputs.BaseModelOutput - - # Store attention masks - attention_mask = inputs["attention_mask"] - if "attention_mask" not in feature_storage: - feature_storage["attention_mask"] = attention_mask.cpu().detach() - else: - feature_storage["attention_mask"] = torch.cat( - (feature_storage["attention_mask"], attention_mask.cpu().detach()), - dim=0, - ) - - hidden_states = [ - remove_pads(attention_mask, hidden_state) - for hidden_state in outputs.hidden_states - ] - hidden_states = torch.stack(outputs.hidden_states, dim=1) - - if residual_space not in feature_storage: - feature_storage[residual_space] = hidden_states - else: - feature_storage[residual_space] = torch.cat( - (feature_storage[residual_space], hidden_states), dim=0 - ) - - for space_name, v in storage_dict.items(): - if space_name not in feature_storage: - feature_storage[space_name] = v - else: - feature_storage[space_name] = torch.cat( - (feature_storage[space_name], v), dim=0 - ) - - storage_dict = {} - - # ================== Save activations/features ================== - - logging.info("Feature storage:") - for k, v in feature_storage.items(): - if v is not None: - logging.info(f"{k}: Shape: {v.shape}") - - abs_path = os.path.abspath(model_path) - if os.path.exists(abs_path): - model_path = abs_path - - model_path = model_path.replace("/", "_") - - # create output directory - os.makedirs(out_path, exist_ok=True) - - save_file( - feature_storage, os.path.join(out_path, f"{model_path}_features.safetensor") - ) - - -if __name__ == "__main__": - main() diff --git a/mergekit/scripts/ABM/extract_permutation_matrices.py b/mergekit/scripts/ABM/extract_permutation_matrices.py deleted file mode 100644 index 4c862664..00000000 --- a/mergekit/scripts/ABM/extract_permutation_matrices.py +++ /dev/null @@ -1,226 +0,0 @@ -import os -import sys -from collections import defaultdict - -import click -import numpy as np -import safetensors.torch -import scipy -import torch - -from mergekit.architecture import ArchitectureInfoUtils, _template_substitution -from mergekit.common import ModelReference - - -def calc_correlation_matrix(feats): - feats = feats.view(-1, feats.shape[-1]) - - return torch.corrcoef(feats.T) - - -def match_tensors_permute( - absval=False, - correlation_matrix=None, -): - """ - This function is adapted from ZipIt! (https://github.com/gstoica27/ZipIt) - """ - - Om = correlation_matrix.shape[0] // 2 - device = correlation_matrix.device - - mats = [torch.eye(Om, device=device)] - - corr_submatrix = correlation_matrix[:Om, Om:].cpu().numpy() - if absval: - corr_submatrix = np.absolute(corr_submatrix) - _, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True) - - new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)] - mats.append(new_mat.T) - - unmerge_mats = mats - - unmerge = torch.cat(unmerge_mats, dim=0) - - merge = torch.cat(mats, dim=0) - merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) - - return merge.T, unmerge - - -def match_tensors_permute_MHA( - n_heads=32, - absval=False, - correlation_matrix=None, -): - """ - Handles different head permutations in attention. - Modified version of the function here: https://github.com/nverma1/merging-text-transformers/blob/main/matching_functions.py#L76 - """ - - Om = correlation_matrix.shape[0] // 2 - device = correlation_matrix.device - query_size = Om // n_heads - - mats = [torch.eye(Om, device=device)] - head_perms = [] - - costs = np.ones((n_heads, n_heads)) * -sys.maxsize - - col_inds_storage = defaultdict(lambda: defaultdict(int)) - - for j in range(n_heads): - for k in range(n_heads): - head1_idx = [query_size * j, query_size * (j + 1)] - head2_idx = [query_size * k, query_size * (k + 1)] - - corr_submatrix = ( - correlation_matrix[ - head1_idx[0] : head1_idx[1], - (Om + head2_idx[0]) : (Om + head2_idx[1]), - ] - .cpu() - .numpy() - ) - if absval: - corr_submatrix = np.absolute(corr_submatrix) - - # compute perm for head j & head k - row_ind, col_ind = scipy.optimize.linear_sum_assignment( - corr_submatrix, maximize=True - ) - - costs[j, k] = corr_submatrix[row_ind, col_ind].sum() - - col_inds_storage[j][k] = col_ind - - outer_row_ind, outer_col_ind = scipy.optimize.linear_sum_assignment( - costs, maximize=True - ) - - for j in range(n_heads): - head_1 = outer_row_ind[j] - head_2 = outer_col_ind[j] - - head_perm = col_inds_storage[head_1][head_2] - head_perms.append(torch.tensor(head_perm + query_size * head_2)) - - new_mat = torch.eye(Om, device=device)[ - torch.cat(head_perms).clone().detach().long().to(device) - ] - mats.append(new_mat.T) - - unmerge_mats = mats - - unmerge = torch.cat(unmerge_mats, dim=0) - merge = torch.cat(mats, dim=0) - merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5) - - return merge.T, unmerge - - -@click.command("mergekit-abm-extract-permutations") -@click.argument("model1-ft", type=str, required=True) -@click.argument("model2-ft", type=str, required=True) -@click.option("--model_path", type=str, required=True, help="Model information") -@click.option( - "--out_path", required=True, type=str, help="Output path for metric tensors" -) -@click.option( - "--absval/--no-absval", - required=False, - default=False, - help="Use absolute value on correlation matrices/submatrices while calculating merge/unmerge matrices", -) -@click.option( - "--device", - "-d", - type=str, - default="cpu", - help="Device to compute on (default: cpu)", -) -def main(model1_ft, model2_ft, model_path, out_path, absval, device): - os.makedirs(out_path, exist_ok=True) - - model = ModelReference.model_validate(model_path) - - model_config = model.config() - - model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config) - - _json = model_arch_info.definition - - residual_space = None - kq_space = None - v_space = None - - # extract the residual, attention related spaces - for weight in _json.layer_templates.weights: - if weight.is_kq: - kq_space = weight.output_space - residual_space = weight.input_space - continue - - # assuming order is observed - if ( - not weight.is_kq - and weight.head_split - and (weight.input_space == residual_space) - ): - v_space = weight.output_space - continue - - num_layers = model_arch_info.num_layers(model_config) - - kq_spaces = [] - v_spaces = [] - for j in range(num_layers): - kq_spaces.append( - _template_substitution(kq_space, num_layers=num_layers, layer_idx=j) - ) - v_spaces.append( - _template_substitution(v_space, num_layers=num_layers, layer_idx=j) - ) - - model1_features = safetensors.torch.load_file(model1_ft, device=device) - model2_features = safetensors.torch.load_file(model2_ft, device=device) - - model1_features.pop("attention_mask") - model2_features.pop("attention_mask") - - for feature_space in model1_features.keys(): - concatenated_feature = torch.cat( - (model1_features[feature_space], model2_features[feature_space]), dim=-1 - ) - - correlation_matrix = calc_correlation_matrix(concatenated_feature) - - if feature_space in (kq_spaces + v_spaces): - merge, unmerge = match_tensors_permute_MHA( - correlation_matrix=correlation_matrix, - n_heads=model_config.num_attention_heads, - absval=absval, - ) - - else: - merge, unmerge = match_tensors_permute( - correlation_matrix=correlation_matrix, - absval=absval, - ) - - safetensors.torch.save_file( - {feature_space: merge.contiguous()}, - f"{out_path}/{feature_space}_merge.safetensor", - ) - - safetensors.torch.save_file( - {feature_space: unmerge.contiguous()}, - f"{out_path}/{feature_space}_unmerge.safetensor", - ) - - del merge, unmerge, correlation_matrix, concatenated_feature - - -if __name__ == "__main__": - main() diff --git a/mergekit/scripts/evolve.py b/mergekit/scripts/evolve.py index 9ad35eb3..7dafc0bd 100644 --- a/mergekit/scripts/evolve.py +++ b/mergekit/scripts/evolve.py @@ -1,6 +1,7 @@ # Copyright (C) 2025 Arcee AI # SPDX-License-Identifier: BUSL-1.1 +import importlib.util import logging import os import time @@ -126,7 +127,7 @@ def main( vllm: bool, strategy: str, in_memory: bool, - storage_path: Optional[str], + storage_path: str, num_gpus: Optional[int], merge_cuda: bool, trust_remote_code: bool, @@ -160,9 +161,7 @@ def main( raise ValueError("Cannot use vLLM with 4-bit or 8-bit models") if in_memory: raise ValueError("Cannot use in-memory mode with 4-bit or 8-bit models") - try: - import bitsandbytes - except ImportError: + if not importlib.util.find_spec("bitsandbytes"): raise RuntimeError("bitsandbytes is not installed") bnb_config = transformers.BitsAndBytesConfig( @@ -271,7 +270,7 @@ def progress_callback(es: cma.CMAEvolutionStrategy): nonlocal xbest, xbest_cost res = es.result - if use_wandb: + if use_wandb and run is not None: best_params = genome.genotype_to_param_arrays(res.xbest) mean_params = genome.genotype_to_param_arrays(res.xfavorite) run.log( @@ -377,7 +376,10 @@ def parallel_evaluate(x: List[np.ndarray]) -> List[float]: def _reshard_model( - model: ModelReference, storage_path: str, merge_cache: str, trust_remote_code: bool + model: ModelReference, + storage_path: str, + merge_cache: Optional[str], + trust_remote_code: bool, ) -> ModelReference: merged = model.merged( cache_dir=merge_cache, diff --git a/mergekit/scripts/extract_lora.py b/mergekit/scripts/extract_lora.py index 023b10fe..53b055a4 100644 --- a/mergekit/scripts/extract_lora.py +++ b/mergekit/scripts/extract_lora.py @@ -12,12 +12,12 @@ import torch import torch.nn as nn import tqdm +import transformers from pydantic import BaseModel -from transformers import AutoModelForCausalLM -from mergekit.architecture import ArchitectureInfoUtils, WeightInfo +from mergekit.architecture import WeightInfo, arch_info_for_config from mergekit.card import generate_card_lora -from mergekit.common import ModelReference +from mergekit.common import ModelReference, get_auto_cls from mergekit.graph import Executor, Task from mergekit.io.tasks import FinalizeModel, LoadTensor, SaveTensor, TensorWriterTask from mergekit.io.tensor_writer import TensorWriter @@ -323,6 +323,20 @@ def _wi_load(model_ref: ModelReference, weight_info: WeightInfo) -> LoadTensor: ) +def _make_dummy_model( + model_ref: ModelReference, trust_remote_code: bool = False +) -> transformers.PreTrainedModel: + model_cfg = transformers.AutoConfig.from_pretrained( + model_ref.model.path, + revision=model_ref.model.revision, + trust_remote_code=trust_remote_code, + ) + auto_cls = get_auto_cls(model_cfg.architectures[0]) + with torch.device("meta"): + res = auto_cls.from_config(model_cfg, trust_remote_code=trust_remote_code) + return res + + class PlanResults(BaseModel): tasks: List[Task] base_vocab_size: int @@ -352,20 +366,8 @@ def plan_extraction( ) name_to_wi = all_weights_map(model_ref, options) - dummy_model = AutoModelForCausalLM.from_pretrained( - model_ref.model.path, - revision=model_ref.model.revision, - trust_remote_code=options.trust_remote_code, - device_map="meta", - state_dict={}, - ) - dummy_base = AutoModelForCausalLM.from_pretrained( - base_model_ref.model.path, - revision=base_model_ref.model.revision, - trust_remote_code=options.trust_remote_code, - device_map="meta", - state_dict={}, - ) + dummy_base = _make_dummy_model(base_model_ref, options.trust_remote_code) + dummy_model = _make_dummy_model(model_ref, options.trust_remote_code) embed_in = dummy_model.get_input_embeddings() embed_out = dummy_model.get_output_embeddings() @@ -378,6 +380,7 @@ def plan_extraction( ) logger.warning("Enforcing embeddings in modules_to_save, embed_lora=False") embed_lora = False + del dummy_base warned_modules = set() @@ -553,7 +556,7 @@ def all_weights_map( ) -> Dict[str, WeightInfo]: name_to_wi = {} model_cfg = model_ref.config(trust_remote_code=options.trust_remote_code) - arch_info = ArchitectureInfoUtils.get_architecture_info(model_cfg) + arch_info = arch_info_for_config(model_cfg) for wi in arch_info.all_weights(model_cfg): name_to_wi[wi.name] = wi return name_to_wi diff --git a/mergekit/scripts/fill_missing_params.py b/mergekit/scripts/fill_missing_params.py index 81aec1b3..e8bc6d4d 100644 --- a/mergekit/scripts/fill_missing_params.py +++ b/mergekit/scripts/fill_missing_params.py @@ -3,13 +3,14 @@ import logging import shutil from pathlib import Path +from typing import List, Optional, Tuple import click import torch +from huggingface_hub import snapshot_download from safetensors import safe_open from tqdm import tqdm -from mergekit.architecture import ParameterNamesUtils from mergekit.io.lazy_tensor_loader import ShardedTensorIndex from mergekit.io.tensor_writer import TensorWriter @@ -197,3 +198,199 @@ def main( if __name__ == "__main__": main() + + +class ParameterNamesUtils: + """Utility functions for handling parameter names.""" + + @staticmethod + def resolve_model_directory(repo_id: str) -> Path: + """Resolve the model directory (local or Hugging Face Hub).""" + if Path(repo_id).is_dir(): + return Path(repo_id) + + return Path(snapshot_download(repo_id)) + + @staticmethod + def get_model_parameter_names(repo_id: str) -> List[str]: + """Get parameter names of a model from a Hugging Face repo or local directory.""" + model_dir = ParameterNamesUtils.resolve_model_directory(repo_id) + return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys()) + + @staticmethod + def strip_prefix(name: str, prefix: str) -> str: + """Remove a single prefix from the start of a name.""" + if prefix != "" and name.startswith(prefix + "."): + return name[len(prefix) + 1 :] + return name + + @staticmethod + def find_prefix(list1: List[str], list2: List[str]) -> Optional[str]: + """ + Find a prefix in list1 that, after removal, makes list2 an ordered sublist. + """ + assert len(list1) >= len(list2), "params name list1 can't be shorter than list2" + + possible_prefixes = {item.split(".")[0] for item in list1 if "." in item} + possible_prefixes = [""] + list(possible_prefixes) + + prefix_matches = {} + best_prefix = "" # Default to no prefix + for prefix in possible_prefixes: + stripped_list1 = [ + ParameterNamesUtils.strip_prefix(item, prefix) for item in list1 + ] + prefix_matches[prefix] = len( + [item for item in list2 if item in stripped_list1] + ) + + if max(prefix_matches.values()) > prefix_matches[""]: + best_prefix = max(prefix_matches, key=prefix_matches.get) + + return best_prefix + + @staticmethod + def find_common_ordered_names( + param_names: List[List[str]], prefixes: List[str] + ) -> List[str]: + """Identify and return common parameter names across all models, ensuring correct order. Also account for prefix.""" + common_names = set(param_names[0]) + for i in range(1, len(param_names)): + prefix = f"{prefixes[i]}." if prefixes[i] else "" + common_names.intersection_update({prefix + name for name in param_names[i]}) + return [name for name in param_names[0] if name in common_names] + + @staticmethod + def remove_size_conflicts(common_names, referenced_models, prefixes): + model_dirs = [ + ParameterNamesUtils.resolve_model_directory(m.model.path) + for m in referenced_models + ] + model_indices = [ShardedTensorIndex.from_disk(str(dir)) for dir in model_dirs] + + common_name_and_shape = common_names.copy() + removed_names = [] + + for name in common_names: + base_shape = ParameterNamesUtils.tensor_shape(name, model_indices[0]) + + for i in range(1, len(referenced_models)): + other_name = name + prefix = f"{prefixes[i]}." if prefixes[i] else "" + if name.startswith(prefix) and prefix != "": + other_name = name[len(prefix) :] + shape = ParameterNamesUtils.tensor_shape(other_name, model_indices[i]) + + if base_shape != shape: + common_name_and_shape.remove(name) + removed_names.append((name, base_shape, shape, i)) + break + + size_mismatch_count = len(removed_names) + if size_mismatch_count > 0: + logging.warning( + f"Size mismatch detected for {size_mismatch_count}/{size_mismatch_count + len(common_names)} tensors. " + "These names were removed from the merge list." + ) + logging.info( + "The following tensors have different shapes across models and were removed from the merge list:" + ) + for name, base_shape, shape, i in removed_names: + logging.info( + f"Tensor name: {name}, Base model shape: {base_shape}, Mismatched shape: {shape} in model {referenced_models[i].model.path}" + ) + + return common_name_and_shape + + @staticmethod + def are_common_params_ordered(list1: List[str], list2: List[str]) -> bool: + """ + Check if common elements of list2 maintain their relative order in list1. + """ + common_params = set(list1).intersection(set(list2)) + last_index = -1 + + for param in list2: + if param in common_params: + current_index = list1.index(param) + if current_index < last_index: + return False + last_index = current_index + return True + + @staticmethod + def ordered_sublist(list1: List[str], list2: List[str]) -> bool: + """ + Check if list2 is a contiguous ordered sublist of list1. + """ + n, m = len(list1), len(list2) + + for i in range(n - m + 1): + if list1[i : i + m] == list2: + return True + return False + + @staticmethod + def report_names_similarity( + base_names: List[str], other_names: List[str] + ) -> Tuple[Optional[str], str]: + """ + Analyze similarity between parameter names of two models and identify shared prefixes. + Returns: + best_prefix (str): Best matching prefix for parameter names. + case_message (str): Explanation of the structural relationship. + """ + possible_prefixes = {""} + possible_prefixes.update( + {item.split(".")[0] for item in base_names if "." in item} + ) + + prefixes_subset_overlap = {} + best_prefix = None + case_message = "No common parameter names found for any prefix" + + for prefix in possible_prefixes: + base_names_stripped = [ + ParameterNamesUtils.strip_prefix(name, prefix) for name in base_names + ] + + if ParameterNamesUtils.ordered_sublist(base_names_stripped, other_names): + return prefix, "All params in model have exact match in base model." + + intersection = set(base_names_stripped).intersection(set(other_names)) + prefixes_subset_overlap[prefix] = intersection + + if prefixes_subset_overlap: + best_prefix = max( + prefixes_subset_overlap, key=lambda x: len(prefixes_subset_overlap[x]) + ) + base_names_stripped = [ + ParameterNamesUtils.strip_prefix(name, best_prefix) + for name in base_names + ] + + overlap = len(prefixes_subset_overlap[best_prefix]) + ordered = ParameterNamesUtils.are_common_params_ordered( + base_names_stripped, other_names + ) + mismatched = [ + item for item in other_names if item not in base_names_stripped + ] + mismatched = "\n ".join(mismatched) + case_message = ( + f"{overlap}/{len(other_names)} ({100 * overlap / len(other_names):.2f}%) " + f"of model parameters are in the base model. \n" + f" Name ordering is {'preserved' if ordered else 'not preserved'}.\n" + f" Missing parameters:\n {mismatched}" + ) + + return best_prefix, case_message + + @staticmethod + def tensor_shape(name, index) -> Tuple[int]: + from safetensors import safe_open + + with safe_open( + Path(index.base_path) / index.tensor_paths[name], framework="pt" + ) as f: + return f.get_slice(name).get_shape() diff --git a/mergekit/scripts/layershuffle.py b/mergekit/scripts/layershuffle.py index 267e397c..b93c8bd5 100644 --- a/mergekit/scripts/layershuffle.py +++ b/mergekit/scripts/layershuffle.py @@ -7,7 +7,7 @@ import click import yaml -from mergekit.architecture import ArchitectureInfoUtils +from mergekit.architecture import arch_info_for_config from mergekit.common import ModelReference from mergekit.config import ( InputSliceDefinition, @@ -64,7 +64,7 @@ def main( models = [ModelReference.parse(m) for m in model] m0_cfg = models[0].config() - arch_info = ArchitectureInfoUtils.get_architecture_info(m0_cfg) + arch_info = arch_info_for_config(m0_cfg) total_num_layers = arch_info.num_layers(m0_cfg) out_slices: List[OutputSliceDefinition] = [] diff --git a/mergekit/scripts/moe.py b/mergekit/scripts/moe.py index 87eef5d0..b0c27594 100644 --- a/mergekit/scripts/moe.py +++ b/mergekit/scripts/moe.py @@ -163,9 +163,6 @@ def select_output_arch( help="Device to use to compute embeddings", show_default=True, ) -@click.option( - "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging" -) @click.option( "--i-understand-this-is-not-useful-without-training", type=bool, @@ -180,7 +177,6 @@ def main( load_in_4bit: bool, load_in_8bit: bool, device: str, - verbose: bool, i_understand_this_is_not_useful_without_training: bool, merge_options: MergeOptions, ): @@ -204,7 +200,7 @@ def main( load_in_8bit=load_in_8bit, device=device, allow_all_same=i_understand_this_is_not_useful_without_training, - verbose=verbose, + verbose=merge_options.verbose, ) if merge_options.write_model_card: diff --git a/mergekit/scripts/tokensurgeon.py b/mergekit/scripts/tokensurgeon.py index 406e7aa6..98b9780c 100644 --- a/mergekit/scripts/tokensurgeon.py +++ b/mergekit/scripts/tokensurgeon.py @@ -13,9 +13,9 @@ from typing_extensions import TypeAlias from mergekit.architecture import ( - ArchitectureInfoUtils, - ConfiguredArchitectureInfo, + ConfiguredModelArchitecture, WeightInfo, + arch_info_for_config, ) from mergekit.common import ModelReference from mergekit.io import TensorWriter @@ -132,6 +132,7 @@ def main( barycentric=barycentric, cosine_similarity=cosine_similarity, name=embed_info.name, + log_reconstruction_error=verbosity > 0, ) if lm_head_info: @@ -269,21 +270,24 @@ def get_embedding_info( ) -> Tuple[WeightInfo, WeightInfo]: """Get WeightInfo for the input and output embeddings of a model.""" cfg = model.config(trust_remote_code=options.trust_remote_code) - arch_info = ArchitectureInfoUtils.get_architecture_info(cfg) + arch_info = arch_info_for_config(cfg) + + if len(arch_info.modules) != 1: + raise RuntimeError("Model has multiple modules - not supported by tokensurgeon") + module_def = next(iter(arch_info.modules.values())) embed, lm_head = None, None - for weight_info in arch_info.pre_weights(cfg): + for weight_info in module_def.architecture.pre_weights(cfg): if weight_info.is_embed: if embed is not None: raise RuntimeError("Multiple input embeddings found") embed = weight_info - for weight_info in arch_info.post_weights(cfg): + for weight_info in module_def.architecture.post_weights(cfg): if weight_info.is_embed: if lm_head is not None: raise RuntimeError("Multiple output embeddings found") lm_head = weight_info - return embed, lm_head @@ -466,12 +470,14 @@ def get_embeddings( if log_reconstruction_error: # compute reconstruction error in donor_embed space - knn_reconstruction_error.append( - torch.nn.functional.mse_loss( - (knn_embeddings.T.to(weights.dtype) @ weights).squeeze(), - token_embedding, - ).item() + reconstructed = ( + (knn_embeddings.T.to(weights.dtype) @ weights) + .squeeze() + .to(token_embedding.dtype) ) + diff = token_embedding - reconstructed + mse = diff.square().mean().item() + knn_reconstruction_error.append(mse) # Reconstruct the embedding in original_embed space res[idx_1] = (e_c_0[indices].T @ weights).squeeze() @@ -576,7 +582,7 @@ def load_tokenizer( def validate_architecture( model: ModelReference, donor: ModelReference, options: MergeOptions -) -> Tuple[ConfiguredArchitectureInfo, transformers.PretrainedConfig]: +) -> Tuple[ConfiguredModelArchitecture, transformers.PretrainedConfig]: """ Validate that the architectures of two models match. @@ -584,15 +590,18 @@ def validate_architecture( """ model_cfg = model.config(trust_remote_code=options.trust_remote_code) donor_cfg = donor.config(trust_remote_code=options.trust_remote_code) - model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_cfg) - donor_arch_info = ArchitectureInfoUtils.get_architecture_info(donor_cfg) + model_arch_info = arch_info_for_config(model_cfg) + donor_arch_info = arch_info_for_config(donor_cfg) if donor_arch_info != model_arch_info: report_issue( - f"Model architectures do not match: {model_arch_info.name()} vs {donor_arch_info.name()}", + f"Model architectures do not match: {model_arch_info.expected_model_type} vs {donor_arch_info.expected_model_type}", error=not options.allow_crimes, ) - return ConfiguredArchitectureInfo(info=model_arch_info, config=model_cfg), donor_cfg + return ( + ConfiguredModelArchitecture(info=model_arch_info, config=model_cfg), + donor_cfg, + ) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index a8f11abc..db52c1f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "mergekit" description = "Tools for merging pre-trained large language models" readme = "README.md" license = { text = "BUSL-1.1" } -version = "0.1.1" +version = "0.1.2" authors = [{ name = "Charles Goddard", email = "chargoddard@gmail.com" }] requires-python = ">=3.10" dependencies = [ @@ -60,6 +60,7 @@ packages = [ "mergekit.scripts", "mergekit.evo", "mergekit.tokenizer", + "mergekit.architecture", "mergekit._data", "mergekit._data.architectures", "mergekit._data.chat_templates", diff --git a/tests/common.py b/tests/common.py index 54068c54..9b7ceb9c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -13,7 +13,10 @@ LlavaForConditionalGeneration, ) -from mergekit.architecture import ArchitectureInfoUtils +from mergekit.architecture import ( + arch_info_for_config, + get_architecture_info, +) from mergekit.config import MergeConfiguration from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex from mergekit.merge import MergeOptions, run_merge @@ -53,9 +56,9 @@ def run_and_check_merge( if check_tensors: model_config = AutoConfig.from_pretrained(tmpdir) if auto_arch: - arch_info = ArchitectureInfoUtils.infer_architecture_info(config) + arch_info = get_architecture_info(config, MergeOptions()) else: - arch_info = ArchitectureInfoUtils.get_architecture_info(model_config) + arch_info = arch_info_for_config(model_config) index = ShardedTensorIndex.from_disk(tmpdir) for weight_info in arch_info.all_weights(model_config): diff --git a/tests/test_basic_merges.py b/tests/test_basic_merges.py index 8aac2322..15c03621 100644 --- a/tests/test_basic_merges.py +++ b/tests/test_basic_merges.py @@ -119,13 +119,6 @@ def test_slerp_merge(self, model_a, model_b): config.parameters = {"t": 0.35} run_and_check_merge(config) - def test_nearswap_merge(self, model_a, model_b): - config = self.two_model_config( - model_a, model_b, merge_method="nearswap", base_model=model_a - ) - config.parameters = {"t": 0.0001} - run_and_check_merge(config) - def test_nuslerp_merges(self, model_a, model_b, model_c): for base_model in [None, model_c]: for row_wise in [False, True]: diff --git a/tests/test_chat_template.py b/tests/test_chat_template.py index af511a2b..2bd41cde 100644 --- a/tests/test_chat_template.py +++ b/tests/test_chat_template.py @@ -1,13 +1,25 @@ from typing import Optional -from common import run_and_check_merge -from test_basic_merges import model_b -from test_tokenizer import model_base +import pytest +from common import make_picollama, run_and_check_merge +from test_tokenizer import make_tokenizer from transformers import AutoTokenizer from mergekit.config import InputModelDefinition, MergeConfiguration +@pytest.fixture(scope="session") +def model_base(tmp_path_factory): + model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64) + make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path) + return model_path + + +@pytest.fixture(scope="session") +def model_b(tmp_path_factory): + return make_picollama(tmp_path_factory.mktemp("model_b")) + + def check_chat_template(model_path: str, needle: Optional[str] = None): tokenizer = AutoTokenizer.from_pretrained(model_path) if needle is None: