Skip to content

Commit a9561e9

Browse files
committed
More experiments
1 parent ce98bac commit a9561e9

File tree

4 files changed

+456
-7
lines changed

4 files changed

+456
-7
lines changed

mergekit/scripts/tokensurgeon.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from mergekit.tokensurgeon import (
3232
SubwordMethod,
3333
WeightingScheme,
34+
batch_mp_rope,
3435
batch_omp,
3536
common_interp_approximate,
3637
compute_token_basis,
@@ -75,6 +76,7 @@ class ApproximationMethod(enum.Enum):
7576
ORTHOGONAL_MATCHING_PURSUIT = "omp"
7677
LANDMARK_PCA = "landmark_pca"
7778
SPARSE_TOKEN_BASIS = "stb"
79+
MATCHING_PURSUIT_ROPE = "mp_rope"
7880

7981

8082
class TokenSurgeonOptions(BaseModel):
@@ -333,6 +335,7 @@ def compute_new_embeddings(
333335
ApproximationMethod.COMMON_INTERPOLATION,
334336
ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT,
335337
ApproximationMethod.LANDMARK_PCA,
338+
ApproximationMethod.MATCHING_PURSUIT_ROPE,
336339
):
337340
shared_vocab = list(
338341
sorted(
@@ -347,6 +350,7 @@ def compute_new_embeddings(
347350
orig_shared_embeds = orig_embed[
348351
torch.tensor([orig_vocab[t] for t in shared_vocab])
349352
]
353+
res = None
350354
targets = donor_embed[torch.tensor([donor_vocab[t] for t in target_tokens])]
351355
if options.method == ApproximationMethod.LANDMARK_PCA:
352356
return landmark_pca_approximate(
@@ -366,6 +370,17 @@ def compute_new_embeddings(
366370
),
367371
weight_scheme=options.weight_scheme,
368372
)
373+
elif options.method == ApproximationMethod.MATCHING_PURSUIT_ROPE:
374+
indices, coeffs, res = batch_mp_rope(
375+
targets,
376+
donor_shared_embeds,
377+
orig_shared_embeds,
378+
k=options.k,
379+
num_heads_a=28,
380+
num_heads_b=32,
381+
a_rope_base=1000000,
382+
b_rope_base=500000,
383+
)
369384
else:
370385
indices, coeffs = batch_omp(targets, donor_shared_embeds, options.k)
371386

@@ -381,11 +396,14 @@ def compute_new_embeddings(
381396
options,
382397
)
383398

384-
res = (
385-
torch.bmm(coeffs.unsqueeze(1), orig_shared_embeds[indices].to(torch.float))
386-
.squeeze(1)
387-
.to(orig_embed.dtype)
388-
)
399+
if res is None:
400+
res = (
401+
torch.bmm(
402+
coeffs.unsqueeze(1), orig_shared_embeds[indices].to(torch.float)
403+
)
404+
.squeeze(1)
405+
.to(orig_embed.dtype)
406+
)
389407
return res
390408
elif options.method == ApproximationMethod.SUBWORD:
391409
return subword_approximate(orig_embed, target_tokens, is_lm_head, options)

mergekit/tokensurgeon/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
common_interp_approximate,
88
)
99
from .magikarp import well_trained_tokens
10-
from .omp import batch_omp
10+
from .omp import batch_mp_rope, batch_omp
1111
from .pca import landmark_pca_approximate
1212
from .subword import SubwordMethod, subword_approximate
1313
from .token_basis import compute_token_basis
@@ -17,6 +17,7 @@
1717
"DistanceMetric",
1818
"WeightingScheme",
1919
"batch_omp",
20+
"batch_mp_rope",
2021
"SubwordMethod",
2122
"subword_approximate",
2223
"well_trained_tokens",

mergekit/tokensurgeon/omp.py

+175-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
# SPDX-License-Identifier: BUSL-1.1
33

44
import logging
5-
from typing import Tuple
5+
from typing import Optional, Tuple
66

77
import torch
88

9+
from .rope_helpers import apply_rope, estimate_pos_id_best
10+
911
LOG = logging.getLogger(__name__)
1012

1113

@@ -114,3 +116,175 @@ def batch_omp(
114116
LOG.debug(f"OMP final RMS: {residuals.norm(dim=1).mean()}")
115117

116118
return selected_indices, final_coeff
119+
120+
121+
def batch_mp_resets(
122+
targets: torch.Tensor,
123+
candidate_points: torch.Tensor,
124+
k: int,
125+
eps: float = 1e-8,
126+
total_iterations: Optional[int] = None,
127+
) -> Tuple[torch.LongTensor, torch.Tensor]:
128+
"""
129+
Matching Pursuit with Resets
130+
131+
Algorithm:
132+
1. first perform k iterations of standard matching pursuit
133+
2. then, for each excess iteration, select a random index from the
134+
and remove it from the set of selected indices
135+
3. select a new one from the remaining candidates (may be the same, may be different)
136+
4. repeat until total_iterations are completed
137+
"""
138+
if total_iterations is None:
139+
total_iterations = k * 3
140+
if total_iterations < k:
141+
raise ValueError(
142+
f"total_iterations {total_iterations} must be greater than or equal to k {k}"
143+
)
144+
B, D = targets.shape
145+
N, _ = candidate_points.shape
146+
device = targets.device
147+
if k > N:
148+
raise ValueError(f"Cannot select {k} points from {N} candidates")
149+
work_dtype = (
150+
targets.dtype
151+
if targets.dtype in (torch.float32, torch.float64)
152+
else torch.float32
153+
)
154+
targets_work = targets.to(dtype=work_dtype)
155+
points_work = candidate_points.to(dtype=work_dtype)
156+
selected_indices = torch.zeros((B, k), dtype=torch.long, device=device)
157+
mask = torch.zeros((B, N), dtype=torch.bool, device=device)
158+
coeff = torch.zeros((B, k), dtype=work_dtype, device=device)
159+
residuals = targets_work.clone()
160+
161+
iter_indices = list(range(k))
162+
while len(iter_indices) < total_iterations:
163+
honk = torch.randperm(k).tolist()
164+
iter_indices.extend(honk)
165+
iter_indices = iter_indices[:total_iterations]
166+
167+
for step, t in enumerate(iter_indices):
168+
if step < k:
169+
# Initial phase: select a new candidate for position t
170+
inner_products = torch.matmul(residuals, points_work.T) # B x N
171+
# Mask already selected points
172+
inner_products = inner_products.masked_fill(mask, -float("inf"))
173+
max_values, max_indices = torch.max(inner_products, dim=1) # B, B
174+
selected_points = points_work[max_indices] # B x D
175+
norms_sq = torch.sum(selected_points**2, dim=1) + eps # B
176+
coeffs = max_values / norms_sq # B
177+
residuals -= coeffs.unsqueeze(-1) * selected_points
178+
selected_indices[:, t] = max_indices
179+
coeff[:, t] = coeffs
180+
mask.scatter_(1, max_indices.unsqueeze(1), True)
181+
else:
182+
# Replacement phase: replace the candidate at position t
183+
old_indices = selected_indices[:, t]
184+
old_coeffs = coeff[:, t]
185+
old_points = points_work[old_indices]
186+
# Add back the old contribution
187+
residuals += old_coeffs.unsqueeze(-1) * old_points
188+
# Remove old index from mask
189+
mask.scatter_(1, old_indices.unsqueeze(1), False)
190+
# Compute new inner products
191+
inner_products = torch.matmul(residuals, points_work.T)
192+
inner_products = inner_products.masked_fill(mask, -float("inf"))
193+
new_max_values, new_max_indices = torch.max(inner_products, dim=1)
194+
new_points = points_work[new_max_indices]
195+
norms_sq = torch.sum(new_points**2, dim=1) + eps
196+
new_coeffs = new_max_values / norms_sq
197+
residuals -= new_coeffs.unsqueeze(-1) * new_points
198+
selected_indices[:, t] = new_max_indices
199+
coeff[:, t] = new_coeffs
200+
# Update mask with new index
201+
mask.scatter_(1, new_max_indices.unsqueeze(1), True)
202+
203+
return selected_indices, coeff
204+
205+
206+
def batch_mp_rope(
207+
targets: torch.Tensor,
208+
points_a: torch.Tensor,
209+
points_b: torch.Tensor,
210+
k: int,
211+
num_heads_a: int,
212+
num_heads_b: int,
213+
eps: float = 1e-8,
214+
a_rope_base: float = 10000.0,
215+
b_rope_base: float = 10000.0,
216+
) -> torch.Tensor:
217+
B, D_a = targets.shape
218+
N, _ = points_a.shape
219+
_, D_b = points_b.shape
220+
assert (
221+
points_a.shape[0] == points_b.shape[0]
222+
), "Number of points in A and B must match"
223+
device = targets.device
224+
if k > N:
225+
raise ValueError(f"Cannot select {k} points from {N} candidates")
226+
work_dtype = (
227+
targets.dtype
228+
if targets.dtype in (torch.float32, torch.float64)
229+
else torch.float32
230+
)
231+
out_dtype = targets.dtype
232+
points_a = points_a.to(dtype=work_dtype)
233+
points_b = points_b.to(dtype=work_dtype)
234+
targets = targets.to(dtype=work_dtype)
235+
selected_indices = torch.zeros((B, k), dtype=torch.long, device=device)
236+
coeffs = torch.zeros((B, k), dtype=work_dtype, device=device)
237+
pos_ids = torch.zeros((B, k), dtype=work_dtype, device=device)
238+
mask = torch.zeros((B, N), dtype=torch.bool, device=device)
239+
residuals = targets.clone()
240+
241+
for t in range(k):
242+
abs_inner = (residuals @ points_a.T).abs() # (B, N)
243+
abs_inner.masked_fill_(mask, -float("inf"))
244+
245+
# Select new index with maximum correlation
246+
_, new_idx = torch.max(abs_inner, dim=1) # (B,)
247+
248+
# update state
249+
selected_indices[:, t] = new_idx
250+
mask[torch.arange(B, device=device), new_idx] = True
251+
new_atom = points_a[new_idx]
252+
253+
# compute position id for new atom
254+
pos_id = estimate_pos_id_best(
255+
new_atom,
256+
residuals,
257+
num_heads=num_heads_a,
258+
head_dim=D_a // num_heads_a,
259+
base=a_rope_base,
260+
)
261+
pos_ids[:, t] = pos_id
262+
new_atom = apply_rope(
263+
new_atom,
264+
pos_id,
265+
num_heads=num_heads_a,
266+
head_dim=D_a // num_heads_a,
267+
base=a_rope_base,
268+
)
269+
270+
# compute coefficients
271+
current_coeff = (residuals * new_atom).sum(dim=1) / (
272+
new_atom.pow(2).sum(dim=1).clamp(min=eps)
273+
)
274+
coeffs[:, t] = current_coeff
275+
276+
# update residuals
277+
residuals = residuals - current_coeff.unsqueeze(1) * new_atom
278+
279+
# return result in b space
280+
selected_points_b = points_b[selected_indices]
281+
atoms_b = apply_rope(
282+
selected_points_b,
283+
pos_ids,
284+
num_heads=num_heads_b,
285+
head_dim=D_b // num_heads_b,
286+
base=b_rope_base,
287+
)
288+
approx_b = (atoms_b * coeffs.unsqueeze(-1)).sum(dim=1)
289+
final_tensor = approx_b.to(out_dtype)
290+
return selected_indices, coeffs, final_tensor

0 commit comments

Comments
 (0)