Skip to content

Commit 49b24e5

Browse files
committed
Merge branch 'main' into nogil
2 parents ae579ff + 13d4a70 commit 49b24e5

File tree

16 files changed

+359
-107
lines changed

16 files changed

+359
-107
lines changed

pixi.lock

Lines changed: 308 additions & 54 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ array-api-extra = { path = ".", editable = true }
5757
typing-extensions = ">=4.13.2"
5858
pre-commit = ">=4.2.0"
5959
pylint = ">=3.3.7"
60-
basedmypy = ">=2.10.0"
60+
mypy = ">=1.16.0"
6161
basedpyright = ">=1.29.2"
6262
numpydoc = ">=1.8.0,<2"
6363
# import dependencies for mypy:
@@ -236,16 +236,17 @@ python_version = "3.10"
236236
warn_unused_configs = true
237237
strict = true
238238
enable_error_code = ["ignore-without-code", "truthy-bool"]
239-
# https://github.com/data-apis/array-api-typing
240-
disallow_any_expr = false
241-
# false positives with input validation
242-
disable_error_code = ["redundant-expr", "unreachable", "no-any-return"]
239+
disable_error_code = ["no-any-return"]
243240

244241
[[tool.mypy.overrides]]
245242
# slow or unavailable on Windows; do not add to the lint env
246243
module = ["cupy.*", "jax.*", "sparse.*", "torch.*"]
247244
ignore_missing_imports = true
248245

246+
[[tool.mypy.overrides]]
247+
module = ["tests/*"]
248+
disable_error_code = ["no-untyped-def"] # test(...) without -> None
249+
249250
# pyright
250251

251252
[tool.basedpyright]

src/array_api_extra/_lib/_at.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class _AtOp(Enum):
3737
MAX = "max"
3838

3939
# @override from Python 3.12
40-
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride]
40+
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride]
4141
"""
4242
Return string representation (useful for pytest logs).
4343

src/array_api_extra/_lib/_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
4545
"""Check if this backend uses the same module as others."""
4646
return any(self.modname == other.modname for other in others)
4747

48-
def pytest_param(self) -> Any: # type: ignore[explicit-any]
48+
def pytest_param(self) -> Any:
4949
"""
5050
Backend as a pytest parameter
5151

src/array_api_extra/_lib/_funcs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
@overload
37-
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
37+
def apply_where( # numpydoc ignore=GL08
3838
cond: Array,
3939
args: Array | tuple[Array, ...],
4040
f1: Callable[..., Array],
@@ -46,7 +46,7 @@ def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=G
4646

4747

4848
@overload
49-
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
49+
def apply_where( # numpydoc ignore=GL08
5050
cond: Array,
5151
args: Array | tuple[Array, ...],
5252
f1: Callable[..., Array],
@@ -57,7 +57,7 @@ def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=G
5757
) -> Array: ...
5858

5959

60-
def apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,PR02
60+
def apply_where( # numpydoc ignore=PR01,PR02
6161
cond: Array,
6262
args: Array | tuple[Array, ...],
6363
f1: Callable[..., Array],
@@ -143,7 +143,7 @@ def apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,PR02
143143
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
144144

145145

146-
def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
146+
def _apply_where( # numpydoc ignore=PR01,RT01
147147
cond: Array,
148148
f1: Callable[..., Array],
149149
f2: Callable[..., Array] | None,
@@ -813,8 +813,7 @@ def pad(
813813
else:
814814
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))
815815

816-
# https://github.com/python/typeshed/issues/13376
817-
slices: list[slice] = [] # type: ignore[explicit-any]
816+
slices: list[slice] = []
818817
newshape: list[int] = []
819818
for ax, w_tpl in enumerate(pad_width_seq):
820819
if len(w_tpl) != 2:
@@ -826,6 +825,7 @@ def pad(
826825
if w_tpl[0] == 0 and w_tpl[1] == 0:
827826
sl = slice(None, None, None)
828827
else:
828+
stop: int | None
829829
start, stop = w_tpl
830830
stop = None if stop == 0 else -stop
831831

src/array_api_extra/_lib/_lazy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
from numpy.typing import ArrayLike
2424

25-
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[explicit-any]
25+
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic
2626
else:
2727
# Sphinx hack
2828
NumPyObject = Any
@@ -31,7 +31,7 @@
3131

3232

3333
@overload
34-
def lazy_apply( # type: ignore[decorated-any, valid-type]
34+
def lazy_apply( # type: ignore[valid-type]
3535
func: Callable[P, Array | ArrayLike],
3636
*args: Array | complex | None,
3737
shape: tuple[int | None, ...] | None = None,
@@ -43,7 +43,7 @@ def lazy_apply( # type: ignore[decorated-any, valid-type]
4343

4444

4545
@overload
46-
def lazy_apply( # type: ignore[decorated-any, valid-type]
46+
def lazy_apply( # type: ignore[valid-type]
4747
func: Callable[P, Sequence[Array | ArrayLike]],
4848
*args: Array | complex | None,
4949
shape: Sequence[tuple[int | None, ...]],
@@ -313,7 +313,7 @@ def _is_jax_jit_enabled(xp: ModuleType) -> bool: # numpydoc ignore=PR01,RT01
313313
return True
314314

315315

316-
def _lazy_apply_wrapper( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
316+
def _lazy_apply_wrapper( # numpydoc ignore=PR01,RT01
317317
func: Callable[..., Array | ArrayLike | Sequence[Array | ArrayLike]],
318318
as_numpy: bool,
319319
multi_output: bool,
@@ -331,7 +331,7 @@ def _lazy_apply_wrapper( # type: ignore[explicit-any] # numpydoc ignore=PR01,R
331331

332332
# On Dask, @wraps causes the graph key to contain the wrapped function's name
333333
@wraps(func)
334-
def wrapper( # type: ignore[decorated-any,explicit-any]
334+
def wrapper(
335335
*args: Array | complex | None, **kwargs: Any
336336
) -> tuple[Array, ...]: # numpydoc ignore=GL08
337337
args_list = []
@@ -343,7 +343,7 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
343343
if as_numpy:
344344
import numpy as np
345345

346-
arg = cast(Array, np.asarray(arg)) # type: ignore[bad-cast] # pyright: ignore[reportInvalidCast] # noqa: PLW2901
346+
arg = cast(Array, np.asarray(arg)) # pyright: ignore[reportInvalidCast] # noqa: PLW2901
347347
args_list.append(arg)
348348
assert device is not None
349349

src/array_api_extra/_lib/_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _is_materializable(x: Array) -> bool:
110110
return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
111111

112112

113-
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
113+
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
114114
"""
115115
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
116116
"""

src/array_api_extra/_lib/_utils/_compat.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def is_torch_array(x: object, /) -> TypeGuard[Array]: ...
3636
def is_lazy_array(x: object, /) -> TypeGuard[Array]: ...
3737
def is_writeable_array(x: object, /) -> TypeGuard[Array]: ...
3838
def size(x: Array, /) -> int | None: ...
39-
def to_device( # type: ignore[explicit-any]
39+
def to_device(
4040
x: Array,
4141
device: Device, # pylint: disable=redefined-outer-name
4242
/,

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def asarrays(
210210
float: ("real floating", "complex floating"),
211211
complex: "complex floating",
212212
}
213-
kind = same_dtype[type(cast(complex, b))] # type: ignore[index]
213+
kind = same_dtype[type(cast(complex, b))]
214214
if xp.isdtype(a.dtype, kind):
215215
xb = xp.asarray(b, dtype=a.dtype)
216216
else:
@@ -458,7 +458,7 @@ def persistent_id(
458458
return instances, (f.getvalue(), *rest)
459459

460460

461-
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any]
461+
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any:
462462
"""
463463
Reverse of ``pickle_flatten``.
464464
@@ -521,7 +521,7 @@ def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
521521
self.obj = obj
522522

523523
@classmethod
524-
def _register(cls): # numpydoc ignore=SS06
524+
def _register(cls) -> None: # numpydoc ignore=SS06
525525
"""
526526
Register upon first use instead of at import time, to avoid
527527
globally importing JAX.
@@ -583,7 +583,7 @@ def f(x: Array, y: float, plus: bool) -> Array:
583583
import jax
584584

585585
@jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator]
586-
def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
586+
def inner( # numpydoc ignore=GL08
587587
wargs: _AutoJITWrapper[Any],
588588
) -> _AutoJITWrapper[T]:
589589
args, kwargs = wargs.obj

src/array_api_extra/_lib/_utils/_typing.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ class DType(Protocol): # pylint: disable=missing-class-docstring
9595
class Device(Protocol): # pylint: disable=missing-class-docstring
9696
pass
9797

98-
SetIndex: TypeAlias = ( # type: ignore[explicit-any]
98+
SetIndex: TypeAlias = (
9999
int | slice | EllipsisType | Array | tuple[int | slice | EllipsisType | Array, ...]
100100
)
101-
GetIndex: TypeAlias = ( # type: ignore[explicit-any]
101+
GetIndex: TypeAlias = (
102102
SetIndex | None | tuple[int | slice | EllipsisType | None | Array, ...]
103103
)
104104

src/array_api_extra/testing.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def override(func):
3636
P = ParamSpec("P")
3737
T = TypeVar("T")
3838

39-
_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[explicit-any]
39+
_ufuncs_tags: dict[object, dict[str, Any]] = {}
4040

4141

4242
class Deprecated(enum.Enum):
@@ -48,7 +48,7 @@ class Deprecated(enum.Enum):
4848
DEPRECATED = Deprecated.DEPRECATED
4949

5050

51-
def lazy_xp_function( # type: ignore[explicit-any]
51+
def lazy_xp_function(
5252
func: Callable[..., Any],
5353
*,
5454
allow_dask_compute: bool | int = False,
@@ -297,12 +297,12 @@ def temp_setattr(mod: ModuleType, name: str, func: object) -> None:
297297
# Enable using patch_lazy_xp_function not as a context manager
298298
temp_setattr = monkeypatch.setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
299299

300-
def iter_tagged() -> ( # type: ignore[explicit-any]
300+
def iter_tagged() -> (
301301
Iterator[tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]]
302302
):
303303
for mod in mods:
304304
for name, func in mod.__dict__.items():
305-
tags: dict[str, Any] | None = None # type: ignore[explicit-any]
305+
tags: dict[str, Any] | None = None
306306
with contextlib.suppress(AttributeError):
307307
tags = func._lazy_xp_function # pylint: disable=protected-access
308308
if tags is None:
@@ -366,15 +366,17 @@ def __init__(self, max_count: int, msg: str): # numpydoc ignore=GL08
366366
self.msg = msg
367367

368368
@override
369-
def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any: # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
369+
def __call__(
370+
self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any
371+
) -> Any: # numpydoc ignore=GL08
370372
import dask
371373

372374
self.count += 1
373375
# This should yield a nice traceback to the
374376
# offending line in the user's code
375377
assert self.count <= self.max_count, self.msg
376378

377-
return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
379+
return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]
378380

379381

380382
def _dask_wrap(
@@ -407,7 +409,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
407409
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would
408410
# not work on scheduler='distributed', as it would not block.
409411
arrays, rest = pickle_flatten(out, da.Array)
410-
arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
412+
arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
411413
return pickle_unflatten(arrays, rest) # pyright: ignore[reportUnknownArgumentType]
412414

413415
return wrapper

tests/test_at.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def at_op(
4040
just a workaround for when one wants to apply jax.jit to `at()` directly,
4141
which is not a common use case.
4242
"""
43-
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[explicit-any]
43+
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value))
4444
return meth(y, copy=copy, xp=xp)
4545

4646

@@ -157,7 +157,7 @@ def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
157157
"""
158158
x = xp.asarray([1.0, 10.0, 20.0])
159159
expect_copy = not is_writeable_array(x)
160-
meth = cast(Callable[..., Array], getattr(at(x)[:2], op.value)) # type: ignore[explicit-any]
160+
meth = cast(Callable[..., Array], getattr(at(x)[:2], op.value))
161161
with assert_copy(x, None, expect_copy):
162162
_ = meth(2.0)
163163

@@ -166,7 +166,7 @@ def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
166166
# even if the arrays are writeable.
167167
expect_copy = not is_writeable_array(x) or library is Backend.DASK
168168
idx = xp.asarray([True, True, False])
169-
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[explicit-any]
169+
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value))
170170
with assert_copy(x, None, expect_copy):
171171
_ = meth(2.0)
172172

@@ -178,7 +178,7 @@ def test_copy_invalid():
178178

179179

180180
def test_xp():
181-
a = cast(Array, np.asarray([1, 2, 3])) # type: ignore[bad-cast] # pyright: ignore[reportInvalidCast]
181+
a = cast(Array, np.asarray([1, 2, 3])) # pyright: ignore[reportInvalidCast]
182182
_ = at(a, 0).set(4, xp=np)
183183
_ = at(a, 0).add(4, xp=np)
184184
_ = at(a, 0).subtract(4, xp=np)
@@ -190,7 +190,7 @@ def test_xp():
190190

191191

192192
def test_alternate_index_syntax():
193-
xp = cast(ModuleType, np) # pyright: ignore[reportInvalidCast]
193+
xp = cast(ModuleType, np) # type: ignore[redundant-cast] # pyright: ignore[reportInvalidCast]
194194
a = cast(Array, xp.asarray([1, 2, 3]))
195195
xp_assert_equal(at(a, 0).set(4, copy=True), xp.asarray([4, 2, 3]))
196196
xp_assert_equal(at(a)[0].set(4, copy=True), xp.asarray([4, 2, 3]))

tests/test_funcs.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@
3434
from array_api_extra._lib._utils._typing import Array, Device
3535
from array_api_extra.testing import lazy_xp_function
3636

37-
# some xp backends are untyped
38-
# mypy: disable-error-code=no-untyped-def
39-
4037
lazy_xp_function(apply_where)
4138
lazy_xp_function(atleast_nd)
4239
lazy_xp_function(cov)
@@ -211,7 +208,7 @@ def test_device(self, xp: ModuleType, device: Device):
211208
p=st.floats(min_value=0, max_value=1),
212209
data=st.data(),
213210
)
214-
def test_hypothesis( # type: ignore[explicit-any,decorated-any]
211+
def test_hypothesis(
215212
self,
216213
n_arrays: int,
217214
rng_seed: int,

tests/test_helpers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
def override(func):
3333
return func
3434

35-
# mypy: disable-error-code=no-untyped-usage
3635

3736
T = TypeVar("T")
3837

@@ -387,7 +386,7 @@ def test_static_hashable(self, jnp: ModuleType):
387386
"""Static argument/return value is hashable, but not serializable"""
388387

389388
class C:
390-
def __reduce__(self) -> object: # type: ignore[explicit-override,override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride]
389+
def __reduce__(self) -> object: # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride]
391390
raise Exception()
392391

393392
@jax_autojit
@@ -399,12 +398,12 @@ def f(x: object) -> object:
399398
assert out is inp
400399

401400
# Serializable opaque input contains non-serializable object plus array
402-
inp = Wrapper((C(), jnp.asarray([1, 2])))
403-
out = f(inp)
401+
winp = Wrapper((C(), jnp.asarray([1, 2])))
402+
out = f(winp)
404403
assert isinstance(out, Wrapper)
405-
assert out.x[0] is inp.x[0]
406-
assert out.x[1] is not inp.x[1]
407-
xp_assert_equal(out.x[1], inp.x[1]) # pyright: ignore[reportUnknownArgumentType]
404+
assert out.x[0] is winp.x[0]
405+
assert out.x[1] is not winp.x[1]
406+
xp_assert_equal(out.x[1], winp.x[1]) # pyright: ignore[reportUnknownArgumentType]
408407

409408
def test_arraylikes_are_static(self):
410409
pytest.importorskip("jax")

tests/test_lazy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def f(x: Array) -> Array:
141141
xp = array_namespace(x)
142142
return xp.sum(x, axis=0) + x
143143

144-
x_np = cast(Array, np.arange(15).reshape(5, 3)) # type: ignore[bad-cast] # pyright: ignore[reportInvalidCast]
144+
x_np = cast(Array, np.arange(15).reshape(5, 3)) # pyright: ignore[reportInvalidCast]
145145
expect = da.asarray(f(x_np))
146146
x_da = da.asarray(x_np).rechunk(3)
147147

0 commit comments

Comments
 (0)