Skip to content

Commit 6c223c6

Browse files
mdrevestf-model-analysis-team
authored and
tf-model-analysis-team
committed
Deprecated external use of tfma.slicer.SingleSliceSpec (tfma.SlicingSpec should be used).
PiperOrigin-RevId: 317321141
1 parent 6e170c3 commit 6c223c6

File tree

7 files changed

+86
-103
lines changed

7 files changed

+86
-103
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
load_eval_results.
3939
* Fix typo in jupyter widgets breaking TimeSeriesView and PlotViewer.
4040
* Add `tfma.slicer.stringify_slice_key()`.
41+
* Deprecated external use of tfma.slicer.SingleSliceSpec (tfma.SlicingSpec
42+
should be used instead).
4143

4244
## Breaking changes
4345

tensorflow_model_analysis/api/model_eval_lib.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -405,20 +405,22 @@ def default_extractors( # pylint: disable=invalid-name
405405
Raises:
406406
NotImplementedError: If eval_config contains mixed serving and eval models.
407407
"""
408+
if slice_spec and eval_config:
409+
raise ValueError('slice_spec is deprecated, only use eval_config')
408410
if eval_config is not None:
409411
eval_config = _update_eval_config_with_defaults(eval_config,
410412
eval_shared_model)
411-
slice_spec = [
412-
slicer.SingleSliceSpec(spec=spec) for spec in eval_config.slicing_specs
413-
]
414413

415414
if _is_legacy_eval(eval_shared_model, eval_config):
416415
# Backwards compatibility for previous add_metrics_callbacks implementation.
416+
if not eval_config and slice_spec:
417+
eval_config = config.EvalConfig(
418+
slicing_specs=[s.to_proto() for s in slice_spec])
417419
return [
418420
predict_extractor.PredictExtractor(
419421
eval_shared_model, materialize=materialize),
420422
slice_key_extractor.SliceKeyExtractor(
421-
slice_spec, materialize=materialize)
423+
eval_config=eval_config, materialize=materialize)
422424
]
423425
elif eval_shared_model:
424426
model_types = _model_types(eval_shared_model)
@@ -435,7 +437,7 @@ def default_extractors( # pylint: disable=invalid-name
435437
tflite_predict_extractor.TFLitePredictExtractor(
436438
eval_config=eval_config, eval_shared_model=eval_shared_model),
437439
slice_key_extractor.SliceKeyExtractor(
438-
slice_spec, materialize=materialize)
440+
eval_config=eval_config, materialize=materialize)
439441
]
440442
elif constants.TF_LITE in model_types:
441443
raise NotImplementedError(
@@ -451,7 +453,7 @@ def default_extractors( # pylint: disable=invalid-name
451453
materialize=materialize,
452454
eval_config=eval_config),
453455
slice_key_extractor.SliceKeyExtractor(
454-
slice_spec, materialize=materialize)
456+
eval_config=eval_config, materialize=materialize)
455457
]
456458
elif (eval_config and constants.TF_ESTIMATOR in model_types and
457459
any(eval_constants.EVAL_TAG in m.model_loader.tags
@@ -470,15 +472,15 @@ def default_extractors( # pylint: disable=invalid-name
470472
tensor_adapter_config=tensor_adapter_config),
471473
unbatch_extractor.UnbatchExtractor(),
472474
slice_key_extractor.SliceKeyExtractor(
473-
slice_spec, materialize=materialize)
475+
eval_config=eval_config, materialize=materialize)
474476
]
475477
else:
476478
return [
477479
input_extractor.InputExtractor(eval_config=eval_config),
478480
predict_extractor_v2.PredictExtractor(
479481
eval_config=eval_config, eval_shared_model=eval_shared_model),
480482
slice_key_extractor.SliceKeyExtractor(
481-
slice_spec, materialize=materialize)
483+
eval_config=eval_config, materialize=materialize)
482484
]
483485
else:
484486
if enable_batched_extractors:
@@ -487,13 +489,13 @@ def default_extractors( # pylint: disable=invalid-name
487489
eval_config=eval_config),
488490
unbatch_extractor.UnbatchExtractor(),
489491
slice_key_extractor.SliceKeyExtractor(
490-
slice_spec, materialize=materialize)
492+
eval_config=eval_config, materialize=materialize)
491493
]
492494
else:
493495
return [
494496
input_extractor.InputExtractor(eval_config=eval_config),
495497
slice_key_extractor.SliceKeyExtractor(
496-
slice_spec, materialize=materialize)
498+
eval_config=eval_config, materialize=materialize)
497499
]
498500

499501

@@ -1117,6 +1119,7 @@ def single_model_analysis(
11171119
model_location: Text,
11181120
data_location: Text,
11191121
output_path: Text = None,
1122+
eval_config: Optional[config.EvalConfig] = None,
11201123
slice_spec: Optional[List[slicer.SingleSliceSpec]] = None
11211124
) -> view_types.EvalResult:
11221125
"""Run model analysis for a single model on a single data set.
@@ -1130,7 +1133,8 @@ def single_model_analysis(
11301133
data_location: The location of the data files.
11311134
output_path: The directory to output metrics and results to. If None, we use
11321135
a temporary directory.
1133-
slice_spec: A list of tfma.slicer.SingleSliceSpec.
1136+
eval_config: Eval config.
1137+
slice_spec: Deprecated (use EvalConfig).
11341138
11351139
Returns:
11361140
An EvalResult that can be used with the TFMA visualization functions.
@@ -1141,8 +1145,11 @@ def single_model_analysis(
11411145
if not tf.io.gfile.exists(output_path):
11421146
tf.io.gfile.makedirs(output_path)
11431147

1144-
eval_config = config.EvalConfig(
1145-
slicing_specs=[s.to_proto() for s in slice_spec])
1148+
if slice_spec and eval_config:
1149+
raise ValueError('slice_spec is deprecated, only use eval_config')
1150+
if slice_spec:
1151+
eval_config = config.EvalConfig(
1152+
slicing_specs=[s.to_proto() for s in slice_spec])
11461153

11471154
return run_model_analysis(
11481155
eval_config=eval_config,

tensorflow_model_analysis/api/model_eval_lib_test.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from tensorflow_model_analysis.post_export_metrics import metric_keys
5151
from tensorflow_model_analysis.post_export_metrics import post_export_metrics
5252
from tensorflow_model_analysis.proto import validation_result_pb2
53-
from tensorflow_model_analysis.slicer import slicer_lib as slicer
5453
from tensorflow_model_analysis.view import view_types
5554

5655
from google.protobuf import text_format
@@ -230,14 +229,14 @@ def testRunModelAnalysisExtraFieldsPlusFeatureExtraction(self):
230229
eval_config = config.EvalConfig(slicing_specs=slicing_specs)
231230
eval_shared_model = model_eval_lib.default_eval_shared_model(
232231
eval_saved_model_path=model_location, example_weight_key='age')
233-
slice_spec = [slicer.SingleSliceSpec(spec=slicing_specs[0])]
234232
extractors_with_feature_extraction = [
235233
predict_extractor.PredictExtractor(
236234
eval_shared_model, desired_batch_size=3, materialize=False),
237235
feature_extractor.FeatureExtractor(
238236
extract_source=constants.INPUT_KEY,
239237
extract_dest=constants.FEATURES_PREDICTIONS_LABELS_KEY),
240-
slice_key_extractor.SliceKeyExtractor(slice_spec, materialize=False)
238+
slice_key_extractor.SliceKeyExtractor(
239+
eval_config=eval_config, materialize=False)
241240
]
242241
eval_result = model_eval_lib.run_model_analysis(
243242
eval_config=eval_config,
@@ -1171,10 +1170,13 @@ def testMultipleModelAnalysis(self):
11711170
self._makeExample(age=5.0, language='chinese', label=1.0)
11721171
]
11731172
data_location = self._writeTFExamplesToTFRecords(examples)
1173+
eval_config = config.EvalConfig(slicing_specs=[
1174+
config.SlicingSpec(feature_values={'language': 'english'})
1175+
])
11741176
eval_results = model_eval_lib.multiple_model_analysis(
11751177
[model_location_1, model_location_2],
11761178
data_location,
1177-
slice_spec=[slicer.SingleSliceSpec(features=[('language', 'english')])])
1179+
eval_config=eval_config)
11781180
# We only check some of the metrics to ensure that the end-to-end
11791181
# pipeline works.
11801182
self.assertLen(eval_results._results, 2)
@@ -1213,9 +1215,12 @@ def testMultipleDataAnalysis(self):
12131215
])
12141216
data_location_2 = self._writeTFExamplesToTFRecords(
12151217
[self._makeExample(age=4.0, language='english', label=1.0)])
1218+
eval_config = config.EvalConfig(slicing_specs=[
1219+
config.SlicingSpec(feature_values={'language': 'english'})
1220+
])
12161221
eval_results = model_eval_lib.multiple_data_analysis(
12171222
model_location, [data_location_1, data_location_2],
1218-
slice_spec=[slicer.SingleSliceSpec(features=[('language', 'english')])])
1223+
eval_config=eval_config)
12191224
self.assertLen(eval_results._results, 2)
12201225
# We only check some of the metrics to ensure that the end-to-end
12211226
# pipeline works.

0 commit comments

Comments
 (0)