forked from YanchaoYang/FDA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
135 lines (102 loc) · 5.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch.nn.functional as F
import numpy as np
from options.train_options import TrainOptions
from utils.timer import Timer
import os
from data import CreateSrcDataLoader
from data import CreateTrgDataLoader
from model import CreateModel
#import tensorboardX
import torch.backends.cudnn as cudnn
import torch
from torch.autograd import Variable
from utils import FDA_source_to_target
import scipy.io as sio
IMG_MEAN = np.array((115.11354771, 86.17001789, 76.51190912), dtype=np.float32)#blender values
#IMG_MEAN = np.array((59.11354771, 65.17001789, 46.51190912), dtype=np.float32)#target values
IMG_MEAN = torch.reshape( torch.from_numpy(IMG_MEAN), (1,3,1,1) )
CS_weights = np.array( (1.0, 1.0), dtype=np.float32 )
CS_weights = torch.from_numpy(CS_weights)
def main():
opt = TrainOptions()
args = opt.initialize()
os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
_t = {'iter time' : Timer()}
model_name = args.source + '_to_' + args.target
if not os.path.exists(args.snapshot_dir):
os.makedirs(args.snapshot_dir)
os.makedirs(os.path.join(args.snapshot_dir, 'logs'))
opt.print_options(args)
sourceloader, targetloader = CreateSrcDataLoader(args), CreateTrgDataLoader(args)
sourceloader_iter, targetloader_iter = iter(sourceloader), iter(targetloader)
model, optimizer = CreateModel(args)
start_iter = 0
cudnn.enabled = True
cudnn.benchmark = True
model.train()
model.cuda()
# losses to log
loss = ['loss_seg_src', 'loss_seg_trg']
loss_train = 0.0
loss_val = 0.0
loss_train_list = []
loss_val_list = []
mean_img = torch.zeros(1, 1)
class_weights = Variable(CS_weights).cuda()
_t['iter time'].tic()
for i in range(start_iter, args.num_steps):
model.adjust_learning_rate(args, optimizer, i) # adjust learning rate
optimizer.zero_grad() # zero grad
src_img, src_lbl, _, _ = next(sourceloader_iter) #_ .next() # new batch source
trg_img, trg_lbl, _, _ = next(targetloader_iter) #.next() # new batch target
scr_img_copy = src_img.clone()
if mean_img.shape[-1] < 2:
B, C, H, W = src_img.shape
mean_img = IMG_MEAN.repeat(B,1,H,W)
#-------------------------------------------------------------------#
# 1. source to target, target to target
src_in_trg = FDA_source_to_target( src_img, trg_img, L=args.LB ) # src_lbl
trg_in_trg = trg_img
# 2. subtract mean
src_img = src_in_trg.clone() - mean_img # src, src_lbl
trg_img = trg_in_trg.clone() - mean_img # trg, trg_lbl
#-------------------------------------------------------------------#
# evaluate and update params #####
src_img, src_lbl = Variable(src_img).cuda(), Variable(src_lbl.long()).cuda() # to gpu
src_seg_score = model(src_img, lbl=src_lbl, weight=class_weights, ita=args.ita) # forward pass
loss_seg_src = model.loss_seg # get loss
loss_ent_src = model.loss_ent
# get target loss, only entropy for backpro
trg_img, trg_lbl = Variable(trg_img).cuda(), Variable(trg_lbl.long()).cuda() # to gpu
trg_seg_score = model(trg_img, lbl=trg_lbl, weight=class_weights, ita=args.ita) # forward pass
loss_seg_trg = model.loss_seg # get loss
loss_ent_trg = model.loss_ent
triger_ent = 0.0
if i > args.switch2entropy:
triger_ent = 1.0
loss_all = loss_seg_src + triger_ent * args.entW * loss_ent_trg # loss of seg on src, and ent on s and t
loss_all.backward()
optimizer.step()
loss_train += loss_seg_src.detach().cpu().numpy()
loss_val += loss_seg_trg.detach().cpu().numpy()
if (i+1) % args.save_pred_every == 0:
print('taking snapshot ...')
torch.save( model.state_dict(), os.path.join(args.snapshot_dir, '%s_' % (args.source) + str(i+1) + '.pth') )
if (i+1) % args.print_freq == 0:
_t['iter time'].toc(average=False)
print('[it %d][src seg loss %.4f][trg seg loss %.4f][lr %.4f][%.2fs]' % \
(i + 1, loss_seg_src.data, loss_seg_trg.data, optimizer.param_groups[0]['lr']*10000, _t['iter time'].diff) )
sio.savemat(args.tempdata, {'src_img':src_img.cpu().numpy(), 'trg_img':trg_img.cpu().numpy()})
loss_train /= args.print_freq
loss_val /= args.print_freq
loss_train_list.append(loss_train)
loss_val_list.append(loss_val)
sio.savemat( args.matname, {'loss_train':loss_train_list, 'loss_val':loss_val_list} )
loss_train = 0.0
loss_val = 0.0
if i + 1 > args.num_steps_stop:
print('finish training')
break
_t['iter time'].tic()
if __name__ == '__main__':
main()