@@ -95,7 +95,7 @@ def approximate_from_landmarks(
95
95
weights = 1 - distances
96
96
else :
97
97
weights = 1 / distances .clamp_min (1e-6 )
98
- weights = weights / weights .sum (dim = 1 , keepdim = True )
98
+ weights = weights / weights .sum (dim = 1 , keepdim = True ). clamp_min ( 1e-6 )
99
99
elif scheme == WeightingScheme .BARYCENTRIC :
100
100
weights = barycentric_weights (targets , points )
101
101
elif scheme == WeightingScheme .LEAST_SQUARES :
@@ -179,52 +179,65 @@ def common_interp_approximate(
179
179
return weights , indices , res
180
180
181
181
182
- def batch_omp (targets : torch .Tensor , pts : torch .Tensor , k : int ):
182
+ def batch_omp (
183
+ targets : torch .Tensor , candidate_points : torch .Tensor , k : int
184
+ ) -> Tuple [torch .LongTensor , torch .Tensor ]:
183
185
"""
184
- Batched Orthogonal Matching Pursuit (OMP) to select `k` points from `pts ` that best approximate each target in `targets`.
186
+ Batched Orthogonal Matching Pursuit (OMP) to select `k` points from `candidate_points ` that best approximate each target in `targets`.
185
187
186
188
Args:
187
189
targets: (B, D) tensor of target vectors.
188
- pts : (N, D) tensor of candidate points.
190
+ candidate_points : (N, D) tensor of candidate points.
189
191
k: Number of points to select (sparsity level).
190
192
191
193
Returns:
192
- (B, k) tensor of indices selected for each target.
194
+ selected_indices: (B, k) tensor of indices selected for each target.
195
+ coeff: (B, k) tensor of coefficients for each selected point.
193
196
"""
194
197
B , D = targets .shape
195
- N , _ = pts .shape
198
+ N , _ = candidate_points .shape
196
199
device = targets .device
200
+ if k > N :
201
+ raise ValueError (f"Cannot select { k } points from { N } candidates" )
202
+ work_dtype = (
203
+ targets .dtype
204
+ if targets .dtype in (torch .float32 , torch .float64 )
205
+ else torch .float32
206
+ )
197
207
# Initialize selected indices and residuals
198
208
selected_indices = torch .zeros ((B , k ), dtype = torch .long , device = device )
199
- residuals = targets .clone ()
209
+ targets_work = targets .to (dtype = work_dtype )
210
+ residuals = targets_work .clone ()
211
+ points_work = candidate_points .to (dtype = work_dtype )
212
+ mask = torch .zeros ((B , N ), dtype = torch .bool , device = device )
213
+
200
214
for t in range (k ):
201
- LOG . debug ( f"OMP iteration { t } - current rms: { residuals .norm (dim = 1 ).mean () } " )
215
+ rms_0 = residuals .norm (dim = 1 ).mean ()
202
216
# Compute absolute inner products between residuals and points
203
- abs_inner = (residuals @ pts .T ).abs () # (B, N)
204
- # Mask previously selected indices
205
- if t > 0 :
206
- mask = torch .zeros ((B , N ), dtype = torch .bool , device = device )
207
- mask .scatter_ (1 , selected_indices [:, :t ], True )
208
- abs_inner = abs_inner .masked_fill (mask , - torch .inf )
217
+ abs_inner = (residuals @ points_work .T ).abs () # (B, N)
218
+ # Mask out already selected points
219
+ abs_inner .masked_fill_ (mask , - float ("inf" ))
220
+
209
221
# Select new index with maximum correlation
210
222
_ , new_idx = torch .max (abs_inner , dim = 1 ) # (B,)
211
223
selected_indices [:, t ] = new_idx
224
+
225
+ # Update mask
226
+ mask [torch .arange (B , device = device ), new_idx ] = True
227
+
212
228
# Gather selected points (B, t+1, D)
213
- batch_indices = selected_indices [:, : t + 1 ].unsqueeze (- 1 ).expand (- 1 , - 1 , D )
214
- selected_points = torch .gather (
215
- pts .unsqueeze (0 ).expand (B , - 1 , - 1 ), 1 , batch_indices
216
- )
217
- selected_points_transposed = selected_points .transpose (1 , 2 ) # Fix here
229
+ selected_points = points_work [selected_indices [:, : t + 1 ]]
218
230
# Solve least squares
219
231
coeff = torch .linalg .lstsq (
220
- selected_points_transposed . float ( ), # (B, D, t+1)
221
- targets .unsqueeze (- 1 ). float ( ), # (B, D, 1)
232
+ selected_points . transpose ( 1 , 2 ), # (B, D, t+1)
233
+ targets_work .unsqueeze (- 1 ), # (B, D, 1)
222
234
).solution .squeeze (
223
235
- 1
224
236
) # (B, t+1)
225
237
# Update residuals
226
- approx = torch .bmm (coeff .unsqueeze (1 ), selected_points .float ()).squeeze (1 )
227
- residuals = targets - approx .to (targets .dtype )
238
+ approx = torch .bmm (coeff .unsqueeze (1 ), selected_points ).squeeze (1 )
239
+ residuals = targets_work - approx
240
+ LOG .debug (f"OMP iteration { t } : RMS { rms_0 } -> { residuals .norm (dim = 1 ).mean ()} " )
228
241
return selected_indices , coeff
229
242
230
243
@@ -432,7 +445,11 @@ def compute_new_embeddings(
432
445
]
433
446
targets = donor_embed [torch .tensor ([donor_vocab [t ] for t in target_tokens ])]
434
447
indices , coeffs = batch_omp (targets , donor_shared_embeds , options .k )
435
- return torch .bmm (coeffs .unsqueeze (1 ), orig_shared_embeds [indices ]).squeeze (1 )
448
+ return (
449
+ torch .bmm (coeffs .unsqueeze (1 ), orig_shared_embeds [indices ].to (torch .float ))
450
+ .squeeze (1 )
451
+ .to (orig_embed .dtype )
452
+ )
436
453
elif options .method == ApproximationMethod .SUBWORD :
437
454
raise NotImplementedError ("Subword approximation not yet implemented" )
438
455
else :
0 commit comments