diff --git a/kornia/utils/helpers.py b/kornia/utils/helpers.py index e33405fd01..d6381c2b19 100644 --- a/kornia/utils/helpers.py +++ b/kornia/utils/helpers.py @@ -248,11 +248,15 @@ def _torch_linalg_svdvals(input: Tensor) -> Tensor: def _torch_solve_cast(A: Tensor, B: Tensor) -> Tensor: """Make torch.solve work with other than fp32/64. - For stable operation, the input matrices should be cast to fp64, and the output will be cast back to the input - dtype. + For stable operation, the input matrices should be cast to fp64, and the output will + be cast back to the input dtype. However, fp64 is not yet supported on MPS. """ - # cast to fp64 and solve - out = torch.linalg.solve(A.to(torch.float64), B.to(torch.float64)) + if is_mps_tensor_safe(A): + dtype = torch.float32 + else: + dtype = torch.float64 + + out = torch.linalg.solve(A.to(dtype), B.to(dtype)) # cast back to the input dtype return out.to(A.dtype)