Skip to content

Commit ddfc14b

Browse files
qqaatwpytorchmergebot
authored andcommitted
[MPS] Fix where (pytorch#151176)
Fixes pytorch#150967 Pull Request resolved: pytorch#151176 Approved by: https://github.com/kulinseth, https://github.com/malfet
1 parent 8494d55 commit ddfc14b

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

aten/src/ATen/native/mps/operations/TensorCompare.mm

+14-1
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,11 @@ static void where_kernel_mps(TensorIterator& iter) {
421421
return;
422422
}
423423

424+
Tensor out_;
425+
if (needsGather(out)) {
426+
out_ = out.contiguous();
427+
}
428+
424429
// Derive from MPSCachedGraph
425430
struct CachedGraph : public MPSCachedGraph {
426431
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
@@ -459,11 +464,19 @@ static void where_kernel_mps(TensorIterator& iter) {
459464
Placeholder(cachedGraph->selfTensor_, self, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, selfDataType);
460465
Placeholder otherPlaceholder =
461466
Placeholder(cachedGraph->otherTensor_, other, /*mpsShape=*/nullptr, /*gatherTensorData=*/true, otherDataType);
462-
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
467+
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_,
468+
needsGather(out) ? out_ : out,
469+
/*mpsShape=*/nullptr,
470+
/*gatherTensorData=*/needsGather(out),
471+
getMPSScalarType(out.scalar_type()));
463472

464473
auto feeds = dictionaryFromPlaceholders(conditionPlaceholder, selfPlaceholder, otherPlaceholder);
465474
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
466475
}
476+
477+
if (needsGather(out)) {
478+
out.copy_(out_);
479+
}
467480
}
468481

469482
Tensor& nan_to_num_out_mps(const Tensor& self,

test/test_mps.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -12937,15 +12937,9 @@ def tearDownClass(cls):
1293712937
def test_numpy_ref_mps(self, device, dtype, op):
1293812938
# Unlike `test_numpy_ref`, this test compares in `float32` since at the time of this test's creation MPS
1293912939
# does not support float64 Tensors.
12940-
# A few ops are currently broken on their reference inputs, but not their sample inputs. These should
12941-
# get patched up and this workaround removed.
12942-
broken_on_ref_inputs = op.name in ('where',)
1294312940

1294412941
# TODO: Enable per-sample seed setting and tweak tolerances / fix xfails
12945-
inputs = (
12946-
op.reference_inputs(device, dtype, set_seed=False) if not broken_on_ref_inputs
12947-
else op.sample_inputs(device, dtype, set_seed=False)
12948-
)
12942+
inputs = op.reference_inputs(device, dtype, set_seed=False)
1294912943
for sample_input in inputs:
1295012944
self.compare_with_reference(op, op.ref, sample_input)
1295112945

torch/testing/_internal/common_methods_invocations.py

+4
Original file line numberDiff line numberDiff line change
@@ -7800,6 +7800,10 @@ def reference_inputs_where(op, device, dtype, requires_grad, **kwargs):
78007800
# NOTE that the OpInfo for where takes samples of the form a, cond, b
78017801
yield SampleInput(a, args=(c, b))
78027802

7803+
# MPS does not support float64, which causes issues in the following tests
7804+
if torch.device(device).type == "mps":
7805+
return
7806+
78037807
# type promoting
78047808
# FIXME(rec): shouldn't other_dtype be used two lines below?
78057809
other_dtype = torch.double if dtype is not torch.double else torch.long # noqa: F841

0 commit comments

Comments
 (0)