diff --git a/README.md b/README.md index 8ad07333..cf244575 100644 --- a/README.md +++ b/README.md @@ -251,6 +251,7 @@ A quick overview of the currently supported merge methods: | [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) + [TIES](https://arxiv.org/abs/2306.01708) | `breadcrumbs_ties` | ✅ | ✅ | | [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ | | NuSLERP | `nuslerp` | ❌ | ✅ | +| [ChipAlign](https://arxiv.org/abs/2412.19819) | `nuslerp` + `geodesic: true` | ❌ | ❌ | | [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ | | [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ | | [SCE](https://arxiv.org/abs/2408.07990) | `sce` | ✅ | ✅ | @@ -333,9 +334,24 @@ Parameters: - `weight`: relative weighting of a given tensor - `nuslerp_flatten`: set to false to do row-wise/column-wise interpolation instead of treating tensors as vectors - `nuslerp_row_wise`: SLERP row vectors instead of column vectors +- `geodesic`: (boolean) when true, use ChipAlign-style geodesic interpolation +- `lambda`: interpolation factor for geodesic interpolation (required when `geodesic` is true) To replicate the behavior of the original `slerp` method, set `weight` to `1-t` and `t` for your first and second model respectively. +### ChipAlign + +[ChipAlign](https://arxiv.org/abs/2412.19819) is implemented as an extension to NuSLERP that uses geodesic interpolation on the weight manifold to effectively merge a general instruction-aligned LLM with a domain-specific LLM. This approach is particularly useful for creating models that maintain both strong instruction-following capabilities and domain expertise. + +To use ChipAlign, set `merge_method: nuslerp` and add the parameter `geodesic: true`. ChipAlign requires exactly two models, typically: +- An instruction-aligned general LLM (first model) +- A domain-specific expert LLM (second model) + +Parameters: +- `lambda`: Interpolation factor between 0.0 and 1.0. At 0.0, the result will behave more like the instruction model; at 1.0, more like the domain model. + +See the [chipalign.yml](examples/chipalign.yml) example for a complete configuration. + ### [DELLA](https://arxiv.org/abs/2406.11617) Building upon DARE, DELLA uses adaptive pruning based on parameter magnitudes. DELLA first ranks parameters in each row of delta parameters and assigns drop probabilities inversely proportional to their magnitudes. This allows it to retain more important changes while reducing interference. After pruning, it rescales the remaining parameters similar to [DARE](#dare). DELLA can be used with (`della`) or without (`della_linear`) the sign elect step of TIES diff --git a/examples/chipalign.yml b/examples/chipalign.yml new file mode 100644 index 00000000..95a62d8e --- /dev/null +++ b/examples/chipalign.yml @@ -0,0 +1,19 @@ +# ChipAlign merge example +# Merge an instruction-aligned model with a domain-specific chip model +# Based on https://arxiv.org/abs/2412.19819 + +models: + - model: instruction_model_path # Path to instruction-aligned model + - model: chip_model_path # Path to chip-specific model + +merge_method: nuslerp +parameters: + geodesic: true + lambda: 0.5 # Interpolation factor between the two models (0.0 = instruction model, 1.0 = chip model) + nuslerp_flatten: true # Treat tensors as flattened vectors for geodesic interpolation + +# Optionally configure tokenizer +tokenizer: + source: "union" # Combine vocabularies from both models + +dtype: bfloat16 # Output precision \ No newline at end of file diff --git a/mergekit/merge_methods/nuslerp.py b/mergekit/merge_methods/nuslerp.py index 43c80bb4..55157fb3 100644 --- a/mergekit/merge_methods/nuslerp.py +++ b/mergekit/merge_methods/nuslerp.py @@ -19,12 +19,19 @@ class NuSlerpTask(Task[torch.Tensor]): + """Task for performing NuSLERP or ChipAlign merges between two model tensors. + + Supports both traditional NuSLERP and ChipAlign-style geodesic interpolation + with magnitude preservation, as described in https://arxiv.org/abs/2412.19819. + """ gather_tensors: MergeTensorInput tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] weight_info: WeightInfo row_wise: bool flatten: bool base_model: Optional[ModelReference] + geodesic: bool # Whether to use ChipAlign-style geodesic interpolation + lambda_val: Optional[float] # Interpolation factor for geodesic mode def uses_accelerator(self) -> bool: return True @@ -33,9 +40,11 @@ def arguments(self) -> Dict[str, Task]: return {"tensors": self.gather_tensors} def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor: + # Fast path for single-model case if len(tensors) == 1: return list(tensors.values())[0] + # Handle base model if provided if self.base_model is not None: if len(tensors) != 3: raise RuntimeError( @@ -45,27 +54,65 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor: else: base_tensor = None + # Extract tensors and weights keys = list(tensors.keys()) tensors = [tensors[key] for key in keys] weights = [self.tensor_parameters[key]["weight"] for key in keys] + # Verify exactly two models are provided if len(tensors) != 2: raise RuntimeError( "NuSlerp merge expects exactly two models (plus optional base model)" ) + # Calculate interpolation factor from weights if abs(sum(weights)) < 1e-6: - # this is fairly arbitrary, but it's more sane than exploding - t = 0.5 + t = 0.5 # Default when weights sum to zero else: t = weights[1] / sum(weights) + # Handle embedding tensors with different sizes if base_tensor is not None: tensors.append(base_tensor) rectify_embed_sizes(self.weight_info, tensors) + # ChipAlign geodesic interpolation path + if self.geodesic: + if base_tensor is not None: + raise ValueError("ChipAlign-style geodesic interpolation does not support a base model.") + if self.lambda_val is None: + raise ValueError("lambda must be specified when geodesic=True") + + # Extract the instruction and domain-specific tensors + instruction_tensor = tensors[0] + domain_tensor = tensors[1] + + # Calculate norms for magnitude preservation + instruction_tensor_norm = torch.norm(instruction_tensor) + domain_tensor_norm = torch.norm(domain_tensor) + + # Normalize to unit vectors + instruction_tensor_unit = instruction_tensor / instruction_tensor_norm + domain_tensor_unit = domain_tensor / domain_tensor_norm + + # Perform spherical interpolation on unit vectors + from mergekit.merge_methods.slerp import slerp + merged_tensor_unit = slerp( + self.lambda_val, instruction_tensor_unit, domain_tensor_unit + ) + + # Apply magnitude scaling using weighted geometric mean (from ChipAlign paper) + merged_tensor = ( + (instruction_tensor_norm ** (1 - self.lambda_val)) + * (domain_tensor_norm ** self.lambda_val) + * merged_tensor_unit + ) + return merged_tensor + + # Standard NuSlerp path if base_tensor is not None: base_tensor = tensors.pop() + # For task vector mode (with base model) return base_tensor + nuslerp( t, tensors[0] - base_tensor, @@ -73,6 +120,8 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor: dim=0 if self.row_wise else -1, flatten=self.flatten, ) + + # Direct tensor mode (no base model) return nuslerp( t, tensors[0], @@ -83,6 +132,11 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor: class NuSlerpMerge(MergeMethod): + """Merge method implementing both NuSLERP and ChipAlign geodesic interpolation. + + Provides a flexible, enhanced implementation of spherical linear interpolation + with additional options for interpolation mode and parameter customization. + """ def name(self) -> str: return "nuslerp" @@ -90,17 +144,42 @@ def name(self) -> str: def pretty_name(self): return "NuSLERP" + @override + def reference_url(self): + return "https://arxiv.org/abs/2412.19819" if self.is_chipalign() else None + + def is_chipalign(self) -> bool: + """Check if configured as ChipAlign mode based on parameters.""" + try: + return self._parameters and self._parameters.get("geodesic", False) + except AttributeError: + return False + def parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef( name="nuslerp_row_wise", required=False, default_value=False, + description="SLERP row vectors instead of column vectors", ), ConfigParameterDef( name="nuslerp_flatten", required=False, default_value=True, + description="Treat tensors as flattened vectors", + ), + ConfigParameterDef( + name="geodesic", + required=False, + default_value=False, + description="Enable ChipAlign-style geodesic interpolation with magnitude preservation", + ), + ConfigParameterDef( + name="lambda", + required=False, + default_value=None, + description="Interpolation factor (0.0-1.0) for geodesic mode; 0=first model, 1=second model", ), ] @@ -117,6 +196,9 @@ def make_task( tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], **_kwargs, ) -> Task: + # Store parameters for reference_url to detect ChipAlign mode + self._parameters = parameters + return NuSlerpTask( gather_tensors=tensors, tensor_parameters=tensor_parameters, @@ -124,6 +206,8 @@ def make_task( row_wise=parameters["nuslerp_row_wise"], flatten=parameters["nuslerp_flatten"], base_model=base_model, + geodesic=parameters["geodesic"], + lambda_val=parameters["lambda"], ) @@ -135,31 +219,54 @@ def nuslerp( eps: float = 1e-8, flatten: bool = False, ): + """Enhanced spherical linear interpolation (SLERP) with flexible tensor handling. + + Args: + t: Interpolation factor between 0.0 and 1.0 + v0: First tensor + v1: Second tensor + dim: Dimension along which to perform row/column-wise interpolation + eps: Small value to prevent division by zero + flatten: Whether to flatten tensors before interpolation + + Returns: + Interpolated tensor with the same shape as inputs + """ out_shape = v0.shape def _normalize(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + """Normalize tensor along last dimension with numeric stability.""" return x / torch.norm(x, dim=-1, keepdim=True).clamp(min=eps) + # Handle tensor reshaping based on interpolation mode if flatten: + # Treat entire tensor as a single vector v0 = v0.view(-1) v1 = v1.view(-1) elif dim != -1: + # Perform interpolation along specified dimension v0 = v0.transpose(dim, -1) v1 = v1.transpose(dim, -1) + # Normalize to unit vectors v0_u = _normalize(v0) v1_u = _normalize(v1) + # Calculate angle between vectors cos_theta = torch.sum(v0_u * v1_u, dim=-1, keepdim=True) theta = torch.acos(cos_theta.clamp(-1, 1)) sin_theta = torch.sin(theta) + # Handle (nearly) colinear vectors to avoid numerical issues colinear = (sin_theta.abs() < eps).squeeze() + # SLERP formula: (sin((1-t)*θ)/sin(θ))*v0 + (sin(t*θ)/sin(θ))*v1 res = (torch.sin((1 - t) * theta) * v0 + torch.sin(t * theta) * v1) / sin_theta - # Use linear interpolation for (nearly) colinear vectors + + # Fall back to linear interpolation for numerically colinear vectors res[colinear] = (1 - t) * v0[colinear] + t * v1[colinear] + # Restore original tensor shape if dim != -1 and not flatten: res = res.transpose(dim, -1) return res.view(out_shape)