Skip to content
This repository was archived by the owner on Apr 19, 2023. It is now read-only.

Commit fa53dcb

Browse files
authored
Merge pull request #164 from inferno-pytorch/gradient_callback
add gradient logging callback
2 parents b0ed9bd + ef2de4e commit fa53dcb

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

inferno/trainers/callbacks/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
__all__ = ['CallbackEngine','Callback', 'Console','essentials','scheduling']
1+
__all__ = ['CallbackEngine', 'Callback', 'Console', 'essentials', 'scheduling', 'gradients']
22

33
from .base import CallbackEngine, Callback
44
from .console import Console
55
from . import essentials
66
from . import scheduling
7+
from . import gradients
78

89
try:
910
from .tqdm import TQDMProgressBar
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from ...utils.train_utils import Frequency
2+
from ...utils.exceptions import assert_, FrequencyValueError
3+
from .base import Callback
4+
5+
6+
class LogOutputGradients(Callback):
7+
"""Logs the gradient of the network output"""
8+
9+
def __init__(self, frequency):
10+
super(LogOutputGradients, self).__init__()
11+
self.log_every = frequency
12+
self.registered = False
13+
self.hook_handle = None
14+
15+
@property
16+
def log_every(self):
17+
return self._log_every
18+
19+
@log_every.setter
20+
def log_every(self, value):
21+
self._log_every = Frequency(value, 'iterations')
22+
assert_(self.log_every.is_consistent,
23+
"Log frequency is not consistent.",
24+
FrequencyValueError)
25+
26+
def add_hook(self):
27+
def hook(module, grad_input, grad_output):
28+
if self.log_every.match(iteration_count=self.trainer.iteration_count,
29+
epoch_count=self.trainer.epoch_count,
30+
persistent=True, match_zero=True):
31+
self.trainer.update_state('output_gradient', grad_output[0].detach().cpu())
32+
33+
self.hook_handle = self.trainer.model.register_backward_hook(hook)
34+
35+
def begin_of_fit(self, **kwargs):
36+
self._trainer.logger.observe_state("output_gradient",
37+
observe_while='training')
38+
self.add_hook()
39+
40+
def begin_of_save(self, **_):
41+
# remove hook from model, because you can't pickle it.
42+
if self.hook_handle is not None:
43+
self.hook_handle.remove()
44+
self.hook_handle = None
45+
46+
47+
def end_of_save(self, **_):
48+
# add hook after model save
49+
self.add_hook()

0 commit comments

Comments
 (0)