From a10b9ca8e5cdaeb2c79ce8375cd8e33c6d0d6fbb Mon Sep 17 00:00:00 2001 From: Tyler Date: Tue, 23 Jul 2024 09:51:07 -0700 Subject: [PATCH] General changes/fixes. (#135) Add NxA and AxN shapes. TIGER info includes area land and internal point coordinates. Census postal code to fips (and vice versa) maps contain Python strings instead of np.str_ instances. Create CustomScope with list in addition to nparray. SimFunction typing improvement. --- epymorph/data_shape.py | 39 ++++++++++++++++++- epymorph/geography/scope.py | 4 +- epymorph/geography/us_census.py | 4 +- epymorph/geography/us_tiger.py | 46 ++++++++++++---------- epymorph/simulation.py | 7 +++- epymorph/test/data_shape_test.py | 66 ++++++++++++++++++++++++++++++++ 6 files changed, 140 insertions(+), 26 deletions(-) diff --git a/epymorph/data_shape.py b/epymorph/data_shape.py index b8c4126f..4a5d57c9 100644 --- a/epymorph/data_shape.py +++ b/epymorph/data_shape.py @@ -229,15 +229,46 @@ def __str__(self): return "TxN" +class NodeAndArbitrary(DataShape): + """An array of size exactly-N by any dimension.""" + + def matches(self, dim: SimDimensions, value: NDArray, allow_broadcast: bool) -> bool: + if value.ndim == 2 and value.shape[0] == dim.nodes: + return True + return False + + def adapt(self, dim: SimDimensions, value: NDArray[T], allow_broadcast: bool) -> NDArray[T] | None: + return value if self.matches(dim, value, allow_broadcast) else None + + def __str__(self): + return "NxA" + + +class ArbitraryAndNode(DataShape): + """An array of size any dimension by exactly-N.""" + + def matches(self, dim: SimDimensions, value: NDArray, allow_broadcast: bool) -> bool: + if value.ndim == 2 and value.shape[1] == dim.nodes: + return True + return False + + def adapt(self, dim: SimDimensions, value: NDArray[T], allow_broadcast: bool) -> NDArray[T] | None: + return value if self.matches(dim, value, allow_broadcast) else None + + def __str__(self): + return "AxN" + + @dataclass(frozen=True) class Shapes: """Static instances for all available shapes.""" # Data can be in any of these shapes, where: # - S is a single scalar value - # - T is the number of ticks + # - T is the number of days # - N is the number of nodes # - C is the number of IPM compartments + # - A is any length (arbitrary; this dimension is effectively unchecked) S = Scalar() T = Time() @@ -245,6 +276,8 @@ class Shapes: NxC = NodeAndCompartment() NxN = NodeAndNode() TxN = TimeAndNode() + NxA = NodeAndArbitrary() + AxN = ArbitraryAndNode() def parse_shape(shape: str) -> DataShape: @@ -262,6 +295,10 @@ def parse_shape(shape: str) -> DataShape: return Shapes.NxN case "TxN": return Shapes.TxN + case "NxA": + return Shapes.NxA + case "AxN": + return Shapes.AxN case _: raise ValueError(f"'{shape}' is not a valid shape specification.") diff --git a/epymorph/geography/scope.py b/epymorph/geography/scope.py index c3f172c9..00aacb74 100644 --- a/epymorph/geography/scope.py +++ b/epymorph/geography/scope.py @@ -28,7 +28,9 @@ class CustomScope(GeoScope): _nodes: NDArray[np.str_] - def __init__(self, nodes: NDArray[np.str_]): + def __init__(self, nodes: NDArray[np.str_] | list[str]): + if isinstance(nodes, list): + nodes = np.array(nodes, dtype=np.str_) self._nodes = nodes def get_node_ids(self) -> NDArray[np.str_]: diff --git a/epymorph/geography/us_census.py b/epymorph/geography/us_census.py index da177a56..7452ba1c 100644 --- a/epymorph/geography/us_census.py +++ b/epymorph/geography/us_census.py @@ -405,14 +405,14 @@ def verify_fips(granularity: CensusGranularityName, year: int, fips: Sequence[st def state_code_to_fips(year: int) -> Mapping[str, str]: """Mapping from state postal code to FIPS code.""" states = get_us_states(year) - return dict(zip(states.code, states.geoid)) + return dict(zip(states.code.tolist(), states.geoid.tolist())) @cache def state_fips_to_code(year: int) -> Mapping[str, str]: """Mapping from state FIPS code to postal code.""" states = get_us_states(year) - return dict(zip(states.geoid, states.code)) + return dict(zip(states.geoid.tolist(), states.code.tolist())) def validate_state_codes_as_fips(year: int, codes: Sequence[str]) -> Sequence[str]: diff --git a/epymorph/geography/us_tiger.py b/epymorph/geography/us_tiger.py index c6881350..b8503ecf 100644 --- a/epymorph/geography/us_tiger.py +++ b/epymorph/geography/us_tiger.py @@ -40,6 +40,8 @@ # Below there are some commented-code remnants which demonstrate what it takes to support the additional # territories, in case we ever want to reverse this choice. +# NOTE: TIGER files express areas in meters-squared. + TigerYear = Literal[2000, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023] """A supported TIGER file year.""" @@ -156,30 +158,33 @@ def _get_states_config(year: TigerYear) -> tuple[list[str], list[str], list[str] """Produce the args for _get_info or _get_geo (states).""" match year: case year if year in range(2011, 2024): - cols = ["GEOID", "NAME", "STUSPS"] + cols = ["GEOID", "NAME", "STUSPS", "ALAND", "INTPTLAT", "INTPTLON"] urls = [ f"{_TIGER_URL}/TIGER{year}/STATE/tl_{year}_us_state.zip" ] case 2010: - cols = ["GEOID10", "NAME10", "STUSPS10"] + cols = ["GEOID10", "NAME10", "STUSPS10", + "ALAND10", "INTPTLAT10", "INTPTLON10"] urls = [ f"{_TIGER_URL}/TIGER2010/STATE/2010/tl_2010_{xx}_state10.zip" for xx in _SUPPORTED_STATE_FILES ] case 2009: - cols = ["STATEFP00", "NAME00", "STUSPS00"] + cols = ["STATEFP00", "NAME00", "STUSPS00", + "ALAND00", "INTPTLAT00", "INTPTLON00"] urls = [ f"{_TIGER_URL}/TIGER2009/tl_2009_us_state00.zip" ] case 2000: - cols = ["STATEFP00", "NAME00", "STUSPS00"] + cols = ["STATEFP00", "NAME00", "STUSPS00", + "ALAND00", "INTPTLAT00", "INTPTLON00"] urls = [ f"{_TIGER_URL}/TIGER2010/STATE/2000/tl_2010_{xx}_state00.zip" for xx in _SUPPORTED_STATE_FILES ] case _: raise GeographyError(f"Unsupported year: {year}") - return cols, urls, ["GEOID", "NAME", "STUSPS"] + return cols, urls, ["GEOID", "NAME", "STUSPS", "ALAND", "INTPTLAT", "INTPTLON"] def get_states_geo(year: TigerYear) -> GeoDataFrame: @@ -201,30 +206,31 @@ def _get_counties_config(year: TigerYear) -> tuple[list[str], list[str], list[st """Produce the args for _get_info or _get_geo (counties).""" match year: case year if year in range(2011, 2024): - cols = ["GEOID", "NAME"] + cols = ["GEOID", "NAME", "ALAND", "INTPTLAT", "INTPTLON"] urls = [ f"{_TIGER_URL}/TIGER{year}/COUNTY/tl_{year}_us_county.zip" ] case 2010: - cols = ["GEOID10", "NAME10"] + cols = ["GEOID10", "NAME10", "ALAND10", "INTPTLAT10", "INTPTLON10"] urls = [ f"{_TIGER_URL}/TIGER2010/COUNTY/2010/tl_2010_{xx}_county10.zip" for xx in _SUPPORTED_STATE_FILES ] case 2009: - cols = ["CNTYIDFP00", "NAME00"] + cols = ["CNTYIDFP00", "NAME00", "ALAND00", + "AWATER00", "INTPTLAT00", "INTPTLON00"] urls = [ f"{_TIGER_URL}/TIGER2009/tl_2009_us_county00.zip" ] case 2000: - cols = ["CNTYIDFP00", "NAME00"] + cols = ["CNTYIDFP00", "NAME00", "ALAND00", "INTPTLAT00", "INTPTLON00"] urls = [ f"{_TIGER_URL}/TIGER2010/COUNTY/2000/tl_2010_{xx}_county00.zip" for xx in _SUPPORTED_STATE_FILES ] case _: raise GeographyError(f"Unsupported year: {year}") - return cols, urls, ["GEOID", "NAME"] + return cols, urls, ["GEOID", "NAME", "ALAND", "INTPTLAT", "INTPTLON"] def get_counties_geo(year: TigerYear) -> GeoDataFrame: @@ -247,13 +253,13 @@ def _get_tracts_config(year: TigerYear) -> tuple[list[str], list[str], list[str] states = get_states_info(year) match year: case year if year in range(2011, 2024): - cols = ["GEOID"] + cols = ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"] urls = [ f"{_TIGER_URL}/TIGER{year}/TRACT/tl_{year}_{xx}_tract.zip" for xx in states["GEOID"] ] case 2010: - cols = ["GEOID10"] + cols = ["GEOID10", "ALAND10", "INTPTLAT10", "INTPTLON10"] urls = [ f"{_TIGER_URL}/TIGER2010/TRACT/2010/tl_2010_{xx}_tract10.zip" for xx in states["GEOID"] @@ -262,20 +268,20 @@ def _get_tracts_config(year: TigerYear) -> tuple[list[str], list[str], list[str] def state_folder(fips, name): return f"{fips}_{name.upper().replace(' ', '_')}" - cols = ["CTIDFP00"] + cols = ["CTIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"] urls = [ f"{_TIGER_URL}/TIGER2009/{state_folder(xx, name)}/tl_2009_{xx}_tract00.zip" for xx, name in zip(states["GEOID"], states["NAME"]) ] case 2000: - cols = ["CTIDFP00"] + cols = ["CTIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"] urls = [ f"{_TIGER_URL}/TIGER2010/TRACT/2000/tl_2010_{xx}_tract00.zip" for xx in states["GEOID"] ] case _: raise GeographyError(f"Unsupported year: {year}") - return cols, urls, ["GEOID"] + return cols, urls, ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"] def get_tracts_geo(year: TigerYear) -> GeoDataFrame: @@ -298,13 +304,13 @@ def _get_block_groups_config(year: TigerYear) -> tuple[list[str], list[str], lis states = get_states_info(year) match year: case year if year in range(2011, 2024): - cols = ["GEOID"] + cols = ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"] urls = [ f"{_TIGER_URL}/TIGER{year}/BG/tl_{year}_{xx}_bg.zip" for xx in states["GEOID"] ] case 2010: - cols = ["GEOID10"] + cols = ["GEOID10", "ALAND10", "INTPTLAT10", "INTPTLON10"] urls = [ f"{_TIGER_URL}/TIGER2010/BG/2010/tl_2010_{xx}_bg10.zip" for xx in states["GEOID"] @@ -313,20 +319,20 @@ def _get_block_groups_config(year: TigerYear) -> tuple[list[str], list[str], lis def state_folder(fips, name): return f"{fips}_{name.upper().replace(' ', '_')}" - cols = ["BKGPIDFP00"] + cols = ["BKGPIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"] urls = [ f"{_TIGER_URL}/TIGER2009/{state_folder(xx, name)}/tl_2009_{xx}_bg00.zip" for xx, name in zip(states["GEOID"], states["NAME"]) ] case 2000: - cols = ["BKGPIDFP00"] + cols = ["BKGPIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"] urls = [ f"{_TIGER_URL}/TIGER2010/BG/2000/tl_2010_{xx}_bg00.zip" for xx in states["GEOID"] ] case _: raise GeographyError(f"Unsupported year: {year}") - return cols, urls, ["GEOID"] + return cols, urls, ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"] def get_block_groups_geo(year: TigerYear) -> GeoDataFrame: diff --git a/epymorph/simulation.py b/epymorph/simulation.py index 4f8e70d9..44230143 100644 --- a/epymorph/simulation.py +++ b/epymorph/simulation.py @@ -329,6 +329,9 @@ def resolve_name(self, attr_name: str) -> NDArray: T_co = TypeVar('T_co', bound=np.generic, covariant=True) """The result type of a SimulationFunction.""" +_DeferredT = TypeVar('_DeferredT', bound=np.generic) +"""The result type of a SimulationFunction during deference.""" + class _Context: def data(self, attribute: AttributeKey) -> NDArray: @@ -387,7 +390,7 @@ def rng(self) -> np.random.Generator: """The simulation's random number generator.""" return self._rng - def defer(self, other: 'SimulationFunction[T_co]') -> NDArray[T_co]: + def defer(self, other: 'SimulationFunction[_DeferredT]') -> NDArray[_DeferredT]: """Defer processing to another similarly-typed instance of a SimulationFunction.""" return other(self._data, self._dim, self._rng) @@ -450,6 +453,6 @@ def rng(self) -> np.random.Generator: return self._ctx.rng @final - def defer(self, other: 'SimulationFunction[T_co]') -> NDArray[T_co]: + def defer(self, other: 'SimulationFunction[_DeferredT]') -> NDArray[_DeferredT]: """Defer processing to another similarly-typed instance of a SimulationFunction.""" return self._ctx.defer(other) diff --git a/epymorph/test/data_shape_test.py b/epymorph/test/data_shape_test.py index 26c9575a..5b148b6c 100644 --- a/epymorph/test/data_shape_test.py +++ b/epymorph/test/data_shape_test.py @@ -107,6 +107,34 @@ def test_time_and_node(self): ttt(np.arange(90), True) fff(np.arange(90), False) + def test_node_and_arbitrary(self): + ttt, fff = self.as_bool_asserts( + lambda x, bc=False: Shapes.NxA.matches(_dim, x, bc) + ) + + ttt(np.arange(6 * 111).reshape((6, 111))) + ttt(np.arange(6 * 222).reshape((6, 222))) + + fff(np.arange(6 * 111).reshape((111, 6))) + fff(np.arange(6 * 222).reshape((222, 6))) + + fff(np.arange(4 * 111).reshape((4, 111))) + fff(np.arange(4 * 222).reshape((4, 222))) + + def test_arbitrary_and_node(self): + ttt, fff = self.as_bool_asserts( + lambda x, bc=False: Shapes.AxN.matches(_dim, x, bc) + ) + + ttt(np.arange(6 * 111).reshape((111, 6))) + ttt(np.arange(6 * 222).reshape((222, 6))) + + fff(np.arange(6 * 111).reshape((6, 111))) + fff(np.arange(6 * 222).reshape((6, 222))) + + fff(np.arange(4 * 111).reshape((111, 4))) + fff(np.arange(4 * 222).reshape((222, 4))) + def adapt_test_framework(self, shape, cases): for i, (input_value, broadcast, expected) in enumerate(cases): error = f"Failure in test case {i}: ({shape}, {input_value}, {broadcast}, {expected})" @@ -241,6 +269,42 @@ def test_adapt_time_and_node(self): (np.arange(27).reshape((3, 3, 3)), False, None), ]) + def test_adapt_node_and_arbitrary(self): + arr1 = np.arange(6 * 111).reshape((6, 111)) + arr2 = np.arange(6 * 222).reshape((6, 222)) + arr3 = np.arange(6 * 333).reshape((6, 333)) + + arr4 = np.arange(5 * 111).reshape((5, 111)) + arr5 = np.arange(111) + arr6 = np.arange(6) + + self.adapt_test_framework(Shapes.NxA, [ + (arr1, True, arr1), + (arr2, True, arr2), + (arr3, True, arr3), + (arr4, True, None), + (arr5, True, None), + (arr6, True, None), + ]) + + def test_adapt_arbitrary_and_node(self): + arr1 = np.arange(6 * 111).reshape((111, 6)) + arr2 = np.arange(6 * 222).reshape((222, 6)) + arr3 = np.arange(6 * 333).reshape((333, 6)) + + arr4 = np.arange(5 * 111).reshape((111, 5)) + arr5 = np.arange(111) + arr6 = np.arange(6) + + self.adapt_test_framework(Shapes.AxN, [ + (arr1, True, arr1), + (arr2, True, arr2), + (arr3, True, arr3), + (arr4, True, None), + (arr5, True, None), + (arr6, True, None), + ]) + class TestParseShape(unittest.TestCase): def test_successful(self): @@ -250,6 +314,8 @@ def test_successful(self): eq(parse_shape('N'), Shapes.N) eq(parse_shape('NxN'), Shapes.NxN) eq(parse_shape('TxN'), Shapes.TxN) + eq(parse_shape('AxN'), Shapes.AxN) + eq(parse_shape('NxA'), Shapes.NxA) def test_failure(self): def test(s):