@@ -94,6 +94,9 @@ def arguments(self):
94
94
def uses_accelerator (self ):
95
95
return True
96
96
97
+ def priority (self ):
98
+ return 10
99
+
97
100
def execute (self , target : torch .Tensor , common_embeddings : torch .Tensor ):
98
101
if self .cosine_similarity :
99
102
distances = 1 - torch .nn .functional .cosine_similarity (
@@ -127,7 +130,9 @@ def execute(self, target: torch.Tensor, knn: Tuple[torch.Tensor, torch.Tensor]):
127
130
# Find least squares barycentric weights
128
131
# Constrain sum of weights to 1 by adding a row of 1s
129
132
constraint_row = torch .ones (
130
- (1 , knn_embeddings .shape [0 ]), device = target .device
133
+ (1 , knn_embeddings .shape [0 ]),
134
+ device = target .device ,
135
+ dtype = knn_embeddings .dtype ,
131
136
) # (1, k)
132
137
knn_e_c = torch .cat ([knn_embeddings .T , constraint_row ], dim = 0 )
133
138
e_c = torch .cat (
@@ -139,8 +144,11 @@ def execute(self, target: torch.Tensor, knn: Tuple[torch.Tensor, torch.Tensor]):
139
144
# torch.linalg.lstsq doesn't work for rank-deficient matrices on CUDA
140
145
# despite it being explicitly recommended for this use case in the docs
141
146
# so pinv instead
142
- weights = torch .linalg .pinv (knn_e_c , rcond = 1e-6 ) @ e_c
143
- return weights [:- 1 ]
147
+ # also upcast to float32 for stability
148
+ weights = torch .linalg .pinv (knn_e_c .to (torch .float32 ), rcond = 1e-6 ) @ e_c .to (
149
+ torch .float32
150
+ )
151
+ return weights [:- 1 ].to (target .dtype )
144
152
145
153
146
154
class DistanceWeightsTask (Task [torch .Tensor ]):
@@ -488,7 +496,10 @@ def plan_embedding(
488
496
)
489
497
if options .barycentric :
490
498
weights_task = BarycentricWeightsTask (
491
- target_tensor = t_donor_embed , knn_task = knn_task
499
+ target_tensor = IndexedEmbeddingTask (
500
+ embeddings = t_donor_embed , index = idx_out
501
+ ),
502
+ knn_task = knn_task ,
492
503
)
493
504
else :
494
505
weights_task = DistanceWeightsTask (knn_task = knn_task )
0 commit comments