@@ -573,6 +573,18 @@ def default_extractors( # pylint: disable=invalid-name
573
573
eval_config = eval_config , materialize = materialize )
574
574
])
575
575
576
+ extract_features = features_extractor .FeaturesExtractor (
577
+ eval_config = eval_config , tensor_representations = tensor_representations
578
+ )
579
+ extract_labels = labels_extractor .LabelsExtractor (eval_config = eval_config )
580
+ extract_example_weights = example_weights_extractor .ExampleWeightsExtractor (
581
+ eval_config = eval_config
582
+ )
583
+ extract_materialized_predictions = (
584
+ materialized_predictions_extractor .MaterializedPredictionsExtractor (
585
+ eval_config = eval_config
586
+ )
587
+ )
576
588
if eval_shared_model :
577
589
model_types = _model_types (eval_shared_models )
578
590
logging .info ('eval_shared_models have model_types: %s' , model_types )
@@ -582,21 +594,29 @@ def default_extractors( # pylint: disable=invalid-name
582
594
'either a custom_predict_extractor must be used or model type must '
583
595
'be one of: {}. evalconfig={}' .format (
584
596
str (constants .VALID_TF_MODEL_TYPES ), eval_config ))
585
- if model_types == {constants .TF_LITE }:
597
+ if model_types == {constants .MATERIALIZED_PREDICTION }:
598
+ return [
599
+ extract_features ,
600
+ extract_labels ,
601
+ extract_example_weights ,
602
+ extract_materialized_predictions ,
603
+ ] + slicing_extractors
604
+ elif model_types == {constants .TF_LITE }:
586
605
# TODO(b/163889779): Convert TFLite extractor to operate on batched
587
606
# extracts. Then we can remove the input extractor.
588
607
return [
589
- features_extractor .FeaturesExtractor (
590
- eval_config = eval_config ,
591
- tensor_representations = tensor_representations ),
608
+ extract_features ,
592
609
transformed_features_extractor .TransformedFeaturesExtractor (
593
- eval_config = eval_config , eval_shared_model = eval_shared_model ),
594
- labels_extractor .LabelsExtractor (eval_config = eval_config ),
595
- example_weights_extractor .ExampleWeightsExtractor (
596
- eval_config = eval_config ),
597
- (custom_predict_extractor or
598
- tflite_predict_extractor .TFLitePredictExtractor (
599
- eval_config = eval_config , eval_shared_model = eval_shared_model ))
610
+ eval_config = eval_config , eval_shared_model = eval_shared_model
611
+ ),
612
+ extract_labels ,
613
+ extract_example_weights ,
614
+ (
615
+ custom_predict_extractor
616
+ or tflite_predict_extractor .TFLitePredictExtractor (
617
+ eval_config = eval_config , eval_shared_model = eval_shared_model
618
+ )
619
+ ),
600
620
] + slicing_extractors
601
621
elif constants .TF_LITE in model_types :
602
622
raise NotImplementedError (
@@ -605,15 +625,15 @@ def default_extractors( # pylint: disable=invalid-name
605
625
606
626
if model_types == {constants .TF_JS }:
607
627
return [
608
- features_extractor . FeaturesExtractor (
609
- eval_config = eval_config ,
610
- tensor_representations = tensor_representations ) ,
611
- labels_extractor . LabelsExtractor ( eval_config = eval_config ),
612
- example_weights_extractor . ExampleWeightsExtractor (
613
- eval_config = eval_config ),
614
- ( custom_predict_extractor or
615
- tfjs_predict_extractor . TFJSPredictExtractor (
616
- eval_config = eval_config , eval_shared_model = eval_shared_model ))
628
+ extract_features ,
629
+ extract_labels ,
630
+ extract_example_weights ,
631
+ (
632
+ custom_predict_extractor
633
+ or tfjs_predict_extractor . TFJSPredictExtractor (
634
+ eval_config = eval_config , eval_shared_model = eval_shared_model
635
+ )
636
+ ),
617
637
] + slicing_extractors
618
638
elif constants .TF_JS in model_types :
619
639
raise NotImplementedError (
@@ -646,35 +666,29 @@ def default_extractors( # pylint: disable=invalid-name
646
666
'implemented: eval_config={}' .format (eval_config )
647
667
)
648
668
else :
649
- extractors = [
650
- features_extractor .FeaturesExtractor (
651
- eval_config = eval_config ,
652
- tensor_representations = tensor_representations )
653
- ]
669
+ extractors = [extract_features ]
654
670
if not custom_predict_extractor :
655
671
extractors .append (
656
672
transformed_features_extractor .TransformedFeaturesExtractor (
657
673
eval_config = eval_config , eval_shared_model = eval_shared_model ))
658
674
extractors .extend ([
659
- labels_extractor .LabelsExtractor (eval_config = eval_config ),
660
- example_weights_extractor .ExampleWeightsExtractor (
661
- eval_config = eval_config ),
662
- (custom_predict_extractor or
663
- predictions_extractor .PredictionsExtractor (
664
- eval_config = eval_config , eval_shared_model = eval_shared_model )),
675
+ extract_labels ,
676
+ extract_example_weights ,
677
+ (
678
+ custom_predict_extractor
679
+ or predictions_extractor .PredictionsExtractor (
680
+ eval_config = eval_config , eval_shared_model = eval_shared_model
681
+ )
682
+ ),
665
683
])
666
684
extractors .extend (slicing_extractors )
667
685
return extractors
668
686
else :
669
687
return [
670
- features_extractor .FeaturesExtractor (
671
- eval_config = eval_config ,
672
- tensor_representations = tensor_representations ),
673
- labels_extractor .LabelsExtractor (eval_config = eval_config ),
674
- example_weights_extractor .ExampleWeightsExtractor (
675
- eval_config = eval_config ),
676
- materialized_predictions_extractor .MaterializedPredictionsExtractor (
677
- eval_config ),
688
+ extract_features ,
689
+ extract_labels ,
690
+ extract_example_weights ,
691
+ extract_materialized_predictions ,
678
692
] + slicing_extractors
679
693
680
694
0 commit comments