|
27 | 27 | from .callbacks import Console
|
28 | 28 | from ..utils.exceptions import assert_, NotSetError, NotTorchModuleError, DeviceError
|
29 | 29 |
|
| 30 | +# NOTE for distributed training, we might also need |
| 31 | +# from apex.parallel import DistributedDataParallel as DDP |
| 32 | +# but I don't know where exactly to put it. |
| 33 | +try: |
| 34 | + from apex import amp |
| 35 | +except ImportError: |
| 36 | + amp = None |
| 37 | + |
30 | 38 |
|
31 | 39 | class Trainer(object):
|
32 | 40 | """A basic trainer.
|
@@ -126,10 +134,44 @@ def __init__(self, model=None):
|
126 | 134 | # Print console
|
127 | 135 | self._console = Console()
|
128 | 136 |
|
| 137 | + # Train with mixed precision, only works |
| 138 | + # if we have apex |
| 139 | + self._mixed_precision = False |
| 140 | + self._apex_opt_level = 'O1' |
| 141 | + |
129 | 142 | # Public
|
130 | 143 | if model is not None:
|
131 | 144 | self.model = model
|
132 | 145 |
|
| 146 | + @property |
| 147 | + def mixed_precision(self): |
| 148 | + return self._mixed_precision |
| 149 | + |
| 150 | + # this needs to be called after model and optimizer are set |
| 151 | + @mixed_precision.setter |
| 152 | + def mixed_precision(self, mp): |
| 153 | + if mp: |
| 154 | + assert_(amp is not None, "Cannot use mixed precision training without apex library", RuntimeError) |
| 155 | + assert_(self.model is not None and self._optimizer is not None, |
| 156 | + "Model and optimizer need to be set before activating mixed precision", RuntimeError) |
| 157 | + # in order to support BCE loss |
| 158 | + amp.register_float_function(torch, 'sigmoid') |
| 159 | + # For now, we don't allow to set 'keep_batchnorm' and 'loss_scale' |
| 160 | + self.model, self._optimizer = amp.initialize(self.model, self._optimizer, |
| 161 | + opt_level=self._apex_opt_level, |
| 162 | + keep_batchnorm_fp32=None) |
| 163 | + self._mixed_precision = mp |
| 164 | + |
| 165 | + @property |
| 166 | + def apex_opt_level(self): |
| 167 | + return self._apex_opt_level |
| 168 | + |
| 169 | + @apex_opt_level.setter |
| 170 | + def apex_opt_level(self, opt_level): |
| 171 | + assert_(opt_level in ('O0', 'O1', 'O2', 'O3'), |
| 172 | + "Invalid optimization level", ValueError) |
| 173 | + self._apex_opt_level = opt_level |
| 174 | + |
133 | 175 | @property
|
134 | 176 | def console(self):
|
135 | 177 | """Get the current console."""
|
@@ -1368,17 +1410,21 @@ def apply_model_and_loss(self, inputs, target, backward=True, mode=None):
|
1368 | 1410 | kwargs['trainer'] = self
|
1369 | 1411 | if mode == 'train':
|
1370 | 1412 | loss = self.criterion(prediction, target, **kwargs) \
|
1371 |
| - if len(target) != 0 else self.criterion(prediction, **kwargs) |
| 1413 | + if len(target) != 0 else self.criterion(prediction, **kwargs) |
1372 | 1414 | elif mode == 'eval':
|
1373 | 1415 | loss = self.validation_criterion(prediction, target, **kwargs) \
|
1374 |
| - if len(target) != 0 else self.validation_criterion(prediction, **kwargs) |
| 1416 | + if len(target) != 0 else self.validation_criterion(prediction, **kwargs) |
1375 | 1417 | else:
|
1376 | 1418 | raise ValueError
|
1377 | 1419 | if backward:
|
1378 | 1420 | # Backprop if required
|
1379 | 1421 | # retain_graph option is needed for some custom
|
1380 | 1422 | # loss functions like malis, False per default
|
1381 |
| - loss.backward(retain_graph=self.retain_graph) |
| 1423 | + if self.mixed_precision: |
| 1424 | + with amp.scale_loss(loss, self.optimizer) as scaled_loss: |
| 1425 | + scaled_loss.backward(retain_graph=self.retain_graph) |
| 1426 | + else: |
| 1427 | + loss.backward(retain_graph=self.retain_graph) |
1382 | 1428 | return prediction, loss
|
1383 | 1429 |
|
1384 | 1430 | def train_for(self, num_iterations=None, break_callback=None):
|
@@ -1676,7 +1722,7 @@ def load(self, from_directory=None, best=False, filename=None, map_location=None
|
1676 | 1722 | 'best_checkpoint.pytorch'.
|
1677 | 1723 | filename : str
|
1678 | 1724 | Overrides the default filename.
|
1679 |
| - device : function, torch.device, string or a dict |
| 1725 | + map_location : function, torch.device, string or a dict |
1680 | 1726 | Specify how to remap storage locations.
|
1681 | 1727 |
|
1682 | 1728 | Returns
|
|
0 commit comments