Skip to content

Commit d84402b

Browse files
committed
More experiments
1 parent ff24630 commit d84402b

File tree

1 file changed

+47
-18
lines changed

1 file changed

+47
-18
lines changed

mergekit/scripts/tokensurgeon.py

+47-18
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ class WeightingScheme(enum.Enum):
7373
LEAST_SQUARES = "least_squares"
7474

7575

76+
class SubwordMethod(enum.Enum):
77+
MEAN = "mean"
78+
SUM = "sum"
79+
WEIGHTED_MEAN = "weighted_mean"
80+
FIRST_LAST = "first_last"
81+
82+
7683
def approximate_from_landmarks(
7784
targets: torch.Tensor,
7885
points: torch.Tensor,
@@ -304,7 +311,7 @@ class TokenSurgeonOptions(BaseModel):
304311
k: int = 8
305312
knn: bool = True
306313
cosine_similarity: bool = False
307-
average: bool = True
314+
subword_method: SubwordMethod = SubwordMethod.MEAN
308315
batch_size: Optional[int] = None
309316

310317

@@ -447,6 +454,7 @@ def get_out_arch_info(
447454
def subword_approximate(
448455
orig_embed: torch.Tensor,
449456
target_tokens: List[NormalizedToken],
457+
is_lm_head: bool,
450458
options: TokenSurgeonOptions,
451459
) -> torch.Tensor:
452460
res = torch.zeros(
@@ -463,10 +471,31 @@ def subword_approximate(
463471
for idx, token in enumerate(target_tokens):
464472
text = unnormalize_token(token)
465473
token_ids = tok_0(text, add_special_tokens=False)["input_ids"]
466-
for id in token_ids:
467-
res[idx] += orig_embed[id]
468-
if options.average and len(token_ids) > 0:
469-
res[idx] /= len(token_ids)
474+
475+
if options.subword_method in (SubwordMethod.MEAN, SubwordMethod.SUM):
476+
for id in token_ids:
477+
res[idx] += orig_embed[id]
478+
if options.subword_method == SubwordMethod.MEAN and len(token_ids) > 0:
479+
res[idx] /= len(token_ids)
480+
elif options.subword_method == SubwordMethod.WEIGHTED_MEAN:
481+
weights = list(range(1, len(token_ids) + 1))
482+
if not is_lm_head:
483+
# for embed_tokens, want last token to have highest weight
484+
# (vs. first token for lm_head)
485+
weights = weights[::-1]
486+
for id, weight in zip(token_ids, weights):
487+
res[idx] += weight * orig_embed[id]
488+
if len(token_ids) > 0:
489+
res[idx] /= sum(weights)
490+
elif options.subword_method == SubwordMethod.FIRST_LAST:
491+
if len(token_ids) == 0:
492+
continue
493+
if is_lm_head:
494+
res[idx] = orig_embed[token_ids[0]]
495+
else:
496+
res[idx] = orig_embed[token_ids[-1]]
497+
else:
498+
raise ValueError(f"Unknown subword method: {options.subword_method}")
470499
return res
471500

472501

@@ -476,6 +505,7 @@ def compute_new_embeddings(
476505
orig_vocab: Dict[NormalizedToken, int],
477506
donor_vocab: Dict[NormalizedToken, int],
478507
target_tokens: List[NormalizedToken],
508+
is_lm_head: bool,
479509
options: TokenSurgeonOptions,
480510
) -> torch.Tensor:
481511
assert all(t in donor_vocab for t in target_tokens)
@@ -524,21 +554,15 @@ def compute_new_embeddings(
524554
torch.tensor([orig_vocab[t] for t in shared_vocab])
525555
]
526556
targets = donor_embed[torch.tensor([donor_vocab[t] for t in target_tokens])]
527-
print(
528-
f"OMP: {len(shared_vocab)} shared tokens, {len(target_tokens)} targets, k={options.k}"
529-
)
530557
indices, coeffs = batch_omp(targets, donor_shared_embeds, options.k)
531-
print(f"OMP: coeffs shape {coeffs.shape}, indices shape {indices.shape}")
532558
res = (
533559
torch.bmm(coeffs.unsqueeze(1), orig_shared_embeds[indices].to(torch.float))
534560
.squeeze(1)
535561
.to(orig_embed.dtype)
536562
)
537-
print(f"OMP: res shape {res.shape}")
538-
print(repr(res))
539563
return res
540564
elif options.method == ApproximationMethod.SUBWORD:
541-
return subword_approximate(orig_embed, target_tokens, options)
565+
return subword_approximate(orig_embed, target_tokens, is_lm_head, options)
542566
else:
543567
raise ValueError(f"Unknown approximation method: {options.method}")
544568

@@ -551,6 +575,7 @@ def build_embedding_matrix(
551575
donor_vocab: Dict[NormalizedToken, int],
552576
allow_prefix: bool,
553577
allow_byte: bool,
578+
is_lm_head: bool,
554579
options: TokenSurgeonOptions,
555580
) -> torch.Tensor:
556581
LOG.info(f"Building new tensor for {weight_info.name}")
@@ -594,6 +619,7 @@ def build_embedding_matrix(
594619
orig_vocab,
595620
donor_vocab,
596621
new_tokens[base_idx : base_idx + batch_size],
622+
is_lm_head,
597623
options,
598624
)
599625
for ne_idx, token in enumerate(
@@ -647,10 +673,11 @@ def build_embedding_matrix(
647673
show_default=True,
648674
)
649675
@click.option(
650-
"--average/--no-average",
651-
is_flag=True,
652-
default=True,
653-
help="Use average instead of sum for subword embedding approximation",
676+
"--subword-method",
677+
"-s",
678+
type=click.Choice([m.value for m in SubwordMethod]),
679+
default=SubwordMethod.MEAN.value,
680+
help="Method for approximating embeddings with subword tokens",
654681
show_default=True,
655682
)
656683
@click.option(
@@ -670,7 +697,7 @@ def main(
670697
cosine_similarity: bool,
671698
approximation_method: str,
672699
weight_scheme: str,
673-
average: bool,
700+
subword_method: str,
674701
batch_size: Optional[int],
675702
merge_options: MergeOptions,
676703
):
@@ -685,7 +712,7 @@ def main(
685712
cosine_similarity=cosine_similarity,
686713
method=ApproximationMethod(approximation_method),
687714
weight_scheme=WeightingScheme(weight_scheme),
688-
average=average,
715+
subword_method=SubwordMethod(subword_method),
689716
batch_size=batch_size,
690717
)
691718

@@ -716,6 +743,7 @@ def main(
716743
donor_vocab=donor_vocab,
717744
allow_prefix=False,
718745
allow_byte=True,
746+
is_lm_head=False,
719747
options=options,
720748
)
721749
else:
@@ -738,6 +766,7 @@ def main(
738766
donor_vocab=donor_vocab,
739767
allow_prefix=True,
740768
allow_byte=True,
769+
is_lm_head=True,
741770
options=options,
742771
)
743772
else:

0 commit comments

Comments
 (0)