Skip to content

Commit

Permalink
Improvements to optimization debug printing
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Feb 15, 2021
1 parent 13b08f0 commit 03487af
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 26 deletions.
5 changes: 4 additions & 1 deletion aesara/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ def replace(self, var, new_var, reason=None, verbose=None, import_missing=False)
if verbose is None:
verbose = config.optimizer_verbose
if verbose:
print(reason, var, new_var)
print(
f"{reason}:\t{var.owner or var} [{var.name or var.auto_name}] -> "
f"{new_var.owner or new_var} [{new_var.name or new_var.auto_name}]"
)

new_var = var.type.filter_variable(new_var, allow_convert=True)

Expand Down
43 changes: 28 additions & 15 deletions aesara/graph/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def __hash__(self):
_optimizer_idx[0] += 1
return self._optimizer_idx

def __str__(self):
if hasattr(self, "name"):
return f"{type(self).__name__}[{self.name}]"
return repr(self)


class FromFunctionOptimizer(GlobalOptimizer):
"""A `GlobalOptimizer` constructed from a given function."""
Expand Down Expand Up @@ -1074,6 +1079,11 @@ def add_requirements(self, fgraph):
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream)

def __str__(self):
if hasattr(self, "name"):
return f"{type(self).__name__}[{self.name}]"
return repr(self)


class LocalMetaOptimizer(LocalOptimizer):
"""
Expand Down Expand Up @@ -1253,6 +1263,14 @@ def decorator(f):
class LocalOptGroup(LocalOptimizer):
"""Takes a list of LocalOptimizer and applies them to the node.
This is where the "tracks" parameters are largely used. If you set
one of the `LocalOptimizer` in `LocalOptGroup.optimizers` to track a
specific `Op` instance, this optimizer will only apply said
`LocalOptimizer` when it's acting on a node that exactly matches the object
object tracked `Op` (the matching is performed using a `dict` lookup).
TODO: Use type-based matching (e.g. like `singledispatch`).
Parameters
----------
optimizers :
Expand All @@ -1269,10 +1287,12 @@ class LocalOptGroup(LocalOptimizer):
"""

def __init__(self, *optimizers, **kwargs):
def __init__(self, *optimizers, apply_all_opts=False, profile=False, name=None):
self.name = name

if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0])

self.opts = optimizers
assert isinstance(self.opts, tuple)

Expand All @@ -1281,10 +1301,10 @@ def __init__(self, *optimizers, **kwargs):
getattr(opt, "retains_inputs", False) for opt in optimizers
)

self.apply_all_opts = kwargs.pop("apply_all_opts", False)
self.profile = kwargs.pop("profile", False)
self.track_map = defaultdict(lambda: [])
assert len(kwargs) == 0
self.apply_all_opts = apply_all_opts
self.profile = profile
self.track_map = defaultdict(list)

if self.profile:
self.time_opts = {}
self.process_count = {}
Expand All @@ -1304,12 +1324,8 @@ def __init__(self, *optimizers, **kwargs):
for c in tracks:
self.track_map[c].append(o)

def __str__(self):
return getattr(
self,
"__name__",
f"LocalOptGroup({','.join([str(o) for o in self.opts])})",
)
def __repr__(self):
return f"LocalOptGroup([{', '.join([str(o) for o in self.opts])}])"

def tracks(self):
t = []
Expand Down Expand Up @@ -2189,9 +2205,6 @@ def print_profile(stream, prof, level=0):
level=level + 1,
)

def __str__(self):
return getattr(self, "__name__", "<TopoOptimizer instance>")


def out2in(*local_opts, **kwargs):
"""
Expand Down
11 changes: 4 additions & 7 deletions aesara/graph/optdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,22 +460,19 @@ def register(self, name, obj, *tags, **kwargs):
self.__position__[name] = position

def query(self, *tags, **kwtags):
# For the new `useless` optimizer
opts = list(super().query(*tags, **kwtags))
opts.sort(key=lambda obj: (self.__position__[obj.name], obj.name))

ret = self.local_opt(
*opts, apply_all_opts=self.apply_all_opts, profile=self.profile
*opts,
apply_all_opts=self.apply_all_opts,
profile=self.profile,
)
return ret


class TopoDB(DB):
"""
Generate a `GlobalOptimizer` of type TopoOptimizer.
"""
"""Generate a `GlobalOptimizer` of type TopoOptimizer."""

def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
Expand Down
4 changes: 1 addition & 3 deletions aesara/graph/toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def replace_all_validate(

for r, new_r in replacements:
try:
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
fgraph.replace(r, new_r, reason=reason, verbose=verbose, **kwargs)
except Exception as e:
msg = str(e)
s1 = "The type of the replacement must be the same"
Expand Down Expand Up @@ -626,8 +626,6 @@ def replace_all_validate(
print(
"Scan removed", nb, nb2, getattr(reason, "name", reason), r, new_r
)
if verbose:
print(reason, r, new_r)
# The return is needed by replace_all_validate_remove
return chk

Expand Down

0 comments on commit 03487af

Please sign in to comment.