Skip to content

Commit 5dc8023

Browse files
authoredMar 12, 2025
Fix tensor shape change in della pruning (#531)
Resolves #528.
1 parent efd4ea0 commit 5dc8023

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed
 

‎mergekit/sparsify.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def della_magprune(
145145
return tensor
146146
if density <= 0:
147147
return torch.zeros_like(tensor)
148+
orig_shape = tensor.shape
148149

149150
if density + epsilon >= 1 or density - epsilon <= 0:
150151
raise ValueError(
@@ -171,7 +172,7 @@ def della_magprune(
171172
mask = torch.bernoulli(probs).to(work_dtype)
172173

173174
res = rescaled_masked_tensor(tensor.to(work_dtype), mask, rescale_norm)
174-
return res.view_as(tensor)
175+
return res.to(tensor.dtype).reshape(orig_shape)
175176

176177

177178
def sparsify(

0 commit comments

Comments
 (0)