Skip to content

Commit a50f202

Browse files
authored
Merge pull request #385 from ev-br/pr/192_
`test_take`: test axis being optional if `x.ndim==1`
2 parents 2e91ca4 + 3bbe3d4 commit a50f202

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

array_api_tests/test_indexing_functions.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,39 @@
1717
)
1818
def test_take(x, data):
1919
# TODO:
20-
# * negative axis
2120
# * negative indices
2221
# * different dtypes for indices
23-
axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
22+
23+
# axis is optional but only if x.ndim == 1
24+
_axis_st = st.integers(-x.ndim, max(x.ndim - 1, 0))
25+
if x.ndim == 1:
26+
kw = data.draw(hh.kwargs(axis=_axis_st))
27+
else:
28+
kw = {"axis": data.draw(_axis_st)}
29+
axis = kw.get("axis", 0)
2430
_indices = data.draw(
2531
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
2632
label="_indices",
2733
)
34+
n_axis = axis if axis>=0 else x.ndim + axis
2835
indices = xp.asarray(_indices, dtype=dh.default_int)
2936
note(f"{indices=}")
3037

31-
out = xp.take(x, indices, axis=axis)
38+
out = xp.take(x, indices, **kw)
3239

3340
ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
3441
ph.assert_shape(
3542
"take",
3643
out_shape=out.shape,
37-
expected=x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
44+
expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:],
3845
kw=dict(
3946
x=x,
4047
indices=indices,
4148
axis=axis,
4249
),
4350
)
4451
out_indices = sh.ndindex(out.shape)
45-
axis_indices = list(sh.axis_ndindex(x.shape, axis))
52+
axis_indices = list(sh.axis_ndindex(x.shape, n_axis))
4653
for axis_idx in axis_indices:
4754
f_axis_idx = sh.fmt_idx("x", axis_idx)
4855
for i in _indices:
@@ -62,7 +69,6 @@ def test_take(x, data):
6269
next(out_indices)
6370

6471

65-
6672
@pytest.mark.unvectorized
6773
@pytest.mark.min_version("2024.12")
6874
@given(

0 commit comments

Comments
 (0)