Skip to content

Commit 5fa8660

Browse files
committed
Specify batch size
1 parent 1d683c0 commit 5fa8660

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

mergekit/scripts/tokensurgeon.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,6 @@ def common_interp_approximate(
162162
targets.size(0), -1
163163
)
164164
knn_distances = distances
165-
print(f"indices: {indices.shape}")
166-
print(f"knn_distances: {knn_distances.shape}")
167165

168166
weights = approximate_from_landmarks(
169167
targets,
@@ -240,6 +238,7 @@ class TokenSurgeonOptions(BaseModel):
240238
knn: bool = True
241239
cosine_similarity: bool = False
242240
average: bool = True
241+
batch_size: Optional[int] = None
243242

244243

245244
def get_arch_info(
@@ -480,11 +479,23 @@ def build_embedding_matrix(
480479
LOG.info(stats.pretty_print())
481480
if new_tokens:
482481
LOG.info(f"Approximating {len(new_tokens)} tokens")
483-
new_embeds = compute_new_embeddings(
484-
orig_embed, donor_embed, orig_vocab, donor_vocab, new_tokens, options
485-
)
486-
for ne_idx, token in enumerate(new_tokens):
487-
res[donor_vocab[token]] = new_embeds[ne_idx]
482+
batch_size = options.batch_size or len(new_tokens)
483+
for base_idx in tqdm.tqdm(
484+
range(0, len(new_tokens), batch_size),
485+
desc="Approximating tokens",
486+
):
487+
new_embeds = compute_new_embeddings(
488+
orig_embed,
489+
donor_embed,
490+
orig_vocab,
491+
donor_vocab,
492+
new_tokens[base_idx : base_idx + batch_size],
493+
options,
494+
)
495+
for ne_idx, token in enumerate(
496+
new_tokens[base_idx : base_idx + batch_size]
497+
):
498+
res[donor_vocab[token]] = new_embeds[ne_idx]
488499
return res
489500

490501

@@ -538,6 +549,13 @@ def build_embedding_matrix(
538549
help="Use average instead of sum for subword embedding approximation",
539550
show_default=True,
540551
)
552+
@click.option(
553+
"--batch-size",
554+
type=int,
555+
default=None,
556+
help="Number of tokens to process in each batch",
557+
show_default=True,
558+
)
541559
@add_merge_options
542560
def main(
543561
model: str,
@@ -549,6 +567,7 @@ def main(
549567
approximation_method: str,
550568
weight_scheme: str,
551569
average: bool,
570+
batch_size: Optional[int],
552571
merge_options: MergeOptions,
553572
):
554573
merge_options.apply_global_options()
@@ -563,6 +582,7 @@ def main(
563582
method=ApproximationMethod(approximation_method),
564583
weight_scheme=WeightingScheme(weight_scheme),
565584
average=average,
585+
batch_size=batch_size,
566586
)
567587

568588
cache = LoaderCache()

0 commit comments

Comments
 (0)