@@ -73,6 +73,13 @@ class WeightingScheme(enum.Enum):
73
73
LEAST_SQUARES = "least_squares"
74
74
75
75
76
+ class SubwordMethod (enum .Enum ):
77
+ MEAN = "mean"
78
+ SUM = "sum"
79
+ WEIGHTED_MEAN = "weighted_mean"
80
+ FIRST_LAST = "first_last"
81
+
82
+
76
83
def approximate_from_landmarks (
77
84
targets : torch .Tensor ,
78
85
points : torch .Tensor ,
@@ -304,7 +311,7 @@ class TokenSurgeonOptions(BaseModel):
304
311
k : int = 8
305
312
knn : bool = True
306
313
cosine_similarity : bool = False
307
- average : bool = True
314
+ subword_method : SubwordMethod = SubwordMethod . MEAN
308
315
batch_size : Optional [int ] = None
309
316
310
317
@@ -447,6 +454,7 @@ def get_out_arch_info(
447
454
def subword_approximate (
448
455
orig_embed : torch .Tensor ,
449
456
target_tokens : List [NormalizedToken ],
457
+ is_lm_head : bool ,
450
458
options : TokenSurgeonOptions ,
451
459
) -> torch .Tensor :
452
460
res = torch .zeros (
@@ -463,10 +471,31 @@ def subword_approximate(
463
471
for idx , token in enumerate (target_tokens ):
464
472
text = unnormalize_token (token )
465
473
token_ids = tok_0 (text , add_special_tokens = False )["input_ids" ]
466
- for id in token_ids :
467
- res [idx ] += orig_embed [id ]
468
- if options .average and len (token_ids ) > 0 :
469
- res [idx ] /= len (token_ids )
474
+
475
+ if options .subword_method in (SubwordMethod .MEAN , SubwordMethod .SUM ):
476
+ for id in token_ids :
477
+ res [idx ] += orig_embed [id ]
478
+ if options .subword_method == SubwordMethod .MEAN and len (token_ids ) > 0 :
479
+ res [idx ] /= len (token_ids )
480
+ elif options .subword_method == SubwordMethod .WEIGHTED_MEAN :
481
+ weights = list (range (1 , len (token_ids ) + 1 ))
482
+ if not is_lm_head :
483
+ # for embed_tokens, want last token to have highest weight
484
+ # (vs. first token for lm_head)
485
+ weights = weights [::- 1 ]
486
+ for id , weight in zip (token_ids , weights ):
487
+ res [idx ] += weight * orig_embed [id ]
488
+ if len (token_ids ) > 0 :
489
+ res [idx ] /= sum (weights )
490
+ elif options .subword_method == SubwordMethod .FIRST_LAST :
491
+ if len (token_ids ) == 0 :
492
+ continue
493
+ if is_lm_head :
494
+ res [idx ] = orig_embed [token_ids [0 ]]
495
+ else :
496
+ res [idx ] = orig_embed [token_ids [- 1 ]]
497
+ else :
498
+ raise ValueError (f"Unknown subword method: { options .subword_method } " )
470
499
return res
471
500
472
501
@@ -476,6 +505,7 @@ def compute_new_embeddings(
476
505
orig_vocab : Dict [NormalizedToken , int ],
477
506
donor_vocab : Dict [NormalizedToken , int ],
478
507
target_tokens : List [NormalizedToken ],
508
+ is_lm_head : bool ,
479
509
options : TokenSurgeonOptions ,
480
510
) -> torch .Tensor :
481
511
assert all (t in donor_vocab for t in target_tokens )
@@ -524,21 +554,15 @@ def compute_new_embeddings(
524
554
torch .tensor ([orig_vocab [t ] for t in shared_vocab ])
525
555
]
526
556
targets = donor_embed [torch .tensor ([donor_vocab [t ] for t in target_tokens ])]
527
- print (
528
- f"OMP: { len (shared_vocab )} shared tokens, { len (target_tokens )} targets, k={ options .k } "
529
- )
530
557
indices , coeffs = batch_omp (targets , donor_shared_embeds , options .k )
531
- print (f"OMP: coeffs shape { coeffs .shape } , indices shape { indices .shape } " )
532
558
res = (
533
559
torch .bmm (coeffs .unsqueeze (1 ), orig_shared_embeds [indices ].to (torch .float ))
534
560
.squeeze (1 )
535
561
.to (orig_embed .dtype )
536
562
)
537
- print (f"OMP: res shape { res .shape } " )
538
- print (repr (res ))
539
563
return res
540
564
elif options .method == ApproximationMethod .SUBWORD :
541
- return subword_approximate (orig_embed , target_tokens , options )
565
+ return subword_approximate (orig_embed , target_tokens , is_lm_head , options )
542
566
else :
543
567
raise ValueError (f"Unknown approximation method: { options .method } " )
544
568
@@ -551,6 +575,7 @@ def build_embedding_matrix(
551
575
donor_vocab : Dict [NormalizedToken , int ],
552
576
allow_prefix : bool ,
553
577
allow_byte : bool ,
578
+ is_lm_head : bool ,
554
579
options : TokenSurgeonOptions ,
555
580
) -> torch .Tensor :
556
581
LOG .info (f"Building new tensor for { weight_info .name } " )
@@ -594,6 +619,7 @@ def build_embedding_matrix(
594
619
orig_vocab ,
595
620
donor_vocab ,
596
621
new_tokens [base_idx : base_idx + batch_size ],
622
+ is_lm_head ,
597
623
options ,
598
624
)
599
625
for ne_idx , token in enumerate (
@@ -647,10 +673,11 @@ def build_embedding_matrix(
647
673
show_default = True ,
648
674
)
649
675
@click .option (
650
- "--average/--no-average" ,
651
- is_flag = True ,
652
- default = True ,
653
- help = "Use average instead of sum for subword embedding approximation" ,
676
+ "--subword-method" ,
677
+ "-s" ,
678
+ type = click .Choice ([m .value for m in SubwordMethod ]),
679
+ default = SubwordMethod .MEAN .value ,
680
+ help = "Method for approximating embeddings with subword tokens" ,
654
681
show_default = True ,
655
682
)
656
683
@click .option (
@@ -670,7 +697,7 @@ def main(
670
697
cosine_similarity : bool ,
671
698
approximation_method : str ,
672
699
weight_scheme : str ,
673
- average : bool ,
700
+ subword_method : str ,
674
701
batch_size : Optional [int ],
675
702
merge_options : MergeOptions ,
676
703
):
@@ -685,7 +712,7 @@ def main(
685
712
cosine_similarity = cosine_similarity ,
686
713
method = ApproximationMethod (approximation_method ),
687
714
weight_scheme = WeightingScheme (weight_scheme ),
688
- average = average ,
715
+ subword_method = SubwordMethod ( subword_method ) ,
689
716
batch_size = batch_size ,
690
717
)
691
718
@@ -716,6 +743,7 @@ def main(
716
743
donor_vocab = donor_vocab ,
717
744
allow_prefix = False ,
718
745
allow_byte = True ,
746
+ is_lm_head = False ,
719
747
options = options ,
720
748
)
721
749
else :
@@ -738,6 +766,7 @@ def main(
738
766
donor_vocab = donor_vocab ,
739
767
allow_prefix = True ,
740
768
allow_byte = True ,
769
+ is_lm_head = True ,
741
770
options = options ,
742
771
)
743
772
else :
0 commit comments