forked from tensorflow/hub
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport.py
141 lines (119 loc) · 4.77 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Example for exporting a MNIST classifier in SavedModel v2.0 format.
The module has as a single signature, accepting a batch of images with shape
[None, 28, 28, 1] and returning a prediction vector.
In this example, we are loading the MNIST Dataset from TFDS and training a
simple digit classifier.
"""
import argparse
import logging
import sys
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
FLAGS = None
class MNIST(tf.keras.models.Model):
"""Model representing a MNIST classifier."""
def __init__(self, output_activation="softmax"):
super(MNIST, self).__init__()
self.layer_1 = tf.keras.layers.Dense(64)
self.layer_2 = tf.keras.layers.Dense(10, activation=output_activation)
def call(self, inputs):
casted = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float32))(inputs)
flatten = tf.keras.layers.Flatten()(casted)
def normalize_fn(x):
return x / tf.reduce_max(tf.gather(x, 0))
normalize = tf.keras.layers.Lambda(normalize_fn)(flatten)
x = self.layer_1(normalize)
output = self.layer_2(x)
return output
def train_step(model, loss_fn, optimizer_fn, metric, image, label):
"""Perform one training step for the model.
Args:
model: Keras model to train.
loss_fn: Loss function to use.
optimizer_fn: Optimizer function to use.
metric: keras.metric to use.
image: Tensor of training images of shape [batch_size, 28, 28, 1].
label: Tensor of class labels of shape [batch_size].
"""
with tf.GradientTape() as tape:
preds = model(image)
label_onehot = tf.one_hot(label, 10)
loss_ = loss_fn(label_onehot, preds)
grads = tape.gradient(loss_, model.trainable_variables)
optimizer_fn.apply_gradients(zip(grads, model.trainable_variables))
metric(loss_)
def train_and_export(export_path,
buffer_size=1000,
batch_size=32,
learning_rate=1e-3,
epoch=10,
dataset=None):
"""Trains and export the model as SavedModel 2.0.
Args:
export_path: (str) Path to export the trained model.
buffer_size: (int) Size of buffer to use while shuffling.
batch_size: (int) Size of each training batch.
learning_rate: (float) Learning rate to use for the optimizer.
epoch: (int) Number of Epochs to train for.
dataset: (tf.data.Dataset) Dataset object. Defaults to a MNIST dataset.
"""
model = MNIST()
if not dataset:
dataset = tfds.load(
"mnist", split="train", batch_size=batch_size,
shuffle_files=True).shuffle(
buffer_size, reshuffle_each_iteration=True)
optimizer_fn = tf.keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.Mean()
model.compile(optimizer_fn, loss=loss_fn)
# Training loop.
for epoch in range(epoch):
for step, data in enumerate(dataset):
train_step(model, loss_fn, optimizer_fn, metric, data["image"],
data["label"])
print("\rEpoch: #{}\tStep: #{}\tLoss: {}\n".format(
epoch, step,
metric.result().numpy()))
# We have to call either predict or fit to make it possible to export with
# tf.saved_model.save.
model.predict(next(iter(dataset))["image"])
# Export the model as SavedModel 2.0.
tf.saved_model.save(model, export_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--export_path",
type=str,
default=None,
help="Path to export the module")
parser.add_argument(
"--buffer_size",
type=int,
default=1000,
help="Buffer Size to use while shuffling the dataset")
parser.add_argument(
"--batch_size", type=int, default=32, help="Size of each batch")
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="learning rate")
parser.add_argument(
"--epoch", type=int, default=10, help="Number of iterations")
FLAGS, unparsed = parser.parse_known_args()
if not FLAGS.export_path:
logging.error("Must set flag --export_path.")
sys.exit(1)
train_and_export(**vars(FLAGS))