-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathpytorch_train_val.py
157 lines (119 loc) · 5.57 KB
/
pytorch_train_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# train, test on MEBeauty and plot the results
import pytorch_mebeauty_dataset
import argparse
from progiter import ProgIter
import torch
import time
from torchvision import models
import matplotlib.pyplot as plt
from torch import nn
import torch.optim as optim
def fit(model, loss_func, opt, train_dl, valid_dl, device = "cpu", epochs = 10):
train_losses = []
val_losses = []
best_model_wts = model.state_dict()
best_loss = 100.0
for epoch in range(epochs):
model.train()
loss_sum = 0
loss_sum = 0
for xb, yb in ProgIter(train_dl):
xb, yb = xb.to(device), yb.to(device)
loss = loss_func(model(xb).reshape(-1), yb.float())
loss_sum += loss.item()
loss.backward()
opt.step()
opt.zero_grad()
print('train loss {:.3f}'.format(loss_sum / len(train_dl)))
train_losses.append(loss_sum / len(train_dl))
model.eval()
loss_sum = 0
correct = 0
num = 0
with torch.no_grad():
for xb, yb in ProgIter(valid_dl):
xb, yb = xb.to(device), yb.to(device)
probs = model(xb).reshape(-1)
loss_sum += loss_func(probs, yb).item()
_, preds = torch.max(probs, axis=-1)
correct += (preds == yb).sum().item()
num += len(xb)
val_loss = loss_sum / len(valid_dl)
print('val loss {:.3f}'.format(val_loss))
val_losses.append(val_loss)
# если достиглось лучшее качество, то запомним веса модели
if val_loss < best_loss:
best_loss = val_loss
best_model_wts = model.state_dict()
else:
print("Loss was better in the previous epoch")
torch.save(model, './pytorch_trained_models/model_'+ str(epoch) + time.strftime("%Y%m%d-%H%M%S") +'.pht')
return train_losses, val_losses
def model_preparation(base_model, device):
if base_model == 'densenet':
model = models.densenet161(pretrained = True, progress = False)
in_features = 2208
print("\n The model is fine-tuned on DenseNet \n")
elif base_model == 'mobilenet':
model = models.mobilenet_v2(pretrained = True, progress = False)
in_features = 1280
print("\n The model is fine-tuned on MobileNet \n")
elif base_model == 'alexnet':
model = models.alexnet(pretrained = True, progress = False)
in_features = 9216
print("\n The model is fine-tuned on AlexNet \n")
else:
model = models.vgg16(pretrained = True, progress = False)
in_features = 25088
print("\n The model is fine-tuned on VGG16 \n")
#model = models.vgg16(pretrained = True, progress = False)
model.classifier = nn.Sequential(
nn.Linear(in_features=in_features, out_features=4096, bias=True),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(in_features=4096, out_features=4096, bias=True),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(in_features=4096, out_features=1, bias=True)
)
for param in model.features.parameters():
param.requires_grad = False
model = model.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
return model, criterion, optimizer
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--base_model', type=str, help='base model',
default = "vgg16")
parser.add_argument('--train_augmentation', type=bool, help='train augmentation?',
default = False)
parser.add_argument('--train_scores', type=str, help='csv file with scores for training',
default = 'scores/train_crop.csv')
parser.add_argument('--test_scores', type=str, help='csv file with scores for validation',
default = 'scores/test_crop.csv')
parser.add_argument('--batch_size', type=int, help='batch size',
default = 16)
parser.add_argument('--epochs', type=int, help='number of epochs',
default = 25)
parser.add_argument('--num_workers', type=int, help='number of workers',
default = 8)
parser.add_argument('--pin_memory', type=int, help='pin_memory',
default = True)
args = parser.parse_args()
base_model = args.base_model
train_aug = args.train_augmentation
train_scores = args.train_scores
test_scores = args.test_scores
batch = args.batch_size
epochs = args.epochs
num_workers = args.num_workers
pin_memory = args.pin_memory
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#device = torch.device("cpu")
print('The network is training on',device)
traindata, testdata = pytorch_mebeauty_dataset.train_test_data(train_scores, test_scores, train_augmentation = train_aug,
batch = batch, num_workers = num_workers, pin_memory = pin_memory) # train and test dataloaders
model, criterion, optimizer = model_preparation(base_model, device)
train_loss, val_loss = fit(model, criterion, optimizer, traindata, testdata, device, epochs)
#mebeauty_dataset.plot_training(train_loss, val_loss)