|
| 1 | +from ...utils.train_utils import Frequency |
| 2 | +from ...utils.exceptions import assert_, FrequencyValueError |
| 3 | +from .base import Callback |
| 4 | + |
| 5 | + |
| 6 | +class LogOutputGradients(Callback): |
| 7 | + """Logs the gradient of the network output""" |
| 8 | + |
| 9 | + def __init__(self, frequency): |
| 10 | + super(LogOutputGradients, self).__init__() |
| 11 | + self.log_every = frequency |
| 12 | + self.registered = False |
| 13 | + self.hook_handle = None |
| 14 | + |
| 15 | + @property |
| 16 | + def log_every(self): |
| 17 | + return self._log_every |
| 18 | + |
| 19 | + @log_every.setter |
| 20 | + def log_every(self, value): |
| 21 | + self._log_every = Frequency(value, 'iterations') |
| 22 | + assert_(self.log_every.is_consistent, |
| 23 | + "Log frequency is not consistent.", |
| 24 | + FrequencyValueError) |
| 25 | + |
| 26 | + def add_hook(self): |
| 27 | + def hook(module, grad_input, grad_output): |
| 28 | + if self.log_every.match(iteration_count=self.trainer.iteration_count, |
| 29 | + epoch_count=self.trainer.epoch_count, |
| 30 | + persistent=True, match_zero=True): |
| 31 | + self.trainer.update_state('output_gradient', grad_output[0].detach().cpu()) |
| 32 | + |
| 33 | + self.hook_handle = self.trainer.model.register_backward_hook(hook) |
| 34 | + |
| 35 | + def begin_of_fit(self, **kwargs): |
| 36 | + self._trainer.logger.observe_state("output_gradient", |
| 37 | + observe_while='training') |
| 38 | + self.add_hook() |
| 39 | + |
| 40 | + def begin_of_save(self, **_): |
| 41 | + # remove hook from model, because you can't pickle it. |
| 42 | + if self.hook_handle is not None: |
| 43 | + self.hook_handle.remove() |
| 44 | + self.hook_handle = None |
| 45 | + |
| 46 | + |
| 47 | + def end_of_save(self, **_): |
| 48 | + # add hook after model save |
| 49 | + self.add_hook() |
0 commit comments