Skip to content

Commit 8a5c348

Browse files
authored
[r0.30.0 cherry-pick] Use jax2tf batch polymorphism for the Penguin Flax experimental example. (#3975)
1 parent 4b68118 commit 8a5c348

File tree

3 files changed

+12
-15
lines changed

3 files changed

+12
-15
lines changed

tfx/dependencies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def make_extra_packages_examples():
189189
# examples/penguin/experimental/penguin_pipeline_sklearn_gcp.py)
190190
# Required for the experimental tfx/examples using Flax, e.g.,
191191
# tfx/examples/penguin.
192-
'jax>=0.2.12,<0.3',
192+
'jax>=0.2.13,<0.3',
193193
'jaxlib>=0.1.64,<0.2',
194194
'flax>=0.3.3,<0.4',
195195
# Required for tfx/examples/penguin/penguin_utils_cloud_tuner.py

tfx/examples/penguin/penguin_utils_base.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Flax.
1818
"""
1919

20-
from typing import List, Optional, Text
20+
from typing import List, Text
2121
import tensorflow as tf
2222
import tensorflow_transform as tft
2323

@@ -39,21 +39,19 @@ def transformed_name(key):
3939

4040

4141
def make_serving_signatures(model,
42-
tf_transform_features: tft.TFTransformOutput,
43-
serving_batch_size: Optional[int] = None):
42+
tf_transform_features: tft.TFTransformOutput):
4443
"""Returns the serving signatures.
4544
4645
Args:
4746
model: the model function to apply to the transformed features.
4847
tf_transform_features: The transformation to apply to the serialized
4948
tf.Example.
50-
serving_batch_size: an optional specification for a concrete serving batch
51-
size.
5249
5350
Returns:
5451
The signatures to use for saving the mode. The 'serving_default' signature
55-
will be a concrete function that takes a serialized tf.Example, parses it,
56-
transformes the features and then applies the model.
52+
will be a concrete function that takes a batch of unspecified length of
53+
serialized tf.Example, parses them, transformes the features and
54+
then applies the model.
5755
"""
5856

5957
model.tft_layer = tf_transform_features.transform_features_layer()
@@ -72,8 +70,7 @@ def serve_tf_examples_fn(serialized_tf_examples):
7270
return {
7371
'serving_default':
7472
serve_tf_examples_fn.get_concrete_function(
75-
tf.TensorSpec(
76-
shape=[serving_batch_size], dtype=tf.string, name='examples'))
73+
tf.TensorSpec(shape=[None], dtype=tf.string, name='examples'))
7774
}
7875

7976

tfx/examples/penguin/penguin_utils_flax_experimental.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,10 @@ def predict(params: _Params, inputs: _InputBatch):
131131

132132
trained_params = optimizer.target
133133

134-
# Convert the prediction function to TF.
135-
tf_fn = jax2tf.convert(predict, with_gradient=False, enable_xla=True)
134+
# Convert the prediction function to TF, with a variable batch dimension
135+
# for all inputs.
136+
tf_fn = jax2tf.convert(predict, with_gradient=False, enable_xla=True,
137+
polymorphic_shapes=(None, '(b, 1)'))
136138

137139
# Create tf.Variables for the parameters. If you want more useful variable
138140
# names, you can use `tree.map_structure_with_path` from the `dm-tree`
@@ -327,8 +329,6 @@ def run_fn(fn_args: tfx.components.FnArgs):
327329
steps_per_epoch=fn_args.train_steps,
328330
eval_steps_per_epoch=fn_args.eval_steps,
329331
tensorboard_log_dir=fn_args.model_run_dir)
330-
# TODO(b/180721874): batch polymorphic model not yet supported.
331332

332-
signatures = base.make_serving_signatures(model, tf_transform_output,
333-
serving_batch_size=1)
333+
signatures = base.make_serving_signatures(model, tf_transform_output)
334334
tf.saved_model.save(model, fn_args.serving_model_dir, signatures=signatures)

0 commit comments

Comments
 (0)