Skip to content

Commit 81bad33

Browse files
mdrevestf-model-analysis-team
authored and
tf-model-analysis-team
committed
Wrap binarize and disabled_outputs options into an additional proto message so that we can better add default values when no value is given.
PiperOrigin-RevId: 292483034
1 parent 0bfe88e commit 81bad33

File tree

9 files changed

+56
-35
lines changed

9 files changed

+56
-35
lines changed

RELEASE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
## Breaking changes
1414

15+
* `tfma.BinarizeOptions.class_ids`, `tfma.BinarizeOptions.k_list`,
16+
`tfma.BinarizeOptions.top_k_list`, and `tfma.Options.disabled_outputs` are
17+
now wrapped in an additional proto message.
18+
1519
## Deprecations
1620

1721
# Release 0.21.0

g3doc/metrics.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ from google.protobuf import text_format
192192

193193
metrics_specs = text_format.Parse("""
194194
metrics_specs {
195-
binarize: { class_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] }
195+
binarize: { class_ids: { values: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } }
196196
// Metrics to binarize
197197
metrics { class_name: "AUC" }
198198
...
@@ -209,7 +209,8 @@ metrics = [
209209
...
210210
]
211211
metrics_specs = tfma.metrics.specs_from_metrics(
212-
metrics, binarize=tfma.BinarizationOptions(class_ids=[0,1,2,3,4,5,6,7,8,9]))
212+
metrics, binarize=tfma.BinarizationOptions(
213+
class_ids={'values': [0,1,2,3,4,5,6,7,8,9]}))
213214
```
214215

215216
### Multi-class/Multi-label Aggregate Metrics
@@ -259,7 +260,7 @@ from google.protobuf import text_format
259260

260261
metrics_specs = text_format.Parse("""
261262
metrics_specs {
262-
binarize: { class_ids: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] }
263+
binarize: { class_ids: { values: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } }
263264
aggregeate: { macro_average: true }
264265
// Metrics to both binarize and aggregate
265266
metrics { class_name: "AUC" }
@@ -278,7 +279,8 @@ metrics = [
278279
]
279280
metrics_specs = tfma.metrics.specs_from_metrics(
280281
metrics,
281-
binarize=tfma.BinarizationOptions(class_ids=[0,1,2,3,4,5,6,7,8,9]),
282+
binarize=tfma.BinarizationOptions(
283+
class_ids={'values': [0,1,2,3,4,5,6,7,8,9]}),
282284
aggregate=tfma.AggregationOptions(macro_average=True))
283285
```
284286

@@ -293,7 +295,7 @@ from google.protobuf import text_format
293295
metrics_specs = text_format.Parse("""
294296
metrics_specs {
295297
query_key: "doc_id"
296-
binarize { top_k: [1, 2] }
298+
binarize { top_k_list: { values: [1, 2] } }
297299
metrics { class_name: "NDCG" config: '"gain_key": "gain"' }
298300
}
299301
metrics_specs {
@@ -310,7 +312,8 @@ metrics = [
310312
tfma.metrics.NDCG(name='ndcg', gain_key='gain'),
311313
]
312314
metrics_specs = tfma.metrics.specs_from_metrics(
313-
metrics, query_key='doc_id', binarize=tfma.BinarizationOptions(top_k=[1,2]))
315+
metrics, query_key='doc_id', binarize=tfma.BinarizationOptions(
316+
top_k_list={'values': [1,2]}))
314317

315318
metrics = [
316319
tfma.metrics.MinLabelPosition(name='min_label_position')

tensorflow_model_analysis/api/model_eval_lib.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,6 @@ def default_evaluators( # pylint: disable=invalid-name
363363
desired_batch_size: Optional[int] = None,
364364
serialize: bool = False,
365365
random_seed_for_testing: Optional[int] = None) -> List[evaluator.Evaluator]:
366-
367366
"""Returns the default evaluators for use in ExtractAndEvaluate.
368367
369368
Args:
@@ -378,8 +377,8 @@ def default_evaluators( # pylint: disable=invalid-name
378377
random_seed_for_testing: Provide for deterministic tests only.
379378
"""
380379
disabled_outputs = []
381-
if eval_config and eval_config.options:
382-
disabled_outputs = eval_config.options.disabled_outputs
380+
if eval_config:
381+
disabled_outputs = eval_config.options.disabled_outputs.values
383382
if (constants.METRICS_KEY in disabled_outputs and
384383
constants.PLOTS_KEY in disabled_outputs):
385384
return []
@@ -702,7 +701,7 @@ def ExtractEvaluateAndWriteResults( # pylint: disable=invalid-name
702701
options.compute_confidence_intervals.value = compute_confidence_intervals
703702
options.k_anonymization_count.value = k_anonymization_count
704703
if not write_config:
705-
options.disabled_outputs.append(_EVAL_CONFIG_FILE)
704+
options.disabled_outputs.values.append(_EVAL_CONFIG_FILE)
706705
eval_config = config.EvalConfig(
707706
model_specs=model_specs, slicing_specs=slicing_specs, options=options)
708707

@@ -744,7 +743,7 @@ def ExtractEvaluateAndWriteResults( # pylint: disable=invalid-name
744743
extractors=extractors, evaluators=evaluators)
745744
| 'WriteResults' >> WriteResults(writers=writers))
746745

747-
if _EVAL_CONFIG_FILE not in eval_config.options.disabled_outputs:
746+
if _EVAL_CONFIG_FILE not in eval_config.options.disabled_outputs.values:
748747
data_location = '<user provided PCollection>'
749748
if display_only_data_location is not None:
750749
data_location = display_only_data_location
@@ -781,7 +780,8 @@ def run_model_analysis(
781780
compute_confidence_intervals: Optional[bool] = False,
782781
k_anonymization_count: int = 1,
783782
desired_batch_size: Optional[int] = None,
784-
random_seed_for_testing: Optional[int] = None) -> Union[EvalResult, EvalResults]:
783+
random_seed_for_testing: Optional[int] = None
784+
) -> Union[EvalResult, EvalResults]:
785785
"""Runs TensorFlow model analysis.
786786
787787
It runs a Beam pipeline to compute the slicing metrics exported in TensorFlow
@@ -856,7 +856,7 @@ def run_model_analysis(
856856
options.compute_confidence_intervals.value = compute_confidence_intervals
857857
options.k_anonymization_count.value = k_anonymization_count
858858
if not write_config:
859-
options.disabled_outputs.append(_EVAL_CONFIG_FILE)
859+
options.disabled_outputs.values.append(_EVAL_CONFIG_FILE)
860860
eval_config = config.EvalConfig(
861861
model_specs=model_specs, slicing_specs=slicing_specs, options=options)
862862

tensorflow_model_analysis/api/model_eval_lib_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def testRunModelAnalysisWithKerasModel(self):
402402
config.MetricConfig(
403403
class_name=cfg['class_name'], config=json.dumps(cfg['config'])))
404404
for class_id in (0, 5, 9):
405-
metrics_spec.binarize.class_ids.append(class_id)
405+
metrics_spec.binarize.class_ids.values.append(class_id)
406406
eval_config = config.EvalConfig(
407407
model_specs=[config.ModelSpec(label_key='label')],
408408
metrics_specs=[metrics_spec])
@@ -470,7 +470,7 @@ def testRunModelAnalysisWithQueryBasedMetrics(self):
470470
slicing_specs=slicing_specs,
471471
metrics_specs=metric_specs.specs_from_metrics(
472472
[ndcg.NDCG(gain_key='age', name='ndcg')],
473-
binarize=config.BinarizationOptions(top_k_list=[1]),
473+
binarize=config.BinarizationOptions(top_k_list={'values': [1]}),
474474
query_key='language'))
475475
eval_shared_model = model_eval_lib.default_eval_shared_model(
476476
eval_saved_model_path=model_location, tags=[tf.saved_model.SERVING])

tensorflow_model_analysis/evaluators/metrics_and_plots_evaluator_v2_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,8 @@ def testEvaluateWithMultiClassModel(self):
598598
slicing_specs=[config.SlicingSpec()],
599599
metrics_specs=metric_specs.specs_from_metrics(
600600
[calibration.MeanLabel('mean_label')],
601-
binarize=config.BinarizationOptions(class_ids=range(n_classes))))
601+
binarize=config.BinarizationOptions(
602+
class_ids={'values': range(n_classes)})))
602603
eval_shared_model = self.createTestEvalSharedModel(
603604
eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])
604605

@@ -898,7 +899,7 @@ def testEvaluateWithQueryBasedMetrics(self):
898899
],
899900
metrics_specs=metric_specs.specs_from_metrics(
900901
[ndcg.NDCG(gain_key='fixed_float', name='ndcg')],
901-
binarize=config.BinarizationOptions(top_k_list=[1, 2]),
902+
binarize=config.BinarizationOptions(top_k_list={'values': [1, 2]}),
902903
query_key='fixed_string'))
903904
eval_shared_model = self.createTestEvalSharedModel(
904905
eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

tensorflow_model_analysis/metrics/metric_specs.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -305,13 +305,13 @@ def default_multi_class_classification_specs(
305305
metrics.append(
306306
multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot())
307307
if binarize is not None:
308-
for top_k in binarize.top_k_list:
308+
for top_k in binarize.top_k_list.values:
309309
metrics.extend([
310310
tf.keras.metrics.Precision(name='precision', top_k=top_k),
311311
tf.keras.metrics.Recall(name='recall', top_k=top_k)
312312
])
313313
binarize = config.BinarizationOptions().CopyFrom(binarize)
314-
binarize.ClearField('top_k')
314+
binarize.ClearField('top_k_list')
315315
multi_class_metrics = specs_from_metrics(
316316
metrics, model_names=model_names, output_names=output_names)
317317
if aggregate is None:
@@ -526,14 +526,14 @@ def _create_sub_keys(
526526
sub_keys = None
527527
if spec.HasField('binarize'):
528528
sub_keys = []
529-
if spec.binarize.class_ids:
530-
for v in spec.binarize.class_ids:
529+
if spec.binarize.class_ids.values:
530+
for v in spec.binarize.class_ids.values:
531531
sub_keys.append(metric_types.SubKey(class_id=v))
532-
if spec.binarize.k_list:
533-
for v in spec.binarize.k_list:
532+
if spec.binarize.k_list.values:
533+
for v in spec.binarize.k_list.values:
534534
sub_keys.append(metric_types.SubKey(k=v))
535-
if spec.binarize.top_k_list:
536-
for v in spec.binarize.top_k_list:
535+
if spec.binarize.top_k_list.values:
536+
for v in spec.binarize.top_k_list.values:
537537
sub_keys.append(metric_types.SubKey(top_k=v))
538538
if spec.aggregate.micro_average:
539539
# Micro averaging is performed by flattening the labels and predictions

tensorflow_model_analysis/metrics/metric_specs_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def testSpecsFromMetrics(self):
4242
]
4343
},
4444
model_names=['model_name1', 'model_name2'],
45-
binarize=config.BinarizationOptions(class_ids=[0, 1]),
45+
binarize=config.BinarizationOptions(class_ids={'values': [0, 1]}),
4646
aggregate=config.AggregationOptions(macro_average=True))
4747

4848
self.assertLen(metrics_specs, 5)
@@ -80,7 +80,7 @@ def testSpecsFromMetrics(self):
8080
],
8181
model_names=['model_name1', 'model_name2'],
8282
output_names=['output_name1'],
83-
binarize=config.BinarizationOptions(class_ids=[0, 1]),
83+
binarize=config.BinarizationOptions(class_ids={'values': [0, 1]}),
8484
aggregate=config.AggregationOptions(macro_average=True)))
8585
self.assertProtoEquals(
8686
metrics_specs[3],
@@ -109,7 +109,7 @@ def testSpecsFromMetrics(self):
109109
],
110110
model_names=['model_name1', 'model_name2'],
111111
output_names=['output_name2'],
112-
binarize=config.BinarizationOptions(class_ids=[0, 1]),
112+
binarize=config.BinarizationOptions(class_ids={'values': [0, 1]}),
113113
aggregate=config.AggregationOptions(macro_average=True)))
114114

115115
def testToComputations(self):
@@ -122,7 +122,7 @@ def testToComputations(self):
122122
]
123123
},
124124
model_names=['model_name'],
125-
binarize=config.BinarizationOptions(class_ids=[0, 1]),
125+
binarize=config.BinarizationOptions(class_ids={'values': [0, 1]}),
126126
aggregate=config.AggregationOptions(macro_average=True)),
127127
config.EvalConfig())
128128

tensorflow_model_analysis/proto/config.proto

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,19 @@ message AggregationOptions {
119119
message BinarizationOptions {
120120
// Creates binary classification metrics based on one-vs-rest for each
121121
// value of class_id provided.
122-
repeated int32 class_ids = 1;
122+
RepeatedInt32Value class_ids = 4;
123123
// Creates binary classification metrics based on the kth predicted value
124124
// for each value of k provided.
125-
repeated int32 k_list = 2;
125+
RepeatedInt32Value k_list = 5;
126126
// Creates binary classification metrics based on the top k predicted values
127127
// for each value of top_k provided. When used to create calibration plots
128128
// the histogram will contain a mix of all labels and predictions in the top
129129
// k predictions. Note that precision@k and recall@k can also be configured
130130
// directly as multi-class classification metrics by setting top_k on the
131131
// metric itself.
132-
repeated int32 top_k_list = 3;
132+
RepeatedInt32Value top_k_list = 6;
133+
134+
reserved 1, 2, 3;
133135
}
134136

135137
// Metric configuration.
@@ -183,9 +185,9 @@ message Options {
183185
google.protobuf.Int32Value k_anonymization_count = 3;
184186
// List of outputs that should not be written (e.g. 'metrics', 'plots',
185187
// 'analysis', 'eval_config.json').
186-
repeated string disabled_outputs = 6;
188+
RepeatedStringValue disabled_outputs = 7;
187189

188-
reserved 4, 5;
190+
reserved 4, 5, 6;
189191
}
190192

191193
// Tensorflow model analaysis config settings.
@@ -219,6 +221,16 @@ message EvalConfig {
219221
reserved 1, 3, 7;
220222
}
221223

224+
// Repeated string value. Used to allow a default if no values are given.
225+
message RepeatedStringValue {
226+
repeated string values = 1;
227+
}
228+
229+
// Repeated int32 value. Used to allow a default if no values are given.
230+
message RepeatedInt32Value {
231+
repeated int32 values = 1;
232+
}
233+
222234
// Config and version.
223235
message EvalConfigAndVersion {
224236
EvalConfig eval_config = 1;

tensorflow_model_analysis/writers/metrics_and_plots_writer_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def testWriteMetricsAndPlots(self):
5959
None, temp_eval_export_dir))
6060
eval_config = config.EvalConfig(
6161
model_specs=[config.ModelSpec()],
62-
options=config.Options(disabled_outputs=['eval_config.json']))
62+
options=config.Options(
63+
disabled_outputs={'values': ['eval_config.json']}))
6364
eval_shared_model = self.createTestEvalSharedModel(
6465
eval_saved_model_path=eval_export_dir,
6566
add_metrics_callbacks=[

0 commit comments

Comments
 (0)