From ba8dc1592eb9155640fcec35e97be566fb3c4af6 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 4 May 2022 22:26:31 +0000 Subject: [PATCH 1/3] Pass backend-related ctx to TorchDynamo Optimize Context --- torchdynamo/eval_frame.py | 35 ++++++++++++++++++++------- torchdynamo/optimizations/training.py | 12 ++++++++- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index fca066eb18..877ba92247 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -1,3 +1,4 @@ +import contextlib import functools import logging import threading @@ -24,26 +25,31 @@ def nothing(): pass +null_context = contextlib.nullcontext() + unset = object() compile_lock = threading.Lock() class _TorchDynamoContext: - def __init__(self, callback, on_enter=nothing): + def __init__(self, callback, on_enter=nothing, extra_ctx=null_context): super().__init__() assert callable(callback) or callback is False or callback is None self.callback = callback self.prior = unset self.on_enter = on_enter + self.extra_ctx = extra_ctx def __enter__(self): self.on_enter() self.prior = set_eval_frame(self.callback) + self.extra_ctx.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): set_eval_frame(self.prior) self.prior = unset + self.extra_ctx.__exit__(exc_type, exc_val, exc_tb) def __call__(self, fn): assert callable(fn) @@ -69,8 +75,12 @@ def _fn(*args, **kwargs): class OptimizeContext(_TorchDynamoContext): - def __init__(self, callback): - super().__init__(callback=callback, on_enter=install_generation_tagging_new) + def __init__(self, callback, extra_ctx): + super().__init__( + callback=callback, + on_enter=install_generation_tagging_new, + extra_ctx=extra_ctx, + ) class RunOnlyContext(_TorchDynamoContext): @@ -107,8 +117,8 @@ def catch_errors(frame, cache_size): return catch_errors -def _optimize_catch_errors(compile_fn): - return OptimizeContext(catch_errors_wrapper(compile_fn)) +def _optimize_catch_errors(compile_fn, extra_ctx=null_context): + return OptimizeContext(catch_errors_wrapper(compile_fn), extra_ctx=extra_ctx) def optimize(backend, nopython=False): @@ -136,16 +146,23 @@ def toy_example(a, b): with torchdynamo.optimize(my_compiler): ... """ + + extra_ctx = null_context + if hasattr(backend, "extra_ctx"): + extra_ctx = getattr(backend, "extra_ctx") + if nopython: - return optimize_assert(backend) - return _optimize_catch_errors(convert_frame.convert_frame(backend)) + return optimize_assert(backend, extra_ctx) + return _optimize_catch_errors(convert_frame.convert_frame(backend), extra_ctx) -def optimize_assert(backend): +def optimize_assert(backend, extra_ctx=null_context): """ The same as `torchdynamo.optimize(backend, nopython=True)` """ - return _optimize_catch_errors(convert_frame.convert_frame_assert(backend)) + return _optimize_catch_errors( + convert_frame.convert_frame_assert(backend), extra_ctx + ) def run(fn=None): diff --git a/torchdynamo/optimizations/training.py b/torchdynamo/optimizations/training.py index f03ea787dd..a3a5337bf0 100644 --- a/torchdynamo/optimizations/training.py +++ b/torchdynamo/optimizations/training.py @@ -143,4 +143,14 @@ def candidate(self): return BACKENDS["aot_autograd"](self.gm, self.example_inputs) -aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusion.compile_fn +class AOTAutogradMemoryEfficientFusionWithContext: + """Pass nvfuser context to TorchDynamo""" + + def __init__(self): + self.extra_ctx = torch.jit.fuser("fuser2") + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + return AOTAutogradMemoryEfficientFusion.compile_fn(gm, example_inputs) + + +aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusionWithContext() From 42f68e331a903a75624bba9b3069d33c1c684f14 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 4 May 2022 23:51:29 +0000 Subject: [PATCH 2/3] Reinit the backend ctx for every frame --- torchdynamo/eval_frame.py | 37 +++++++++++++++------------ torchdynamo/optimizations/training.py | 2 +- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index 877ba92247..530c54db56 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -25,7 +25,7 @@ def nothing(): pass -null_context = contextlib.nullcontext() +null_context = contextlib.nullcontext unset = object() @@ -33,23 +33,24 @@ def nothing(): class _TorchDynamoContext: - def __init__(self, callback, on_enter=nothing, extra_ctx=null_context): + def __init__(self, callback, on_enter=nothing, backend_ctx_ctor=null_context): super().__init__() assert callable(callback) or callback is False or callback is None self.callback = callback self.prior = unset self.on_enter = on_enter - self.extra_ctx = extra_ctx + self.extra_ctx_ctor = backend_ctx_ctor def __enter__(self): self.on_enter() self.prior = set_eval_frame(self.callback) - self.extra_ctx.__enter__() + self.backend_ctx = self.extra_ctx_ctor() + self.backend_ctx.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): set_eval_frame(self.prior) self.prior = unset - self.extra_ctx.__exit__(exc_type, exc_val, exc_tb) + self.backend_ctx.__exit__(exc_type, exc_val, exc_tb) def __call__(self, fn): assert callable(fn) @@ -75,11 +76,11 @@ def _fn(*args, **kwargs): class OptimizeContext(_TorchDynamoContext): - def __init__(self, callback, extra_ctx): + def __init__(self, callback, backend_ctx_ctor): super().__init__( callback=callback, on_enter=install_generation_tagging_new, - extra_ctx=extra_ctx, + backend_ctx_ctor=backend_ctx_ctor, ) @@ -117,8 +118,10 @@ def catch_errors(frame, cache_size): return catch_errors -def _optimize_catch_errors(compile_fn, extra_ctx=null_context): - return OptimizeContext(catch_errors_wrapper(compile_fn), extra_ctx=extra_ctx) +def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context): + return OptimizeContext( + catch_errors_wrapper(compile_fn), backend_ctx_ctor=backend_ctx_ctor + ) def optimize(backend, nopython=False): @@ -147,21 +150,23 @@ def toy_example(a, b): ... """ - extra_ctx = null_context - if hasattr(backend, "extra_ctx"): - extra_ctx = getattr(backend, "extra_ctx") + backend_ctx_ctor = null_context + if hasattr(backend, "backend_ctx_ctor"): + backend_ctx_ctor = getattr(backend, "backend_ctx_ctor") if nopython: - return optimize_assert(backend, extra_ctx) - return _optimize_catch_errors(convert_frame.convert_frame(backend), extra_ctx) + return optimize_assert(backend, backend_ctx_ctor) + return _optimize_catch_errors( + convert_frame.convert_frame(backend), backend_ctx_ctor + ) -def optimize_assert(backend, extra_ctx=null_context): +def optimize_assert(backend, backend_ctx_ctor=null_context): """ The same as `torchdynamo.optimize(backend, nopython=True)` """ return _optimize_catch_errors( - convert_frame.convert_frame_assert(backend), extra_ctx + convert_frame.convert_frame_assert(backend), backend_ctx_ctor ) diff --git a/torchdynamo/optimizations/training.py b/torchdynamo/optimizations/training.py index a3a5337bf0..d7de789268 100644 --- a/torchdynamo/optimizations/training.py +++ b/torchdynamo/optimizations/training.py @@ -147,7 +147,7 @@ class AOTAutogradMemoryEfficientFusionWithContext: """Pass nvfuser context to TorchDynamo""" def __init__(self): - self.extra_ctx = torch.jit.fuser("fuser2") + self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2") def __call__(self, gm: torch.fx.GraphModule, example_inputs): return AOTAutogradMemoryEfficientFusion.compile_fn(gm, example_inputs) From 637124d260360995d57e7afce8d66f9440ae1f34 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 5 May 2022 00:00:39 +0000 Subject: [PATCH 3/3] Doc --- torchdynamo/eval_frame.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index 530c54db56..809002cb07 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -130,10 +130,13 @@ def optimize(backend, nopython=False): backend() to optimize extracted graphs. Args: - backend: One of two things: - - Either, a function taking a torch.fx.GraphModule and + backend: One of the two things: + - Either, a function/callable taking a torch.fx.GraphModule and example_inputs and returning a python callable that runs the graph faster. + One can also provide additional context for the backend, like + torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. + See AOTAutogradMemoryEfficientFusionWithContext for the usage. - Or, a string backend name in `torchdynamo.list_backends()` nopython: If True, graph breaks will be errors and there will be a single whole-program graph.