-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunet_master (2).py
149 lines (112 loc) · 5.8 KB
/
unet_master (2).py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging
import sys
from pathlib import Path
import pandas as pd
from torchvision import transforms
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
from SudokuDataset import SudokuDataset
from utils import dice_loss, multiclass_dice_coeff, dice_coeff
from unet_module import UNET
dir_checkpoint = Path('./checkpoints/')
def train_net(net, dataset, epochs, batch_size, learning_rate, device, writer,
save_checkpoint=True, val_percent=0.1, amp=True):
pass
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
# 3. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=1, pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
global_step = 0
# 5. Begin training
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images = batch[0]
true_masks = batch[1]
# assert images.shape[1] == net.in_channels, \
# f'Network has been defined with {net.n_channels} input channels, ' \
# f'but loaded images have {images.shape[1]} channels. Please check that ' \
# 'the images are loaded correctly.'
images = images.to(device=device, dtype=torch.float16)
true_masks = true_masks.to(device=device, dtype=torch.float16)
with torch.cuda.amp.autocast(enabled=amp, dtype=torch.float16):
masks_pred = net(images)
loss = criterion(masks_pred, true_masks)
# + dice_loss(F.softmax(masks_pred, dim=1).float(),
# F.one_hot(true_masks, N_CLASSES).permute(0, 3, 1, 2).float(),
# multiclass=True)
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
writer.add_scalar('training loss',
loss.item(),
global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()})
# Evaluation round
division_step = (n_train // (4 * batch_size))
if division_step > 0:
if global_step % division_step == 0:
histograms = {}
# for tag, value in net.named_parameters():
# tag = tag.replace('/', '.')
# histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
# histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
val_score = evaluate(net, val_loader, device)
scheduler.step(val_score)
logging.info('Validation Dice score: {}'.format(val_score))
writer.add_scalar('validation_loss',
val_score,
global_step)
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
logging.info(f'Checkpoint {epoch + 1} saved!')
def evaluate(net, dataloader, device):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
# iterate over the validation set
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=True):
image, mask_true = batch[0], batch[1]
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32)
mask_true = mask_true.to(device=device, dtype=torch.float32)
# mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
with torch.no_grad():
# predict the mask
mask_pred = net(image)
# convert to one-hot format
if net.n_classes == 1:
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
# compute the Dice score
dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
else:
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...],
reduce_batch_first=False)
net.train()
# Fixes a potential division by zero error
if num_val_batches == 0:
return dice_score
return dice_score / num_val_batches
if __name__ == "__main__":
sys.exit('')