Skip to content

Commit 3ee1f3d

Browse files
committed
Whoops
1 parent ab58c45 commit 3ee1f3d

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

mergekit/merge_methods/karcher.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
class KarcherTask(Task[torch.Tensor]):
2121
"""
2222
Task for merging model weights using the Riemannian (Karcher) mean algorithm.
23-
23+
2424
The Karcher mean provides a geometrically meaningful way to average points on a manifold,
2525
which is particularly useful for merging model weights that can be interpreted as points
2626
on a hypersphere.
2727
"""
28+
2829
gather_tensors: MergeTensorInput
2930
weight_info: WeightInfo
3031
max_iter: int
@@ -39,23 +40,20 @@ def arguments(self) -> Dict[str, Task]:
3940
def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor:
4041
if len(tensors) == 1:
4142
return list(tensors.values())[0]
42-
43+
4344
# Extract tensors and prepare for merging
4445
model_tensors = list(tensors.values())
45-
46+
4647
# Ensure all tensors have compatible shapes
4748
for i in range(1, len(model_tensors)):
4849
rectify_embed_sizes(self.weight_info, [model_tensors[0], model_tensors[i]])
49-
50+
5051
# Calculate weights (equal by default)
5152
alphas = [1.0 / len(model_tensors)] * len(model_tensors)
52-
53+
5354
# Apply Karcher mean algorithm
5455
return karcher_merge_tensors(
55-
model_tensors,
56-
alphas,
57-
max_iter=self.max_iter,
58-
tol=self.tol
56+
model_tensors, alphas, max_iter=self.max_iter, tol=self.tol
5957
)
6058

6159
def group_label(self) -> Optional[str]:
@@ -65,10 +63,11 @@ def group_label(self) -> Optional[str]:
6563
class KarcherMerge(MergeMethod):
6664
"""
6765
Implementation of the Karcher mean merge method.
68-
66+
6967
This method merges model weights using the Riemannian (Karcher) mean concept,
7068
which provides a geometrically meaningful way to average points on a manifold.
7169
"""
70+
7271
def name(self) -> str:
7372
return "karcher"
7473

@@ -99,7 +98,7 @@ def make_task(
9998
# Use default values from parameters() if not provided
10099
max_iter = parameters["max_iter"] if "max_iter" in parameters else 10
101100
tol = parameters["tol"] if "tol" in parameters else 1e-5
102-
101+
103102
return KarcherTask(
104103
gather_tensors=tensors,
105104
weight_info=output_weight,
@@ -111,19 +110,19 @@ def make_task(
111110
def karcher_merge_tensors(tensors, alphas, max_iter=10, tol=1e-5):
112111
"""
113112
Implements weight fusion based on the Riemannian (Karcher) mean concept.
114-
113+
115114
Args:
116115
tensors: List of tensors to merge
117116
alphas: List of weights for each tensor
118117
max_iter: Maximum number of iterations for the Karcher mean algorithm
119118
tol: Convergence tolerance
120-
119+
121120
Returns:
122121
Merged tensor using Karcher mean algorithm
123122
"""
124123
if len(tensors) == 1:
125124
return tensors[0]
126-
125+
127126
# Calculate norms and unit vectors
128127
norms = []
129128
units = []
@@ -137,12 +136,12 @@ def karcher_merge_tensors(tensors, alphas, max_iter=10, tol=1e-5):
137136
else:
138137
norms.append(n_val)
139138
units.append((t / n).to(t.dtype))
140-
139+
141140
# Select non-zero weight vectors
142141
valid_indices = [i for i, n in enumerate(norms) if n > tol]
143142
if not valid_indices:
144143
return torch.zeros_like(tensors[0])
145-
144+
146145
valid_alphas = [alphas[i] for i in valid_indices]
147146
alpha_sum = sum(valid_alphas)
148147
normalized_alphas = [a / alpha_sum for a in valid_alphas]
@@ -157,7 +156,7 @@ def karcher_merge_tensors(tensors, alphas, max_iter=10, tol=1e-5):
157156
u = valid_units[0].clone()
158157
else:
159158
u = (u / norm_u).to(u.dtype)
160-
159+
161160
# Iterative Karcher mean computation
162161
for _ in range(max_iter):
163162
T = torch.zeros_like(u)
@@ -172,25 +171,25 @@ def karcher_merge_tensors(tensors, alphas, max_iter=10, tol=1e-5):
172171
# Ensure tensor operations
173172
sin_theta = torch.sin(theta)
174173
T += a * (theta / sin_theta) * (ui - dot * u)
175-
174+
176175
# Convert norm_T to tensor
177176
norm_T = torch.linalg.norm(T.float())
178177
if norm_T.item() < tol:
179178
break
180-
179+
181180
# Use tensor for trigonometric calculations
182181
cos_norm_T = torch.cos(norm_T)
183182
sin_norm_T = torch.sin(norm_T)
184183
u = (cos_norm_T * u + sin_norm_T * (T / norm_T)).to(u.dtype)
185-
184+
186185
# Ensure u is a unit vector
187186
u_norm = torch.linalg.norm(u.float())
188187
if u_norm.item() > tol:
189188
u = (u / u_norm).to(u.dtype)
190-
189+
191190
# Global scale: Weighted sum of original tensor norms (including zero vectors)
192191
s = 0.0
193192
for a, n in zip(alphas, norms):
194193
s += a * n
195-
194+
196195
return s * u

0 commit comments

Comments
 (0)