Skip to content
This repository was archived by the owner on Apr 19, 2023. It is now read-only.

Commit e92dad3

Browse files
committedFeb 1, 2019
fix gradient callback (make hook member to allow pickeling)
1 parent fa53dcb commit e92dad3

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed
 

‎inferno/trainers/callbacks/gradients.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ def log_every(self, value):
2323
"Log frequency is not consistent.",
2424
FrequencyValueError)
2525

26+
def hook(self, module, grad_input, grad_output):
27+
if self.log_every.match(iteration_count=self.trainer.iteration_count,
28+
epoch_count=self.trainer.epoch_count,
29+
persistent=True, match_zero=True):
30+
self.trainer.update_state('output_gradient', grad_output[0].detach().float().clone().cpu())
31+
2632
def add_hook(self):
27-
def hook(module, grad_input, grad_output):
28-
if self.log_every.match(iteration_count=self.trainer.iteration_count,
29-
epoch_count=self.trainer.epoch_count,
30-
persistent=True, match_zero=True):
31-
self.trainer.update_state('output_gradient', grad_output[0].detach().cpu())
32-
33-
self.hook_handle = self.trainer.model.register_backward_hook(hook)
33+
self.hook_handle = self.trainer.model.register_backward_hook(self.hook)
3434

3535
def begin_of_fit(self, **kwargs):
3636
self._trainer.logger.observe_state("output_gradient",
@@ -43,7 +43,7 @@ def begin_of_save(self, **_):
4343
self.hook_handle.remove()
4444
self.hook_handle = None
4545

46-
4746
def end_of_save(self, **_):
4847
# add hook after model save
4948
self.add_hook()
49+

0 commit comments

Comments
 (0)