Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 54b01b4

Browse files
author
Mesh TensorFlow Team
committed
Modified the eval_model function in mesh_tensorflow/transformer/utils.py to accept Summary protos in addition to tag-to-scalar dicts.
PiperOrigin-RevId: 375981279
1 parent da90793 commit 54b01b4

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,10 +2268,15 @@ def eval_model(estimator,
22682268
summary = tf.Summary()
22692269
targets = cached_targets[eval_dataset.name]
22702270
metric_result = metric_fn(targets, predictions)
2271-
for metric_name, metric_value in metric_result.items():
2272-
tag = "eval/{}/{}".format(eval_dataset.name, metric_name)
2273-
tf.logging.info("%s at step %d: %.3f", tag, global_step, metric_value)
2274-
summary.value.add(tag=tag, simple_value=metric_value)
2271+
if isinstance(metric_result, tf.Summary):
2272+
tf.logging.info("Precomputed summary at step %d", global_step)
2273+
summary_writer.add_summary(metric_result, global_step)
2274+
else:
2275+
for metric_name, metric_value in metric_result.items():
2276+
tag = "eval/{}/{}".format(eval_dataset.name, metric_name)
2277+
tf.logging.info("%s at step %d: %.3f", tag, global_step,
2278+
metric_value)
2279+
summary.value.add(tag=tag, simple_value=metric_value)
22752280
summary_writer.add_summary(summary, global_step)
22762281
summary_writer.flush()
22772282

0 commit comments

Comments
 (0)