-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathrun_posenet.py
184 lines (156 loc) · 6.4 KB
/
run_posenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import math
import os
from posenet import GoogLeNet as PoseNet
from tensorflow.python.training import training_util
from tqdm import tqdm
import tensorflow as tf
import numpy as np
from data import get_data, gen_data_batch
from net_builder import build_posenet, add_pose_loss
max_num_train_images = -1
max_num_test_images = -1
batch_size = 48
max_iterations = 30000
max_iterations = 1
display_interval = 20
save_interval = 1000
test_interval = 1000
# Set this path to your data_file data_dir
data_dir = '/home/user/Datasets/camera_relocalization/KingsCollege'
train_data_file = 'dataset_train.txt'
test_data_file = 'dataset_train.txt'
model_path = '/home/user/Datasets/tensorflow/models/mobilenet/mobilenet_v1_1.0_224_2017_06_14/mobilenet_v1_1.0_224.ckpt'
checkpoint_dir = 'checkpoint'
output_checkpoint_dir = 'checkpoint'
checkpoint_file = 'posenet_mobilenet.ckpt'
train = True
test = True
test_first = True
debug = False
def should_load(name):
if name.startswith('cls') and name.find('_fc_pose_') != -1:
return False
if name.find('Logits') != -1 or name.find('Predictions') != -1:
return False
return True
def load_data(data_dir, data_file, max_num_images=-1):
data_path = os.path.join(data_dir, data_file)
if max_num_images >= 0:
data_source = get_data(data_path, data_dir, max_num_images)
else:
data_source = get_data(data_path, data_dir)
num_images = len(data_source.images)
num_batches = (num_images + batch_size - 1) / batch_size
print 'num_images', num_images, 'batch_size', batch_size, 'num_batches', num_batches
return data_source
def main():
images = tf.placeholder(tf.float32, [batch_size, 224, 224, 3])
poses_x = tf.placeholder(tf.float32, [batch_size, 3])
poses_q = tf.placeholder(tf.float32, [batch_size, 4])
print 'build_posenet'
net = build_posenet(images, 'mobilenet')
# net = PoseNet({'data': images})
loss = add_pose_loss(net, poses_x, poses_q)
print 'loss', loss
global_step = training_util.create_global_step()
opt = tf.train.AdamOptimizer(
learning_rate=0.0001,
beta1=0.9,
beta2=0.999,
epsilon=0.00000001,
use_locking=False,
name='Adam').minimize(
loss, global_step=global_step)
# Set GPU options
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6833)
init = tf.global_variables_initializer()
p3_x = net['cls3_fc_pose_xyz']
p3_q = net['cls3_fc_pose_wpqr']
variables_to_restore = tf.global_variables()
if debug:
print '\n variables_to_restore', variables_to_restore
variables_to_save = tf.global_variables()
if debug:
print '\n variables_to_save', variables_to_save
restorer = tf.train.Saver(variables_to_restore)
saver = tf.train.Saver(variables_to_save)
output_checkpoint = os.path.join(output_checkpoint_dir, checkpoint_file)
checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if checkpoint is None:
checkpoint = model_path
print 'checkpoint', checkpoint
if train:
train_data_source = load_data(data_dir, train_data_file,
max_num_train_images)
if test:
test_data_source = load_data(data_dir, test_data_file, max_num_test_images)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
# Load the data
sess.run(init)
# Restore model weights from previously saved model
try:
restorer.restore(sess, checkpoint)
except:
print 'Failed to restore from model:', model_path
variables_to_restore = [
x for x in tf.global_variables() if should_load(x.name)
]
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, checkpoint)
print('Model restored from file: %s' % checkpoint)
if train:
train_data_batch_generator = gen_data_batch(train_data_source,
batch_size)
if test:
test_data_batch_generator = gen_data_batch(test_data_source, batch_size)
num_test_images = len(test_data_source.images)
num_test_batches = (num_test_images + batch_size - 1) / batch_size
iter = -1
for i in range(max_iterations):
if (i > 0 or test_first) and i % test_interval == 0:
print 'Validating'
results = np.zeros((num_test_images, 2))
for j in tqdm(range(num_test_batches)):
np_image, np_poses_x, np_poses_q = next(test_data_batch_generator)
if debug:
print 'np_image', np_image.shape, np_poses_x.shape, np_poses_q.shape
feed = {images: np_image}
predicted_x, predicted_q = sess.run([p3_x, p3_q], feed_dict=feed)
predicted_q = np.squeeze(predicted_q)
predicted_x = np.squeeze(predicted_x)
batch_start = batch_size * j
batch_end = min(batch_start + batch_size, num_test_images)
pose_q = np.asarray(
test_data_source.poses[batch_start:batch_end])[:, 3:7]
pose_x = np.asarray(
test_data_source.poses[batch_start:batch_end])[:, 0:3]
pose_q = np.squeeze(pose_q)
pose_x = np.squeeze(pose_x)
predicted_q = predicted_q[:batch_end - batch_start]
predicted_x = predicted_x[:batch_end - batch_start]
#Compute Individual Sample Error
pose_q /= np.linalg.norm(pose_q)
predicted_q /= np.linalg.norm(predicted_q)
d = abs(np.sum(np.multiply(pose_q, predicted_q), axis=1))
theta = 2 * np.arccos(d) * 180 / math.pi
error_x = np.linalg.norm(pose_x - predicted_x, axis=1)
results[batch_start:batch_end, :] = np.column_stack((error_x, theta))
median_result = np.median(results, axis=0)
print 'Median error ', median_result[0], 'm and ', median_result[
1], 'degrees.'
if train:
np_images, np_poses_x, np_poses_q = next(train_data_batch_generator)
feed = {images: np_images, poses_x: np_poses_x, poses_q: np_poses_q}
sess.run(opt, feed_dict=feed)
np_loss = sess.run(loss, feed_dict=feed)
if i > 0 and i % display_interval == 0:
print('Iteration: ' + str(i) + '\n\t' + 'Loss is: ' + str(np_loss))
if i > 0 and i % save_interval == 0:
saver.save(sess, output_file, global_step=global_step)
print('Intermediate file saved at: ' + output_file)
iter = i
if iter > 0 and iter % save_interval != 0:
saver.save(sess, output_file, global_step=global_step)
print('Intermediate file saved at: ' + output_file)
if __name__ == '__main__':
main()