diff --git a/epymorph/engine/mm_exec.py b/epymorph/engine/mm_exec.py index 278e1fc8..fe049d5e 100644 --- a/epymorph/engine/mm_exec.py +++ b/epymorph/engine/mm_exec.py @@ -114,23 +114,24 @@ def calculate_travelers( class StandardMovementExecutor(MovementEventsMixin, MovementExecutor): """The standard implementation of movement model execution.""" - _ctx: MovementContext _model: MovementModel - _clause_masks: dict[TravelClause, NDArray[np.bool_]] - _predef: PredefData = {} - _predef_hash: int | None = None + _predef: PredefData + _predef_hash: int | None + _mobility: NDArray[np.bool_] + _ctx: MovementContext def __init__( self, ctx: MovementContext, mm: MovementSpec, + compartment_mobility: NDArray[np.bool_], ): MovementEventsMixin.__init__(self) self._model = compile_spec(mm, ctx.rng) self._predef = {} self._predef_hash = None + self._mobility = compartment_mobility self._ctx = ctx - self._clause_masks = {c: c.mask(ctx) for c in self._model.clauses} self._check_predef() def _check_predef(self) -> None: @@ -141,8 +142,6 @@ def _check_predef(self) -> None: self._predef = self._model.predef(self._ctx) self._predef_hash = curr_hash except KeyError as e: - # NOTE: catching KeyError here will be necessary (to get nice error messages) - # until we can properly validate the MM clauses. msg = f"Missing attribute {e} required by movement model predef." raise AttributeException(msg) from None @@ -167,7 +166,7 @@ def apply(self, world: World, tick: Tick) -> None: clause_event = calculate_travelers( self._ctx, self._predef, - clause, self._clause_masks[clause], tick, local_array, + clause, self._mobility, tick, local_array, ) self.on_movement_clause.publish(clause_event) travelers = clause_event.actual diff --git a/epymorph/engine/standard_sim.py b/epymorph/engine/standard_sim.py index cfbb8756..bbcea9c3 100644 --- a/epymorph/engine/standard_sim.py +++ b/epymorph/engine/standard_sim.py @@ -127,7 +127,17 @@ def run(self) -> Output: with error_gate("compiling the simulation", CompilationException): ipm_exec = StandardIpmExecutor(ctx, self.ipm) - movement_exec = StandardMovementExecutor(ctx, self.mm) + + compartment_mobility = np.array( + ['immobile' not in tags for tags in self.ipm.compartment_tags], + dtype=np.bool_ + ) + + movement_exec = StandardMovementExecutor( + ctx=ctx, + mm=self.mm, + compartment_mobility=compartment_mobility, + ) # Proxy the movement_exec's events, if anyone is listening for them. if MovementEventsMixin.has_subscribers(self): diff --git a/epymorph/movement/compile.py b/epymorph/movement/compile.py index 9b51efab..f4187f52 100644 --- a/epymorph/movement/compile.py +++ b/epymorph/movement/compile.py @@ -126,10 +126,6 @@ def _compile_clause( if d in clause.days ) - # TODO: @cache? - def mask_predicate(ctx: MovementContext) -> NDArray[np.bool_]: - return ctx.compartment_mobility - def move_predicate(_ctx: MovementContext, tick: Tick) -> bool: return clause.leave_step == tick.step and \ tick.date.weekday() in clause_weekdays @@ -142,7 +138,6 @@ def returns(_ctx: MovementContext, _tick: Tick) -> TickDelta: return DynamicTravelClause( name=name_override(fn_ast.name), - mask_predicate=mask_predicate, move_predicate=move_predicate, requested=_adapt_move_function(fn, fn_ast), returns=returns diff --git a/epymorph/movement/movement_model.py b/epymorph/movement/movement_model.py index 1b4f8087..a2e20f25 100644 --- a/epymorph/movement/movement_model.py +++ b/epymorph/movement/movement_model.py @@ -31,11 +31,6 @@ def geo(self) -> GeoData: """The geo data.""" raise NotImplementedError - @property - def compartment_mobility(self) -> NDArray[np.bool_]: - """Which compartments from the IPM are subject to movement?""" - raise NotImplementedError - @property def params(self) -> ParamsData: """The parameter data.""" @@ -64,10 +59,6 @@ class TravelClause(ABC): name: str - @abstractmethod - def mask(self, ctx: MovementContext) -> NDArray[np.bool_]: - """Calculate the movement mask for this clause.""" - @abstractmethod def predicate(self, ctx: MovementContext, tick: Tick) -> bool: """Should this clause apply this tick?""" @@ -81,12 +72,6 @@ def returns(self, ctx: MovementContext, tick: Tick) -> TickDelta: """Calculate when this clause's movers should return (which may vary from tick-to-tick).""" -MaskPredicate = Callable[[MovementContext], NDArray[np.bool_]] -""" -A predicate which creates a per-IPM-compartment mask: -should this compartment be subject to movement by this clause? -""" - MovementPredicate = Callable[[MovementContext, Tick], bool] """A predicate which decides if a clause should fire this tick.""" @@ -108,26 +93,20 @@ class DynamicTravelClause(TravelClause): name: str - _mask: MaskPredicate _move: MovementPredicate _requested: MovementFunction _returns: ReturnsFunction def __init__(self, name: str, - mask_predicate: MaskPredicate, move_predicate: MovementPredicate, requested: MovementFunction, returns: ReturnsFunction): self.name = name - self._mask = mask_predicate self._move = move_predicate self._requested = requested self._returns = returns - def mask(self, ctx: MovementContext) -> NDArray[np.bool_]: - return self._mask(ctx) - def predicate(self, ctx: MovementContext, tick: Tick) -> bool: return self._move(ctx, tick)