-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompile.py
383 lines (329 loc) · 15.3 KB
/
compile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
"""
Compilation of movement models.
"""
import ast
from functools import wraps
from typing import Any, Callable, Mapping, Sequence
import numpy as np
from numpy.typing import NDArray
from epymorph.code import (ImmutableNamespace, compile_function,
epymorph_namespace, parse_function)
from epymorph.data_type import SimDType
from epymorph.error import AttributeException, MmCompileException, error_gate
from epymorph.movement.movement_model import (DynamicTravelClause,
MovementContext,
MovementFunction, MovementModel,
PredefData, TravelClause)
from epymorph.movement.parser import (ALL_DAYS, DailyClause, MovementClause,
MovementSpec)
from epymorph.simulation import AttributeDef, Tick, TickDelta
from epymorph.util import identity
def _empty_predef(_ctx: MovementContext) -> PredefData:
"""A placeholder predef function for when none is given by the movement spec."""
return {}
def compile_spec(
spec: MovementSpec,
rng: np.random.Generator,
name_override: Callable[[str], str] = identity,
) -> MovementModel:
"""
Compile a movement model from a spec. Requires a reference to the random number generator
that will be used to execute the movement model.
By default, clauses will be given a name from the spec file, but you can override
that naming behavior by providing the `name_override` function.
"""
with error_gate("compiling the movement model", MmCompileException, AttributeException):
# Prepare a namespace within which to execute our movement functions.
global_namespace = _movement_global_namespace(rng)
# Compile predef (if any).
if spec.predef is None:
predef_f = _empty_predef
else:
orig_ast = parse_function(spec.predef.function)
transformer = PredefFunctionTransformer(spec.attributes)
trns_ast = transformer.visit_and_fix(orig_ast)
predef_f = compile_function(trns_ast, global_namespace)
def predef_context_hash(ctx: MovementContext) -> int:
# NOTE: This is a placeholder predef hash function
# that will recalculate the predef if any change is made to the context.
# Fine for now, but we could go finer-grained than that
# and only recalc if something changes that the predef code
# actually uses. For this we'll have to extract references
# from the predef AST.
return hash(ctx.version)
return MovementModel(
tau_steps=spec.steps.step_lengths,
attributes=spec.attributes,
predef=predef_f,
predef_context_hash=predef_context_hash,
clauses=[_compile_clause(c, spec.attributes, global_namespace, name_override)
for c in spec.clauses]
)
def _movement_global_namespace(rng: np.random.Generator) -> dict[str, Any]:
"""Make a safe namespace for user-defined movement functions."""
def as_simdtype(func):
@wraps(func)
def wrapped_func(*args, **kwargs):
result = func(*args, **kwargs)
if np.isscalar(result):
return SimDType(result) # type: ignore
else:
return result.astype(SimDType)
return wrapped_func
global_namespace = epymorph_namespace(SimDType)
# Add rng functions to np namespace.
np_ns = ImmutableNamespace({
**global_namespace['np'].to_dict_shallow(),
'poisson': as_simdtype(rng.poisson),
'binomial': as_simdtype(rng.binomial),
'multinomial': as_simdtype(rng.multinomial)
})
# Add simulation details.
global_namespace |= {
'MovementContext': MovementContext,
'PredefData': PredefData,
'np': np_ns,
}
return global_namespace
def _compile_clause(
clause: MovementClause,
model_attributes: Sequence[AttributeDef],
global_namespace: dict[str, Any],
name_override: Callable[[str], str] = identity,
) -> TravelClause:
"""Compiles a movement clause in a given namespace."""
# Parse AST for the function.
try:
orig_ast = parse_function(clause.function)
transformer = ClauseFunctionTransformer(model_attributes)
fn_ast = transformer.visit_and_fix(orig_ast)
fn = compile_function(fn_ast, global_namespace)
except MmCompileException as e:
raise e
except Exception as e:
msg = "Unable to parse and compile movement clause function."
raise MmCompileException(msg) from e
# Handle different types of MovementClause.
match clause:
case DailyClause():
clause_weekdays = set(
i for (i, d) in enumerate(ALL_DAYS)
if d in clause.days
)
def move_predicate(_ctx: MovementContext, tick: Tick) -> bool:
return clause.leave_step == tick.step and \
tick.date.weekday() in clause_weekdays
def returns(_ctx: MovementContext, _tick: Tick) -> TickDelta:
return TickDelta(
days=clause.duration.to_days(),
step=clause.return_step
)
return DynamicTravelClause(
name=name_override(fn_ast.name),
move_predicate=move_predicate,
requested=_adapt_move_function(fn, fn_ast),
returns=returns
)
def _adapt_move_function(fn: Callable, fn_ast: ast.FunctionDef) -> MovementFunction:
"""
Wrap the user-provided function in order to handle functions of different arity.
Movement functions as specified by the user can have signature:
f(tick); f(tick, src); or f(tick, src, dst).
"""
match len(fn_ast.args.args):
# Remember `fn` has been transformed, so if the user gave 1 arg we added 2 for a total of 3.
case 3:
@wraps(fn)
def fn_arity1(ctx: MovementContext, predef: PredefData, tick: Tick) -> NDArray[SimDType]:
requested = fn(ctx, predef, tick)
np.fill_diagonal(requested, 0)
return requested
return fn_arity1
case 4:
@wraps(fn)
def fn_arity2(ctx: MovementContext, predef: PredefData, tick: Tick) -> NDArray[SimDType]:
N = ctx.dim.nodes
requested = np.zeros((N, N), dtype=SimDType)
for n in range(N):
requested[n, :] = fn(ctx, predef, tick, n)
np.fill_diagonal(requested, 0)
return requested
return fn_arity2
case 5:
@wraps(fn)
def fn_arity3(ctx: MovementContext, predef: PredefData, tick: Tick) -> NDArray[SimDType]:
N = ctx.dim.nodes
requested = np.zeros((N, N), dtype=SimDType)
for i, j in np.ndindex(N, N):
requested[i, j] = fn(ctx, predef, tick, i, j)
np.fill_diagonal(requested, 0)
return requested
return fn_arity3
case invalid_num_args:
msg = f"Movement clause '{fn_ast.name}' has an invalid number of arguments ({invalid_num_args})"
raise MmCompileException(msg)
# Code transformers
class _MovementCodeTransformer(ast.NodeTransformer):
"""
This class defines the logic that can be shared between Predef and Clause function
transformers. Some functionality might be more than is technically necessary for either
case, but only if that extra functionality is effectively harmless.
"""
check_attributes: bool
attributes: Mapping[str, AttributeDef]
geo_remapping: Callable[[str], str]
params_remapping: Callable[[str], str]
predef_remapping: Callable[[str], str]
def __init__(
self,
attributes: Sequence[AttributeDef],
geo_remapping: Callable[[str], str] = identity,
params_remapping: Callable[[str], str] = identity,
predef_remapping: Callable[[str], str] = identity,
):
# NOTE: for the sake of backwards compatibility, MovementModel attribute declarations
# are optional; so our approach will be that attributes will only be checked if at least
# one attribute declaration is provided.
if len(attributes) == 0:
self.check_attributes = False
self.attributes = {}
else:
self.check_attributes = True
self.attributes = {a.name: a for a in attributes}
# NOTE: When I added the remapping capability I thought that would be our
# approach to handling multi-strata movement models. As development
# of that feature continued, I decided not to use it in favor of remapping
# the source data itself instead. Nevertheless, it could be a useful feature
# to have here, so I'm leaving the code in place.
self.geo_remapping = geo_remapping
self.params_remapping = params_remapping
self.predef_remapping = predef_remapping
def _report_line(self, node: ast.AST):
return f"Line: {node.lineno}"
def visit_Subscript(self, node: ast.Subscript) -> Any:
"""Modify references to dictionaries that should be in context."""
modified = False
if isinstance(node.value, ast.Name) and isinstance(node.slice, ast.Constant):
source = node.value.id
attr_name = node.slice.value
# Check attributes against declarations.
if self.check_attributes and source in ['geo', 'params']:
if not attr_name in self.attributes:
msg = f"Movement model is using an undeclared attribute: `{source}[{attr_name}]`. "\
f"Please add a suitable attribute declaration. ({self._report_line(node)})"
raise MmCompileException(msg)
attr = self.attributes[attr_name]
if source != attr.source:
msg = "Movement model is using an attribute from a source other than the one that's declared. "\
f"It's trying to access `{source}[{attr_name}]` "\
f"but the attribute declaration says this should come from {attr.source}. "\
"Please correct either the attribute declaration or the model function code. "\
f"({self._report_line(node)})"
raise MmCompileException(msg)
# NOTE: what we are *NOT* doing is checking if usage of predef attributes are
# actually provided by the predef function. Doing this at compile time would be
# exceedingly difficult, as we'd have to scrape and analyze all code that contributes to
# the returned dictionary's keys. In simple cases this might be straight-forward, but not
# in the general case. For the time being, this will remain a simulation-time error.
# Remap the attribute name (slice.value) based on source.
if source in ['geo', 'params', 'predef']:
match source:
case 'geo':
attr_rename = self.geo_remapping(attr_name)
case 'params':
attr_rename = self.params_remapping(attr_name)
case 'predef':
attr_rename = self.predef_remapping(attr_name)
node.slice = ast.Constant(value=attr_rename)
modified = True
# Geo and params are in the context (predef is not).
# Rewrite to access via the context.
if source in ['geo', 'params']:
node.value = ast.Attribute(
value=ast.Name(id='ctx', ctx=ast.Load()),
attr=source,
ctx=ast.Load(),
)
modified = True
return node if modified else self.generic_visit(node)
def visit_Attribute(self, node: ast.Attribute) -> Any:
"""Modify references to objects that should be in context."""
if isinstance(node.value, ast.Name) and node.value.id in ['dim']:
node.value = ast.Attribute(
value=ast.Name(id='ctx', ctx=ast.Load()),
attr=node.value.id,
ctx=ast.Load(),
)
return node
return self.generic_visit(node)
def visit_and_fix(self, node: ast.AST) -> Any:
"""
Shortcut for visiting the node and then running
ast.fix_missing_locations() on the result before returning it.
"""
transformed = self.visit(node)
ast.fix_missing_locations(transformed)
return transformed
class PredefFunctionTransformer(_MovementCodeTransformer):
"""
Transforms movement model predef code. This is the dual of
ClauseFunctionTransformer (below; see that for additional description),
but specialized for predef which is similar but slightly different.
Most importantly, this transforms the function signature to have the context
as the first parameter.
"""
def _report_line(self, node: ast.AST):
return f"Predef line: {node.lineno}"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Modify function parameters."""
new_node = self.generic_visit(node)
if isinstance(new_node, ast.FunctionDef):
ctx_arg = ast.arg(
arg='ctx',
annotation=ast.Name(id='MovementContext', ctx=ast.Load()),
)
new_node.args.args = [ctx_arg, *new_node.args.args]
return new_node
class ClauseFunctionTransformer(_MovementCodeTransformer):
"""
Transforms movement clause code so that we can pass context, etc.,
via function arguments instead of the namespace. The goal is to
simplify the function interface for end users while still maintaining
good performance characteristics when parameters change during
a simulation run (i.e., not have to recompile the functions every time
the params change).
A function like:
def commuters(t):
typical = np.minimum(
geo['population'][:],
predef['commuters_by_node'],
)
actual = np.binomial(typical, param['move_control'])
return np.multinomial(actual, predef['commuting_probability'])
Will be rewritten as:
def commuters(ctx, predef, t):
typical = np.minimum(
ctx.geo['population'][:],
predef['commuters_by_node'],
)
actual = np.binomial(typical, ctx.param['move_control'])
return np.multinomial(actual, predef['commuting_probability'])
"""
clause_name: str = "<unknown clause>"
def _report_line(self, node: ast.AST):
return f"{self.clause_name} line: {node.lineno}"
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Modify function parameters."""
self.clause_name = f"`{node.name}`"
new_node = self.generic_visit(node)
if isinstance(new_node, ast.FunctionDef):
ctx_arg = ast.arg(
arg='ctx',
annotation=ast.Name(id='MovementContext', ctx=ast.Load()),
)
predef_arg = ast.arg(
arg='predef',
annotation=ast.Name(id='PredefData', ctx=ast.Load()),
)
new_node.args.args = [ctx_arg, predef_arg, *new_node.args.args]
return new_node