diff --git a/epymorph/data_type.py b/epymorph/data_type.py index 37dc2f3f..201f123b 100644 --- a/epymorph/data_type.py +++ b/epymorph/data_type.py @@ -2,7 +2,7 @@ Types for source data and attributes in epymorph. """ from datetime import date -from typing import Any, Sequence +from typing import Any import numpy as np from numpy.typing import DTypeLike, NDArray @@ -10,8 +10,13 @@ # Types for attribute declarations: # these are expressed as Python types for simplicity. +# NOTE: In epymorph, we express structured types as tuples-of-tuples; +# this way they're hashable, which is important for AttributeDef. +# However numpy expresses them as lists-of-tuples, so we have to convert; +# thankfully we had an infrastructure for this sort of thing already. + ScalarType = type[int | float | str | date] -StructType = Sequence[tuple[str, ScalarType]] +StructType = tuple[tuple[str, ScalarType], ...] AttributeType = ScalarType | StructType """The allowed type declarations for epymorph attributes.""" @@ -38,16 +43,16 @@ def dtype_as_np(dtype: AttributeType) -> np.dtype: return np.dtype(np.str_) if dtype == date: return np.dtype(np.datetime64) - if isinstance(dtype, Sequence): - dtype = list(dtype) - if len(dtype) == 0: + if isinstance(dtype, tuple): + fields = list(dtype) + if len(fields) == 0: raise ValueError(f"Unsupported dtype: {dtype}") try: return np.dtype([ (field_name, dtype_as_np(field_dtype)) - for field_name, field_dtype in dtype + for field_name, field_dtype in fields ]) - except TypeError: + except (TypeError, ValueError): raise ValueError(f"Unsupported dtype: {dtype}") from None raise ValueError(f"Unsupported dtype: {dtype}") @@ -62,17 +67,17 @@ def dtype_str(dtype: AttributeType) -> str: return "str" if dtype == date: return "date" - if isinstance(dtype, Sequence): - dtype = list(dtype) - if len(dtype) == 0: + if isinstance(dtype, tuple): + fields = list(dtype) + if len(fields) == 0: raise ValueError(f"Unsupported dtype: {dtype}") try: values = [ f"({field_name}, {dtype_str(field_dtype)})" - for field_name, field_dtype in dtype + for field_name, field_dtype in fields ] return f"[{', '.join(values)}]" - except TypeError: + except (TypeError, ValueError): raise ValueError(f"Unsupported dtype: {dtype}") from None raise ValueError(f"Unsupported dtype: {dtype}") @@ -81,22 +86,22 @@ def dtype_check(dtype: AttributeType, value: Any) -> bool: """Checks that a value conforms to a given dtype. (Python types only.)""" if dtype in (int, float, str, date): return isinstance(value, dtype) - if isinstance(dtype, Sequence): - dtype = list(dtype) + if isinstance(dtype, tuple): + fields = list(dtype) if not isinstance(value, tuple): return False - if len(value) != len(dtype): + if len(value) != len(fields): return False return all(( dtype_check(field_dtype, field_value) - for ((_, field_dtype), field_value) in zip(dtype, value) + for ((_, field_dtype), field_value) in zip(fields, value) )) raise ValueError(f"Unsupported dtype: {dtype}") -CentroidType: AttributeType = [('longitude', float), ('latitude', float)] +CentroidType: AttributeType = (('longitude', float), ('latitude', float)) """Structured epymorph type declaration for long/lat coordinates.""" -CentroidDType: DTypeLike = [('longitude', np.float64), ('latitude', np.float64)] +CentroidDType: DTypeLike = dtype_as_np(CentroidType) """The numpy equivalent of `CentroidType` (structured dtype for long/lat coordinates).""" # SimDType being centrally-located means we can change it reliably. diff --git a/epymorph/test/data_type_test.py b/epymorph/test/data_type_test.py index 85b97ff3..6a1d0033 100644 --- a/epymorph/test/data_type_test.py +++ b/epymorph/test/data_type_test.py @@ -15,7 +15,7 @@ def test_dtype_as_np(self): self.assertEqual(dtype_as_np(str), np.str_) self.assertEqual(dtype_as_np(date), np.datetime64) - struct = [('foo', float), ('bar', int), ('baz', str), ('bux', date)] + struct = (('foo', float), ('bar', int), ('baz', str), ('bux', date)) self.assertEqual( dtype_as_np(struct), [('foo', np.float64), ('bar', np.int64), @@ -28,7 +28,7 @@ def test_dtype_str(self): self.assertEqual(dtype_str(str), "str") self.assertEqual(dtype_str(date), "date") - struct = [('foo', float), ('bar', int), ('baz', str), ('bux', date)] + struct = (('foo', float), ('bar', int), ('baz', str), ('bux', date)) self.assertEqual( dtype_str(struct), "[(foo, float), (bar, int), (baz, str), (bux, date)]" @@ -53,8 +53,8 @@ def test_dtype_check(self): self.assertTrue(dtype_check(str, "")) self.assertTrue(dtype_check(date, date(2024, 1, 1))) self.assertTrue(dtype_check(date, date(1066, 10, 14))) - self.assertTrue(dtype_check([('x', int), ('y', int)], (1, 2))) - self.assertTrue(dtype_check([('a', str), ('b', float)], ("hi", 9273.3))) + self.assertTrue(dtype_check((('x', int), ('y', int)), (1, 2))) + self.assertTrue(dtype_check((('a', str), ('b', float)), ("hi", 9273.3))) self.assertFalse(dtype_check(int, "hi")) self.assertFalse(dtype_check(int, 42.42)) @@ -68,18 +68,10 @@ def test_dtype_check(self): self.assertFalse(dtype_check(date, '2024-01-01')) self.assertFalse(dtype_check(date, 123)) - dt1 = [('x', int), ('y', int)] + dt1 = (('x', int), ('y', int)) self.assertFalse(dtype_check(dt1, 1)) self.assertFalse(dtype_check(dt1, 78923.1)) self.assertFalse(dtype_check(dt1, "hi")) self.assertFalse(dtype_check(dt1, ())) self.assertFalse(dtype_check(dt1, (1, 237.8))) self.assertFalse(dtype_check(dt1, (1, 2, 3))) - - dt2 = (('x', int), ('y', int)) - self.assertFalse(dtype_check(dt2, 1)) - self.assertFalse(dtype_check(dt2, 78923.1)) - self.assertFalse(dtype_check(dt2, "hi")) - self.assertFalse(dtype_check(dt2, ())) - self.assertFalse(dtype_check(dt2, (1, 237.8))) - self.assertFalse(dtype_check(dt2, (1, 2, 3))) diff --git a/epymorph/test/rume_test.py b/epymorph/test/rume_test.py index f15f773a..c8fcac1c 100644 --- a/epymorph/test/rume_test.py +++ b/epymorph/test/rume_test.py @@ -285,7 +285,7 @@ def test_create_multistrata_2(self): comment="The total population at each node."), AbsoluteName("gpm:aaa", "mm", "centroid"): - AttributeDef("centroid", [('longitude', float), ('latitude', float)], Shapes.N, + AttributeDef("centroid", (('longitude', float), ('latitude', float)), Shapes.N, comment="The centroids for each node as (longitude, latitude) tuples."), AbsoluteName("gpm:aaa", "mm", "phi"):