Skip to content

Commit bc32ec1

Browse files
mdrevestf-model-analysis-team
authored and
tf-model-analysis-team
committed
Updated pipeline to validate the config at the beginning of the pipeline. Fix typo in plot name. Release notes formatting.
PiperOrigin-RevId: 290375740
1 parent cd18c1e commit bc32ec1

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

RELEASE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
- Adding an option to "Select all" metrics in UI.
4949

5050
## Breaking changes
51+
5152
* Updated proto config to remove input/output data specs in favor of passing
5253
them directly to the run_eval.
5354

tensorflow_model_analysis/api/model_eval_lib.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,8 @@ def ExtractEvaluateAndWriteResults( # pylint: disable=invalid-name
658658
desired_batch_size: Optional batch size for batching in Predict.
659659
660660
Raises:
661-
ValueError: If matching Extractor not found for an Evaluator.
661+
ValueError: If EvalConfig invalid or matching Extractor not found for an
662+
Evaluator.
662663
663664
Returns:
664665
PDone.
@@ -691,6 +692,16 @@ def ExtractEvaluateAndWriteResults( # pylint: disable=invalid-name
691692
eval_config = config.EvalConfig(
692693
model_specs=model_specs, slicing_specs=slicing_specs, options=options)
693694

695+
# Add default ModelSpec if empty.
696+
if (eval_shared_models and len(eval_shared_models) == 1 and
697+
not eval_config.model_specs):
698+
tmp_config = config.EvalConfig()
699+
tmp_config.CopyFrom(eval_config)
700+
eval_config = tmp_config
701+
eval_config.model_specs.add()
702+
703+
config.verify_eval_config(eval_config)
704+
694705
if not extractors:
695706
extractors = default_extractors(
696707
eval_config=eval_config,

tensorflow_model_analysis/metrics/metric_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def default_binary_classification_specs(
261261
calibration.MeanLabel(name='mean_label'),
262262
calibration.MeanPrediction(name='mean_prediction'),
263263
calibration.Calibration(name='calibration'),
264-
confusion_matrix_plot.AUCPlot(name='confusion_matrix_plot'),
264+
confusion_matrix_plot.ConfusionMatrixPlot(name='confusion_matrix_plot'),
265265
calibration_plot.CalibrationPlot(name='calibration_plot')
266266
]
267267
if include_loss:

0 commit comments

Comments
 (0)