-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunet_run.py
382 lines (289 loc) · 66.5 KB
/
unet_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# -*- coding: utf-8 -*-
"""unet-run.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/gist/Jeremy26/9ffc8efe458d5e139db86b9d50b00939/unet-run.ipynb

[Image reference](https://arxiv.org/abs/1505.04597)
## Imports
"""
# basic imports
import random
import numpy as np
from tqdm import tqdm
# DL library imports
import torch
import torch.nn as nn
import torch.nn.functional as F
# libraries for loading image, plotting
import cv2
import matplotlib.pyplot as plt
# try to import the library, if already present
# good to go, else install it and then import library
try:
import segmentation_models_pytorch as smp
except:
!pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp
"""## 1. Dataset : Download and use BDD100k dataset"""
# load images and label data
images = np.load("dataset/image_180_320.npy")
labels = np.load("dataset/label_180_320.npy")
# plot sample image
_, (ax0, ax1) = plt.subplots(1,2, figsize=(20,40))
ax0.imshow(images[458]) # random image number
ax0.set_title("Image")
ax1.imshow(labels[458])
ax1.set_title("Label")
plt.show()
"""### Class label to standard color maps"""
# Constants for Standard color mapping
# reference : https://github.com/bdd100k/bdd100k/blob/master/bdd100k/label/label.py
from collections import namedtuple
# Each label is a tuple with name, class id and color
Label = namedtuple( "Label", [ "name", "train_id", "color"])
drivables = [
Label("direct", 0, (171, 44, 236)), # red
Label("alternative", 1, (86, 211, 19)), # cyan
Label("background", 2, (0, 0, 0)), # black
]
train_id_to_color = [c.color for c in drivables if (c.train_id != -1 and c.train_id != 255)]
train_id_to_color = np.array(train_id_to_color)
print(f"train_id_to_color = \n {train_id_to_color}")
# plot sample image using defined color mappings
fig, axes = plt.subplots(1,2, figsize=(20,10))
axes[0].imshow(images[458]);
axes[0].set_title("RGB Image");
axes[0].axis('off');
axes[1].imshow(train_id_to_color[labels[201]]);
axes[1].set_title("Label");
axes[1].axis('off');
"""### Build Datasets & DataLoaders
Every PyTorch model is built and train using 3 elements:
* Dataset
* DataLoader
* Model
"""
from torch.utils.data import Dataset, DataLoader
class BDD100k_dataset(Dataset):
def __init__(self, images, labels, tf):
self.images = images
self.labels = labels
self.tf = tf
def __len__(self):
return self.images.shape[0]
def __getitem__(self, index):
# read source image and convert to RGB, apply transform
rgb_image = self.images[index]
if self.tf is not None:
rgb_image = self.tf(rgb_image)
# read label image and convert to torch tensor
label_image = torch.from_numpy(self.labels[index]).long()
return rgb_image, label_image
from torchvision import transforms
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.56, 0.406), std=(0.229, 0.224, 0.225))
])
# Function to split data into train, validation and test sets
def get_datasets(images, labels):
data = BDD100k_dataset(images, labels, tf=preprocess)
total_count = len(data)
train_count = int(0.7 * total_count)
valid_count = int(0.2 * total_count)
test_count = total_count - train_count - valid_count
train_set, val_set, test_set = torch.utils.data.random_split(data,
(train_count, valid_count, test_count), generator=torch.Generator().manual_seed(1))
return train_set, val_set, test_set
"""#### Dataloaders
- Dataloaders help load data in batches
- We'll need to define separate dataloaders for training, validation and test sets
"""
def get_dataloaders(train_set, val_set, test_set):
train_dataloader = DataLoader(train_set, batch_size=8,drop_last=True)
val_dataloader = DataLoader(val_set, batch_size=8)
test_dataloader = DataLoader(test_set, batch_size=8)
return train_dataloader, val_dataloader, test_dataloader
train_set, val_set, test_set= get_datasets(images, labels)
sample_image, sample_label = train_set[0]
print(f"There are {len(train_set)} train images, {len(val_set)} validation images, {len(test_set)} test Images")
print(f"Input shape = {sample_image.shape}, output label shape = {sample_label.shape}")
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(train_set, val_set, test_set)
"""### Show Sample images from dataset"""
from utils import inverse_transform
rgb_image, label = train_set[np.random.choice(len(train_set))]
rgb_image = inverse_transform(rgb_image).permute(1, 2, 0).cpu().detach().numpy()
label = label.cpu().detach().numpy()
# plot sample image
fig, axes = plt.subplots(1,2, figsize=(20,10))
axes[0].imshow(rgb_image);
axes[0].set_title("Image");
axes[0].axis('off');
axes[1].imshow(train_id_to_color[label]);
axes[1].set_title("Label");
axes[1].axis('off');
"""## 2. Network: Define a UNet Encoder-Decoder
Pay close attention to the image, this is what we are going to code:

Notice the 3 main elements:
* The Encoder
* The Decoder
* The Skip-Connections
### The Encoder
Notice how the encoder is made of these "double convolutions".
We have an input image, and then:
* Two 3x3 Convolutions with ReLU
* A 2x2 Max Pooling
That operation is repeated 4 times, as shown.
"""
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride =1, padding = 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride =1, padding = 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
class UNetEncoder(nn.Module):
def __init__(self, in_channels, out_channels, layer_channels):
super(UNetEncoder, self).__init__()
self.encoder = nn.ModuleList()
# Double Convolution
for num_channels in layer_channels:
self.encoder.append(double_conv(in_channels, num_channels))
in_channels = num_channels
# Pooling
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self):
# Pass input image through Encoder
for down in self.encoder:
x = down(x)
x = self.pool(x)
return x
"""### The Decoder
The upsampling part is made through a series of 4:
* **One Transposed 2x2 Convolution**
* **A Double 3x3 Convolution with ReLU** (as in the encoder)
Finally, the model has **a final 1x1 convolution that classifies every single pixel**!
"""
import torchvision.transforms.functional as TF
class UNetDecoder(nn.Module):
def __init__(self, layer_channels):
super(UNetDecoder, self).__init__()
self.decoder = nn.ModuleList()
# Decoder layer Double Convolution blocks
# and upsampling blocks
self.decoder = nn.ModuleList()
for num_channels in reversed(layer_channels):
# upsample output and reduce channels by 2
self.decoder.append(nn.ConvTranspose2d(num_channels*2, num_channels, kernel_size=2, stride=2))
self.decoder.append(double_conv(num_channels, num_channels))
def forward(self, x):
for up in self.decoder:
x = up(x)
return x
class UNetEncoder(nn.Module):
def __init__(self, in_channels, layer_channels):
super(UNetEncoder, self).__init__()
self.encoder = nn.ModuleList()
# Double Convolution blocks
for num_channels in layer_channels:
self.encoder.append(double_conv(in_channels, num_channels))
in_channels = num_channels
# Max Pooling
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
# Pass input image through Encoder blocks
# and return outputs at each stage
skip_connections = []
for down in self.encoder:
x = down(x)
skip_connections.append(x)
x = self.pool(x)
return x, skip_connections
class UNetDecoder(nn.Module):
def __init__(self, layer_channels):
super(UNetDecoder, self).__init__()
self.decoder = nn.ModuleList()
# Decoder layer Double Convolution blocks
# and upsampling blocks
self.decoder = nn.ModuleList()
for num_channels in reversed(layer_channels):
self.decoder.append(nn.ConvTranspose2d(num_channels*2, num_channels, kernel_size=2, stride=2))
self.decoder.append(double_conv(num_channels*2, num_channels))
def forward(self, x, skip_connections):
for idx in range(0, len(self.decoder), 2):
# upsample output and reduce channels by 2
x = self.decoder[idx](x)
# if skip connection shape doesn't match, resize
skip_connection = skip_connections[idx//2]
if x.shape != skip_connection.shape:
x = TF.resize(x, size=skip_connection.shape[2:])
# concatenate and pass through double_conv block
concat_skip = torch.cat((skip_connection, x), dim=1)
x = self.decoder[idx+1](concat_skip)
return x
class UNet(nn.Module):
def __init__(self, in_channels, out_channels, layer_channels):
super(UNet, self).__init__()
# Encoder and decoder modules
self.encoder = UNetEncoder(in_channels, layer_channels)
self.decoder = UNetDecoder(layer_channels)
# conv layer to transition from encoder to decoder and
# 1x1 convolution to reduce num channels to out_channels
self.bottleneck = double_conv(layer_channels[-1], layer_channels[-1]*2)
self.final_conv = nn.Conv2d(layer_channels[0], out_channels, kernel_size=1)
# initialize parameters now to avoid modifying the initialization of top_blocks
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1)
# nn.init.constant_(m.bias, 0)
def forward(self, x):
# Encoder blocks
encoder_output, skip_connections = self.encoder(x)
# transition between encoder and decoder
x = self.bottleneck(encoder_output)
# we need the last skip connection first
# so reversing the list
skip_connections = skip_connections[::-1]
# Decoder blocks
x = self.decoder(x, skip_connections)
# final 1x1 conv to match input size
return self.final_conv(x)
"""## 3. Training : Train and validate model on the custom dataset"""
from utils import meanIoU # metric class
from utils import plot_training_results # function to plot training curves
from utils import evaluate_model # evaluation function
from utils import train_validate_model # train validate function
"""### Hyperparameters"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
# reference : https://smp.readthedocs.io/en/latest/losses.html
criterion = smp.losses.DiceLoss('multiclass', classes=[0,1,2], log_loss = True, smooth=1.0)
# MODEL HYPERPARAMETERS
N_EPOCHS = 5
NUM_CLASSES = 3
MAX_LR = 3e-4
MODEL_NAME = 'UNet_baseline'
# create model, optimizer, lr_scheduler and pass to training function
model = UNet(in_channels=3, out_channels=3, layer_channels=[64, 128, 256, 512]).to(device)
optimizer = optim.Adam(model.parameters(), lr=MAX_LR)
scheduler = OneCycleLR(optimizer, max_lr= MAX_LR, epochs = N_EPOCHS,steps_per_epoch = len(train_dataloader),
pct_start=0.3, div_factor=10, anneal_strategy='cos')
output_path = "dataset"
_ = train_validate_model(model, N_EPOCHS, MODEL_NAME, criterion, optimizer,
device, train_dataloader, val_dataloader, meanIoU, 'meanIoU',
NUM_CLASSES, lr_scheduler = scheduler, output_path = output_path)
"""# 4. Evaluate : Evaluate the model on Test Data and visualize results"""
model.load_state_dict(torch.load(f'{output_path}/{MODEL_NAME}.pt', map_location=device))
_, test_metric = evaluate_model(model, test_dataloader, criterion, meanIoU, NUM_CLASSES, device)
print(f"\nModel has {test_metric} mean IoU in test set")
from utils import visualize_predictions
num_test_samples = 2
_, axes = plt.subplots(num_test_samples, 3, figsize=(3*6, num_test_samples * 4))
visualize_predictions(model, test_set, axes, device, numTestSamples=num_test_samples, id_to_color=train_id_to_color)
"""## Test on sample video"""
from utils import predict_video
predict_video(model, "Unet5epochs", "highway_1241_376.avi", "segmentation", 1241, 376, "cuda", train_id_to_color)