Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 878832b

Browse files
afrozenatorMesh TensorFlow Team
authored andcommitted
[T5/MTF] Fix predict_fn flow.
PiperOrigin-RevId: 360801432
1 parent e22cc2b commit 878832b

File tree

1 file changed

+31
-29
lines changed

1 file changed

+31
-29
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def variable_filter_max_size(v, max_size=1e7):
280280
return v.size <= max_size
281281

282282

283-
@gin.configurable
283+
@gin.configurable(denylist=["predict_fn"]) # pass `predict_fn` through `run`
284284
def tpu_estimator_model_fn(model_type,
285285
transformer_model,
286286
vocabulary,
@@ -484,42 +484,44 @@ def _maybe_detokenize(ids, vocab):
484484
return ids
485485
if mode == "score":
486486
# compute log-likelihoods per sequence
487+
targets = mtf_features["targets"]
487488
if predict_fn:
488489
# predict_fn contains a custom scoring function
489-
# this code-path has not been tested
490490
scores = predict_fn(
491491
model=transformer_model,
492492
features=mtf_features,
493493
variable_dtype=get_variable_dtype())
494-
targets = mtf_features["targets"]
495-
if isinstance(transformer_model, transformer.Unitransformer):
496-
length_dim = targets.shape.dims[-1]
497-
inputs = transformer.autoregressive_inputs(
498-
mtf_features["targets"])
499-
elif isinstance(transformer_model,
500-
(transformer.Bitransformer,
501-
transformer.StudentTeacher)):
502-
inputs = mtf_features["inputs"]
503494
else:
504-
raise ValueError("unrecognized class")
505-
logits, _ = transformer_model.call_simple(
506-
inputs=inputs,
507-
targets=targets,
508-
compute_loss=False,
509-
mode=mode,
510-
variable_dtype=get_variable_dtype())
511-
logits = mtf.cast(logits, tf.float32)
512-
_, length_dim, vocab_dim = logits.shape.dims
495+
if isinstance(transformer_model, transformer.Unitransformer):
496+
length_dim = targets.shape.dims[-1]
497+
inputs = transformer.autoregressive_inputs(
498+
mtf_features["targets"])
499+
elif isinstance(transformer_model,
500+
(transformer.Bitransformer,
501+
transformer.StudentTeacher)):
502+
inputs = mtf_features["inputs"]
503+
else:
504+
raise ValueError("unrecognized class")
505+
logits, _ = transformer_model.call_simple(
506+
inputs=inputs,
507+
targets=targets,
508+
compute_loss=False,
509+
mode=mode,
510+
variable_dtype=get_variable_dtype())
511+
logits = mtf.cast(logits, tf.float32)
512+
_, length_dim, vocab_dim = logits.shape.dims
513+
514+
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
515+
logits, mtf_features["targets"], vocab_dim)
516+
# 0=padding and negative targets are a hack to indicate no loss
517+
cross_entropy *= mtf.cast(
518+
mtf.greater(targets, 0), cross_entropy.dtype)
519+
if model_type == "delimited_lm":
520+
cross_entropy *= mtf.cast(mtf.logical_not(
521+
transformer.delimited_lm_inputs_mask(targets)),
522+
cross_entropy.dtype)
523+
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
513524

514-
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
515-
logits, mtf_features["targets"], vocab_dim)
516-
# 0=padding and negative targets are a hack to indicate no loss
517-
cross_entropy *= mtf.cast(
518-
mtf.greater(targets, 0), cross_entropy.dtype)
519-
if model_type == "delimited_lm":
520-
cross_entropy *= mtf.cast(mtf.logical_not(
521-
transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
522-
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
523525
scores = mtf.anonymize(scores)
524526
targets = mtf.anonymize(targets)
525527
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)

0 commit comments

Comments
 (0)