@@ -81,7 +81,7 @@ def get(self, model: ModelReference) -> transformers.PreTrainedTokenizerBase:
81
81
return self .loaded [model ]
82
82
83
83
84
- class EmbeddingKnnTask (Task [Tuple [torch .Tensor , torch .Tensor ]]):
84
+ class EmbeddingKnnTask (Task [Tuple [torch .Tensor , torch .Tensor , torch . Tensor ]]):
85
85
target_tensor : Task
86
86
common_embedding : Task
87
87
k : int
@@ -110,12 +110,12 @@ def execute(self, target: torch.Tensor, common_embeddings: torch.Tensor):
110
110
).squeeze ()
111
111
distances , indices = torch .topk (distances , self .k , largest = False )
112
112
knn_embeddings = common_embeddings [indices ]
113
- return distances , knn_embeddings
113
+ return distances , knn_embeddings , indices
114
114
115
115
116
116
class BarycentricWeightsTask (Task [torch .Tensor ]):
117
117
target_tensor : Task # [torch.Tensor]
118
- knn_task : Task # [Tuple[torch.Tensor, torch.Tensor]]
118
+ knn_task : Task # [Tuple[torch.Tensor, torch.Tensor, torch.Tensor ]]
119
119
120
120
def arguments (self ):
121
121
return {
@@ -126,8 +126,11 @@ def arguments(self):
126
126
def uses_accelerator (self ):
127
127
return True
128
128
129
+ def priority (self ):
130
+ return 11
131
+
129
132
def execute (self , target : torch .Tensor , knn : Tuple [torch .Tensor , torch .Tensor ]):
130
- distances , knn_embeddings = knn
133
+ _ , knn_embeddings , _ = knn
131
134
132
135
# Find least squares barycentric weights
133
136
# Constrain sum of weights to 1 by adding a row of 1s
@@ -147,14 +150,20 @@ def execute(self, target: torch.Tensor, knn: Tuple[torch.Tensor, torch.Tensor]):
147
150
# despite it being explicitly recommended for this use case in the docs
148
151
# so pinv instead
149
152
# also upcast to float32 for stability
150
- weights = torch .linalg .pinv (knn_e_c .to (torch .float32 ), rcond = 1e-6 ) @ e_c .to (
153
+ weights = torch .linalg .pinv (knn_e_c .to (torch .float32 ), rcond = 1e-8 ) @ e_c .to (
151
154
torch .float32
152
155
)
153
- return weights [:- 1 ].to (target .dtype )
156
+ if torch .isnan (weights ).any ():
157
+ # try again with slight ridge regression
158
+ weights = torch .linalg .pinv (
159
+ knn_e_c .to (torch .float32 ) + 1e-6 * torch .eye (knn_e_c .shape [0 ]),
160
+ rcond = 1e-8 ,
161
+ ) @ e_c .to (torch .float32 )
162
+ return weights .squeeze (- 1 )
154
163
155
164
156
165
class DistanceWeightsTask (Task [torch .Tensor ]):
157
- knn_task : Task # [Tuple[torch.Tensor, torch.Tensor]]
166
+ knn_task : Task # [Tuple[torch.Tensor, torch.Tensor, torch.Tensor ]]
158
167
159
168
def arguments (self ):
160
169
return {
@@ -164,19 +173,80 @@ def arguments(self):
164
173
def uses_accelerator (self ):
165
174
return True
166
175
176
+ def priority (self ):
177
+ return 11
178
+
167
179
def execute (self , knn : Tuple [torch .Tensor , torch .Tensor ]):
168
- distances , _ = knn
180
+ distances , _ , _ = knn
169
181
return torch .nn .functional .softmin (distances , dim = 0 )
170
182
171
183
172
- class ReconstructedEmbeddingTask (Task [torch .Tensor ]):
184
+ class OrthogonalMatchingPursuitWeightsTask (Task [Tuple [torch .LongTensor , torch .Tensor ]]):
185
+ target_tensor_task : Task # [torch.Tensor]
186
+ common_embeddings_task : Task # [torch.Tensor]
187
+ k : int
188
+
189
+ def arguments (self ):
190
+ return {
191
+ "target" : self .target_tensor_task ,
192
+ "common_embeddings" : self .common_embeddings_task ,
193
+ }
194
+
195
+ def uses_accelerator (self ):
196
+ return True
197
+
198
+ def priority (self ):
199
+ return 10
200
+
201
+ def execute (self , target : torch .Tensor , common_embeddings : torch .Tensor ):
202
+ residual = target .clone ()
203
+ selected = []
204
+ for _ in range (self .k ):
205
+ idx = torch .argmax (torch .abs (residual @ common_embeddings .T ))
206
+ selected .append (idx )
207
+ B = common_embeddings [selected , :].T
208
+ # pinv because rank-deficient and linalg.lstsq chokes on CUDA
209
+ coeffs = torch .linalg .pinv (B .to (torch .float32 )) @ target .to (torch .float32 )
210
+ residual = target - B @ coeffs .to (target .dtype )
211
+
212
+ return selected , coeffs
213
+
214
+
215
+ class OmpReconstructedEmbeddingTask (Task [torch .Tensor ]):
216
+ omp_task : Task # [Tuple[torch.LongTensor, torch.Tensor]]
217
+ common_embeddings_task : Task # [torch.Tensor]
218
+
219
+ def arguments (self ):
220
+ return {
221
+ "omp" : self .omp_task ,
222
+ "common_embeddings" : self .common_embeddings_task ,
223
+ }
224
+
225
+ def uses_accelerator (self ):
226
+ return True
227
+
228
+ def priority (self ):
229
+ return 100
230
+
231
+ def execute (
232
+ self ,
233
+ omp : Tuple [torch .LongTensor , torch .Tensor ],
234
+ common_embeddings : torch .Tensor ,
235
+ ):
236
+ indices , coeffs = omp
237
+ return torch .sum (coeffs .unsqueeze (- 1 ) * common_embeddings [indices ], dim = 0 )
238
+
239
+
240
+ class KnnReconstructedEmbeddingTask (Task [torch .Tensor ]):
173
241
weights_task : Task # [torch.Tensor]
174
- knn_task : Task # [Tuple[torch.Tensor, torch.Tensor]]
242
+ knn_task : Task # [Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
243
+ embeddings_task : Task # [torch.Tensor]
175
244
176
245
def arguments (self ):
177
246
return {
178
247
"weights" : self .weights_task ,
179
248
"knn" : self .knn_task ,
249
+ "embeddings" : self .embeddings_task ,
180
250
}
181
251
182
252
def uses_accelerator (self ):
@@ -188,9 +258,11 @@ def priority(self):
188
258
def execute (
189
259
self ,
190
260
weights : torch .Tensor ,
191
- knn : Tuple [torch .Tensor , torch .Tensor ],
261
+ knn : Tuple [torch .Tensor , torch .Tensor , torch .Tensor ],
262
+ embeddings : torch .Tensor ,
192
263
):
193
- _ , knn_embeddings = knn
264
+ _ , _ , knn_indices = knn
265
+ knn_embeddings = embeddings [knn_indices ]
194
266
return torch .sum (weights .unsqueeze (- 1 ) * knn_embeddings , dim = 0 )
195
267
196
268
@@ -243,6 +315,9 @@ def uses_accelerator(self):
243
315
def execute (self , embeddings : torch .Tensor ):
244
316
return embeddings [self .index ]
245
317
318
+ def priority (self ):
319
+ return 1
320
+
246
321
247
322
class MultiIndexedEmbeddingTask (Task [torch .Tensor ]):
248
323
embeddings : Task
@@ -257,9 +332,6 @@ def uses_accelerator(self):
257
332
def execute (self , embeddings : torch .Tensor ):
258
333
return torch .stack ([embeddings [i ] for i in self .indices ], dim = 0 )
259
334
260
- def main_thread_only (self ):
261
- return True
262
-
263
335
def __hash__ (self ):
264
336
# fun fact: hashing a tuple of 100k ints is very very slow
265
337
# so just hash the embeddings task and let __eq__ sort it out
@@ -270,6 +342,9 @@ def __eq__(self, other):
270
342
return False
271
343
return self .indices == other .indices and self .embeddings == other .embeddings
272
344
345
+ def duplicate_per_gpu (self ):
346
+ return True
347
+
273
348
274
349
class ZeroTensorTask (Task [torch .Tensor ]):
275
350
shape : Tuple [int , ...]
@@ -319,6 +394,7 @@ class ApproximationMethod(enum.Enum):
319
394
SUBWORD = "subword"
320
395
MEAN = "mean"
321
396
ZERO = "zero"
397
+ ORTHOGONAL_MATCHING_PURSUIT = "omp"
322
398
323
399
324
400
class TokenSurgeonOptions (BaseModel ):
@@ -440,15 +516,15 @@ def plan_embedding(
440
516
optional = weight_info .optional ,
441
517
aliases = weight_info .aliases ,
442
518
tied_names = weight_info .tied_names ,
443
- force_main_thread = True ,
519
+ per_gpu = True ,
444
520
)
445
521
t_donor_embed = LoadTensor (
446
522
model = options .donor ,
447
523
tensor = weight_info .name ,
448
524
optional = weight_info .optional ,
449
525
aliases = weight_info .aliases ,
450
526
tied_names = weight_info .tied_names ,
451
- force_main_thread = True ,
527
+ per_gpu = True ,
452
528
)
453
529
t_e_c_0 = MultiIndexedEmbeddingTask (
454
530
embeddings = t_original_embed ,
@@ -496,16 +572,16 @@ def plan_embedding(
496
572
cosine_similarity = options .cosine_similarity ,
497
573
)
498
574
if options .barycentric :
499
- weights_task = BarycentricWeightsTask (
575
+ omp_task = BarycentricWeightsTask (
500
576
target_tensor = IndexedEmbeddingTask (
501
577
embeddings = t_donor_embed , index = idx_out
502
578
),
503
579
knn_task = knn_task ,
504
580
)
505
581
else :
506
- weights_task = DistanceWeightsTask (knn_task = knn_task )
507
- reconstructed_task = ReconstructedEmbeddingTask (
508
- weights_task = weights_task ,
582
+ omp_task = DistanceWeightsTask (knn_task = knn_task )
583
+ reconstructed_task = KnnReconstructedEmbeddingTask (
584
+ weights_task = omp_task ,
509
585
knn_task = knn_task ,
510
586
embeddings_task = t_e_c_0 ,
511
587
)
@@ -521,6 +597,18 @@ def plan_embedding(
521
597
tok_embedding_task = mean_embed_task
522
598
elif options .method == ApproximationMethod .ZERO :
523
599
tok_embedding_task = ZeroTensorTask (shape = (hidden_size ,))
600
+ elif options .method == ApproximationMethod .ORTHOGONAL_MATCHING_PURSUIT :
601
+ omp_task = OrthogonalMatchingPursuitWeightsTask (
602
+ target_tensor_task = IndexedEmbeddingTask (
603
+ embeddings = t_donor_embed , index = idx_out
604
+ ),
605
+ common_embeddings_task = t_e_c_1 ,
606
+ k = options .k ,
607
+ )
608
+ tok_embedding_task = OmpReconstructedEmbeddingTask (
609
+ omp_task = omp_task ,
610
+ common_embeddings_task = t_e_c_0 ,
611
+ )
524
612
else :
525
613
raise RuntimeError (f"Unknown approximation method: { options .method } " )
526
614
0 commit comments