@@ -280,7 +280,7 @@ def variable_filter_max_size(v, max_size=1e7):
280
280
return v .size <= max_size
281
281
282
282
283
- @gin .configurable
283
+ @gin .configurable ( denylist = [ "predict_fn" ]) # pass `predict_fn` through `run`
284
284
def tpu_estimator_model_fn (model_type ,
285
285
transformer_model ,
286
286
vocabulary ,
@@ -484,42 +484,44 @@ def _maybe_detokenize(ids, vocab):
484
484
return ids
485
485
if mode == "score" :
486
486
# compute log-likelihoods per sequence
487
+ targets = mtf_features ["targets" ]
487
488
if predict_fn :
488
489
# predict_fn contains a custom scoring function
489
- # this code-path has not been tested
490
490
scores = predict_fn (
491
491
model = transformer_model ,
492
492
features = mtf_features ,
493
493
variable_dtype = get_variable_dtype ())
494
- targets = mtf_features ["targets" ]
495
- if isinstance (transformer_model , transformer .Unitransformer ):
496
- length_dim = targets .shape .dims [- 1 ]
497
- inputs = transformer .autoregressive_inputs (
498
- mtf_features ["targets" ])
499
- elif isinstance (transformer_model ,
500
- (transformer .Bitransformer ,
501
- transformer .StudentTeacher )):
502
- inputs = mtf_features ["inputs" ]
503
494
else :
504
- raise ValueError ("unrecognized class" )
505
- logits , _ = transformer_model .call_simple (
506
- inputs = inputs ,
507
- targets = targets ,
508
- compute_loss = False ,
509
- mode = mode ,
510
- variable_dtype = get_variable_dtype ())
511
- logits = mtf .cast (logits , tf .float32 )
512
- _ , length_dim , vocab_dim = logits .shape .dims
495
+ if isinstance (transformer_model , transformer .Unitransformer ):
496
+ length_dim = targets .shape .dims [- 1 ]
497
+ inputs = transformer .autoregressive_inputs (
498
+ mtf_features ["targets" ])
499
+ elif isinstance (transformer_model ,
500
+ (transformer .Bitransformer ,
501
+ transformer .StudentTeacher )):
502
+ inputs = mtf_features ["inputs" ]
503
+ else :
504
+ raise ValueError ("unrecognized class" )
505
+ logits , _ = transformer_model .call_simple (
506
+ inputs = inputs ,
507
+ targets = targets ,
508
+ compute_loss = False ,
509
+ mode = mode ,
510
+ variable_dtype = get_variable_dtype ())
511
+ logits = mtf .cast (logits , tf .float32 )
512
+ _ , length_dim , vocab_dim = logits .shape .dims
513
+
514
+ cross_entropy = mtf .layers .softmax_cross_entropy_with_logits (
515
+ logits , mtf_features ["targets" ], vocab_dim )
516
+ # 0=padding and negative targets are a hack to indicate no loss
517
+ cross_entropy *= mtf .cast (
518
+ mtf .greater (targets , 0 ), cross_entropy .dtype )
519
+ if model_type == "delimited_lm" :
520
+ cross_entropy *= mtf .cast (mtf .logical_not (
521
+ transformer .delimited_lm_inputs_mask (targets )),
522
+ cross_entropy .dtype )
523
+ scores = - mtf .reduce_sum (cross_entropy , reduced_dim = length_dim )
513
524
514
- cross_entropy = mtf .layers .softmax_cross_entropy_with_logits (
515
- logits , mtf_features ["targets" ], vocab_dim )
516
- # 0=padding and negative targets are a hack to indicate no loss
517
- cross_entropy *= mtf .cast (
518
- mtf .greater (targets , 0 ), cross_entropy .dtype )
519
- if model_type == "delimited_lm" :
520
- cross_entropy *= mtf .cast (mtf .logical_not (
521
- transformer .delimited_lm_inputs_mask (targets )), cross_entropy .dtype )
522
- scores = - mtf .reduce_sum (cross_entropy , reduced_dim = length_dim )
523
525
scores = mtf .anonymize (scores )
524
526
targets = mtf .anonymize (targets )
525
527
lowering = mtf .Lowering (graph , {mesh : mesh_impl }, autostack = autostack )
0 commit comments