Skip to content

Commit

Permalink
AST manipulation applied to Movement and Parameter code (#82)
Browse files Browse the repository at this point in the history
Includes:
* Better code execution security via AST transforms.
* Simplify proxy use -- geo and dim (new).
* Use AST manipulation on parameter functions.
* Use AST manipulation on movement functions to avoid need to recompile them if the context changes.
* Movement model predef now managed by the movement executor.
* Add np.pi to namespaces.
  • Loading branch information
Tyler authored Jan 24, 2024
1 parent 101d20f commit c48a15b
Show file tree
Hide file tree
Showing 23 changed files with 1,193 additions and 367 deletions.
194 changes: 194 additions & 0 deletions doc/devlog/2024-01-08.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions doc/devlog/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ This folder is a handy place to put Jupyter notebooks or other documents which h
| 2023-10-10.ipynb | Tyler | | A demo of various epymorph workflows in a Notebook environment, designed for a live presentation. |
| 2023-10-26.ipynb | Tyler | | Describes a major Geo system refactor and introduces new systems. |
| 2023-11-03-seirs-example.ipynb | Ajay | | Demonstrates the building and running of an SEIRS model. |
| 2023-11-08.ipynb | Ajay | | Demonstration of using proxy geo to access data in parameter functions. |
| 2023-11-15.ipynb | Ajay | | Detailed description of parameter functions functionality. |
| 2023-11-20-adrio-phase-2-demo.ipynb | Trevor | | Demonstrates the refactor work on DynamicGeos and the ADRIO system, and geo cache handling. |
| 2023-11-22-ipm-probs.ipynb | Tyler | | Analyzing statistical correctness of our IPM processing algorithms. |
| 2023-12-05.ipynb | Tyler | | A brief tour of changes to epymorph due to the refactor effort. |
| 2024-01-08.ipynb | Tyler | | Another functional parameters demonstration, revisiting the Bonus Example from 2023-10-10. |

## Contributing

Expand Down
5 changes: 3 additions & 2 deletions epymorph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from epymorph.data import geo_library, ipm_library, mm_library
from epymorph.data_shape import Shapes
from epymorph.engine.standard_sim import StandardSimulation
from epymorph.geo.abstract import geo as proxy_geo
from epymorph.plots import plot_event, plot_pop
from epymorph.proxy import dim, geo
from epymorph.simulation import SimDType, TimeFrame, default_rng, sim_messaging

__all__ = [
Expand All @@ -15,7 +15,8 @@
'geo_library',
'Shapes',
'StandardSimulation',
'proxy_geo',
'geo',
'dim',
'plot_event',
'plot_pop',
'SimDType',
Expand Down
151 changes: 151 additions & 0 deletions epymorph/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
Utilities for handling user code parsed from strings.
While the primary goal is to extract runnable Python functions,
it is also important to take some security measures to mitigate
the execution of malicious code.
"""
import ast
import re
import textwrap
from typing import Any, Callable


def has_function_structure(string: str) -> bool:
"""Check if a string seems to have the structure of a function definition."""
return re.search(r"^\s*def\s+\w+\s*\(.*?\):", string, flags=re.MULTILINE) is not None


def parse_function(code_string: str, unsafe: bool = False) -> ast.FunctionDef:
"""
Parse a function from a code string, returning the function's AST.
The resulting AST will have security mitigations applied, unless `unsafe` is True.
The string must contain a single top-level Python function definition,
or else ValueError is raised. Raises SyntaxError if the function is not valid Python.
"""
tree = ast.parse(textwrap.dedent(code_string), '<string>', mode='exec')
functions = [statement for statement in tree.body
if isinstance(statement, ast.FunctionDef)]
if (n := len(functions)) != 1:
msg = f"Code must contain exactly one top-level function definition: found {n}"
raise ValueError(msg)
return functions[0] if unsafe else scrub_function(functions[0])


class CodeCompileException(Exception):
"""An exception raised when code cannot be compiled for some reason."""


class CodeSecurityException(CodeCompileException):
"""An exception raised when code cannot be safely compiled due to security rules."""


_FORBIDDEN_NAMES = frozenset({
'eval', 'exec', 'compile', 'object', 'print', 'open',
'quit', 'exit', 'globals', 'locals', 'help', 'breakpoint'
})
"""Names which should not exist in a user-defined function."""


class SecureTransformer(ast.NodeTransformer):
"""AST transformer for applying basic security mitigations."""

def visit_Import(self, _node: ast.Import) -> Any:
"""Silently remove imports."""
return None

def visit_ImportFrom(self, _node: ast.ImportFrom) -> Any:
"""Silently remove imports."""
return None

def visit_Name(self, node: ast.Name) -> Any:
"""No referencing sensitive names like eval or exec, or anything starting with an underscore."""
if node.id.startswith('_'):
raise CodeSecurityException(f"Illegal reference to `{node.id}`.")
if node.id in _FORBIDDEN_NAMES:
raise CodeSecurityException(f"Illegal reference to `{node.id}`.")
return super().generic_visit(node)

def visit_Attribute(self, node: ast.Attribute) -> Any:
"""Disallow accessing potentially sensitive attributes (any with a leading underscore)."""
if node.attr.startswith('_'):
msg = f"Illegal reference to attribute `{node.attr}`."
raise CodeSecurityException(msg)
return super().generic_visit(node)


def scrub_function(function_def: ast.FunctionDef) -> ast.FunctionDef:
"""
Applies security mitigations to an AST, returning the transformed AST.
"""
return SecureTransformer().visit(function_def)


def compile_function(function_def: ast.FunctionDef, global_namespace: dict[str, Any] | None) -> Callable:
"""
Compile the given function's AST using the given global namespace.
Returns the function.
Args:
function_definition: The function definition to compile.
global_vars: A dictionary of global variables to make available to the compiled function.
Returns:
A callable object representing the compiled function.
"""

# Compile the code and execute it, providing global and local namespaces
module = ast.Module(body=[function_def], type_ignores=[])
code = compile(module, '<string>', mode='exec')
if global_namespace is None:
global_namespace = base_namespace()
local_namespace = dict[str, Any]()
exec(code, global_namespace, local_namespace)
# Now our function is defined in the local namespace, retrieve it.
function = local_namespace[function_def.name]
if not isinstance(function, Callable):
msg = f"`{function_def.name}` did not compile to a callable function."
raise CodeCompileException(msg)
return function


def base_namespace() -> dict[str, Any]:
"""Make a safer namespace for user-defined functions."""
return {'__builtins__': {}}


class ImmutableNamespace:
"""A simple dot-accessible dictionary."""

__slots__ = ['_data']

_data: dict[str, Any]

def __init__(self, data: dict[str, Any] | None = None):
if data is None:
data = {}
object.__setattr__(self, '_data', data)

def __getattribute__(self, __name: str) -> Any:
if __name == '_data':
__cls = self.__class__.__name__
raise AttributeError(f"{__cls} object has no attribute '{__name}'")
try:
return object.__getattribute__(self, __name)
except AttributeError:
data = object.__getattribute__(self, '_data')
if __name not in data:
__cls = self.__class__.__name__
msg = f"{__cls} object has no attribute '{__name}'"
raise AttributeError(msg) from None
return data[__name]

def __setattr__(self, __name: str, __value: Any) -> None:
raise AttributeError(f"{self.__class__.__name__} is immutable.")

def __delattr__(self, __name: str) -> None:
raise AttributeError(f"{self.__class__.__name__} is immutable.")

def to_dict_shallow(self) -> dict[str, Any]:
"""Make a shallow copy of this Namespace as a dict."""
# This is necessary in order to pass it to exec or eval.
# The shallow copy allows child-namespaces to remain dot-accessible.
return object.__getattribute__(self, '_data').copy()
2 changes: 1 addition & 1 deletion epymorph/data/mm/centroids.movement
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def centroids_movement():
centroid = geo['centroid']
distance = pairwise_haversine(centroid['longitude'], centroid['latitude'])
dispersal_kernel = row_normalize(1 / np.exp(distance / param['phi']))
dispersal_kernel = row_normalize(1 / np.exp(distance / params['phi']))
return { 'dispersal_kernel': dispersal_kernel }
]

Expand Down
6 changes: 3 additions & 3 deletions epymorph/data/mm/pei.movement
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def pei_movement():
[mtype: days=all; leave=1; duration=0d; return=2; function=
def commuters(t):
typical = predef['commuters_by_node']
actual = np.binomial(typical, param['move_control'])
actual = np.binomial(typical, params['move_control'])
return np.multinomial(actual, predef['commuting_probability'])
]

Expand All @@ -35,11 +35,11 @@ def commuters(t):
#[mtype: days=all; leave=1; duration=0d; return=2; function=
#def dispersers(t, src, dst):
# avg = predef['commuters_average'][src, dst]
# return poisson(avg * param['theta'])
# return poisson(avg * params['theta'])
#]

[mtype: days=all; leave=1; duration=0d; return=2; function=
def dispersers(t):
avg = predef['commuters_average']
return np.poisson(avg * param['theta'])
return np.poisson(avg * params['theta'])
]
2 changes: 1 addition & 1 deletion epymorph/data/mm/sparsemod.movement
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def sparsemod_predef():
centroid = geo['centroid']
distance = pairwise_haversine(centroid['longitude'], centroid['latitude'])
dispersal_kernel = row_normalize(1 / np.exp(distance / param['phi']))
dispersal_kernel = row_normalize(1 / np.exp(distance / params['phi']))
return {
'commuters_by_node': np.sum(geo['commuters'], axis=1),
'dispersal_kernel': dispersal_kernel
Expand Down
22 changes: 12 additions & 10 deletions epymorph/engine/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from epymorph.compartment_model import CompartmentModel
from epymorph.error import (AttributeException, InitException,
IpmValidationException)
from epymorph.geo.abstract import proxy_geo
from epymorph.geo.geo import Geo
from epymorph.initializer import (InitContext, Initializer,
normalize_init_params)
Expand Down Expand Up @@ -55,6 +54,11 @@ class RumeContext:
time_frame: TimeFrame
initializer: Initializer
rng: np.random.Generator
version: int = field(init=False, default=0)
"""
`version` indicates when changes have been made to the context.
If `version` hasn't changed, no other changes have been made.
"""

_attribute_getters: MemoDict[AttributeDef, AttributeGetter] = field(init=False)

Expand Down Expand Up @@ -83,15 +87,12 @@ def from_config(cls, config: RumeConfig) -> Self:
for a in config.ipm.attributes
}

with proxy_geo(config.geo):
# Parameters might be functions reference the proxy geo,
# so to evaluate them we must be in the `proxy_geo` context.
ctx_params = normalize_params(
config.params,
config.geo,
config.time_frame.duration_days,
attr_dtypes
)
ctx_params = normalize_params(
config.params,
config.geo,
dim,
attr_dtypes
)

return cls(dim, config.geo, config.ipm, config.mm,
ctx_params, config.params, config.time_frame,
Expand Down Expand Up @@ -124,6 +125,7 @@ def update_param(self, name: str, value: ParamNp) -> None:
attrs = [a for a in self._attribute_getters if a.name == name]
for a in attrs:
del self._attribute_getters[a]
self.version += 1

def _get_attribute_value(self, attr: AttributeDef) -> NDArray:
"""Retrieve the value associated with the given attribute."""
Expand Down
39 changes: 31 additions & 8 deletions epymorph/engine/mm_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

from epymorph.engine.context import RumeContext
from epymorph.engine.world import World
from epymorph.error import AttributeException, MmCompileException
from epymorph.movement.compile import compile_spec
from epymorph.movement.movement_model import TravelClause
from epymorph.movement.movement_model import (MovementModel, PredefParams,
TravelClause)
from epymorph.movement.parser import MovementSpec
from epymorph.simulation import SimDType, Tick
from epymorph.util import row_normalize
Expand Down Expand Up @@ -38,27 +40,48 @@ class StandardMovementExecutor(MovementExecutor):

_ctx: RumeContext
_log: Logger
_clauses: list[TravelClause]
_model: MovementModel
_clause_masks: dict[TravelClause, NDArray[np.bool_]]
_predef: PredefParams = {}
_predef_hash: int | None = None

def __init__(self, ctx: RumeContext):
# If we were given a MovementSpec, we need to compile it to get its clauses.
if isinstance(ctx.mm, MovementSpec):
clauses = compile_spec(ctx, ctx.mm).clauses
self._model = compile_spec(ctx, ctx.mm)
else:
clauses = ctx.mm.clauses
self._model = ctx.mm

self._ctx = ctx
self._log = getLogger('movement')
self._clauses = clauses
self._clause_masks = {c: c.mask(ctx) for c in clauses}
self._clause_masks = {c: c.mask(ctx) for c in self._model.clauses}
self._check_predef()

def _check_predef(self) -> None:
"""Check if predef needs to be re-calc'd, and if so, do so."""
curr_hash = self._model.predef_context_hash(self._ctx)
if curr_hash != self._predef_hash:
try:
self._predef = self._model.predef(self._ctx)
self._predef_hash = curr_hash
except KeyError as e:
# NOTE: catching KeyError here will be necessary (to get nice error messages)
# until we can properly validate the MM clauses.
msg = f"Missing attribute {e} required by movement model predef."
raise AttributeException(msg) from None

if not isinstance(self._predef, dict):
msg = f"Movement predef: did not return a dictionary result (got: {type(self._predef)})"
raise MmCompileException(msg)

def apply(self, world: World, tick: Tick) -> None:
"""Applies movement for this tick, mutating the world state."""
self._log.debug('Processing movement for day %s, step %s', tick.day, tick.step)

self._check_predef()

# Process travel clauses.
for clause in self._clauses:
for clause in self._model.clauses:
if not clause.predicate(self._ctx, tick):
continue
local_array = world.get_local_array()
Expand All @@ -81,7 +104,7 @@ def _travelers(self, clause: TravelClause, tick: Tick, local_cohorts: NDArray[Si
clause_log = self._log.getChild(clause.name)
_, N, C, _ = self._ctx.dim.TNCE

requested_movers = clause.requested(self._ctx, tick)
requested_movers = clause.requested(self._ctx, self._predef, tick)
np.fill_diagonal(requested_movers, 0)
requested_sum = requested_movers.sum(axis=1, dtype=SimDType)

Expand Down
19 changes: 18 additions & 1 deletion epymorph/engine/standard_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from epymorph.initializer import DEFAULT_INITIALIZER, Initializer
from epymorph.movement.movement_model import MovementModel
from epymorph.movement.parser import MovementSpec
from epymorph.params import Params
from epymorph.params import ContextParams, Params
from epymorph.simulation import (OnStart, SimDimensions, SimDType, SimTick,
SimulationEvents, TimeFrame)
from epymorph.util import Event
Expand Down Expand Up @@ -92,6 +92,7 @@ class StandardSimulation(SimulationEvents):
"""Runs singular simulation passes, producing time-series output."""

_config: RumeConfig
_params: ContextParams | None = None
on_tick: Event[SimTick] # this class supports on_tick; so narrow the type def

def __init__(self,
Expand Down Expand Up @@ -123,6 +124,22 @@ def validate(self) -> None:
ctx.validate_ipm()
# ctx.validate_init()

@property
def params(self) -> ContextParams:
"""Simulation parameters as used by this simulation."""
# Here we lazily-evaluate and then cache params from the context.
# Why not just cache the whole context when StandardSim is constructed? The problem is mutability.
# Params is a dictionary, which allow mutation, and many of its values are
# numpy arrays, which also allow mutation.
# We can't really guarantee immutability in userland code or even our simulation code
# so it's safest to reconstruct a fresh copy of the context every time we need it.
# Of course, the user can still muck with this cached version of params, but the blast radius
# for doing so is sufficiently contained by this approach because sim runs use a fresh context.
# It would be nice to be able to deep-freeze the entire context object tree, but alas...
if self._params is None:
self._params = RumeContext.from_config(self._config).params
return self._params

def run(self) -> Output:
"""
Run the simulation. It is safe to call this multiple times
Expand Down
Loading

0 comments on commit c48a15b

Please sign in to comment.