Skip to content

Commit

Permalink
General changes/fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tyler Coles committed Jul 23, 2024
1 parent 75de9e2 commit df9165b
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 26 deletions.
39 changes: 38 additions & 1 deletion epymorph/data_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,55 @@ def __str__(self):
return "TxN"


class NodeAndArbitrary(DataShape):
"""An array of size exactly-N by any dimension."""

def matches(self, dim: SimDimensions, value: NDArray, allow_broadcast: bool) -> bool:
if value.ndim == 2 and value.shape[0] == dim.nodes:
return True
return False

def adapt(self, dim: SimDimensions, value: NDArray[T], allow_broadcast: bool) -> NDArray[T] | None:
return value if self.matches(dim, value, allow_broadcast) else None

def __str__(self):
return "NxA"


class ArbitraryAndNode(DataShape):
"""An array of size any dimension by exactly-N."""

def matches(self, dim: SimDimensions, value: NDArray, allow_broadcast: bool) -> bool:
if value.ndim == 2 and value.shape[1] == dim.nodes:
return True
return False

def adapt(self, dim: SimDimensions, value: NDArray[T], allow_broadcast: bool) -> NDArray[T] | None:
return value if self.matches(dim, value, allow_broadcast) else None

def __str__(self):
return "AxN"


@dataclass(frozen=True)
class Shapes:
"""Static instances for all available shapes."""

# Data can be in any of these shapes, where:
# - S is a single scalar value
# - T is the number of ticks
# - T is the number of days
# - N is the number of nodes
# - C is the number of IPM compartments
# - A is any length (arbitrary; this dimension is effectively unchecked)

S = Scalar()
T = Time()
N = Node()
NxC = NodeAndCompartment()
NxN = NodeAndNode()
TxN = TimeAndNode()
NxA = NodeAndArbitrary()
AxN = ArbitraryAndNode()


def parse_shape(shape: str) -> DataShape:
Expand All @@ -262,6 +295,10 @@ def parse_shape(shape: str) -> DataShape:
return Shapes.NxN
case "TxN":
return Shapes.TxN
case "NxA":
return Shapes.NxA
case "AxN":
return Shapes.AxN
case _:
raise ValueError(f"'{shape}' is not a valid shape specification.")

Expand Down
4 changes: 3 additions & 1 deletion epymorph/geography/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class CustomScope(GeoScope):

_nodes: NDArray[np.str_]

def __init__(self, nodes: NDArray[np.str_]):
def __init__(self, nodes: NDArray[np.str_] | list[str]):
if isinstance(nodes, list):
nodes = np.array(nodes, dtype=np.str_)
self._nodes = nodes

def get_node_ids(self) -> NDArray[np.str_]:
Expand Down
4 changes: 2 additions & 2 deletions epymorph/geography/us_census.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,14 +405,14 @@ def verify_fips(granularity: CensusGranularityName, year: int, fips: Sequence[st
def state_code_to_fips(year: int) -> Mapping[str, str]:
"""Mapping from state postal code to FIPS code."""
states = get_us_states(year)
return dict(zip(states.code, states.geoid))
return dict(zip(states.code.tolist(), states.geoid.tolist()))


@cache
def state_fips_to_code(year: int) -> Mapping[str, str]:
"""Mapping from state FIPS code to postal code."""
states = get_us_states(year)
return dict(zip(states.geoid, states.code))
return dict(zip(states.geoid.tolist(), states.code.tolist()))


def validate_state_codes_as_fips(year: int, codes: Sequence[str]) -> Sequence[str]:
Expand Down
46 changes: 26 additions & 20 deletions epymorph/geography/us_tiger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
# Below there are some commented-code remnants which demonstrate what it takes to support the additional
# territories, in case we ever want to reverse this choice.

# NOTE: TIGER files express areas in meters-squared.

TigerYear = Literal[2000, 2009, 2010, 2011, 2012, 2013, 2014,
2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023]
"""A supported TIGER file year."""
Expand Down Expand Up @@ -156,30 +158,33 @@ def _get_states_config(year: TigerYear) -> tuple[list[str], list[str], list[str]
"""Produce the args for _get_info or _get_geo (states)."""
match year:
case year if year in range(2011, 2024):
cols = ["GEOID", "NAME", "STUSPS"]
cols = ["GEOID", "NAME", "STUSPS", "ALAND", "INTPTLAT", "INTPTLON"]
urls = [
f"{_TIGER_URL}/TIGER{year}/STATE/tl_{year}_us_state.zip"
]
case 2010:
cols = ["GEOID10", "NAME10", "STUSPS10"]
cols = ["GEOID10", "NAME10", "STUSPS10",
"ALAND10", "INTPTLAT10", "INTPTLON10"]
urls = [
f"{_TIGER_URL}/TIGER2010/STATE/2010/tl_2010_{xx}_state10.zip"
for xx in _SUPPORTED_STATE_FILES
]
case 2009:
cols = ["STATEFP00", "NAME00", "STUSPS00"]
cols = ["STATEFP00", "NAME00", "STUSPS00",
"ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2009/tl_2009_us_state00.zip"
]
case 2000:
cols = ["STATEFP00", "NAME00", "STUSPS00"]
cols = ["STATEFP00", "NAME00", "STUSPS00",
"ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2010/STATE/2000/tl_2010_{xx}_state00.zip"
for xx in _SUPPORTED_STATE_FILES
]
case _:
raise GeographyError(f"Unsupported year: {year}")
return cols, urls, ["GEOID", "NAME", "STUSPS"]
return cols, urls, ["GEOID", "NAME", "STUSPS", "ALAND", "INTPTLAT", "INTPTLON"]


def get_states_geo(year: TigerYear) -> GeoDataFrame:
Expand All @@ -201,30 +206,31 @@ def _get_counties_config(year: TigerYear) -> tuple[list[str], list[str], list[st
"""Produce the args for _get_info or _get_geo (counties)."""
match year:
case year if year in range(2011, 2024):
cols = ["GEOID", "NAME"]
cols = ["GEOID", "NAME", "ALAND", "INTPTLAT", "INTPTLON"]
urls = [
f"{_TIGER_URL}/TIGER{year}/COUNTY/tl_{year}_us_county.zip"
]
case 2010:
cols = ["GEOID10", "NAME10"]
cols = ["GEOID10", "NAME10", "ALAND10", "INTPTLAT10", "INTPTLON10"]
urls = [
f"{_TIGER_URL}/TIGER2010/COUNTY/2010/tl_2010_{xx}_county10.zip"
for xx in _SUPPORTED_STATE_FILES
]
case 2009:
cols = ["CNTYIDFP00", "NAME00"]
cols = ["CNTYIDFP00", "NAME00", "ALAND00",
"AWATER00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2009/tl_2009_us_county00.zip"
]
case 2000:
cols = ["CNTYIDFP00", "NAME00"]
cols = ["CNTYIDFP00", "NAME00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2010/COUNTY/2000/tl_2010_{xx}_county00.zip"
for xx in _SUPPORTED_STATE_FILES
]
case _:
raise GeographyError(f"Unsupported year: {year}")
return cols, urls, ["GEOID", "NAME"]
return cols, urls, ["GEOID", "NAME", "ALAND", "INTPTLAT", "INTPTLON"]


def get_counties_geo(year: TigerYear) -> GeoDataFrame:
Expand All @@ -247,13 +253,13 @@ def _get_tracts_config(year: TigerYear) -> tuple[list[str], list[str], list[str]
states = get_states_info(year)
match year:
case year if year in range(2011, 2024):
cols = ["GEOID"]
cols = ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"]
urls = [
f"{_TIGER_URL}/TIGER{year}/TRACT/tl_{year}_{xx}_tract.zip"
for xx in states["GEOID"]
]
case 2010:
cols = ["GEOID10"]
cols = ["GEOID10", "ALAND10", "INTPTLAT10", "INTPTLON10"]
urls = [
f"{_TIGER_URL}/TIGER2010/TRACT/2010/tl_2010_{xx}_tract10.zip"
for xx in states["GEOID"]
Expand All @@ -262,20 +268,20 @@ def _get_tracts_config(year: TigerYear) -> tuple[list[str], list[str], list[str]
def state_folder(fips, name):
return f"{fips}_{name.upper().replace(' ', '_')}"

cols = ["CTIDFP00"]
cols = ["CTIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2009/{state_folder(xx, name)}/tl_2009_{xx}_tract00.zip"
for xx, name in zip(states["GEOID"], states["NAME"])
]
case 2000:
cols = ["CTIDFP00"]
cols = ["CTIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2010/TRACT/2000/tl_2010_{xx}_tract00.zip"
for xx in states["GEOID"]
]
case _:
raise GeographyError(f"Unsupported year: {year}")
return cols, urls, ["GEOID"]
return cols, urls, ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"]


def get_tracts_geo(year: TigerYear) -> GeoDataFrame:
Expand All @@ -298,13 +304,13 @@ def _get_block_groups_config(year: TigerYear) -> tuple[list[str], list[str], lis
states = get_states_info(year)
match year:
case year if year in range(2011, 2024):
cols = ["GEOID"]
cols = ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"]
urls = [
f"{_TIGER_URL}/TIGER{year}/BG/tl_{year}_{xx}_bg.zip"
for xx in states["GEOID"]
]
case 2010:
cols = ["GEOID10"]
cols = ["GEOID10", "ALAND10", "INTPTLAT10", "INTPTLON10"]
urls = [
f"{_TIGER_URL}/TIGER2010/BG/2010/tl_2010_{xx}_bg10.zip"
for xx in states["GEOID"]
Expand All @@ -313,20 +319,20 @@ def _get_block_groups_config(year: TigerYear) -> tuple[list[str], list[str], lis
def state_folder(fips, name):
return f"{fips}_{name.upper().replace(' ', '_')}"

cols = ["BKGPIDFP00"]
cols = ["BKGPIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2009/{state_folder(xx, name)}/tl_2009_{xx}_bg00.zip"
for xx, name in zip(states["GEOID"], states["NAME"])
]
case 2000:
cols = ["BKGPIDFP00"]
cols = ["BKGPIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2010/BG/2000/tl_2010_{xx}_bg00.zip"
for xx in states["GEOID"]
]
case _:
raise GeographyError(f"Unsupported year: {year}")
return cols, urls, ["GEOID"]
return cols, urls, ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"]


def get_block_groups_geo(year: TigerYear) -> GeoDataFrame:
Expand Down
7 changes: 5 additions & 2 deletions epymorph/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def resolve_name(self, attr_name: str) -> NDArray:
T_co = TypeVar('T_co', bound=np.generic, covariant=True)
"""The result type of a SimulationFunction."""

_DeferredT = TypeVar('_DeferredT', bound=np.generic)
"""The result type of a SimulationFunction during deference."""


class _Context:
def data(self, attribute: AttributeKey) -> NDArray:
Expand Down Expand Up @@ -387,7 +390,7 @@ def rng(self) -> np.random.Generator:
"""The simulation's random number generator."""
return self._rng

def defer(self, other: 'SimulationFunction[T_co]') -> NDArray[T_co]:
def defer(self, other: 'SimulationFunction[_DeferredT]') -> NDArray[_DeferredT]:
"""Defer processing to another similarly-typed instance of a SimulationFunction."""
return other(self._data, self._dim, self._rng)

Expand Down Expand Up @@ -450,6 +453,6 @@ def rng(self) -> np.random.Generator:
return self._ctx.rng

@final
def defer(self, other: 'SimulationFunction[T_co]') -> NDArray[T_co]:
def defer(self, other: 'SimulationFunction[_DeferredT]') -> NDArray[_DeferredT]:
"""Defer processing to another similarly-typed instance of a SimulationFunction."""
return self._ctx.defer(other)
66 changes: 66 additions & 0 deletions epymorph/test/data_shape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,34 @@ def test_time_and_node(self):
ttt(np.arange(90), True)
fff(np.arange(90), False)

def test_node_and_arbitrary(self):
ttt, fff = self.as_bool_asserts(
lambda x, bc=False: Shapes.NxA.matches(_dim, x, bc)
)

ttt(np.arange(6 * 111).reshape((6, 111)))
ttt(np.arange(6 * 222).reshape((6, 222)))

fff(np.arange(6 * 111).reshape((111, 6)))
fff(np.arange(6 * 222).reshape((222, 6)))

fff(np.arange(4 * 111).reshape((4, 111)))
fff(np.arange(4 * 222).reshape((4, 222)))

def test_arbitrary_and_node(self):
ttt, fff = self.as_bool_asserts(
lambda x, bc=False: Shapes.AxN.matches(_dim, x, bc)
)

ttt(np.arange(6 * 111).reshape((111, 6)))
ttt(np.arange(6 * 222).reshape((222, 6)))

fff(np.arange(6 * 111).reshape((6, 111)))
fff(np.arange(6 * 222).reshape((6, 222)))

fff(np.arange(4 * 111).reshape((111, 4)))
fff(np.arange(4 * 222).reshape((222, 4)))

def adapt_test_framework(self, shape, cases):
for i, (input_value, broadcast, expected) in enumerate(cases):
error = f"Failure in test case {i}: ({shape}, {input_value}, {broadcast}, {expected})"
Expand Down Expand Up @@ -241,6 +269,42 @@ def test_adapt_time_and_node(self):
(np.arange(27).reshape((3, 3, 3)), False, None),
])

def test_adapt_node_and_arbitrary(self):
arr1 = np.arange(6 * 111).reshape((6, 111))
arr2 = np.arange(6 * 222).reshape((6, 222))
arr3 = np.arange(6 * 333).reshape((6, 333))

arr4 = np.arange(5 * 111).reshape((5, 111))
arr5 = np.arange(111)
arr6 = np.arange(6)

self.adapt_test_framework(Shapes.NxA, [
(arr1, True, arr1),
(arr2, True, arr2),
(arr3, True, arr3),
(arr4, True, None),
(arr5, True, None),
(arr6, True, None),
])

def test_adapt_arbitrary_and_node(self):
arr1 = np.arange(6 * 111).reshape((111, 6))
arr2 = np.arange(6 * 222).reshape((222, 6))
arr3 = np.arange(6 * 333).reshape((333, 6))

arr4 = np.arange(5 * 111).reshape((111, 5))
arr5 = np.arange(111)
arr6 = np.arange(6)

self.adapt_test_framework(Shapes.AxN, [
(arr1, True, arr1),
(arr2, True, arr2),
(arr3, True, arr3),
(arr4, True, None),
(arr5, True, None),
(arr6, True, None),
])


class TestParseShape(unittest.TestCase):
def test_successful(self):
Expand All @@ -250,6 +314,8 @@ def test_successful(self):
eq(parse_shape('N'), Shapes.N)
eq(parse_shape('NxN'), Shapes.NxN)
eq(parse_shape('TxN'), Shapes.TxN)
eq(parse_shape('AxN'), Shapes.AxN)
eq(parse_shape('NxA'), Shapes.NxA)

def test_failure(self):
def test(s):
Expand Down

0 comments on commit df9165b

Please sign in to comment.