@@ -288,13 +288,39 @@ def _ExtractOutput( # pylint: disable=invalid-name
288
288
main = _ExtractOutputDoFn .OUTPUT_TAG_METRICS )
289
289
290
290
291
+ def PredictExtractor (eval_saved_model_path , add_metrics_callbacks ,
292
+ shared_handle , desired_batch_size ):
293
+ # Map function which loads and runs the eval_saved_model against every
294
+ # example, yielding an types.ExampleAndExtracts containing a
295
+ # FeaturesPredictionsLabels value (where key is 'fpl').
296
+ return types .Extractor (
297
+ stage_name = 'Predict' ,
298
+ ptransform = predict_extractor .TFMAPredict (
299
+ eval_saved_model_path = eval_saved_model_path ,
300
+ add_metrics_callbacks = add_metrics_callbacks ,
301
+ shared_handle = shared_handle ,
302
+ desired_batch_size = desired_batch_size ))
303
+
304
+
305
+ @beam .ptransform_fn
306
+ def Extract (examples , extractors ):
307
+ """Performs Extractions serially in provided order."""
308
+ augmented = examples
309
+
310
+ for extractor in extractors :
311
+ augmented = augmented | extractor .stage_name >> extractor .ptransform
312
+
313
+ return augmented
314
+
315
+
291
316
@beam .ptransform_fn
292
317
# No typehint for output type, since it's a multi-output DoFn result that
293
318
# Beam doesn't support typehints for yet (BEAM-3280).
294
319
def Evaluate (
295
320
# pylint: disable=invalid-name
296
321
examples ,
297
322
eval_saved_model_path ,
323
+ extractors = None ,
298
324
add_metrics_callbacks = None ,
299
325
slice_spec = None ,
300
326
desired_batch_size = None ,
@@ -309,6 +335,8 @@ def Evaluate(
309
335
(e.g. string containing CSV row, TensorFlow.Example, etc).
310
336
eval_saved_model_path: Path to EvalSavedModel. This directory should contain
311
337
the saved_model.pb file.
338
+ extractors: Optional list of Extractors to execute prior to slicing and
339
+ aggregating the metrics. If not provided, a default set will be run.
312
340
add_metrics_callbacks: Optional list of callbacks for adding additional
313
341
metrics to the graph. The names of the metrics added by the callbacks
314
342
should not conflict with existing metrics, or metrics added by other
@@ -349,24 +377,22 @@ def add_metrics_callback(features_dict, predictions_dict, labels):
349
377
350
378
shared_handle = shared .Shared ()
351
379
380
+ if not extractors :
381
+ extractors = [
382
+ PredictExtractor (eval_saved_model_path , add_metrics_callbacks ,
383
+ shared_handle , desired_batch_size ),
384
+ ]
385
+
352
386
# pylint: disable=no-value-for-parameter
353
387
return (
354
388
examples
355
389
# Our diagnostic outputs, pass types.ExampleAndExtracts throughout,
356
390
# however our aggregating functions do not use this interface.
357
391
| 'ToExampleAndExtracts' >>
358
392
beam .Map (lambda x : types .ExampleAndExtracts (example = x , extracts = {}))
393
+ | Extract (extractors = extractors )
359
394
360
- # Map function which loads and runs the eval_saved_model against every
361
- # example, yielding an types.ExampleAndExtracts containing a
362
- # FeaturesPredictionsLabels value (where key is 'fpl').
363
- | 'Predict' >> predict_extractor .TFMAPredict (
364
- eval_saved_model_path = eval_saved_model_path ,
365
- add_metrics_callbacks = add_metrics_callbacks ,
366
- shared_handle = shared_handle ,
367
- desired_batch_size = desired_batch_size )
368
-
369
- # Input: one example fpl at a time
395
+ # Input: one example at a time
370
396
# Output: one fpl example per slice key (notice that the example turns
371
397
# into n, replicated once per applicable slice key)
372
398
| 'Slice' >> slice_api .Slice (slice_spec )
@@ -395,6 +421,7 @@ def BuildDiagnosticTable(
395
421
# pylint: disable=invalid-name
396
422
examples ,
397
423
eval_saved_model_path ,
424
+ extractors = None ,
398
425
desired_batch_size = None ):
399
426
"""Build diagnostics for the spacified EvalSavedModel and example collection.
400
427
@@ -403,18 +430,24 @@ def BuildDiagnosticTable(
403
430
(e.g. string containing CSV row, TensorFlow.Example, etc).
404
431
eval_saved_model_path: Path to EvalSavedModel. This directory should contain
405
432
the saved_model.pb file.
433
+ extractors: Optional list of Extractors to execute prior to slicing and
434
+ aggregating the metrics. If not provided, a default set will be run.
406
435
desired_batch_size: Optional batch size for batching in Predict and
407
436
Aggregate.
408
437
409
438
Returns:
410
439
PCollection of ExampleAndExtracts
411
440
"""
441
+
442
+ if not extractors :
443
+ extractors = [
444
+ PredictExtractor (eval_saved_model_path , None , shared .Shared (),
445
+ desired_batch_size ),
446
+ types .Extractor (
447
+ stage_name = 'ExtractFeatures' ,
448
+ ptransform = feature_extractor .ExtractFeatures ()),
449
+ ]
412
450
return (examples
413
451
| 'ToExampleAndExtracts' >>
414
452
beam .Map (lambda x : types .ExampleAndExtracts (example = x , extracts = {}))
415
- | 'Predict' >> predict_extractor .TFMAPredict (
416
- eval_saved_model_path ,
417
- add_metrics_callbacks = None ,
418
- shared_handle = shared .Shared (),
419
- desired_batch_size = desired_batch_size )
420
- | 'ExtractFeatures' >> feature_extractor .ExtractFeatures ())
453
+ | Extract (extractors = extractors ))
0 commit comments