Skip to content
This repository was archived by the owner on Apr 19, 2023. It is now read-only.

Commit 8e055b0

Browse files
Merge pull request #198 from inferno-pytorch/apex
Mixed precision training with apex
2 parents a75888e + 4c732db commit 8e055b0

File tree

13 files changed

+141
-25
lines changed

13 files changed

+141
-25
lines changed

inferno/extensions/containers/graph.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def is_node_in_graph(self, name):
130130
-------
131131
bool
132132
"""
133-
return name in self.graph.node
133+
return name in self.graph.nodes
134134

135135
def is_source_node(self, name):
136136
"""
@@ -187,7 +187,7 @@ def output_nodes(self):
187187
list
188188
A list of names (str) of the output nodes.
189189
"""
190-
return [name for name, node_attributes in self.graph.node.items()
190+
return [name for name, node_attributes in self.graph.nodes.items()
191191
if node_attributes.get('is_output_node', False)]
192192

193193
@property
@@ -201,7 +201,7 @@ def input_nodes(self):
201201
list
202202
A list of names (str) of the input nodes.
203203
"""
204-
return [name for name, node_attributes in self.graph.node.items()
204+
return [name for name, node_attributes in self.graph.nodes.items()
205205
if node_attributes.get('is_input_node', False)]
206206

207207
@property
+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .adam import Adam
2-
from .annealed_adam import AnnealedAdam
2+
from .annealed_adam import AnnealedAdam
3+
from .ranger import Ranger, RangerQH, RangerVA
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# easy support for additional ranger optimizers from
2+
# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
3+
try:
4+
from ranger import Ranger, RangerVA, RangerQH
5+
except ImportError:
6+
Ranger = None
7+
RangerVA = None
8+
RangerQH = None

inferno/io/transform/base.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __call__(self, *tensors, **transform_function_kwargs):
5858
transformed = self.batch_function(tensors, **transform_function_kwargs)
5959
return pyu.from_iterable(transformed)
6060
elif hasattr(self, 'tensor_function'):
61-
transformed = [self.tensor_function(tensor, **transform_function_kwargs)
61+
transformed = [self._apply_tensor_function(tensor, **transform_function_kwargs)
6262
if tensor_index in apply_to else tensor
6363
for tensor_index, tensor in enumerate(tensors)]
6464
return pyu.from_iterable(transformed)
@@ -77,9 +77,17 @@ def __call__(self, *tensors, **transform_function_kwargs):
7777
else:
7878
raise NotImplementedError
7979

80+
# noinspection PyUnresolvedReferences
81+
def _apply_tensor_function(self, tensor, **transform_function_kwargs):
82+
if isinstance(tensor, list):
83+
return [self._apply_tensor_function(tens) for tens in tensor]
84+
return self.tensor_function(tensor)
85+
8086
# noinspection PyUnresolvedReferences
8187
def _apply_image_function(self, tensor, **transform_function_kwargs):
8288
assert pyu.has_callable_attr(self, 'image_function')
89+
if isinstance(tensor, list):
90+
return [self._apply_image_function(tens) for tens in tensor]
8391
# 2D case
8492
if tensor.ndim == 4:
8593
return np.array([np.array([self.image_function(image, **transform_function_kwargs)
@@ -106,6 +114,8 @@ def _apply_image_function(self, tensor, **transform_function_kwargs):
106114
# noinspection PyUnresolvedReferences
107115
def _apply_volume_function(self, tensor, **transform_function_kwargs):
108116
assert pyu.has_callable_attr(self, 'volume_function')
117+
if isinstance(tensor, list):
118+
return [self._apply_volume_function(tens) for tens in tensor]
109119
# 3D case
110120
if tensor.ndim == 5:
111121
# tensor is bczyx
@@ -125,7 +135,7 @@ def _apply_volume_function(self, tensor, **transform_function_kwargs):
125135
# We're applying the volume function on the volume itself
126136
return self.volume_function(tensor, **transform_function_kwargs)
127137
else:
128-
raise NotImplementedError
138+
raise NotImplementedError("Volume function not implemented for ndim %i" % tensor.ndim)
129139

130140

131141
class Compose(object):

inferno/io/transform/image.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -596,5 +596,5 @@ def batch_function(self, image):
596596
pad_r = image_shape - new_shape - pad_l
597597
padding = [(0,0)] + list(zip(pad_l, pad_r))
598598
img = np.pad(img, padding, 'constant', constant_values=self.pad_const)
599-
seg = np.pad(seg, padding, 'constant', constant_values=self.pad_const)
600-
return img, seg
599+
seg = np.pad(seg, padding, 'constant', constant_values=self.pad_const)
600+
return img, seg

inferno/io/transform/volume.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def volume_function(self, volume):
6363
return volume
6464

6565

66+
# TODO this is obsolete
6667
class AdditiveRandomNoise3D(Transform):
6768
""" Add gaussian noise to 3d volume
6869
@@ -105,7 +106,7 @@ def __init__(self, sigma, mode='gaussian', **super_kwargs):
105106
self.sigma = sigma
106107

107108
# TODO check if volume is tensor and use torch functions in that case
108-
def volume_function(self, volume):
109+
def tensor_function(self, volume):
109110
volume += np.random.normal(loc=0, scale=self.sigma, size=volume.shape)
110111
return volume
111112

inferno/io/volumetric/lazy_volume_loader.py

+49-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22
import os
3+
import pickle
4+
from concurrent import futures
35

46
# try to load io libraries (h5py and z5py)
57
try:
@@ -20,10 +22,39 @@
2022
from ...utils import python_utils as pyu
2123

2224

25+
# TODO support h5py as well
26+
def filter_base_sequence(input_path, input_key,
27+
window_size, stride,
28+
filter_function, n_threads):
29+
with z5py.File(input_path, 'r') as f:
30+
ds = f[input_key]
31+
shape = list(ds.shape)
32+
sequence = vu.slidingwindowslices(shape=shape,
33+
window_size=window_size,
34+
strides=stride,
35+
shuffle=True,
36+
add_overhanging=True)
37+
38+
def check_slice(slice_id, slice_):
39+
print("Checking slice_id", slice_id)
40+
data = ds[slice_]
41+
if filter_function(data):
42+
return None
43+
else:
44+
return slice_
45+
46+
with futures.ThreadPoolExecutor(n_threads) as tp:
47+
tasks = [tp.submit(check_slice, slice_id, slice_) for slice_id, slice_ in enumerate(sequence)]
48+
filtered_sequence = [t.result() for t in tasks]
49+
50+
filtered_sequence = [seq for seq in filtered_sequence if seq is not None]
51+
return filtered_sequence
52+
53+
2354
class LazyVolumeLoaderBase(SyncableDataset):
2455
def __init__(self, dataset, window_size, stride, downsampling_ratio=None, padding=None,
2556
padding_mode='reflect', transforms=None, return_index_spec=False, name=None,
26-
data_slice=None):
57+
data_slice=None, base_sequence=None):
2758
super(LazyVolumeLoaderBase, self).__init__()
2859
assert len(window_size) == dataset.ndim, "%i, %i" % (len(window_size), dataset.ndim)
2960
assert len(stride) == dataset.ndim
@@ -58,7 +89,22 @@ def __init__(self, dataset, window_size, stride, downsampling_ratio=None, paddin
5889
else:
5990
raise NotImplementedError
6091

61-
self.base_sequence = self.make_sliding_windows()
92+
if base_sequence is None:
93+
self.base_sequence = self.make_sliding_windows()
94+
else:
95+
self.base_sequence = self.load_base_sequence(base_sequence)
96+
97+
@staticmethod
98+
def load_base_sequence(base_sequence):
99+
if isinstance(base_sequence, (list, tuple)):
100+
return base_sequence
101+
elif isinstance(base_sequence, str):
102+
assert os.path.exists(base_sequence)
103+
with open(base_sequence, 'rb') as f:
104+
base_sequence = pickle.load(f)
105+
return base_sequence
106+
else:
107+
raise ValueError("Unsupported base_sequence format, must be either listlike or str")
62108

63109
def normalize_slice(self, data_slice):
64110
if data_slice is None:
@@ -185,7 +231,7 @@ def __init__(self, file_impl, path,
185231
assert os.path.exists(path), path
186232
self.path = path
187233
else:
188-
raise NotImplementedError
234+
raise NotImplementedError("Not implemented for type %s" % type(path))
189235

190236
if isinstance(path_in_file, dict):
191237
assert name is not None

inferno/io/volumetric/volume.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def pad_volume(self, padding=None):
100100
assert_(all(isinstance(pad, (int, tuple, list)) for pad in self.padding),\
101101
"Expect int or iterable", TypeError)
102102
self.padding = [[pad, pad] if isinstance(pad, int) else pad for pad in self.padding]
103+
print(self.volume.shape)
103104
self.volume = np.pad(self.volume,
104105
pad_width=self.padding,
105106
mode=self.padding_mode)
@@ -228,7 +229,7 @@ def __init__(self, path, path_in_h5_dataset=None, data_slice=None, transforms=No
228229
if self.data_slice is not None and slicing_config_for_name.get('is_multichannel', False):
229230
self.data_slice = (slice(None),) + self.data_slice
230231

231-
assert 'window_size' in slicing_config_for_name
232+
assert 'window_size' in slicing_config_for_name, str(slicing_config_for_name)
232233
assert 'stride' in slicing_config_for_name
233234

234235
# Read in volume from file (can be hdf5, n5 or zarr)

inferno/io/volumetric/volumetric_utils.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,13 @@ def dimension_window(start, stop, wsize, stride, dimsize, ds_dim):
4242
# otherwise predict the whole volume
4343
if dataslice is not None:
4444
assert len(dataslice) == dim, "Dataslice must be a tuple with len = data dimension."
45-
starts = [sl.start for sl in dataslice]
46-
stops = [sl.stop - wsize for sl, wsize in zip(dataslice, window_size)]
45+
starts = [0 if sl.start is None else sl.start for sl in dataslice]
46+
stops = [sh - wsize if sl.stop is None else sl.stop - wsize
47+
for sl, wsize, sh in zip(dataslice, window_size, shape)]
4748
else:
4849
starts = dim * [0]
49-
stops = [dimsize - wsize if wsize != dimsize else dimsize
50-
for dimsize, wsize in zip(shape, window_size)]
50+
stops = [dimsize - wsize if wsize != dimsize else dimsize
51+
for dimsize, wsize in zip(shape, window_size)]
5152

5253
assert all(stp > strt for strt, stp in zip(starts, stops)),\
5354
"%s, %s" % (str(starts), str(stops))
@@ -128,7 +129,7 @@ def _to_list(x):
128129
nslices = [_1Dwindow(startmin, startmax, nhoodsiz, st, dsample, datalen, shuffle) if windowspec == 'x'
129130
else [slice(ws, ws + 1) for ws in _to_list(windowspec)]
130131
for startmin, startmax, datalen, nhoodsiz, st, windowspec, dsample in zip(startmins, startmaxs, shape,
131-
nhoodsize, stride, window, ds)]
132+
nhoodsize, stride, window, ds)]
132133

133134
return it.product(*nslices)
134135

inferno/trainers/basic.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727
from .callbacks import Console
2828
from ..utils.exceptions import assert_, NotSetError, NotTorchModuleError, DeviceError
2929

30+
# NOTE for distributed training, we might also need
31+
# from apex.parallel import DistributedDataParallel as DDP
32+
# but I don't know where exactly to put it.
33+
try:
34+
from apex import amp
35+
except ImportError:
36+
amp = None
37+
3038

3139
class Trainer(object):
3240
"""A basic trainer.
@@ -126,10 +134,44 @@ def __init__(self, model=None):
126134
# Print console
127135
self._console = Console()
128136

137+
# Train with mixed precision, only works
138+
# if we have apex
139+
self._mixed_precision = False
140+
self._apex_opt_level = 'O1'
141+
129142
# Public
130143
if model is not None:
131144
self.model = model
132145

146+
@property
147+
def mixed_precision(self):
148+
return self._mixed_precision
149+
150+
# this needs to be called after model and optimizer are set
151+
@mixed_precision.setter
152+
def mixed_precision(self, mp):
153+
if mp:
154+
assert_(amp is not None, "Cannot use mixed precision training without apex library", RuntimeError)
155+
assert_(self.model is not None and self._optimizer is not None,
156+
"Model and optimizer need to be set before activating mixed precision", RuntimeError)
157+
# in order to support BCE loss
158+
amp.register_float_function(torch, 'sigmoid')
159+
# For now, we don't allow to set 'keep_batchnorm' and 'loss_scale'
160+
self.model, self._optimizer = amp.initialize(self.model, self._optimizer,
161+
opt_level=self._apex_opt_level,
162+
keep_batchnorm_fp32=None)
163+
self._mixed_precision = mp
164+
165+
@property
166+
def apex_opt_level(self):
167+
return self._apex_opt_level
168+
169+
@apex_opt_level.setter
170+
def apex_opt_level(self, opt_level):
171+
assert_(opt_level in ('O0', 'O1', 'O2', 'O3'),
172+
"Invalid optimization level", ValueError)
173+
self._apex_opt_level = opt_level
174+
133175
@property
134176
def console(self):
135177
"""Get the current console."""
@@ -1368,17 +1410,21 @@ def apply_model_and_loss(self, inputs, target, backward=True, mode=None):
13681410
kwargs['trainer'] = self
13691411
if mode == 'train':
13701412
loss = self.criterion(prediction, target, **kwargs) \
1371-
if len(target) != 0 else self.criterion(prediction, **kwargs)
1413+
if len(target) != 0 else self.criterion(prediction, **kwargs)
13721414
elif mode == 'eval':
13731415
loss = self.validation_criterion(prediction, target, **kwargs) \
1374-
if len(target) != 0 else self.validation_criterion(prediction, **kwargs)
1416+
if len(target) != 0 else self.validation_criterion(prediction, **kwargs)
13751417
else:
13761418
raise ValueError
13771419
if backward:
13781420
# Backprop if required
13791421
# retain_graph option is needed for some custom
13801422
# loss functions like malis, False per default
1381-
loss.backward(retain_graph=self.retain_graph)
1423+
if self.mixed_precision:
1424+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
1425+
scaled_loss.backward(retain_graph=self.retain_graph)
1426+
else:
1427+
loss.backward(retain_graph=self.retain_graph)
13821428
return prediction, loss
13831429

13841430
def train_for(self, num_iterations=None, break_callback=None):
@@ -1676,7 +1722,7 @@ def load(self, from_directory=None, best=False, filename=None, map_location=None
16761722
'best_checkpoint.pytorch'.
16771723
filename : str
16781724
Overrides the default filename.
1679-
device : function, torch.device, string or a dict
1725+
map_location : function, torch.device, string or a dict
16801726
Specify how to remap storage locations.
16811727
16821728
Returns

inferno/trainers/callbacks/essentials.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,10 @@ def norm_or_value(self):
277277
def after_model_and_loss_is_applied(self, **_):
278278
tu.clip_gradients_(self.trainer.model.parameters(), self.mode, self.norm_or_value)
279279

280+
280281
class GarbageCollection(Callback):
281282
"""
282-
Callback that triggers garbage collection at the end of every
283+
Callback that triggers garbage collection at the end of every
283284
training iteration in order to reduce the memory footprint of training
284285
"""
285286

inferno/trainers/callbacks/scheduling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,10 @@ def end_of_validation_run(self, **_):
301301

302302
@staticmethod
303303
def is_significantly_less_than(x, y, min_relative_delta):
304+
eps = 1.e-6
304305
if x > y:
305306
return False
306-
relative_delta = abs(y - x) / abs(y)
307+
relative_delta = abs(y - x) / (abs(y) + eps)
307308
return relative_delta > min_relative_delta
308309

309310

inferno/utils/io_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def yaml2dict(path):
4949
# Forgivable mistake that path is a dict already
5050
return path
5151
with open(path, 'r') as f:
52-
readict = yaml.load(f)
52+
readict = yaml.load(f, Loader=yaml.FullLoader)
5353
return readict
5454

5555

0 commit comments

Comments
 (0)