Skip to content

Commit

Permalink
General changes/fixes. (#135)
Browse files Browse the repository at this point in the history
Add NxA and AxN shapes.
TIGER info includes area land and internal point coordinates.
Census postal code to fips (and vice versa) maps contain Python strings instead of np.str_ instances.
Create CustomScope with list in addition to nparray.
SimFunction typing improvement.
  • Loading branch information
JavadocMD authored Jul 23, 2024
1 parent 75de9e2 commit a10b9ca
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 26 deletions.
39 changes: 38 additions & 1 deletion epymorph/data_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,55 @@ def __str__(self):
return "TxN"


class NodeAndArbitrary(DataShape):
"""An array of size exactly-N by any dimension."""

def matches(self, dim: SimDimensions, value: NDArray, allow_broadcast: bool) -> bool:
if value.ndim == 2 and value.shape[0] == dim.nodes:
return True
return False

def adapt(self, dim: SimDimensions, value: NDArray[T], allow_broadcast: bool) -> NDArray[T] | None:
return value if self.matches(dim, value, allow_broadcast) else None

def __str__(self):
return "NxA"


class ArbitraryAndNode(DataShape):
"""An array of size any dimension by exactly-N."""

def matches(self, dim: SimDimensions, value: NDArray, allow_broadcast: bool) -> bool:
if value.ndim == 2 and value.shape[1] == dim.nodes:
return True
return False

def adapt(self, dim: SimDimensions, value: NDArray[T], allow_broadcast: bool) -> NDArray[T] | None:
return value if self.matches(dim, value, allow_broadcast) else None

def __str__(self):
return "AxN"


@dataclass(frozen=True)
class Shapes:
"""Static instances for all available shapes."""

# Data can be in any of these shapes, where:
# - S is a single scalar value
# - T is the number of ticks
# - T is the number of days
# - N is the number of nodes
# - C is the number of IPM compartments
# - A is any length (arbitrary; this dimension is effectively unchecked)

S = Scalar()
T = Time()
N = Node()
NxC = NodeAndCompartment()
NxN = NodeAndNode()
TxN = TimeAndNode()
NxA = NodeAndArbitrary()
AxN = ArbitraryAndNode()


def parse_shape(shape: str) -> DataShape:
Expand All @@ -262,6 +295,10 @@ def parse_shape(shape: str) -> DataShape:
return Shapes.NxN
case "TxN":
return Shapes.TxN
case "NxA":
return Shapes.NxA
case "AxN":
return Shapes.AxN
case _:
raise ValueError(f"'{shape}' is not a valid shape specification.")

Expand Down
4 changes: 3 additions & 1 deletion epymorph/geography/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class CustomScope(GeoScope):

_nodes: NDArray[np.str_]

def __init__(self, nodes: NDArray[np.str_]):
def __init__(self, nodes: NDArray[np.str_] | list[str]):
if isinstance(nodes, list):
nodes = np.array(nodes, dtype=np.str_)
self._nodes = nodes

def get_node_ids(self) -> NDArray[np.str_]:
Expand Down
4 changes: 2 additions & 2 deletions epymorph/geography/us_census.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,14 +405,14 @@ def verify_fips(granularity: CensusGranularityName, year: int, fips: Sequence[st
def state_code_to_fips(year: int) -> Mapping[str, str]:
"""Mapping from state postal code to FIPS code."""
states = get_us_states(year)
return dict(zip(states.code, states.geoid))
return dict(zip(states.code.tolist(), states.geoid.tolist()))


@cache
def state_fips_to_code(year: int) -> Mapping[str, str]:
"""Mapping from state FIPS code to postal code."""
states = get_us_states(year)
return dict(zip(states.geoid, states.code))
return dict(zip(states.geoid.tolist(), states.code.tolist()))


def validate_state_codes_as_fips(year: int, codes: Sequence[str]) -> Sequence[str]:
Expand Down
46 changes: 26 additions & 20 deletions epymorph/geography/us_tiger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
# Below there are some commented-code remnants which demonstrate what it takes to support the additional
# territories, in case we ever want to reverse this choice.

# NOTE: TIGER files express areas in meters-squared.

TigerYear = Literal[2000, 2009, 2010, 2011, 2012, 2013, 2014,
2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023]
"""A supported TIGER file year."""
Expand Down Expand Up @@ -156,30 +158,33 @@ def _get_states_config(year: TigerYear) -> tuple[list[str], list[str], list[str]
"""Produce the args for _get_info or _get_geo (states)."""
match year:
case year if year in range(2011, 2024):
cols = ["GEOID", "NAME", "STUSPS"]
cols = ["GEOID", "NAME", "STUSPS", "ALAND", "INTPTLAT", "INTPTLON"]
urls = [
f"{_TIGER_URL}/TIGER{year}/STATE/tl_{year}_us_state.zip"
]
case 2010:
cols = ["GEOID10", "NAME10", "STUSPS10"]
cols = ["GEOID10", "NAME10", "STUSPS10",
"ALAND10", "INTPTLAT10", "INTPTLON10"]
urls = [
f"{_TIGER_URL}/TIGER2010/STATE/2010/tl_2010_{xx}_state10.zip"
for xx in _SUPPORTED_STATE_FILES
]
case 2009:
cols = ["STATEFP00", "NAME00", "STUSPS00"]
cols = ["STATEFP00", "NAME00", "STUSPS00",
"ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2009/tl_2009_us_state00.zip"
]
case 2000:
cols = ["STATEFP00", "NAME00", "STUSPS00"]
cols = ["STATEFP00", "NAME00", "STUSPS00",
"ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2010/STATE/2000/tl_2010_{xx}_state00.zip"
for xx in _SUPPORTED_STATE_FILES
]
case _:
raise GeographyError(f"Unsupported year: {year}")
return cols, urls, ["GEOID", "NAME", "STUSPS"]
return cols, urls, ["GEOID", "NAME", "STUSPS", "ALAND", "INTPTLAT", "INTPTLON"]


def get_states_geo(year: TigerYear) -> GeoDataFrame:
Expand All @@ -201,30 +206,31 @@ def _get_counties_config(year: TigerYear) -> tuple[list[str], list[str], list[st
"""Produce the args for _get_info or _get_geo (counties)."""
match year:
case year if year in range(2011, 2024):
cols = ["GEOID", "NAME"]
cols = ["GEOID", "NAME", "ALAND", "INTPTLAT", "INTPTLON"]
urls = [
f"{_TIGER_URL}/TIGER{year}/COUNTY/tl_{year}_us_county.zip"
]
case 2010:
cols = ["GEOID10", "NAME10"]
cols = ["GEOID10", "NAME10", "ALAND10", "INTPTLAT10", "INTPTLON10"]
urls = [
f"{_TIGER_URL}/TIGER2010/COUNTY/2010/tl_2010_{xx}_county10.zip"
for xx in _SUPPORTED_STATE_FILES
]
case 2009:
cols = ["CNTYIDFP00", "NAME00"]
cols = ["CNTYIDFP00", "NAME00", "ALAND00",
"AWATER00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2009/tl_2009_us_county00.zip"
]
case 2000:
cols = ["CNTYIDFP00", "NAME00"]
cols = ["CNTYIDFP00", "NAME00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2010/COUNTY/2000/tl_2010_{xx}_county00.zip"
for xx in _SUPPORTED_STATE_FILES
]
case _:
raise GeographyError(f"Unsupported year: {year}")
return cols, urls, ["GEOID", "NAME"]
return cols, urls, ["GEOID", "NAME", "ALAND", "INTPTLAT", "INTPTLON"]


def get_counties_geo(year: TigerYear) -> GeoDataFrame:
Expand All @@ -247,13 +253,13 @@ def _get_tracts_config(year: TigerYear) -> tuple[list[str], list[str], list[str]
states = get_states_info(year)
match year:
case year if year in range(2011, 2024):
cols = ["GEOID"]
cols = ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"]
urls = [
f"{_TIGER_URL}/TIGER{year}/TRACT/tl_{year}_{xx}_tract.zip"
for xx in states["GEOID"]
]
case 2010:
cols = ["GEOID10"]
cols = ["GEOID10", "ALAND10", "INTPTLAT10", "INTPTLON10"]
urls = [
f"{_TIGER_URL}/TIGER2010/TRACT/2010/tl_2010_{xx}_tract10.zip"
for xx in states["GEOID"]
Expand All @@ -262,20 +268,20 @@ def _get_tracts_config(year: TigerYear) -> tuple[list[str], list[str], list[str]
def state_folder(fips, name):
return f"{fips}_{name.upper().replace(' ', '_')}"

cols = ["CTIDFP00"]
cols = ["CTIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2009/{state_folder(xx, name)}/tl_2009_{xx}_tract00.zip"
for xx, name in zip(states["GEOID"], states["NAME"])
]
case 2000:
cols = ["CTIDFP00"]
cols = ["CTIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2010/TRACT/2000/tl_2010_{xx}_tract00.zip"
for xx in states["GEOID"]
]
case _:
raise GeographyError(f"Unsupported year: {year}")
return cols, urls, ["GEOID"]
return cols, urls, ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"]


def get_tracts_geo(year: TigerYear) -> GeoDataFrame:
Expand All @@ -298,13 +304,13 @@ def _get_block_groups_config(year: TigerYear) -> tuple[list[str], list[str], lis
states = get_states_info(year)
match year:
case year if year in range(2011, 2024):
cols = ["GEOID"]
cols = ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"]
urls = [
f"{_TIGER_URL}/TIGER{year}/BG/tl_{year}_{xx}_bg.zip"
for xx in states["GEOID"]
]
case 2010:
cols = ["GEOID10"]
cols = ["GEOID10", "ALAND10", "INTPTLAT10", "INTPTLON10"]
urls = [
f"{_TIGER_URL}/TIGER2010/BG/2010/tl_2010_{xx}_bg10.zip"
for xx in states["GEOID"]
Expand All @@ -313,20 +319,20 @@ def _get_block_groups_config(year: TigerYear) -> tuple[list[str], list[str], lis
def state_folder(fips, name):
return f"{fips}_{name.upper().replace(' ', '_')}"

cols = ["BKGPIDFP00"]
cols = ["BKGPIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2009/{state_folder(xx, name)}/tl_2009_{xx}_bg00.zip"
for xx, name in zip(states["GEOID"], states["NAME"])
]
case 2000:
cols = ["BKGPIDFP00"]
cols = ["BKGPIDFP00", "ALAND00", "INTPTLAT00", "INTPTLON00"]
urls = [
f"{_TIGER_URL}/TIGER2010/BG/2000/tl_2010_{xx}_bg00.zip"
for xx in states["GEOID"]
]
case _:
raise GeographyError(f"Unsupported year: {year}")
return cols, urls, ["GEOID"]
return cols, urls, ["GEOID", "ALAND", "INTPTLAT", "INTPTLON"]


def get_block_groups_geo(year: TigerYear) -> GeoDataFrame:
Expand Down
7 changes: 5 additions & 2 deletions epymorph/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def resolve_name(self, attr_name: str) -> NDArray:
T_co = TypeVar('T_co', bound=np.generic, covariant=True)
"""The result type of a SimulationFunction."""

_DeferredT = TypeVar('_DeferredT', bound=np.generic)
"""The result type of a SimulationFunction during deference."""


class _Context:
def data(self, attribute: AttributeKey) -> NDArray:
Expand Down Expand Up @@ -387,7 +390,7 @@ def rng(self) -> np.random.Generator:
"""The simulation's random number generator."""
return self._rng

def defer(self, other: 'SimulationFunction[T_co]') -> NDArray[T_co]:
def defer(self, other: 'SimulationFunction[_DeferredT]') -> NDArray[_DeferredT]:
"""Defer processing to another similarly-typed instance of a SimulationFunction."""
return other(self._data, self._dim, self._rng)

Expand Down Expand Up @@ -450,6 +453,6 @@ def rng(self) -> np.random.Generator:
return self._ctx.rng

@final
def defer(self, other: 'SimulationFunction[T_co]') -> NDArray[T_co]:
def defer(self, other: 'SimulationFunction[_DeferredT]') -> NDArray[_DeferredT]:
"""Defer processing to another similarly-typed instance of a SimulationFunction."""
return self._ctx.defer(other)
66 changes: 66 additions & 0 deletions epymorph/test/data_shape_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,34 @@ def test_time_and_node(self):
ttt(np.arange(90), True)
fff(np.arange(90), False)

def test_node_and_arbitrary(self):
ttt, fff = self.as_bool_asserts(
lambda x, bc=False: Shapes.NxA.matches(_dim, x, bc)
)

ttt(np.arange(6 * 111).reshape((6, 111)))
ttt(np.arange(6 * 222).reshape((6, 222)))

fff(np.arange(6 * 111).reshape((111, 6)))
fff(np.arange(6 * 222).reshape((222, 6)))

fff(np.arange(4 * 111).reshape((4, 111)))
fff(np.arange(4 * 222).reshape((4, 222)))

def test_arbitrary_and_node(self):
ttt, fff = self.as_bool_asserts(
lambda x, bc=False: Shapes.AxN.matches(_dim, x, bc)
)

ttt(np.arange(6 * 111).reshape((111, 6)))
ttt(np.arange(6 * 222).reshape((222, 6)))

fff(np.arange(6 * 111).reshape((6, 111)))
fff(np.arange(6 * 222).reshape((6, 222)))

fff(np.arange(4 * 111).reshape((111, 4)))
fff(np.arange(4 * 222).reshape((222, 4)))

def adapt_test_framework(self, shape, cases):
for i, (input_value, broadcast, expected) in enumerate(cases):
error = f"Failure in test case {i}: ({shape}, {input_value}, {broadcast}, {expected})"
Expand Down Expand Up @@ -241,6 +269,42 @@ def test_adapt_time_and_node(self):
(np.arange(27).reshape((3, 3, 3)), False, None),
])

def test_adapt_node_and_arbitrary(self):
arr1 = np.arange(6 * 111).reshape((6, 111))
arr2 = np.arange(6 * 222).reshape((6, 222))
arr3 = np.arange(6 * 333).reshape((6, 333))

arr4 = np.arange(5 * 111).reshape((5, 111))
arr5 = np.arange(111)
arr6 = np.arange(6)

self.adapt_test_framework(Shapes.NxA, [
(arr1, True, arr1),
(arr2, True, arr2),
(arr3, True, arr3),
(arr4, True, None),
(arr5, True, None),
(arr6, True, None),
])

def test_adapt_arbitrary_and_node(self):
arr1 = np.arange(6 * 111).reshape((111, 6))
arr2 = np.arange(6 * 222).reshape((222, 6))
arr3 = np.arange(6 * 333).reshape((333, 6))

arr4 = np.arange(5 * 111).reshape((111, 5))
arr5 = np.arange(111)
arr6 = np.arange(6)

self.adapt_test_framework(Shapes.AxN, [
(arr1, True, arr1),
(arr2, True, arr2),
(arr3, True, arr3),
(arr4, True, None),
(arr5, True, None),
(arr6, True, None),
])


class TestParseShape(unittest.TestCase):
def test_successful(self):
Expand All @@ -250,6 +314,8 @@ def test_successful(self):
eq(parse_shape('N'), Shapes.N)
eq(parse_shape('NxN'), Shapes.NxN)
eq(parse_shape('TxN'), Shapes.TxN)
eq(parse_shape('AxN'), Shapes.AxN)
eq(parse_shape('NxA'), Shapes.NxA)

def test_failure(self):
def test(s):
Expand Down

0 comments on commit a10b9ca

Please sign in to comment.