Skip to content
This repository was archived by the owner on Apr 19, 2023. It is now read-only.

Commit 50fa978

Browse files
authored
Merge pull request #165 from mys007/unet_nchannels_freedom
UNet: Allow to freely define the number of channels per depth in subclasses
2 parents 365dfd5 + 2d657e6 commit 50fa978

File tree

2 files changed

+48
-21
lines changed

2 files changed

+48
-21
lines changed

inferno/extensions/models/unet.py

+11-18
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,18 @@ def __init__(self, in_channels, dim, out_channels=None, depth=3,
119119
assert len(self.n_channels_per_output) == self._store_conv_down.count(True) + \
120120
self._store_conv_up.count(True) + int(self._store_conv_bottom)
121121

122+
def _get_num_channels(self, depth):
123+
assert depth > 0
124+
return self.in_channels * self.gain**depth
125+
122126
def _init__downstream(self):
123127
conv_down_ops = []
124128
self._store_conv_down = []
125129

126130
current_in_channels = self.in_channels
127131

128132
for i in range(self.depth):
129-
out_channels = current_in_channels * self.gain
133+
out_channels = self._get_num_channels(i + 1)
130134
op, return_op_res = self.conv_op_factory(in_channels=current_in_channels,
131135
out_channels=out_channels,
132136
part='down', index=i)
@@ -138,7 +142,7 @@ def _init__downstream(self):
138142
self._store_conv_down.append(False)
139143

140144
# increase the number of channels
141-
current_in_channels *= self.gain
145+
current_in_channels = out_channels
142146

143147
# store as proper torch ModuleList
144148
self._conv_down_ops = nn.ModuleList(conv_down_ops)
@@ -147,9 +151,7 @@ def _init__downstream(self):
147151

148152
def _init__bottom(self):
149153

150-
conv_up_ops = []
151-
152-
current_in_channels = self.in_channels* self.gain**self.depth
154+
current_in_channels = self._get_num_channels(self.depth)
153155

154156
factory_res = self.conv_op_factory(in_channels=current_in_channels,
155157
out_channels=current_in_channels, part='bottom', index=0)
@@ -163,12 +165,12 @@ def _init__bottom(self):
163165

164166
def _init__upstream(self):
165167
conv_up_ops = []
166-
current_in_channels = self.in_channels * self.gain**self.depth
168+
current_in_channels = self._get_num_channels(self.depth)
167169

168170
for i in range(self.depth):
169171
# the number of out channels (set to self.out_channels for last decoder)
170-
out_channels = self.out_channels if i +1 == self.depth else\
171-
current_in_channels // self.gain
172+
out_channels = self.out_channels if i + 1 == self.depth else \
173+
self._get_num_channels(self.depth - i - 1)
172174

173175
# if not residual we concat which needs twice as many channels
174176
fac = 1 if self.residual else 2
@@ -186,7 +188,7 @@ def _init__upstream(self):
186188
self._store_conv_up.append(False)
187189

188190
# decrease the number of input_channels
189-
current_in_channels //= self.gain
191+
current_in_channels = out_channels
190192

191193
# store as proper torch ModuleLis
192194
self._conv_up_ops = nn.ModuleList(conv_up_ops)
@@ -311,15 +313,6 @@ def upsample_op_factory(self, index):\
311313
return InfernoUpsample(**self._upsample_kwargs)
312314
#return nn.Upsample(**self._upsample_kwargs)
313315

314-
def pre_conv_op_regularizer_factory(self, in_channels, out_channels, part, index):
315-
if self.use_dropout and in_channels > 2:
316-
return self._channel_dropout_op(x)
317-
else:
318-
return Identity()
319-
320-
def post_conv_op_regularizer_factory(self, in_channels, out_channels, part, index):
321-
return Identity()
322-
323316
def conv_op_factory(self, in_channels, out_channels, part, index):
324317
raise NotImplementedError("conv_op_factory need to be implemented by deriving class")
325318

Original file line numberDiff line numberDiff line change
@@ -1,24 +1,58 @@
11
import unittest
22
import torch.cuda as cuda
3-
from inferno.utils.model_utils import ModelTester
3+
from inferno.utils.model_utils import ModelTester, MultiscaleModelTester
4+
from inferno.extensions.models import UNet
5+
6+
class _MultiscaleUNet(UNet):
7+
def conv_op_factory(self, in_channels, out_channels, part, index):
8+
return super(_MultiscaleUNet, self).conv_op_factory(in_channels, out_channels, part, index)[0], True
9+
10+
def forward(self, input):
11+
x = self._initial_conv(input)
12+
x = list(super(UNet, self).forward(x))
13+
x[-1] = self._output(x[-1])
14+
return tuple(x)
415

516

617
class UNetTest(unittest.TestCase):
718
def test_unet_2d(self):
8-
from inferno.extensions.models import UNet
919
tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256))
1020
if cuda.is_available():
1121
tester.cuda()
1222
tester(UNet(1, 1, dim=2, initial_features=32))
1323

1424
def test_unet_3d(self):
15-
from inferno.extensions.models import UNet
1625
tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64))
1726
if cuda.is_available():
1827
tester.cuda()
1928
# test default unet 3d
2029
tester(UNet(1, 1, dim=3, initial_features=8))
2130

31+
def test_monochannel_unet_3d(self):
32+
nc = 2
33+
class _UNetMonochannel(_MultiscaleUNet):
34+
def _get_num_channels(self, depth):
35+
return nc
36+
37+
shapes = [(1, nc, 16, 64, 64), (1, nc, 8, 32, 32), (1, nc, 4, 16, 16), (1, nc, 2, 8, 8), (1, nc, 1, 4, 4),
38+
(1, nc, 2, 8, 8), (1, nc, 4, 16, 16), (1, nc, 8, 32, 32), (1, 1, 16, 64, 64)]
39+
tester = MultiscaleModelTester((1, 1, 16, 64, 64), shapes)
40+
if cuda.is_available():
41+
tester.cuda()
42+
tester(_UNetMonochannel(1, 1, dim=3, initial_features=8))
43+
44+
def test_inverse_pyramid_unet_2d(self):
45+
class _UNetInversePyramid(_MultiscaleUNet):
46+
def _get_num_channels(self, depth):
47+
return [13, 12, 11][depth - 1]
48+
49+
shapes = [(1, 13, 16, 64), (1, 12, 8, 32), (1, 11, 4, 16), (1, 11, 2, 8),
50+
(1, 12, 4, 16), (1, 13, 8, 32), (1, 1, 16, 64)]
51+
tester = MultiscaleModelTester((1, 1, 16, 64), shapes)
52+
if cuda.is_available():
53+
tester.cuda()
54+
tester(_UNetInversePyramid(1, 1, dim=2, depth=3, initial_features=8))
55+
2256

2357
if __name__ == '__main__':
2458
unittest.main()

0 commit comments

Comments
 (0)