Skip to content

Commit 1e74d3f

Browse files
committed
Add OMP experiment
1 parent 023adf1 commit 1e74d3f

File tree

1 file changed

+109
-21
lines changed

1 file changed

+109
-21
lines changed

mergekit/scripts/tokensurgeon.py

+109-21
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def get(self, model: ModelReference) -> transformers.PreTrainedTokenizerBase:
8181
return self.loaded[model]
8282

8383

84-
class EmbeddingKnnTask(Task[Tuple[torch.Tensor, torch.Tensor]]):
84+
class EmbeddingKnnTask(Task[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]):
8585
target_tensor: Task
8686
common_embedding: Task
8787
k: int
@@ -110,12 +110,12 @@ def execute(self, target: torch.Tensor, common_embeddings: torch.Tensor):
110110
).squeeze()
111111
distances, indices = torch.topk(distances, self.k, largest=False)
112112
knn_embeddings = common_embeddings[indices]
113-
return distances, knn_embeddings
113+
return distances, knn_embeddings, indices
114114

115115

116116
class BarycentricWeightsTask(Task[torch.Tensor]):
117117
target_tensor: Task # [torch.Tensor]
118-
knn_task: Task # [Tuple[torch.Tensor, torch.Tensor]]
118+
knn_task: Task # [Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
119119

120120
def arguments(self):
121121
return {
@@ -126,8 +126,11 @@ def arguments(self):
126126
def uses_accelerator(self):
127127
return True
128128

129+
def priority(self):
130+
return 11
131+
129132
def execute(self, target: torch.Tensor, knn: Tuple[torch.Tensor, torch.Tensor]):
130-
distances, knn_embeddings = knn
133+
_, knn_embeddings, _ = knn
131134

132135
# Find least squares barycentric weights
133136
# Constrain sum of weights to 1 by adding a row of 1s
@@ -147,14 +150,20 @@ def execute(self, target: torch.Tensor, knn: Tuple[torch.Tensor, torch.Tensor]):
147150
# despite it being explicitly recommended for this use case in the docs
148151
# so pinv instead
149152
# also upcast to float32 for stability
150-
weights = torch.linalg.pinv(knn_e_c.to(torch.float32), rcond=1e-6) @ e_c.to(
153+
weights = torch.linalg.pinv(knn_e_c.to(torch.float32), rcond=1e-8) @ e_c.to(
151154
torch.float32
152155
)
153-
return weights[:-1].to(target.dtype)
156+
if torch.isnan(weights).any():
157+
# try again with slight ridge regression
158+
weights = torch.linalg.pinv(
159+
knn_e_c.to(torch.float32) + 1e-6 * torch.eye(knn_e_c.shape[0]),
160+
rcond=1e-8,
161+
) @ e_c.to(torch.float32)
162+
return weights.squeeze(-1)
154163

155164

156165
class DistanceWeightsTask(Task[torch.Tensor]):
157-
knn_task: Task # [Tuple[torch.Tensor, torch.Tensor]]
166+
knn_task: Task # [Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
158167

159168
def arguments(self):
160169
return {
@@ -164,19 +173,80 @@ def arguments(self):
164173
def uses_accelerator(self):
165174
return True
166175

176+
def priority(self):
177+
return 11
178+
167179
def execute(self, knn: Tuple[torch.Tensor, torch.Tensor]):
168-
distances, _ = knn
180+
distances, _, _ = knn
169181
return torch.nn.functional.softmin(distances, dim=0)
170182

171183

172-
class ReconstructedEmbeddingTask(Task[torch.Tensor]):
184+
class OrthogonalMatchingPursuitWeightsTask(Task[Tuple[torch.LongTensor, torch.Tensor]]):
185+
target_tensor_task: Task # [torch.Tensor]
186+
common_embeddings_task: Task # [torch.Tensor]
187+
k: int
188+
189+
def arguments(self):
190+
return {
191+
"target": self.target_tensor_task,
192+
"common_embeddings": self.common_embeddings_task,
193+
}
194+
195+
def uses_accelerator(self):
196+
return True
197+
198+
def priority(self):
199+
return 10
200+
201+
def execute(self, target: torch.Tensor, common_embeddings: torch.Tensor):
202+
residual = target.clone()
203+
selected = []
204+
for _ in range(self.k):
205+
idx = torch.argmax(torch.abs(residual @ common_embeddings.T))
206+
selected.append(idx)
207+
B = common_embeddings[selected, :].T
208+
# pinv because rank-deficient and linalg.lstsq chokes on CUDA
209+
coeffs = torch.linalg.pinv(B.to(torch.float32)) @ target.to(torch.float32)
210+
residual = target - B @ coeffs.to(target.dtype)
211+
212+
return selected, coeffs
213+
214+
215+
class OmpReconstructedEmbeddingTask(Task[torch.Tensor]):
216+
omp_task: Task # [Tuple[torch.LongTensor, torch.Tensor]]
217+
common_embeddings_task: Task # [torch.Tensor]
218+
219+
def arguments(self):
220+
return {
221+
"omp": self.omp_task,
222+
"common_embeddings": self.common_embeddings_task,
223+
}
224+
225+
def uses_accelerator(self):
226+
return True
227+
228+
def priority(self):
229+
return 100
230+
231+
def execute(
232+
self,
233+
omp: Tuple[torch.LongTensor, torch.Tensor],
234+
common_embeddings: torch.Tensor,
235+
):
236+
indices, coeffs = omp
237+
return torch.sum(coeffs.unsqueeze(-1) * common_embeddings[indices], dim=0)
238+
239+
240+
class KnnReconstructedEmbeddingTask(Task[torch.Tensor]):
173241
weights_task: Task # [torch.Tensor]
174-
knn_task: Task # [Tuple[torch.Tensor, torch.Tensor]]
242+
knn_task: Task # [Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
243+
embeddings_task: Task # [torch.Tensor]
175244

176245
def arguments(self):
177246
return {
178247
"weights": self.weights_task,
179248
"knn": self.knn_task,
249+
"embeddings": self.embeddings_task,
180250
}
181251

182252
def uses_accelerator(self):
@@ -188,9 +258,11 @@ def priority(self):
188258
def execute(
189259
self,
190260
weights: torch.Tensor,
191-
knn: Tuple[torch.Tensor, torch.Tensor],
261+
knn: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
262+
embeddings: torch.Tensor,
192263
):
193-
_, knn_embeddings = knn
264+
_, _, knn_indices = knn
265+
knn_embeddings = embeddings[knn_indices]
194266
return torch.sum(weights.unsqueeze(-1) * knn_embeddings, dim=0)
195267

196268

@@ -243,6 +315,9 @@ def uses_accelerator(self):
243315
def execute(self, embeddings: torch.Tensor):
244316
return embeddings[self.index]
245317

318+
def priority(self):
319+
return 1
320+
246321

247322
class MultiIndexedEmbeddingTask(Task[torch.Tensor]):
248323
embeddings: Task
@@ -257,9 +332,6 @@ def uses_accelerator(self):
257332
def execute(self, embeddings: torch.Tensor):
258333
return torch.stack([embeddings[i] for i in self.indices], dim=0)
259334

260-
def main_thread_only(self):
261-
return True
262-
263335
def __hash__(self):
264336
# fun fact: hashing a tuple of 100k ints is very very slow
265337
# so just hash the embeddings task and let __eq__ sort it out
@@ -270,6 +342,9 @@ def __eq__(self, other):
270342
return False
271343
return self.indices == other.indices and self.embeddings == other.embeddings
272344

345+
def duplicate_per_gpu(self):
346+
return True
347+
273348

274349
class ZeroTensorTask(Task[torch.Tensor]):
275350
shape: Tuple[int, ...]
@@ -319,6 +394,7 @@ class ApproximationMethod(enum.Enum):
319394
SUBWORD = "subword"
320395
MEAN = "mean"
321396
ZERO = "zero"
397+
ORTHOGONAL_MATCHING_PURSUIT = "omp"
322398

323399

324400
class TokenSurgeonOptions(BaseModel):
@@ -440,15 +516,15 @@ def plan_embedding(
440516
optional=weight_info.optional,
441517
aliases=weight_info.aliases,
442518
tied_names=weight_info.tied_names,
443-
force_main_thread=True,
519+
per_gpu=True,
444520
)
445521
t_donor_embed = LoadTensor(
446522
model=options.donor,
447523
tensor=weight_info.name,
448524
optional=weight_info.optional,
449525
aliases=weight_info.aliases,
450526
tied_names=weight_info.tied_names,
451-
force_main_thread=True,
527+
per_gpu=True,
452528
)
453529
t_e_c_0 = MultiIndexedEmbeddingTask(
454530
embeddings=t_original_embed,
@@ -496,16 +572,16 @@ def plan_embedding(
496572
cosine_similarity=options.cosine_similarity,
497573
)
498574
if options.barycentric:
499-
weights_task = BarycentricWeightsTask(
575+
omp_task = BarycentricWeightsTask(
500576
target_tensor=IndexedEmbeddingTask(
501577
embeddings=t_donor_embed, index=idx_out
502578
),
503579
knn_task=knn_task,
504580
)
505581
else:
506-
weights_task = DistanceWeightsTask(knn_task=knn_task)
507-
reconstructed_task = ReconstructedEmbeddingTask(
508-
weights_task=weights_task,
582+
omp_task = DistanceWeightsTask(knn_task=knn_task)
583+
reconstructed_task = KnnReconstructedEmbeddingTask(
584+
weights_task=omp_task,
509585
knn_task=knn_task,
510586
embeddings_task=t_e_c_0,
511587
)
@@ -521,6 +597,18 @@ def plan_embedding(
521597
tok_embedding_task = mean_embed_task
522598
elif options.method == ApproximationMethod.ZERO:
523599
tok_embedding_task = ZeroTensorTask(shape=(hidden_size,))
600+
elif options.method == ApproximationMethod.ORTHOGONAL_MATCHING_PURSUIT:
601+
omp_task = OrthogonalMatchingPursuitWeightsTask(
602+
target_tensor_task=IndexedEmbeddingTask(
603+
embeddings=t_donor_embed, index=idx_out
604+
),
605+
common_embeddings_task=t_e_c_1,
606+
k=options.k,
607+
)
608+
tok_embedding_task = OmpReconstructedEmbeddingTask(
609+
omp_task=omp_task,
610+
common_embeddings_task=t_e_c_0,
611+
)
524612
else:
525613
raise RuntimeError(f"Unknown approximation method: {options.method}")
526614

0 commit comments

Comments
 (0)