-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval.py
93 lines (69 loc) · 2.47 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
'''
Evaluate the model
usage: eval.py [-h] [-m MODEL] [-d DATA_DIR] [--batch BATCH]
optional arguments:
-h, --help show this help message and exit
-m MODEL, --model MODEL
Path to trained model.
-d DATA_DIR, --data_dir DATA_DIR
Image data folder.
--batch BATCH Batch size.
'''
import argparse
import torch
from torchvision import transforms
from utils import load_data
def create_dataloader(data_dir, batch_size):
'''
Create dataloader
'''
im_transforms = transforms.Compose([
transforms.Resize((120, 100)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
_, image_datasets = load_data(data_dir, transform=im_transforms)
return torch.utils.data.DataLoader(
image_datasets['test'], batch_size=batch_size, shuffle=True, num_workers=4)
def evaluate(model_path, data_dir, batch_size):
'''
Evaluate the model
'''
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Create dataloader
dataloader = create_dataloader(data_dir, batch_size)
model = torch.load(model_path)
model.to(device)
model.eval()
running_corrects = 0
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(False):
outputs = model(inputs)
_, predict = torch.max(outputs, 1)
for x, y in zip(predict, labels.data):
if torch.equal(x, y):
running_corrects += 1
else:
# print(x, y)
pass
data_len = len(dataloader.dataset)
print("Test accuracy: ", float(running_corrects) / data_len)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model',
default='model.pt',
type=str,
help='Path to trained model.')
parser.add_argument('-d', '--data_dir',
default='images/char-4-epoch-6',
type=str,
help='Image data folder.')
parser.add_argument('--batch',
default=16,
type=int,
help='Batch size.')
hp = parser.parse_args()
evaluate(hp.model, hp.data_dir, hp.batch)