@@ -274,6 +274,9 @@ def __eq__(self, other):
274
274
class ZeroTensorTask (Task [torch .Tensor ]):
275
275
shape : Tuple [int , ...]
276
276
277
+ def arguments (self ):
278
+ return {}
279
+
277
280
def uses_accelerator (self ):
278
281
return True
279
282
@@ -447,10 +450,6 @@ def plan_embedding(
447
450
tied_names = weight_info .tied_names ,
448
451
force_main_thread = True ,
449
452
)
450
- # e_c_0 = torch.stack(
451
- # [original_embed[original_vocab[token]] for token in common_tokens]
452
- # )
453
- # e_c_1 = torch.stack([donor_embed[donor_vocab[token]] for token in common_tokens])
454
453
t_e_c_0 = MultiIndexedEmbeddingTask (
455
454
embeddings = t_original_embed ,
456
455
indices = tuple (original_vocab [token ] for token in common_tokens ),
@@ -459,7 +458,7 @@ def plan_embedding(
459
458
embeddings = t_donor_embed ,
460
459
indices = tuple (donor_vocab [token ] for token in common_tokens ),
461
460
)
462
- mean_donor_embed_task = EmbeddingMeanTask (embeddings = t_donor_embed )
461
+ mean_embed_task = EmbeddingMeanTask (embeddings = t_original_embed )
463
462
464
463
stats = TokenAssignmentStats ()
465
464
embedding_tasks = []
@@ -519,7 +518,7 @@ def plan_embedding(
519
518
average = options .average ,
520
519
)
521
520
elif options .method == ApproximationMethod .MEAN :
522
- tok_embedding_task = mean_donor_embed_task
521
+ tok_embedding_task = mean_embed_task
523
522
elif options .method == ApproximationMethod .ZERO :
524
523
tok_embedding_task = ZeroTensorTask (shape = (hidden_size ,))
525
524
else :
0 commit comments