Skip to content

Commit 355a83c

Browse files
author
tf-model-analysis-team
committed
CrossSliceMetricThreshold updates for config.
PiperOrigin-RevId: 317525264
1 parent ff51128 commit 355a83c

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

tensorflow_model_analysis/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
# Define types here to avoid type errors between OSS and internal code.
2929
ModelSpec = config_pb2.ModelSpec
3030
SlicingSpec = config_pb2.SlicingSpec
31+
CrossSlicingSpec = config_pb2.CrossSlicingSpec
3132
BinarizationOptions = config_pb2.BinarizationOptions
3233
ConfidenceIntervalOptions = config_pb2.ConfidenceIntervalOptions
3334
AggregationOptions = config_pb2.AggregationOptions
@@ -40,6 +41,8 @@
4041
Options = config_pb2.Options
4142
PerSliceMetricThreshold = config_pb2.PerSliceMetricThreshold
4243
PerSliceMetricThresholds = config_pb2.PerSliceMetricThresholds
44+
CrossSliceMetricThreshold = config_pb2.CrossSliceMetricThreshold
45+
CrossSliceMetricThresholds = config_pb2.CrossSliceMetricThresholds
4346
EvalConfig = config_pb2.EvalConfig
4447

4548

tensorflow_model_analysis/proto/config.proto

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ message SlicingSpec {
8989
map<string, string> feature_values = 2;
9090
}
9191

92+
// Cross slicing specification.
93+
message CrossSlicingSpec {
94+
SlicingSpec baseline_spec = 1;
95+
repeated SlicingSpec slicing_specs = 2;
96+
}
97+
9298
// Options for aggregating multi-class / multi-label outputs.
9399
//
94100
// When used the associated MetricSpec metrics must be binary classification
@@ -211,6 +217,17 @@ message PerSliceMetricThresholds {
211217
repeated PerSliceMetricThreshold thresholds = 1;
212218
}
213219

220+
// Cross slice metric threshold.
221+
message CrossSliceMetricThreshold {
222+
// A list of cross slicing specs to apply threshold to.
223+
repeated CrossSlicingSpec cross_slicing_specs = 1;
224+
MetricThreshold threshold = 2;
225+
}
226+
227+
message CrossSliceMetricThresholds {
228+
repeated CrossSliceMetricThreshold thresholds = 1;
229+
}
230+
214231
// Metric configuration.
215232
message MetricConfig {
216233
// Name of a class derived for either tf.keras.metrics.Metric or
@@ -231,6 +248,8 @@ message MetricConfig {
231248
MetricThreshold threshold = 4;
232249
// Optional thresholds for model validation using specific slices.
233250
repeated PerSliceMetricThreshold per_slice_thresholds = 5;
251+
// Optional thresholds for model validation across slices.
252+
repeated CrossSliceMetricThreshold cross_slice_thresholds = 6;
234253
}
235254

236255
// Metrics specification.
@@ -266,6 +285,9 @@ message MetricsSpec {
266285
// Optional thresholds for model validation using specific slices (keyed by
267286
// the associated metric name - e.g. 'auc', etc).
268287
map<string, PerSliceMetricThresholds> per_slice_thresholds = 8;
288+
// Optional thresholds for model validation across slices (keyed by the
289+
// associated metric name - e.g. 'auc', etc).
290+
map<string, CrossSliceMetricThresholds> cross_slice_thresholds = 9;
269291
}
270292

271293
// Additional configuration options.
@@ -319,6 +341,11 @@ message EvalConfig {
319341
// Slices for all values in feature "country" crossed with value
320342
// "age:20".
321343
repeated SlicingSpec slicing_specs = 4;
344+
// A list of cross slicing specs where each spec represents a pair of slices
345+
// whose associated outputs should be compared. By default slices will be
346+
// created for both slicing_spec and baseline_spec if they do not already
347+
// exist in slicing_specs.
348+
repeated CrossSlicingSpec cross_slicing_specs = 8;
322349
// Metrics specifications.
323350
repeated MetricsSpec metrics_specs = 5;
324351
// Additional configuration options.

0 commit comments

Comments
 (0)