Skip to content

Commit ff24630

Browse files
committed
Re-add subword approaches
1 parent c917a85 commit ff24630

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

mergekit/scripts/tokensurgeon.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,32 @@ def get_out_arch_info(
444444
return ConfiguredModelArchitecture(info=arch_info_out, config=cfg_out)
445445

446446

447+
def subword_approximate(
448+
orig_embed: torch.Tensor,
449+
target_tokens: List[NormalizedToken],
450+
options: TokenSurgeonOptions,
451+
) -> torch.Tensor:
452+
res = torch.zeros(
453+
len(target_tokens),
454+
orig_embed.shape[1],
455+
device=orig_embed.device,
456+
dtype=orig_embed.dtype,
457+
)
458+
tok_0 = transformers.AutoTokenizer.from_pretrained(
459+
options.model.model.path,
460+
revision=options.model.model.revision,
461+
trust_remote_code=False,
462+
)
463+
for idx, token in enumerate(target_tokens):
464+
text = unnormalize_token(token)
465+
token_ids = tok_0(text, add_special_tokens=False)["input_ids"]
466+
for id in token_ids:
467+
res[idx] += orig_embed[id]
468+
if options.average and len(token_ids) > 0:
469+
res[idx] /= len(token_ids)
470+
return res
471+
472+
447473
def compute_new_embeddings(
448474
orig_embed: torch.Tensor,
449475
donor_embed: torch.Tensor,
@@ -498,14 +524,21 @@ def compute_new_embeddings(
498524
torch.tensor([orig_vocab[t] for t in shared_vocab])
499525
]
500526
targets = donor_embed[torch.tensor([donor_vocab[t] for t in target_tokens])]
527+
print(
528+
f"OMP: {len(shared_vocab)} shared tokens, {len(target_tokens)} targets, k={options.k}"
529+
)
501530
indices, coeffs = batch_omp(targets, donor_shared_embeds, options.k)
502-
return (
531+
print(f"OMP: coeffs shape {coeffs.shape}, indices shape {indices.shape}")
532+
res = (
503533
torch.bmm(coeffs.unsqueeze(1), orig_shared_embeds[indices].to(torch.float))
504534
.squeeze(1)
505535
.to(orig_embed.dtype)
506536
)
537+
print(f"OMP: res shape {res.shape}")
538+
print(repr(res))
539+
return res
507540
elif options.method == ApproximationMethod.SUBWORD:
508-
raise NotImplementedError("Subword approximation not yet implemented")
541+
return subword_approximate(orig_embed, target_tokens, options)
509542
else:
510543
raise ValueError(f"Unknown approximation method: {options.method}")
511544

0 commit comments

Comments
 (0)