Skip to content

Commit aa9b499

Browse files
committed
Improve OMP
1 parent 5fa8660 commit aa9b499

File tree

1 file changed

+41
-24
lines changed

1 file changed

+41
-24
lines changed

mergekit/scripts/tokensurgeon.py

+41-24
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def approximate_from_landmarks(
9595
weights = 1 - distances
9696
else:
9797
weights = 1 / distances.clamp_min(1e-6)
98-
weights = weights / weights.sum(dim=1, keepdim=True)
98+
weights = weights / weights.sum(dim=1, keepdim=True).clamp_min(1e-6)
9999
elif scheme == WeightingScheme.BARYCENTRIC:
100100
weights = barycentric_weights(targets, points)
101101
elif scheme == WeightingScheme.LEAST_SQUARES:
@@ -179,52 +179,65 @@ def common_interp_approximate(
179179
return weights, indices, res
180180

181181

182-
def batch_omp(targets: torch.Tensor, pts: torch.Tensor, k: int):
182+
def batch_omp(
183+
targets: torch.Tensor, candidate_points: torch.Tensor, k: int
184+
) -> Tuple[torch.LongTensor, torch.Tensor]:
183185
"""
184-
Batched Orthogonal Matching Pursuit (OMP) to select `k` points from `pts` that best approximate each target in `targets`.
186+
Batched Orthogonal Matching Pursuit (OMP) to select `k` points from `candidate_points` that best approximate each target in `targets`.
185187
186188
Args:
187189
targets: (B, D) tensor of target vectors.
188-
pts: (N, D) tensor of candidate points.
190+
candidate_points: (N, D) tensor of candidate points.
189191
k: Number of points to select (sparsity level).
190192
191193
Returns:
192-
(B, k) tensor of indices selected for each target.
194+
selected_indices: (B, k) tensor of indices selected for each target.
195+
coeff: (B, k) tensor of coefficients for each selected point.
193196
"""
194197
B, D = targets.shape
195-
N, _ = pts.shape
198+
N, _ = candidate_points.shape
196199
device = targets.device
200+
if k > N:
201+
raise ValueError(f"Cannot select {k} points from {N} candidates")
202+
work_dtype = (
203+
targets.dtype
204+
if targets.dtype in (torch.float32, torch.float64)
205+
else torch.float32
206+
)
197207
# Initialize selected indices and residuals
198208
selected_indices = torch.zeros((B, k), dtype=torch.long, device=device)
199-
residuals = targets.clone()
209+
targets_work = targets.to(dtype=work_dtype)
210+
residuals = targets_work.clone()
211+
points_work = candidate_points.to(dtype=work_dtype)
212+
mask = torch.zeros((B, N), dtype=torch.bool, device=device)
213+
200214
for t in range(k):
201-
LOG.debug(f"OMP iteration {t} - current rms: {residuals.norm(dim=1).mean()}")
215+
rms_0 = residuals.norm(dim=1).mean()
202216
# Compute absolute inner products between residuals and points
203-
abs_inner = (residuals @ pts.T).abs() # (B, N)
204-
# Mask previously selected indices
205-
if t > 0:
206-
mask = torch.zeros((B, N), dtype=torch.bool, device=device)
207-
mask.scatter_(1, selected_indices[:, :t], True)
208-
abs_inner = abs_inner.masked_fill(mask, -torch.inf)
217+
abs_inner = (residuals @ points_work.T).abs() # (B, N)
218+
# Mask out already selected points
219+
abs_inner.masked_fill_(mask, -float("inf"))
220+
209221
# Select new index with maximum correlation
210222
_, new_idx = torch.max(abs_inner, dim=1) # (B,)
211223
selected_indices[:, t] = new_idx
224+
225+
# Update mask
226+
mask[torch.arange(B, device=device), new_idx] = True
227+
212228
# Gather selected points (B, t+1, D)
213-
batch_indices = selected_indices[:, : t + 1].unsqueeze(-1).expand(-1, -1, D)
214-
selected_points = torch.gather(
215-
pts.unsqueeze(0).expand(B, -1, -1), 1, batch_indices
216-
)
217-
selected_points_transposed = selected_points.transpose(1, 2) # Fix here
229+
selected_points = points_work[selected_indices[:, : t + 1]]
218230
# Solve least squares
219231
coeff = torch.linalg.lstsq(
220-
selected_points_transposed.float(), # (B, D, t+1)
221-
targets.unsqueeze(-1).float(), # (B, D, 1)
232+
selected_points.transpose(1, 2), # (B, D, t+1)
233+
targets_work.unsqueeze(-1), # (B, D, 1)
222234
).solution.squeeze(
223235
-1
224236
) # (B, t+1)
225237
# Update residuals
226-
approx = torch.bmm(coeff.unsqueeze(1), selected_points.float()).squeeze(1)
227-
residuals = targets - approx.to(targets.dtype)
238+
approx = torch.bmm(coeff.unsqueeze(1), selected_points).squeeze(1)
239+
residuals = targets_work - approx
240+
LOG.debug(f"OMP iteration {t}: RMS {rms_0} -> {residuals.norm(dim=1).mean()}")
228241
return selected_indices, coeff
229242

230243

@@ -432,7 +445,11 @@ def compute_new_embeddings(
432445
]
433446
targets = donor_embed[torch.tensor([donor_vocab[t] for t in target_tokens])]
434447
indices, coeffs = batch_omp(targets, donor_shared_embeds, options.k)
435-
return torch.bmm(coeffs.unsqueeze(1), orig_shared_embeds[indices]).squeeze(1)
448+
return (
449+
torch.bmm(coeffs.unsqueeze(1), orig_shared_embeds[indices].to(torch.float))
450+
.squeeze(1)
451+
.to(orig_embed.dtype)
452+
)
436453
elif options.method == ApproximationMethod.SUBWORD:
437454
raise NotImplementedError("Subword approximation not yet implemented")
438455
else:

0 commit comments

Comments
 (0)