Skip to content

Commit 1b5cabe

Browse files
tf-model-analysis-teamtfx-copybara
tf-model-analysis-team
authored andcommitted
Enable the sql_slice_key extractor when evaluating a model.
PiperOrigin-RevId: 423119962
1 parent b31f630 commit 1b5cabe

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

RELEASE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
* Fixed issue with aggregation type not being set properly in keys associated
1010
with confusion matrix metrics.
11+
* Enabled the sql_slice_key extractor when evaluating a model.
1112
* Depends on `numpy>=1.16,<2`.
1213
* Depends on `absl-py>=0.9,<2.0.0`.
1314
* Depends on

tensorflow_model_analysis/api/model_eval_lib.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
# TODO(b/149126671): Put ValidationResultsWriter in a separate file.
1717

18-
1918
import os
2019
import tempfile
2120
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
@@ -108,9 +107,8 @@ def _is_legacy_eval(
108107
return ((config_version is not None and config_version == 1) or
109108
(eval_shared_model and not isinstance(eval_shared_model, dict) and
110109
not isinstance(eval_shared_model, list) and
111-
(not eval_shared_model.model_loader.tags or
112-
eval_constants.EVAL_TAG in eval_shared_model.model_loader.tags) and
113-
not eval_config))
110+
(not eval_shared_model.model_loader.tags or eval_constants.EVAL_TAG
111+
in eval_shared_model.model_loader.tags) and not eval_config))
114112

115113

116114
def _default_eval_config(eval_shared_models: List[types.EvalSharedModel],
@@ -441,6 +439,14 @@ def default_eval_shared_model(
441439
is_baseline=is_baseline)
442440

443441

442+
def _has_sql_slices(eval_config: Optional[config_pb2.EvalConfig]) -> bool:
443+
if eval_config:
444+
for spec in eval_config.slicing_specs:
445+
if spec.slice_keys_sql:
446+
return True
447+
return False
448+
449+
444450
def default_extractors( # pylint: disable=invalid-name
445451
eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels] = None,
446452
eval_config: Optional[config_pb2.EvalConfig] = None,
@@ -495,7 +501,16 @@ def default_extractors( # pylint: disable=invalid-name
495501
slice_key_extractor.SliceKeyExtractor(
496502
eval_config=eval_config, materialize=materialize)
497503
]
498-
elif eval_shared_model:
504+
slicing_extractors = []
505+
if _has_sql_slices(eval_config):
506+
slicing_extractors.append(
507+
sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config))
508+
slicing_extractors.extend([
509+
unbatch_extractor.UnbatchExtractor(),
510+
slice_key_extractor.SliceKeyExtractor(
511+
eval_config=eval_config, materialize=materialize)
512+
])
513+
if eval_shared_model:
499514
model_types = _model_types(eval_shared_model)
500515
eval_shared_models = model_util.verify_and_update_eval_shared_models(
501516
eval_shared_model)
@@ -520,11 +535,8 @@ def default_extractors( # pylint: disable=invalid-name
520535
eval_config=eval_config),
521536
(custom_predict_extractor or
522537
tflite_predict_extractor.TFLitePredictExtractor(
523-
eval_config=eval_config, eval_shared_model=eval_shared_model)),
524-
unbatch_extractor.UnbatchExtractor(),
525-
slice_key_extractor.SliceKeyExtractor(
526-
eval_config=eval_config, materialize=materialize)
527-
]
538+
eval_config=eval_config, eval_shared_model=eval_shared_model))
539+
] + slicing_extractors
528540
elif constants.TF_LITE in model_types:
529541
raise NotImplementedError(
530542
'support for mixing tf_lite and non-tf_lite models is not '
@@ -538,11 +550,8 @@ def default_extractors( # pylint: disable=invalid-name
538550
eval_config=eval_config),
539551
(custom_predict_extractor or
540552
tfjs_predict_extractor.TFJSPredictExtractor(
541-
eval_config=eval_config, eval_shared_model=eval_shared_model)),
542-
unbatch_extractor.UnbatchExtractor(),
543-
slice_key_extractor.SliceKeyExtractor(
544-
eval_config=eval_config, materialize=materialize)
545-
]
553+
eval_config=eval_config, eval_shared_model=eval_shared_model))
554+
] + slicing_extractors
546555
elif constants.TF_JS in model_types:
547556
raise NotImplementedError(
548557
'support for mixing tf_js and non-tf_js models is not '
@@ -555,11 +564,8 @@ def default_extractors( # pylint: disable=invalid-name
555564
custom_predict_extractor or legacy_predict_extractor.PredictExtractor(
556565
eval_shared_model,
557566
materialize=materialize,
558-
eval_config=eval_config),
559-
unbatch_extractor.UnbatchExtractor(),
560-
slice_key_extractor.SliceKeyExtractor(
561-
eval_config=eval_config, materialize=materialize)
562-
]
567+
eval_config=eval_config)
568+
] + slicing_extractors
563569
elif (eval_config and constants.TF_ESTIMATOR in model_types and
564570
any(eval_constants.EVAL_TAG in m.model_loader.tags
565571
for m in eval_shared_models)):
@@ -585,23 +591,17 @@ def default_extractors( # pylint: disable=invalid-name
585591
eval_config=eval_config,
586592
eval_shared_model=eval_shared_model,
587593
tensor_adapter_config=tensor_adapter_config)),
588-
unbatch_extractor.UnbatchExtractor(),
589-
slice_key_extractor.SliceKeyExtractor(
590-
eval_config=eval_config, materialize=materialize)
591594
])
595+
extractors.extend(slicing_extractors)
592596
return extractors
593597
else:
594598
return [
595599
features_extractor.FeaturesExtractor(eval_config=eval_config),
596600
labels_extractor.LabelsExtractor(eval_config=eval_config),
597601
example_weights_extractor.ExampleWeightsExtractor(
598602
eval_config=eval_config),
599-
predictions_extractor.PredictionsExtractor(eval_config=eval_config),
600-
sql_slice_key_extractor.SqlSliceKeyExtractor(eval_config),
601-
unbatch_extractor.UnbatchExtractor(),
602-
slice_key_extractor.SliceKeyExtractor(
603-
eval_config=eval_config, materialize=materialize)
604-
]
603+
predictions_extractor.PredictionsExtractor(eval_config=eval_config)
604+
] + slicing_extractors
605605

606606

607607
def default_evaluators( # pylint: disable=invalid-name

0 commit comments

Comments
 (0)