@@ -555,6 +555,7 @@ def compute_new_embeddings(
555
555
]
556
556
targets = donor_embed [torch .tensor ([donor_vocab [t ] for t in target_tokens ])]
557
557
indices , coeffs = batch_omp (targets , donor_shared_embeds , options .k )
558
+
558
559
res = (
559
560
torch .bmm (coeffs .unsqueeze (1 ), orig_shared_embeds [indices ].to (torch .float ))
560
561
.squeeze (1 )
@@ -660,7 +661,7 @@ def build_embedding_matrix(
660
661
"--approximation-method" ,
661
662
"-a" ,
662
663
type = click .Choice ([m .value for m in ApproximationMethod ]),
663
- default = ApproximationMethod .COMMON_INTERPOLATION .value ,
664
+ default = ApproximationMethod .ORTHOGONAL_MATCHING_PURSUIT .value ,
664
665
help = "Method for approximating missing tokens" ,
665
666
show_default = True ,
666
667
)
@@ -669,7 +670,7 @@ def build_embedding_matrix(
669
670
"-w" ,
670
671
type = click .Choice ([w .value for w in WeightingScheme ]),
671
672
default = WeightingScheme .DISTANCE_PROPORTIONAL .value ,
672
- help = "Weighting scheme for KNN interpolation" ,
673
+ help = "Weighting scheme for common-vocabulary interpolation" ,
673
674
show_default = True ,
674
675
)
675
676
@click .option (
@@ -690,7 +691,7 @@ def build_embedding_matrix(
690
691
@click .option (
691
692
"--allow-lm-head-prefix-match/--no-allow-lm-head-prefix-match" ,
692
693
is_flag = True ,
693
- default = True ,
694
+ default = False ,
694
695
help = "Allow prefix matches for LM head tokens" ,
695
696
show_default = True ,
696
697
)
0 commit comments