Skip to content

Commit

Permalink
Fix bug with structured attribute types (like centroid).
Browse files Browse the repository at this point in the history
Lists are unhashable, but AttributeDef needs to be. We now require they be tuples, and will convert to lists when needed for numpy.
  • Loading branch information
Tyler Coles committed Aug 16, 2024
1 parent f2afa9f commit e4bdd89
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 32 deletions.
41 changes: 23 additions & 18 deletions epymorph/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
Types for source data and attributes in epymorph.
"""
from datetime import date
from typing import Any, Sequence
from typing import Any

import numpy as np
from numpy.typing import DTypeLike, NDArray

# Types for attribute declarations:
# these are expressed as Python types for simplicity.

# NOTE: In epymorph, we express structured types as tuples-of-tuples;
# this way they're hashable, which is important for AttributeDef.
# However numpy expresses them as lists-of-tuples, so we have to convert;
# thankfully we had an infrastructure for this sort of thing already.

ScalarType = type[int | float | str | date]
StructType = Sequence[tuple[str, ScalarType]]
StructType = tuple[tuple[str, ScalarType], ...]
AttributeType = ScalarType | StructType
"""The allowed type declarations for epymorph attributes."""

Expand All @@ -38,16 +43,16 @@ def dtype_as_np(dtype: AttributeType) -> np.dtype:
return np.dtype(np.str_)
if dtype == date:
return np.dtype(np.datetime64)
if isinstance(dtype, Sequence):
dtype = list(dtype)
if len(dtype) == 0:
if isinstance(dtype, tuple):
fields = list(dtype)
if len(fields) == 0:
raise ValueError(f"Unsupported dtype: {dtype}")
try:
return np.dtype([
(field_name, dtype_as_np(field_dtype))
for field_name, field_dtype in dtype
for field_name, field_dtype in fields
])
except TypeError:
except (TypeError, ValueError):
raise ValueError(f"Unsupported dtype: {dtype}") from None
raise ValueError(f"Unsupported dtype: {dtype}")

Expand All @@ -62,17 +67,17 @@ def dtype_str(dtype: AttributeType) -> str:
return "str"
if dtype == date:
return "date"
if isinstance(dtype, Sequence):
dtype = list(dtype)
if len(dtype) == 0:
if isinstance(dtype, tuple):
fields = list(dtype)
if len(fields) == 0:
raise ValueError(f"Unsupported dtype: {dtype}")
try:
values = [
f"({field_name}, {dtype_str(field_dtype)})"
for field_name, field_dtype in dtype
for field_name, field_dtype in fields
]
return f"[{', '.join(values)}]"
except TypeError:
except (TypeError, ValueError):
raise ValueError(f"Unsupported dtype: {dtype}") from None
raise ValueError(f"Unsupported dtype: {dtype}")

Expand All @@ -81,22 +86,22 @@ def dtype_check(dtype: AttributeType, value: Any) -> bool:
"""Checks that a value conforms to a given dtype. (Python types only.)"""
if dtype in (int, float, str, date):
return isinstance(value, dtype)
if isinstance(dtype, Sequence):
dtype = list(dtype)
if isinstance(dtype, tuple):
fields = list(dtype)
if not isinstance(value, tuple):
return False
if len(value) != len(dtype):
if len(value) != len(fields):
return False
return all((
dtype_check(field_dtype, field_value)
for ((_, field_dtype), field_value) in zip(dtype, value)
for ((_, field_dtype), field_value) in zip(fields, value)
))
raise ValueError(f"Unsupported dtype: {dtype}")


CentroidType: AttributeType = [('longitude', float), ('latitude', float)]
CentroidType: AttributeType = (('longitude', float), ('latitude', float))
"""Structured epymorph type declaration for long/lat coordinates."""
CentroidDType: DTypeLike = [('longitude', np.float64), ('latitude', np.float64)]
CentroidDType: DTypeLike = dtype_as_np(CentroidType)
"""The numpy equivalent of `CentroidType` (structured dtype for long/lat coordinates)."""

# SimDType being centrally-located means we can change it reliably.
Expand Down
18 changes: 5 additions & 13 deletions epymorph/test/data_type_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_dtype_as_np(self):
self.assertEqual(dtype_as_np(str), np.str_)
self.assertEqual(dtype_as_np(date), np.datetime64)

struct = [('foo', float), ('bar', int), ('baz', str), ('bux', date)]
struct = (('foo', float), ('bar', int), ('baz', str), ('bux', date))
self.assertEqual(
dtype_as_np(struct),
[('foo', np.float64), ('bar', np.int64),
Expand All @@ -28,7 +28,7 @@ def test_dtype_str(self):
self.assertEqual(dtype_str(str), "str")
self.assertEqual(dtype_str(date), "date")

struct = [('foo', float), ('bar', int), ('baz', str), ('bux', date)]
struct = (('foo', float), ('bar', int), ('baz', str), ('bux', date))
self.assertEqual(
dtype_str(struct),
"[(foo, float), (bar, int), (baz, str), (bux, date)]"
Expand All @@ -53,8 +53,8 @@ def test_dtype_check(self):
self.assertTrue(dtype_check(str, ""))
self.assertTrue(dtype_check(date, date(2024, 1, 1)))
self.assertTrue(dtype_check(date, date(1066, 10, 14)))
self.assertTrue(dtype_check([('x', int), ('y', int)], (1, 2)))
self.assertTrue(dtype_check([('a', str), ('b', float)], ("hi", 9273.3)))
self.assertTrue(dtype_check((('x', int), ('y', int)), (1, 2)))
self.assertTrue(dtype_check((('a', str), ('b', float)), ("hi", 9273.3)))

self.assertFalse(dtype_check(int, "hi"))
self.assertFalse(dtype_check(int, 42.42))
Expand All @@ -68,18 +68,10 @@ def test_dtype_check(self):
self.assertFalse(dtype_check(date, '2024-01-01'))
self.assertFalse(dtype_check(date, 123))

dt1 = [('x', int), ('y', int)]
dt1 = (('x', int), ('y', int))
self.assertFalse(dtype_check(dt1, 1))
self.assertFalse(dtype_check(dt1, 78923.1))
self.assertFalse(dtype_check(dt1, "hi"))
self.assertFalse(dtype_check(dt1, ()))
self.assertFalse(dtype_check(dt1, (1, 237.8)))
self.assertFalse(dtype_check(dt1, (1, 2, 3)))

dt2 = (('x', int), ('y', int))
self.assertFalse(dtype_check(dt2, 1))
self.assertFalse(dtype_check(dt2, 78923.1))
self.assertFalse(dtype_check(dt2, "hi"))
self.assertFalse(dtype_check(dt2, ()))
self.assertFalse(dtype_check(dt2, (1, 237.8)))
self.assertFalse(dtype_check(dt2, (1, 2, 3)))
2 changes: 1 addition & 1 deletion epymorph/test/rume_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def test_create_multistrata_2(self):
comment="The total population at each node."),

AbsoluteName("gpm:aaa", "mm", "centroid"):
AttributeDef("centroid", [('longitude', float), ('latitude', float)], Shapes.N,
AttributeDef("centroid", (('longitude', float), ('latitude', float)), Shapes.N,
comment="The centroids for each node as (longitude, latitude) tuples."),

AbsoluteName("gpm:aaa", "mm", "phi"):
Expand Down

0 comments on commit e4bdd89

Please sign in to comment.