Skip to content

Commit 000f6b5

Browse files
committed
More splitting up
1 parent 5147c84 commit 000f6b5

File tree

5 files changed

+307
-309
lines changed

5 files changed

+307
-309
lines changed

mergekit/scripts/tokensurgeon.py

+3-309
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,16 @@
2727
NormalizedToken,
2828
normalized_vocabulary,
2929
token_prefixes,
30-
unnormalize_token,
3130
)
3231
from mergekit.tokensurgeon import (
3332
SubwordMethod,
3433
WeightingScheme,
3534
batch_omp,
3635
common_interp_approximate,
36+
compute_token_basis,
37+
landmark_pca_approximate,
3738
subword_approximate,
39+
well_trained_tokens,
3840
)
3941
from mergekit.tokensurgeon.common_interpolation import DistanceMetric
4042

@@ -72,7 +74,6 @@ class ApproximationMethod(enum.Enum):
7274
JOHN_HEWITT = "random_matching_distribution"
7375
ORTHOGONAL_MATCHING_PURSUIT = "omp"
7476
LANDMARK_PCA = "landmark_pca"
75-
RBF = "rbf"
7677
SPARSE_TOKEN_BASIS = "stb"
7778

7879

@@ -244,116 +245,6 @@ def john_hewitt_init(orig_embed: torch.Tensor, num_new_tokens: int) -> torch.Ten
244245
return new_embeds.to(orig_embed.dtype)
245246

246247

247-
def landmark_pca_approximate(
248-
targets: torch.Tensor,
249-
points_a: torch.Tensor,
250-
points_b: torch.Tensor,
251-
) -> torch.Tensor:
252-
"""Given target points in space a and a set of reference points in both space a and b,
253-
approximate the target points in space b."""
254-
# points_a: (N, D_a)
255-
# points_b: (N, D_b)
256-
# 1:1 correspondence between points_a and points_b
257-
# targets: (B, D_a)
258-
num_points, d_a = points_a.shape
259-
batch_size, _ = targets.shape
260-
_, d_b = points_b.shape
261-
assert (
262-
points_a.shape[0] == points_b.shape[0]
263-
), "Number of points in A and B must match"
264-
assert targets.shape == (batch_size, d_a)
265-
266-
effective_dim = min(d_a, d_b)
267-
268-
out_dtype = targets.dtype
269-
points_a = points_a.float()
270-
points_b = points_b.float()
271-
targets = targets.float()
272-
273-
# Compute the mean of all points in A and B
274-
mean_a = points_a.mean(dim=0, keepdim=True) # (1, D_a)
275-
mean_b = points_b.mean(dim=0, keepdim=True) # (1, D_b)
276-
centered_a = points_a - mean_a # (N, D_a)
277-
centered_b = points_b - mean_b # (N, D_b)
278-
centered_targets = targets - mean_a # (B, D_a)
279-
280-
# Perform PCA to get the principal components
281-
U_a, S_a, V_a = torch.pca_lowrank(centered_a, q=effective_dim)
282-
U_b, S_b, V_b = torch.pca_lowrank(centered_b, q=effective_dim)
283-
284-
# Project reference points into PCA space
285-
A_pca = torch.mm(centered_a, V_a) # (N, effective_dim)
286-
B_pca = torch.mm(centered_b, V_b) # (N, effective_dim)
287-
288-
# Compute Procrustes matrix and solve for optimal rotation
289-
M = torch.mm(B_pca.t(), A_pca) # (effective_dim, effective_dim)
290-
U, S, V = torch.svd(M)
291-
R = torch.mm(U, V.t()) # (effective_dim, effective_dim)
292-
293-
# Transform targets through PCA spaces and rotation
294-
projected_a = torch.mm(centered_targets, V_a) # (B, effective_dim)
295-
rotated = torch.mm(projected_a, R) # (B, effective_dim)
296-
projected_b = torch.mm(rotated, V_b.t()) # (B, D_b)
297-
298-
# Translate back to original space B
299-
approximated_b = projected_b + mean_b
300-
301-
return approximated_b.to(out_dtype)
302-
303-
304-
def rbf_approximate(
305-
targets: torch.Tensor,
306-
points_a: torch.Tensor,
307-
points_b: torch.Tensor,
308-
epsilon: float = 1e-6,
309-
) -> torch.Tensor:
310-
"""
311-
Approximate target points from space 'a' to space 'b' using RBF interpolation.
312-
313-
Args:
314-
targets: Tensor of shape (B, D_a), points to approximate.
315-
points_a: Reference points in space 'a', shape (N, D_a).
316-
points_b: Corresponding points in space 'b', shape (N, D_b).
317-
epsilon: Small number to ensure numerical stability.
318-
319-
Returns:
320-
Approximate points in space 'b', tensor of shape (B, D_b).
321-
"""
322-
N, D_a = points_a.shape
323-
B, _ = targets.shape
324-
_, D_b = points_b.shape
325-
326-
assert (
327-
points_a.shape[0] == points_b.shape[0]
328-
), "points_a and points_b must have the same number of points."
329-
assert (
330-
targets.shape[1] == D_a
331-
), "targets and points_a must have the same dimensionality."
332-
333-
# Compute pairwise squared distances between points_a
334-
dist_matrix = torch.cdist(points_a, points_a, p=2).pow(2) # shape (N, N)
335-
336-
# Use Gaussian Radial Basis Function kernel
337-
sigma = torch.median(dist_matrix) + epsilon # heuristic sigma value
338-
rbf_kernel = torch.exp(-dist_matrix / (2 * sigma**2)) # (N, N)
339-
340-
# Solve for weights to map from points_a to points_b
341-
weights, _ = torch.lstsq(
342-
points_b, rbf_kernel + epsilon * torch.eye(N, device=points_a.device)
343-
)
344-
345-
# Compute distances between targets and points_a
346-
dist_targets = torch.cdist(targets, points_a, p=2).pow(2) # shape (B, N)
347-
348-
# Apply RBF kernel to target points
349-
rbf_targets = torch.exp(-dist_targets / (2 * sigma**2)) # shape (B, N)
350-
351-
# Approximate targets in space 'b'
352-
approximations = rbf_targets @ weights[:N]
353-
354-
return approximations
355-
356-
357248
def debug_reconstruction_for_random_tokens(
358249
coeffs: torch.Tensor,
359250
donor_shared_embeds: torch.Tensor,
@@ -397,8 +288,6 @@ def debug_reconstruction_for_random_tokens(
397288
)
398289
donor_tok_embed = donor_embed[donor_vocab[target_tokens[i]]]
399290
reconstructed = reconstructed_in_donor[i]
400-
err_rms = (reconstructed - donor_tok_embed).norm()
401-
err_rel = err_rms / donor_tok_embed.norm().clamp_min(1e-6)
402291
cos_sim = torch.nn.functional.cosine_similarity(
403292
donor_tok_embed,
404293
reconstructed,
@@ -408,66 +297,6 @@ def debug_reconstruction_for_random_tokens(
408297
print()
409298

410299

411-
def sparse_linear_basis(
412-
embeddings: torch.Tensor,
413-
k: int,
414-
d: int,
415-
eps: float = 1e-8,
416-
) -> Tuple[torch.LongTensor, torch.Tensor]:
417-
"""
418-
Form an approximate orthogonal basis from sparse linear combinations of the input embeddings.
419-
Args:
420-
embeddings: (num_pts, embed_dim) tensor of embeddings
421-
k: number of points to select per basis vector
422-
d: dimensionality of the basis
423-
eps: numerical stability parameter
424-
Returns:
425-
indices: (d, k) tensor of selected indices
426-
coeffs: (d, k) tensor of coefficients for each selected point
427-
"""
428-
assert embeddings.dim() == 2
429-
num_pts, embed_dim = embeddings.shape
430-
assert k <= num_pts, "k must be less than or equal to the number of points"
431-
assert d <= embed_dim, "d must be less than or equal to the embedding dimension"
432-
433-
mean_embed = embeddings.mean(dim=0)
434-
centered_embeddings = (embeddings - mean_embed).to(torch.float32)
435-
covariance_matrix = (
436-
centered_embeddings.T @ centered_embeddings
437-
) / num_pts # (embed_dim, embed_dim)
438-
439-
U, S, V = torch.linalg.svd(covariance_matrix)
440-
# Select the top d singular vectors
441-
U_d = U[:, :d] # (embed_dim, d)
442-
V_d = V[:, :d] # (embed_dim, d)
443-
S_d = S[:d] # (d,)
444-
445-
# use OMP to approximate the singular vectors
446-
indices, coeffs = batch_omp(
447-
U_d.t(), # (d, embed_dim)
448-
centered_embeddings, # (num_pts, embed_dim)
449-
k,
450-
eps=eps,
451-
)
452-
453-
if LOG.isEnabledFor(logging.DEBUG):
454-
rc_basis = torch.bmm(
455-
coeffs.unsqueeze(1).to(torch.float),
456-
centered_embeddings[indices].to(torch.float),
457-
).squeeze(1)
458-
for i in range(d):
459-
v_0 = U_d[:, i]
460-
v_1 = rc_basis[i]
461-
cos_sim = torch.nn.functional.cosine_similarity(v_0, v_1, dim=0)
462-
rms = torch.norm(v_0 - v_1)
463-
norm_rms = torch.norm(v_0 - (v_1 / v_1.norm().clamp_min(1e-6)))
464-
LOG.debug(
465-
f"Basis vector {i}: cos_sim = {cos_sim.item():.4f}, RMS = {rms.item():.4f}, norm_rms = {norm_rms.item():.4f}"
466-
)
467-
468-
return indices, coeffs
469-
470-
471300
def compute_new_embeddings(
472301
orig_embed: torch.Tensor,
473302
donor_embed: torch.Tensor,
@@ -502,7 +331,6 @@ def compute_new_embeddings(
502331
ApproximationMethod.COMMON_INTERPOLATION,
503332
ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT,
504333
ApproximationMethod.LANDMARK_PCA,
505-
ApproximationMethod.RBF,
506334
):
507335
shared_vocab = list(
508336
sorted(
@@ -524,12 +352,6 @@ def compute_new_embeddings(
524352
donor_shared_embeds,
525353
orig_shared_embeds,
526354
)
527-
elif options.method == ApproximationMethod.RBF:
528-
return rbf_approximate(
529-
targets,
530-
donor_shared_embeds,
531-
orig_shared_embeds,
532-
)
533355
elif options.method == ApproximationMethod.COMMON_INTERPOLATION:
534356
indices, coeffs = common_interp_approximate(
535357
targets,
@@ -682,141 +504,13 @@ def build_embedding_matrix(
682504
return res
683505

684506

685-
def compute_token_basis(
686-
orig_embed: torch.Tensor,
687-
donor_embed: torch.Tensor,
688-
orig_vocab: Dict[NormalizedToken, int],
689-
donor_vocab: Dict[NormalizedToken, int],
690-
junk_tokens: List[int],
691-
options: TokenSurgeonOptions,
692-
) -> Tuple[torch.Tensor, torch.Tensor]:
693-
common_vocab = set(orig_vocab.keys()) & set(donor_vocab.keys())
694-
junk_set = set(junk_tokens)
695-
common_vocab = [
696-
tok
697-
for tok in common_vocab
698-
if (tok not in donor_vocab or donor_vocab[tok] not in junk_set)
699-
]
700-
effective_dim = min(orig_embed.shape[1], donor_embed.shape[1])
701-
orig_shared_embeds = orig_embed[torch.tensor([orig_vocab[t] for t in common_vocab])]
702-
donor_shared_embeds = donor_embed[
703-
torch.tensor([donor_vocab[t] for t in common_vocab])
704-
]
705-
if donor_embed.shape[1] < orig_embed.shape[1]:
706-
basis_src_embeds = donor_shared_embeds
707-
LOG.debug(f"Using donor embeds to compute token basis")
708-
else:
709-
basis_src_embeds = orig_shared_embeds
710-
LOG.debug(f"Using original embeds to compute token basis")
711-
LOG.debug(f"Basis dimension: {effective_dim}")
712-
tb_indices, tb_weights = sparse_linear_basis(
713-
basis_src_embeds,
714-
k=options.k,
715-
d=effective_dim,
716-
)
717-
donor_basis = (
718-
torch.bmm(
719-
tb_weights.unsqueeze(1).to(torch.float),
720-
donor_shared_embeds[tb_indices].to(torch.float),
721-
)
722-
.squeeze(1)
723-
.to(donor_embed.dtype)
724-
)
725-
orig_basis = (
726-
torch.bmm(
727-
tb_weights.unsqueeze(1).to(torch.float),
728-
orig_shared_embeds[tb_indices].to(torch.float),
729-
)
730-
.squeeze(1)
731-
.to(orig_embed.dtype)
732-
)
733-
return (donor_basis, orig_basis)
734-
735-
736507
class AllowMatch(enum.Enum):
737508
LM_HEAD_ONLY = "lm_head"
738509
EMBED_ONLY = "embed"
739510
YES = "yes"
740511
NO = "no"
741512

742513

743-
def well_trained_tokens(
744-
vocab: Dict[NormalizedToken, int],
745-
embed: torch.Tensor,
746-
lm_head: Optional[torch.Tensor],
747-
known_unused: Optional[List[NormalizedToken]] = None,
748-
) -> List[NormalizedToken]:
749-
"""Get a list of tokens that are well-trained in the model."""
750-
unused_indices = set(range(embed.shape[0])) - set(vocab.values())
751-
if known_unused:
752-
unused_indices.update(vocab[tok] for tok in known_unused if tok in vocab)
753-
for tok in vocab:
754-
tok_text = unnormalize_token(tok)
755-
if "unused_token" in tok_text or "reserved_special_token" in tok_text:
756-
LOG.debug(f"Assuming {tok_text} is unused")
757-
unused_indices.add(vocab[tok])
758-
759-
if unused_indices:
760-
mean_unused_in = embed[list(unused_indices)].mean(dim=0)
761-
mean_unused_out = (
762-
lm_head[list(unused_indices)].mean(dim=0) if lm_head is not None else None
763-
)
764-
LOG.info(f"Found {len(unused_indices)} unused tokens")
765-
else:
766-
mean_unused_in = None
767-
mean_unused_out = None
768-
769-
bad_indices = set(unused_indices)
770-
771-
if lm_head is not None:
772-
# check L2 norm of input embeddings - use 5th percentile as threshold
773-
l2_norms = embed.norm(dim=1).float()
774-
threshold = torch.quantile(l2_norms, 0.05, dim=0)
775-
LOG.debug(f"Unused token threshold: {threshold.item():.4f} (5th percentile)")
776-
l2_bad_indices = torch.where(l2_norms < threshold)[0]
777-
if len(l2_bad_indices) > 0:
778-
bad_indices.update(l2_bad_indices.tolist())
779-
LOG.info(f"Discarding {len(l2_bad_indices)} low-l2 tokens")
780-
781-
if mean_unused_in is not None:
782-
# check cosine similarity of input embeddings
783-
cos_sim = torch.nn.functional.cosine_similarity(
784-
embed, mean_unused_in.unsqueeze(0), dim=1
785-
).float()
786-
threshold = torch.quantile(cos_sim, 0.9, dim=0)
787-
LOG.debug(
788-
f"Unused token threshold in embed_tokens: {threshold.item():.4f} (90th percentile)"
789-
)
790-
cos_bad_indices = torch.where(cos_sim > threshold)[0]
791-
if len(cos_bad_indices) > 0:
792-
bad_indices.update(cos_bad_indices.tolist())
793-
LOG.info(
794-
f"Discarding {len(cos_bad_indices)} high-sim to unused mean tokens"
795-
)
796-
797-
if lm_head is not None and mean_unused_out is not None:
798-
# check cosine similarity of output embeddings
799-
cos_sim = torch.nn.functional.cosine_similarity(
800-
lm_head, mean_unused_out.unsqueeze(0), dim=1
801-
).float()
802-
threshold = torch.quantile(cos_sim, 0.9, dim=0)
803-
LOG.debug(
804-
f"Unused token threshold in lm_head: {threshold.item():.4f} (90th percentile)"
805-
)
806-
cos_bad_indices = torch.where(cos_sim > threshold)[0]
807-
if len(cos_bad_indices) > 0:
808-
bad_indices.update(cos_bad_indices.tolist())
809-
LOG.info(
810-
f"Discarding {len(cos_bad_indices)} high-sim to unused mean tokens"
811-
)
812-
813-
good_tokens = [tok for tok, idx in vocab.items() if idx not in bad_indices]
814-
LOG.info(
815-
f"Found {len(good_tokens)} well-trained tokens, {len(bad_indices)} bad tokens"
816-
)
817-
return good_tokens
818-
819-
820514
@click.command("mergekit-tokensurgeon", cls=PrettyPrintHelp)
821515
@click.argument("model", type=str)
822516
@click.argument("donor", type=str)

0 commit comments

Comments
 (0)