-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodelcut.py
80 lines (63 loc) · 2.68 KB
/
modelcut.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os, torch, glob
from lib.models.exot import build_exotst, build_exotst1
net = build_exotst(cfg)
def save_checkpoint(self):
"""Saves a checkpoint of the network and other variables."""
net = self.actor.net
actor_type = type(self.actor).__name__
net_type = type(net).__name__
state = {
'epoch': self.epoch,
'actor_type': actor_type,
'net_type': net_type,
'net': net.state_dict(),
'net_info': getattr(net, 'info', None),
'constructor': getattr(net, 'constructor', None),
'optimizer': self.optimizer.state_dict(),
'stats': self.stats,
'settings': self.settings
}
directory = '{}/{}'.format(self._checkpoint_dir, self.settings.project_path)
print(directory)
if not os.path.exists(directory):
print("directory doesn't exist. creating...")
os.makedirs(directory)
# First save as a tmp file
tmp_file_path = '{}/{}_ep{:04d}.tmp'.format(directory, net_type, self.epoch)
torch.save(state, tmp_file_path)
file_path = '{}/{}_ep{:04d}.pth.tar'.format(directory, net_type, self.epoch)
# Now rename to actual checkpoint. os.rename seems to be atomic if files are on same filesystem. Not 100% sure
os.rename(tmp_file_path, file_path)
def load_state_dict( checkpoint=None, distill=False):
"""Loads a network checkpoint file.
Can be called in three different ways:
load_checkpoint():
Loads the latest epoch from the workspace. Use this to continue training.
load_checkpoint(epoch_num):
Loads the network at the given epoch number (int).
load_checkpoint(path_to_checkpoint):
Loads the file from the given absolute path (str).
"""
net = self.actor.net
net_type = type(net).__name__
if isinstance(checkpoint, str):
# checkpoint is the path
if os.path.isdir(checkpoint):
checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
if checkpoint_list:
checkpoint_path = checkpoint_list[-1]
else:
raise Exception('No checkpoint found')
else:
checkpoint_path = os.path.expanduser(checkpoint)
else:
raise TypeError
# Load network
print("Loading pretrained model from ", checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'
missing_k, unexpected_k = net.load_state_dict(checkpoint_dict["net"], strict=False)
print("previous checkpoint is loaded.")
print("missing keys: ", missing_k)
print("unexpected keys:", unexpected_k)
return True