From 926cd2e2388d6d517adcc71ac7e49d36e3f0eb53 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 17 Aug 2022 15:15:42 -0500 Subject: [PATCH] Fix misc. issues surrounding ProfileStats --- aesara/compile/function/pfunc.py | 14 +-- aesara/compile/function/types.py | 10 ++- aesara/compile/profiling.py | 147 +++++-------------------------- aesara/scan/utils.py | 12 +-- tests/compile/test_profiling.py | 27 ++++-- tests/d3viz/test_d3viz.py | 4 +- tests/test_printing.py | 3 +- 7 files changed, 62 insertions(+), 155 deletions(-) diff --git a/aesara/compile/function/pfunc.py b/aesara/compile/function/pfunc.py index 94be2db2c6..3bf6490c00 100644 --- a/aesara/compile/function/pfunc.py +++ b/aesara/compile/function/pfunc.py @@ -5,7 +5,7 @@ import logging from copy import copy -from typing import Optional +from typing import Optional, Union from aesara.compile.function.types import Function, UnusedInputError, orig_function from aesara.compile.io import In, Out @@ -282,7 +282,7 @@ def pfunc( name=None, rebuild_strict=True, allow_input_downcast=None, - profile=None, + profile: Optional[Union[bool, str, ProfileStats]] = None, on_unused_input=None, output_keys=None, fgraph: Optional[FunctionGraph] = None, @@ -322,13 +322,13 @@ def pfunc( general, or precise, type. None (default) is almost like False, but allows downcasting of Python float scalars to floatX. - profile : None, True, str, or ProfileStats instance - Accumulate profiling information into a given ProfileStats instance. + profile + Accumulate profiling information into a given `ProfileStats` instance. None is the default, and means to use the value of config.profile. - If argument is `True` then a new ProfileStats instance will be used. - If argument is a string, a new ProfileStats instance will be created + If argument is ``True`` then a new `ProfileStats` instance will be used. + If argument is a string, a new `ProfileStats` instance will be created with that string as its `message` attribute. This profiling object will - be available via self.profile. + be available via `Function.profile`. on_unused_input : {'raise', 'warn','ignore', None} What to do if a variable in the 'inputs' list is not used in the graph. fgraph diff --git a/aesara/compile/function/types.py b/aesara/compile/function/types.py index 2444145e29..f9e4077cb2 100644 --- a/aesara/compile/function/types.py +++ b/aesara/compile/function/types.py @@ -14,6 +14,7 @@ import aesara.compile.profiling from aesara.compile.io import In, SymbolicInput, SymbolicOutput from aesara.compile.ops import deep_copy_op, view_op +from aesara.compile.profiling import ProfileStats from aesara.configdefaults import config from aesara.graph.basic import ( Constant, @@ -731,10 +732,10 @@ def checkSV(sv_ori, sv_rpl): message = name else: message = str(profile.message) + " copy" - profile = aesara.compile.profiling.ProfileStats(message=message) + profile = ProfileStats(message=message) # profile -> object elif isinstance(profile, str): - profile = aesara.compile.profiling.ProfileStats(message=profile) + profile = ProfileStats(message=profile) f_cpy = maker.__class__( inputs=ins, @@ -1688,7 +1689,7 @@ def orig_function( mode=None, accept_inplace=False, name=None, - profile=None, + profile: Optional[ProfileStats] = None, on_unused_input=None, output_keys=None, fgraph: Optional[FunctionGraph] = None, @@ -1712,7 +1713,8 @@ def orig_function( accept_inplace : bool True iff the graph can contain inplace operations prior to the rewrite phase (default is False). - profile : None or ProfileStats instance + profile : + `ProfileStats` instance. on_unused_input : {'raise', 'warn', 'ignore', None} What to do if a variable in the 'inputs' list is not used in the graph. output_keys diff --git a/aesara/compile/profiling.py b/aesara/compile/profiling.py index 15c57fdf74..44c70ea6a0 100644 --- a/aesara/compile/profiling.py +++ b/aesara/compile/profiling.py @@ -8,15 +8,12 @@ # TODO: what to do about 'diff summary'? (ask Fred?) # -import atexit -import copy -import logging import operator import sys import time from collections import defaultdict from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union import numpy as np @@ -41,107 +38,11 @@ def extended_open(filename, mode="r"): yield f -logger = logging.getLogger("aesara.compile.profiling") - aesara_imported_time: float = time.time() total_fct_exec_time: float = 0.0 total_graph_rewrite_time: float = 0.0 total_time_linker: float = 0.0 -_atexit_print_list: List["ProfileStats"] = [] -_atexit_registered: bool = False - - -def _atexit_print_fn(): - """Print `ProfileStat` objects in `_atexit_print_list` to `_atexit_print_file`.""" - if config.profile: - to_sum = [] - - if config.profiling__destination == "stderr": - destination_file = "" - elif config.profiling__destination == "stdout": - destination_file = "" - else: - destination_file = config.profiling__destination - - with extended_open(destination_file, mode="w"): - - # Reverse sort in the order of compile+exec time - for ps in sorted( - _atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time - )[::-1]: - if ( - ps.fct_callcount >= 1 - or ps.compile_time > 1 - or getattr(ps, "callcount", 0) > 1 - ): - ps.summary( - file=destination_file, - n_ops_to_print=config.profiling__n_ops, - n_apply_to_print=config.profiling__n_apply, - ) - - if ps.show_sum: - to_sum.append(ps) - else: - # TODO print the name if there is one! - print("Skipping empty Profile") - if len(to_sum) > 1: - # Make a global profile - cum = copy.copy(to_sum[0]) - msg = f"Sum of all({len(to_sum)}) printed profiles at exit." - cum.message = msg - for ps in to_sum[1:]: - for attr in [ - "compile_time", - "fct_call_time", - "fct_callcount", - "vm_call_time", - "rewriter_time", - "linker_time", - "validate_time", - "import_time", - "linker_node_make_thunks", - ]: - setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr)) - - # merge dictionary - for attr in [ - "apply_time", - "apply_callcount", - "apply_cimpl", - "variable_shape", - "variable_strides", - "variable_offset", - "linker_make_thunk_time", - ]: - cum_attr = getattr(cum, attr) - for key, val in getattr(ps, attr.items()): - assert key not in cum_attr, (key, cum_attr) - cum_attr[key] = val - - if cum.rewriter_profile and ps.rewriter_profile: - try: - merge = cum.rewriter_profile[0].merge_profile( - cum.rewriter_profile[1], ps.rewriter_profile[1] - ) - assert len(merge) == len(cum.rewriter_profile[1]) - cum.rewriter_profile = (cum.rewriter_profile[0], merge) - except Exception as e: - print(e) - cum.rewriter_profile = None - else: - cum.rewriter_profile = None - - cum.summary( - file=destination_file, - n_ops_to_print=config.profiling__n_ops, - n_apply_to_print=config.profiling__n_apply, - ) - - if config.print_global_stats: - print_global_stats() - def print_global_stats(): """ @@ -190,26 +91,12 @@ class ProfileStats: Parameters ---------- - atexit_print : bool - True means that this object will be printed to stderr (using .summary()) - at the end of the program. **kwargs : misc initializers These should (but need not) match the names of the class vars declared in this class. """ - def reset(self): - """Ignore previous function call""" - # self.compile_time = 0. - self.fct_call_time = 0.0 - self.fct_callcount = 0 - self.vm_call_time = 0.0 - self.apply_time = {} - self.apply_callcount = {} - # self.apply_cimpl = None - # self.message = None - # # Note on implementation: # Class variables are used here so that each one can be @@ -277,7 +164,7 @@ def reset(self): linker_make_thunk_time: Dict = {} - line_width = config.profiling__output_line_width + line_width: int = config.profiling__output_line_width nb_nodes: int = -1 # The number of nodes in the graph. We need the information separately in @@ -289,7 +176,7 @@ def reset(self): # param is called flag_time_thunks because most other attributes with time # in the name are times *of* something, rather than configuration flags. - def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs): + def __init__(self, flag_time_thunks=None, message=None): self.apply_callcount = {} self.output_size = {} # Keys are `(FunctionGraph, Variable)` @@ -298,20 +185,25 @@ def __init__(self, atexit_print=True, flag_time_thunks=None, **kwargs): self.variable_shape = {} self.variable_strides = {} self.variable_offset = {} + self.message = message if flag_time_thunks is None: self.flag_time_thunks = config.profiling__time_thunks else: self.flag_time_thunks = flag_time_thunks - self.__dict__.update(kwargs) - if atexit_print: - global _atexit_print_list - _atexit_print_list.append(self) - global _atexit_registered - if not _atexit_registered: - atexit.register(_atexit_print_fn) - _atexit_registered = True + self.ignore_first_call = config.profiling__ignore_first_call + def reset(self): + """Ignore previous function call""" + # self.compile_time = 0. + self.fct_call_time = 0.0 + self.fct_callcount = 0 + self.vm_call_time = 0.0 + self.apply_time = {} + self.apply_callcount = {} + self.apply_cimpl = None + self.message = None + def class_time(self): """ dict op -> total time on thunks @@ -360,7 +252,7 @@ def class_impl(self): rval = {} for (fgraph, node) in self.apply_callcount: typ = type(node.op) - if self.apply_cimpl[node]: + if self.apply_cimpl and self.apply_cimpl[node]: impl = "C " else: impl = "Py" @@ -438,7 +330,7 @@ def op_impl(self): # timing is stored by node, we compute timing by Op on demand rval = {} for (fgraph, node) in self.apply_callcount: - if self.apply_cimpl[node]: + if self.apply_cimpl and self.apply_cimpl[node]: rval[node.op] = "C " else: rval[node.op] = "Py" @@ -785,7 +677,8 @@ def summary_nodes(self, file=sys.stderr, N=None): def summary_function(self, file): print("Function profiling", file=file) print("==================", file=file) - print(f" Message: {self.message}", file=file) + if self.message: + print(f" Message: {self.message}", file=file) print( f" Time in {self.fct_callcount} calls to Function.__call__: {self.fct_call_time:e}s", file=file, diff --git a/aesara/scan/utils.py b/aesara/scan/utils.py index 8ec792fad0..879d23cea7 100644 --- a/aesara/scan/utils.py +++ b/aesara/scan/utils.py @@ -136,13 +136,13 @@ def __init__(self, condition): class ScanProfileStats(ProfileStats): - show_sum = False - callcount = 0 - nbsteps = 0 - call_time = 0.0 + show_sum: bool = False + callcount: int = 0 + nbsteps: int = 0 + call_time: float = 0.0 - def __init__(self, atexit_print=True, name=None, **kwargs): - super().__init__(atexit_print, **kwargs) + def __init__(self, name: Optional[str] = None, **kwargs): + super().__init__(**kwargs) self.name = name def summary_globals(self, file): diff --git a/tests/compile/test_profiling.py b/tests/compile/test_profiling.py index 240a26fd7a..80093f7f64 100644 --- a/tests/compile/test_profiling.py +++ b/tests/compile/test_profiling.py @@ -1,9 +1,7 @@ -# Test of memory profiling - - from io import StringIO import numpy as np +import pytest import aesara.tensor as at from aesara.compile import ProfileStats @@ -13,8 +11,13 @@ from aesara.tensor.type import fvector, scalars +pytestmark = pytest.mark.filterwarnings("error") + + class TestProfiling: - # Test of Aesara profiling with min_peak_memory=True + """ + Test Aesara profiling with ``min_peak_memory=True``. + """ def test_profiling(self): @@ -32,14 +35,17 @@ def test_profiling(self): z += [at.outer(x[i], x[i + 1]).sum(axis=1) for i in range(len(x) - 1)] z += [x[i] + x[i + 1] for i in range(len(x) - 1)] - p = ProfileStats(False, gpu_checks=False) + p = ProfileStats() if config.mode in ("DebugMode", "DEBUG_MODE", "FAST_COMPILE"): m = "FAST_RUN" else: m = None - f = function(x, z, profile=p, name="test_profiling", mode=m) + with pytest.warns( + UserWarning, match=".*CVM does not support memory profiling.*" + ): + f = function(x, z, profile=p, name="test_profiling", mode=m) inp = [np.arange(1024, dtype="float32") + 1 for i in range(len(x))] f(*inp) @@ -87,14 +93,19 @@ def test_ifelse(self): z = ifelse(at.lt(a, b), x * 2, y * 2) - p = ProfileStats(False, gpu_checks=False) + p = ProfileStats() if config.mode in ("DebugMode", "DEBUG_MODE", "FAST_COMPILE"): m = "FAST_RUN" else: m = None - f_ifelse = function([a, b, x, y], z, profile=p, name="test_ifelse", mode=m) + with pytest.warns( + UserWarning, match=".*CVM does not support memory profiling.*" + ): + f_ifelse = function( + [a, b, x, y], z, profile=p, name="test_ifelse", mode=m + ) val1 = 0.0 val2 = 1.0 diff --git a/tests/d3viz/test_d3viz.py b/tests/d3viz/test_d3viz.py index 41425d72f6..efe1ff899d 100644 --- a/tests/d3viz/test_d3viz.py +++ b/tests/d3viz/test_d3viz.py @@ -6,8 +6,8 @@ import pytest import aesara.d3viz as d3v -from aesara import compile from aesara.compile.function import function +from aesara.compile.profiling import ProfileStats from aesara.configdefaults import config from aesara.d3viz.formatting import pydot_imported, pydot_imported_msg from tests.d3viz import models @@ -41,7 +41,7 @@ def test_mlp_profiled(self): if config.mode in ("DebugMode", "DEBUG_MODE"): pytest.skip("Can't profile in DebugMode") m = models.Mlp() - profile = compile.profiling.ProfileStats(False) + profile = ProfileStats() f = function(m.inputs, m.outputs, profile=profile) x_val = self.rng.normal(0, 1, (1000, m.nfeatures)) f(x_val) diff --git a/tests/test_printing.py b/tests/test_printing.py index ac64024152..e29400fc23 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -9,6 +9,7 @@ import aesara from aesara.compile.mode import get_mode from aesara.compile.ops import deep_copy_op +from aesara.compile.profiling import ProfileStats from aesara.printing import ( PatternPrinter, PPrinter, @@ -81,7 +82,7 @@ def test_pydotprint_long_name(): ) def test_pydotprint_profile(): A = matrix() - prof = aesara.compile.ProfileStats(atexit_print=False, gpu_checks=False) + prof = ProfileStats() f = aesara.function([A], A + 1, profile=prof) pydotprint(f, print_output_file=False) f([[1]])