From 49ebc0c90bca2ad3721fc5f10cc894a6c68eb5d7 Mon Sep 17 00:00:00 2001 From: Tyler Coles Date: Thu, 25 Apr 2024 16:09:01 -0700 Subject: [PATCH] Patch check_ndarray so that it is somewhat safer to use with structural types. When using the dtype argument, you should always wrap the argument in a list, even if there's only one valid dtype. (And especially if that dtype is a structural type.) --- epymorph/engine/context.py | 2 +- epymorph/geo/spec.py | 2 +- epymorph/initializer.py | 6 +++--- epymorph/test/util_test.py | 33 ++++++++++++++++++++++++++++++--- epymorph/util.py | 8 ++++---- 5 files changed, 39 insertions(+), 12 deletions(-) diff --git a/epymorph/engine/context.py b/epymorph/engine/context.py index 1d48e41b..20ece347 100644 --- a/epymorph/engine/context.py +++ b/epymorph/engine/context.py @@ -258,7 +258,7 @@ def _initialize(ctx: RumeContext) -> NDArray[SimDType]: try: _, N, C, _ = ctx.dim.TNCE - check_ndarray(result, SimDType, (N, C)) + check_ndarray(result, [SimDType], (N, C)) except NumpyTypeError as e: raise InitException(f"Invalid return type from '{init_name}'") from e return result diff --git a/epymorph/geo/spec.py b/epymorph/geo/spec.py index 07ae876f..cb157ef5 100644 --- a/epymorph/geo/spec.py +++ b/epymorph/geo/spec.py @@ -194,7 +194,7 @@ def validate_geo_values(spec: GeoSpec, values: dict[str, NDArray]) -> None: try: v = values[a.name] expected_shape = a.shape.as_tuple(N, T) - check_ndarray(v, dtype=a.dtype, shape=expected_shape) + check_ndarray(v, dtype=[a.dtype], shape=expected_shape) except KeyError: msg = f"Geo is missing values for attribute '{a.name}'." attribute_errors.append(msg) diff --git a/epymorph/initializer.py b/epymorph/initializer.py index 56e2de5a..cb01ae39 100644 --- a/epymorph/initializer.py +++ b/epymorph/initializer.py @@ -90,7 +90,7 @@ def explicit(ctx: InitContext, initials: NDArray[SimDType]) -> NDArray[SimDType] """ _, N, C, _ = ctx.dim.TNCE try: - check_ndarray(initials, SimDType, (N, C)) + check_ndarray(initials, [SimDType], (N, C)) except NumpyTypeError as e: raise InitException.for_arg('initials') from e return initials.copy() @@ -131,7 +131,7 @@ def indexed_locations(ctx: InitContext, selection: NDArray[np.intp], seed_size: N = ctx.dim.nodes try: - check_ndarray(selection, np.intp, dimensions=1) + check_ndarray(selection, [np.intp], dimensions=1) except NumpyTypeError as e: raise InitException.for_arg('selection') from e if not np.all((-N < selection) & (selection < N)): @@ -176,7 +176,7 @@ def labeled_locations(ctx: InitContext, labels: NDArray[np.str_], seed_size: int - `seed_size` the number of individuals to infect in total """ try: - check_ndarray(labels, np.str_) + check_ndarray(labels, [np.str_]) except NumpyTypeError as e: raise InitException.for_arg('label') from e diff --git a/epymorph/test/util_test.py b/epymorph/test/util_test.py index 92980bae..420f8b03 100644 --- a/epymorph/test/util_test.py +++ b/epymorph/test/util_test.py @@ -46,9 +46,9 @@ def test_check_ndarray_01(self): # Test is just to make sure none of these raise NumpyTypeError arr = np.array([1, 2, 3], dtype=np.int64) util.check_ndarray(arr) - util.check_ndarray(arr, dtype=np.int64) + util.check_ndarray(arr, dtype=[np.int64]) util.check_ndarray(arr, shape=(3,)) - util.check_ndarray(arr, np.int64, (3,)) + util.check_ndarray(arr, [np.int64], (3,)) util.check_ndarray(arr, [np.int64], (3,)) util.check_ndarray(arr, [np.int64, np.float64], (3,)) util.check_ndarray(arr, [np.float64, np.int64], (3,)) @@ -80,6 +80,33 @@ def test_check_ndarray_03(self): with self.assertRaises(util.NumpyTypeError): util.check_ndarray(arr, shape=(3, 4, 1)) with self.assertRaises(util.NumpyTypeError): - util.check_ndarray(arr, dtype=np.str_) + util.check_ndarray(arr, dtype=[np.str_]) with self.assertRaises(util.NumpyTypeError): util.check_ndarray(arr, dimensions=3) + + def test_check_ndarray_04(self): + # check_ndarray is poorly constructed to support structural types. + # + # Its original signature accepted a single DType or lists of DTypes intended to be joined with a logical "or". + # However in this scheme it's hard to distinguish structural types (which are also lists). + # With enough effort we could get this to work, but really this is a pretty unnecessary feature, so not worth it. + # + # For the time-being we're going to patch this by just requiring the dtype argument is passed as a list, + # even when there's only one option. This way structural types are wrapped in an extra list and we can drop + # the list detection logic. (Long term I want to phase out check_ndarray and its over-flexibility.) + # + # What makes this patch unfortunate is that if you accidentally pass a structural type without wrapping it in a list + # type-checking succeeds, but will fail at runtime: + # + # util.check_ndarray(arr, dtype=lnglat) --> ERROR: "data type 'longitude' not understood" + + lnglat = [('longitude', np.float64), ('latitude', np.float64)] + arr = np.array([(1.0, 2.0), (3.0, 4.0)], dtype=lnglat) + # Doesn't raise... + util.check_ndarray(arr, dtype=[lnglat]) + util.check_ndarray(arr, dtype=[np.float64, lnglat]) + # Does raise... + with self.assertRaises(util.NumpyTypeError): + util.check_ndarray(arr, dtype=[np.str_]) + with self.assertRaises(util.NumpyTypeError): + util.check_ndarray(arr, dtype=[np.float64]) diff --git a/epymorph/util.py b/epymorph/util.py index 66bd1c39..b28b8cac 100644 --- a/epymorph/util.py +++ b/epymorph/util.py @@ -236,7 +236,7 @@ def dtype_name(d: np.dtype) -> str: def check_ndarray( value: Any, - dtype: DTypeLike | list[DTypeLike] | None = None, + dtype: list[DTypeLike] | None = None, shape: tuple[int, ...] | list[tuple[int, ...]] | None = None, dimensions: int | list[int] | None = None, ) -> None: @@ -255,7 +255,7 @@ def check_ndarray( msg = f"Not a numpy shape match: got {value.shape}, expected {shape}" raise NumpyTypeError(msg) if dtype is not None: - npdtypes = [np.dtype(x) for x in as_list(dtype)] + npdtypes = [np.dtype(x) for x in dtype] is_subtype = map(lambda x: np.issubdtype(value.dtype, x), npdtypes) if not any(is_subtype): if len(npdtypes) == 1: @@ -305,7 +305,7 @@ def publish(self, event: T) -> None: for subscriber in self._subscribers: subscriber(event) - @property + @ property def has_subscribers(self) -> bool: """True if at least one listener is subscribed to this event.""" return len(self._subscribers) > 0 @@ -334,7 +334,7 @@ def unsubscribe(self) -> None: self._unsubscribers.clear() -@contextmanager +@ contextmanager def subscriptions() -> Generator[Subscriber, None, None]: """ Manage a subscription context, where all subscriptions added through the returned Subscriber