From eb0c45558ae7f6afa18f60f5c5fef29c0d322359 Mon Sep 17 00:00:00 2001 From: Mesh TensorFlow Team Date: Wed, 27 Jan 2021 18:34:35 -0800 Subject: [PATCH] Decode Unicode strings in inference mode. This is already done in eval mode. This CL applies the same logic for inference. PiperOrigin-RevId: 354219241 --- mesh_tensorflow/transformer/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index 355c0e7f..4b92ca7c 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -1220,8 +1220,11 @@ def input_fn(params): return dataset checkpoint_step = get_step_from_checkpoint_path(checkpoint_path) - decodes = decode( - estimator, input_fn, vocabulary, checkpoint_path=checkpoint_path) + decodes = [ + d.decode("utf-8") if isinstance(d, bytes) else d + for d in decode(estimator, input_fn, vocabulary, checkpoint_path) + ] + # Remove any padded examples dataset_size = len(inputs) * repeats decodes = decodes[:dataset_size]