Skip to content

Commit 3bbe3d4

Browse files
committed
ENH: take: test axis<0
1 parent b10aea5 commit 3bbe3d4

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

array_api_tests/test_indexing_functions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
)
1818
def test_take(x, data):
1919
# TODO:
20-
# * negative axis
2120
# * negative indices
2221
# * different dtypes for indices
2322

2423
# axis is optional but only if x.ndim == 1
25-
_axis_st = st.integers(0, max(x.ndim - 1, 0))
24+
_axis_st = st.integers(-x.ndim, max(x.ndim - 1, 0))
2625
if x.ndim == 1:
2726
kw = data.draw(hh.kwargs(axis=_axis_st))
2827
else:
@@ -32,6 +31,7 @@ def test_take(x, data):
3231
st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
3332
label="_indices",
3433
)
34+
n_axis = axis if axis>=0 else x.ndim + axis
3535
indices = xp.asarray(_indices, dtype=dh.default_int)
3636
note(f"{indices=}")
3737

@@ -41,15 +41,15 @@ def test_take(x, data):
4141
ph.assert_shape(
4242
"take",
4343
out_shape=out.shape,
44-
expected=x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
44+
expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:],
4545
kw=dict(
4646
x=x,
4747
indices=indices,
4848
axis=axis,
4949
),
5050
)
5151
out_indices = sh.ndindex(out.shape)
52-
axis_indices = list(sh.axis_ndindex(x.shape, axis))
52+
axis_indices = list(sh.axis_ndindex(x.shape, n_axis))
5353
for axis_idx in axis_indices:
5454
f_axis_idx = sh.fmt_idx("x", axis_idx)
5555
for i in _indices:
@@ -69,7 +69,6 @@ def test_take(x, data):
6969
next(out_indices)
7070

7171

72-
7372
@pytest.mark.unvectorized
7473
@pytest.mark.min_version("2024.12")
7574
@given(

0 commit comments

Comments
 (0)