Skip to content

Commit 5341e91

Browse files
genehwungtfx-copybara
authored andcommitted
Add a new model type "materialized_prediction" to identify explicit materialized prediction extractor.
PiperOrigin-RevId: 553366496
1 parent a642b9e commit 5341e91

File tree

4 files changed

+115
-76
lines changed

4 files changed

+115
-76
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
eval_saved_model, allowing signature='eval' to now be used with other model
2525
types.
2626

27+
* Add "materialized_prediction" model type to allow users bypassing model
28+
inference explicitly.
29+
2730
## Breaking Changes
2831

2932
* Depend on PIL for image related metrics.

tensorflow_model_analysis/api/model_eval_lib.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,18 @@ def default_extractors( # pylint: disable=invalid-name
573573
eval_config=eval_config, materialize=materialize)
574574
])
575575

576+
extract_features = features_extractor.FeaturesExtractor(
577+
eval_config=eval_config, tensor_representations=tensor_representations
578+
)
579+
extract_labels = labels_extractor.LabelsExtractor(eval_config=eval_config)
580+
extract_example_weights = example_weights_extractor.ExampleWeightsExtractor(
581+
eval_config=eval_config
582+
)
583+
extract_materialized_predictions = (
584+
materialized_predictions_extractor.MaterializedPredictionsExtractor(
585+
eval_config=eval_config
586+
)
587+
)
576588
if eval_shared_model:
577589
model_types = _model_types(eval_shared_models)
578590
logging.info('eval_shared_models have model_types: %s', model_types)
@@ -582,21 +594,29 @@ def default_extractors( # pylint: disable=invalid-name
582594
'either a custom_predict_extractor must be used or model type must '
583595
'be one of: {}. evalconfig={}'.format(
584596
str(constants.VALID_TF_MODEL_TYPES), eval_config))
585-
if model_types == {constants.TF_LITE}:
597+
if model_types == {constants.MATERIALIZED_PREDICTION}:
598+
return [
599+
extract_features,
600+
extract_labels,
601+
extract_example_weights,
602+
extract_materialized_predictions,
603+
] + slicing_extractors
604+
elif model_types == {constants.TF_LITE}:
586605
# TODO(b/163889779): Convert TFLite extractor to operate on batched
587606
# extracts. Then we can remove the input extractor.
588607
return [
589-
features_extractor.FeaturesExtractor(
590-
eval_config=eval_config,
591-
tensor_representations=tensor_representations),
608+
extract_features,
592609
transformed_features_extractor.TransformedFeaturesExtractor(
593-
eval_config=eval_config, eval_shared_model=eval_shared_model),
594-
labels_extractor.LabelsExtractor(eval_config=eval_config),
595-
example_weights_extractor.ExampleWeightsExtractor(
596-
eval_config=eval_config),
597-
(custom_predict_extractor or
598-
tflite_predict_extractor.TFLitePredictExtractor(
599-
eval_config=eval_config, eval_shared_model=eval_shared_model))
610+
eval_config=eval_config, eval_shared_model=eval_shared_model
611+
),
612+
extract_labels,
613+
extract_example_weights,
614+
(
615+
custom_predict_extractor
616+
or tflite_predict_extractor.TFLitePredictExtractor(
617+
eval_config=eval_config, eval_shared_model=eval_shared_model
618+
)
619+
),
600620
] + slicing_extractors
601621
elif constants.TF_LITE in model_types:
602622
raise NotImplementedError(
@@ -605,15 +625,15 @@ def default_extractors( # pylint: disable=invalid-name
605625

606626
if model_types == {constants.TF_JS}:
607627
return [
608-
features_extractor.FeaturesExtractor(
609-
eval_config=eval_config,
610-
tensor_representations=tensor_representations),
611-
labels_extractor.LabelsExtractor(eval_config=eval_config),
612-
example_weights_extractor.ExampleWeightsExtractor(
613-
eval_config=eval_config),
614-
(custom_predict_extractor or
615-
tfjs_predict_extractor.TFJSPredictExtractor(
616-
eval_config=eval_config, eval_shared_model=eval_shared_model))
628+
extract_features,
629+
extract_labels,
630+
extract_example_weights,
631+
(
632+
custom_predict_extractor
633+
or tfjs_predict_extractor.TFJSPredictExtractor(
634+
eval_config=eval_config, eval_shared_model=eval_shared_model
635+
)
636+
),
617637
] + slicing_extractors
618638
elif constants.TF_JS in model_types:
619639
raise NotImplementedError(
@@ -646,35 +666,29 @@ def default_extractors( # pylint: disable=invalid-name
646666
'implemented: eval_config={}'.format(eval_config)
647667
)
648668
else:
649-
extractors = [
650-
features_extractor.FeaturesExtractor(
651-
eval_config=eval_config,
652-
tensor_representations=tensor_representations)
653-
]
669+
extractors = [extract_features]
654670
if not custom_predict_extractor:
655671
extractors.append(
656672
transformed_features_extractor.TransformedFeaturesExtractor(
657673
eval_config=eval_config, eval_shared_model=eval_shared_model))
658674
extractors.extend([
659-
labels_extractor.LabelsExtractor(eval_config=eval_config),
660-
example_weights_extractor.ExampleWeightsExtractor(
661-
eval_config=eval_config),
662-
(custom_predict_extractor or
663-
predictions_extractor.PredictionsExtractor(
664-
eval_config=eval_config, eval_shared_model=eval_shared_model)),
675+
extract_labels,
676+
extract_example_weights,
677+
(
678+
custom_predict_extractor
679+
or predictions_extractor.PredictionsExtractor(
680+
eval_config=eval_config, eval_shared_model=eval_shared_model
681+
)
682+
),
665683
])
666684
extractors.extend(slicing_extractors)
667685
return extractors
668686
else:
669687
return [
670-
features_extractor.FeaturesExtractor(
671-
eval_config=eval_config,
672-
tensor_representations=tensor_representations),
673-
labels_extractor.LabelsExtractor(eval_config=eval_config),
674-
example_weights_extractor.ExampleWeightsExtractor(
675-
eval_config=eval_config),
676-
materialized_predictions_extractor.MaterializedPredictionsExtractor(
677-
eval_config),
688+
extract_features,
689+
extract_labels,
690+
extract_example_weights,
691+
extract_materialized_predictions,
678692
] + slicing_extractors
679693

680694

tensorflow_model_analysis/api/model_eval_lib_test.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,13 @@ def testRunModelAnalysisMultipleModels(self):
714714
self.assertMetricsAlmostEqual(eval_result_2.slicing_metrics,
715715
expected_result_2)
716716

717-
def testRunModelAnalysisWithModelAgnosticPredictions(self):
717+
@parameterized.named_parameters(
718+
('no_model', False, None),
719+
('has_a_model', True, constants.MATERIALIZED_PREDICTION),
720+
)
721+
def testRunModelAnalysisWithExplicitModelAgnosticPredictions(
722+
self, has_model, model_type
723+
):
718724
examples = [
719725
self._makeExample(
720726
age=3.0, language='english', label=1.0, prediction=0.9),
@@ -725,13 +731,6 @@ def testRunModelAnalysisWithModelAgnosticPredictions(self):
725731
self._makeExample(
726732
age=5.0, language='chinese', label=1.0, prediction=0.2)
727733
]
728-
data_location = self._writeTFExamplesToTFRecords(examples)
729-
model_specs = [
730-
config_pb2.ModelSpec(
731-
prediction_key='prediction',
732-
label_key='label',
733-
example_weight_key='age')
734-
]
735734
metrics_specs = [
736735
config_pb2.MetricsSpec(
737736
metrics=[config_pb2.MetricConfig(class_name='ExampleCount')],
@@ -746,41 +745,56 @@ def testRunModelAnalysisWithModelAgnosticPredictions(self):
746745
example_weights=config_pb2.ExampleWeightOptions(weighted=True))
747746
]
748747
slicing_specs = [config_pb2.SlicingSpec(feature_keys=['language'])]
748+
model_spec = config_pb2.ModelSpec(
749+
prediction_key='prediction',
750+
label_key='label',
751+
example_weight_key='age',
752+
)
753+
if model_type is not None:
754+
model_spec.model_type = model_type
749755
eval_config = config_pb2.EvalConfig(
750-
model_specs=model_specs,
756+
model_specs=[model_spec],
751757
metrics_specs=metrics_specs,
752-
slicing_specs=slicing_specs)
753-
eval_result = model_eval_lib.run_model_analysis(
754-
eval_config=eval_config,
755-
data_location=data_location,
756-
output_path=self._getTempDir())
758+
slicing_specs=slicing_specs,
759+
)
760+
data_location = self._writeTFExamplesToTFRecords(examples)
761+
if has_model:
762+
model_location = self._exportEvalSavedModel(
763+
linear_classifier.simple_linear_classifier
764+
)
765+
model = model_eval_lib.default_eval_shared_model(
766+
eval_saved_model_path=model_location,
767+
eval_config=eval_config,
768+
)
769+
eval_result = model_eval_lib.run_model_analysis(
770+
eval_shared_model=model,
771+
eval_config=eval_config,
772+
data_location=data_location,
773+
output_path=self._getTempDir(),
774+
)
775+
else:
776+
eval_result = model_eval_lib.run_model_analysis(
777+
eval_config=eval_config,
778+
data_location=data_location,
779+
output_path=self._getTempDir(),
780+
)
757781
expected = {
758782
(('language', 'chinese'),): {
759-
'binary_accuracy': {
760-
'doubleValue': 0.375
761-
},
762-
'weighted_example_count': {
763-
'doubleValue': 8.0
764-
},
765-
'example_count': {
766-
'doubleValue': 2.0
767-
},
783+
'binary_accuracy': {'doubleValue': 0.375},
784+
'weighted_example_count': {'doubleValue': 8.0},
785+
'example_count': {'doubleValue': 2.0},
768786
},
769787
(('language', 'english'),): {
770-
'binary_accuracy': {
771-
'doubleValue': 1.0
772-
},
773-
'weighted_example_count': {
774-
'doubleValue': 7.0
775-
},
776-
'example_count': {
777-
'doubleValue': 2.0
778-
},
779-
}
788+
'binary_accuracy': {'doubleValue': 1.0},
789+
'weighted_example_count': {'doubleValue': 7.0},
790+
'example_count': {'doubleValue': 2.0},
791+
},
780792
}
781793
self.assertEqual(eval_result.data_location, data_location)
782-
self.assertEqual(eval_result.config.slicing_specs[0],
783-
config_pb2.SlicingSpec(feature_keys=['language']))
794+
self.assertEqual(
795+
eval_result.config.slicing_specs[0],
796+
config_pb2.SlicingSpec(feature_keys=['language']),
797+
)
784798
self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected)
785799

786800
@parameterized.named_parameters(

tensorflow_model_analysis/constants.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,16 @@
3232
TF_GENERIC = 'tf_generic'
3333
TF_LITE = 'tf_lite'
3434
TF_JS = 'tf_js'
35-
VALID_TF_MODEL_TYPES = (TFMA_EVAL, TF_GENERIC, TF_ESTIMATOR, TF_KERAS, TF_LITE,
36-
TF_JS)
35+
MATERIALIZED_PREDICTION = 'materialized_prediction'
36+
VALID_TF_MODEL_TYPES = (
37+
TFMA_EVAL,
38+
TF_GENERIC,
39+
TF_ESTIMATOR,
40+
TF_KERAS,
41+
TF_LITE,
42+
TF_JS,
43+
MATERIALIZED_PREDICTION,
44+
)
3745

3846
# This constant is only used for telemetry
3947
MODEL_AGNOSTIC = 'model_agnostic'

0 commit comments

Comments
 (0)