-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
77 lines (67 loc) · 2.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# This is a sample Python script.
# Press Shift+F10 to execute it or replace it with your code.
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
import argparse
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import models
from cls_function import train
# from utils import get_optimizer
from models.mdeq import MDEQClsNet
from config import config
from models.mdeq import get_cls_net
from config import update_config
# Press the green button in the gutter to run the script.
def parse_args():
parser = argparse.ArgumentParser(description='Train classification network')
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
parser.add_argument('--modelDir',
help='model directory',
type=str,
default='')
parser.add_argument('--logDir',
help='log directory',
type=str,
default='')
parser.add_argument('--dataDir',
help='data directory',
type=str,
default='')
parser.add_argument('--testModel',
help='testModel',
type=str,
default='')
parser.add_argument('--percent',
help='percentage of training data to use',
type=float,
default=1.0)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
update_config(config, args)
return args
if __name__ == '__main__':
input = torch.randn(64,3,32,32)
print(config.MODEL.NAME)
args = parse_args()
model = eval('models.'+config.MODEL.NAME+'.get_cls_net')(config)
print(sum(torch.numel(parameter) for parameter in model.parameters()))
train_loader = [];
optimizer = get_optimizer(config, model)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, len(train_loader) * config.TRAIN.END_EPOCH, eta_min=1e-6)
train(config, input, model,lr_scheduler)
print(" ")