Skip to content

Commit 5c036bf

Browse files
committed
Fix mean and zero approximation methods
1 parent bb03596 commit 5c036bf

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

mergekit/scripts/tokensurgeon.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,9 @@ def __eq__(self, other):
274274
class ZeroTensorTask(Task[torch.Tensor]):
275275
shape: Tuple[int, ...]
276276

277+
def arguments(self):
278+
return {}
279+
277280
def uses_accelerator(self):
278281
return True
279282

@@ -447,10 +450,6 @@ def plan_embedding(
447450
tied_names=weight_info.tied_names,
448451
force_main_thread=True,
449452
)
450-
# e_c_0 = torch.stack(
451-
# [original_embed[original_vocab[token]] for token in common_tokens]
452-
# )
453-
# e_c_1 = torch.stack([donor_embed[donor_vocab[token]] for token in common_tokens])
454453
t_e_c_0 = MultiIndexedEmbeddingTask(
455454
embeddings=t_original_embed,
456455
indices=tuple(original_vocab[token] for token in common_tokens),
@@ -459,7 +458,7 @@ def plan_embedding(
459458
embeddings=t_donor_embed,
460459
indices=tuple(donor_vocab[token] for token in common_tokens),
461460
)
462-
mean_donor_embed_task = EmbeddingMeanTask(embeddings=t_donor_embed)
461+
mean_embed_task = EmbeddingMeanTask(embeddings=t_original_embed)
463462

464463
stats = TokenAssignmentStats()
465464
embedding_tasks = []
@@ -519,7 +518,7 @@ def plan_embedding(
519518
average=options.average,
520519
)
521520
elif options.method == ApproximationMethod.MEAN:
522-
tok_embedding_task = mean_donor_embed_task
521+
tok_embedding_task = mean_embed_task
523522
elif options.method == ApproximationMethod.ZERO:
524523
tok_embedding_task = ZeroTensorTask(shape=(hidden_size,))
525524
else:

0 commit comments

Comments
 (0)