Skip to content

Commit 8507174

Browse files
authored
MAINT: bump to sparse >=0.17 (#318)
1 parent fb03063 commit 8507174

File tree

4 files changed

+28
-25
lines changed

4 files changed

+28
-25
lines changed

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -322,26 +322,28 @@ def capabilities(
322322
dict
323323
Capabilities of the namespace.
324324
"""
325-
if is_pydata_sparse_namespace(xp):
326-
# No __array_namespace_info__(); no indexing by sparse arrays
327-
return {
328-
"boolean indexing": False,
329-
"data-dependent shapes": True,
330-
"max dimensions": None,
331-
}
332325
out = xp.__array_namespace_info__().capabilities()
333-
if is_jax_namespace(xp) and out["boolean indexing"]:
334-
# FIXME https://github.com/jax-ml/jax/issues/27418
335-
# Fixed in jax >=0.6.0
336-
out = out.copy()
337-
out["boolean indexing"] = False
338-
if is_torch_namespace(xp):
326+
if is_pydata_sparse_namespace(xp):
327+
if out["boolean indexing"]:
328+
# FIXME https://github.com/pydata/sparse/issues/876
329+
# boolean indexing is supported, but not when the index is a sparse array.
330+
# boolean indexing by list or numpy array is not part of the Array API.
331+
out = out.copy()
332+
out["boolean indexing"] = False
333+
elif is_jax_namespace(xp):
334+
if out["boolean indexing"]: # pragma: no cover
335+
# Backwards compatibility with jax <0.6.0
336+
# https://github.com/jax-ml/jax/issues/27418
337+
out = out.copy()
338+
out["boolean indexing"] = False
339+
elif is_torch_namespace(xp):
339340
# FIXME https://github.com/data-apis/array-api/issues/945
340341
device = xp.get_default_device() if device is None else xp.device(device)
341342
if device.type == "meta": # type: ignore[union-attr] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
342343
out = out.copy()
343344
out["boolean indexing"] = False
344345
out["data-dependent shapes"] = False
346+
345347
return out
346348

347349

tests/test_funcs.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def test_complex(self, xp: ModuleType):
416416
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
417417
xp_assert_close(actual, expect)
418418

419-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="matmul with nan fillvalue")
419+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="sparse#877")
420420
def test_empty(self, xp: ModuleType):
421421
with warnings.catch_warnings(record=True):
422422
warnings.simplefilter("always", RuntimeWarning)
@@ -451,7 +451,7 @@ def test_xp(self, xp: ModuleType):
451451
)
452452

453453

454-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="backend doesn't have arange")
454+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)
455455
class TestOneHot:
456456
@pytest.mark.parametrize("n_dim", range(4))
457457
@pytest.mark.parametrize("num_classes", [1, 3, 10])
@@ -816,7 +816,7 @@ def test_bool_dtype(self, xp: ModuleType):
816816
isclose(xp.asarray(True), b, atol=1), xp.asarray([True, True, True])
817817
)
818818

819-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
819+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
820820
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
821821
def test_none_shape(self, xp: ModuleType):
822822
a = xp.asarray([1, 5, 0])
@@ -825,7 +825,7 @@ def test_none_shape(self, xp: ModuleType):
825825
a = a[a < 5]
826826
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
827827

828-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
828+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
829829
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="unknown shape")
830830
def test_none_shape_bool(self, xp: ModuleType):
831831
a = xp.asarray([True, True, False])
@@ -1141,10 +1141,10 @@ def test_xp(self, xp: ModuleType):
11411141

11421142

11431143
class TestSinc:
1144-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no linspace")
11451144
def test_simple(self, xp: ModuleType):
11461145
xp_assert_equal(sinc(xp.asarray(0.0)), xp.asarray(1.0))
1147-
w = sinc(xp.linspace(-1, 1, 100))
1146+
x = xp.asarray(np.linspace(-1, 1, 100))
1147+
w = sinc(x)
11481148
# check symmetry
11491149
xp_assert_close(w, xp.flip(w, axis=0))
11501150

@@ -1153,11 +1153,12 @@ def test_dtype(self, xp: ModuleType, x: int | complex):
11531153
with pytest.raises(ValueError, match="real floating data type"):
11541154
_ = sinc(xp.asarray(x))
11551155

1156-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange")
11571156
def test_3d(self, xp: ModuleType):
1158-
x = xp.reshape(xp.arange(18, dtype=xp.float64), (3, 3, 2))
1159-
expected = xp.zeros((3, 3, 2), dtype=xp.float64)
1160-
expected = at(expected)[0, 0, 0].set(1.0)
1157+
x = np.arange(18, dtype=np.float64).reshape((3, 3, 2))
1158+
expected = np.zeros_like(x)
1159+
expected[0, 0, 0] = 1
1160+
x = xp.asarray(x)
1161+
expected = xp.asarray(expected)
11611162
xp_assert_close(sinc(x), expected, atol=1e-15)
11621163

11631164
def test_device(self, xp: ModuleType, device: Device):

tests/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def override(func):
4040
lazy_xp_function(in1d, jax_jit=False)
4141

4242

43-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse")
43+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse")
4444
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="no unique_inverse")
4545
class TestIn1D:
4646
# cover both code paths

tests/test_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_assert_less(self, xp: ModuleType):
140140
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
141141

142142
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
143-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
143+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
144144
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
145145
def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
146146
"""On Dask and other lazy backends, test that a shape with NaN's or None's

0 commit comments

Comments
 (0)