-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrial_nas.py
142 lines (103 loc) · 4.84 KB
/
trial_nas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import argparse
import yaml
import torch
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# local imports
from models.nas import Niche_YOLO_NAS
from data.splitter.yolo import YOLO_Splitter
# constants
ROOT = os.path.dirname(__file__)
DIR_OUT = os.path.join(ROOT, "out")
DIR_MODEL = os.path.join(ROOT, "models")
DIR_DATA = os.path.join(ROOT, "data")
# Adjusted to match your directory structure
DIR_COW200 = os.path.join(DIR_DATA, "cow200", "yolov5" )
#DIR_COW200 = os.path.join(DIR_DATA, "2_light")
#DIR_COW200 = os.path.join(DIR_DATA, "1a_angle_t2s", "tv" )
# model configuration
BATCH = 16
EPOCHS = 2 #100
def main(args): # parse arguments
i = args.iter
n_train = args.n_train
yolo_base = args.yolo_base
suffix = args.suffix
# shuffle dataset
splitter = YOLO_Splitter(DIR_COW200, classes=["cow"], suffix=suffix+f"_{yolo_base}_{n_train}_{i}")
#splitter = YOLO_Splitter(DIR_COW200, classes=["cow"], suffix=suffix)
#splitter.shuffle_train_val(n_included=n_train)
splitter.shuffle_train_val(n_included=n_train,k=5)
path_data = splitter.write_dataset()
# print("----------------------------------------------------------------------------")
# print('path_data', path_data)
# print("----------------------------------------------------------------------------")
# log
print("----------------------------------------------------------------------------")
print(f"n_train: {n_train}, yolo_base: {yolo_base}, i: {i}, {suffix}")
print("----------------------------------------------------------------------------")
# variable batch sizes
if n_train<=10:
BATCH = 2
elif 10<n_train<=50:
BATCH = 5
elif 50<n_train<=100:
BATCH = 10
elif 100<n_train:
BATCH =16
# define paths
#name_task = f"n{n_train}_{yolo_base[:-3]}_i{i}_{suffix}"
name_task = f"n{n_train}_{yolo_base}_i{i}"
DIR_OUT_split = os.path.join(DIR_COW200, f'{suffix}_{yolo_base}_{n_train}_{i}')
print('DIR_current',DIR_OUT_split)
print("----------------------------------------------------------------------------")
# configure model
yolo_nas = Niche_YOLO_NAS(
path_model=yolo_base,
dir_train=os.path.join(DIR_OUT_split, "train"),
dir_val=os.path.join(DIR_OUT_split, "val"),
dir_test=os.path.join(DIR_OUT_split, "test"),
name_task=name_task
)
# paths for the train and validation text files
# os.path.join(DIR_COW200, 'train.txt')
#path_train_txt = os.path.join(os.path.split(path_data)[0], 'train.txt')
# os.path.join(DIR_COW200, 'val.txt')
#path_val_txt = os.path.join(os.path.split(path_data)[0], 'val.txt')
# os.path.join(DIR_COW200, 'test.txt')
#path_test_txt = os.path.join(os.path.split(path_data)[0], 'test.txt')
checkpoint_dir = ROOT + '/checkpoints/' + 'n'+str(n_train) + '_' + yolo_base + '_' + 'i' +str(i)
path_train_txt = os.path.join(DIR_OUT_split, 'train.txt')
path_val_txt = os.path.join(DIR_OUT_split, 'val.txt') #os.path.join(DIR_COW200, 'val.txt')
path_test_txt = os.path.join(DIR_OUT_split, 'test.txt') #os.path.join(DIR_COW200, 'test.txt')
print("----------------------------------------------------------------------------")
print('path_train_txt', path_train_txt)
print('path_val_txt', path_val_txt)
print("----------------------------------------------------------------------------")
# path for the yaml file
# os.path.join(DIR_COW200, 'data.yaml')
#path_yaml = os.path.join(os.path.split(path_data)[0], 'data.yaml')
path_yaml = os.path.join(DIR_OUT_split, 'data.yaml') #os.path.join(DIR_COW200, 'data.yaml')
print('path_yaml', path_yaml)
# train
yolo_nas.train(path_yaml, path_train_txt, path_val_txt, BATCH, EPOCHS)
## new function to keep only the best_ckpt.pth
yolo_nas.remove_ckpt(checkpoint_dir,'latest')
### perfom evaluation
yolo_nas.evaluate_test_set(ROOT,yolo_base, config, exp_name, n, iteration)
# new function to keep only the best_ckpt.pth
yolo_nas.remove_ckpt(checkpoint_dir,'best')
if __name__ == "__main__":
# parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--iter", type=int, help="iteration number")
parser.add_argument("--n_train", type=int,
help="number of images in training set")
parser.add_argument("--yolo_base", type=str,
help="e.g., yolo8n, yolo8m, yolo8x")
parser.add_argument("--suffix", type=str, help="suffix for folder name")
#parser.add_argument("--dataset", type=str, help="e.g., 1a_angle_t2s, 1b_angle_s2t, 2_light, 3_breed, 4_all")
args = parser.parse_args()
main(args)