13
13
# limitations under the License.
14
14
"""Counterfactual predictions extractor."""
15
15
16
- from typing import Dict , Iterable , Mapping , Optional , Sequence
16
+ from typing import Dict , Iterable , Mapping , Optional
17
17
18
18
import apache_beam as beam
19
19
import numpy as np
30
30
# The extracts key under which the non-CF INPUT_KEY value is temporarily stored,
31
31
# when invoking one or more PredictionsExtractors on modified inputs.
32
32
_TEMP_ORIG_INPUT_KEY = 'non_counterfactual_input'
33
- # The extracts key under which the accumulated PREDICTIONS_KEY value is
34
- # temporarily stored, when invoking a PredictionsExtractor more than once. This
35
- # is necessary because the PredictionsExtractor clobbers incoming values under
36
- # the PREDICITION_KEY.
37
- _TEMP_PREV_PREDICTIONS_KEY = 'counterfactual_prev_predictions'
38
33
CounterfactualConfig = Dict [str , str ]
39
34
40
35
@@ -92,26 +87,33 @@ def CounterfactualPredictionsExtractor( # pylint: disable=invalid-name
92
87
# TODO(b/258850519): Refactor default_extractors logic to expose new api
93
88
# for constructing the default predictions extractor and call it here.
94
89
predictions_ptransform = predictions_extractor .PredictionsExtractor (
95
- eval_shared_model = model , eval_config = cf_eval_config ).ptransform
90
+ eval_shared_model = model ,
91
+ eval_config = cf_eval_config ,
92
+ output_keypath = (constants .PREDICTIONS_KEY , model .model_name ),
93
+ ).ptransform
96
94
cf_ptransforms [model .model_name ] = _ExtractCounterfactualPredictions ( # pylint: disable=no-value-for-parameter
97
- model_name = model .model_name ,
98
95
config = cf_config ,
99
96
predictions_ptransform = predictions_ptransform )
100
97
else :
101
98
non_cf_models .append (model )
102
99
non_cf_eval_config = _filter_model_specs (eval_config , non_cf_models )
103
100
if non_cf_models :
101
+ output_keypath = (constants .PREDICTIONS_KEY ,)
102
+ if len (non_cf_models ) == 1 :
103
+ output_keypath = output_keypath + (non_cf_models [0 ].model_name ,)
104
104
non_cf_ptransform = predictions_extractor .PredictionsExtractor (
105
105
eval_shared_model = non_cf_models ,
106
- eval_config = non_cf_eval_config ).ptransform
106
+ eval_config = non_cf_eval_config ,
107
+ output_keypath = output_keypath ,
108
+ ).ptransform
107
109
else :
108
110
non_cf_ptransform = None
109
111
return extractor .Extractor (
110
112
stage_name = _COUNTERFACTUAL_PREDICTIONS_EXTRACTOR_NAME ,
111
113
ptransform = _ExtractPredictions ( # pylint: disable=no-value-for-parameter
112
- cf_ptransforms = cf_ptransforms ,
113
- non_cf_ptransform = non_cf_ptransform ,
114
- non_cf_model_names = [ model . model_name for model in non_cf_models ]) )
114
+ cf_ptransforms = cf_ptransforms , non_cf_ptransform = non_cf_ptransform
115
+ ) ,
116
+ )
115
117
116
118
117
119
def _validate_and_update_models_and_configs (
@@ -216,18 +218,10 @@ def _cf_preprocess(
216
218
cf_inputs .append (cf_example .SerializeToString ())
217
219
cf_inputs = np .array (cf_inputs , dtype = object )
218
220
result [constants .INPUT_KEY ] = cf_inputs
219
- # We stash pre-existing predictions because most predictions extractors
220
- # overwrite rather than update any existing value under PREDICTIONS_KEY.
221
- if constants .PREDICTIONS_KEY in result :
222
- result [_TEMP_PREV_PREDICTIONS_KEY ] = result [constants .PREDICTIONS_KEY ]
223
- del result [constants .PREDICTIONS_KEY ]
224
221
return result
225
222
226
223
227
- def _cf_postprocess (
228
- extracts : types .Extracts ,
229
- model_name : str ,
230
- ) -> types .Extracts :
224
+ def _cf_postprocess (extracts : types .Extracts ) -> types .Extracts :
231
225
"""Postprocesses the result of applying a CF prediction ptransform.
232
226
233
227
This method takes in an Extracts instance that has been prepocessed by
@@ -237,57 +231,38 @@ def _cf_postprocess(
237
231
Args:
238
232
extracts: An Extracts instance which has been preprocessed by _preprocess_cf
239
233
and gone through a prediction PTransform.
240
- model_name: The name of the model being post-processed. This is used to key
241
- the CF predictions by a model name before merging with previous
242
- predictions.
243
234
244
235
Returns:
245
236
An Extracts instance which appears to have been produced by a standard
246
237
predictions PTransform.
247
238
"""
248
239
extracts = extracts .copy ()
249
- if _TEMP_PREV_PREDICTIONS_KEY in extracts :
250
- prev_predictions = extracts [_TEMP_PREV_PREDICTIONS_KEY ]
251
- del extracts [_TEMP_PREV_PREDICTIONS_KEY ]
252
- else :
253
- prev_predictions = {}
254
- cf_predictions = {model_name : extracts [constants .PREDICTIONS_KEY ]}
255
- prev_predictions .update (cf_predictions )
256
- extracts [constants .PREDICTIONS_KEY ] = prev_predictions
257
240
extracts [constants .INPUT_KEY ] = extracts [_TEMP_ORIG_INPUT_KEY ]
258
241
del extracts [_TEMP_ORIG_INPUT_KEY ]
259
242
return extracts
260
243
261
244
262
- def _key_predictions_by_model_name (
263
- extracts : types .Extracts , model_names : Sequence [str ]) -> types .Extracts :
264
- if len (model_names ) == 1 :
265
- extracts [constants .PREDICTIONS_KEY ] = {
266
- model_names [0 ]: extracts [constants .PREDICTIONS_KEY ]
267
- }
268
- return extracts
269
-
270
-
271
245
@beam .ptransform_fn
272
246
def _ExtractCounterfactualPredictions ( # pylint: disable=invalid-name
273
- extracts : beam .PCollection [types .Extracts ], model_name : str ,
247
+ extracts : beam .PCollection [types .Extracts ],
274
248
config : CounterfactualConfig ,
275
- predictions_ptransform : beam .PTransform ) -> beam . PCollection [
276
- types .Extracts ]:
249
+ predictions_ptransform : beam .PTransform ,
250
+ ) -> beam . PCollection [ types .Extracts ]:
277
251
"""Computes counterfactual predictions for a single model."""
278
- return (extracts
279
- | 'PreprocessInputs' >> beam .Map (_cf_preprocess , config = config )
280
- | 'Predict' >> predictions_ptransform
281
- | 'PostProcessPredictions' >> beam .Map (
282
- _cf_postprocess , model_name = model_name ))
252
+ return (
253
+ extracts
254
+ | 'PreprocessInputs' >> beam .Map (_cf_preprocess , config = config )
255
+ | 'Predict' >> predictions_ptransform
256
+ | 'PostProcessPredictions' >> beam .Map (_cf_postprocess )
257
+ )
283
258
284
259
285
260
@beam .ptransform_fn
286
261
def _ExtractPredictions ( # pylint: disable=invalid-name
287
262
extracts : beam .PCollection [types .Extracts ],
288
263
cf_ptransforms : Dict [str , beam .PTransform ],
289
264
non_cf_ptransform : Optional [beam .PTransform ],
290
- non_cf_model_names : Sequence [ str ] ) -> beam .PCollection [types .Extracts ]:
265
+ ) -> beam .PCollection [types .Extracts ]:
291
266
"""Applies both CF and non-CF prediction ptransforms and merges results.
292
267
293
268
Args:
@@ -296,22 +271,13 @@ def _ExtractPredictions( # pylint: disable=invalid-name
296
271
_ExtractCounterfactualPredictions ptransforms
297
272
non_cf_ptransform: Optionally, a ptransform responsible for computing the
298
273
non-counterfactual predictions.
299
- non_cf_model_names: The names of the models for which predictions will be
300
- generated by the non_cf_ptransform. This is only used to restore a model
301
- to the result of the the non_cf_ptransform in the event that only a single
302
- model was provided, in which case extracts[constants.PREDICTIONS_KEY] will
303
- either be a single tensor, or a dict of per-output tensors.
304
274
305
275
Returns:
306
276
A PCollection of extracts containing merged predictions from both
307
277
counterfactual and non-counterfactual models.
308
278
"""
309
279
if non_cf_ptransform :
310
- extracts = (
311
- extracts
312
- | 'PredictNonCF' >> non_cf_ptransform
313
- | 'KeyNonCFPredictionsByModelName' >> beam .Map (
314
- _key_predictions_by_model_name , model_names = non_cf_model_names ))
280
+ extracts = extracts | 'PredictNonCF' >> non_cf_ptransform
315
281
for model_name , cf_ptransform in cf_ptransforms .items ():
316
282
extracts = extracts | f'PredictCF[{ model_name } ]' >> cf_ptransform
317
283
return extracts
0 commit comments