17
17
)
18
18
def test_take (x , data ):
19
19
# TODO:
20
- # * negative axis
21
20
# * negative indices
22
21
# * different dtypes for indices
23
- axis = data .draw (st .integers (0 , max (x .ndim - 1 , 0 )), label = "axis" )
22
+
23
+ # axis is optional but only if x.ndim == 1
24
+ _axis_st = st .integers (- x .ndim , max (x .ndim - 1 , 0 ))
25
+ if x .ndim == 1 :
26
+ kw = data .draw (hh .kwargs (axis = _axis_st ))
27
+ else :
28
+ kw = {"axis" : data .draw (_axis_st )}
29
+ axis = kw .get ("axis" , 0 )
24
30
_indices = data .draw (
25
31
st .lists (st .integers (0 , x .shape [axis ] - 1 ), min_size = 1 , unique = True ),
26
32
label = "_indices" ,
27
33
)
34
+ n_axis = axis if axis >= 0 else x .ndim + axis
28
35
indices = xp .asarray (_indices , dtype = dh .default_int )
29
36
note (f"{ indices = } " )
30
37
31
- out = xp .take (x , indices , axis = axis )
38
+ out = xp .take (x , indices , ** kw )
32
39
33
40
ph .assert_dtype ("take" , in_dtype = x .dtype , out_dtype = out .dtype )
34
41
ph .assert_shape (
35
42
"take" ,
36
43
out_shape = out .shape ,
37
- expected = x .shape [:axis ] + (len (_indices ),) + x .shape [axis + 1 :],
44
+ expected = x .shape [:n_axis ] + (len (_indices ),) + x .shape [n_axis + 1 :],
38
45
kw = dict (
39
46
x = x ,
40
47
indices = indices ,
41
48
axis = axis ,
42
49
),
43
50
)
44
51
out_indices = sh .ndindex (out .shape )
45
- axis_indices = list (sh .axis_ndindex (x .shape , axis ))
52
+ axis_indices = list (sh .axis_ndindex (x .shape , n_axis ))
46
53
for axis_idx in axis_indices :
47
54
f_axis_idx = sh .fmt_idx ("x" , axis_idx )
48
55
for i in _indices :
@@ -62,7 +69,6 @@ def test_take(x, data):
62
69
next (out_indices )
63
70
64
71
65
-
66
72
@pytest .mark .unvectorized
67
73
@pytest .mark .min_version ("2024.12" )
68
74
@given (
0 commit comments