Skip to content
This repository was archived by the owner on Apr 19, 2023. It is now read-only.

Commit 2d657e6

Browse files
msimonovskymsimonovsky
msimonovsky
authored and
msimonovsky
committed
added tests for custom channel
1 parent b80d9a8 commit 2d657e6

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,58 @@
11
import unittest
22
import torch.cuda as cuda
3-
from inferno.utils.model_utils import ModelTester
3+
from inferno.utils.model_utils import ModelTester, MultiscaleModelTester
4+
from inferno.extensions.models import UNet
5+
6+
class _MultiscaleUNet(UNet):
7+
def conv_op_factory(self, in_channels, out_channels, part, index):
8+
return super(_MultiscaleUNet, self).conv_op_factory(in_channels, out_channels, part, index)[0], True
9+
10+
def forward(self, input):
11+
x = self._initial_conv(input)
12+
x = list(super(UNet, self).forward(x))
13+
x[-1] = self._output(x[-1])
14+
return tuple(x)
415

516

617
class UNetTest(unittest.TestCase):
718
def test_unet_2d(self):
8-
from inferno.extensions.models import UNet
919
tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256))
1020
if cuda.is_available():
1121
tester.cuda()
1222
tester(UNet(1, 1, dim=2, initial_features=32))
1323

1424
def test_unet_3d(self):
15-
from inferno.extensions.models import UNet
1625
tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64))
1726
if cuda.is_available():
1827
tester.cuda()
1928
# test default unet 3d
2029
tester(UNet(1, 1, dim=3, initial_features=8))
2130

31+
def test_monochannel_unet_3d(self):
32+
nc = 2
33+
class _UNetMonochannel(_MultiscaleUNet):
34+
def _get_num_channels(self, depth):
35+
return nc
36+
37+
shapes = [(1, nc, 16, 64, 64), (1, nc, 8, 32, 32), (1, nc, 4, 16, 16), (1, nc, 2, 8, 8), (1, nc, 1, 4, 4),
38+
(1, nc, 2, 8, 8), (1, nc, 4, 16, 16), (1, nc, 8, 32, 32), (1, 1, 16, 64, 64)]
39+
tester = MultiscaleModelTester((1, 1, 16, 64, 64), shapes)
40+
if cuda.is_available():
41+
tester.cuda()
42+
tester(_UNetMonochannel(1, 1, dim=3, initial_features=8))
43+
44+
def test_inverse_pyramid_unet_2d(self):
45+
class _UNetInversePyramid(_MultiscaleUNet):
46+
def _get_num_channels(self, depth):
47+
return [13, 12, 11][depth - 1]
48+
49+
shapes = [(1, 13, 16, 64), (1, 12, 8, 32), (1, 11, 4, 16), (1, 11, 2, 8),
50+
(1, 12, 4, 16), (1, 13, 8, 32), (1, 1, 16, 64)]
51+
tester = MultiscaleModelTester((1, 1, 16, 64), shapes)
52+
if cuda.is_available():
53+
tester.cuda()
54+
tester(_UNetInversePyramid(1, 1, dim=2, depth=3, initial_features=8))
55+
2256

2357
if __name__ == '__main__':
2458
unittest.main()

0 commit comments

Comments
 (0)