Skip to content

Commit

Permalink
Patch check_ndarray... (#105)
Browse files Browse the repository at this point in the history
* Patch check_ndarray so that it is somewhat safer to use with structural types.
When using the dtype argument, you should always wrap the argument in a list, even if there's only one valid dtype.
(And especially if that dtype is a structural type.)
  • Loading branch information
JavadocMD authored Apr 25, 2024
1 parent 05c3f2e commit 7433936
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 10 deletions.
2 changes: 1 addition & 1 deletion epymorph/engine/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _initialize(ctx: RumeContext) -> NDArray[SimDType]:

try:
_, N, C, _ = ctx.dim.TNCE
check_ndarray(result, SimDType, (N, C))
check_ndarray(result, [SimDType], (N, C))
except NumpyTypeError as e:
raise InitException(f"Invalid return type from '{init_name}'") from e
return result
2 changes: 1 addition & 1 deletion epymorph/geo/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def validate_geo_values(spec: GeoSpec, values: dict[str, NDArray]) -> None:
try:
v = values[a.name]
expected_shape = a.shape.as_tuple(N, T)
check_ndarray(v, dtype=a.dtype, shape=expected_shape)
check_ndarray(v, dtype=[a.dtype], shape=expected_shape)
except KeyError:
msg = f"Geo is missing values for attribute '{a.name}'."
attribute_errors.append(msg)
Expand Down
6 changes: 3 additions & 3 deletions epymorph/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def explicit(ctx: InitContext, initials: NDArray[SimDType]) -> NDArray[SimDType]
"""
_, N, C, _ = ctx.dim.TNCE
try:
check_ndarray(initials, SimDType, (N, C))
check_ndarray(initials, [SimDType], (N, C))
except NumpyTypeError as e:
raise InitException.for_arg('initials') from e
return initials.copy()
Expand Down Expand Up @@ -131,7 +131,7 @@ def indexed_locations(ctx: InitContext, selection: NDArray[np.intp], seed_size:
N = ctx.dim.nodes

try:
check_ndarray(selection, np.intp, dimensions=1)
check_ndarray(selection, [np.intp], dimensions=1)
except NumpyTypeError as e:
raise InitException.for_arg('selection') from e
if not np.all((-N < selection) & (selection < N)):
Expand Down Expand Up @@ -176,7 +176,7 @@ def labeled_locations(ctx: InitContext, labels: NDArray[np.str_], seed_size: int
- `seed_size` the number of individuals to infect in total
"""
try:
check_ndarray(labels, np.str_)
check_ndarray(labels, [np.str_])
except NumpyTypeError as e:
raise InitException.for_arg('label') from e

Expand Down
33 changes: 30 additions & 3 deletions epymorph/test/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def test_check_ndarray_01(self):
# Test is just to make sure none of these raise NumpyTypeError
arr = np.array([1, 2, 3], dtype=np.int64)
util.check_ndarray(arr)
util.check_ndarray(arr, dtype=np.int64)
util.check_ndarray(arr, dtype=[np.int64])
util.check_ndarray(arr, shape=(3,))
util.check_ndarray(arr, np.int64, (3,))
util.check_ndarray(arr, [np.int64], (3,))
util.check_ndarray(arr, [np.int64], (3,))
util.check_ndarray(arr, [np.int64, np.float64], (3,))
util.check_ndarray(arr, [np.float64, np.int64], (3,))
Expand Down Expand Up @@ -80,6 +80,33 @@ def test_check_ndarray_03(self):
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, shape=(3, 4, 1))
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, dtype=np.str_)
util.check_ndarray(arr, dtype=[np.str_])
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, dimensions=3)

def test_check_ndarray_04(self):
# check_ndarray is poorly constructed to support structural types.
#
# Its original signature accepted a single DType or lists of DTypes intended to be joined with a logical "or".
# However in this scheme it's hard to distinguish structural types (which are also lists).
# With enough effort we could get this to work, but really this is a pretty unnecessary feature, so not worth it.
#
# For the time-being we're going to patch this by just requiring the dtype argument is passed as a list,
# even when there's only one option. This way structural types are wrapped in an extra list and we can drop
# the list detection logic. (Long term I want to phase out check_ndarray and its over-flexibility.)
#
# What makes this patch unfortunate is that if you accidentally pass a structural type without wrapping it in a list
# type-checking succeeds, but will fail at runtime:
#
# util.check_ndarray(arr, dtype=lnglat) --> ERROR: "data type 'longitude' not understood"

lnglat = [('longitude', np.float64), ('latitude', np.float64)]
arr = np.array([(1.0, 2.0), (3.0, 4.0)], dtype=lnglat)
# Doesn't raise...
util.check_ndarray(arr, dtype=[lnglat])
util.check_ndarray(arr, dtype=[np.float64, lnglat])
# Does raise...
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, dtype=[np.str_])
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, dtype=[np.float64])
4 changes: 2 additions & 2 deletions epymorph/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def dtype_name(d: np.dtype) -> str:

def check_ndarray(
value: Any,
dtype: DTypeLike | list[DTypeLike] | None = None,
dtype: list[DTypeLike] | None = None,
shape: tuple[int, ...] | list[tuple[int, ...]] | None = None,
dimensions: int | list[int] | None = None,
) -> None:
Expand All @@ -255,7 +255,7 @@ def check_ndarray(
msg = f"Not a numpy shape match: got {value.shape}, expected {shape}"
raise NumpyTypeError(msg)
if dtype is not None:
npdtypes = [np.dtype(x) for x in as_list(dtype)]
npdtypes = [np.dtype(x) for x in dtype]
is_subtype = map(lambda x: np.issubdtype(value.dtype, x), npdtypes)
if not any(is_subtype):
if len(npdtypes) == 1:
Expand Down

0 comments on commit 7433936

Please sign in to comment.