diff --git a/mergekit/_data/architectures/gemma3.json b/mergekit/_data/architectures/gemma3.json
new file mode 100644
index 00000000..787fd041
--- /dev/null
+++ b/mergekit/_data/architectures/gemma3.json
@@ -0,0 +1,69 @@
+{
+    "model_type": "gemma3_text",
+    "architectures": [
+        "Gemma3ForCausalLM"
+    ],
+    "pre_weights": [
+        {
+            "name": "model.embed_tokens.weight",
+            "is_embed": true
+        }
+    ],
+    "num_layers_config_key": "num_hidden_layers",
+    "layer_templates": {
+        "weights": [
+            {
+                "name": "model.layers.${layer_index}.input_layernorm.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.self_attn.q_proj.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.self_attn.q_norm.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.self_attn.k_proj.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.self_attn.k_norm.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.self_attn.v_proj.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.self_attn.o_proj.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.post_attention_layernorm.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.mlp.up_proj.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.mlp.gate_proj.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.mlp.down_proj.weight"
+            },
+            {
+                "name": "model.layers.${layer_index}.post_feedforward_layernorm.weight"
+            }
+        ]
+    },
+    "post_weights": [
+        {
+            "name": "model.norm.weight"
+        },
+        {
+            "name": "lm_head.weight",
+            "is_embed": true,
+            "optional": true,
+            "tied_names": [
+                "model.embed_tokens.weight"
+            ]
+        }
+    ]
+}
diff --git a/mergekit/_data/architectures/gemma3vl.json b/mergekit/_data/architectures/gemma3vl.json
new file mode 100644
index 00000000..cb45f7e2
--- /dev/null
+++ b/mergekit/_data/architectures/gemma3vl.json
@@ -0,0 +1,184 @@
+{
+    "kind": "modular",
+    "architectures": [
+        "Gemma3ForConditionalGeneration"
+    ],
+    "model_type": "gemma3",
+    "tagalong_files": [
+        "preprocessor_config.json",
+        "processor_config.json"
+    ],
+    "modules": {
+        "text_decoder": {
+            "weight_prefix": "language_model.",
+            "architecture": {
+                "model_type": "gemma3_text",
+                "architectures": [
+                    "Gemma3ForCausalLM"
+                ],
+                "pre_weights": [
+                    {
+                        "name": "model.embed_tokens.weight",
+                        "is_embed": true
+                    }
+                ],
+                "num_layers_config_key": "text_config.num_hidden_layers",
+                "layer_templates": {
+                    "weights": [
+                        {
+                            "name": "model.layers.${layer_index}.input_layernorm.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.self_attn.q_proj.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.self_attn.q_norm.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.self_attn.k_proj.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.self_attn.k_norm.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.self_attn.v_proj.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.self_attn.o_proj.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.post_attention_layernorm.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.pre_feedforward_layernorm.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.mlp.up_proj.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.mlp.gate_proj.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.mlp.down_proj.weight"
+                        },
+                        {
+                            "name": "model.layers.${layer_index}.post_feedforward_layernorm.weight"
+                        }
+                    ]
+                },
+                "post_weights": [
+                    {
+                        "name": "model.norm.weight"
+                    },
+                    {
+                        "name": "lm_head.weight",
+                        "is_embed": true,
+                        "optional": true,
+                        "tied_names": [
+                            "model.embed_tokens.weight"
+                        ]
+                    }
+                ]
+            }
+        },
+        "multi_modal_projector": {
+            "weight_prefix": "multi_modal_projector.",
+            "architecture": {
+                "model_type": "gemma3_mmproj",
+                "architectures": [],
+                "pre_weights": [
+                    {
+                        "name": "mm_input_projection_weight"
+                    },
+                    {
+                        "name": "mm_soft_emb_norm.weight"
+                    }
+                ],
+                "post_weights": [],
+                "layer_templates": {
+                    "weights": []
+                },
+                "override_num_layers": 0
+            }
+        },
+        "vision_tower": {
+            "weight_prefix": "vision_tower.vision_model.",
+            "architecture": {
+                "model_type": "siglip_vision_model",
+                "architectures": [],
+                "pre_weights": [
+                    {
+                        "name": "embeddings.patch_embedding.bias"
+                    },
+                    {
+                        "name": "embeddings.patch_embedding.weight"
+                    },
+                    {
+                        "name": "embeddings.position_embedding.weight"
+                    }
+                ],
+                "post_weights": [
+                    {
+                        "name": "post_layernorm.bias"
+                    },
+                    {
+                        "name": "post_layernorm.weight"
+                    }
+                ],
+                "layer_templates": {
+                    "weights": [
+                        {
+                            "name": "encoder.layers.${layer_index}.layer_norm1.bias"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.layer_norm1.weight"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.layer_norm2.bias"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.layer_norm2.weight"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.mlp.fc1.bias"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.mlp.fc1.weight"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.mlp.fc2.bias"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.mlp.fc2.weight"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.self_attn.k_proj.bias"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.self_attn.k_proj.weight"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.self_attn.out_proj.bias"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.self_attn.out_proj.weight"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.self_attn.q_proj.bias"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.self_attn.q_proj.weight"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.self_attn.v_proj.bias"
+                        },
+                        {
+                            "name": "encoder.layers.${layer_index}.self_attn.v_proj.weight"
+                        }
+                    ]
+                },
+                "num_layers_config_key": "vision_config.num_hidden_layers"
+            }
+        }
+    }
+}
diff --git a/mergekit/_data/architectures/t5.json b/mergekit/_data/architectures/t5.json
new file mode 100644
index 00000000..9ed8a7b8
--- /dev/null
+++ b/mergekit/_data/architectures/t5.json
@@ -0,0 +1,170 @@
+{
+    "kind": "modular",
+    "architectures": [
+        "T5ForConditionalGeneration"
+    ],
+    "model_type": "t5",
+    "modules": {
+        "decoder": {
+            "architecture": {
+                "model_type": "",
+                "architectures": [],
+                "pre_weights": [
+                    {
+                        "name": "decoder.embed_tokens.weight",
+                        "is_embed": true,
+                        "optional": true,
+                        "tied_names": [
+                            "shared.weight",
+                            "lm_head.weight",
+                            "encoder.embed_tokens.weight"
+                        ]
+                    }
+                ],
+                "num_layers_config_key": "num_decoder_layers",
+                "layer_templates": {
+                    "weights": [
+                        {
+                            "name": "decoder.block.${layer_index}.layer.0.layer_norm.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.0.SelfAttention.q.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.0.SelfAttention.k.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.0.SelfAttention.v.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.0.SelfAttention.o.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.0.SelfAttention.relative_attention_bias.weight",
+                            "optional": true
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.1.EncDecAttention.q.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.1.EncDecAttention.k.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.1.EncDecAttention.v.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.1.EncDecAttention.o.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.1.layer_norm.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.2.DenseReluDense.wi_0.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.2.DenseReluDense.wi_1.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.2.DenseReluDense.wo.weight"
+                        },
+                        {
+                            "name": "decoder.block.${layer_index}.layer.2.layer_norm.weight"
+                        }
+                    ]
+                },
+                "post_weights": [
+                    {
+                        "name": "decoder.final_layer_norm.weight"
+                    }
+                ]
+            }
+        },
+        "encoder": {
+            "architecture": {
+                "model_type": "",
+                "architectures": [],
+                "pre_weights": [
+                    {
+                        "name": "encoder.embed_tokens.weight",
+                        "is_embed": true,
+                        "optional": true,
+                        "tied_names": [
+                            "shared.weight",
+                            "lm_head.weight",
+                            "decoder.embed_tokens.weight"
+                        ]
+                    }
+                ],
+                "num_layers_config_key": "num_hidden_layers",
+                "layer_templates": {
+                    "weights": [
+                        {
+                            "name": "encoder.block.${layer_index}.layer.0.layer_norm.weight"
+                        },
+                        {
+                            "name": "encoder.block.${layer_index}.layer.0.SelfAttention.q.weight"
+                        },
+                        {
+                            "name": "encoder.block.${layer_index}.layer.0.SelfAttention.k.weight"
+                        },
+                        {
+                            "name": "encoder.block.${layer_index}.layer.0.SelfAttention.v.weight"
+                        },
+                        {
+                            "name": "encoder.block.${layer_index}.layer.0.SelfAttention.o.weight"
+                        },
+                        {
+                            "name": "encoder.block.${layer_index}.layer.0.SelfAttention.relative_attention_bias.weight",
+                            "optional": true
+                        },
+                        {
+                            "name": "encoder.block.${layer_index}.layer.1.DenseReluDense.wi_0.weight"
+                        },
+                        {
+                            "name": "encoder.block.${layer_index}.layer.1.DenseReluDense.wi_1.weight"
+                        },
+                        {
+                            "name": "encoder.block.${layer_index}.layer.1.DenseReluDense.wo.weight"
+                        },
+                        {
+                            "name": "encoder.block.${layer_index}.layer.1.layer_norm.weight"
+                        }
+                    ]
+                },
+                "post_weights": [
+                    {
+                        "name": "encoder.final_layer_norm.weight"
+                    }
+                ]
+            }
+        },
+        "shared": {
+            "architecture": {
+                "model_type": "",
+                "architectures": [],
+                "pre_weights": [
+                    {
+                        "name": "shared.weight",
+                        "is_embed": true
+                    }
+                ],
+                "layer_templates": {
+                    "weights": []
+                },
+                "post_weights": [
+                    {
+                        "name": "lm_head.weight",
+                        "is_embed": true,
+                        "optional": true,
+                        "tied_names": [
+                            "shared.weight",
+                            "encoder.embed_tokens.weight",
+                            "decoder.embed_tokens.weight"
+                        ]
+                    }
+                ],
+                "override_num_layers": 0
+            }
+        }
+    }
+}
diff --git a/mergekit/_data/architectures/whisper.json b/mergekit/_data/architectures/whisper.json
new file mode 100644
index 00000000..82ac444b
--- /dev/null
+++ b/mergekit/_data/architectures/whisper.json
@@ -0,0 +1,196 @@
+{
+    "kind": "modular",
+    "architectures": [
+        "WhisperForConditionalGeneration"
+    ],
+    "model_type": "whisper",
+    "tagalong_files": [
+        "preprocessor_config.json",
+        "normalizer.json"
+    ],
+    "modules": {
+        "decoder": {
+            "weight_prefix": "model.decoder",
+            "architecture": {
+                "model_type": "",
+                "architectures": [],
+                "pre_weights": [
+                    {
+                        "name": "embed_tokens.weight",
+                        "is_embed": true
+                    },
+                    {
+                        "name": "embed_positions.weight"
+                    }
+                ],
+                "num_layers_config_key": "decoder_layers",
+                "layer_templates": {
+                    "weights": [
+                        {
+                            "name": "layers.${layer_index}.encoder_attn.k_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.encoder_attn.out_proj.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.encoder_attn.out_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.encoder_attn.q_proj.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.encoder_attn.q_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.encoder_attn.v_proj.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.encoder_attn.v_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.encoder_attn_layer_norm.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.encoder_attn_layer_norm.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.fc1.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.fc1.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.fc2.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.fc2.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.final_layer_norm.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.final_layer_norm.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.k_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.out_proj.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.out_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.q_proj.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.q_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.v_proj.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.v_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn_layer_norm.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn_layer_norm.weight"
+                        }
+                    ]
+                },
+                "post_weights": [
+                    {
+                        "name": "layer_norm.bias"
+                    },
+                    {
+                        "name": "layer_norm.weight"
+                    }
+                ]
+            }
+        },
+        "encoder": {
+            "weight_prefix": "model.encoder.",
+            "architecture": {
+                "model_type": "",
+                "architectures": [],
+                "pre_weights": [
+                    {
+                        "name": "embed_positions.weight"
+                    },
+                    {
+                        "name": "conv1.bias"
+                    },
+                    {
+                        "name": "conv1.weight"
+                    },
+                    {
+                        "name": "conv2.bias"
+                    },
+                    {
+                        "name": "conv2.weight"
+                    }
+                ],
+                "post_weights": [
+                    {
+                        "name": "layer_norm.bias"
+                    },
+                    {
+                        "name": "layer_norm.weight"
+                    }
+                ],
+                "layer_templates": {
+                    "weights": [
+                        {
+                            "name": "layers.${layer_index}.fc1.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.fc1.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.fc2.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.fc2.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.final_layer_norm.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.final_layer_norm.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.k_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.out_proj.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.out_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.q_proj.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.q_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.v_proj.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn.v_proj.weight"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn_layer_norm.bias"
+                        },
+                        {
+                            "name": "layers.${layer_index}.self_attn_layer_norm.weight"
+                        }
+                    ]
+                },
+                "num_layers_config_key": "encoder_layers"
+            }
+        }
+    }
+}
diff --git a/mergekit/architecture.py b/mergekit/architecture.py
deleted file mode 100644
index 49840b73..00000000
--- a/mergekit/architecture.py
+++ /dev/null
@@ -1,779 +0,0 @@
-# Copyright (C) 2025 Arcee AI
-# SPDX-License-Identifier: BUSL-1.1
-
-import importlib.resources
-import logging
-import re
-import string
-import warnings
-from abc import ABC, abstractmethod
-from collections import defaultdict
-from pathlib import Path
-from typing import ClassVar, Dict, List, Optional, Tuple, Union
-
-from huggingface_hub import snapshot_download
-from pydantic import BaseModel, Field
-from transformers import PretrainedConfig
-from typing_extensions import Literal
-
-import mergekit._data.architectures
-from mergekit.io.lazy_tensor_loader import ShardedTensorIndex
-
-
-class WeightInfo(BaseModel, frozen=True):
-    """Information about an individual weight tensor in a model.
-
-    Attributes:
-        name (str):
-            The name of the tensor representing the weight.
-        is_embed (bool):
-            Indicates whether the weight is for an embedding or language model head.
-        input_space (Optional[str]):
-            The name of the input space associated with the weight, if applicable.
-        output_space (Optional[str]):
-            The name of the output space associated with the weight, if applicable.
-        optional (bool):
-            Indicates whether the weight can be omitted from a model.
-        aliases (Optional[List[str]]):
-            List of alternative names for the weight, if applicable.
-        tied_names (Optional[List[str]]):
-            List of names for weights that are tied to this weight, if applicable.
-        force_dtype (Optional[str]):
-            Mandatory dtype for the weight, if applicable.
-    """
-
-    name: str
-    is_embed: bool = False
-    input_space: Optional[str] = None
-    output_space: Optional[str] = None
-    optional: bool = False
-    tied: bool = False
-    aliases: Optional[Tuple[str, ...]] = None
-    tied_names: Optional[Tuple[str, ...]] = None
-    force_dtype: Optional[str] = None
-    head_split: Literal[None, "input", "output"] = None
-    is_kq: Optional[bool] = False
-
-
-class ProceduralSpaceInfo(BaseModel, frozen=True):
-    """Defines a procedural space computed from one or more other spaces.
-
-    Currently only supports residual connections.
-
-    Attributes:
-        name (str): The name of the space defined.
-        type (str): The type of procedural space.
-        inputs (List[str]): List of names of spaces used to define this space."""
-
-    name: str
-    type: Literal["residual"]
-    inputs: List[str]
-
-
-class ArchitectureInfo(ABC):
-    @abstractmethod
-    def name(self) -> str:
-        """Return the name of the architecture."""
-        ...
-
-    @abstractmethod
-    def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
-        """Return a list of all weights preceding the first layer."""
-        ...
-
-    @abstractmethod
-    def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
-        """Return a list of all weights following the final layer."""
-        ...
-
-    @abstractmethod
-    def layer_weights(
-        self, index: int, config: PretrainedConfig
-    ) -> Optional[List[WeightInfo]]:
-        """Return a list of all weights associated with a given layer."""
-        ...
-
-    @abstractmethod
-    def sliceable(self) -> bool:
-        """
-        Return True if the layers of this architecture can be meaningfully sliced.
-        """
-        ...
-
-    def num_layers_config_key(self) -> str:
-        """Key in config that represents number of layers"""
-        return "num_hidden_layers"
-
-    def num_layers(self, config: PretrainedConfig) -> int:
-        """Return the number of layers in a model."""
-        return getattr(config, self.num_layers_config_key())
-
-    def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
-        """Return all weights associated with a model."""
-        num_layers = self.num_layers(config)
-        res = list(self.pre_weights(config))
-        for layer_idx in range(num_layers):
-            res.extend(self.layer_weights(layer_idx, config))
-        res.extend(self.post_weights(config))
-        return res
-
-    def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]:
-        """Return a list of all procedurally defined spaces in a model."""
-        return []
-
-    def has_defined_spaces(self) -> bool:
-        """
-        Return True if this architecture defines space information needed for
-        matching-based merge methods.
-        """
-        return False
-
-
-class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True):
-    info: ArchitectureInfo
-    config: PretrainedConfig
-
-    def name(self) -> str:
-        return self.info.name()
-
-    def num_layers(self) -> int:
-        return self.info.num_layers(self.config)
-
-    def pre_weights(self) -> List[WeightInfo]:
-        return self.info.pre_weights(self.config)
-
-    def post_weights(self) -> List[WeightInfo]:
-        return self.info.post_weights(self.config)
-
-    def layer_weights(self, index: int) -> List[WeightInfo]:
-        return self.info.layer_weights(index, self.config)
-
-    def procedural_spaces(self) -> List[ProceduralSpaceInfo]:
-        return self.info.procedural_spaces(self.config)
-
-    def all_weights(self) -> List[WeightInfo]:
-        return self.info.all_weights(self.config)
-
-
-class JSONLayerTemplates(BaseModel, frozen=True):
-    weights: List[WeightInfo]
-    procedural_spaces: Optional[List[ProceduralSpaceInfo]] = None
-
-
-class JSONArchitectureDefinition(BaseModel, frozen=True):
-    expected_model_type: str = Field(alias="model_type")
-    architectures: List[str]
-    pre_weights: List[WeightInfo]
-    layer_templates: JSONLayerTemplates
-    post_weights: List[WeightInfo]
-    procedural_spaces: Optional[List[ProceduralSpaceInfo]] = None
-    num_layers_config_key: Optional[str] = None
-
-
-class TemplateWithArithmetic(string.Template):
-    idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)"
-
-
-def _template_substitution(
-    template: str, num_layers: int, layer_idx: Optional[int] = None
-) -> str:
-    if "{" not in template:
-        return template
-
-    substitutions = {
-        "num_layers": num_layers,
-        "num_layers+1": num_layers + 1,
-        "num_layers-1": num_layers - 1,
-    }
-
-    if layer_idx is not None:
-        substitutions.update(
-            {
-                "layer_index": layer_idx,
-                "layer_index+1": layer_idx + 1,
-                "layer_index-1": layer_idx - 1,
-            }
-        )
-
-    return TemplateWithArithmetic(template).substitute(substitutions)
-
-
-def _hierarchy(names, layer_prefix=r"\.\d+\.") -> Dict[str, List[str]]:
-    hierarchy = defaultdict(list)
-
-    # Regular expression to match layers (denoted by .{integer}. by default)
-    layer_pattern = re.compile(layer_prefix)
-
-    if names:
-        for name in names:
-            # Find the layer part of the string (e.g., 'model.layers.0.')
-            match = layer_pattern.search(name)
-            if match:
-                # Extract everything up to the layer identifier
-                layer_prefix = name[: match.end() - 1]  # e.g., 'model.layers.0'
-                # Extract the parameter name after the layer identifier
-                param_name = name[match.end() :]  # e.g., 'input_layernorm.weight'
-                # Add the parameter name to the corresponding layer in the hierarchy
-                hierarchy[layer_prefix].append(param_name)
-            else:
-                hierarchy[name].append("")
-
-    return hierarchy
-
-
-class AutomaticArchitectureInfo(ArchitectureInfo, BaseModel):
-    arch_name: str = Field(default="")
-    parameter_names: List[str] = Field(default_factory=list)
-    embed: List[str] = Field(default_factory=list)
-    layered_parameter_names: Dict[str, List[str]] = Field(default_factory=dict)
-    prefix_tracker: Dict[str, str] = Field(default_factory=dict)
-    post_fill_parameters: bool = False
-
-    def __init__(
-        self,
-        arch_name: str,
-        parameter_names: List[str],
-        prefix_tracker: Optional[Dict[str, str]] = None,
-        post_fill_parameters: bool = False,
-    ):
-        super().__init__()
-        self.arch_name = arch_name
-        self.parameter_names = parameter_names
-        self.layered_parameter_names = _hierarchy(self.parameter_names)
-        self.prefix_tracker = prefix_tracker or {}
-        self.embed = self._find_embed_params()
-        self.post_fill_parameters = post_fill_parameters
-
-    def _find_embed_params(self) -> List[str]:
-        """Identify embedding parameters (e.g., 'lm_head', 'embed') that may require special handling."""
-        embed_params = []
-        for name in self.parameter_names:
-            if any(embedding_name in name for embedding_name in ["lm_head", "embed"]):
-                embed_params.append(name)
-        return embed_params
-
-    def name(self) -> str:
-        """Returns the architecture name."""
-        return self.arch_name
-
-    def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
-        """This architecture does not distinguish pre-weights."""
-        return []
-
-    def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
-        """This architecture does not distinguish post-weights."""
-        return []
-
-    def layer_weights(
-        self, index: int, config: PretrainedConfig
-    ) -> Optional[List[WeightInfo]]:
-        """
-        Retrieves the weights for a specified layer, adjusting names for prefixes if applicable.
-        """
-        layer_name = list(self.layered_parameter_names.keys())[index]
-        adjusted_layer_name = self._adjust_layer_name(layer_name, config)
-
-        weights = [
-            WeightInfo(
-                name=f"{adjusted_layer_name}.{param}" if param else adjusted_layer_name,
-                is_embed=(layer_name in self.embed),
-            )
-            for param in self.layered_parameter_names[layer_name]
-        ]
-        return (
-            weights
-            if weights
-            else [
-                WeightInfo(
-                    name=adjusted_layer_name, is_embed=(layer_name in self.embed)
-                )
-            ]
-        )
-
-    def _adjust_layer_name(self, layer_name: str, config: PretrainedConfig) -> str:
-        """Adjust layer names by removing any prefix as indicated in the prefix tracker."""
-        if config and config.name_or_path in self.prefix_tracker:
-            prefix = self.prefix_tracker.get(config.name_or_path, "")
-            if layer_name.startswith(prefix):
-                return layer_name[len(prefix) :]
-        return layer_name
-
-    def sliceable(self) -> bool:
-        """Indicates if the architecture supports slicing."""
-        return True
-
-    def num_layers(self, config: PretrainedConfig) -> int:
-        """Returns the number of layers based on layered parameter names."""
-        return len(self.layered_parameter_names)
-
-
-class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True):
-    definition: JSONArchitectureDefinition
-
-    def _substitute(
-        self,
-        item: Union[WeightInfo, ProceduralSpaceInfo],
-        config: PretrainedConfig,
-        layer_idx: Optional[int] = None,
-    ) -> Union[WeightInfo, ProceduralSpaceInfo]:
-        num_layers = self.num_layers(config)
-
-        obj_dict = item.model_dump(mode="json", exclude_unset=True)
-        for key in obj_dict:
-            if isinstance(obj_dict[key], str):
-                obj_dict[key] = _template_substitution(
-                    obj_dict[key], num_layers, layer_idx
-                )
-            elif isinstance(obj_dict[key], list):
-                obj_dict[key] = [
-                    (
-                        _template_substitution(s, num_layers, layer_idx)
-                        if isinstance(s, str)
-                        else s
-                    )
-                    for s in obj_dict[key]
-                ]
-        return type(item).model_validate(obj_dict)
-
-    def name(self) -> str:
-        return self.definition.expected_model_type
-
-    def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
-        return [
-            self._substitute(wi, config=config) for wi in self.definition.pre_weights
-        ]
-
-    def layer_weights(
-        self, index: int, config: PretrainedConfig
-    ) -> Optional[List[WeightInfo]]:
-        return [
-            self._substitute(wi, config=config, layer_idx=index)
-            for wi in self.definition.layer_templates.weights
-        ]
-
-    def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
-        return [
-            self._substitute(wi, config=config) for wi in self.definition.post_weights
-        ]
-
-    def sliceable(self) -> bool:
-        return True
-
-    def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]:
-        res = []
-        for s in self.definition.procedural_spaces or []:
-            res.append(self._substitute(s, config=config))
-        for idx in range(self.num_layers(config)):
-            for s in self.definition.layer_templates.procedural_spaces or []:
-                res.append(self._substitute(s, config=config, layer_idx=idx))
-        return res
-
-    def has_defined_spaces(self) -> bool:
-        if (
-            self.definition.procedural_spaces
-            or self.definition.layer_templates.procedural_spaces
-        ):
-            return True
-        for wi in (
-            self.definition.layer_templates.weights
-            + self.definition.pre_weights
-            + self.definition.post_weights
-        ):
-            if wi.input_space or wi.output_space:
-                return True
-        return False
-
-    def num_layers_config_key(self) -> str:
-        return self.definition.num_layers_config_key
-
-
-class MixtralTensorNames(ArchitectureInfo, BaseModel):
-    ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM"
-    num_local_experts: int
-
-    def name(self) -> str:
-        return "mixtral"
-
-    @classmethod
-    def from_config(cls, config: PretrainedConfig):
-        return MixtralTensorNames(num_local_experts=config.num_local_experts)
-
-    def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
-        return MISTRAL_INFO.pre_weights(config)
-
-    def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
-        return MISTRAL_INFO.post_weights(config)
-
-    def num_layers_config_key(self) -> str:
-        return MISTRAL_INFO.num_layers_config_key()
-
-    def layer_weights(
-        self, index: int, config: PretrainedConfig
-    ) -> Optional[List[WeightInfo]]:
-        num_experts = self.num_local_experts
-        prefix = f"model.layers.{index}"
-        tensor_names = []
-        for expert_idx in range(num_experts):
-            for param in ("w1", "w2", "w3"):
-                tensor_names.append(
-                    prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight"
-                )
-        tensor_names.append(prefix + ".block_sparse_moe.gate.weight")
-        res = []
-        for name in tensor_names:
-            res.append(WeightInfo(name=name))
-        for weight_info in MISTRAL_INFO.layer_weights(index, config):
-            if ".mlp." in weight_info.name:
-                continue
-            res.append(weight_info)
-        return res
-
-    def sliceable(self) -> bool:
-        return True
-
-    def has_defined_spaces(self) -> bool:
-        return False
-
-
-def _load_json_arch(name: str) -> JsonArchitectureInfo:
-    text = importlib.resources.read_text(mergekit._data.architectures, name)
-    return JsonArchitectureInfo(
-        definition=JSONArchitectureDefinition.model_validate_json(text)
-    )
-
-
-def _load_all_architectures() -> (
-    Tuple[List[JsonArchitectureInfo], Dict[str, List[JsonArchitectureInfo]]]
-):
-    architectures: List[JsonArchitectureInfo] = []
-    for f in importlib.resources.contents(mergekit._data.architectures):
-        if f.lower().endswith(".json"):
-            architectures.append(_load_json_arch(f))
-
-    name_to_arch: Dict[str, List[JsonArchitectureInfo]] = {}
-    for arch_info in architectures:
-        for name in arch_info.definition.architectures:
-            name_to_arch[name] = name_to_arch.get(name, [])
-            name_to_arch[name].append(arch_info)
-    return architectures, name_to_arch
-
-
-JSON_ARCHITECTURES, NAME_TO_ARCH = _load_all_architectures()
-MISTRAL_INFO = _load_json_arch("mistral.json")
-QWEN2_INFO = _load_json_arch("qwen2.json")
-
-
-class ArchitectureInfoUtils:
-    """Functions for inferring architecture information from a merge configuration."""
-
-    @staticmethod
-    def get_architecture_info(config: PretrainedConfig) -> Optional[ArchitectureInfo]:
-        """Get architecture info from an existing model config."""
-        if len(config.architectures) != 1:
-            raise RuntimeError("More than one architecture in config?")
-
-        arch_name = config.architectures[0]
-
-        if arch_name == MixtralTensorNames.ARCHITECTURE_NAME:
-            return MixtralTensorNames.from_config(config)
-
-        if arch_name in NAME_TO_ARCH:
-            candidates = list(NAME_TO_ARCH[arch_name])
-            if len(candidates) == 1:
-                return candidates[0]
-
-            for c in candidates:
-                if c.definition.expected_model_type == config.model_type:
-                    return c
-
-        warnings.warn(f"No architecture config available for: {arch_name}.")
-        return None
-
-    @staticmethod
-    def infer_architecture_info(merge_config) -> AutomaticArchitectureInfo:
-        """
-        Infer architecture info and prefixes for alignment.
-        Prefixes typically denote where a model is used as a subcomponent of another model.
-        e.g., [layer.0, layer.1, ...] and []'vision_tower.layer.0', vision_tower.layer.1', ...]
-            inferring ßprefix = 'vision_tower' is required to align the two models.
-
-        Usage:
-            Similar to `get_architecture_info`, but requires a merge configuration object rather than a model config.
-            This is so the common parameter names between all models can be inferred.
-        """
-        param_names = [
-            ParameterNamesUtils.get_model_parameter_names(source_model.model.path)
-            for source_model in merge_config.referenced_models()
-        ]
-        base_model = merge_config.base_model
-
-        paired_list = list(zip(param_names, merge_config.referenced_models()))
-        paired_list.sort(key=lambda x: len(x[0]), reverse=True)
-        for i, (_, model_name) in enumerate(paired_list):
-            if model_name == base_model:
-                paired_list.insert(0, paired_list.pop(i))
-                break
-        param_names, referenced_models = zip(*paired_list)
-        logging.info(f"Base model selected: {referenced_models[0].model.path}")
-
-        prefixes = [""]
-        for i in range(1, len(param_names)):
-            assert len(param_names[0]) >= len(
-                param_names[i]
-            ), f"base model names list can't be shorter than model {i} names list"
-            prefixes.append(
-                ParameterNamesUtils.find_prefix(param_names[0], param_names[i])
-            )
-
-        common_names = ParameterNamesUtils.find_common_ordered_names(
-            param_names, prefixes
-        )
-
-        common_names = ParameterNamesUtils.remove_size_conflicts(
-            common_names, referenced_models, prefixes
-        )
-
-        ArchitectureInfoUtils.log_info(common_names, param_names, referenced_models)
-
-        if not common_names or any([p is None for p in prefixes]):
-            raise ValueError("Could not resolve model architecture automatically.")
-
-        prefix_tracker = {
-            model.model.path: f"{prefix}." if prefix else ""
-            for model, prefix in zip(referenced_models, prefixes)
-        }
-
-        arch_name = referenced_models[0].model.path
-        parameter_names = common_names
-
-        return AutomaticArchitectureInfo(
-            arch_name=arch_name,
-            parameter_names=parameter_names,
-            prefix_tracker=prefix_tracker,
-            post_fill_parameters=(
-                referenced_models[0].model.path  # base model name
-                if len(common_names) != len(param_names[0])
-                else None  # no post-fill needed
-            ),
-        )
-
-    @staticmethod
-    def log_info(common_names, param_names, referenced_models):
-        for i in range(1, len(param_names)):
-            prefix, case_message = ParameterNamesUtils.report_names_similarity(
-                param_names[0], param_names[i]
-            )
-            logging.info(
-                f"Model {referenced_models[i].model.path}: \
-                    \n  {f'Best prefix found: {prefix}' if prefix else 'No prefix found'}\
-                    \n  {case_message.replace('MODEL_ID', referenced_models[i].model.path)}"
-            )
-
-        if len(common_names) != len(param_names[0]):
-            warnings.warn(
-                f"Merging {len(common_names)}/{len(param_names[0])} base model parameters. \
-                \n Base model selected: {referenced_models[0].model.path} \
-                \n copy_and_fill_missing_params will run when merge is complete, to fill in missing params from base model."
-            )
-
-        if len(common_names) < 0.3 * len(param_names[0]):
-            warnings.warn(
-                "Not many common parameters found. Are you sure you are merging the correct models?"
-            )
-
-
-class ParameterNamesUtils:
-    """Utility functions for handling parameter names."""
-
-    @staticmethod
-    def resolve_model_directory(repo_id: str) -> Path:
-        """Resolve the model directory (local or Hugging Face Hub)."""
-        if Path(repo_id).is_dir():
-            return Path(repo_id)
-
-        return Path(snapshot_download(repo_id))
-
-    @staticmethod
-    def get_model_parameter_names(repo_id: str) -> List[str]:
-        """Get parameter names of a model from a Hugging Face repo or local directory."""
-        model_dir = ParameterNamesUtils.resolve_model_directory(repo_id)
-        return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys())
-
-    @staticmethod
-    def strip_prefix(name: str, prefix: str) -> str:
-        """Remove a single prefix from the start of a name."""
-        if prefix != "" and name.startswith(prefix + "."):
-            return name[len(prefix) + 1 :]
-        return name
-
-    @staticmethod
-    def find_prefix(list1: List[str], list2: List[str]) -> Optional[str]:
-        """
-        Find a prefix in list1 that, after removal, makes list2 an ordered sublist.
-        """
-        assert len(list1) >= len(list2), "params name list1 can't be shorter than list2"
-
-        possible_prefixes = {item.split(".")[0] for item in list1 if "." in item}
-        possible_prefixes = [""] + list(possible_prefixes)
-
-        prefix_matches = {}
-        best_prefix = ""  # Default to no prefix
-        for prefix in possible_prefixes:
-            stripped_list1 = [
-                ParameterNamesUtils.strip_prefix(item, prefix) for item in list1
-            ]
-            prefix_matches[prefix] = len(
-                [item for item in list2 if item in stripped_list1]
-            )
-
-        if max(prefix_matches.values()) > prefix_matches[""]:
-            best_prefix = max(prefix_matches, key=prefix_matches.get)
-
-        return best_prefix
-
-    @staticmethod
-    def find_common_ordered_names(
-        param_names: List[List[str]], prefixes: List[str]
-    ) -> List[str]:
-        """Identify and return common parameter names across all models, ensuring correct order. Also account for prefix."""
-        common_names = set(param_names[0])
-        for i in range(1, len(param_names)):
-            prefix = f"{prefixes[i]}." if prefixes[i] else ""
-            common_names.intersection_update({prefix + name for name in param_names[i]})
-        return [name for name in param_names[0] if name in common_names]
-
-    @staticmethod
-    def remove_size_conflicts(common_names, referenced_models, prefixes):
-        model_dirs = [
-            ParameterNamesUtils.resolve_model_directory(m.model.path)
-            for m in referenced_models
-        ]
-        model_indices = [ShardedTensorIndex.from_disk(str(dir)) for dir in model_dirs]
-
-        common_name_and_shape = common_names.copy()
-        removed_names = []
-
-        for name in common_names:
-            base_shape = ParameterNamesUtils.tensor_shape(name, model_indices[0])
-
-            for i in range(1, len(referenced_models)):
-                other_name = name
-                prefix = f"{prefixes[i]}." if prefixes[i] else ""
-                if name.startswith(prefix) and prefix != "":
-                    other_name = name[len(prefix) :]
-                shape = ParameterNamesUtils.tensor_shape(other_name, model_indices[i])
-
-                if base_shape != shape:
-                    common_name_and_shape.remove(name)
-                    removed_names.append((name, base_shape, shape, i))
-                    break
-
-        size_mismatch_count = len(removed_names)
-        if size_mismatch_count > 0:
-            logging.warning(
-                f"Size mismatch detected for {size_mismatch_count}/{size_mismatch_count + len(common_names)} tensors. "
-                "These names were removed from the merge list."
-            )
-            logging.info(
-                "The following tensors have different shapes across models and were removed from the merge list:"
-            )
-            for name, base_shape, shape, i in removed_names:
-                logging.info(
-                    f"Tensor name: {name}, Base model shape: {base_shape}, Mismatched shape: {shape} in model {referenced_models[i].model.path}"
-                )
-
-        return common_name_and_shape
-
-    @staticmethod
-    def are_common_params_ordered(list1: List[str], list2: List[str]) -> bool:
-        """
-        Check if common elements of list2 maintain their relative order in list1.
-        """
-        common_params = set(list1).intersection(set(list2))
-        last_index = -1
-
-        for param in list2:
-            if param in common_params:
-                current_index = list1.index(param)
-                if current_index < last_index:
-                    return False
-                last_index = current_index
-        return True
-
-    @staticmethod
-    def ordered_sublist(list1: List[str], list2: List[str]) -> bool:
-        """
-        Check if list2 is a contiguous ordered sublist of list1.
-        """
-        n, m = len(list1), len(list2)
-
-        for i in range(n - m + 1):
-            if list1[i : i + m] == list2:
-                return True
-        return False
-
-    @staticmethod
-    def report_names_similarity(
-        base_names: List[str], other_names: List[str]
-    ) -> Tuple[Optional[str], str]:
-        """
-        Analyze similarity between parameter names of two models and identify shared prefixes.
-
-        Returns:
-            best_prefix (str): Best matching prefix for parameter names.
-            case_message (str): Explanation of the structural relationship.
-        """
-        possible_prefixes = {""}
-        possible_prefixes.update(
-            {item.split(".")[0] for item in base_names if "." in item}
-        )
-
-        prefixes_subset_overlap = {}
-        best_prefix = None
-        case_message = "No common parameter names found for any prefix"
-
-        for prefix in possible_prefixes:
-            base_names_stripped = [
-                ParameterNamesUtils.strip_prefix(name, prefix) for name in base_names
-            ]
-
-            if ParameterNamesUtils.ordered_sublist(base_names_stripped, other_names):
-                return prefix, "All params in model have exact match in base model."
-
-            intersection = set(base_names_stripped).intersection(set(other_names))
-            prefixes_subset_overlap[prefix] = intersection
-
-        if prefixes_subset_overlap:
-            best_prefix = max(
-                prefixes_subset_overlap, key=lambda x: len(prefixes_subset_overlap[x])
-            )
-            base_names_stripped = [
-                ParameterNamesUtils.strip_prefix(name, best_prefix)
-                for name in base_names
-            ]
-
-            overlap = len(prefixes_subset_overlap[best_prefix])
-            ordered = ParameterNamesUtils.are_common_params_ordered(
-                base_names_stripped, other_names
-            )
-            mismatched = [
-                item for item in other_names if item not in base_names_stripped
-            ]
-            mismatched = "\n    ".join(mismatched)
-            case_message = (
-                f"{overlap}/{len(other_names)} ({100 * overlap / len(other_names):.2f}%) "
-                f"of model parameters are in the base model. \n"
-                f"  Name ordering is {'preserved' if ordered else 'not preserved'}.\n"
-                f"  Missing parameters:\n    {mismatched}"
-            )
-
-        return best_prefix, case_message
-
-    @staticmethod
-    def tensor_shape(name, index) -> Tuple[int]:
-        from safetensors import safe_open
-
-        with safe_open(
-            Path(index.base_path) / index.tensor_paths[name], framework="pt"
-        ) as f:
-            return f.get_slice(name).get_shape()
diff --git a/mergekit/architecture/__init__.py b/mergekit/architecture/__init__.py
new file mode 100644
index 00000000..7f1310b8
--- /dev/null
+++ b/mergekit/architecture/__init__.py
@@ -0,0 +1,88 @@
+# Copyright (C) 2025 Arcee AI
+# SPDX-License-Identifier: BUSL-1.1
+
+import logging
+from typing import TYPE_CHECKING, Optional
+
+from transformers import PretrainedConfig
+
+from mergekit.architecture.auto import infer_architecture_info
+from mergekit.architecture.base import (
+    ConfiguredModelArchitecture,
+    ConfiguredModuleArchitecture,
+    ModelArchitecture,
+    ModuleArchitecture,
+    ModuleDefinition,
+    WeightInfo,
+)
+from mergekit.architecture.json_definitions import NAME_TO_ARCH
+from mergekit.architecture.mixtral import MixtralTensorNames
+from mergekit.options import MergeOptions
+
+if TYPE_CHECKING:
+    from mergekit.config import MergeConfiguration
+
+logger = logging.getLogger(__name__)
+
+
+def arch_info_for_config(config: PretrainedConfig) -> Optional[ModelArchitecture]:
+    if len(config.architectures) != 1:
+        raise RuntimeError("More than one architecture in config?")
+    arch_name = config.architectures[0]
+
+    if arch_name == MixtralTensorNames.ARCHITECTURE_NAME:
+        module = MixtralTensorNames.from_config(config)
+        return ModelArchitecture(
+            modules={"default": ModuleDefinition(architecture=module)},
+            architectures=[arch_name],
+        )
+    elif arch_name in NAME_TO_ARCH:
+        candidates = list(NAME_TO_ARCH[arch_name])
+        if len(candidates) == 1:
+            return candidates[0]
+
+        for c in candidates:
+            if c.expected_model_type == config.model_type:
+                return c
+        logger.warning(
+            f"Multiple architectures for {arch_name}, none match model type {config.model_type}"
+        )
+
+    logger.warning(f"No JSON architecture found for {arch_name}")
+    return None
+
+
+def get_architecture_info(
+    config: "MergeConfiguration", options: MergeOptions
+) -> ModelArchitecture:
+    models = config.referenced_models()
+    if not models:
+        raise ValueError("No models referenced in config")
+
+    model_arch_info = [
+        arch_info_for_config(m.config(trust_remote_code=options.trust_remote_code))
+        for m in models
+    ]
+    if all(arch is not None for arch in model_arch_info):
+        if not options.allow_crimes and any(
+            arch != model_arch_info[0] for arch in model_arch_info
+        ):
+            raise RuntimeError(
+                "Must specify --allow-crimes to attempt to mix different architectures"
+            )
+        return model_arch_info[0]
+
+    # try to infer from all models
+    return infer_architecture_info(models, config.base_model, options)
+
+
+__all__ = [
+    "ModelArchitecture",
+    "ModuleArchitecture",
+    "ModuleDefinition",
+    "ConfiguredModuleArchitecture",
+    "ConfiguredModelArchitecture",
+    "WeightInfo",
+    "get_architecture_info",
+    "arch_info_for_config",
+]
diff --git a/mergekit/architecture/auto.py b/mergekit/architecture/auto.py
new file mode 100644
index 00000000..3eee5855
--- /dev/null
+++ b/mergekit/architecture/auto.py
@@ -0,0 +1,120 @@
+# Copyright (C) 2025 Arcee AI
+# SPDX-License-Identifier: BUSL-1.1
+
+import logging
+import re
+from collections import defaultdict
+from typing import List, Optional
+
+from mergekit.architecture.base import (
+    ModelArchitecture,
+    ModuleDefinition,
+    WeightInfo,
+)
+from mergekit.architecture.json_definitions import (
+    JsonLayerTemplates,
+    JsonModuleArchDef,
+    JsonModuleArchitecture,
+)
+from mergekit.common import ModelReference
+from mergekit.options import MergeOptions
+
+RE_LAYER_INDEX = re.compile(r"\.(\d+)\.")
+
+logger = logging.getLogger(__name__)
+
+
+def get_model_tensor_names(model: ModelReference, options: MergeOptions) -> List[str]:
+    loader = model.lazy_loader(
+        cache_dir=options.transformers_cache, lazy_unpickle=options.lazy_unpickle
+    )
+    return list(loader.index.tensor_paths.keys())
+
+
+def infer_architecture_info(
+    models: List[ModelReference],
+    base_model: Optional[ModelReference],
+    options: MergeOptions,
+) -> ModelArchitecture:
+    model_tensor_names = {
+        model: set(get_model_tensor_names(model, options))
+        for model in (set(models).union({base_model} if base_model else {}))
+    }
+    if base_model is None:
+        base_model = models.pop(0)
+    all_tensor_names = set().union(*model_tensor_names.values())
+    in_all_models = all_tensor_names.intersection(*model_tensor_names.values())
+
+    module_prefixes = set()
+    module_layer_counts = defaultdict(int)
+    module_templates = defaultdict(set)
+    module_loose_weights = defaultdict(set)
+    for tensor_name in all_tensor_names:
+        if len(RE_LAYER_INDEX.findall(tensor_name)) > 1:
+            raise ValueError(
+                f"Tensor name {tensor_name} has more than one layer index - not supported"
+            )
+        elif match := RE_LAYER_INDEX.search(tensor_name):
+            prefix = tensor_name[: match.start()]
+            module_prefixes.add(prefix)
+            layer_idx = int(match.group(1))
+            module_layer_counts[prefix] = max(
+                module_layer_counts[prefix], layer_idx + 1
+            )
+            module_templates[prefix] = module_templates[prefix].union(
+                set([RE_LAYER_INDEX.sub("{layer_index}", tensor_name)])
+            )
+
+    # create a default module with no prefix
+    module_prefixes.add("")
+
+    for tensor_name in all_tensor_names:
+        if RE_LAYER_INDEX.search(tensor_name):
+            continue
+        for prefix in module_prefixes:
+            if tensor_name.startswith(prefix):
+                module_loose_weights[prefix].add(tensor_name[len(prefix) :])
+
+    if not (module_loose_weights[""] or module_templates[""]):
+        module_prefixes.remove("")
+    if not module_prefixes:
+        raise ValueError("No modules found in models")
+
+    logging.warning(f"Inferred {len(module_prefixes)} modules:")
+    for prefix in module_prefixes:
+        logging.warning(
+            f"  {repr(prefix or 'default')} with {module_layer_counts[prefix]} layers, {len(module_templates[prefix])} templates, and {len(module_loose_weights[prefix])} loose weights"
+        )
+
+    def _wi(template: str) -> WeightInfo:
+        optional = template.replace("{layer_index}", "0") not in in_all_models
+        return WeightInfo(
+            name=template,
+            optional=optional,
+        )
+
+    module_archs = {}
+    for prefix in module_prefixes:
+        num_layers = module_layer_counts[prefix]
+        module_archs[prefix or "default"] = JsonModuleArchitecture(
+            definition=JsonModuleArchDef(
+                model_type="",
+                architectures=[],
+                pre_weights=[_wi(t) for t in module_loose_weights[prefix]],
+                layer_templates=JsonLayerTemplates(
+                    weights=[_wi(t) for t in module_templates[prefix]]
+                ),
+                post_weights=[],
+                num_layers_config_key=None,
+                override_num_layers=num_layers,
+            ),
+        )
+
+    return ModelArchitecture(
+        modules={
+            key: ModuleDefinition(architecture=value)
+            for key, value in module_archs.items()
+        },
+        architectures=[],
+        model_type="",
+    )
diff --git a/mergekit/architecture/base.py b/mergekit/architecture/base.py
new file mode 100644
index 00000000..26c0d231
--- /dev/null
+++ b/mergekit/architecture/base.py
@@ -0,0 +1,152 @@
+# Copyright (C) 2025 Arcee AI
+# SPDX-License-Identifier: BUSL-1.1
+
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional, Tuple
+
+from pydantic import BaseModel, Field
+from transformers import PretrainedConfig
+
+from mergekit.common import get_config_value
+
+
+class WeightInfo(BaseModel, frozen=True):
+    """Information about an individual weight tensor in a model.
+
+    Attributes:
+        name (str):
+            The name of the tensor representing the weight.
+        is_embed (bool):
+            Indicates whether the weight is for an embedding or language model head.
+        optional (bool):
+            Indicates whether the weight can be omitted from a model.
+        aliases (Optional[List[str]]):
+            List of alternative names for the weight, if applicable.
+        force_dtype (Optional[str]):
+            Mandatory dtype for the weight, if applicable.
+    """
+
+    name: str
+    is_embed: bool = False
+    optional: bool = False
+    aliases: Optional[Tuple[str, ...]] = None
+    force_dtype: Optional[str] = None
+    tied_names: Optional[Tuple[str, ...]] = None
+
+
+def _prefix_weight(weight: WeightInfo, prefix: Optional[str] = None) -> WeightInfo:
+    if prefix is None:
+        return weight
+    return WeightInfo(
+        name=prefix + weight.name,
+        aliases=tuple(prefix + alias for alias in weight.aliases or ()) or None,
+        tied_names=tuple(prefix + tied_name for tied_name in weight.tied_names or ())
+        or None,
+        **weight.model_dump(exclude={"name", "aliases", "tied_names"}),
+    )
+
+
+class ModuleArchitecture(ABC):
+    @abstractmethod
+    def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
+        """Return a list of all weights preceding the first layer."""
+        ...
+
+    @abstractmethod
+    def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
+        """Return a list of all weights following the final layer."""
+        ...
+
+    @abstractmethod
+    def layer_weights(
+        self, index: int, config: PretrainedConfig
+    ) -> Optional[List[WeightInfo]]:
+        """Return a list of all weights associated with a given layer."""
+        ...
+
+    def num_layers_config_key(self) -> str:
+        """Key in config that represents number of layers"""
+        return "num_hidden_layers"
+
+    def num_layers(self, config: PretrainedConfig) -> int:
+        """Return the number of layers in a model."""
+        return get_config_value(config, self.num_layers_config_key())
+
+    def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
+        """Return all weights associated with a model."""
+        num_layers = self.num_layers(config)
+        res = list(self.pre_weights(config))
+        for layer_idx in range(num_layers):
+            res.extend(self.layer_weights(layer_idx, config))
+        res.extend(self.post_weights(config))
+        return res
+
+
+class ConfiguredModuleArchitecture(
+    BaseModel, frozen=True, arbitrary_types_allowed=True
+):
+    info: ModuleArchitecture
+    config: PretrainedConfig
+    weight_prefix: Optional[str] = None
+
+    def num_layers(self) -> int:
+        return self.info.num_layers(self.config)
+
+    def pre_weights(self) -> List[WeightInfo]:
+        return [
+            _prefix_weight(w, self.weight_prefix)
+            for w in self.info.pre_weights(self.config)
+        ]
+
+    def post_weights(self) -> List[WeightInfo]:
+        return [
+            _prefix_weight(w, self.weight_prefix)
+            for w in self.info.post_weights(self.config)
+        ]
+
+    def layer_weights(self, index: int) -> List[WeightInfo]:
+        return [
+            _prefix_weight(w, self.weight_prefix)
+            for w in self.info.layer_weights(index, self.config)
+        ]
+
+    def all_weights(self) -> List[WeightInfo]:
+        return [
+            _prefix_weight(w, self.weight_prefix)
+            for w in self.info.all_weights(self.config)
+        ]
+
+
+class ModuleDefinition(BaseModel, frozen=True, arbitrary_types_allowed=True):
+    architecture: ModuleArchitecture
+    weight_prefix: Optional[str] = None
+    subfolder: Optional[str] = None
+
+
+class ModelArchitecture(BaseModel, frozen=True):
+    modules: Dict[str, ModuleDefinition]
+    architectures: List[str]
+    expected_model_type: str = Field(alias="model_type")
+    tagalong_files: Optional[List[str]] = None
+
+    def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
+        res = []
+        for module in self.modules.values():
+            for weight_info in module.architecture.all_weights(config=config):
+                res.append(_prefix_weight(weight_info, module.weight_prefix))
+        return res
+
+
+class ConfiguredModelArchitecture(BaseModel, frozen=True, arbitrary_types_allowed=True):
+    info: ModelArchitecture
+    config: PretrainedConfig
+
+    def all_weights(self) -> List[WeightInfo]:
+        return self.info.all_weights(self.config)
+
+    def get_module(self, module_name: str) -> ConfiguredModuleArchitecture:
+        return ConfiguredModuleArchitecture(
+            info=self.info.modules[module_name].architecture,
+            config=self.config,
+            weight_prefix=self.info.modules[module_name].weight_prefix,
+        )
diff --git a/mergekit/architecture/json_definitions.py b/mergekit/architecture/json_definitions.py
new file mode 100644
index 00000000..0c11c486
--- /dev/null
+++ b/mergekit/architecture/json_definitions.py
@@ -0,0 +1,187 @@
+# Copyright (C) 2025 Arcee AI
+# SPDX-License-Identifier: BUSL-1.1
+
+import importlib
+import importlib.resources
+import json
+import string
+from typing import Dict, List, Optional, Tuple
+
+from pydantic import BaseModel, Field
+from transformers import PretrainedConfig
+from typing_extensions import Literal
+
+import mergekit._data.architectures
+from mergekit.architecture.base import (
+    ModelArchitecture,
+    ModuleArchitecture,
+    ModuleDefinition,
+    WeightInfo,
+)
+
+
+class JsonLayerTemplates(BaseModel, frozen=True):
+    weights: List[WeightInfo]
+
+
+class JsonModuleArchDef(BaseModel, frozen=True):
+    expected_model_type: str = Field(alias="model_type")
+    architectures: List[str]
+    pre_weights: List[WeightInfo]
+    layer_templates: JsonLayerTemplates
+    post_weights: List[WeightInfo]
+    num_layers_config_key: Optional[str] = None
+    override_num_layers: Optional[int] = None
+
+
+class JsonModuleArchitecture(ModuleArchitecture, BaseModel, frozen=True):
+    kind: Literal["module"] = "module"
+    definition: JsonModuleArchDef
+
+    def _substitute(
+        self,
+        item: WeightInfo,
+        config: PretrainedConfig,
+        layer_idx: Optional[int] = None,
+    ) -> WeightInfo:
+        num_layers = self.num_layers(config)
+
+        obj_dict = item.model_dump(mode="json", exclude_unset=True)
+        for key in obj_dict:
+            if isinstance(obj_dict[key], str):
+                obj_dict[key] = _template_substitution(
+                    obj_dict[key], num_layers, layer_idx
+                )
+            elif isinstance(obj_dict[key], list):
+                obj_dict[key] = [
+                    (
+                        _template_substitution(s, num_layers, layer_idx)
+                        if isinstance(s, str)
+                        else s
+                    )
+                    for s in obj_dict[key]
+                ]
+        return type(item).model_validate(obj_dict)
+
+    def name(self) -> str:
+        return self.definition.expected_model_type
+
+    def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
+        return [
+            self._substitute(wi, config=config) for wi in self.definition.pre_weights
+        ]
+
+    def layer_weights(
+        self, index: int, config: PretrainedConfig
+    ) -> Optional[List[WeightInfo]]:
+        return [
+            self._substitute(wi, config=config, layer_idx=index)
+            for wi in self.definition.layer_templates.weights
+        ]
+
+    def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
+        return [
+            self._substitute(wi, config=config) for wi in self.definition.post_weights
+        ]
+
+    def num_layers_config_key(self) -> str:
+        return self.definition.num_layers_config_key
+
+    def num_layers(self, config):
+        if self.definition.override_num_layers is not None:
+            return self.definition.override_num_layers
+        return super().num_layers(config)
+
+
+class JsonModuleDefinition(BaseModel, frozen=True):
+    architecture: JsonModuleArchDef
+    weight_prefix: Optional[str] = None
+    subfolder: Optional[str] = None
+
+
+class JsonModularArchitectureDefinition(BaseModel, frozen=True):
+    kind: Literal["modular"]
+    modules: Dict[str, JsonModuleDefinition]
+    architectures: List[str]
+    expected_model_type: str = Field(alias="model_type")
+    tagalong_files: Optional[List[str]] = None
+
+
+class TemplateWithArithmetic(string.Template):
+    idpattern = r"(?a:[_a-z][_a-z0-9]*([+-]1)?)"
+
+
+def _template_substitution(
+    template: str, num_layers: int, layer_idx: Optional[int] = None
+) -> str:
+    if "{" not in template:
+        return template
+
+    substitutions = {
+        "num_layers": num_layers,
+        "num_layers+1": num_layers + 1,
+        "num_layers-1": num_layers - 1,
+    }
+
+    if layer_idx is not None:
+        substitutions.update(
+            {
+                "layer_index": layer_idx,
+                "layer_index+1": layer_idx + 1,
+                "layer_index-1": layer_idx - 1,
+            }
+        )
+
+    return TemplateWithArithmetic(template).substitute(substitutions)
+
+
+def _load_architecture_json(name: str) -> ModelArchitecture:
+    with importlib.resources.open_text(mergekit._data.architectures, name) as f:
+        text = f.read()
+    data = json.loads(text)
+    kind = data.get("kind", "module")
+    if kind == "modular":
+        parsed = JsonModularArchitectureDefinition.model_validate_json(text)
+        return ModelArchitecture(
+            modules={
+                k: ModuleDefinition(
+                    architecture=JsonModuleArchitecture(definition=v.architecture),
+                    weight_prefix=v.weight_prefix,
+                    subfolder=v.subfolder,
+                )
+                for k, v in parsed.modules.items()
+            },
+            architectures=parsed.architectures,
+            model_type=parsed.expected_model_type,
+            tagalong_files=parsed.tagalong_files,
+        )
+    elif data.get("kind", "module") == "module":
+        module = JsonModuleArchitecture(
+            definition=JsonModuleArchDef.model_validate(data)
+        )
+        return ModelArchitecture(
+            modules={"default": ModuleDefinition(architecture=module)},
+            architectures=module.definition.architectures,
+            model_type=module.definition.expected_model_type,
+        )
+    else:
+        raise RuntimeError(f"Unexpected architecture kind: {data['kind']}")
+
+
+def _load_all_architectures() -> (
+    Tuple[List[ModelArchitecture], Dict[str, List[ModelArchitecture]]]
+):
+    architectures: List[ModelArchitecture] = []
+    for f in importlib.resources.contents(mergekit._data.architectures):
+        if f.lower().endswith(".json"):
+            architectures.append(_load_architecture_json(f))
+
+    name_to_arch: Dict[str, List[JsonModuleArchitecture]] = {}
+    for arch_info in architectures:
+        for arch_name in arch_info.architectures:
+            name_to_arch[arch_name] = name_to_arch.get(arch_name, [])
+            name_to_arch[arch_name].append(arch_info)
+    return architectures, name_to_arch
+
+
+JSON_ARCHITECTURES, NAME_TO_ARCH = _load_all_architectures()
diff --git a/mergekit/architecture/mixtral.py b/mergekit/architecture/mixtral.py
new file mode 100644
index 00000000..47c9c440
--- /dev/null
+++ b/mergekit/architecture/mixtral.py
@@ -0,0 +1,58 @@
+# Copyright (C) 2025 Arcee AI
+# SPDX-License-Identifier: BUSL-1.1
+
+from typing import ClassVar, List, Optional
+
+from pydantic import BaseModel
+from transformers import PretrainedConfig
+
+from mergekit.architecture.base import (
+    ModuleArchitecture,
+    WeightInfo,
+)
+from mergekit.architecture.json_definitions import NAME_TO_ARCH
+
+MISTRAL_INFO = NAME_TO_ARCH["MistralForCausalLM"][0]
+MISTRAL_MODULE_ARCH = MISTRAL_INFO.modules["default"].architecture
+
+
+class MixtralTensorNames(ModuleArchitecture, BaseModel):
+    ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM"
+    num_local_experts: int
+
+    def name(self) -> str:
+        return "mixtral"
+
+    @classmethod
+    def from_config(cls, config: PretrainedConfig):
+        return MixtralTensorNames(num_local_experts=config.num_local_experts)
+
+    def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
+        return MISTRAL_MODULE_ARCH.pre_weights(config)
+
+    def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]:
+        return MISTRAL_MODULE_ARCH.post_weights(config)
+
+    def num_layers_config_key(self) -> str:
+        return MISTRAL_MODULE_ARCH.num_layers_config_key()
+
+    def layer_weights(
+        self, index: int, config: PretrainedConfig
+    ) -> Optional[List[WeightInfo]]:
+        num_experts = self.num_local_experts
+        prefix = f"model.layers.{index}"
+        tensor_names = []
+        for expert_idx in range(num_experts):
+            for param in ("w1", "w2", "w3"):
+                tensor_names.append(
+                    prefix + f".block_sparse_moe.experts.{expert_idx}.{param}.weight"
+                )
+        tensor_names.append(prefix + ".block_sparse_moe.gate.weight")
+        res = []
+        for name in tensor_names:
+            res.append(WeightInfo(name=name))
+        for weight_info in MISTRAL_MODULE_ARCH.layer_weights(index, config):
+            if ".mlp." in weight_info.name:
+                continue
+            res.append(weight_info)
+        return res
diff --git a/mergekit/common.py b/mergekit/common.py
index 8a087543..f16ddde5 100644
--- a/mergekit/common.py
+++ b/mergekit/common.py
@@ -13,6 +13,7 @@
     Iterator,
     Mapping,
     Optional,
+    Protocol,
     Tuple,
     Union,
     get_args,
@@ -31,6 +32,32 @@
 from mergekit.io import LazyTensorLoader, ShardedTensorIndex
 
 
+def set_config_value(config: PretrainedConfig, key: str, value: Any):
+    """Set a value in a PretrainedConfig object."""
+    parts = key.split(".")
+    obj = config
+    for idx, part in enumerate(parts[:-1]):
+        if not hasattr(obj, part):
+            raise RuntimeError(
+                f"Config {config} has no attribute {'.'.join(parts[:idx+1])}"
+            )
+        obj = getattr(obj, part)
+    setattr(obj, parts[-1], value)
+
+
+def get_config_value(config: PretrainedConfig, key: str) -> Any:
+    """Get a value from a PretrainedConfig object."""
+    parts = key.split(".")
+    obj = config
+    for idx, part in enumerate(parts):
+        if not hasattr(obj, part):
+            raise RuntimeError(
+                f"Config {config} has no attribute {'.'.join(parts[:idx+1])}"
+            )
+        obj = getattr(obj, part)
+    return obj
+
+
 class ModelPath(BaseModel, frozen=True):
     path: str
     revision: Optional[str] = None
@@ -92,7 +119,7 @@ def merged(
             os.makedirs(out_path, exist_ok=True)
 
             config = self.config(trust_remote_code)
-            auto_cls = _get_auto_cls(config.architectures[0])
+            auto_cls = get_auto_cls(config.architectures[0])
 
             logging.info(f"Loading {self.model} for merge...")
             model = auto_cls.from_pretrained(
@@ -110,7 +137,7 @@ def merged(
             model.save_pretrained(out_path, safe_serialization=True)
             del model
 
-        return ModelReference(model=out_path)
+        return ModelReference(model=ModelPath(path=out_path))
 
     def config(self, trust_remote_code: bool = False) -> PretrainedConfig:
         res = AutoConfig.from_pretrained(
@@ -270,8 +297,70 @@ def values(self) -> Iterator[T_V]:
         return self.data.values()
 
 
-def _get_auto_cls(arch_name: str):
+ARCH_NAME_TO_AUTO_CLS = {}
+
+try:
+    import transformers.models.auto.modeling_auto as tf_auto
+except ImportError:
+    tf_auto = None
+
+if tf_auto is not None:
+    for map_name, cls_name in [
+        ("MODEL_MAPPING_NAMES", "AutoModel"),
+        (
+            "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES",
+            "AutoModelForAudioClassification",
+        ),
+        (
+            "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES",
+            "AutoModelForImageClassification",
+        ),
+        ("MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", "AutoModelForSpeechSeq2Seq"),
+        (
+            "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES",
+            "AutoModelForSequenceClassification",
+        ),
+        ("MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"),
+        (
+            "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES",
+            "AutoModelForTokenClassification",
+        ),
+        ("MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "AutoModelForImageTextToText"),
+        ("MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES", "AutoModelForTextToWaveform"),
+        ("MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
+        ("MODEL_FOR_CAUSAL_LM_MAPPING_NAMES", "AutoModelForCausalLM"),
+    ]:
+        cls = getattr(transformers, cls_name, None)
+        if cls is None:
+            logging.info(f"Could not find {cls_name} in transformers")
+            continue
+        if hasattr(tf_auto, map_name):
+            name_to_arch_name = getattr(tf_auto, map_name)
+            for arch_name in name_to_arch_name.values():
+                ARCH_NAME_TO_AUTO_CLS[arch_name] = cls
+
+
+class AutoClassProtocol(Protocol):
+    def from_pretrained(
+        self,
+        pretrained_model_name_or_path: str,
+        *model_args,
+        **kwargs,
+    ) -> transformers.PreTrainedModel: ...
+
+    def from_config(
+        self,
+        config: transformers.PretrainedConfig,
+        *model_args,
+        **kwargs,
+    ) -> transformers.PreTrainedModel: ...
+
+
+def get_auto_cls(arch_name: str) -> AutoClassProtocol:
     """Get the AutoModel class for a given architecture name."""
+    if arch_name in ARCH_NAME_TO_AUTO_CLS:
+        return ARCH_NAME_TO_AUTO_CLS[arch_name]
+
     if arch_name.endswith("ForMaskedLM"):
         auto_cls = transformers.AutoModelForMaskedLM
     elif arch_name.endswith("ForSequenceClassification"):
diff --git a/mergekit/config.py b/mergekit/config.py
index 532d30cf..c449e778 100644
--- a/mergekit/config.py
+++ b/mergekit/config.py
@@ -70,11 +70,24 @@ class OutputSliceDefinition(BaseModel):
     parameters: Optional[Dict[str, ParameterSetting]] = None
 
 
-class MergeConfiguration(BaseModel):
-    merge_method: str
+class OutputModuleDefinition(BaseModel):
     slices: Optional[List[OutputSliceDefinition]] = None
     models: Optional[List[InputModelDefinition]] = None
     parameters: Optional[Dict[str, ParameterSetting]] = None
+
+    @model_validator(mode="after")
+    def validate_inputs(self):
+        if ((not self.slices) and (not self.models)) or (self.slices and self.models):
+            raise RuntimeError("Must specify either output slices or models to merge")
+        return self
+
+
+class MergeConfiguration(BaseModel):
+    modules: Optional[Dict[str, OutputModuleDefinition]] = None
+    slices: Optional[List[OutputSliceDefinition]] = None
+    models: Optional[List[InputModelDefinition]] = None
+
+    merge_method: str
     base_model: Optional[ModelReference] = None
     dtype: Optional[str] = None
     tokenizer_source: Union[Literal["union"], Literal["base"], ModelReference, None] = (
@@ -83,6 +96,7 @@ class MergeConfiguration(BaseModel):
     tokenizer: Optional[TokenizerConfig] = None
     chat_template: Optional[str] = None
     out_dtype: Optional[str] = None
+    parameters: Optional[Dict[str, ParameterSetting]] = None
 
     def referenced_models(self) -> List[ModelReference]:
         models = set()
@@ -95,12 +109,31 @@ def referenced_models(self) -> List[ModelReference]:
             for s in self.slices:
                 for src in s.sources:
                     models.add(src.model)
+        if self.modules:
+            for m in self.modules.values():
+                if m.models:
+                    for model_in in m.models:
+                        models.add(model_in.model)
+                if m.slices:
+                    for s in m.slices:
+                        for src in s.sources:
+                            models.add(src.model)
         return list(models)
 
     @model_validator(mode="after")
     def validate_inputs(self):
-        if ((not self.slices) and (not self.models)) or (self.slices and self.models):
-            raise RuntimeError("Must specify either output slices or models to merge")
+        set_ct = 0
+        if self.modules:
+            set_ct += 1
+        if self.slices:
+            set_ct += 1
+        if self.models:
+            set_ct += 1
+
+        if set_ct != 1:
+            raise RuntimeError(
+                "Exactly one of 'models', 'slices', or 'modules' must be present"
+            )
         return self
 
     @model_validator(mode="after")
@@ -121,6 +154,7 @@ class ConfigReader(BaseModel):
     t: float
     tensor_name: Optional[str] = None
     slice_out: Optional[OutputSliceDefinition] = None
+    module: Optional[OutputModuleDefinition] = None
 
     @property
     def base_model(self) -> Optional[ModelReference]:
@@ -137,6 +171,7 @@ def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader":
             t=self.t,
             tensor_name=self.tensor_name,
             slice_out=slice,
+            module=self.module,
         )
 
     def for_tensor(self, tensor_name: str) -> "ConfigReader":
@@ -145,6 +180,7 @@ def for_tensor(self, tensor_name: str) -> "ConfigReader":
             t=self.t,
             tensor_name=tensor_name,
             slice_out=self.slice_out,
+            module=self.module,
         )
 
     def with_t(self, t: float) -> "ConfigReader":
@@ -153,6 +189,16 @@ def with_t(self, t: float) -> "ConfigReader":
             t=t,
             tensor_name=self.tensor_name,
             slice_out=self.slice_out,
+            module=self.module,
+        )
+
+    def for_module(self, module: OutputModuleDefinition) -> "ConfigReader":
+        return ConfigReader(
+            config=self.config,
+            t=self.t,
+            tensor_name=self.tensor_name,
+            slice_out=self.slice_out,
+            module=module,
         )
 
     def parameter(
@@ -179,6 +225,15 @@ def parameter(
                 if value is not None:
                     return value
 
+        if self.module and self.module.parameters and name in self.module.parameters:
+            value = evaluate_setting(
+                self.tensor_name,
+                self.module.parameters[name],
+                self.t,
+            )
+            if value is not None:
+                return value
+
         if self.config.parameters and name in self.config.parameters:
             value = evaluate_setting(
                 self.tensor_name,
diff --git a/mergekit/evo/actors.py b/mergekit/evo/actors.py
index 43abdcbf..c37ff7e0 100644
--- a/mergekit/evo/actors.py
+++ b/mergekit/evo/actors.py
@@ -17,13 +17,15 @@
 import transformers
 from transformers.utils import is_flash_attn_2_available
 
+from mergekit.architecture.base import ConfiguredModelArchitecture
+
 try:
     import vllm
 except ImportError:
     vllm = None
 
 
-from mergekit.architecture import ArchitectureInfoUtils, ConfiguredArchitectureInfo
+from mergekit.architecture import arch_info_for_config
 from mergekit.config import MergeConfiguration
 from mergekit.evo.config import EvolMergeConfiguration
 from mergekit.evo.genome import InvalidGenotypeError, ModelGenome
@@ -130,7 +132,7 @@ class InMemoryMergeEvaluator(MergeActorBase):
     model: Union[
         lm_eval.models.huggingface.HFLM, lm_eval.models.vllm_causallms.VLLM, None
     ] = None
-    arch_info: Optional[ConfiguredArchitectureInfo] = None
+    arch_info: Optional[ConfiguredModelArchitecture] = None
 
     def __init__(
         self,
@@ -142,9 +144,7 @@ def __init__(
         super().__init__(*args, vllm=vllm, **kwargs)
 
     def _maybe_init_model(self, config: MergeConfiguration):
-        ai = ArchitectureInfoUtils.get_architecture_info(
-            self.genome._input_config_example
-        )
+        ai = arch_info_for_config(self.genome._input_config_example)
         cfg_out = _model_out_config(
             config,
             ai,
@@ -167,7 +167,7 @@ def _maybe_init_model(self, config: MergeConfiguration):
                     continue
 
                 if getattr(cfg_out, key) != getattr(self.arch_info.config, key, None):
-                    logger.warn(f"Config key {key} changed, reinitializing model")
+                    logger.warning(f"Config key {key} changed, reinitializing model")
                     different = True
                     break
 
@@ -240,7 +240,14 @@ def _maybe_init_model(self, config: MergeConfiguration):
                 )
         else:
             self.model = lm_eval.models.huggingface.HFLM(pretrained=inner_model)
-        self.arch_info = ConfiguredArchitectureInfo(info=ai, config=cfg_out)
+        self.arch_info = (
+            ConfiguredModelArchitecture(
+                info=ai,
+                config=cfg_out,
+            )
+            if ai
+            else None
+        )
         logger.info("Model initialized")
 
     def evaluate(self, genotype: torch.Tensor) -> dict:
diff --git a/mergekit/graph.py b/mergekit/graph.py
index d518ccae..1d6309b1 100644
--- a/mergekit/graph.py
+++ b/mergekit/graph.py
@@ -8,6 +8,7 @@
     Executor: Class for scheduling and executing directed acyclic task graphs.
 """
 
+import logging
 from abc import ABC, abstractmethod
 from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
 
@@ -19,6 +20,8 @@
 
 ValueT = TypeVar("ValueT")
 
+logger = logging.getLogger(__name__)
+
 
 class Task(ABC, BaseModel, Generic[ValueT], frozen=True):
     """
@@ -243,20 +246,35 @@ def _move_tensors(
     DUMMY_TASK_VALUE = "!!DUMMY!!"
 
     def _make_schedule(self, targets: List[Task]) -> List[Task]:
+        logger.debug(f"Building schedule for {len(targets)} targets")
         self.schedule = []
         self.dependencies = self._build_dependencies(targets)
 
+        node_indices = {}
+        node_values = []
+
+        # instead of using the actual task objects as nodes in the graph,
+        # use an integer index to avoid reserializing the task objects
+        # inside networkx (slow)
+        def _index(node: Union[Task, str]) -> int:
+            if node not in node_indices:
+                node_indices[node] = len(node_indices)
+                node_values.append(node)
+            return node_indices[node]
+
         edge_tups = []
         for node in self.dependencies:
             for dependency in self.dependencies[node]:
-                edge_tups.append((dependency, node))
+                edge_tups.append((_index(dependency), _index(node)))
 
+        # add edges from a dummy node to each target to guarantee
+        # they will be included in the final schedule
+        dummy_index = _index(Executor.DUMMY_TASK_VALUE)
         for task in targets:
-            # add edges from a dummy node to each target to guarantee
-            # they will be included in the final schedule
-            edge_tups.append((Executor.DUMMY_TASK_VALUE, task))
+            edge_tups.append((dummy_index, _index(task)))
 
-        def _compare_key(task: Union[Task, str]):
+        def _compare_key(node: int) -> Tuple[str, int]:
+            task = node_values[node]
             if task == Executor.DUMMY_TASK_VALUE:
                 return ("", 0)
             return (
@@ -265,13 +283,14 @@ def _compare_key(task: Union[Task, str]):
             )
 
         graph = networkx.DiGraph(edge_tups)
-        res = [
-            t
-            for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)
-            if (t != Executor.DUMMY_TASK_VALUE)
-            and (t not in (self.cached_values or {}))
+        return [
+            node_values[idx]
+            for idx in networkx.lexicographical_topological_sort(
+                graph, key=_compare_key
+            )
+            if (idx != dummy_index)
+            and node_values[idx] not in (self.cached_values or {})
         ]
-        return res
 
     def _build_dependencies(self, targets: List[Task]) -> Dict[Task, Set[Task]]:
         task_dependencies: Dict[Task, Set[Task]] = {}
diff --git a/mergekit/io/tensor_writer.py b/mergekit/io/tensor_writer.py
index 6ec474d1..19778f76 100644
--- a/mergekit/io/tensor_writer.py
+++ b/mergekit/io/tensor_writer.py
@@ -122,7 +122,7 @@ def finalize(self):
                 json.dump(
                     {
                         "metadata": {
-                            "mergekit_version": "0.1.1",
+                            "mergekit_version": "0.1.2",
                         },
                         "weight_map": self.weight_map,
                     },
diff --git a/mergekit/merge.py b/mergekit/merge.py
index 994774e3..cf18ad9e 100644
--- a/mergekit/merge.py
+++ b/mergekit/merge.py
@@ -6,16 +6,16 @@
 import logging
 import os
 import shutil
-import warnings
 from collections import Counter
-from typing import Optional
+from typing import List, Optional
 
 import tqdm
 import transformers
 
 from mergekit._data import chat_templates
-from mergekit.architecture import ArchitectureInfo, ArchitectureInfoUtils
+from mergekit.architecture import ModelArchitecture, get_architecture_info
 from mergekit.card import generate_card
+from mergekit.common import set_config_value
 from mergekit.config import MergeConfiguration
 from mergekit.graph import Executor
 from mergekit.io.tasks import LoaderCache
@@ -39,7 +39,7 @@ def run_merge(
     if not merge_config.models and not merge_config.slices:
         raise RuntimeError("No output requested")
 
-    arch_info = _load_arch_info(merge_config, options)
+    arch_info = get_architecture_info(merge_config, options)
 
     # initialize loader cache and set options
     loader_cache = LoaderCache()
@@ -112,7 +112,11 @@ def run_merge(
         ) as fp:
             fp.write(config_source)
 
-    if tokenizer is None:
+    if tokenizer is not None:
+        logger.info("Saving tokenizer")
+        _set_chat_template(tokenizer, merge_config)
+        tokenizer.save_pretrained(out_path, safe_serialization=True)
+    else:
         if options.copy_tokenizer:
             try:
                 _copy_tokenizer(
@@ -128,10 +132,12 @@ def run_merge(
                 "Chat template specified but no tokenizer found. Chat template will not be saved."
             )
 
-    if tokenizer:
-        logger.info("Saving tokenizer")
-        _set_chat_template(tokenizer, merge_config)
-        tokenizer.save_pretrained(out_path, safe_serialization=True)
+    _copy_tagalong_files(
+        merge_config,
+        out_path,
+        files=arch_info.tagalong_files or [],
+        trust_remote_code=options.trust_remote_code,
+    )
 
     if getattr(arch_info, "post_fill_parameters", False):
         from mergekit.scripts.fill_missing_params import copy_and_fill_missing_params
@@ -192,6 +198,25 @@ def _set_chat_template(
     tokenizer.chat_template = chat_template
 
 
+def _copy_tagalong_files(
+    merge_config: MergeConfiguration,
+    out_path: str,
+    files: List[str],
+    trust_remote_code: bool = False,
+):
+    donor_model = merge_config.base_model or (merge_config.referenced_models()[0])
+
+    for file_name in files:
+        if os.path.exists(os.path.join(donor_model.model.path, file_name)):
+            logger.info(f"Copying {file_name} from {donor_model}")
+            shutil.copy(
+                os.path.join(donor_model.model.path, file_name),
+                os.path.join(out_path, file_name),
+            )
+
+    return
+
+
 def _copy_tokenizer(
     merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False
 ):
@@ -214,6 +239,8 @@ def _copy_tokenizer(
             "special_tokens_map.json",
             "tokenizer.json",
             "tokenizer.model",
+            "added_tokens.json",
+            "merges.txt",
         ]:
             if os.path.exists(os.path.join(donor_model.model.path, file_name)):
                 shutil.copy(
@@ -236,7 +263,7 @@ def _copy_tokenizer(
 
 def _model_out_config(
     config: MergeConfiguration,
-    arch_info: ArchitectureInfo,
+    arch_info: ModelArchitecture,
     trust_remote_code: bool = False,
 ) -> transformers.PretrainedConfig:
     """Return a configuration for the resulting model."""
@@ -249,19 +276,33 @@ def _model_out_config(
     elif config.dtype:
         res.torch_dtype = config.dtype
 
-    if config.slices:
-        try:
-            num_layers = sum(
+    module_layers = {}
+    for module_name in arch_info.modules:
+        if config.modules and module_name in config.modules:
+            module_def = config.modules.get(module_name)
+            module_layers[module_name] = sum(
                 s.sources[0].layer_range[1] - s.sources[0].layer_range[0]
-                for s in config.slices
+                for s in module_def.slices
             )
-            setattr(res, arch_info.num_layers_config_key(), num_layers)
-        except Exception as e:
-            logger.warning(
-                "Unable to set number of layers in output config - you may need to manually correct it.",
-                exc_info=e,
+        elif config.slices:
+            module_layers[module_name] = sum(
+                s.sources[0].layer_range[1] - s.sources[0].layer_range[0]
+                for s in config.slices
             )
 
+    if module_layers:
+        for module_name in module_layers:
+            try:
+                module_info = arch_info.modules[module_name]
+                cfg_key = module_info.architecture.num_layers_config_key()
+                set_config_value(res, cfg_key, module_layers[module_name])
+            except Exception as e:
+                logger.warning(
+                    f"Unable to set number of layers for module {module_name} in output config "
+                    "- you may need to manually correct it.",
+                    exc_info=e,
+                )
+
     return res
 
 
@@ -282,32 +323,4 @@ def _update_config_vocab(
         )
 
 
-def _load_arch_info(
-    merge_config: MergeConfiguration, options: MergeOptions
-) -> ArchitectureInfo:
-    """
-    Loads architecture information, handling cases where models lack predefined architecture info.
-    """
-    model_arch_info = [
-        ArchitectureInfoUtils.get_architecture_info(
-            m.config(trust_remote_code=options.trust_remote_code)
-        )
-        for m in merge_config.referenced_models()
-    ]
-
-    if all(a is not None for a in model_arch_info):
-        if not options.allow_crimes and not all(
-            a == model_arch_info[0] for a in model_arch_info[1:]
-        ):
-            raise RuntimeError(
-                "Must specify --allow-crimes to attempt to mix different architectures"
-            )
-        return model_arch_info[0]
-    else:
-        warnings.warn("Attempting Automatic Merge.")
-        model_arch_info = ArchitectureInfoUtils.infer_architecture_info(merge_config)
-
-    return model_arch_info
-
-
 __all__ = ["MergeOptions", "run_merge"]
diff --git a/mergekit/moe/deepseek.py b/mergekit/moe/deepseek.py
index 9f8a4b1f..dba6b78e 100644
--- a/mergekit/moe/deepseek.py
+++ b/mergekit/moe/deepseek.py
@@ -10,7 +10,7 @@
 import tqdm
 import transformers
 
-from mergekit.architecture import ArchitectureInfoUtils
+from mergekit.architecture import arch_info_for_config
 from mergekit.moe.arch import MoEOutputArchitecture
 from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
 from mergekit.moe.config import MoEMergeConfig
@@ -126,7 +126,7 @@ def write_model(
         loaders, base_loader, writer = initialize_io(config, out_path, merge_options)
         shared_loader = loaders.get(shared_def.source_model) if shared_def else None
         for weight_info in tqdm.tqdm(
-            ArchitectureInfoUtils.get_architecture_info(base_cfg).all_weights(base_cfg),
+            arch_info_for_config(base_cfg).all_weights(base_cfg),
             desc="Weights",
         ):
             tensor_name = weight_info.name
diff --git a/mergekit/moe/mixtral.py b/mergekit/moe/mixtral.py
index 5f0c7dfd..187e5f1e 100644
--- a/mergekit/moe/mixtral.py
+++ b/mergekit/moe/mixtral.py
@@ -8,7 +8,8 @@
 import tqdm
 import transformers
 
-from mergekit.architecture import MISTRAL_INFO, WeightInfo
+from mergekit.architecture import WeightInfo
+from mergekit.architecture.mixtral import MISTRAL_INFO
 from mergekit.moe.arch import MoEOutputArchitecture
 from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
 from mergekit.moe.config import MoEMergeConfig
diff --git a/mergekit/moe/qwen.py b/mergekit/moe/qwen.py
index f5730a6c..46cc820c 100644
--- a/mergekit/moe/qwen.py
+++ b/mergekit/moe/qwen.py
@@ -12,12 +12,14 @@
 # if the transformers version installed is too old
 from transformers.models.qwen2_moe import Qwen2MoeConfig
 
-from mergekit.architecture import QWEN2_INFO
+from mergekit.architecture.json_definitions import NAME_TO_ARCH
 from mergekit.moe.arch import MoEOutputArchitecture
 from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype
 from mergekit.moe.config import MoEMergeConfig
 from mergekit.options import MergeOptions
 
+QWEN2_INFO = NAME_TO_ARCH["Qwen2ForCausalLM"][0]
+
 
 class QwenMoE(MoEOutputArchitecture):
     def name(self) -> str:
diff --git a/mergekit/multigpu_executor.py b/mergekit/multigpu_executor.py
index 73c41e42..88aec975 100644
--- a/mergekit/multigpu_executor.py
+++ b/mergekit/multigpu_executor.py
@@ -50,7 +50,7 @@ def __init__(
             num_gpus: Number of GPUs to utilize (None = all available)
             storage_device: Device for storing tensors between stages
         """
-        self.results = {}
+        self.results: Dict[Task, Any] = {}
         self.targets = set(tasks)
         self.storage_device = storage_device
 
@@ -140,10 +140,10 @@ def update_progress():
                     )
 
                 for future in concurrent.futures.as_completed(futures):
-                    if future.exception():
+                    if ex := future.exception():
                         self.done_event.set()
                         executor.shutdown(wait=False)
-                        raise future.exception()
+                        raise ex
 
             self.done_event.set()
             progress_thread.join()
@@ -237,7 +237,7 @@ def _assign_islands_to_gpus(
 
         islands = list(nx.weakly_connected_components(island_graph))
         logger.info(f"Found {len(islands)} islands in parallel task graph")
-        assignments = {}
+        assignments: Dict[torch.device, List[Task]] = {}
         for island in islands:
             # Borrow orderings from original task list
             island_tasks = [t for t in tasks if t in island]
diff --git a/mergekit/options.py b/mergekit/options.py
index fb88f6a3..86ec4f9a 100644
--- a/mergekit/options.py
+++ b/mergekit/options.py
@@ -164,7 +164,7 @@ def wrapper(*args, **kwargs):
 
 class PrettyPrintHelp(click.Command):
     def format_options(self, ctx: Context, formatter: HelpFormatter) -> None:
-        categories = {None: []}
+        categories: dict[str, list[Parameter]] = {None: []}
         for param in ctx.command.params:
             if param.name in OPTION_CATEGORIES:
                 category = OPTION_CATEGORIES[param.name]
diff --git a/mergekit/plan.py b/mergekit/plan.py
index 65e63bef..973bc69c 100644
--- a/mergekit/plan.py
+++ b/mergekit/plan.py
@@ -7,15 +7,17 @@
 
 from mergekit import merge_methods
 from mergekit.architecture import (
-    ArchitectureInfo,
-    ConfiguredArchitectureInfo,
+    ConfiguredModuleArchitecture,
+    ModelArchitecture,
     WeightInfo,
 )
+from mergekit.architecture.base import ConfiguredModelArchitecture
 from mergekit.common import ImmutableMap, ModelReference
 from mergekit.config import (
     ConfigReader,
     InputSliceDefinition,
     MergeConfiguration,
+    OutputModuleDefinition,
     OutputSliceDefinition,
 )
 from mergekit.graph import Task
@@ -34,18 +36,18 @@
 
 class MergePlanner:
     config: MergeConfiguration
-    arch_info: ArchitectureInfo
+    arch_info: ModelArchitecture
     options: MergeOptions
     out_model_config: Any
     _method: MergeMethod
     _tensors: List[Tuple[WeightInfo, Task]]
-    _current_layers: int = 0
+    _current_module_layers: int = 0
     _tokenizer_task: Optional[BuildTokenizer] = None
 
     def __init__(
         self,
         config: MergeConfiguration,
-        arch_info: ArchitectureInfo,
+        arch_info: ModelArchitecture,
         options: MergeOptions,
         out_model_config: Any,
     ):
@@ -54,6 +56,7 @@ def __init__(
         self.options = options
         self.out_model_config = out_model_config
         self._method = merge_methods.get(config.merge_method)
+        self._tensors = []
 
         token_cfg = {}
         tokenizer_source = config.tokenizer_source
@@ -69,9 +72,17 @@ def __init__(
                 add_tokens=tuple(token_cfg.keys()),
             )
 
+    def _out_module_arch(self, module: str) -> ConfiguredModuleArchitecture:
+        module_def = self.arch_info.modules[module]
+        return ConfiguredModuleArchitecture(
+            info=module_def.architecture,
+            config=self.out_model_config,
+            weight_prefix=module_def.weight_prefix,
+        )
+
     @lru_cache
-    def model_arch_info(self, model: ModelReference):
-        return ConfiguredArchitectureInfo(
+    def _model_arch(self, model: ModelReference):
+        return ConfiguredModelArchitecture(
             info=self.arch_info,
             config=model.config(trust_remote_code=self.options.trust_remote_code),
         )
@@ -79,41 +90,70 @@ def model_arch_info(self, model: ModelReference):
     def normalize_config(self):
         base_model = self.config.base_model
 
-        # if models to merge are specified instead of output slices, compute them
+        # models -> modules.models
         if self.config.models:
-            if self.config.slices:
-                raise RuntimeError(
-                    "Must specify either models to merge or output slices"
+            self.config.modules = {}
+            for module_name in self.arch_info.modules:
+                self.config.modules[module_name] = OutputModuleDefinition(
+                    name=module_name, models=self.config.models
                 )
+            self.config.models = None
 
-            slices_in = []
-            base_included = False
-
-            for model_in in self.config.models:
-                if base_model and model_in.model == base_model:
-                    base_included = True
-
-                model_info = self.model_arch_info(model_in.model)
-                slices_in.append(
-                    InputSliceDefinition(
-                        layer_range=[0, model_info.num_layers()],
-                        model=model_in.model,
-                        parameters=model_in.parameters,
-                    )
+        # slices -> modules.slices
+        if self.config.slices:
+            if len(self.arch_info.modules) != 1:
+                raise RuntimeError(
+                    "Model has multiple modules, must use modules: config syntax "
+                    "to work with slices"
                 )
+            module_name = list(self.arch_info.modules.keys())[0]
+            self.config.modules = {
+                module_name: OutputModuleDefinition(slices=self.config.slices)
+            }
+            self.config.slices = None
+
+        # modules.models -> modules.slices
+        for module_name in self.config.modules:
+            module_out = self.config.modules[module_name]
+            module_arch = self.arch_info.modules[module_name].architecture
+
+            if module_out.models:
+                slices_in = []
+                base_included = False
+
+                for model_in in module_out.models:
+                    if base_model and model_in.model == base_model:
+                        base_included = True
+
+                    model_cfg = model_in.model.config(
+                        trust_remote_code=self.options.trust_remote_code
+                    )
+                    num_layers = module_arch.num_layers(model_cfg)
+                    slices_in.append(
+                        InputSliceDefinition(
+                            layer_range=[0, num_layers],
+                            model=model_in.model,
+                            parameters=model_in.parameters,
+                        )
+                    )
 
-            if base_model and not base_included:
-                logging.info("Base model specified but not in input models - adding")
-                base_info = self.model_arch_info(base_model)
-                slices_in.append(
-                    InputSliceDefinition(
-                        layer_range=[0, base_info.num_layers()],
-                        model=base_model,
+                if base_model and not base_included:
+                    logging.info(
+                        "Base model specified but not in input models - adding"
+                    )
+                    base_cfg = base_model.config(
+                        trust_remote_code=self.options.trust_remote_code
+                    )
+                    num_layers = module_arch.num_layers(base_cfg)
+                    slices_in.append(
+                        InputSliceDefinition(
+                            layer_range=[0, num_layers],
+                            model=base_model,
+                        )
                     )
-                )
 
-            self.config.slices = [OutputSliceDefinition(sources=slices_in)]
-            self.config.models = None
+                module_out.slices = [OutputSliceDefinition(sources=slices_in)]
+                module_out.models = None
 
     def plan_tensor(
         self,
@@ -201,15 +241,16 @@ def plan_layer(
         layer_offset: int,
         t: float,
         cfg_reader: ConfigReader,
+        module_name: str,
     ):
-        weights_out: List[WeightInfo] = self.arch_info.layer_weights(
-            index=self._current_layers,
-            config=self.out_model_config,
+        module_arch = self._out_module_arch(module_name)
+        weights_out: List[WeightInfo] = module_arch.layer_weights(
+            index=self._current_module_layers,
         )
         weights_in: List[List[WeightInfo]] = [
-            self.model_arch_info(s.model).layer_weights(
-                index=s.layer_range[0] + layer_offset
-            )
+            self._model_arch(s.model)
+            .get_module(module_name)
+            .layer_weights(index=s.layer_range[0] + layer_offset)
             for s in sources
         ]
 
@@ -221,9 +262,14 @@ def plan_layer(
                 cfg_reader=cfg_reader.with_t(t),
             )
 
-        self._current_layers += 1
+        self._current_module_layers += 1
 
-    def plan_slice(self, definition: OutputSliceDefinition):
+    def plan_slice(
+        self,
+        definition: OutputSliceDefinition,
+        module_def: OutputModuleDefinition,
+        module_name: str,
+    ):
         slice_lengths = [
             s.layer_range[1] - s.layer_range[0] for s in definition.sources
         ]
@@ -233,7 +279,9 @@ def plan_slice(self, definition: OutputSliceDefinition):
             )
         num_layers = slice_lengths[0]
 
-        cfg_reader = ConfigReader(config=self.config, slice_out=definition, t=0)
+        cfg_reader = ConfigReader(
+            config=self.config, slice_out=definition, t=0, module=module_def
+        )
         for idx in range(num_layers):
             # compute t for interpolated gradients
             if num_layers > 1:
@@ -246,6 +294,40 @@ def plan_slice(self, definition: OutputSliceDefinition):
                 layer_offset=idx,
                 t=t,
                 cfg_reader=cfg_reader,
+                module_name=module_name,
+            )
+
+    def plan_module(self, module_name: str, definition: OutputModuleDefinition):
+        self._current_module_layers = 0
+
+        module_arch = self._out_module_arch(module_name)
+        config_reader = ConfigReader(config=self.config, t=0, module=definition)
+
+        for weight_info in module_arch.pre_weights():
+            self.plan_tensor(
+                weight_info,
+                [weight_info] * len(definition.slices[0].sources),
+                [s.model for s in definition.slices[0].sources],
+                config_reader.for_tensor(tensor_name=weight_info.name).for_out_slice(
+                    definition.slices[0]
+                ),
+            )
+
+        for out_slice in definition.slices:
+            self.plan_slice(
+                out_slice,
+                module_def=definition,
+                module_name=module_name,
+            )
+
+        for weight_info in module_arch.post_weights():
+            self.plan_tensor(
+                weight_info,
+                [weight_info] * len(definition.slices[0].sources),
+                [s.model for s in definition.slices[-1].sources],
+                config_reader.for_tensor(tensor_name=weight_info.name).for_out_slice(
+                    definition.slices[-1]
+                ),
             )
 
     def plan_to_disk(self, out_path: str) -> List[Task]:
@@ -292,31 +374,7 @@ def plan_in_memory(self) -> List[ReturnTensor]:
 
     def _plan(self):
         self.normalize_config()
-        self._tensors = []
-
-        for weight_info in self.arch_info.pre_weights(config=self.out_model_config):
-            self.plan_tensor(
-                weight_info,
-                [weight_info] * len(self.config.slices[0].sources),
-                [s.model for s in self.config.slices[0].sources],
-                ConfigReader(
-                    config=self.config,
-                    t=0,
-                    tensor_name=weight_info.name,
-                ).for_out_slice(self.config.slices[0]),
-            )
-
-        for out_slice in self.config.slices:
-            self.plan_slice(out_slice)
+        self._tasks = []
 
-        for weight_info in self.arch_info.post_weights(config=self.out_model_config):
-            self.plan_tensor(
-                weight_info,
-                [weight_info] * len(self.config.slices[-1].sources),
-                [s.model for s in self.config.slices[-1].sources],
-                ConfigReader(
-                    config=self.config,
-                    t=1,
-                    tensor_name=weight_info.name,
-                ).for_out_slice(self.config.slices[-1]),
-            )
+        for module_name in self.config.modules:
+            self.plan_module(module_name, self.config.modules[module_name])
diff --git a/mergekit/scripts/ABM/activations_based_merge.py b/mergekit/scripts/ABM/activations_based_merge.py
deleted file mode 100644
index 3834892d..00000000
--- a/mergekit/scripts/ABM/activations_based_merge.py
+++ /dev/null
@@ -1,171 +0,0 @@
-import logging
-import os
-from typing import Optional
-
-import click
-import safetensors.torch
-import torch
-import tqdm
-from transformers import AutoTokenizer
-
-from mergekit.architecture import ArchitectureInfoUtils
-from mergekit.common import ModelReference, dtype_from_name
-from mergekit.io.tasks import LoaderCache
-from mergekit.io.tensor_writer import TensorWriter
-from mergekit.options import MergeOptions, add_merge_options
-
-
-@click.command("mergekit-activation-based-merge")
-@click.argument("model_path", type=str)
-@click.argument("secondary_model_path", type=str)
-@click.argument("merge_unmerge_directory", type=str)
-@click.option("--out-path", "-o", required=True, type=str, help="Output model path")
-@click.option(
-    "--dtype",
-    type=str,
-    default="float16",
-    help="Data type to convert weights to",
-)
-@click.option(
-    "--device",
-    "-d",
-    type=str,
-    default="cuda",
-    help="Device to compute on (default: cuda)",
-)
-@add_merge_options
-def main(
-    model_path: str,
-    secondary_model_path,
-    merge_unmerge_directory: str,
-    out_path: str,
-    dtype: Optional[str],
-    device: Optional[str],
-    merge_options: MergeOptions,
-):
-    model = ModelReference.model_validate(model_path)
-    secondary_model = ModelReference.model_validate(secondary_model_path)
-
-    dtype = dtype_from_name(dtype) if dtype else None
-
-    cache = LoaderCache()
-    cache.lazy_unpickle = merge_options.lazy_unpickle
-    cache.hf_cache_dir = merge_options.transformers_cache
-
-    for m in tqdm.tqdm([model, secondary_model], desc="Preparing models"):
-        cache.get(m)
-
-    writer = TensorWriter(
-        out_path=out_path,
-        max_shard_size=merge_options.out_shard_size,
-        safe_serialization=merge_options.safe_serialization,
-    )
-
-    model_config = model.config(trust_remote_code=merge_options.trust_remote_code)
-    model_arch_info = ArchitectureInfoUtils.get_architecture_info(
-        model.config(trust_remote_code=merge_options.trust_remote_code)
-    )
-
-    loader_1 = cache.get(model)
-    loader_2 = cache.get(secondary_model)
-
-    os.makedirs(out_path, exist_ok=True)
-
-    merge_unmerge_dictionary = {}
-    # load files from merge_unmerge_directory
-    spaces = [
-        f.split("_unmerge")[0]
-        for f in os.listdir(merge_unmerge_directory)
-        if "_unmerge" in f
-    ]
-    for i in spaces:
-        logging.info(f"Loading merge/unmerge tensors for {i}")
-        m = safetensors.torch.load_file(
-            os.path.join(merge_unmerge_directory, f"{i}_merge.safetensor"),
-            device=device,
-        )
-        u = safetensors.torch.load_file(
-            os.path.join(merge_unmerge_directory, f"{i}_unmerge.safetensor"),
-            device=device,
-        )
-        merge_unmerge_dictionary[i] = (
-            m[i].to(device, dtype=dtype),
-            u[i].to(device, dtype=dtype),
-        )
-
-    for weight_info in model_arch_info.all_weights(config=model_config):
-        merge_matrix, unmerge_matrix = None, None
-
-        if weight_info.input_space in merge_unmerge_dictionary:
-            _, unmerge_matrix = merge_unmerge_dictionary[weight_info.input_space]
-            unmerge_matrix = unmerge_matrix.chunk(2, dim=0)
-
-        if weight_info.output_space in merge_unmerge_dictionary:
-            merge_matrix, _ = merge_unmerge_dictionary[weight_info.output_space]
-            merge_matrix = merge_matrix.chunk(2, dim=1)
-
-        original_w = loader_1.get_tensor(weight_info.name, device=device)
-        original_w2 = loader_2.get_tensor(weight_info.name, device=device)
-
-        if dtype is not None:
-            original_w = original_w.to(dtype=dtype)
-            original_w2 = original_w2.to(dtype=dtype)
-
-        w = torch.clone(original_w)
-        w2 = torch.clone(original_w2)
-
-        if not merge_matrix and not unmerge_matrix:
-            logging.warning(
-                f"❌ Weight {weight_info.name} for model 1 and model 2 has no merge or unmerge matrix"
-            )
-
-        if merge_matrix is not None:
-            if weight_info.is_embed:
-                w = (merge_matrix[0] @ w.T).T
-                w2 = (merge_matrix[1] @ w2.T).T
-            else:
-                w = merge_matrix[0] @ w
-                w2 = merge_matrix[1] @ w2
-
-        if unmerge_matrix is not None:
-            w = w @ unmerge_matrix[0]
-            w2 = w2 @ unmerge_matrix[1]
-
-        # check if weights have not mutated, if yes then  shoot warning
-        if torch.allclose(original_w, w):
-            logging.warning(
-                f"❌ Weight {weight_info.name} for model 1 has NOT mutated during merge"
-            )
-        else:
-            logging.warning(
-                f"✅ Weight {weight_info.name} for model 1 has mutated during merge"
-            )
-
-        if torch.allclose(original_w2, w2):
-            logging.warning(
-                f"❌ Weight {weight_info.name} for model 2 has NOT mutated during merge"
-            )
-        else:
-            logging.warning(
-                f"✅ Weight {weight_info.name} for model 2 has mutated during merge"
-            )
-
-        # average weights and save them
-        if merge_matrix:
-            w = w + w2
-        else:
-            w = (w + w2) / 2
-        writer.save_tensor(weight_info.name, w)
-    writer.finalize()
-
-    tokenizer = AutoTokenizer.from_pretrained(model_path)
-    tokenizer.save_pretrained(out_path, safe_serialization=True)
-
-    # write config
-    model_out_config = model.config(trust_remote_code=merge_options.trust_remote_code)
-    if dtype:
-        model_out_config.torch_dtype = dtype
-    model_out_config.save_pretrained(out_path)
-
-
-main()
diff --git a/mergekit/scripts/ABM/extract_activations.py b/mergekit/scripts/ABM/extract_activations.py
deleted file mode 100644
index 3f7c151b..00000000
--- a/mergekit/scripts/ABM/extract_activations.py
+++ /dev/null
@@ -1,347 +0,0 @@
-import logging
-import os
-from collections import defaultdict
-from typing import List, Optional
-
-import click
-import datasets
-import numpy as np
-import torch
-from safetensors.torch import save_file
-from torch.utils.data import DataLoader
-from transformers import AutoModel, AutoTokenizer, DefaultDataCollator
-
-from mergekit.architecture import ArchitectureInfoUtils, _template_substitution
-from mergekit.common import ModelReference
-
-logging.basicConfig(level=logging.INFO)
-
-# set seed
-torch.manual_seed(42)
-np.random.seed(42)
-
-
-def clean_name(name):
-    return name.replace(".weight", "").replace("model.", "")
-
-
-def parse_items(ctx, param, value):
-    if value is not None:
-        return [item.strip() for item in value.split(",")]
-
-
-def remove_pads(attention_mask, feature_vector):
-    if (
-        len(feature_vector.shape) == 3
-    ):  # Hidden states: (batch_size, seq_length, embedding_dim)
-        # Expand mask to match the feature_vector dimensions and apply it
-        expanded_mask = attention_mask.unsqueeze(-1)
-        filtered_feature_vector = feature_vector * expanded_mask
-    else:
-        raise ValueError("Unsupported feature vector shape.")
-
-    return filtered_feature_vector
-
-
-def get_attention_output_hook(storage_dict, space_name, capture_input=True):
-    """
-    Returns a hook function that stores the output of the attention layer.
-    """
-
-    def hook(module, input, output):
-        # NOTE: shape of input is [batch, seq_len, dim] and output is Tuple[(seq_len, dim),...]
-        if capture_input:
-            o = input[0].detach()
-        else:
-            o = output.detach()
-
-        if space_name not in storage_dict:
-            storage_dict[space_name] = o
-        else:
-            storage_dict[space_name] = torch.cat((storage_dict[space_name], o), dim=0)
-
-    return hook
-
-
-"""
-
-What this script does:
-
-It tries to map input/output spaces to activation maps
-
-"""
-
-
-@click.command("mergekit-abm-extract-activations")
-@click.argument("model-path", type=str)
-@click.option(
-    "--dataset", "-d", required=True, type=str, help="Dataset to use for activations"
-)
-@click.option("--out-path", "-o", required=True, type=str, help="Output model path")
-@click.option("--batch-size", "-b", type=int, default=2, help="Batch size")
-@click.option(
-    "--dataset-size",
-    "-s",
-    type=int,
-    default=None,
-    help="Dataset size. If None, use full dataset",
-)
-@click.option(
-    "--dataset-column", "-c", type=str, default="text", help="Dataset column to use"
-)
-@click.option(
-    "--dataset-subset", "-u", type=str, default="eval", help="Dataset subset to use"
-)
-@click.option(
-    "--chat-template/--no-chat-template",
-    default=False,
-    help="use Chat template for inference",
-)
-@click.option("--max-length", "-l", type=int, default=512, help="Max length")
-@click.option("--dtype", type=str, default=None, help="Data type to convert weights to")
-@click.option(
-    "--device", type=str, default=None, help="device to compute the activations"
-)
-@click.option(
-    "--ignore-spaces",
-    "-i",
-    type=str,
-    default="",
-    callback=parse_items,
-    help="Spaces to ignore separated by comma. Example: up_${layer_index}",
-)
-def main(
-    model_path: str,
-    dataset: str,
-    dataset_column: str,
-    out_path: str,
-    batch_size: int,
-    max_length: int,
-    dataset_size: Optional[int],
-    dataset_subset: Optional[str],
-    chat_template: Optional[bool],
-    dtype: Optional[str],
-    device: Optional[str],
-    ignore_spaces: Optional[List[str]],
-):
-    # sorting out locations to hook into
-    # we do this via the predefined json architecture definitions in mergekit
-
-    model = ModelReference.model_validate(model_path)
-
-    model_config = model.config()
-    model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config)
-
-    _json = model_arch_info.definition
-
-    residual_space = None
-
-    weights = []
-    for weight in _json.layer_templates.weights:
-        if weight.is_kq:
-            residual_space = weight.input_space
-        weights.append(weight)
-
-    if residual_space is None:
-        raise ValueError("No residual space found")
-
-    # ======================== Mapping spaces to weights ========================
-
-    # just a list of connected components
-    space_to_output_weight_templates = defaultdict(list)
-    space_to_input_weight_templates = defaultdict(list)
-
-    for layer_template in weights:
-        if (
-            not layer_template.input_space
-            or layer_template.input_space in ignore_spaces
-        ):
-            continue
-        space_to_output_weight_templates[layer_template.input_space].append(
-            layer_template.name
-        )
-
-    for layer_template in weights:
-        if (
-            not layer_template.output_space
-            or layer_template.output_space in ignore_spaces
-        ):
-            continue
-        space_to_input_weight_templates[layer_template.output_space].append(
-            layer_template.name
-        )
-
-    # remove the residual space from the input and output
-    space_to_input_weight_templates.pop(residual_space, None)
-    space_to_output_weight_templates.pop(residual_space, None)
-
-    # NOTE: if space has input and output weights, remove one or the other because hooking
-    # into both will result in duplicate activations
-    to_remove = []
-    for space, input_weights in space_to_input_weight_templates.items():
-        if space in space_to_output_weight_templates:
-            # if count of input weights and output weights is non zero, remove the space from space to output_weights
-            if (
-                len(input_weights) > 0
-                and len(space_to_output_weight_templates[space]) > 0
-            ):
-                to_remove.append(space)
-
-    # remove keys from output
-    space_to_output_weight_templates = {
-        k: v for k, v in space_to_output_weight_templates.items() if k not in to_remove
-    }
-
-    num_layers = model_arch_info.num_layers(model_config)
-
-    space_to_input_weights = {}
-    for k, v in space_to_input_weight_templates.items():
-        for j in range(num_layers):
-            f = lambda x: _template_substitution(x, num_layers=num_layers, layer_idx=j)
-            space_to_input_weights[f(k)] = [f(_v) for _v in v]
-
-    space_to_output_weights = {}
-    for k, v in space_to_output_weight_templates.items():
-        for j in range(num_layers):
-            f = lambda x: _template_substitution(x, num_layers=num_layers, layer_idx=j)
-            space_to_output_weights[f(k)] = [f(_v) for _v in v]
-
-    # ================== Load model, tokenizer for inference and prepare dataset ==================
-
-    model = AutoModel.from_pretrained(
-        model_path, output_attentions=True, attn_implementation="eager"
-    )
-    tokenizer = AutoTokenizer.from_pretrained(model_path)
-
-    if not tokenizer.pad_token:
-        tokenizer.pad_token = tokenizer.eos_token
-
-    tokenize_function = None
-    if chat_template:
-        logging.info("Using chat template for inference")
-        tokenize_function = lambda x: tokenizer.apply_chat_template(
-            x,
-            padding="longest",
-            max_length=max_length,
-            truncation=True,
-            return_dict=True,
-        )
-    else:
-        logging.info("Using default tokenizer (no chat template) for inference")
-        tokenize_function = lambda x: tokenizer(
-            x,
-            padding="longest",
-            max_length=max_length,
-            truncation=True,
-        )
-
-    model.eval()
-    model.to(device)
-    if dtype is not None:
-        model = model.to(dtype=dtype)
-
-    dataset = datasets.load_dataset(dataset)[dataset_subset]
-
-    if dataset_size is not None:
-        logging.info("Using dataset size %s", dataset_size)
-        dataset = dataset.select(range(dataset_size))
-
-    def tokenize(element):
-        outputs = tokenize_function(element[dataset_column])
-        return {
-            "input_ids": outputs["input_ids"],
-            "attention_mask": outputs["attention_mask"],
-        }
-
-    dataset = dataset.map(tokenize).select_columns(["input_ids", "attention_mask"])
-
-    datasets_dataloader = DataLoader(
-        dataset, batch_size=batch_size, shuffle=False, collate_fn=DefaultDataCollator()
-    )
-
-    feature_storage = {}
-    storage_dict = {}
-
-    # ================== Hooking into the model ==================
-
-    # NOTE: if the capture input set to True seems confusing, a space's output is a weight recieving input from the space
-    for k, v in space_to_output_weights.items():
-        for weight in v:
-            weight = clean_name(weight)
-            model.get_submodule(weight).register_forward_hook(
-                get_attention_output_hook(feature_storage, k, capture_input=True)
-            )
-    for k, v in space_to_input_weights.items():
-        for weight in v:
-            weight = clean_name(weight)
-            model.get_submodule(weight).register_forward_hook(
-                get_attention_output_hook(feature_storage, k, capture_input=False)
-            )
-
-    # ================== Inference ==================
-
-    for batch in datasets_dataloader:
-        with torch.no_grad():
-            inputs = {k: v.to(device) for k, v in batch.items()}
-            outputs = model(
-                **inputs, output_hidden_states=True, output_attentions=False
-            )
-
-            # NOTE: https://huggingface.co/docs/transformers/en/main_classes/output#transformers.modeling_outputs.BaseModelOutput
-
-            # Store attention masks
-            attention_mask = inputs["attention_mask"]
-            if "attention_mask" not in feature_storage:
-                feature_storage["attention_mask"] = attention_mask.cpu().detach()
-            else:
-                feature_storage["attention_mask"] = torch.cat(
-                    (feature_storage["attention_mask"], attention_mask.cpu().detach()),
-                    dim=0,
-                )
-
-            hidden_states = [
-                remove_pads(attention_mask, hidden_state)
-                for hidden_state in outputs.hidden_states
-            ]
-            hidden_states = torch.stack(outputs.hidden_states, dim=1)
-
-            if residual_space not in feature_storage:
-                feature_storage[residual_space] = hidden_states
-            else:
-                feature_storage[residual_space] = torch.cat(
-                    (feature_storage[residual_space], hidden_states), dim=0
-                )
-
-            for space_name, v in storage_dict.items():
-                if space_name not in feature_storage:
-                    feature_storage[space_name] = v
-                else:
-                    feature_storage[space_name] = torch.cat(
-                        (feature_storage[space_name], v), dim=0
-                    )
-
-            storage_dict = {}
-
-    # ================== Save activations/features ==================
-
-    logging.info("Feature storage:")
-    for k, v in feature_storage.items():
-        if v is not None:
-            logging.info(f"{k}: Shape: {v.shape}")
-
-    abs_path = os.path.abspath(model_path)
-    if os.path.exists(abs_path):
-        model_path = abs_path
-
-    model_path = model_path.replace("/", "_")
-
-    # create output directory
-    os.makedirs(out_path, exist_ok=True)
-
-    save_file(
-        feature_storage, os.path.join(out_path, f"{model_path}_features.safetensor")
-    )
-
-
-if __name__ == "__main__":
-    main()
diff --git a/mergekit/scripts/ABM/extract_permutation_matrices.py b/mergekit/scripts/ABM/extract_permutation_matrices.py
deleted file mode 100644
index 4c862664..00000000
--- a/mergekit/scripts/ABM/extract_permutation_matrices.py
+++ /dev/null
@@ -1,226 +0,0 @@
-import os
-import sys
-from collections import defaultdict
-
-import click
-import numpy as np
-import safetensors.torch
-import scipy
-import torch
-
-from mergekit.architecture import ArchitectureInfoUtils, _template_substitution
-from mergekit.common import ModelReference
-
-
-def calc_correlation_matrix(feats):
-    feats = feats.view(-1, feats.shape[-1])
-
-    return torch.corrcoef(feats.T)
-
-
-def match_tensors_permute(
-    absval=False,
-    correlation_matrix=None,
-):
-    """
-    This function is adapted from ZipIt! (https://github.com/gstoica27/ZipIt)
-    """
-
-    Om = correlation_matrix.shape[0] // 2
-    device = correlation_matrix.device
-
-    mats = [torch.eye(Om, device=device)]
-
-    corr_submatrix = correlation_matrix[:Om, Om:].cpu().numpy()
-    if absval:
-        corr_submatrix = np.absolute(corr_submatrix)
-    _, col_ind = scipy.optimize.linear_sum_assignment(corr_submatrix, maximize=True)
-
-    new_mat = torch.eye(Om, device=device)[torch.tensor(col_ind).long().to(device)]
-    mats.append(new_mat.T)
-
-    unmerge_mats = mats
-
-    unmerge = torch.cat(unmerge_mats, dim=0)
-
-    merge = torch.cat(mats, dim=0)
-    merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5)
-
-    return merge.T, unmerge
-
-
-def match_tensors_permute_MHA(
-    n_heads=32,
-    absval=False,
-    correlation_matrix=None,
-):
-    """
-    Handles different head permutations in attention.
-    Modified version of the function here: https://github.com/nverma1/merging-text-transformers/blob/main/matching_functions.py#L76
-    """
-
-    Om = correlation_matrix.shape[0] // 2
-    device = correlation_matrix.device
-    query_size = Om // n_heads
-
-    mats = [torch.eye(Om, device=device)]
-    head_perms = []
-
-    costs = np.ones((n_heads, n_heads)) * -sys.maxsize
-
-    col_inds_storage = defaultdict(lambda: defaultdict(int))
-
-    for j in range(n_heads):
-        for k in range(n_heads):
-            head1_idx = [query_size * j, query_size * (j + 1)]
-            head2_idx = [query_size * k, query_size * (k + 1)]
-
-            corr_submatrix = (
-                correlation_matrix[
-                    head1_idx[0] : head1_idx[1],
-                    (Om + head2_idx[0]) : (Om + head2_idx[1]),
-                ]
-                .cpu()
-                .numpy()
-            )
-            if absval:
-                corr_submatrix = np.absolute(corr_submatrix)
-
-            # compute perm for head j & head k
-            row_ind, col_ind = scipy.optimize.linear_sum_assignment(
-                corr_submatrix, maximize=True
-            )
-
-            costs[j, k] = corr_submatrix[row_ind, col_ind].sum()
-
-            col_inds_storage[j][k] = col_ind
-
-    outer_row_ind, outer_col_ind = scipy.optimize.linear_sum_assignment(
-        costs, maximize=True
-    )
-
-    for j in range(n_heads):
-        head_1 = outer_row_ind[j]
-        head_2 = outer_col_ind[j]
-
-        head_perm = col_inds_storage[head_1][head_2]
-        head_perms.append(torch.tensor(head_perm + query_size * head_2))
-
-    new_mat = torch.eye(Om, device=device)[
-        torch.cat(head_perms).clone().detach().long().to(device)
-    ]
-    mats.append(new_mat.T)
-
-    unmerge_mats = mats
-
-    unmerge = torch.cat(unmerge_mats, dim=0)
-    merge = torch.cat(mats, dim=0)
-    merge = merge / (merge.sum(dim=0, keepdim=True) + 1e-5)
-
-    return merge.T, unmerge
-
-
-@click.command("mergekit-abm-extract-permutations")
-@click.argument("model1-ft", type=str, required=True)
-@click.argument("model2-ft", type=str, required=True)
-@click.option("--model_path", type=str, required=True, help="Model information")
-@click.option(
-    "--out_path", required=True, type=str, help="Output path for metric tensors"
-)
-@click.option(
-    "--absval/--no-absval",
-    required=False,
-    default=False,
-    help="Use absolute value on correlation matrices/submatrices while calculating merge/unmerge matrices",
-)
-@click.option(
-    "--device",
-    "-d",
-    type=str,
-    default="cpu",
-    help="Device to compute on (default: cpu)",
-)
-def main(model1_ft, model2_ft, model_path, out_path, absval, device):
-    os.makedirs(out_path, exist_ok=True)
-
-    model = ModelReference.model_validate(model_path)
-
-    model_config = model.config()
-
-    model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_config)
-
-    _json = model_arch_info.definition
-
-    residual_space = None
-    kq_space = None
-    v_space = None
-
-    # extract the residual, attention related spaces
-    for weight in _json.layer_templates.weights:
-        if weight.is_kq:
-            kq_space = weight.output_space
-            residual_space = weight.input_space
-            continue
-
-        # assuming order is observed
-        if (
-            not weight.is_kq
-            and weight.head_split
-            and (weight.input_space == residual_space)
-        ):
-            v_space = weight.output_space
-            continue
-
-    num_layers = model_arch_info.num_layers(model_config)
-
-    kq_spaces = []
-    v_spaces = []
-    for j in range(num_layers):
-        kq_spaces.append(
-            _template_substitution(kq_space, num_layers=num_layers, layer_idx=j)
-        )
-        v_spaces.append(
-            _template_substitution(v_space, num_layers=num_layers, layer_idx=j)
-        )
-
-    model1_features = safetensors.torch.load_file(model1_ft, device=device)
-    model2_features = safetensors.torch.load_file(model2_ft, device=device)
-
-    model1_features.pop("attention_mask")
-    model2_features.pop("attention_mask")
-
-    for feature_space in model1_features.keys():
-        concatenated_feature = torch.cat(
-            (model1_features[feature_space], model2_features[feature_space]), dim=-1
-        )
-
-        correlation_matrix = calc_correlation_matrix(concatenated_feature)
-
-        if feature_space in (kq_spaces + v_spaces):
-            merge, unmerge = match_tensors_permute_MHA(
-                correlation_matrix=correlation_matrix,
-                n_heads=model_config.num_attention_heads,
-                absval=absval,
-            )
-
-        else:
-            merge, unmerge = match_tensors_permute(
-                correlation_matrix=correlation_matrix,
-                absval=absval,
-            )
-
-        safetensors.torch.save_file(
-            {feature_space: merge.contiguous()},
-            f"{out_path}/{feature_space}_merge.safetensor",
-        )
-
-        safetensors.torch.save_file(
-            {feature_space: unmerge.contiguous()},
-            f"{out_path}/{feature_space}_unmerge.safetensor",
-        )
-
-        del merge, unmerge, correlation_matrix, concatenated_feature
-
-
-if __name__ == "__main__":
-    main()
diff --git a/mergekit/scripts/evolve.py b/mergekit/scripts/evolve.py
index 9ad35eb3..7dafc0bd 100644
--- a/mergekit/scripts/evolve.py
+++ b/mergekit/scripts/evolve.py
@@ -1,6 +1,7 @@
 # Copyright (C) 2025 Arcee AI
 # SPDX-License-Identifier: BUSL-1.1
 
+import importlib.util
 import logging
 import os
 import time
@@ -126,7 +127,7 @@ def main(
     vllm: bool,
     strategy: str,
     in_memory: bool,
-    storage_path: Optional[str],
+    storage_path: str,
     num_gpus: Optional[int],
     merge_cuda: bool,
     trust_remote_code: bool,
@@ -160,9 +161,7 @@ def main(
             raise ValueError("Cannot use vLLM with 4-bit or 8-bit models")
         if in_memory:
             raise ValueError("Cannot use in-memory mode with 4-bit or 8-bit models")
-        try:
-            import bitsandbytes
-        except ImportError:
+        if not importlib.util.find_spec("bitsandbytes"):
             raise RuntimeError("bitsandbytes is not installed")
 
         bnb_config = transformers.BitsAndBytesConfig(
@@ -271,7 +270,7 @@ def progress_callback(es: cma.CMAEvolutionStrategy):
         nonlocal xbest, xbest_cost
 
         res = es.result
-        if use_wandb:
+        if use_wandb and run is not None:
             best_params = genome.genotype_to_param_arrays(res.xbest)
             mean_params = genome.genotype_to_param_arrays(res.xfavorite)
             run.log(
@@ -377,7 +376,10 @@ def parallel_evaluate(x: List[np.ndarray]) -> List[float]:
 
 
 def _reshard_model(
-    model: ModelReference, storage_path: str, merge_cache: str, trust_remote_code: bool
+    model: ModelReference,
+    storage_path: str,
+    merge_cache: Optional[str],
+    trust_remote_code: bool,
 ) -> ModelReference:
     merged = model.merged(
         cache_dir=merge_cache,
diff --git a/mergekit/scripts/extract_lora.py b/mergekit/scripts/extract_lora.py
index 023b10fe..53b055a4 100644
--- a/mergekit/scripts/extract_lora.py
+++ b/mergekit/scripts/extract_lora.py
@@ -12,12 +12,12 @@
 import torch
 import torch.nn as nn
 import tqdm
+import transformers
 from pydantic import BaseModel
-from transformers import AutoModelForCausalLM
 
-from mergekit.architecture import ArchitectureInfoUtils, WeightInfo
+from mergekit.architecture import WeightInfo, arch_info_for_config
 from mergekit.card import generate_card_lora
-from mergekit.common import ModelReference
+from mergekit.common import ModelReference, get_auto_cls
 from mergekit.graph import Executor, Task
 from mergekit.io.tasks import FinalizeModel, LoadTensor, SaveTensor, TensorWriterTask
 from mergekit.io.tensor_writer import TensorWriter
@@ -323,6 +323,20 @@ def _wi_load(model_ref: ModelReference, weight_info: WeightInfo) -> LoadTensor:
     )
 
 
+def _make_dummy_model(
+    model_ref: ModelReference, trust_remote_code: bool = False
+) -> transformers.PreTrainedModel:
+    model_cfg = transformers.AutoConfig.from_pretrained(
+        model_ref.model.path,
+        revision=model_ref.model.revision,
+        trust_remote_code=trust_remote_code,
+    )
+    auto_cls = get_auto_cls(model_cfg.architectures[0])
+    with torch.device("meta"):
+        res = auto_cls.from_config(model_cfg, trust_remote_code=trust_remote_code)
+    return res
+
+
 class PlanResults(BaseModel):
     tasks: List[Task]
     base_vocab_size: int
@@ -352,20 +366,8 @@ def plan_extraction(
     )
 
     name_to_wi = all_weights_map(model_ref, options)
-    dummy_model = AutoModelForCausalLM.from_pretrained(
-        model_ref.model.path,
-        revision=model_ref.model.revision,
-        trust_remote_code=options.trust_remote_code,
-        device_map="meta",
-        state_dict={},
-    )
-    dummy_base = AutoModelForCausalLM.from_pretrained(
-        base_model_ref.model.path,
-        revision=base_model_ref.model.revision,
-        trust_remote_code=options.trust_remote_code,
-        device_map="meta",
-        state_dict={},
-    )
+    dummy_base = _make_dummy_model(base_model_ref, options.trust_remote_code)
+    dummy_model = _make_dummy_model(model_ref, options.trust_remote_code)
 
     embed_in = dummy_model.get_input_embeddings()
     embed_out = dummy_model.get_output_embeddings()
@@ -378,6 +380,7 @@ def plan_extraction(
         )
         logger.warning("Enforcing embeddings in modules_to_save, embed_lora=False")
         embed_lora = False
+    del dummy_base
 
     warned_modules = set()
 
@@ -553,7 +556,7 @@ def all_weights_map(
 ) -> Dict[str, WeightInfo]:
     name_to_wi = {}
     model_cfg = model_ref.config(trust_remote_code=options.trust_remote_code)
-    arch_info = ArchitectureInfoUtils.get_architecture_info(model_cfg)
+    arch_info = arch_info_for_config(model_cfg)
     for wi in arch_info.all_weights(model_cfg):
         name_to_wi[wi.name] = wi
     return name_to_wi
diff --git a/mergekit/scripts/fill_missing_params.py b/mergekit/scripts/fill_missing_params.py
index 81aec1b3..e8bc6d4d 100644
--- a/mergekit/scripts/fill_missing_params.py
+++ b/mergekit/scripts/fill_missing_params.py
@@ -3,13 +3,14 @@
 import logging
 import shutil
 from pathlib import Path
+from typing import List, Optional, Tuple
 
 import click
 import torch
+from huggingface_hub import snapshot_download
 from safetensors import safe_open
 from tqdm import tqdm
 
-from mergekit.architecture import ParameterNamesUtils
 from mergekit.io.lazy_tensor_loader import ShardedTensorIndex
 from mergekit.io.tensor_writer import TensorWriter
 
@@ -197,3 +198,199 @@ def main(
 
 if __name__ == "__main__":
     main()
+
+
+class ParameterNamesUtils:
+    """Utility functions for handling parameter names."""
+
+    @staticmethod
+    def resolve_model_directory(repo_id: str) -> Path:
+        """Resolve the model directory (local or Hugging Face Hub)."""
+        if Path(repo_id).is_dir():
+            return Path(repo_id)
+
+        return Path(snapshot_download(repo_id))
+
+    @staticmethod
+    def get_model_parameter_names(repo_id: str) -> List[str]:
+        """Get parameter names of a model from a Hugging Face repo or local directory."""
+        model_dir = ParameterNamesUtils.resolve_model_directory(repo_id)
+        return list(ShardedTensorIndex.from_disk(str(model_dir)).tensor_paths.keys())
+
+    @staticmethod
+    def strip_prefix(name: str, prefix: str) -> str:
+        """Remove a single prefix from the start of a name."""
+        if prefix != "" and name.startswith(prefix + "."):
+            return name[len(prefix) + 1 :]
+        return name
+
+    @staticmethod
+    def find_prefix(list1: List[str], list2: List[str]) -> Optional[str]:
+        """
+        Find a prefix in list1 that, after removal, makes list2 an ordered sublist.
+        """
+        assert len(list1) >= len(list2), "params name list1 can't be shorter than list2"
+
+        possible_prefixes = {item.split(".")[0] for item in list1 if "." in item}
+        possible_prefixes = [""] + list(possible_prefixes)
+
+        prefix_matches = {}
+        best_prefix = ""  # Default to no prefix
+        for prefix in possible_prefixes:
+            stripped_list1 = [
+                ParameterNamesUtils.strip_prefix(item, prefix) for item in list1
+            ]
+            prefix_matches[prefix] = len(
+                [item for item in list2 if item in stripped_list1]
+            )
+
+        if max(prefix_matches.values()) > prefix_matches[""]:
+            best_prefix = max(prefix_matches, key=prefix_matches.get)
+
+        return best_prefix
+
+    @staticmethod
+    def find_common_ordered_names(
+        param_names: List[List[str]], prefixes: List[str]
+    ) -> List[str]:
+        """Identify and return common parameter names across all models, ensuring correct order. Also account for prefix."""
+        common_names = set(param_names[0])
+        for i in range(1, len(param_names)):
+            prefix = f"{prefixes[i]}." if prefixes[i] else ""
+            common_names.intersection_update({prefix + name for name in param_names[i]})
+        return [name for name in param_names[0] if name in common_names]
+
+    @staticmethod
+    def remove_size_conflicts(common_names, referenced_models, prefixes):
+        model_dirs = [
+            ParameterNamesUtils.resolve_model_directory(m.model.path)
+            for m in referenced_models
+        ]
+        model_indices = [ShardedTensorIndex.from_disk(str(dir)) for dir in model_dirs]
+
+        common_name_and_shape = common_names.copy()
+        removed_names = []
+
+        for name in common_names:
+            base_shape = ParameterNamesUtils.tensor_shape(name, model_indices[0])
+
+            for i in range(1, len(referenced_models)):
+                other_name = name
+                prefix = f"{prefixes[i]}." if prefixes[i] else ""
+                if name.startswith(prefix) and prefix != "":
+                    other_name = name[len(prefix) :]
+                shape = ParameterNamesUtils.tensor_shape(other_name, model_indices[i])
+
+                if base_shape != shape:
+                    common_name_and_shape.remove(name)
+                    removed_names.append((name, base_shape, shape, i))
+                    break
+
+        size_mismatch_count = len(removed_names)
+        if size_mismatch_count > 0:
+            logging.warning(
+                f"Size mismatch detected for {size_mismatch_count}/{size_mismatch_count + len(common_names)} tensors. "
+                "These names were removed from the merge list."
+            )
+            logging.info(
+                "The following tensors have different shapes across models and were removed from the merge list:"
+            )
+            for name, base_shape, shape, i in removed_names:
+                logging.info(
+                    f"Tensor name: {name}, Base model shape: {base_shape}, Mismatched shape: {shape} in model {referenced_models[i].model.path}"
+                )
+
+        return common_name_and_shape
+
+    @staticmethod
+    def are_common_params_ordered(list1: List[str], list2: List[str]) -> bool:
+        """
+        Check if common elements of list2 maintain their relative order in list1.
+        """
+        common_params = set(list1).intersection(set(list2))
+        last_index = -1
+
+        for param in list2:
+            if param in common_params:
+                current_index = list1.index(param)
+                if current_index < last_index:
+                    return False
+                last_index = current_index
+        return True
+
+    @staticmethod
+    def ordered_sublist(list1: List[str], list2: List[str]) -> bool:
+        """
+        Check if list2 is a contiguous ordered sublist of list1.
+        """
+        n, m = len(list1), len(list2)
+
+        for i in range(n - m + 1):
+            if list1[i : i + m] == list2:
+                return True
+        return False
+
+    @staticmethod
+    def report_names_similarity(
+        base_names: List[str], other_names: List[str]
+    ) -> Tuple[Optional[str], str]:
+        """
+        Analyze similarity between parameter names of two models and identify shared prefixes.
+        Returns:
+            best_prefix (str): Best matching prefix for parameter names.
+            case_message (str): Explanation of the structural relationship.
+        """
+        possible_prefixes = {""}
+        possible_prefixes.update(
+            {item.split(".")[0] for item in base_names if "." in item}
+        )
+
+        prefixes_subset_overlap = {}
+        best_prefix = None
+        case_message = "No common parameter names found for any prefix"
+
+        for prefix in possible_prefixes:
+            base_names_stripped = [
+                ParameterNamesUtils.strip_prefix(name, prefix) for name in base_names
+            ]
+
+            if ParameterNamesUtils.ordered_sublist(base_names_stripped, other_names):
+                return prefix, "All params in model have exact match in base model."
+
+            intersection = set(base_names_stripped).intersection(set(other_names))
+            prefixes_subset_overlap[prefix] = intersection
+
+        if prefixes_subset_overlap:
+            best_prefix = max(
+                prefixes_subset_overlap, key=lambda x: len(prefixes_subset_overlap[x])
+            )
+            base_names_stripped = [
+                ParameterNamesUtils.strip_prefix(name, best_prefix)
+                for name in base_names
+            ]
+
+            overlap = len(prefixes_subset_overlap[best_prefix])
+            ordered = ParameterNamesUtils.are_common_params_ordered(
+                base_names_stripped, other_names
+            )
+            mismatched = [
+                item for item in other_names if item not in base_names_stripped
+            ]
+            mismatched = "\n    ".join(mismatched)
+            case_message = (
+                f"{overlap}/{len(other_names)} ({100 * overlap / len(other_names):.2f}%) "
+                f"of model parameters are in the base model. \n"
+                f"  Name ordering is {'preserved' if ordered else 'not preserved'}.\n"
+                f"  Missing parameters:\n    {mismatched}"
+            )
+
+        return best_prefix, case_message
+
+    @staticmethod
+    def tensor_shape(name, index) -> Tuple[int]:
+        from safetensors import safe_open
+
+        with safe_open(
+            Path(index.base_path) / index.tensor_paths[name], framework="pt"
+        ) as f:
+            return f.get_slice(name).get_shape()
diff --git a/mergekit/scripts/layershuffle.py b/mergekit/scripts/layershuffle.py
index 267e397c..b93c8bd5 100644
--- a/mergekit/scripts/layershuffle.py
+++ b/mergekit/scripts/layershuffle.py
@@ -7,7 +7,7 @@
 import click
 import yaml
 
-from mergekit.architecture import ArchitectureInfoUtils
+from mergekit.architecture import arch_info_for_config
 from mergekit.common import ModelReference
 from mergekit.config import (
     InputSliceDefinition,
@@ -64,7 +64,7 @@ def main(
     models = [ModelReference.parse(m) for m in model]
 
     m0_cfg = models[0].config()
-    arch_info = ArchitectureInfoUtils.get_architecture_info(m0_cfg)
+    arch_info = arch_info_for_config(m0_cfg)
     total_num_layers = arch_info.num_layers(m0_cfg)
 
     out_slices: List[OutputSliceDefinition] = []
diff --git a/mergekit/scripts/moe.py b/mergekit/scripts/moe.py
index 87eef5d0..b0c27594 100644
--- a/mergekit/scripts/moe.py
+++ b/mergekit/scripts/moe.py
@@ -163,9 +163,6 @@ def select_output_arch(
     help="Device to use to compute embeddings",
     show_default=True,
 )
-@click.option(
-    "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
-)
 @click.option(
     "--i-understand-this-is-not-useful-without-training",
     type=bool,
@@ -180,7 +177,6 @@ def main(
     load_in_4bit: bool,
     load_in_8bit: bool,
     device: str,
-    verbose: bool,
     i_understand_this_is_not_useful_without_training: bool,
     merge_options: MergeOptions,
 ):
@@ -204,7 +200,7 @@ def main(
         load_in_8bit=load_in_8bit,
         device=device,
         allow_all_same=i_understand_this_is_not_useful_without_training,
-        verbose=verbose,
+        verbose=merge_options.verbose,
     )
 
     if merge_options.write_model_card:
diff --git a/mergekit/scripts/tokensurgeon.py b/mergekit/scripts/tokensurgeon.py
index 406e7aa6..98b9780c 100644
--- a/mergekit/scripts/tokensurgeon.py
+++ b/mergekit/scripts/tokensurgeon.py
@@ -13,9 +13,9 @@
 from typing_extensions import TypeAlias
 
 from mergekit.architecture import (
-    ArchitectureInfoUtils,
-    ConfiguredArchitectureInfo,
+    ConfiguredModelArchitecture,
     WeightInfo,
+    arch_info_for_config,
 )
 from mergekit.common import ModelReference
 from mergekit.io import TensorWriter
@@ -132,6 +132,7 @@ def main(
         barycentric=barycentric,
         cosine_similarity=cosine_similarity,
         name=embed_info.name,
+        log_reconstruction_error=verbosity > 0,
     )
 
     if lm_head_info:
@@ -269,21 +270,24 @@ def get_embedding_info(
 ) -> Tuple[WeightInfo, WeightInfo]:
     """Get WeightInfo for the input and output embeddings of a model."""
     cfg = model.config(trust_remote_code=options.trust_remote_code)
-    arch_info = ArchitectureInfoUtils.get_architecture_info(cfg)
+    arch_info = arch_info_for_config(cfg)
+
+    if len(arch_info.modules) != 1:
+        raise RuntimeError("Model has multiple modules - not supported by tokensurgeon")
+    module_def = next(iter(arch_info.modules.values()))
 
     embed, lm_head = None, None
-    for weight_info in arch_info.pre_weights(cfg):
+    for weight_info in module_def.architecture.pre_weights(cfg):
         if weight_info.is_embed:
             if embed is not None:
                 raise RuntimeError("Multiple input embeddings found")
             embed = weight_info
 
-    for weight_info in arch_info.post_weights(cfg):
+    for weight_info in module_def.architecture.post_weights(cfg):
         if weight_info.is_embed:
             if lm_head is not None:
                 raise RuntimeError("Multiple output embeddings found")
             lm_head = weight_info
-
     return embed, lm_head
 
 
@@ -466,12 +470,14 @@ def get_embeddings(
 
         if log_reconstruction_error:
             # compute reconstruction error in donor_embed space
-            knn_reconstruction_error.append(
-                torch.nn.functional.mse_loss(
-                    (knn_embeddings.T.to(weights.dtype) @ weights).squeeze(),
-                    token_embedding,
-                ).item()
+            reconstructed = (
+                (knn_embeddings.T.to(weights.dtype) @ weights)
+                .squeeze()
+                .to(token_embedding.dtype)
             )
+            diff = token_embedding - reconstructed
+            mse = diff.square().mean().item()
+            knn_reconstruction_error.append(mse)
 
         # Reconstruct the embedding in original_embed space
         res[idx_1] = (e_c_0[indices].T @ weights).squeeze()
@@ -576,7 +582,7 @@ def load_tokenizer(
 
 def validate_architecture(
     model: ModelReference, donor: ModelReference, options: MergeOptions
-) -> Tuple[ConfiguredArchitectureInfo, transformers.PretrainedConfig]:
+) -> Tuple[ConfiguredModelArchitecture, transformers.PretrainedConfig]:
     """
     Validate that the architectures of two models match.
 
@@ -584,15 +590,18 @@ def validate_architecture(
     """
     model_cfg = model.config(trust_remote_code=options.trust_remote_code)
     donor_cfg = donor.config(trust_remote_code=options.trust_remote_code)
-    model_arch_info = ArchitectureInfoUtils.get_architecture_info(model_cfg)
-    donor_arch_info = ArchitectureInfoUtils.get_architecture_info(donor_cfg)
+    model_arch_info = arch_info_for_config(model_cfg)
+    donor_arch_info = arch_info_for_config(donor_cfg)
     if donor_arch_info != model_arch_info:
         report_issue(
-            f"Model architectures do not match: {model_arch_info.name()} vs {donor_arch_info.name()}",
+            f"Model architectures do not match: {model_arch_info.expected_model_type} vs {donor_arch_info.expected_model_type}",
             error=not options.allow_crimes,
         )
 
-    return ConfiguredArchitectureInfo(info=model_arch_info, config=model_cfg), donor_cfg
+    return (
+        ConfiguredModelArchitecture(info=model_arch_info, config=model_cfg),
+        donor_cfg,
+    )
 
 
 if __name__ == "__main__":
diff --git a/pyproject.toml b/pyproject.toml
index a8f11abc..db52c1f1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,7 +7,7 @@ name = "mergekit"
 description = "Tools for merging pre-trained large language models"
 readme = "README.md"
 license = { text = "BUSL-1.1" }
-version = "0.1.1"
+version = "0.1.2"
 authors = [{ name = "Charles Goddard", email = "chargoddard@gmail.com" }]
 requires-python = ">=3.10"
 dependencies = [
@@ -60,6 +60,7 @@ packages = [
     "mergekit.scripts",
     "mergekit.evo",
     "mergekit.tokenizer",
+    "mergekit.architecture",
     "mergekit._data",
     "mergekit._data.architectures",
     "mergekit._data.chat_templates",
diff --git a/tests/common.py b/tests/common.py
index 54068c54..9b7ceb9c 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -13,7 +13,10 @@
     LlavaForConditionalGeneration,
 )
 
-from mergekit.architecture import ArchitectureInfoUtils
+from mergekit.architecture import (
+    arch_info_for_config,
+    get_architecture_info,
+)
 from mergekit.config import MergeConfiguration
 from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex
 from mergekit.merge import MergeOptions, run_merge
@@ -53,9 +56,9 @@ def run_and_check_merge(
         if check_tensors:
             model_config = AutoConfig.from_pretrained(tmpdir)
             if auto_arch:
-                arch_info = ArchitectureInfoUtils.infer_architecture_info(config)
+                arch_info = get_architecture_info(config, MergeOptions())
             else:
-                arch_info = ArchitectureInfoUtils.get_architecture_info(model_config)
+                arch_info = arch_info_for_config(model_config)
 
             index = ShardedTensorIndex.from_disk(tmpdir)
             for weight_info in arch_info.all_weights(model_config):
diff --git a/tests/test_basic_merges.py b/tests/test_basic_merges.py
index 8aac2322..15c03621 100644
--- a/tests/test_basic_merges.py
+++ b/tests/test_basic_merges.py
@@ -119,13 +119,6 @@ def test_slerp_merge(self, model_a, model_b):
         config.parameters = {"t": 0.35}
         run_and_check_merge(config)
 
-    def test_nearswap_merge(self, model_a, model_b):
-        config = self.two_model_config(
-            model_a, model_b, merge_method="nearswap", base_model=model_a
-        )
-        config.parameters = {"t": 0.0001}
-        run_and_check_merge(config)
-
     def test_nuslerp_merges(self, model_a, model_b, model_c):
         for base_model in [None, model_c]:
             for row_wise in [False, True]:
diff --git a/tests/test_chat_template.py b/tests/test_chat_template.py
index af511a2b..2bd41cde 100644
--- a/tests/test_chat_template.py
+++ b/tests/test_chat_template.py
@@ -1,13 +1,25 @@
 from typing import Optional
 
-from common import run_and_check_merge
-from test_basic_merges import model_b
-from test_tokenizer import model_base
+import pytest
+from common import make_picollama, run_and_check_merge
+from test_tokenizer import make_tokenizer
 from transformers import AutoTokenizer
 
 from mergekit.config import InputModelDefinition, MergeConfiguration
 
 
+@pytest.fixture(scope="session")
+def model_base(tmp_path_factory):
+    model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64)
+    make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path)
+    return model_path
+
+
+@pytest.fixture(scope="session")
+def model_b(tmp_path_factory):
+    return make_picollama(tmp_path_factory.mktemp("model_b"))
+
+
 def check_chat_template(model_path: str, needle: Optional[str] = None):
     tokenizer = AutoTokenizer.from_pretrained(model_path)
     if needle is None: