Skip to content

Commit ff51128

Browse files
mdrevestf-model-analysis-team
authored and
tf-model-analysis-team
committed
Update tfma.default_eval_shared_model and tfma.default_extractors to better support custom model types.
PiperOrigin-RevId: 317473286
1 parent 6c223c6 commit ff51128

File tree

11 files changed

+220
-117
lines changed

11 files changed

+220
-117
lines changed

RELEASE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
* Add `tfma.slicer.stringify_slice_key()`.
4141
* Deprecated external use of tfma.slicer.SingleSliceSpec (tfma.SlicingSpec
4242
should be used instead).
43+
* Updated tfma.default_eval_shared_model and tfma.default_extractors to better
44+
support custom model types.
4345

4446
## Breaking changes
4547

@@ -48,6 +50,8 @@
4850
* Refactored confidence interval methodology field. The old path under
4951
`Options.confidence_interval_methodology` is now at
5052
`Options.confidence_intervals.methodology`.
53+
* Removed model_load_time_callback from ModelLoader construct_fn (timing is
54+
now handled by load). Removed access to shared_handle from ModelLoader.
5155

5256
## Deprecations
5357

tensorflow_model_analysis/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@
9797

9898
from tensorflow_model_analysis.post_export_metrics import post_export_metrics
9999

100+
from tensorflow_model_analysis.model_util import CombineFnWithModels
101+
from tensorflow_model_analysis.model_util import DoFnWithModels
100102
from tensorflow_model_analysis.model_util import get_model_type
101103
from tensorflow_model_analysis.model_util import model_construct_fn
102104
from tensorflow_model_analysis.model_util import verify_and_update_eval_shared_models

tensorflow_model_analysis/api/model_eval_lib.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,9 @@ def default_eval_shared_model(
291291
blacklist_feature_fetches: Optional[List[Text]] = None,
292292
tags: Optional[List[Text]] = None,
293293
model_name: Text = '',
294-
eval_config: Optional[config.EvalConfig] = None) -> types.EvalSharedModel:
294+
eval_config: Optional[config.EvalConfig] = None,
295+
custom_model_loader: Optional[types.ModelLoader] = None
296+
) -> types.EvalSharedModel:
295297
"""Returns default EvalSharedModel.
296298
297299
Args:
@@ -318,6 +320,7 @@ def default_eval_shared_model(
318320
ModelSpecs.name). The name should only be provided if multiple models are
319321
being evaluated.
320322
eval_config: Eval config. Only used for setting default tags.
323+
custom_model_loader: Optional custom model loader for non-TF models.
321324
"""
322325
if not eval_config:
323326
model_type = constants.TF_ESTIMATOR
@@ -360,6 +363,19 @@ def default_eval_shared_model(
360363
add_metrics_callbacks.append(example_weight_callback)
361364
# pytype: enable=module-attr
362365

366+
model_loader = custom_model_loader
367+
if not model_loader and model_type in constants.VALID_TF_MODEL_TYPES:
368+
model_loader = types.ModelLoader(
369+
construct_fn=model_util.model_construct_fn(
370+
eval_saved_model_path=eval_saved_model_path,
371+
add_metrics_callbacks=add_metrics_callbacks,
372+
include_default_metrics=include_default_metrics,
373+
additional_fetches=additional_fetches,
374+
blacklist_feature_fetches=blacklist_feature_fetches,
375+
model_type=model_type,
376+
tags=tags),
377+
tags=tags)
378+
363379
return types.EvalSharedModel(
364380
model_name=model_name,
365381
model_type=model_type,
@@ -368,16 +384,7 @@ def default_eval_shared_model(
368384
include_default_metrics=include_default_metrics,
369385
example_weight_key=example_weight_key,
370386
additional_fetches=additional_fetches,
371-
model_loader=types.ModelLoader(
372-
tags=tags,
373-
construct_fn=model_util.model_construct_fn(
374-
eval_saved_model_path=eval_saved_model_path,
375-
add_metrics_callbacks=add_metrics_callbacks,
376-
include_default_metrics=include_default_metrics,
377-
additional_fetches=additional_fetches,
378-
blacklist_feature_fetches=blacklist_feature_fetches,
379-
model_type=model_type,
380-
tags=tags)))
387+
model_loader=model_loader)
381388

382389

383390
def default_extractors( # pylint: disable=invalid-name
@@ -387,6 +394,7 @@ def default_extractors( # pylint: disable=invalid-name
387394
materialize: Optional[bool] = True,
388395
enable_batched_extractors: Optional[bool] = False,
389396
tensor_adapter_config: Optional[tensor_adapter.TensorAdapterConfig] = None,
397+
custom_predict_extractor: Optional[extractor.Extractor] = None
390398
) -> List[extractor.Extractor]:
391399
"""Returns the default extractors for use in ExtractAndEvaluate.
392400
@@ -401,6 +409,8 @@ def default_extractors( # pylint: disable=invalid-name
401409
tensor_adapter_config: Tensor adapter config which specifies how to obtain
402410
tensors from the Arrow RecordBatch. If None, we feed the raw examples to
403411
the model.
412+
custom_predict_extractor: Optional custom predict extractor for non-TF
413+
models.
404414
405415
Raises:
406416
NotImplementedError: If eval_config contains mixed serving and eval models.
@@ -417,7 +427,7 @@ def default_extractors( # pylint: disable=invalid-name
417427
eval_config = config.EvalConfig(
418428
slicing_specs=[s.to_proto() for s in slice_spec])
419429
return [
420-
predict_extractor.PredictExtractor(
430+
custom_predict_extractor or predict_extractor.PredictExtractor(
421431
eval_shared_model, materialize=materialize),
422432
slice_key_extractor.SliceKeyExtractor(
423433
eval_config=eval_config, materialize=materialize)
@@ -427,15 +437,18 @@ def default_extractors( # pylint: disable=invalid-name
427437
eval_shared_models = model_util.verify_and_update_eval_shared_models(
428438
eval_shared_model)
429439

430-
if not model_types.issubset(constants.VALID_MODEL_TYPES):
440+
if (not model_types.issubset(constants.VALID_TF_MODEL_TYPES) and
441+
not custom_predict_extractor):
431442
raise NotImplementedError(
432-
'model type must be one of: {}. evalconfig={}'.format(
433-
str(constants.VALID_MODEL_TYPES), eval_config))
443+
'either a custom_predict_extractor must be used or model type must '
444+
'be one of: {}. evalconfig={}'.format(
445+
str(constants.VALID_TF_MODEL_TYPES), eval_config))
434446
if model_types == set([constants.TF_LITE]):
435447
return [
436448
input_extractor.InputExtractor(eval_config=eval_config),
437-
tflite_predict_extractor.TFLitePredictExtractor(
438-
eval_config=eval_config, eval_shared_model=eval_shared_model),
449+
(custom_predict_extractor or
450+
tflite_predict_extractor.TFLitePredictExtractor(
451+
eval_config=eval_config, eval_shared_model=eval_shared_model)),
439452
slice_key_extractor.SliceKeyExtractor(
440453
eval_config=eval_config, materialize=materialize)
441454
]
@@ -448,7 +461,7 @@ def default_extractors( # pylint: disable=invalid-name
448461
all(eval_constants.EVAL_TAG in m.model_loader.tags
449462
for m in eval_shared_models)):
450463
return [
451-
predict_extractor.PredictExtractor(
464+
custom_predict_extractor or predict_extractor.PredictExtractor(
452465
eval_shared_model,
453466
materialize=materialize,
454467
eval_config=eval_config),
@@ -466,18 +479,19 @@ def default_extractors( # pylint: disable=invalid-name
466479
return [
467480
batched_input_extractor.BatchedInputExtractor(
468481
eval_config=eval_config),
469-
batched_predict_extractor_v2.BatchedPredictExtractor(
470-
eval_config=eval_config,
471-
eval_shared_model=eval_shared_model,
472-
tensor_adapter_config=tensor_adapter_config),
482+
(custom_predict_extractor or
483+
batched_predict_extractor_v2.BatchedPredictExtractor(
484+
eval_config=eval_config,
485+
eval_shared_model=eval_shared_model,
486+
tensor_adapter_config=tensor_adapter_config)),
473487
unbatch_extractor.UnbatchExtractor(),
474488
slice_key_extractor.SliceKeyExtractor(
475489
eval_config=eval_config, materialize=materialize)
476490
]
477491
else:
478492
return [
479493
input_extractor.InputExtractor(eval_config=eval_config),
480-
predict_extractor_v2.PredictExtractor(
494+
custom_predict_extractor or predict_extractor_v2.PredictExtractor(
481495
eval_config=eval_config, eval_shared_model=eval_shared_model),
482496
slice_key_extractor.SliceKeyExtractor(
483497
eval_config=eval_config, materialize=materialize)

tensorflow_model_analysis/api/model_eval_lib_test.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,89 @@ def testRunModelAnalysis(self):
367367
self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected)
368368
self.assertFalse(eval_result.plots)
369369

370+
def testRunModelAnalysisWithCustomizations(self):
371+
model_location = self._exportEvalSavedModel(
372+
linear_classifier.simple_linear_classifier)
373+
examples = [
374+
self._makeExample(age=3.0, language='english', label=1.0),
375+
self._makeExample(age=3.0, language='chinese', label=0.0),
376+
self._makeExample(age=4.0, language='english', label=1.0),
377+
self._makeExample(age=5.0, language='chinese', label=1.0),
378+
self._makeExample(age=5.0, language='hindi', label=1.0)
379+
]
380+
data_location = self._writeTFExamplesToTFRecords(examples)
381+
slicing_specs = [config.SlicingSpec(feature_keys=['language'])]
382+
options = config.Options()
383+
options.min_slice_size.value = 2
384+
eval_config = config.EvalConfig(
385+
model_specs=[config.ModelSpec(model_type='my_model_type')],
386+
slicing_specs=slicing_specs,
387+
options=options)
388+
# Use default model_loader for testing passing custom_model_loader
389+
model_loader = model_eval_lib.default_eval_shared_model(
390+
eval_saved_model_path=model_location,
391+
example_weight_key='age').model_loader
392+
eval_shared_model = model_eval_lib.default_eval_shared_model(
393+
eval_saved_model_path=model_location, custom_model_loader=model_loader)
394+
# Use PredictExtractor for testing passing custom_predict_extractor
395+
extractors = model_eval_lib.default_extractors(
396+
eval_shared_model=eval_shared_model,
397+
eval_config=eval_config,
398+
custom_predict_extractor=predict_extractor.PredictExtractor(
399+
eval_shared_model=eval_shared_model, eval_config=eval_config))
400+
eval_result = model_eval_lib.run_model_analysis(
401+
eval_config=eval_config,
402+
eval_shared_model=eval_shared_model,
403+
data_location=data_location,
404+
output_path=self._getTempDir(),
405+
extractors=extractors)
406+
# We only check some of the metrics to ensure that the end-to-end
407+
# pipeline works.
408+
expected = {
409+
(('language', 'hindi'),): {
410+
u'__ERROR__': {
411+
'debugMessage':
412+
u'Example count for this slice key is lower than the '
413+
u'minimum required value: 2. No data is aggregated for '
414+
u'this slice.'
415+
},
416+
},
417+
(('language', 'chinese'),): {
418+
'accuracy': {
419+
'doubleValue': 0.5
420+
},
421+
'my_mean_label': {
422+
'doubleValue': 0.5
423+
},
424+
metric_keys.EXAMPLE_WEIGHT: {
425+
'doubleValue': 8.0
426+
},
427+
metric_keys.EXAMPLE_COUNT: {
428+
'doubleValue': 2.0
429+
},
430+
},
431+
(('language', 'english'),): {
432+
'accuracy': {
433+
'doubleValue': 1.0
434+
},
435+
'my_mean_label': {
436+
'doubleValue': 1.0
437+
},
438+
metric_keys.EXAMPLE_WEIGHT: {
439+
'doubleValue': 7.0
440+
},
441+
metric_keys.EXAMPLE_COUNT: {
442+
'doubleValue': 2.0
443+
},
444+
}
445+
}
446+
self.assertEqual(eval_result.model_location, model_location.decode())
447+
self.assertEqual(eval_result.data_location, data_location)
448+
self.assertEqual(eval_result.config.slicing_specs[0],
449+
config.SlicingSpec(feature_keys=['language']))
450+
self.assertMetricsAlmostEqual(eval_result.slicing_metrics, expected)
451+
self.assertFalse(eval_result.plots)
452+
370453
def testRunModelAnalysisMultipleModels(self):
371454
examples = [
372455
self._makeExample(age=3.0, language='english', label=1.0),

tensorflow_model_analysis/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
PLACEHOLDER = 'placeholder'
2828
SPARSE_PLACEHOLDER = 'sparse_placeholder'
2929

30-
# Types of models
30+
# Types of TF models
3131
TF_ESTIMATOR = 'tf_estimator'
3232
TF_KERAS = 'tf_keras'
3333
TF_GENERIC = 'tf_generic'
3434
TF_LITE = 'tf_lite'
35-
VALID_MODEL_TYPES = ('', TF_GENERIC, TF_ESTIMATOR, TF_KERAS, TF_LITE)
35+
VALID_TF_MODEL_TYPES = (TF_GENERIC, TF_ESTIMATOR, TF_KERAS, TF_LITE)
3636

3737
# LINT.IfChange
3838
METRICS_NAMESPACE = 'tfx.ModelAnalysis'

tensorflow_model_analysis/evaluators/keras_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def metrics_specs_from_keras(
3131
model_loader: types.ModelLoader,
3232
) -> List[config.MetricsSpec]:
3333
"""Returns metrics specs for metrics and losses associated with the model."""
34-
model = model_loader.construct_fn(lambda x: None)()
34+
model = model_loader.construct_fn()
3535
if model is None:
3636
return []
3737

tensorflow_model_analysis/evaluators/keras_util_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def testMetricSpecsFromKeras(self):
6969

7070
# TODO(b/149995449): Keras does not support re-loading metrics with the new
7171
# API. Re-enable after this is fixed.
72-
model = eval_shared_model.model_loader.construct_fn(lambda x: None)()
72+
model = eval_shared_model.model_loader.construct_fn()
7373
if not hasattr(model, 'loss_functions'):
7474
return
7575

@@ -122,7 +122,7 @@ def testMetricSpecsFromKerasSequential(self):
122122

123123
# TODO(b/149995449): Keras does not support re-loading metrics with the new
124124
# API. Re-enable after this is fixed.
125-
model = eval_shared_model.model_loader.construct_fn(lambda x: None)()
125+
model = eval_shared_model.model_loader.construct_fn()
126126
if not hasattr(model, 'loss_functions'):
127127
return
128128

@@ -184,7 +184,7 @@ def testMetricSpecsFromKerasWithMultipleOutputs(self):
184184

185185
# TODO(b/149995449): Keras does not support re-loading metrics with the new
186186
# API. Re-enable after this is fixed.
187-
model = eval_shared_model.model_loader.construct_fn(lambda x: None)()
187+
model = eval_shared_model.model_loader.construct_fn()
188188
if not hasattr(model, 'loss_functions'):
189189
return
190190

tensorflow_model_analysis/evaluators/metrics_and_plots_evaluator_v2_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,8 @@ def check_validations(got):
307307
# TODO(b/149995449): Keras does not support re-loading metrics with
308308
# its new API so the loss added at compile time will be missing.
309309
# Re-enable after this is fixed.
310-
if hasattr(
311-
eval_shared_model.model_loader.construct_fn(lambda x: None)(),
312-
'compiled_metrics'):
310+
if hasattr(eval_shared_model.model_loader.construct_fn(),
311+
'compiled_metrics'):
313312
expected_metric_validations_per_slice = (
314313
expected_metric_validations_per_slice[:3])
315314
self.assertLen(got.metric_validations_per_slice[0].failures,

tensorflow_model_analysis/model_agnostic_eval/model_agnostic_evaluate_graph.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,35 +25,24 @@
2525
# Standard __future__ imports
2626
from __future__ import print_function
2727

28-
import datetime
28+
from typing import List, Optional # pytype: disable=not-supported-yet
29+
2930
# Standard Imports
3031
import tensorflow as tf
3132

3233
from tensorflow_model_analysis import types
3334
from tensorflow_model_analysis.eval_metrics_graph import eval_metrics_graph
3435
from tensorflow_model_analysis.model_agnostic_eval import model_agnostic_predict
3536

36-
from typing import Callable, List, Optional # pytype: disable=not-supported-yet
37-
3837

3938
def make_construct_fn( # pylint: disable=invalid-name
4039
add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]],
4140
config: model_agnostic_predict.ModelAgnosticConfig):
4241
"""Returns a construct fn for constructing the model agnostic eval graph."""
4342

44-
def construct_fn(model_load_seconds_callback: Callable[[int], None]):
45-
"""Thin wrapper for the actual construct to allow for metrics."""
46-
47-
def construct(): # pylint: disable=invalid-name
48-
"""Function for constructing a model agnostic eval graph."""
49-
start_time = datetime.datetime.now()
50-
model_agnostic_eval = ModelAgnosticEvaluateGraph(add_metrics_callbacks,
51-
config)
52-
end_time = datetime.datetime.now()
53-
model_load_seconds_callback(int((end_time - start_time).total_seconds()))
54-
return model_agnostic_eval
55-
56-
return construct
43+
def construct_fn(): # pylint: disable=invalid-name
44+
"""Function for constructing a model agnostic eval graph."""
45+
return ModelAgnosticEvaluateGraph(add_metrics_callbacks, config)
5746

5847
return construct_fn
5948

0 commit comments

Comments
 (0)