@@ -162,8 +162,6 @@ def common_interp_approximate(
162
162
targets .size (0 ), - 1
163
163
)
164
164
knn_distances = distances
165
- print (f"indices: { indices .shape } " )
166
- print (f"knn_distances: { knn_distances .shape } " )
167
165
168
166
weights = approximate_from_landmarks (
169
167
targets ,
@@ -240,6 +238,7 @@ class TokenSurgeonOptions(BaseModel):
240
238
knn : bool = True
241
239
cosine_similarity : bool = False
242
240
average : bool = True
241
+ batch_size : Optional [int ] = None
243
242
244
243
245
244
def get_arch_info (
@@ -480,11 +479,23 @@ def build_embedding_matrix(
480
479
LOG .info (stats .pretty_print ())
481
480
if new_tokens :
482
481
LOG .info (f"Approximating { len (new_tokens )} tokens" )
483
- new_embeds = compute_new_embeddings (
484
- orig_embed , donor_embed , orig_vocab , donor_vocab , new_tokens , options
485
- )
486
- for ne_idx , token in enumerate (new_tokens ):
487
- res [donor_vocab [token ]] = new_embeds [ne_idx ]
482
+ batch_size = options .batch_size or len (new_tokens )
483
+ for base_idx in tqdm .tqdm (
484
+ range (0 , len (new_tokens ), batch_size ),
485
+ desc = "Approximating tokens" ,
486
+ ):
487
+ new_embeds = compute_new_embeddings (
488
+ orig_embed ,
489
+ donor_embed ,
490
+ orig_vocab ,
491
+ donor_vocab ,
492
+ new_tokens [base_idx : base_idx + batch_size ],
493
+ options ,
494
+ )
495
+ for ne_idx , token in enumerate (
496
+ new_tokens [base_idx : base_idx + batch_size ]
497
+ ):
498
+ res [donor_vocab [token ]] = new_embeds [ne_idx ]
488
499
return res
489
500
490
501
@@ -538,6 +549,13 @@ def build_embedding_matrix(
538
549
help = "Use average instead of sum for subword embedding approximation" ,
539
550
show_default = True ,
540
551
)
552
+ @click .option (
553
+ "--batch-size" ,
554
+ type = int ,
555
+ default = None ,
556
+ help = "Number of tokens to process in each batch" ,
557
+ show_default = True ,
558
+ )
541
559
@add_merge_options
542
560
def main (
543
561
model : str ,
@@ -549,6 +567,7 @@ def main(
549
567
approximation_method : str ,
550
568
weight_scheme : str ,
551
569
average : bool ,
570
+ batch_size : Optional [int ],
552
571
merge_options : MergeOptions ,
553
572
):
554
573
merge_options .apply_global_options ()
@@ -563,6 +582,7 @@ def main(
563
582
method = ApproximationMethod (approximation_method ),
564
583
weight_scheme = WeightingScheme (weight_scheme ),
565
584
average = average ,
585
+ batch_size = batch_size ,
566
586
)
567
587
568
588
cache = LoaderCache ()
0 commit comments