@@ -88,6 +88,8 @@ class TokenSurgeonOptions(BaseModel):
88
88
cosine_similarity : bool = False
89
89
subword_method : SubwordMethod = SubwordMethod .MEAN
90
90
batch_size : Optional [int ] = None
91
+ new_vocab_noise : Optional [float ] = None
92
+ new_vocab_scale : Optional [float ] = None
91
93
92
94
93
95
def get_arch_info (
@@ -489,6 +491,10 @@ def build_embedding_matrix(
489
491
token_basis = token_basis ,
490
492
options = options ,
491
493
)
494
+ if options .new_vocab_noise :
495
+ new_embeds += torch .randn_like (new_embeds ) * options .new_vocab_noise
496
+ if options .new_vocab_scale :
497
+ new_embeds *= options .new_vocab_scale
492
498
for ne_idx , token in enumerate (
493
499
new_tokens [base_idx : base_idx + batch_size ]
494
500
):
@@ -592,6 +598,22 @@ class AllowMatch(enum.Enum):
592
598
help = "Filter out poorly trained tokens" ,
593
599
show_default = True ,
594
600
)
601
+ @click .option (
602
+ "--new-vocab-noise" ,
603
+ "-nvn" ,
604
+ type = float ,
605
+ default = None ,
606
+ help = "Add gaussian noise to new vocab embeddings" ,
607
+ show_default = True ,
608
+ )
609
+ @click .option (
610
+ "--new-vocab-scale" ,
611
+ "-nvs" ,
612
+ type = float ,
613
+ default = None ,
614
+ help = "Scale computed new vocab embeddings by this factor" ,
615
+ show_default = True ,
616
+ )
595
617
@add_merge_options
596
618
def main (
597
619
model : str ,
@@ -607,6 +629,8 @@ def main(
607
629
prefix_match : str ,
608
630
byte_match : str ,
609
631
magikarp : bool ,
632
+ new_vocab_noise : Optional [float ],
633
+ new_vocab_scale : Optional [float ],
610
634
merge_options : MergeOptions ,
611
635
):
612
636
merge_options .apply_global_options ()
@@ -622,6 +646,8 @@ def main(
622
646
weight_scheme = WeightingScheme (weight_scheme ),
623
647
subword_method = SubwordMethod (subword_method ),
624
648
batch_size = batch_size ,
649
+ new_vocab_noise = new_vocab_noise ,
650
+ new_vocab_scale = new_vocab_scale ,
625
651
)
626
652
prefix_match = AllowMatch (prefix_match )
627
653
byte_match = AllowMatch (byte_match )
0 commit comments