19
19
from __future__ import print_function
20
20
21
21
import math
22
+ from absl .testing import parameterized
22
23
import apache_beam as beam
23
24
from apache_beam .testing import util
24
25
import numpy as np
29
30
from tensorflow_model_analysis .metrics import metric_util
30
31
31
32
32
- class FairnessIndicatorsTest (testutil .TensorflowModelAnalysisTest ):
33
+ class FairnessIndicatorsTest (testutil .TensorflowModelAnalysisTest ,
34
+ parameterized .TestCase ):
33
35
34
- def testFairessIndicatorsMetrics (self ):
36
+ def testFairessIndicatorsMetricsGeneral (self ):
35
37
computations = fairness_indicators .FairnessIndicators (
36
38
thresholds = [0.3 , 0.7 ]).computations ()
37
39
histogram = computations [0 ]
@@ -79,76 +81,48 @@ def check_result(got):
79
81
self .assertDictElementsAlmostEqual (
80
82
got_metrics , {
81
83
metric_types .MetricKey (
82
- name = 'fairness_indicators_metrics/false_positive_rate@0.3' ,
83
- model_name = '' ,
84
- output_name = '' ,
85
- sub_key = None ):
84
+ name = 'fairness_indicators_metrics/false_positive_rate@0.3'
85
+ ):
86
86
0.5 ,
87
87
metric_types .MetricKey (
88
- name = 'fairness_indicators_metrics/false_negative_rate@0.3' ,
89
- model_name = '' ,
90
- output_name = '' ,
91
- sub_key = None ):
88
+ name = 'fairness_indicators_metrics/false_negative_rate@0.3'
89
+ ):
92
90
0.0 ,
93
91
metric_types .MetricKey (
94
- name = 'fairness_indicators_metrics/true_positive_rate@0.3' ,
95
- model_name = '' ,
96
- output_name = '' ,
97
- sub_key = None ):
92
+ name = 'fairness_indicators_metrics/true_positive_rate@0.3'
93
+ ):
98
94
1.0 ,
99
95
metric_types .MetricKey (
100
- name = 'fairness_indicators_metrics/true_negative_rate@0.3' ,
101
- model_name = '' ,
102
- output_name = '' ,
103
- sub_key = None ):
96
+ name = 'fairness_indicators_metrics/true_negative_rate@0.3'
97
+ ):
104
98
0.5 ,
105
99
metric_types .MetricKey (
106
- name = 'fairness_indicators_metrics/positive_rate@0.3' ,
107
- model_name = '' ,
108
- output_name = '' ,
109
- sub_key = None ):
100
+ name = 'fairness_indicators_metrics/positive_rate@0.3' ):
110
101
0.75 ,
111
102
metric_types .MetricKey (
112
- name = 'fairness_indicators_metrics/negative_rate@0.3' ,
113
- model_name = '' ,
114
- output_name = '' ,
115
- sub_key = None ):
103
+ name = 'fairness_indicators_metrics/negative_rate@0.3' ):
116
104
0.25 ,
117
105
metric_types .MetricKey (
118
- name = 'fairness_indicators_metrics/false_positive_rate@0.7' ,
119
- model_name = '' ,
120
- output_name = '' ,
121
- sub_key = None ):
106
+ name = 'fairness_indicators_metrics/false_positive_rate@0.7'
107
+ ):
122
108
0.0 ,
123
109
metric_types .MetricKey (
124
- name = 'fairness_indicators_metrics/false_negative_rate@0.7' ,
125
- model_name = '' ,
126
- output_name = '' ,
127
- sub_key = None ):
110
+ name = 'fairness_indicators_metrics/false_negative_rate@0.7'
111
+ ):
128
112
0.5 ,
129
113
metric_types .MetricKey (
130
- name = 'fairness_indicators_metrics/true_positive_rate@0.7' ,
131
- model_name = '' ,
132
- output_name = '' ,
133
- sub_key = None ):
114
+ name = 'fairness_indicators_metrics/true_positive_rate@0.7'
115
+ ):
134
116
0.5 ,
135
117
metric_types .MetricKey (
136
- name = 'fairness_indicators_metrics/true_negative_rate@0.7' ,
137
- model_name = '' ,
138
- output_name = '' ,
139
- sub_key = None ):
118
+ name = 'fairness_indicators_metrics/true_negative_rate@0.7'
119
+ ):
140
120
1.0 ,
141
121
metric_types .MetricKey (
142
- name = 'fairness_indicators_metrics/positive_rate@0.7' ,
143
- model_name = '' ,
144
- output_name = '' ,
145
- sub_key = None ):
122
+ name = 'fairness_indicators_metrics/positive_rate@0.7' ):
146
123
0.25 ,
147
124
metric_types .MetricKey (
148
- name = 'fairness_indicators_metrics/negative_rate@0.7' ,
149
- model_name = '' ,
150
- output_name = '' ,
151
- sub_key = None ):
125
+ name = 'fairness_indicators_metrics/negative_rate@0.7' ):
152
126
0.75
153
127
})
154
128
except AssertionError as err :
@@ -184,7 +158,6 @@ def testFairessIndicatorsMetricsWithNanValue(self):
184
158
lambda x : (x [0 ], matrices .result (x [1 ]))) # pyformat: ignore
185
159
| 'ComputeMetrics' >> beam .Map (lambda x : (x [0 ], metrics .result (x [1 ])))
186
160
) # pyformat: ignore
187
-
188
161
# pylint: enable=no-value-for-parameter
189
162
190
163
def check_result (got ):
@@ -195,22 +168,166 @@ def check_result(got):
195
168
self .assertLen (got_metrics , 6 ) # 1 threshold * 6 metrics
196
169
self .assertTrue (
197
170
math .isnan (got_metrics [metric_types .MetricKey (
198
- name = 'fairness_indicators_metrics/false_negative_rate@0.5' ,
199
- model_name = '' ,
200
- output_name = '' ,
201
- sub_key = None )]))
171
+ name = 'fairness_indicators_metrics/false_negative_rate@0.5' )]))
202
172
self .assertTrue (
203
173
math .isnan (got_metrics [metric_types .MetricKey (
204
- name = 'fairness_indicators_metrics/true_positive_rate@0.5' ,
205
- model_name = '' ,
206
- output_name = '' ,
207
- sub_key = None )]))
174
+ name = 'fairness_indicators_metrics/true_positive_rate@0.5' )]))
175
+
176
+ except AssertionError as err :
177
+ raise util .BeamAssertException (err )
178
+
179
+ util .assert_that (result , check_result , label = 'result' )
180
+
181
+ @parameterized .named_parameters (
182
+ ('_default_threshold' , {}, 54 , ()),
183
+ ('_thresholds_with_different_digits' , {
184
+ 'thresholds' : [0.1 , 0.22 , 0.333 ]
185
+ }, 18 ,
186
+ (metric_types .MetricKey (
187
+ name = 'fairness_indicators_metrics/false_positive_rate@0.100' ),
188
+ metric_types .MetricKey (
189
+ name = 'fairness_indicators_metrics/false_positive_rate@0.220' ),
190
+ metric_types .MetricKey (
191
+ name = 'fairness_indicators_metrics/false_positive_rate@0.333' ))))
192
+ def testFairessIndicatorsMetricsWithThresholds (self , kwargs ,
193
+ expected_metrics_nums ,
194
+ expected_metrics_keys ):
195
+ computations = fairness_indicators .FairnessIndicators (
196
+ ** kwargs ).computations ()
197
+ histogram = computations [0 ]
198
+ matrices = computations [1 ]
199
+ metrics = computations [2 ]
200
+ examples = [{
201
+ 'labels' : np .array ([0.0 ]),
202
+ 'predictions' : np .array ([0.1 ]),
203
+ 'example_weights' : np .array ([1.0 ]),
204
+ }, {
205
+ 'labels' : np .array ([0.0 ]),
206
+ 'predictions' : np .array ([0.7 ]),
207
+ 'example_weights' : np .array ([3.0 ]),
208
+ }]
209
+
210
+ with beam .Pipeline () as pipeline :
211
+ # pylint: disable=no-value-for-parameter
212
+ result = (
213
+ pipeline
214
+ | 'Create' >> beam .Create (examples )
215
+ | 'Process' >> beam .Map (metric_util .to_standard_metric_inputs )
216
+ | 'AddSlice' >> beam .Map (lambda x : ((), x ))
217
+ | 'ComputeHistogram' >> beam .CombinePerKey (histogram .combiner )
218
+ | 'ComputeMatrices' >> beam .Map (
219
+ lambda x : (x [0 ], matrices .result (x [1 ]))) # pyformat: ignore
220
+ | 'ComputeMetrics' >> beam .Map (lambda x : (x [0 ], metrics .result (x [1 ])))
221
+ ) # pyformat: ignore
222
+
223
+ # pylint: enable=no-value-for-parameter
224
+
225
+ def check_result (got ):
226
+ try :
227
+ self .assertLen (got , 1 )
228
+ got_slice_key , got_metrics = got [0 ]
229
+ self .assertEqual (got_slice_key , ())
230
+ self .assertLen (got_metrics , expected_metrics_nums )
231
+ for metrics_key in expected_metrics_keys :
232
+ self .assertIn (metrics_key , got_metrics )
233
+ except AssertionError as err :
234
+ raise util .BeamAssertException (err )
235
+
236
+ util .assert_that (result , check_result , label = 'result' )
208
237
238
+ @parameterized .named_parameters (('_has_weight' , [{
239
+ 'labels' : np .array ([0.0 ]),
240
+ 'predictions' : np .array ([0.1 ]),
241
+ 'example_weights' : np .array ([1.0 ]),
242
+ }, {
243
+ 'labels' : np .array ([0.0 ]),
244
+ 'predictions' : np .array ([0.7 ]),
245
+ 'example_weights' : np .array ([3.0 ]),
246
+ }], {}, {
247
+ metric_types .MetricKey (
248
+ name = 'fairness_indicators_metrics/negative_rate@0.5' ):
249
+ 0.25 ,
250
+ metric_types .MetricKey (
251
+ name = 'fairness_indicators_metrics/positive_rate@0.5' ):
252
+ 0.75 ,
253
+ metric_types .MetricKey (
254
+ name = 'fairness_indicators_metrics/true_negative_rate@0.5' ):
255
+ 0.25 ,
256
+ metric_types .MetricKey (
257
+ name = 'fairness_indicators_metrics/false_positive_rate@0.5' ):
258
+ 0.75
259
+ }), ('_has_model_name' , [{
260
+ 'labels' : np .array ([0.0 ]),
261
+ 'predictions' : {
262
+ 'model1' : np .array ([0.1 ]),
263
+ },
264
+ 'example_weights' : np .array ([1.0 ]),
265
+ }, {
266
+ 'labels' : np .array ([0.0 ]),
267
+ 'predictions' : {
268
+ 'model1' : np .array ([0.7 ]),
269
+ },
270
+ 'example_weights' : np .array ([3.0 ]),
271
+ }], {
272
+ 'model_names' : ['model1' ]
273
+ }, {
274
+ metric_types .MetricKey (
275
+ name = 'fairness_indicators_metrics/negative_rate@0.5' ,
276
+ model_name = 'model1' ):
277
+ 0.25 ,
278
+ metric_types .MetricKey (
279
+ name = 'fairness_indicators_metrics/positive_rate@0.5' ,
280
+ model_name = 'model1' ):
281
+ 0.75 ,
282
+ metric_types .MetricKey (
283
+ name = 'fairness_indicators_metrics/true_negative_rate@0.5' ,
284
+ model_name = 'model1' ):
285
+ 0.25 ,
286
+ metric_types .MetricKey (
287
+ name = 'fairness_indicators_metrics/false_positive_rate@0.5' ,
288
+ model_name = 'model1' ):
289
+ 0.75
290
+ }))
291
+ def testFairessIndicatorsMetricsWithInput (self , input_examples ,
292
+ computations_kwargs ,
293
+ expected_result ):
294
+ computations = fairness_indicators .FairnessIndicators (
295
+ thresholds = [0.5 ]).computations (** computations_kwargs )
296
+ histogram = computations [0 ]
297
+ matrices = computations [1 ]
298
+ metrics = computations [2 ]
299
+
300
+ with beam .Pipeline () as pipeline :
301
+ # pylint: disable=no-value-for-parameter
302
+ result = (
303
+ pipeline
304
+ | 'Create' >> beam .Create (input_examples )
305
+ | 'Process' >> beam .Map (metric_util .to_standard_metric_inputs )
306
+ | 'AddSlice' >> beam .Map (lambda x : ((), x ))
307
+ | 'ComputeHistogram' >> beam .CombinePerKey (histogram .combiner )
308
+ | 'ComputeMatrices' >> beam .Map (
309
+ lambda x : (x [0 ], matrices .result (x [1 ]))) # pyformat: ignore
310
+ | 'ComputeMetrics' >> beam .Map (lambda x : (x [0 ], metrics .result (x [1 ])))
311
+ ) # pyformat: ignore
312
+
313
+ # pylint: enable=no-value-for-parameter
314
+
315
+ def check_result (got ):
316
+ try :
317
+ self .assertLen (got , 1 )
318
+ got_slice_key , got_metrics = got [0 ]
319
+ self .assertEqual (got_slice_key , ())
320
+ self .assertLen (got_metrics , 6 ) # 1 threshold * 6 metrics
321
+ for metrics_key in expected_result :
322
+ self .assertEqual (got_metrics [metrics_key ],
323
+ expected_result [metrics_key ])
209
324
except AssertionError as err :
210
325
raise util .BeamAssertException (err )
211
326
212
327
util .assert_that (result , check_result , label = 'result' )
213
328
214
329
330
+ # Todo(b/147497357): Add counter test once we have counter setup.
331
+
215
332
if __name__ == '__main__' :
216
333
tf .test .main ()
0 commit comments