20
20
class KarcherTask (Task [torch .Tensor ]):
21
21
"""
22
22
Task for merging model weights using the Riemannian (Karcher) mean algorithm.
23
-
23
+
24
24
The Karcher mean provides a geometrically meaningful way to average points on a manifold,
25
25
which is particularly useful for merging model weights that can be interpreted as points
26
26
on a hypersphere.
27
27
"""
28
+
28
29
gather_tensors : MergeTensorInput
29
30
weight_info : WeightInfo
30
31
max_iter : int
@@ -39,23 +40,20 @@ def arguments(self) -> Dict[str, Task]:
39
40
def execute (self , tensors : Dict [ModelReference , torch .Tensor ]) -> torch .Tensor :
40
41
if len (tensors ) == 1 :
41
42
return list (tensors .values ())[0 ]
42
-
43
+
43
44
# Extract tensors and prepare for merging
44
45
model_tensors = list (tensors .values ())
45
-
46
+
46
47
# Ensure all tensors have compatible shapes
47
48
for i in range (1 , len (model_tensors )):
48
49
rectify_embed_sizes (self .weight_info , [model_tensors [0 ], model_tensors [i ]])
49
-
50
+
50
51
# Calculate weights (equal by default)
51
52
alphas = [1.0 / len (model_tensors )] * len (model_tensors )
52
-
53
+
53
54
# Apply Karcher mean algorithm
54
55
return karcher_merge_tensors (
55
- model_tensors ,
56
- alphas ,
57
- max_iter = self .max_iter ,
58
- tol = self .tol
56
+ model_tensors , alphas , max_iter = self .max_iter , tol = self .tol
59
57
)
60
58
61
59
def group_label (self ) -> Optional [str ]:
@@ -65,10 +63,11 @@ def group_label(self) -> Optional[str]:
65
63
class KarcherMerge (MergeMethod ):
66
64
"""
67
65
Implementation of the Karcher mean merge method.
68
-
66
+
69
67
This method merges model weights using the Riemannian (Karcher) mean concept,
70
68
which provides a geometrically meaningful way to average points on a manifold.
71
69
"""
70
+
72
71
def name (self ) -> str :
73
72
return "karcher"
74
73
@@ -99,7 +98,7 @@ def make_task(
99
98
# Use default values from parameters() if not provided
100
99
max_iter = parameters ["max_iter" ] if "max_iter" in parameters else 10
101
100
tol = parameters ["tol" ] if "tol" in parameters else 1e-5
102
-
101
+
103
102
return KarcherTask (
104
103
gather_tensors = tensors ,
105
104
weight_info = output_weight ,
@@ -111,19 +110,19 @@ def make_task(
111
110
def karcher_merge_tensors (tensors , alphas , max_iter = 10 , tol = 1e-5 ):
112
111
"""
113
112
Implements weight fusion based on the Riemannian (Karcher) mean concept.
114
-
113
+
115
114
Args:
116
115
tensors: List of tensors to merge
117
116
alphas: List of weights for each tensor
118
117
max_iter: Maximum number of iterations for the Karcher mean algorithm
119
118
tol: Convergence tolerance
120
-
119
+
121
120
Returns:
122
121
Merged tensor using Karcher mean algorithm
123
122
"""
124
123
if len (tensors ) == 1 :
125
124
return tensors [0 ]
126
-
125
+
127
126
# Calculate norms and unit vectors
128
127
norms = []
129
128
units = []
@@ -137,12 +136,12 @@ def karcher_merge_tensors(tensors, alphas, max_iter=10, tol=1e-5):
137
136
else :
138
137
norms .append (n_val )
139
138
units .append ((t / n ).to (t .dtype ))
140
-
139
+
141
140
# Select non-zero weight vectors
142
141
valid_indices = [i for i , n in enumerate (norms ) if n > tol ]
143
142
if not valid_indices :
144
143
return torch .zeros_like (tensors [0 ])
145
-
144
+
146
145
valid_alphas = [alphas [i ] for i in valid_indices ]
147
146
alpha_sum = sum (valid_alphas )
148
147
normalized_alphas = [a / alpha_sum for a in valid_alphas ]
@@ -157,7 +156,7 @@ def karcher_merge_tensors(tensors, alphas, max_iter=10, tol=1e-5):
157
156
u = valid_units [0 ].clone ()
158
157
else :
159
158
u = (u / norm_u ).to (u .dtype )
160
-
159
+
161
160
# Iterative Karcher mean computation
162
161
for _ in range (max_iter ):
163
162
T = torch .zeros_like (u )
@@ -172,25 +171,25 @@ def karcher_merge_tensors(tensors, alphas, max_iter=10, tol=1e-5):
172
171
# Ensure tensor operations
173
172
sin_theta = torch .sin (theta )
174
173
T += a * (theta / sin_theta ) * (ui - dot * u )
175
-
174
+
176
175
# Convert norm_T to tensor
177
176
norm_T = torch .linalg .norm (T .float ())
178
177
if norm_T .item () < tol :
179
178
break
180
-
179
+
181
180
# Use tensor for trigonometric calculations
182
181
cos_norm_T = torch .cos (norm_T )
183
182
sin_norm_T = torch .sin (norm_T )
184
183
u = (cos_norm_T * u + sin_norm_T * (T / norm_T )).to (u .dtype )
185
-
184
+
186
185
# Ensure u is a unit vector
187
186
u_norm = torch .linalg .norm (u .float ())
188
187
if u_norm .item () > tol :
189
188
u = (u / u_norm ).to (u .dtype )
190
-
189
+
191
190
# Global scale: Weighted sum of original tensor norms (including zero vectors)
192
191
s = 0.0
193
192
for a , n in zip (alphas , norms ):
194
193
s += a * n
195
-
194
+
196
195
return s * u
0 commit comments