Skip to content

Commit b961335

Browse files
committed
Barycentric working too
1 parent 0062f15 commit b961335

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

mergekit/multigpu_executor.py

-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def update_progress():
155155

156156
# Run main thread tasks
157157
if self.trailing_main_handles:
158-
logger.debug("Running trailing tasks on main thread")
159158
exec = Executor(
160159
self.trailing_main_handles,
161160
math_device=self.storage_device or torch.device("cpu"),

mergekit/scripts/tokensurgeon.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def arguments(self):
9494
def uses_accelerator(self):
9595
return True
9696

97+
def priority(self):
98+
return 10
99+
97100
def execute(self, target: torch.Tensor, common_embeddings: torch.Tensor):
98101
if self.cosine_similarity:
99102
distances = 1 - torch.nn.functional.cosine_similarity(
@@ -127,7 +130,9 @@ def execute(self, target: torch.Tensor, knn: Tuple[torch.Tensor, torch.Tensor]):
127130
# Find least squares barycentric weights
128131
# Constrain sum of weights to 1 by adding a row of 1s
129132
constraint_row = torch.ones(
130-
(1, knn_embeddings.shape[0]), device=target.device
133+
(1, knn_embeddings.shape[0]),
134+
device=target.device,
135+
dtype=knn_embeddings.dtype,
131136
) # (1, k)
132137
knn_e_c = torch.cat([knn_embeddings.T, constraint_row], dim=0)
133138
e_c = torch.cat(
@@ -139,8 +144,11 @@ def execute(self, target: torch.Tensor, knn: Tuple[torch.Tensor, torch.Tensor]):
139144
# torch.linalg.lstsq doesn't work for rank-deficient matrices on CUDA
140145
# despite it being explicitly recommended for this use case in the docs
141146
# so pinv instead
142-
weights = torch.linalg.pinv(knn_e_c, rcond=1e-6) @ e_c
143-
return weights[:-1]
147+
# also upcast to float32 for stability
148+
weights = torch.linalg.pinv(knn_e_c.to(torch.float32), rcond=1e-6) @ e_c.to(
149+
torch.float32
150+
)
151+
return weights[:-1].to(target.dtype)
144152

145153

146154
class DistanceWeightsTask(Task[torch.Tensor]):
@@ -488,7 +496,10 @@ def plan_embedding(
488496
)
489497
if options.barycentric:
490498
weights_task = BarycentricWeightsTask(
491-
target_tensor=t_donor_embed, knn_task=knn_task
499+
target_tensor=IndexedEmbeddingTask(
500+
embeddings=t_donor_embed, index=idx_out
501+
),
502+
knn_task=knn_task,
492503
)
493504
else:
494505
weights_task = DistanceWeightsTask(knn_task=knn_task)

0 commit comments

Comments
 (0)