Skip to content

Commit ce98bac

Browse files
committed
Changes
1 parent 000f6b5 commit ce98bac

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

mergekit/scripts/tokensurgeon.py

+26
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ class TokenSurgeonOptions(BaseModel):
8888
cosine_similarity: bool = False
8989
subword_method: SubwordMethod = SubwordMethod.MEAN
9090
batch_size: Optional[int] = None
91+
new_vocab_noise: Optional[float] = None
92+
new_vocab_scale: Optional[float] = None
9193

9294

9395
def get_arch_info(
@@ -489,6 +491,10 @@ def build_embedding_matrix(
489491
token_basis=token_basis,
490492
options=options,
491493
)
494+
if options.new_vocab_noise:
495+
new_embeds += torch.randn_like(new_embeds) * options.new_vocab_noise
496+
if options.new_vocab_scale:
497+
new_embeds *= options.new_vocab_scale
492498
for ne_idx, token in enumerate(
493499
new_tokens[base_idx : base_idx + batch_size]
494500
):
@@ -592,6 +598,22 @@ class AllowMatch(enum.Enum):
592598
help="Filter out poorly trained tokens",
593599
show_default=True,
594600
)
601+
@click.option(
602+
"--new-vocab-noise",
603+
"-nvn",
604+
type=float,
605+
default=None,
606+
help="Add gaussian noise to new vocab embeddings",
607+
show_default=True,
608+
)
609+
@click.option(
610+
"--new-vocab-scale",
611+
"-nvs",
612+
type=float,
613+
default=None,
614+
help="Scale computed new vocab embeddings by this factor",
615+
show_default=True,
616+
)
595617
@add_merge_options
596618
def main(
597619
model: str,
@@ -607,6 +629,8 @@ def main(
607629
prefix_match: str,
608630
byte_match: str,
609631
magikarp: bool,
632+
new_vocab_noise: Optional[float],
633+
new_vocab_scale: Optional[float],
610634
merge_options: MergeOptions,
611635
):
612636
merge_options.apply_global_options()
@@ -622,6 +646,8 @@ def main(
622646
weight_scheme=WeightingScheme(weight_scheme),
623647
subword_method=SubwordMethod(subword_method),
624648
batch_size=batch_size,
649+
new_vocab_noise=new_vocab_noise,
650+
new_vocab_scale=new_vocab_scale,
625651
)
626652
prefix_match = AllowMatch(prefix_match)
627653
byte_match = AllowMatch(byte_match)

mergekit/tokensurgeon/token_basis.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,29 @@
1010

1111

1212
def sparse_linear_basis(
13-
embeddings: torch.Tensor,
13+
points: torch.Tensor,
1414
k: int,
1515
d: int,
1616
eps: float = 1e-8,
1717
) -> Tuple[torch.LongTensor, torch.Tensor]:
1818
"""
19-
Form an approximate orthogonal basis from sparse linear combinations of the input embeddings.
19+
Form an approximate orthogonal basis from sparse linear combinations of input points.
2020
Args:
21-
embeddings: (num_pts, embed_dim) tensor of embeddings
21+
points: (num_pts, embed_dim) tensor of input points
2222
k: number of points to select per basis vector
2323
d: dimensionality of the basis
2424
eps: numerical stability parameter
2525
Returns:
2626
indices: (d, k) tensor of selected indices
2727
coeffs: (d, k) tensor of coefficients for each selected point
2828
"""
29-
assert embeddings.dim() == 2
30-
num_pts, embed_dim = embeddings.shape
29+
assert points.dim() == 2
30+
num_pts, embed_dim = points.shape
3131
assert k <= num_pts, "k must be less than or equal to the number of points"
3232
assert d <= embed_dim, "d must be less than or equal to the embedding dimension"
3333

34-
mean_embed = embeddings.mean(dim=0)
35-
centered_embeddings = (embeddings - mean_embed).to(torch.float32)
34+
mean_embed = points.mean(dim=0)
35+
centered_embeddings = (points - mean_embed).to(torch.float32)
3636
covariance_matrix = (
3737
centered_embeddings.T @ centered_embeddings
3838
) / num_pts # (embed_dim, embed_dim)

0 commit comments

Comments
 (0)