From b10aea52f5f5802be50071c23f6218be3d8f2c4a Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Sat, 15 Jul 2023 14:21:35 -0500 Subject: [PATCH 1/2] Fix test_take to make axis optional when ndim == 1 I didn't explicitly test axis=None because it's not clear to me that should actually be supported, given that that's the same as axis=0. --- array_api_tests/test_indexing_functions.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index a599d218..a5c3802c 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -20,7 +20,14 @@ def test_take(x, data): # * negative axis # * negative indices # * different dtypes for indices - axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis") + + # axis is optional but only if x.ndim == 1 + _axis_st = st.integers(0, max(x.ndim - 1, 0)) + if x.ndim == 1: + kw = data.draw(hh.kwargs(axis=_axis_st)) + else: + kw = {"axis": data.draw(_axis_st)} + axis = kw.get("axis", 0) _indices = data.draw( st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True), label="_indices", @@ -28,7 +35,7 @@ def test_take(x, data): indices = xp.asarray(_indices, dtype=dh.default_int) note(f"{indices=}") - out = xp.take(x, indices, axis=axis) + out = xp.take(x, indices, **kw) ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_shape( From 3bbe3d4ee31192d7eb78bd59b5e92bd9bef46a44 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 17 Jun 2025 10:21:41 +0200 Subject: [PATCH 2/2] ENH: take: test axis<0 --- array_api_tests/test_indexing_functions.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py index a5c3802c..7b8c8763 100644 --- a/array_api_tests/test_indexing_functions.py +++ b/array_api_tests/test_indexing_functions.py @@ -17,12 +17,11 @@ ) def test_take(x, data): # TODO: - # * negative axis # * negative indices # * different dtypes for indices # axis is optional but only if x.ndim == 1 - _axis_st = st.integers(0, max(x.ndim - 1, 0)) + _axis_st = st.integers(-x.ndim, max(x.ndim - 1, 0)) if x.ndim == 1: kw = data.draw(hh.kwargs(axis=_axis_st)) else: @@ -32,6 +31,7 @@ def test_take(x, data): st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True), label="_indices", ) + n_axis = axis if axis>=0 else x.ndim + axis indices = xp.asarray(_indices, dtype=dh.default_int) note(f"{indices=}") @@ -41,7 +41,7 @@ def test_take(x, data): ph.assert_shape( "take", out_shape=out.shape, - expected=x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :], + expected=x.shape[:n_axis] + (len(_indices),) + x.shape[n_axis + 1:], kw=dict( x=x, indices=indices, @@ -49,7 +49,7 @@ def test_take(x, data): ), ) out_indices = sh.ndindex(out.shape) - axis_indices = list(sh.axis_ndindex(x.shape, axis)) + axis_indices = list(sh.axis_ndindex(x.shape, n_axis)) for axis_idx in axis_indices: f_axis_idx = sh.fmt_idx("x", axis_idx) for i in _indices: @@ -69,7 +69,6 @@ def test_take(x, data): next(out_indices) - @pytest.mark.unvectorized @pytest.mark.min_version("2024.12") @given(