Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch>1.0 and cpu compatible #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
torch>1.0 and cpu compatible
  • Loading branch information
Bart Michiels committed May 14, 2020
commit 6957951e119edd9028f70d3bd79abbf771ecb400
26 changes: 21 additions & 5 deletions tools/transfer.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from covidaid import CovidAID, CheXNet
import torch
import argparse
import re

parser = argparse.ArgumentParser()
parser.add_argument("--combine_pneumonia", action='store_true', default=False)
@@ -19,10 +20,22 @@
model = CovidAID(combine_pneumonia=args.combine_pneumonia)

def load_weights(checkpoint_pth, state_dict=True):
model = torch.load(checkpoint_pth)
if torch.cuda.is_available():
checkpoint = torch.load(checkpoint_pth)
else:
checkpoint = torch.load(checkpoint_pth,map_location=torch.device('cpu'))

if state_dict:
return model['state_dict']
#For compatibility of latest torch versions
pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = checkpoint['state_dict']
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
return checkpoint['state_dict']
else:
return model

@@ -43,7 +56,9 @@ def get_top_keys(model, depth=0):
t_keys = {'module.' + k for k in template.keys()}

assert len(c_keys.difference(t_keys)) == 0
assert len(t_keys.difference(c_keys)) == 0

#This will error for "*.num_batches_tracked" layers
#assert len(t_keys.difference(c_keys)) == 0


# Transfer the feature weights
@@ -59,7 +74,8 @@ def get_top_keys(model, depth=0):
else:
# print (type(template[k]), template[k].size())
# print (type(chexnet_model[chex_key]), chexnet_model[chex_key].size())
assert chexnet_model[chex_key].size() == template[k].size()
template[k] = chexnet_model[chex_key]
if not chex_key.endswith('.num_batches_tracked'):
assert chexnet_model[chex_key].size() == template[k].size()
template[k] = chexnet_model[chex_key]

torch.save(template, covidaid_model_trained_checkpoint)