Skip to content

Commit 8e78aa8

Browse files
huanmingftf-model-analysis-team
huanmingf
authored and
tf-model-analysis-team
committed
Added default thresholds to the metrics.
PiperOrigin-RevId: 292575370
1 parent df00139 commit 8e78aa8

File tree

2 files changed

+188
-62
lines changed

2 files changed

+188
-62
lines changed

tensorflow_model_analysis/addons/fairness/metrics/fairness_indicators.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@
3737
'negative_rate',
3838
)
3939

40+
DEFAULT_THERSHOLDS = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
41+
4042

4143
class FairnessIndicators(metric_types.Metric):
4244
"""Fairness indicators metrics."""
4345

4446
def __init__(self,
45-
thresholds: List[float],
47+
thresholds: List[float] = DEFAULT_THERSHOLDS,
4648
name: Text = FAIRNESS_INDICATORS_METRICS_NAME):
4749
"""Initializes fairness indicators metrics.
4850
@@ -57,6 +59,11 @@ def __init__(self,
5759
name=name)
5860

5961

62+
def calculate_digits(thresholds):
63+
digits = [len(str(t)) - 2 for t in thresholds]
64+
return max(max(digits), 1)
65+
66+
6067
def _fairness_indicators_metrics_at_thresholds(
6168
thresholds: List[float],
6269
name: Text = FAIRNESS_INDICATORS_METRICS_NAME,
@@ -69,11 +76,13 @@ def _fairness_indicators_metrics_at_thresholds(
6976
"""Returns computations for fairness metrics at thresholds."""
7077
metric_key_by_name_by_threshold = collections.defaultdict(dict)
7178
keys = []
79+
digits_num = calculate_digits(thresholds)
7280
for t in thresholds:
7381
for m in FAIRNESS_INDICATORS_SUB_METRICS:
7482
key = metric_types.MetricKey(
75-
name='%s/%s@%s' %
76-
(name, m, t), # e.g. "fairness_indicators_metrics/positive_rate@0.5"
83+
name='%s/%s@%.*f' %
84+
(name, m, digits_num,
85+
t), # e.g. "fairness_indicators_metrics/positive_rate@0.5"
7786
model_name=model_name,
7887
output_name=output_name,
7988
sub_key=sub_key)

tensorflow_model_analysis/addons/fairness/metrics/fairness_indicators_test.py

Lines changed: 176 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
import math
22+
from absl.testing import parameterized
2223
import apache_beam as beam
2324
from apache_beam.testing import util
2425
import numpy as np
@@ -29,9 +30,10 @@
2930
from tensorflow_model_analysis.metrics import metric_util
3031

3132

32-
class FairnessIndicatorsTest(testutil.TensorflowModelAnalysisTest):
33+
class FairnessIndicatorsTest(testutil.TensorflowModelAnalysisTest,
34+
parameterized.TestCase):
3335

34-
def testFairessIndicatorsMetrics(self):
36+
def testFairessIndicatorsMetricsGeneral(self):
3537
computations = fairness_indicators.FairnessIndicators(
3638
thresholds=[0.3, 0.7]).computations()
3739
histogram = computations[0]
@@ -79,76 +81,48 @@ def check_result(got):
7981
self.assertDictElementsAlmostEqual(
8082
got_metrics, {
8183
metric_types.MetricKey(
82-
name='fairness_indicators_metrics/false_positive_rate@0.3',
83-
model_name='',
84-
output_name='',
85-
sub_key=None):
84+
name='fairness_indicators_metrics/false_positive_rate@0.3'
85+
):
8686
0.5,
8787
metric_types.MetricKey(
88-
name='fairness_indicators_metrics/false_negative_rate@0.3',
89-
model_name='',
90-
output_name='',
91-
sub_key=None):
88+
name='fairness_indicators_metrics/false_negative_rate@0.3'
89+
):
9290
0.0,
9391
metric_types.MetricKey(
94-
name='fairness_indicators_metrics/true_positive_rate@0.3',
95-
model_name='',
96-
output_name='',
97-
sub_key=None):
92+
name='fairness_indicators_metrics/true_positive_rate@0.3'
93+
):
9894
1.0,
9995
metric_types.MetricKey(
100-
name='fairness_indicators_metrics/true_negative_rate@0.3',
101-
model_name='',
102-
output_name='',
103-
sub_key=None):
96+
name='fairness_indicators_metrics/true_negative_rate@0.3'
97+
):
10498
0.5,
10599
metric_types.MetricKey(
106-
name='fairness_indicators_metrics/positive_rate@0.3',
107-
model_name='',
108-
output_name='',
109-
sub_key=None):
100+
name='fairness_indicators_metrics/positive_rate@0.3'):
110101
0.75,
111102
metric_types.MetricKey(
112-
name='fairness_indicators_metrics/negative_rate@0.3',
113-
model_name='',
114-
output_name='',
115-
sub_key=None):
103+
name='fairness_indicators_metrics/negative_rate@0.3'):
116104
0.25,
117105
metric_types.MetricKey(
118-
name='fairness_indicators_metrics/false_positive_rate@0.7',
119-
model_name='',
120-
output_name='',
121-
sub_key=None):
106+
name='fairness_indicators_metrics/false_positive_rate@0.7'
107+
):
122108
0.0,
123109
metric_types.MetricKey(
124-
name='fairness_indicators_metrics/false_negative_rate@0.7',
125-
model_name='',
126-
output_name='',
127-
sub_key=None):
110+
name='fairness_indicators_metrics/false_negative_rate@0.7'
111+
):
128112
0.5,
129113
metric_types.MetricKey(
130-
name='fairness_indicators_metrics/true_positive_rate@0.7',
131-
model_name='',
132-
output_name='',
133-
sub_key=None):
114+
name='fairness_indicators_metrics/true_positive_rate@0.7'
115+
):
134116
0.5,
135117
metric_types.MetricKey(
136-
name='fairness_indicators_metrics/true_negative_rate@0.7',
137-
model_name='',
138-
output_name='',
139-
sub_key=None):
118+
name='fairness_indicators_metrics/true_negative_rate@0.7'
119+
):
140120
1.0,
141121
metric_types.MetricKey(
142-
name='fairness_indicators_metrics/positive_rate@0.7',
143-
model_name='',
144-
output_name='',
145-
sub_key=None):
122+
name='fairness_indicators_metrics/positive_rate@0.7'):
146123
0.25,
147124
metric_types.MetricKey(
148-
name='fairness_indicators_metrics/negative_rate@0.7',
149-
model_name='',
150-
output_name='',
151-
sub_key=None):
125+
name='fairness_indicators_metrics/negative_rate@0.7'):
152126
0.75
153127
})
154128
except AssertionError as err:
@@ -184,7 +158,6 @@ def testFairessIndicatorsMetricsWithNanValue(self):
184158
lambda x: (x[0], matrices.result(x[1]))) # pyformat: ignore
185159
| 'ComputeMetrics' >> beam.Map(lambda x: (x[0], metrics.result(x[1])))
186160
) # pyformat: ignore
187-
188161
# pylint: enable=no-value-for-parameter
189162

190163
def check_result(got):
@@ -195,22 +168,166 @@ def check_result(got):
195168
self.assertLen(got_metrics, 6) # 1 threshold * 6 metrics
196169
self.assertTrue(
197170
math.isnan(got_metrics[metric_types.MetricKey(
198-
name='fairness_indicators_metrics/false_negative_rate@0.5',
199-
model_name='',
200-
output_name='',
201-
sub_key=None)]))
171+
name='fairness_indicators_metrics/false_negative_rate@0.5')]))
202172
self.assertTrue(
203173
math.isnan(got_metrics[metric_types.MetricKey(
204-
name='fairness_indicators_metrics/true_positive_rate@0.5',
205-
model_name='',
206-
output_name='',
207-
sub_key=None)]))
174+
name='fairness_indicators_metrics/true_positive_rate@0.5')]))
175+
176+
except AssertionError as err:
177+
raise util.BeamAssertException(err)
178+
179+
util.assert_that(result, check_result, label='result')
180+
181+
@parameterized.named_parameters(
182+
('_default_threshold', {}, 54, ()),
183+
('_thresholds_with_different_digits', {
184+
'thresholds': [0.1, 0.22, 0.333]
185+
}, 18,
186+
(metric_types.MetricKey(
187+
name='fairness_indicators_metrics/false_positive_rate@0.100'),
188+
metric_types.MetricKey(
189+
name='fairness_indicators_metrics/false_positive_rate@0.220'),
190+
metric_types.MetricKey(
191+
name='fairness_indicators_metrics/false_positive_rate@0.333'))))
192+
def testFairessIndicatorsMetricsWithThresholds(self, kwargs,
193+
expected_metrics_nums,
194+
expected_metrics_keys):
195+
computations = fairness_indicators.FairnessIndicators(
196+
**kwargs).computations()
197+
histogram = computations[0]
198+
matrices = computations[1]
199+
metrics = computations[2]
200+
examples = [{
201+
'labels': np.array([0.0]),
202+
'predictions': np.array([0.1]),
203+
'example_weights': np.array([1.0]),
204+
}, {
205+
'labels': np.array([0.0]),
206+
'predictions': np.array([0.7]),
207+
'example_weights': np.array([3.0]),
208+
}]
209+
210+
with beam.Pipeline() as pipeline:
211+
# pylint: disable=no-value-for-parameter
212+
result = (
213+
pipeline
214+
| 'Create' >> beam.Create(examples)
215+
| 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
216+
| 'AddSlice' >> beam.Map(lambda x: ((), x))
217+
| 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)
218+
| 'ComputeMatrices' >> beam.Map(
219+
lambda x: (x[0], matrices.result(x[1]))) # pyformat: ignore
220+
| 'ComputeMetrics' >> beam.Map(lambda x: (x[0], metrics.result(x[1])))
221+
) # pyformat: ignore
222+
223+
# pylint: enable=no-value-for-parameter
224+
225+
def check_result(got):
226+
try:
227+
self.assertLen(got, 1)
228+
got_slice_key, got_metrics = got[0]
229+
self.assertEqual(got_slice_key, ())
230+
self.assertLen(got_metrics, expected_metrics_nums)
231+
for metrics_key in expected_metrics_keys:
232+
self.assertIn(metrics_key, got_metrics)
233+
except AssertionError as err:
234+
raise util.BeamAssertException(err)
235+
236+
util.assert_that(result, check_result, label='result')
208237

238+
@parameterized.named_parameters(('_has_weight', [{
239+
'labels': np.array([0.0]),
240+
'predictions': np.array([0.1]),
241+
'example_weights': np.array([1.0]),
242+
}, {
243+
'labels': np.array([0.0]),
244+
'predictions': np.array([0.7]),
245+
'example_weights': np.array([3.0]),
246+
}], {}, {
247+
metric_types.MetricKey(
248+
name='fairness_indicators_metrics/negative_rate@0.5'):
249+
0.25,
250+
metric_types.MetricKey(
251+
name='fairness_indicators_metrics/positive_rate@0.5'):
252+
0.75,
253+
metric_types.MetricKey(
254+
name='fairness_indicators_metrics/true_negative_rate@0.5'):
255+
0.25,
256+
metric_types.MetricKey(
257+
name='fairness_indicators_metrics/false_positive_rate@0.5'):
258+
0.75
259+
}), ('_has_model_name', [{
260+
'labels': np.array([0.0]),
261+
'predictions': {
262+
'model1': np.array([0.1]),
263+
},
264+
'example_weights': np.array([1.0]),
265+
}, {
266+
'labels': np.array([0.0]),
267+
'predictions': {
268+
'model1': np.array([0.7]),
269+
},
270+
'example_weights': np.array([3.0]),
271+
}], {
272+
'model_names': ['model1']
273+
}, {
274+
metric_types.MetricKey(
275+
name='fairness_indicators_metrics/negative_rate@0.5',
276+
model_name='model1'):
277+
0.25,
278+
metric_types.MetricKey(
279+
name='fairness_indicators_metrics/positive_rate@0.5',
280+
model_name='model1'):
281+
0.75,
282+
metric_types.MetricKey(
283+
name='fairness_indicators_metrics/true_negative_rate@0.5',
284+
model_name='model1'):
285+
0.25,
286+
metric_types.MetricKey(
287+
name='fairness_indicators_metrics/false_positive_rate@0.5',
288+
model_name='model1'):
289+
0.75
290+
}))
291+
def testFairessIndicatorsMetricsWithInput(self, input_examples,
292+
computations_kwargs,
293+
expected_result):
294+
computations = fairness_indicators.FairnessIndicators(
295+
thresholds=[0.5]).computations(**computations_kwargs)
296+
histogram = computations[0]
297+
matrices = computations[1]
298+
metrics = computations[2]
299+
300+
with beam.Pipeline() as pipeline:
301+
# pylint: disable=no-value-for-parameter
302+
result = (
303+
pipeline
304+
| 'Create' >> beam.Create(input_examples)
305+
| 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
306+
| 'AddSlice' >> beam.Map(lambda x: ((), x))
307+
| 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)
308+
| 'ComputeMatrices' >> beam.Map(
309+
lambda x: (x[0], matrices.result(x[1]))) # pyformat: ignore
310+
| 'ComputeMetrics' >> beam.Map(lambda x: (x[0], metrics.result(x[1])))
311+
) # pyformat: ignore
312+
313+
# pylint: enable=no-value-for-parameter
314+
315+
def check_result(got):
316+
try:
317+
self.assertLen(got, 1)
318+
got_slice_key, got_metrics = got[0]
319+
self.assertEqual(got_slice_key, ())
320+
self.assertLen(got_metrics, 6) # 1 threshold * 6 metrics
321+
for metrics_key in expected_result:
322+
self.assertEqual(got_metrics[metrics_key],
323+
expected_result[metrics_key])
209324
except AssertionError as err:
210325
raise util.BeamAssertException(err)
211326

212327
util.assert_that(result, check_result, label='result')
213328

214329

330+
# Todo(b/147497357): Add counter test once we have counter setup.
331+
215332
if __name__ == '__main__':
216333
tf.test.main()

0 commit comments

Comments
 (0)