Skip to content

Commit f036121

Browse files
authored
Utils: support resize on MPS (kornia#3145)
* Utils: support resize on MPS * MPS support not available in all PyTorch versions
1 parent 5729688 commit f036121

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

kornia/utils/helpers.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,15 @@ def _torch_linalg_svdvals(input: Tensor) -> Tensor:
248248
def _torch_solve_cast(A: Tensor, B: Tensor) -> Tensor:
249249
"""Make torch.solve work with other than fp32/64.
250250
251-
For stable operation, the input matrices should be cast to fp64, and the output will be cast back to the input
252-
dtype.
251+
For stable operation, the input matrices should be cast to fp64, and the output will
252+
be cast back to the input dtype. However, fp64 is not yet supported on MPS.
253253
"""
254-
# cast to fp64 and solve
255-
out = torch.linalg.solve(A.to(torch.float64), B.to(torch.float64))
254+
if is_mps_tensor_safe(A):
255+
dtype = torch.float32
256+
else:
257+
dtype = torch.float64
258+
259+
out = torch.linalg.solve(A.to(dtype), B.to(dtype))
256260

257261
# cast back to the input dtype
258262
return out.to(A.dtype)

0 commit comments

Comments
 (0)