27
27
NormalizedToken ,
28
28
normalized_vocabulary ,
29
29
token_prefixes ,
30
- unnormalize_token ,
31
30
)
32
31
from mergekit .tokensurgeon import (
33
32
SubwordMethod ,
34
33
WeightingScheme ,
35
34
batch_omp ,
36
35
common_interp_approximate ,
36
+ compute_token_basis ,
37
+ landmark_pca_approximate ,
37
38
subword_approximate ,
39
+ well_trained_tokens ,
38
40
)
39
41
from mergekit .tokensurgeon .common_interpolation import DistanceMetric
40
42
@@ -72,7 +74,6 @@ class ApproximationMethod(enum.Enum):
72
74
JOHN_HEWITT = "random_matching_distribution"
73
75
ORTHOGONAL_MATCHING_PURSUIT = "omp"
74
76
LANDMARK_PCA = "landmark_pca"
75
- RBF = "rbf"
76
77
SPARSE_TOKEN_BASIS = "stb"
77
78
78
79
@@ -244,116 +245,6 @@ def john_hewitt_init(orig_embed: torch.Tensor, num_new_tokens: int) -> torch.Ten
244
245
return new_embeds .to (orig_embed .dtype )
245
246
246
247
247
- def landmark_pca_approximate (
248
- targets : torch .Tensor ,
249
- points_a : torch .Tensor ,
250
- points_b : torch .Tensor ,
251
- ) -> torch .Tensor :
252
- """Given target points in space a and a set of reference points in both space a and b,
253
- approximate the target points in space b."""
254
- # points_a: (N, D_a)
255
- # points_b: (N, D_b)
256
- # 1:1 correspondence between points_a and points_b
257
- # targets: (B, D_a)
258
- num_points , d_a = points_a .shape
259
- batch_size , _ = targets .shape
260
- _ , d_b = points_b .shape
261
- assert (
262
- points_a .shape [0 ] == points_b .shape [0 ]
263
- ), "Number of points in A and B must match"
264
- assert targets .shape == (batch_size , d_a )
265
-
266
- effective_dim = min (d_a , d_b )
267
-
268
- out_dtype = targets .dtype
269
- points_a = points_a .float ()
270
- points_b = points_b .float ()
271
- targets = targets .float ()
272
-
273
- # Compute the mean of all points in A and B
274
- mean_a = points_a .mean (dim = 0 , keepdim = True ) # (1, D_a)
275
- mean_b = points_b .mean (dim = 0 , keepdim = True ) # (1, D_b)
276
- centered_a = points_a - mean_a # (N, D_a)
277
- centered_b = points_b - mean_b # (N, D_b)
278
- centered_targets = targets - mean_a # (B, D_a)
279
-
280
- # Perform PCA to get the principal components
281
- U_a , S_a , V_a = torch .pca_lowrank (centered_a , q = effective_dim )
282
- U_b , S_b , V_b = torch .pca_lowrank (centered_b , q = effective_dim )
283
-
284
- # Project reference points into PCA space
285
- A_pca = torch .mm (centered_a , V_a ) # (N, effective_dim)
286
- B_pca = torch .mm (centered_b , V_b ) # (N, effective_dim)
287
-
288
- # Compute Procrustes matrix and solve for optimal rotation
289
- M = torch .mm (B_pca .t (), A_pca ) # (effective_dim, effective_dim)
290
- U , S , V = torch .svd (M )
291
- R = torch .mm (U , V .t ()) # (effective_dim, effective_dim)
292
-
293
- # Transform targets through PCA spaces and rotation
294
- projected_a = torch .mm (centered_targets , V_a ) # (B, effective_dim)
295
- rotated = torch .mm (projected_a , R ) # (B, effective_dim)
296
- projected_b = torch .mm (rotated , V_b .t ()) # (B, D_b)
297
-
298
- # Translate back to original space B
299
- approximated_b = projected_b + mean_b
300
-
301
- return approximated_b .to (out_dtype )
302
-
303
-
304
- def rbf_approximate (
305
- targets : torch .Tensor ,
306
- points_a : torch .Tensor ,
307
- points_b : torch .Tensor ,
308
- epsilon : float = 1e-6 ,
309
- ) -> torch .Tensor :
310
- """
311
- Approximate target points from space 'a' to space 'b' using RBF interpolation.
312
-
313
- Args:
314
- targets: Tensor of shape (B, D_a), points to approximate.
315
- points_a: Reference points in space 'a', shape (N, D_a).
316
- points_b: Corresponding points in space 'b', shape (N, D_b).
317
- epsilon: Small number to ensure numerical stability.
318
-
319
- Returns:
320
- Approximate points in space 'b', tensor of shape (B, D_b).
321
- """
322
- N , D_a = points_a .shape
323
- B , _ = targets .shape
324
- _ , D_b = points_b .shape
325
-
326
- assert (
327
- points_a .shape [0 ] == points_b .shape [0 ]
328
- ), "points_a and points_b must have the same number of points."
329
- assert (
330
- targets .shape [1 ] == D_a
331
- ), "targets and points_a must have the same dimensionality."
332
-
333
- # Compute pairwise squared distances between points_a
334
- dist_matrix = torch .cdist (points_a , points_a , p = 2 ).pow (2 ) # shape (N, N)
335
-
336
- # Use Gaussian Radial Basis Function kernel
337
- sigma = torch .median (dist_matrix ) + epsilon # heuristic sigma value
338
- rbf_kernel = torch .exp (- dist_matrix / (2 * sigma ** 2 )) # (N, N)
339
-
340
- # Solve for weights to map from points_a to points_b
341
- weights , _ = torch .lstsq (
342
- points_b , rbf_kernel + epsilon * torch .eye (N , device = points_a .device )
343
- )
344
-
345
- # Compute distances between targets and points_a
346
- dist_targets = torch .cdist (targets , points_a , p = 2 ).pow (2 ) # shape (B, N)
347
-
348
- # Apply RBF kernel to target points
349
- rbf_targets = torch .exp (- dist_targets / (2 * sigma ** 2 )) # shape (B, N)
350
-
351
- # Approximate targets in space 'b'
352
- approximations = rbf_targets @ weights [:N ]
353
-
354
- return approximations
355
-
356
-
357
248
def debug_reconstruction_for_random_tokens (
358
249
coeffs : torch .Tensor ,
359
250
donor_shared_embeds : torch .Tensor ,
@@ -397,8 +288,6 @@ def debug_reconstruction_for_random_tokens(
397
288
)
398
289
donor_tok_embed = donor_embed [donor_vocab [target_tokens [i ]]]
399
290
reconstructed = reconstructed_in_donor [i ]
400
- err_rms = (reconstructed - donor_tok_embed ).norm ()
401
- err_rel = err_rms / donor_tok_embed .norm ().clamp_min (1e-6 )
402
291
cos_sim = torch .nn .functional .cosine_similarity (
403
292
donor_tok_embed ,
404
293
reconstructed ,
@@ -408,66 +297,6 @@ def debug_reconstruction_for_random_tokens(
408
297
print ()
409
298
410
299
411
- def sparse_linear_basis (
412
- embeddings : torch .Tensor ,
413
- k : int ,
414
- d : int ,
415
- eps : float = 1e-8 ,
416
- ) -> Tuple [torch .LongTensor , torch .Tensor ]:
417
- """
418
- Form an approximate orthogonal basis from sparse linear combinations of the input embeddings.
419
- Args:
420
- embeddings: (num_pts, embed_dim) tensor of embeddings
421
- k: number of points to select per basis vector
422
- d: dimensionality of the basis
423
- eps: numerical stability parameter
424
- Returns:
425
- indices: (d, k) tensor of selected indices
426
- coeffs: (d, k) tensor of coefficients for each selected point
427
- """
428
- assert embeddings .dim () == 2
429
- num_pts , embed_dim = embeddings .shape
430
- assert k <= num_pts , "k must be less than or equal to the number of points"
431
- assert d <= embed_dim , "d must be less than or equal to the embedding dimension"
432
-
433
- mean_embed = embeddings .mean (dim = 0 )
434
- centered_embeddings = (embeddings - mean_embed ).to (torch .float32 )
435
- covariance_matrix = (
436
- centered_embeddings .T @ centered_embeddings
437
- ) / num_pts # (embed_dim, embed_dim)
438
-
439
- U , S , V = torch .linalg .svd (covariance_matrix )
440
- # Select the top d singular vectors
441
- U_d = U [:, :d ] # (embed_dim, d)
442
- V_d = V [:, :d ] # (embed_dim, d)
443
- S_d = S [:d ] # (d,)
444
-
445
- # use OMP to approximate the singular vectors
446
- indices , coeffs = batch_omp (
447
- U_d .t (), # (d, embed_dim)
448
- centered_embeddings , # (num_pts, embed_dim)
449
- k ,
450
- eps = eps ,
451
- )
452
-
453
- if LOG .isEnabledFor (logging .DEBUG ):
454
- rc_basis = torch .bmm (
455
- coeffs .unsqueeze (1 ).to (torch .float ),
456
- centered_embeddings [indices ].to (torch .float ),
457
- ).squeeze (1 )
458
- for i in range (d ):
459
- v_0 = U_d [:, i ]
460
- v_1 = rc_basis [i ]
461
- cos_sim = torch .nn .functional .cosine_similarity (v_0 , v_1 , dim = 0 )
462
- rms = torch .norm (v_0 - v_1 )
463
- norm_rms = torch .norm (v_0 - (v_1 / v_1 .norm ().clamp_min (1e-6 )))
464
- LOG .debug (
465
- f"Basis vector { i } : cos_sim = { cos_sim .item ():.4f} , RMS = { rms .item ():.4f} , norm_rms = { norm_rms .item ():.4f} "
466
- )
467
-
468
- return indices , coeffs
469
-
470
-
471
300
def compute_new_embeddings (
472
301
orig_embed : torch .Tensor ,
473
302
donor_embed : torch .Tensor ,
@@ -502,7 +331,6 @@ def compute_new_embeddings(
502
331
ApproximationMethod .COMMON_INTERPOLATION ,
503
332
ApproximationMethod .ORTHOGONAL_MATCHING_PURSUIT ,
504
333
ApproximationMethod .LANDMARK_PCA ,
505
- ApproximationMethod .RBF ,
506
334
):
507
335
shared_vocab = list (
508
336
sorted (
@@ -524,12 +352,6 @@ def compute_new_embeddings(
524
352
donor_shared_embeds ,
525
353
orig_shared_embeds ,
526
354
)
527
- elif options .method == ApproximationMethod .RBF :
528
- return rbf_approximate (
529
- targets ,
530
- donor_shared_embeds ,
531
- orig_shared_embeds ,
532
- )
533
355
elif options .method == ApproximationMethod .COMMON_INTERPOLATION :
534
356
indices , coeffs = common_interp_approximate (
535
357
targets ,
@@ -682,141 +504,13 @@ def build_embedding_matrix(
682
504
return res
683
505
684
506
685
- def compute_token_basis (
686
- orig_embed : torch .Tensor ,
687
- donor_embed : torch .Tensor ,
688
- orig_vocab : Dict [NormalizedToken , int ],
689
- donor_vocab : Dict [NormalizedToken , int ],
690
- junk_tokens : List [int ],
691
- options : TokenSurgeonOptions ,
692
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
693
- common_vocab = set (orig_vocab .keys ()) & set (donor_vocab .keys ())
694
- junk_set = set (junk_tokens )
695
- common_vocab = [
696
- tok
697
- for tok in common_vocab
698
- if (tok not in donor_vocab or donor_vocab [tok ] not in junk_set )
699
- ]
700
- effective_dim = min (orig_embed .shape [1 ], donor_embed .shape [1 ])
701
- orig_shared_embeds = orig_embed [torch .tensor ([orig_vocab [t ] for t in common_vocab ])]
702
- donor_shared_embeds = donor_embed [
703
- torch .tensor ([donor_vocab [t ] for t in common_vocab ])
704
- ]
705
- if donor_embed .shape [1 ] < orig_embed .shape [1 ]:
706
- basis_src_embeds = donor_shared_embeds
707
- LOG .debug (f"Using donor embeds to compute token basis" )
708
- else :
709
- basis_src_embeds = orig_shared_embeds
710
- LOG .debug (f"Using original embeds to compute token basis" )
711
- LOG .debug (f"Basis dimension: { effective_dim } " )
712
- tb_indices , tb_weights = sparse_linear_basis (
713
- basis_src_embeds ,
714
- k = options .k ,
715
- d = effective_dim ,
716
- )
717
- donor_basis = (
718
- torch .bmm (
719
- tb_weights .unsqueeze (1 ).to (torch .float ),
720
- donor_shared_embeds [tb_indices ].to (torch .float ),
721
- )
722
- .squeeze (1 )
723
- .to (donor_embed .dtype )
724
- )
725
- orig_basis = (
726
- torch .bmm (
727
- tb_weights .unsqueeze (1 ).to (torch .float ),
728
- orig_shared_embeds [tb_indices ].to (torch .float ),
729
- )
730
- .squeeze (1 )
731
- .to (orig_embed .dtype )
732
- )
733
- return (donor_basis , orig_basis )
734
-
735
-
736
507
class AllowMatch (enum .Enum ):
737
508
LM_HEAD_ONLY = "lm_head"
738
509
EMBED_ONLY = "embed"
739
510
YES = "yes"
740
511
NO = "no"
741
512
742
513
743
- def well_trained_tokens (
744
- vocab : Dict [NormalizedToken , int ],
745
- embed : torch .Tensor ,
746
- lm_head : Optional [torch .Tensor ],
747
- known_unused : Optional [List [NormalizedToken ]] = None ,
748
- ) -> List [NormalizedToken ]:
749
- """Get a list of tokens that are well-trained in the model."""
750
- unused_indices = set (range (embed .shape [0 ])) - set (vocab .values ())
751
- if known_unused :
752
- unused_indices .update (vocab [tok ] for tok in known_unused if tok in vocab )
753
- for tok in vocab :
754
- tok_text = unnormalize_token (tok )
755
- if "unused_token" in tok_text or "reserved_special_token" in tok_text :
756
- LOG .debug (f"Assuming { tok_text } is unused" )
757
- unused_indices .add (vocab [tok ])
758
-
759
- if unused_indices :
760
- mean_unused_in = embed [list (unused_indices )].mean (dim = 0 )
761
- mean_unused_out = (
762
- lm_head [list (unused_indices )].mean (dim = 0 ) if lm_head is not None else None
763
- )
764
- LOG .info (f"Found { len (unused_indices )} unused tokens" )
765
- else :
766
- mean_unused_in = None
767
- mean_unused_out = None
768
-
769
- bad_indices = set (unused_indices )
770
-
771
- if lm_head is not None :
772
- # check L2 norm of input embeddings - use 5th percentile as threshold
773
- l2_norms = embed .norm (dim = 1 ).float ()
774
- threshold = torch .quantile (l2_norms , 0.05 , dim = 0 )
775
- LOG .debug (f"Unused token threshold: { threshold .item ():.4f} (5th percentile)" )
776
- l2_bad_indices = torch .where (l2_norms < threshold )[0 ]
777
- if len (l2_bad_indices ) > 0 :
778
- bad_indices .update (l2_bad_indices .tolist ())
779
- LOG .info (f"Discarding { len (l2_bad_indices )} low-l2 tokens" )
780
-
781
- if mean_unused_in is not None :
782
- # check cosine similarity of input embeddings
783
- cos_sim = torch .nn .functional .cosine_similarity (
784
- embed , mean_unused_in .unsqueeze (0 ), dim = 1
785
- ).float ()
786
- threshold = torch .quantile (cos_sim , 0.9 , dim = 0 )
787
- LOG .debug (
788
- f"Unused token threshold in embed_tokens: { threshold .item ():.4f} (90th percentile)"
789
- )
790
- cos_bad_indices = torch .where (cos_sim > threshold )[0 ]
791
- if len (cos_bad_indices ) > 0 :
792
- bad_indices .update (cos_bad_indices .tolist ())
793
- LOG .info (
794
- f"Discarding { len (cos_bad_indices )} high-sim to unused mean tokens"
795
- )
796
-
797
- if lm_head is not None and mean_unused_out is not None :
798
- # check cosine similarity of output embeddings
799
- cos_sim = torch .nn .functional .cosine_similarity (
800
- lm_head , mean_unused_out .unsqueeze (0 ), dim = 1
801
- ).float ()
802
- threshold = torch .quantile (cos_sim , 0.9 , dim = 0 )
803
- LOG .debug (
804
- f"Unused token threshold in lm_head: { threshold .item ():.4f} (90th percentile)"
805
- )
806
- cos_bad_indices = torch .where (cos_sim > threshold )[0 ]
807
- if len (cos_bad_indices ) > 0 :
808
- bad_indices .update (cos_bad_indices .tolist ())
809
- LOG .info (
810
- f"Discarding { len (cos_bad_indices )} high-sim to unused mean tokens"
811
- )
812
-
813
- good_tokens = [tok for tok , idx in vocab .items () if idx not in bad_indices ]
814
- LOG .info (
815
- f"Found { len (good_tokens )} well-trained tokens, { len (bad_indices )} bad tokens"
816
- )
817
- return good_tokens
818
-
819
-
820
514
@click .command ("mergekit-tokensurgeon" , cls = PrettyPrintHelp )
821
515
@click .argument ("model" , type = str )
822
516
@click .argument ("donor" , type = str )
0 commit comments