Skip to content

Commit 4bb0949

Browse files
committed
ENH: one more example of the failure reporting
$ ARRAY_API_TESTS_MODULE=array_api_compat.torch pytest array_api_tests/test_array_object.py::test_getitem
1 parent 0744535 commit 4bb0949

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

array_api_tests/test_array_object.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,24 +85,29 @@ def test_getitem(shape, dtype, data):
8585
note(f"{x=}")
8686
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
8787

88-
out = x[key]
89-
90-
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
91-
_key = normalize_key(key, shape)
92-
axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape)
93-
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
94-
out_zero_sided = any(side == 0 for side in expected_shape)
95-
if not zero_sided and not out_zero_sided:
96-
out_obj = []
97-
for idx in product(*axes_indices):
98-
val = obj
99-
for i in idx:
100-
val = val[i]
101-
out_obj.append(val)
102-
out_obj = sh.reshape(out_obj, expected_shape)
103-
expected = xp.asarray(out_obj, dtype=dtype)
104-
ph.assert_array_elements("__getitem__", out=out, expected=expected)
105-
88+
repro_snippet = ph.format_snippet(f"{x!r}[{key!r}]")
89+
90+
try:
91+
out = x[key]
92+
93+
ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
94+
_key = normalize_key(key, shape)
95+
axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape)
96+
ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
97+
out_zero_sided = any(side == 0 for side in expected_shape)
98+
if not zero_sided and not out_zero_sided:
99+
out_obj = []
100+
for idx in product(*axes_indices):
101+
val = obj
102+
for i in idx:
103+
val = val[i]
104+
out_obj.append(val)
105+
out_obj = sh.reshape(out_obj, expected_shape)
106+
expected = xp.asarray(out_obj, dtype=dtype)
107+
ph.assert_array_elements("__getitem__", out=out, expected=expected)
108+
except Exception as exc:
109+
exc.add_note(repro_snippet)
110+
raise
106111

107112
@pytest.mark.unvectorized
108113
@given(

0 commit comments

Comments
 (0)