Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChipAlign geodesic interpolation method #529

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ A quick overview of the currently supported merge methods:
| [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) + [TIES](https://arxiv.org/abs/2306.01708) | `breadcrumbs_ties` | ✅ | ✅ |
| [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ |
| NuSLERP | `nuslerp` | ❌ | ✅ |
| [ChipAlign](https://arxiv.org/abs/2412.19819) | `nuslerp` + `geodesic: true` | ❌ | ❌ |
| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ |
| [SCE](https://arxiv.org/abs/2408.07990) | `sce` | ✅ | ✅ |
Expand Down Expand Up @@ -333,9 +334,24 @@ Parameters:
- `weight`: relative weighting of a given tensor
- `nuslerp_flatten`: set to false to do row-wise/column-wise interpolation instead of treating tensors as vectors
- `nuslerp_row_wise`: SLERP row vectors instead of column vectors
- `geodesic`: (boolean) when true, use ChipAlign-style geodesic interpolation
- `lambda`: interpolation factor for geodesic interpolation (required when `geodesic` is true)

To replicate the behavior of the original `slerp` method, set `weight` to `1-t` and `t` for your first and second model respectively.

### ChipAlign

[ChipAlign](https://arxiv.org/abs/2412.19819) is implemented as an extension to NuSLERP that uses geodesic interpolation on the weight manifold to effectively merge a general instruction-aligned LLM with a domain-specific LLM. This approach is particularly useful for creating models that maintain both strong instruction-following capabilities and domain expertise.

To use ChipAlign, set `merge_method: nuslerp` and add the parameter `geodesic: true`. ChipAlign requires exactly two models, typically:
- An instruction-aligned general LLM (first model)
- A domain-specific expert LLM (second model)

Parameters:
- `lambda`: Interpolation factor between 0.0 and 1.0. At 0.0, the result will behave more like the instruction model; at 1.0, more like the domain model.

See the [chipalign.yml](examples/chipalign.yml) example for a complete configuration.

### [DELLA](https://arxiv.org/abs/2406.11617)

Building upon DARE, DELLA uses adaptive pruning based on parameter magnitudes. DELLA first ranks parameters in each row of delta parameters and assigns drop probabilities inversely proportional to their magnitudes. This allows it to retain more important changes while reducing interference. After pruning, it rescales the remaining parameters similar to [DARE](#dare). DELLA can be used with (`della`) or without (`della_linear`) the sign elect step of TIES
Expand Down
19 changes: 19 additions & 0 deletions examples/chipalign.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# ChipAlign merge example
# Merge an instruction-aligned model with a domain-specific chip model
# Based on https://arxiv.org/abs/2412.19819

models:
- model: instruction_model_path # Path to instruction-aligned model
- model: chip_model_path # Path to chip-specific model

merge_method: nuslerp
parameters:
geodesic: true
lambda: 0.5 # Interpolation factor between the two models (0.0 = instruction model, 1.0 = chip model)
nuslerp_flatten: true # Treat tensors as flattened vectors for geodesic interpolation

# Optionally configure tokenizer
tokenizer:
source: "union" # Combine vocabularies from both models

dtype: bfloat16 # Output precision
113 changes: 110 additions & 3 deletions mergekit/merge_methods/nuslerp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@


class NuSlerpTask(Task[torch.Tensor]):
"""Task for performing NuSLERP or ChipAlign merges between two model tensors.

Supports both traditional NuSLERP and ChipAlign-style geodesic interpolation
with magnitude preservation, as described in https://arxiv.org/abs/2412.19819.
"""
gather_tensors: MergeTensorInput
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]]
weight_info: WeightInfo
row_wise: bool
flatten: bool
base_model: Optional[ModelReference]
geodesic: bool # Whether to use ChipAlign-style geodesic interpolation
lambda_val: Optional[float] # Interpolation factor for geodesic mode

def uses_accelerator(self) -> bool:
return True
Expand All @@ -33,9 +40,11 @@ def arguments(self) -> Dict[str, Task]:
return {"tensors": self.gather_tensors}

def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor:
# Fast path for single-model case
if len(tensors) == 1:
return list(tensors.values())[0]

# Handle base model if provided
if self.base_model is not None:
if len(tensors) != 3:
raise RuntimeError(
Expand All @@ -45,34 +54,74 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor:
else:
base_tensor = None

# Extract tensors and weights
keys = list(tensors.keys())
tensors = [tensors[key] for key in keys]
weights = [self.tensor_parameters[key]["weight"] for key in keys]

# Verify exactly two models are provided
if len(tensors) != 2:
raise RuntimeError(
"NuSlerp merge expects exactly two models (plus optional base model)"
)

# Calculate interpolation factor from weights
if abs(sum(weights)) < 1e-6:
# this is fairly arbitrary, but it's more sane than exploding
t = 0.5
t = 0.5 # Default when weights sum to zero
else:
t = weights[1] / sum(weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for introducing a new lambda parameter instead of using t?


# Handle embedding tensors with different sizes
if base_tensor is not None:
tensors.append(base_tensor)
rectify_embed_sizes(self.weight_info, tensors)

# ChipAlign geodesic interpolation path
if self.geodesic:
if base_tensor is not None:
raise ValueError("ChipAlign-style geodesic interpolation does not support a base model.")
if self.lambda_val is None:
raise ValueError("lambda must be specified when geodesic=True")

# Extract the instruction and domain-specific tensors
instruction_tensor = tensors[0]
domain_tensor = tensors[1]

# Calculate norms for magnitude preservation
instruction_tensor_norm = torch.norm(instruction_tensor)
domain_tensor_norm = torch.norm(domain_tensor)

# Normalize to unit vectors
instruction_tensor_unit = instruction_tensor / instruction_tensor_norm
domain_tensor_unit = domain_tensor / domain_tensor_norm

# Perform spherical interpolation on unit vectors
from mergekit.merge_methods.slerp import slerp
merged_tensor_unit = slerp(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using the nuslerp function here instead - the old slerp moves tensors to CPU so it's a lot slower. It should give the same results though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This also lets us respect nuslerp_flatten and nuslerp_row_wise which would be good.)

self.lambda_val, instruction_tensor_unit, domain_tensor_unit
)

# Apply magnitude scaling using weighted geometric mean (from ChipAlign paper)
merged_tensor = (
(instruction_tensor_norm ** (1 - self.lambda_val))
* (domain_tensor_norm ** self.lambda_val)
* merged_tensor_unit
)
return merged_tensor

# Standard NuSlerp path
if base_tensor is not None:
base_tensor = tensors.pop()
# For task vector mode (with base model)
return base_tensor + nuslerp(
t,
tensors[0] - base_tensor,
tensors[1] - base_tensor,
dim=0 if self.row_wise else -1,
flatten=self.flatten,
)

# Direct tensor mode (no base model)
return nuslerp(
t,
tensors[0],
Expand All @@ -83,24 +132,54 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor:


class NuSlerpMerge(MergeMethod):
"""Merge method implementing both NuSLERP and ChipAlign geodesic interpolation.

Provides a flexible, enhanced implementation of spherical linear interpolation
with additional options for interpolation mode and parameter customization.
"""
def name(self) -> str:
return "nuslerp"

@override
def pretty_name(self):
return "NuSLERP"

@override
def reference_url(self):
return "https://arxiv.org/abs/2412.19819" if self.is_chipalign() else None

def is_chipalign(self) -> bool:
"""Check if configured as ChipAlign mode based on parameters."""
try:
return self._parameters and self._parameters.get("geodesic", False)
except AttributeError:
return False

def parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(
name="nuslerp_row_wise",
required=False,
default_value=False,
description="SLERP row vectors instead of column vectors",
),
ConfigParameterDef(
name="nuslerp_flatten",
required=False,
default_value=True,
description="Treat tensors as flattened vectors",
),
ConfigParameterDef(
name="geodesic",
required=False,
default_value=False,
description="Enable ChipAlign-style geodesic interpolation with magnitude preservation",
),
ConfigParameterDef(
name="lambda",
required=False,
default_value=None,
description="Interpolation factor (0.0-1.0) for geodesic mode; 0=first model, 1=second model",
),
]

Expand All @@ -117,13 +196,18 @@ def make_task(
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
**_kwargs,
) -> Task:
# Store parameters for reference_url to detect ChipAlign mode
self._parameters = parameters

return NuSlerpTask(
gather_tensors=tensors,
tensor_parameters=tensor_parameters,
weight_info=output_weight,
row_wise=parameters["nuslerp_row_wise"],
flatten=parameters["nuslerp_flatten"],
base_model=base_model,
geodesic=parameters["geodesic"],
lambda_val=parameters["lambda"],
)


Expand All @@ -135,31 +219,54 @@ def nuslerp(
eps: float = 1e-8,
flatten: bool = False,
):
"""Enhanced spherical linear interpolation (SLERP) with flexible tensor handling.

Args:
t: Interpolation factor between 0.0 and 1.0
v0: First tensor
v1: Second tensor
dim: Dimension along which to perform row/column-wise interpolation
eps: Small value to prevent division by zero
flatten: Whether to flatten tensors before interpolation

Returns:
Interpolated tensor with the same shape as inputs
"""
out_shape = v0.shape

def _normalize(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
"""Normalize tensor along last dimension with numeric stability."""
return x / torch.norm(x, dim=-1, keepdim=True).clamp(min=eps)

# Handle tensor reshaping based on interpolation mode
if flatten:
# Treat entire tensor as a single vector
v0 = v0.view(-1)
v1 = v1.view(-1)
elif dim != -1:
# Perform interpolation along specified dimension
v0 = v0.transpose(dim, -1)
v1 = v1.transpose(dim, -1)

# Normalize to unit vectors
v0_u = _normalize(v0)
v1_u = _normalize(v1)

# Calculate angle between vectors
cos_theta = torch.sum(v0_u * v1_u, dim=-1, keepdim=True)
theta = torch.acos(cos_theta.clamp(-1, 1))
sin_theta = torch.sin(theta)

# Handle (nearly) colinear vectors to avoid numerical issues
colinear = (sin_theta.abs() < eps).squeeze()

# SLERP formula: (sin((1-t)*θ)/sin(θ))*v0 + (sin(t*θ)/sin(θ))*v1
res = (torch.sin((1 - t) * theta) * v0 + torch.sin(t * theta) * v1) / sin_theta
# Use linear interpolation for (nearly) colinear vectors

# Fall back to linear interpolation for numerically colinear vectors
res[colinear] = (1 - t) * v0[colinear] + t * v1[colinear]

# Restore original tensor shape
if dim != -1 and not flatten:
res = res.transpose(dim, -1)
return res.view(out_shape)
Loading