Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactors #116

Merged
merged 3 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions doc/devlog/2023-08-17.ipynb

Large diffs are not rendered by default.

28 changes: 1 addition & 27 deletions epymorph/engine/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
with simulation data, for example, accessing geo and parameter attributes,
calculating the simulation clock, initializing the world state, and so on.
"""
from datetime import timedelta
from functools import cached_property
from typing import Iterable

import numpy as np
from numpy.typing import NDArray

Expand All @@ -15,7 +11,7 @@
from epymorph.geo.geo import Geo
from epymorph.params import NormalizedParams, NormalizedParamsDict
from epymorph.simulation import (CachingGetAttributeMixin, GeoData,
SimDimensions, Tick, TickDelta)
SimDimensions)


class RumeContext(CachingGetAttributeMixin):
Expand Down Expand Up @@ -47,15 +43,6 @@ def __init__(
self._params = params
CachingGetAttributeMixin.__init__(self, geo, params, dim)

def clock(self) -> Iterable[Tick]:
"""Generate the simulation clock signal: a series of Tick objects describing each time step."""
return _simulation_clock(self.dim)

def resolve_tick(self, tick: Tick, delta: TickDelta) -> int:
"""Add a delta to a tick to get the index of the resulting tick."""
return -1 if delta.days == -1 else \
tick.index - tick.step + (self.dim.tau_steps * delta.days) + delta.step

def update_param(self, attr_name: str, value: AttributeArray) -> None:
"""Updates a params value."""
self._params[attr_name] = value.copy()
Expand All @@ -79,16 +66,3 @@ def compartment_mobility(self) -> NDArray[np.bool_]:
def params(self) -> NormalizedParams:
"""The params values."""
return self._params


def _simulation_clock(dim: SimDimensions) -> Iterable[Tick]:
"""Generator for the sequence of ticks which makes up the simulation clock."""
one_day = timedelta(days=1)
tau_steps = list(enumerate(dim.tau_step_lengths))
curr_index = 0
curr_date = dim.start_date
for day in range(dim.days):
for step, tau in tau_steps:
yield Tick(curr_index, day, curr_date, step, tau)
curr_index += 1
curr_date += one_day
4 changes: 2 additions & 2 deletions epymorph/engine/ipm_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _events(self, node: int, tick: Tick, effective_pop: NDArray[SimDType]) -> ND
raise IpmSimNaNException(
self._get_zero_division_args(
rate_args, node, tick, t)
)
) from None
# check for < 0 rate, throw error in this case
if rate < 0:
raise IpmSimLessThanZeroException(
Expand All @@ -204,7 +204,7 @@ def _events(self, node: int, tick: Tick, effective_pop: NDArray[SimDType]) -> ND
raise IpmSimNaNException(
self._get_zero_division_args(
rate_args, node, tick, t)
)
) from None
# check for < 0 base, throw error in this case
if rate < 0:
raise IpmSimLessThanZeroException(
Expand Down
88 changes: 9 additions & 79 deletions epymorph/engine/mm_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from epymorph.movement.movement_model import (MovementContext, MovementModel,
PredefData, TravelClause)
from epymorph.movement.parser import MovementSpec
from epymorph.simulation import Tick
from epymorph.simulation import Tick, resolve_tick
from epymorph.util import row_normalize


Expand Down Expand Up @@ -114,23 +114,24 @@ def calculate_travelers(
class StandardMovementExecutor(MovementEventsMixin, MovementExecutor):
"""The standard implementation of movement model execution."""

_ctx: MovementContext
_model: MovementModel
_clause_masks: dict[TravelClause, NDArray[np.bool_]]
_predef: PredefData = {}
_predef_hash: int | None = None
_predef: PredefData
_predef_hash: int | None
_mobility: NDArray[np.bool_]
_ctx: MovementContext

def __init__(
self,
ctx: MovementContext,
mm: MovementSpec,
compartment_mobility: NDArray[np.bool_],
):
MovementEventsMixin.__init__(self)
self._model = compile_spec(mm, ctx.rng)
self._predef = {}
self._predef_hash = None
self._mobility = compartment_mobility
self._ctx = ctx
self._clause_masks = {c: c.mask(ctx) for c in self._model.clauses}
self._check_predef()

def _check_predef(self) -> None:
Expand All @@ -141,8 +142,6 @@ def _check_predef(self) -> None:
self._predef = self._model.predef(self._ctx)
self._predef_hash = curr_hash
except KeyError as e:
# NOTE: catching KeyError here will be necessary (to get nice error messages)
# until we can properly validate the MM clauses.
msg = f"Missing attribute {e} required by movement model predef."
raise AttributeException(msg) from None

Expand All @@ -167,13 +166,13 @@ def apply(self, world: World, tick: Tick) -> None:

clause_event = calculate_travelers(
self._ctx, self._predef,
clause, self._clause_masks[clause], tick, local_array,
clause, self._mobility, tick, local_array,
)
self.on_movement_clause.publish(clause_event)
travelers = clause_event.actual

returns = clause.returns(self._ctx, tick)
return_tick = self._ctx.resolve_tick(tick, returns)
return_tick = resolve_tick(self._ctx.dim, tick, returns)
world.apply_travel(travelers, return_tick)
total += travelers.sum()

Expand All @@ -197,72 +196,3 @@ def apply(self, world: World, tick: Tick) -> None:

self.on_movement_finish.publish(
OnMovementFinish(tick.index, tick.day, tick.step, total))

def _travelers(self, clause: TravelClause, tick: Tick, local_cohorts: NDArray[SimDType]) -> NDArray[SimDType]:
"""
Calculate the number of travelers resulting from this movement clause for this tick.
This evaluates the requested number movers, modulates that based on the available movers,
then selects exactly which individuals (by compartment) should move.
Returns an (N,N,C) array; from-source-to-destination-by-compartment.
"""
_, N, C, _ = self._ctx.dim.TNCE

clause_movers = clause.requested(self._ctx, self._predef, tick)
np.fill_diagonal(clause_movers, 0)
clause_sum = clause_movers.sum(axis=1, dtype=SimDType)

available_movers = local_cohorts * self._clause_masks[clause]
available_sum = available_movers.sum(axis=1, dtype=SimDType)

# If clause requested total is greater than the total available,
# use mvhg to select as many as possible.
if not np.any(clause_sum > available_sum):
throttled = False
requested_movers = clause_movers
requested_sum = clause_sum
else:
throttled = True
requested_movers = clause_movers.copy()
for src in range(N):
if clause_sum[src] > available_sum[src]:
requested_movers[src, :] = self._ctx.rng.multivariate_hypergeometric(
colors=requested_movers[src, :],
nsample=available_sum[src]
)
requested_sum = requested_movers.sum(axis=1, dtype=SimDType)

# The probability a mover from a src will go to a dst.
requested_prb = row_normalize(requested_movers, requested_sum, dtype=SimDType)

travelers_cs = np.zeros((N, N, C), dtype=SimDType)
for src in range(N):
if requested_sum[src] == 0:
continue

# Select which individuals will be leaving this node.
mover_cs = self._ctx.rng.multivariate_hypergeometric(
available_movers[src, :],
requested_sum[src]
).astype(SimDType)

# Select which location they are each going to.
# (Each row contains the compartments for a destination.)
travelers_cs[src, :, :] = self._ctx.rng.multinomial(
mover_cs,
requested_prb[src, :]
).T.astype(SimDType)

self.on_movement_clause.publish(
OnMovementClause(
tick.index,
tick.day,
tick.step,
clause.name,
clause_movers,
travelers_cs,
requested_sum.sum(),
throttled,
)
)

return travelers_cs
16 changes: 13 additions & 3 deletions epymorph/engine/standard_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from epymorph.movement.movement_model import MovementModel, validate_mm
from epymorph.movement.parser import MovementSpec
from epymorph.params import NormalizedParamsDict, RawParams, normalize_params
from epymorph.simulation import TimeFrame
from epymorph.simulation import TimeFrame, simulation_clock
from epymorph.util import Subscriber


Expand Down Expand Up @@ -127,7 +127,17 @@ def run(self) -> Output:

with error_gate("compiling the simulation", CompilationException):
ipm_exec = StandardIpmExecutor(ctx, self.ipm)
movement_exec = StandardMovementExecutor(ctx, self.mm)

compartment_mobility = np.array(
['immobile' not in tags for tags in self.ipm.compartment_tags],
dtype=np.bool_
)

movement_exec = StandardMovementExecutor(
ctx=ctx,
mm=self.mm,
compartment_mobility=compartment_mobility,
)

# Proxy the movement_exec's events, if anyone is listening for them.
if MovementEventsMixin.has_subscribers(self):
Expand Down Expand Up @@ -157,7 +167,7 @@ def run(self) -> Output:

self.on_start.publish(OnStart(dim=ctx.dim, time_frame=self.time_frame))

for tick in ctx.clock():
for tick in simulation_clock(self.dim):
# First do movement
with error_gate("executing the movement model", MmSimException, AttributeException):
movement_exec.apply(world, tick)
Expand Down
5 changes: 0 additions & 5 deletions epymorph/movement/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ def _compile_clause(
if d in clause.days
)

# TODO: @cache?
def mask_predicate(ctx: MovementContext) -> NDArray[np.bool_]:
return ctx.compartment_mobility

def move_predicate(_ctx: MovementContext, tick: Tick) -> bool:
return clause.leave_step == tick.step and \
tick.date.weekday() in clause_weekdays
Expand All @@ -142,7 +138,6 @@ def returns(_ctx: MovementContext, _tick: Tick) -> TickDelta:

return DynamicTravelClause(
name=name_override(fn_ast.name),
mask_predicate=mask_predicate,
move_predicate=move_predicate,
requested=_adapt_move_function(fn, fn_ast),
returns=returns
Expand Down
25 changes: 0 additions & 25 deletions epymorph/movement/movement_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ def geo(self) -> GeoData:
"""The geo data."""
raise NotImplementedError

@property
def compartment_mobility(self) -> NDArray[np.bool_]:
"""Which compartments from the IPM are subject to movement?"""
raise NotImplementedError

@property
def params(self) -> ParamsData:
"""The parameter data."""
Expand All @@ -54,10 +49,6 @@ def version(self) -> int:
"""
raise NotImplementedError

@abstractmethod
def resolve_tick(self, tick: Tick, delta: TickDelta) -> int:
"""Add a delta to a tick to get the index of the resulting tick."""


PredefData = ParamsData
PredefClause = Callable[[MovementContext], PredefData]
Expand All @@ -68,10 +59,6 @@ class TravelClause(ABC):

name: str

@abstractmethod
def mask(self, ctx: MovementContext) -> NDArray[np.bool_]:
"""Calculate the movement mask for this clause."""

@abstractmethod
def predicate(self, ctx: MovementContext, tick: Tick) -> bool:
"""Should this clause apply this tick?"""
Expand All @@ -85,12 +72,6 @@ def returns(self, ctx: MovementContext, tick: Tick) -> TickDelta:
"""Calculate when this clause's movers should return (which may vary from tick-to-tick)."""


MaskPredicate = Callable[[MovementContext], NDArray[np.bool_]]
"""
A predicate which creates a per-IPM-compartment mask:
should this compartment be subject to movement by this clause?
"""

MovementPredicate = Callable[[MovementContext, Tick], bool]
"""A predicate which decides if a clause should fire this tick."""

Expand All @@ -112,26 +93,20 @@ class DynamicTravelClause(TravelClause):

name: str

_mask: MaskPredicate
_move: MovementPredicate
_requested: MovementFunction
_returns: ReturnsFunction

def __init__(self,
name: str,
mask_predicate: MaskPredicate,
move_predicate: MovementPredicate,
requested: MovementFunction,
returns: ReturnsFunction):
self.name = name
self._mask = mask_predicate
self._move = move_predicate
self._requested = requested
self._returns = returns

def mask(self, ctx: MovementContext) -> NDArray[np.bool_]:
return self._mask(ctx)

def predicate(self, ctx: MovementContext, tick: Tick) -> bool:
return self._move(ctx, tick)

Expand Down
Loading
Loading