Skip to content
This repository was archived by the owner on Apr 19, 2023. It is now read-only.

Commit 0cbaf1d

Browse files
Merge pull request #160 from DerThorsten/unet
minor unet changes, added more examples
2 parents d8287f5 + 7489290 commit 0cbaf1d

File tree

3 files changed

+248
-4
lines changed

3 files changed

+248
-4
lines changed

examples/plot_cheap_unet.py

+241
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""
2+
UNet Tutorial
3+
================================
4+
A unet example which can be run without a gpu
5+
"""
6+
7+
##############################################################################
8+
# Preface
9+
# --------------
10+
# We start with some unspectacular multi purpose imports needed for this example
11+
import matplotlib.pyplot as plt
12+
import torch
13+
from torch import nn
14+
import numpy
15+
16+
17+
##############################################################################
18+
19+
# determine whether we have a gpu
20+
# and should use cuda
21+
USE_CUDA = torch.cuda.is_available()
22+
23+
24+
##############################################################################
25+
# Dataset
26+
# --------------
27+
# For simplicity we will use a toy dataset where we need to perform
28+
# a binary segmentation task.
29+
from inferno.io.box.binary_blobs import get_binary_blob_loaders
30+
31+
# convert labels from long to float as needed by
32+
# binary cross entropy loss
33+
def label_transform(x):
34+
return torch.from_numpy(x).float()
35+
#label_transform = lambda x : torch.from_numpy(x).float()
36+
37+
train_loader, test_loader, validate_loader = get_binary_blob_loaders(
38+
size=8, # how many images per {train,test,validate}
39+
train_batch_size=2,
40+
length=256, # <= size of the images
41+
gaussian_noise_sigma=1.4, # <= how noise are the images
42+
train_label_transform = label_transform,
43+
validate_label_transform = label_transform
44+
)
45+
46+
image_channels = 1 # <-- number of channels of the image
47+
pred_channels = 1 # <-- number of channels needed for the prediction
48+
49+
if False:
50+
##############################################################################
51+
# Visualize Dataset
52+
# ~~~~~~~~~~~~~~~~~~~~~~
53+
fig = plt.figure()
54+
55+
for i,(image, target) in enumerate(train_loader):
56+
ax = fig.add_subplot(1, 2, 1)
57+
ax.imshow(image[0,0,...])
58+
ax.set_title('raw data')
59+
ax = fig.add_subplot(1, 2, 2)
60+
ax.imshow(target[0,...])
61+
ax.set_title('ground truth')
62+
break
63+
fig.tight_layout()
64+
plt.show()
65+
66+
67+
68+
69+
##############################################################################
70+
# Training
71+
# ----------------------------
72+
# To train the unet, we use the infernos Trainer class of inferno.
73+
# Since we train many models later on in this example we encapsulate
74+
# the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for
75+
# an example dedicated to the trainer itself).
76+
from inferno.trainers import Trainer
77+
from inferno.utils.python_utils import ensure_dir
78+
79+
def train_model(model, loaders, **kwargs):
80+
81+
trainer = Trainer(model)
82+
trainer.build_criterion('BCEWithLogitsLoss')
83+
trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001))
84+
#trainer.validate_every((kwargs.get('validate_every', 10), 'epochs'))
85+
#trainer.save_every((kwargs.get('save_every', 10), 'epochs'))
86+
#trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor')))
87+
trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 20))
88+
89+
# bind the loaders
90+
trainer.bind_loader('train', loaders[0])
91+
trainer.bind_loader('validate', loaders[1])
92+
93+
if USE_CUDA:
94+
trainer.cuda()
95+
96+
# do the training
97+
trainer.fit()
98+
99+
return trainer
100+
101+
102+
103+
104+
##############################################################################
105+
# Prediction
106+
# ----------------------------
107+
# The trainer contains the trained model and we can do predictions.
108+
# We use :code:`unwrap` to convert the results to numpy arrays.
109+
# Since we want to do many prediction we encapsulate the
110+
# the prediction in a function
111+
from inferno.utils.torch_utils import unwrap
112+
113+
def predict(trainer, test_loader, save_dir=None):
114+
115+
116+
trainer.eval_mode()
117+
for image, target in test_loader:
118+
119+
# transfer image to gpu
120+
image = image.cuda() if USE_CUDA else image
121+
122+
# get batch size from image
123+
batch_size = image.size()[0]
124+
125+
for b in range(batch_size):
126+
prediction = trainer.apply_model(image)
127+
prediction = torch.nn.functional.sigmoid(prediction)
128+
129+
image = unwrap(image, as_numpy=True, to_cpu=True)
130+
prediction = unwrap(prediction, as_numpy=True, to_cpu=True)
131+
target = unwrap(target, as_numpy=True, to_cpu=True)
132+
133+
fig = plt.figure()
134+
135+
ax = fig.add_subplot(2, 2, 1)
136+
ax.imshow(image[b,0,...])
137+
ax.set_title('raw data')
138+
139+
ax = fig.add_subplot(2, 2, 2)
140+
ax.imshow(target[b,...])
141+
ax.set_title('ground truth')
142+
143+
ax = fig.add_subplot(2, 2, 4)
144+
ax.imshow(prediction[b,...])
145+
ax.set_title('prediction')
146+
147+
fig.tight_layout()
148+
plt.show()
149+
150+
151+
152+
##############################################################################
153+
# Custom UNet
154+
# ----------------------------
155+
# Often one needs to have a UNet with custom layers.
156+
# Here we show how to implement such a customized UNet.
157+
# To this end we derive from :code:`UNetBase`.
158+
# For the sake of this example we will create
159+
# a Unet which uses depthwise convolutions and might be trained on a CPU
160+
from inferno.extensions.models import UNetBase
161+
from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D,ConvActivation
162+
163+
164+
class CheapConv(nn.Module):
165+
def __init__(self, in_channels, out_channels, activated):
166+
super(CheapConv, self).__init__()
167+
self.in_channels = in_channels
168+
self.out_channels = out_channels
169+
if activated:
170+
self.convs = torch.nn.Sequential(
171+
ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2),
172+
ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
173+
)
174+
else:
175+
self.convs = torch.nn.Sequential(
176+
ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2),
177+
Conv2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
178+
)
179+
def forward(self, x):
180+
assert x.shape[1] == self.in_channels,"input has wrong number of channels"
181+
x = self.convs(x)
182+
assert x.shape[1] == self.out_channels,"output has wrong number of channels"
183+
return x
184+
185+
186+
class CheapConvBlock(nn.Module):
187+
def __init__(self, in_channels, out_channels, activated):
188+
super(CheapConvBlock, self).__init__()
189+
self.activated = activated
190+
self.in_channels = in_channels
191+
self.out_channels = out_channels
192+
if(in_channels != out_channels):
193+
self.start = ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
194+
else:
195+
self.start = None
196+
self.conv_a = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=True)
197+
self.conv_b = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=False)
198+
self.activation = torch.nn.ReLU()
199+
def forward(self, x):
200+
x_input = x
201+
if self.start is not None:
202+
x_input = self.start(x_input)
203+
204+
x = self.conv_a(x_input)
205+
x = self.conv_b(x)
206+
207+
x = x + x_input
208+
209+
if self.activated:
210+
x = self.activation(x)
211+
return x
212+
213+
class MySimple2DCpUnet(UNetBase):
214+
def __init__(self, in_channels, out_channels, depth=3, residual=False, **kwargs):
215+
super(MySimple2DCpUnet, self).__init__(in_channels=in_channels, out_channels=out_channels,
216+
dim=2, depth=depth, **kwargs)
217+
218+
def conv_op_factory(self, in_channels, out_channels, part, index):
219+
220+
# last?
221+
last = part == 'up' and index==0
222+
return CheapConvBlock(in_channels=in_channels, out_channels=out_channels, activated=not last),False
223+
224+
225+
226+
from inferno.extensions.layers import RemoveSingletonDimension
227+
model_b = torch.nn.Sequential(
228+
CheapConv(in_channels=image_channels, out_channels=4, activated=True),
229+
MySimple2DCpUnet(in_channels=4, out_channels=pred_channels) ,
230+
RemoveSingletonDimension(dim=1)
231+
)
232+
233+
234+
###################################################
235+
# do the training (with the same functions as before)
236+
trainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001)
237+
238+
###################################################
239+
# do the training (with the same functions as before)1
240+
predict(trainer=trainer, test_loader=test_loader)
241+

examples/plot_unet_tutorial.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def predict(trainer, test_loader, save_dir=None):
201201
# of the unet
202202
from inferno.extensions.models import UNetBase
203203
from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D
204+
from inferno.extensions.layers.sampling import Upsample
204205

205206
class MySimple2DUnet(UNetBase):
206207
def __init__(self, in_channels, out_channels, depth=3, **kwargs):
@@ -221,7 +222,7 @@ def conv_op_factory(self, in_channels, out_channels, part, index):
221222
), False
222223
elif part == 'up':
223224
# are we in the very last block?
224-
if index + 1 == self.depth:
225+
if index == 0:
225226
return torch.nn.Sequential(
226227
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
227228
Conv2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
@@ -243,7 +244,7 @@ def downsample_op_factory(self, index):
243244

244245
# this function CAN be implemented, if not, Upsampling is used by default
245246
def upsample_op_factory(self, index):
246-
return torch.nn.Upsample(mode='bilinear', align_corners=False,scale_factor=2)
247+
return Upsample(mode='bilinear', align_corners=False,scale_factor=2)
247248

248249
model_b = torch.nn.Sequential(
249250
ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3),

inferno/extensions/models/unet.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from ..layers.identity import Identity
44
from ..layers.convolutional import ConvELU2D, ConvELU3D, Conv2D, Conv3D
5+
from ..layers.sampling import Upsample as InfernoUpsample
56
from ...utils.math_utils import max_allowed_ds_steps
67

78

@@ -306,8 +307,9 @@ def downsample_op_factory(self, index):
306307
C = nn.MaxPool2d if self.dim == 2 else nn.MaxPool3d
307308
return C(kernel_size=2, stride=2)
308309

309-
def upsample_op_factory(self, index):
310-
return nn.Upsample(**self._upsample_kwargs)
310+
def upsample_op_factory(self, index):\
311+
return InfernoUpsample(**self._upsample_kwargs)
312+
#return nn.Upsample(**self._upsample_kwargs)
311313

312314
def pre_conv_op_regularizer_factory(self, in_channels, out_channels, part, index):
313315
if self.use_dropout and in_channels > 2:

0 commit comments

Comments
 (0)