@@ -444,6 +444,32 @@ def get_out_arch_info(
444
444
return ConfiguredModelArchitecture (info = arch_info_out , config = cfg_out )
445
445
446
446
447
+ def subword_approximate (
448
+ orig_embed : torch .Tensor ,
449
+ target_tokens : List [NormalizedToken ],
450
+ options : TokenSurgeonOptions ,
451
+ ) -> torch .Tensor :
452
+ res = torch .zeros (
453
+ len (target_tokens ),
454
+ orig_embed .shape [1 ],
455
+ device = orig_embed .device ,
456
+ dtype = orig_embed .dtype ,
457
+ )
458
+ tok_0 = transformers .AutoTokenizer .from_pretrained (
459
+ options .model .model .path ,
460
+ revision = options .model .model .revision ,
461
+ trust_remote_code = False ,
462
+ )
463
+ for idx , token in enumerate (target_tokens ):
464
+ text = unnormalize_token (token )
465
+ token_ids = tok_0 (text , add_special_tokens = False )["input_ids" ]
466
+ for id in token_ids :
467
+ res [idx ] += orig_embed [id ]
468
+ if options .average and len (token_ids ) > 0 :
469
+ res [idx ] /= len (token_ids )
470
+ return res
471
+
472
+
447
473
def compute_new_embeddings (
448
474
orig_embed : torch .Tensor ,
449
475
donor_embed : torch .Tensor ,
@@ -498,14 +524,21 @@ def compute_new_embeddings(
498
524
torch .tensor ([orig_vocab [t ] for t in shared_vocab ])
499
525
]
500
526
targets = donor_embed [torch .tensor ([donor_vocab [t ] for t in target_tokens ])]
527
+ print (
528
+ f"OMP: { len (shared_vocab )} shared tokens, { len (target_tokens )} targets, k={ options .k } "
529
+ )
501
530
indices , coeffs = batch_omp (targets , donor_shared_embeds , options .k )
502
- return (
531
+ print (f"OMP: coeffs shape { coeffs .shape } , indices shape { indices .shape } " )
532
+ res = (
503
533
torch .bmm (coeffs .unsqueeze (1 ), orig_shared_embeds [indices ].to (torch .float ))
504
534
.squeeze (1 )
505
535
.to (orig_embed .dtype )
506
536
)
537
+ print (f"OMP: res shape { res .shape } " )
538
+ print (repr (res ))
539
+ return res
507
540
elif options .method == ApproximationMethod .SUBWORD :
508
- raise NotImplementedError ( "Subword approximation not yet implemented" )
541
+ return subword_approximate ( orig_embed , target_tokens , options )
509
542
else :
510
543
raise ValueError (f"Unknown approximation method: { options .method } " )
511
544
0 commit comments