17
17
)
18
18
def test_take (x , data ):
19
19
# TODO:
20
- # * negative axis
21
20
# * negative indices
22
21
# * different dtypes for indices
23
22
24
23
# axis is optional but only if x.ndim == 1
25
- _axis_st = st .integers (0 , max (x .ndim - 1 , 0 ))
24
+ _axis_st = st .integers (- x . ndim , max (x .ndim - 1 , 0 ))
26
25
if x .ndim == 1 :
27
26
kw = data .draw (hh .kwargs (axis = _axis_st ))
28
27
else :
@@ -32,6 +31,7 @@ def test_take(x, data):
32
31
st .lists (st .integers (0 , x .shape [axis ] - 1 ), min_size = 1 , unique = True ),
33
32
label = "_indices" ,
34
33
)
34
+ n_axis = axis if axis >= 0 else x .ndim + axis
35
35
indices = xp .asarray (_indices , dtype = dh .default_int )
36
36
note (f"{ indices = } " )
37
37
@@ -41,15 +41,15 @@ def test_take(x, data):
41
41
ph .assert_shape (
42
42
"take" ,
43
43
out_shape = out .shape ,
44
- expected = x .shape [:axis ] + (len (_indices ),) + x .shape [axis + 1 :],
44
+ expected = x .shape [:n_axis ] + (len (_indices ),) + x .shape [n_axis + 1 :],
45
45
kw = dict (
46
46
x = x ,
47
47
indices = indices ,
48
48
axis = axis ,
49
49
),
50
50
)
51
51
out_indices = sh .ndindex (out .shape )
52
- axis_indices = list (sh .axis_ndindex (x .shape , axis ))
52
+ axis_indices = list (sh .axis_ndindex (x .shape , n_axis ))
53
53
for axis_idx in axis_indices :
54
54
f_axis_idx = sh .fmt_idx ("x" , axis_idx )
55
55
for i in _indices :
@@ -69,7 +69,6 @@ def test_take(x, data):
69
69
next (out_indices )
70
70
71
71
72
-
73
72
@pytest .mark .unvectorized
74
73
@pytest .mark .min_version ("2024.12" )
75
74
@given (
0 commit comments