Skip to content

Commit

Permalink
Patch check_ndarray so that it is somewhat safer to use with structur…
Browse files Browse the repository at this point in the history
…al types.

When using the dtype argument, you should always wrap the argument in a list, even if there's only one valid dtype.
(And especially if that dtype is a structural type.)
  • Loading branch information
Tyler Coles committed Apr 25, 2024
1 parent 05c3f2e commit 49ebc0c
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 12 deletions.
2 changes: 1 addition & 1 deletion epymorph/engine/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _initialize(ctx: RumeContext) -> NDArray[SimDType]:

try:
_, N, C, _ = ctx.dim.TNCE
check_ndarray(result, SimDType, (N, C))
check_ndarray(result, [SimDType], (N, C))
except NumpyTypeError as e:
raise InitException(f"Invalid return type from '{init_name}'") from e
return result
2 changes: 1 addition & 1 deletion epymorph/geo/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def validate_geo_values(spec: GeoSpec, values: dict[str, NDArray]) -> None:
try:
v = values[a.name]
expected_shape = a.shape.as_tuple(N, T)
check_ndarray(v, dtype=a.dtype, shape=expected_shape)
check_ndarray(v, dtype=[a.dtype], shape=expected_shape)
except KeyError:
msg = f"Geo is missing values for attribute '{a.name}'."
attribute_errors.append(msg)
Expand Down
6 changes: 3 additions & 3 deletions epymorph/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def explicit(ctx: InitContext, initials: NDArray[SimDType]) -> NDArray[SimDType]
"""
_, N, C, _ = ctx.dim.TNCE
try:
check_ndarray(initials, SimDType, (N, C))
check_ndarray(initials, [SimDType], (N, C))
except NumpyTypeError as e:
raise InitException.for_arg('initials') from e
return initials.copy()
Expand Down Expand Up @@ -131,7 +131,7 @@ def indexed_locations(ctx: InitContext, selection: NDArray[np.intp], seed_size:
N = ctx.dim.nodes

try:
check_ndarray(selection, np.intp, dimensions=1)
check_ndarray(selection, [np.intp], dimensions=1)
except NumpyTypeError as e:
raise InitException.for_arg('selection') from e
if not np.all((-N < selection) & (selection < N)):
Expand Down Expand Up @@ -176,7 +176,7 @@ def labeled_locations(ctx: InitContext, labels: NDArray[np.str_], seed_size: int
- `seed_size` the number of individuals to infect in total
"""
try:
check_ndarray(labels, np.str_)
check_ndarray(labels, [np.str_])
except NumpyTypeError as e:
raise InitException.for_arg('label') from e

Expand Down
33 changes: 30 additions & 3 deletions epymorph/test/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def test_check_ndarray_01(self):
# Test is just to make sure none of these raise NumpyTypeError
arr = np.array([1, 2, 3], dtype=np.int64)
util.check_ndarray(arr)
util.check_ndarray(arr, dtype=np.int64)
util.check_ndarray(arr, dtype=[np.int64])
util.check_ndarray(arr, shape=(3,))
util.check_ndarray(arr, np.int64, (3,))
util.check_ndarray(arr, [np.int64], (3,))
util.check_ndarray(arr, [np.int64], (3,))
util.check_ndarray(arr, [np.int64, np.float64], (3,))
util.check_ndarray(arr, [np.float64, np.int64], (3,))
Expand Down Expand Up @@ -80,6 +80,33 @@ def test_check_ndarray_03(self):
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, shape=(3, 4, 1))
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, dtype=np.str_)
util.check_ndarray(arr, dtype=[np.str_])
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, dimensions=3)

def test_check_ndarray_04(self):
# check_ndarray is poorly constructed to support structural types.
#
# Its original signature accepted a single DType or lists of DTypes intended to be joined with a logical "or".
# However in this scheme it's hard to distinguish structural types (which are also lists).
# With enough effort we could get this to work, but really this is a pretty unnecessary feature, so not worth it.
#
# For the time-being we're going to patch this by just requiring the dtype argument is passed as a list,
# even when there's only one option. This way structural types are wrapped in an extra list and we can drop
# the list detection logic. (Long term I want to phase out check_ndarray and its over-flexibility.)
#
# What makes this patch unfortunate is that if you accidentally pass a structural type without wrapping it in a list
# type-checking succeeds, but will fail at runtime:
#
# util.check_ndarray(arr, dtype=lnglat) --> ERROR: "data type 'longitude' not understood"

lnglat = [('longitude', np.float64), ('latitude', np.float64)]
arr = np.array([(1.0, 2.0), (3.0, 4.0)], dtype=lnglat)
# Doesn't raise...
util.check_ndarray(arr, dtype=[lnglat])
util.check_ndarray(arr, dtype=[np.float64, lnglat])
# Does raise...
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, dtype=[np.str_])
with self.assertRaises(util.NumpyTypeError):
util.check_ndarray(arr, dtype=[np.float64])
8 changes: 4 additions & 4 deletions epymorph/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def dtype_name(d: np.dtype) -> str:

def check_ndarray(
value: Any,
dtype: DTypeLike | list[DTypeLike] | None = None,
dtype: list[DTypeLike] | None = None,
shape: tuple[int, ...] | list[tuple[int, ...]] | None = None,
dimensions: int | list[int] | None = None,
) -> None:
Expand All @@ -255,7 +255,7 @@ def check_ndarray(
msg = f"Not a numpy shape match: got {value.shape}, expected {shape}"
raise NumpyTypeError(msg)
if dtype is not None:
npdtypes = [np.dtype(x) for x in as_list(dtype)]
npdtypes = [np.dtype(x) for x in dtype]
is_subtype = map(lambda x: np.issubdtype(value.dtype, x), npdtypes)
if not any(is_subtype):
if len(npdtypes) == 1:
Expand Down Expand Up @@ -305,7 +305,7 @@ def publish(self, event: T) -> None:
for subscriber in self._subscribers:
subscriber(event)

@property
@ property
def has_subscribers(self) -> bool:
"""True if at least one listener is subscribed to this event."""
return len(self._subscribers) > 0
Expand Down Expand Up @@ -334,7 +334,7 @@ def unsubscribe(self) -> None:
self._unsubscribers.clear()


@contextmanager
@ contextmanager
def subscriptions() -> Generator[Subscriber, None, None]:
"""
Manage a subscription context, where all subscriptions added through the returned Subscriber
Expand Down

0 comments on commit 49ebc0c

Please sign in to comment.