Skip to content

Commit 5645dc7

Browse files
genehwungtfx-copybara
authored andcommitted
Refactor counterfactual prediction extractor to use the new output_keypath in prediction_extractor.
PiperOrigin-RevId: 554686228
1 parent 5341e91 commit 5645dc7

File tree

2 files changed

+26
-63
lines changed

2 files changed

+26
-63
lines changed

tensorflow_model_analysis/extractors/counterfactual_predictions_extractor.py

Lines changed: 26 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Counterfactual predictions extractor."""
1515

16-
from typing import Dict, Iterable, Mapping, Optional, Sequence
16+
from typing import Dict, Iterable, Mapping, Optional
1717

1818
import apache_beam as beam
1919
import numpy as np
@@ -30,11 +30,6 @@
3030
# The extracts key under which the non-CF INPUT_KEY value is temporarily stored,
3131
# when invoking one or more PredictionsExtractors on modified inputs.
3232
_TEMP_ORIG_INPUT_KEY = 'non_counterfactual_input'
33-
# The extracts key under which the accumulated PREDICTIONS_KEY value is
34-
# temporarily stored, when invoking a PredictionsExtractor more than once. This
35-
# is necessary because the PredictionsExtractor clobbers incoming values under
36-
# the PREDICITION_KEY.
37-
_TEMP_PREV_PREDICTIONS_KEY = 'counterfactual_prev_predictions'
3833
CounterfactualConfig = Dict[str, str]
3934

4035

@@ -92,26 +87,33 @@ def CounterfactualPredictionsExtractor( # pylint: disable=invalid-name
9287
# TODO(b/258850519): Refactor default_extractors logic to expose new api
9388
# for constructing the default predictions extractor and call it here.
9489
predictions_ptransform = predictions_extractor.PredictionsExtractor(
95-
eval_shared_model=model, eval_config=cf_eval_config).ptransform
90+
eval_shared_model=model,
91+
eval_config=cf_eval_config,
92+
output_keypath=(constants.PREDICTIONS_KEY, model.model_name),
93+
).ptransform
9694
cf_ptransforms[model.model_name] = _ExtractCounterfactualPredictions( # pylint: disable=no-value-for-parameter
97-
model_name=model.model_name,
9895
config=cf_config,
9996
predictions_ptransform=predictions_ptransform)
10097
else:
10198
non_cf_models.append(model)
10299
non_cf_eval_config = _filter_model_specs(eval_config, non_cf_models)
103100
if non_cf_models:
101+
output_keypath = (constants.PREDICTIONS_KEY,)
102+
if len(non_cf_models) == 1:
103+
output_keypath = output_keypath + (non_cf_models[0].model_name,)
104104
non_cf_ptransform = predictions_extractor.PredictionsExtractor(
105105
eval_shared_model=non_cf_models,
106-
eval_config=non_cf_eval_config).ptransform
106+
eval_config=non_cf_eval_config,
107+
output_keypath=output_keypath,
108+
).ptransform
107109
else:
108110
non_cf_ptransform = None
109111
return extractor.Extractor(
110112
stage_name=_COUNTERFACTUAL_PREDICTIONS_EXTRACTOR_NAME,
111113
ptransform=_ExtractPredictions( # pylint: disable=no-value-for-parameter
112-
cf_ptransforms=cf_ptransforms,
113-
non_cf_ptransform=non_cf_ptransform,
114-
non_cf_model_names=[model.model_name for model in non_cf_models]))
114+
cf_ptransforms=cf_ptransforms, non_cf_ptransform=non_cf_ptransform
115+
),
116+
)
115117

116118

117119
def _validate_and_update_models_and_configs(
@@ -216,18 +218,10 @@ def _cf_preprocess(
216218
cf_inputs.append(cf_example.SerializeToString())
217219
cf_inputs = np.array(cf_inputs, dtype=object)
218220
result[constants.INPUT_KEY] = cf_inputs
219-
# We stash pre-existing predictions because most predictions extractors
220-
# overwrite rather than update any existing value under PREDICTIONS_KEY.
221-
if constants.PREDICTIONS_KEY in result:
222-
result[_TEMP_PREV_PREDICTIONS_KEY] = result[constants.PREDICTIONS_KEY]
223-
del result[constants.PREDICTIONS_KEY]
224221
return result
225222

226223

227-
def _cf_postprocess(
228-
extracts: types.Extracts,
229-
model_name: str,
230-
) -> types.Extracts:
224+
def _cf_postprocess(extracts: types.Extracts) -> types.Extracts:
231225
"""Postprocesses the result of applying a CF prediction ptransform.
232226
233227
This method takes in an Extracts instance that has been prepocessed by
@@ -237,57 +231,38 @@ def _cf_postprocess(
237231
Args:
238232
extracts: An Extracts instance which has been preprocessed by _preprocess_cf
239233
and gone through a prediction PTransform.
240-
model_name: The name of the model being post-processed. This is used to key
241-
the CF predictions by a model name before merging with previous
242-
predictions.
243234
244235
Returns:
245236
An Extracts instance which appears to have been produced by a standard
246237
predictions PTransform.
247238
"""
248239
extracts = extracts.copy()
249-
if _TEMP_PREV_PREDICTIONS_KEY in extracts:
250-
prev_predictions = extracts[_TEMP_PREV_PREDICTIONS_KEY]
251-
del extracts[_TEMP_PREV_PREDICTIONS_KEY]
252-
else:
253-
prev_predictions = {}
254-
cf_predictions = {model_name: extracts[constants.PREDICTIONS_KEY]}
255-
prev_predictions.update(cf_predictions)
256-
extracts[constants.PREDICTIONS_KEY] = prev_predictions
257240
extracts[constants.INPUT_KEY] = extracts[_TEMP_ORIG_INPUT_KEY]
258241
del extracts[_TEMP_ORIG_INPUT_KEY]
259242
return extracts
260243

261244

262-
def _key_predictions_by_model_name(
263-
extracts: types.Extracts, model_names: Sequence[str]) -> types.Extracts:
264-
if len(model_names) == 1:
265-
extracts[constants.PREDICTIONS_KEY] = {
266-
model_names[0]: extracts[constants.PREDICTIONS_KEY]
267-
}
268-
return extracts
269-
270-
271245
@beam.ptransform_fn
272246
def _ExtractCounterfactualPredictions( # pylint: disable=invalid-name
273-
extracts: beam.PCollection[types.Extracts], model_name: str,
247+
extracts: beam.PCollection[types.Extracts],
274248
config: CounterfactualConfig,
275-
predictions_ptransform: beam.PTransform) -> beam.PCollection[
276-
types.Extracts]:
249+
predictions_ptransform: beam.PTransform,
250+
) -> beam.PCollection[types.Extracts]:
277251
"""Computes counterfactual predictions for a single model."""
278-
return (extracts
279-
| 'PreprocessInputs' >> beam.Map(_cf_preprocess, config=config)
280-
| 'Predict' >> predictions_ptransform
281-
| 'PostProcessPredictions' >> beam.Map(
282-
_cf_postprocess, model_name=model_name))
252+
return (
253+
extracts
254+
| 'PreprocessInputs' >> beam.Map(_cf_preprocess, config=config)
255+
| 'Predict' >> predictions_ptransform
256+
| 'PostProcessPredictions' >> beam.Map(_cf_postprocess)
257+
)
283258

284259

285260
@beam.ptransform_fn
286261
def _ExtractPredictions( # pylint: disable=invalid-name
287262
extracts: beam.PCollection[types.Extracts],
288263
cf_ptransforms: Dict[str, beam.PTransform],
289264
non_cf_ptransform: Optional[beam.PTransform],
290-
non_cf_model_names: Sequence[str]) -> beam.PCollection[types.Extracts]:
265+
) -> beam.PCollection[types.Extracts]:
291266
"""Applies both CF and non-CF prediction ptransforms and merges results.
292267
293268
Args:
@@ -296,22 +271,13 @@ def _ExtractPredictions( # pylint: disable=invalid-name
296271
_ExtractCounterfactualPredictions ptransforms
297272
non_cf_ptransform: Optionally, a ptransform responsible for computing the
298273
non-counterfactual predictions.
299-
non_cf_model_names: The names of the models for which predictions will be
300-
generated by the non_cf_ptransform. This is only used to restore a model
301-
to the result of the the non_cf_ptransform in the event that only a single
302-
model was provided, in which case extracts[constants.PREDICTIONS_KEY] will
303-
either be a single tensor, or a dict of per-output tensors.
304274
305275
Returns:
306276
A PCollection of extracts containing merged predictions from both
307277
counterfactual and non-counterfactual models.
308278
"""
309279
if non_cf_ptransform:
310-
extracts = (
311-
extracts
312-
| 'PredictNonCF' >> non_cf_ptransform
313-
| 'KeyNonCFPredictionsByModelName' >> beam.Map(
314-
_key_predictions_by_model_name, model_names=non_cf_model_names))
280+
extracts = extracts | 'PredictNonCF' >> non_cf_ptransform
315281
for model_name, cf_ptransform in cf_ptransforms.items():
316282
extracts = extracts | f'PredictCF[{model_name}]' >> cf_ptransform
317283
return extracts

tensorflow_model_analysis/extractors/counterfactual_predictions_extractor_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,6 @@ def check_result(got):
245245
f'expected:{expected_predictions}'))
246246
self.assertNotIn(
247247
counterfactual_predictions_extractor._TEMP_ORIG_INPUT_KEY, got[0])
248-
self.assertNotIn(
249-
counterfactual_predictions_extractor._TEMP_PREV_PREDICTIONS_KEY,
250-
got[0])
251248
except AssertionError as err:
252249
raise util.BeamAssertException(err)
253250

0 commit comments

Comments
 (0)