Skip to content

Commit

Permalink
Movement model mobility masking refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tyler Coles committed May 23, 2024
1 parent f1808de commit 27a2cdd
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 35 deletions.
15 changes: 7 additions & 8 deletions epymorph/engine/mm_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,24 @@ def calculate_travelers(
class StandardMovementExecutor(MovementEventsMixin, MovementExecutor):
"""The standard implementation of movement model execution."""

_ctx: MovementContext
_model: MovementModel
_clause_masks: dict[TravelClause, NDArray[np.bool_]]
_predef: PredefData = {}
_predef_hash: int | None = None
_predef: PredefData
_predef_hash: int | None
_mobility: NDArray[np.bool_]
_ctx: MovementContext

def __init__(
self,
ctx: MovementContext,
mm: MovementSpec,
compartment_mobility: NDArray[np.bool_],
):
MovementEventsMixin.__init__(self)
self._model = compile_spec(mm, ctx.rng)
self._predef = {}
self._predef_hash = None
self._mobility = compartment_mobility
self._ctx = ctx
self._clause_masks = {c: c.mask(ctx) for c in self._model.clauses}
self._check_predef()

def _check_predef(self) -> None:
Expand All @@ -141,8 +142,6 @@ def _check_predef(self) -> None:
self._predef = self._model.predef(self._ctx)
self._predef_hash = curr_hash
except KeyError as e:
# NOTE: catching KeyError here will be necessary (to get nice error messages)
# until we can properly validate the MM clauses.
msg = f"Missing attribute {e} required by movement model predef."
raise AttributeException(msg) from None

Expand All @@ -167,7 +166,7 @@ def apply(self, world: World, tick: Tick) -> None:

clause_event = calculate_travelers(
self._ctx, self._predef,
clause, self._clause_masks[clause], tick, local_array,
clause, self._mobility, tick, local_array,
)
self.on_movement_clause.publish(clause_event)
travelers = clause_event.actual
Expand Down
12 changes: 11 additions & 1 deletion epymorph/engine/standard_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,17 @@ def run(self) -> Output:

with error_gate("compiling the simulation", CompilationException):
ipm_exec = StandardIpmExecutor(ctx, self.ipm)
movement_exec = StandardMovementExecutor(ctx, self.mm)

compartment_mobility = np.array(
['immobile' not in tags for tags in self.ipm.compartment_tags],
dtype=np.bool_
)

movement_exec = StandardMovementExecutor(
ctx=ctx,
mm=self.mm,
compartment_mobility=compartment_mobility,
)

# Proxy the movement_exec's events, if anyone is listening for them.
if MovementEventsMixin.has_subscribers(self):
Expand Down
5 changes: 0 additions & 5 deletions epymorph/movement/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ def _compile_clause(
if d in clause.days
)

# TODO: @cache?
def mask_predicate(ctx: MovementContext) -> NDArray[np.bool_]:
return ctx.compartment_mobility

def move_predicate(_ctx: MovementContext, tick: Tick) -> bool:
return clause.leave_step == tick.step and \
tick.date.weekday() in clause_weekdays
Expand All @@ -142,7 +138,6 @@ def returns(_ctx: MovementContext, _tick: Tick) -> TickDelta:

return DynamicTravelClause(
name=name_override(fn_ast.name),
mask_predicate=mask_predicate,
move_predicate=move_predicate,
requested=_adapt_move_function(fn, fn_ast),
returns=returns
Expand Down
21 changes: 0 additions & 21 deletions epymorph/movement/movement_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ def geo(self) -> GeoData:
"""The geo data."""
raise NotImplementedError

@property
def compartment_mobility(self) -> NDArray[np.bool_]:
"""Which compartments from the IPM are subject to movement?"""
raise NotImplementedError

@property
def params(self) -> ParamsData:
"""The parameter data."""
Expand Down Expand Up @@ -64,10 +59,6 @@ class TravelClause(ABC):

name: str

@abstractmethod
def mask(self, ctx: MovementContext) -> NDArray[np.bool_]:
"""Calculate the movement mask for this clause."""

@abstractmethod
def predicate(self, ctx: MovementContext, tick: Tick) -> bool:
"""Should this clause apply this tick?"""
Expand All @@ -81,12 +72,6 @@ def returns(self, ctx: MovementContext, tick: Tick) -> TickDelta:
"""Calculate when this clause's movers should return (which may vary from tick-to-tick)."""


MaskPredicate = Callable[[MovementContext], NDArray[np.bool_]]
"""
A predicate which creates a per-IPM-compartment mask:
should this compartment be subject to movement by this clause?
"""

MovementPredicate = Callable[[MovementContext, Tick], bool]
"""A predicate which decides if a clause should fire this tick."""

Expand All @@ -108,26 +93,20 @@ class DynamicTravelClause(TravelClause):

name: str

_mask: MaskPredicate
_move: MovementPredicate
_requested: MovementFunction
_returns: ReturnsFunction

def __init__(self,
name: str,
mask_predicate: MaskPredicate,
move_predicate: MovementPredicate,
requested: MovementFunction,
returns: ReturnsFunction):
self.name = name
self._mask = mask_predicate
self._move = move_predicate
self._requested = requested
self._returns = returns

def mask(self, ctx: MovementContext) -> NDArray[np.bool_]:
return self._mask(ctx)

def predicate(self, ctx: MovementContext, tick: Tick) -> bool:
return self._move(ctx, tick)

Expand Down

0 comments on commit 27a2cdd

Please sign in to comment.