File tree 1 file changed +8
-4
lines changed
1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -248,11 +248,15 @@ def _torch_linalg_svdvals(input: Tensor) -> Tensor:
248
248
def _torch_solve_cast (A : Tensor , B : Tensor ) -> Tensor :
249
249
"""Make torch.solve work with other than fp32/64.
250
250
251
- For stable operation, the input matrices should be cast to fp64, and the output will be cast back to the input
252
- dtype.
251
+ For stable operation, the input matrices should be cast to fp64, and the output will
252
+ be cast back to the input dtype. However, fp64 is not yet supported on MPS .
253
253
"""
254
- # cast to fp64 and solve
255
- out = torch .linalg .solve (A .to (torch .float64 ), B .to (torch .float64 ))
254
+ if is_mps_tensor_safe (A ):
255
+ dtype = torch .float32
256
+ else :
257
+ dtype = torch .float64
258
+
259
+ out = torch .linalg .solve (A .to (dtype ), B .to (dtype ))
256
260
257
261
# cast back to the input dtype
258
262
return out .to (A .dtype )
You can’t perform that action at this time.
0 commit comments