-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain_rads.py
39 lines (30 loc) · 1.04 KB
/
main_rads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# import necessary modules
import argparse
import json
import utils
from data import CADRAData
from model import CADResNet
from torch.utils.data import DataLoader
def test(config):
model = CADResNet(config["MODEL"])
data_test = CADRAData(config["DATA"], mode="test", dset="IDR_CADRADS")
loader_test = DataLoader(data_test, batch_size=config["DATA"]["batch_size"], shuffle=False, num_workers=0)
conf_matrix, y_pred, y_true = model.infer(loader_test)
print(conf_matrix)
if __name__ == "__main__":
# parse config file
parser = argparse.ArgumentParser()
parser.add_argument(
'config',
metavar='config_json_file',
default='None',
help='The configuration file for training/testing the FanCNN')
args = parser.parse_args()
config = json.load(open(args.config))
config_g = config["GENERAL"]
print(json.dumps(config, sort_keys=True, indent=4))
# init
utils.seed_everything(config_g["seed"])
utils.set_gpu(config_g["gpu"])
if config_g["mode"] == "test":
test(config)