-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_submission.py
54 lines (38 loc) · 1.61 KB
/
create_submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import tensorflow as tf
# =================================================================================================
def dataset_input_fn(filenames):
dataset = tf.data.TFRecordDataset(filenames)
# Parse single example
def parser(record):
keys_to_features = {
'pid': tf.FixedLenFeature((), tf.int64, default_value=0),
'category': tf.FixedLenFeature((), tf.int64, default_value=0),
}
parsed = tf.parse_single_example(record, keys_to_features)
pid = [tf.cast(parsed["pid"], tf.int32)]
category = [tf.cast(parsed["category"], tf.int32)]
return pid, category
dataset = dataset.map(parser)
dataset = dataset.batch(1)
dataset = dataset.repeat(1)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
# =================================================================================================
next_example, next_label = dataset_input_fn('data/prediction.tfrecords')
file = open('submission.csv', mode='wb')
file.write('_id,category_id'.encode('UTF-8'))
file.write('\n'.encode('UTF-8'))
with tf.Session() as sess:
while True:
try:
# Write examples from record to csv
pid, cat = sess.run([next_example, next_label])
file.write(str(pid[0][0]).encode('UTF-8'))
file.write(','.encode('UTF-8'))
file.write(str(cat[0][0]).encode('UTF-8'))
file.write('\n'.encode('UTF-8'))
except tf.errors.OutOfRangeError:
print("run finished")
break