Skip to content

Commit e8bcd90

Browse files
authored
Merge pull request #25 from xinzha623/master
Project import generated by Copybara.
2 parents 556302d + 1ea7b29 commit e8bcd90

File tree

2 files changed

+54
-16
lines changed

2 files changed

+54
-16
lines changed

tensorflow_model_analysis/api/impl/evaluate.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,39 @@ def _ExtractOutput( # pylint: disable=invalid-name
288288
main=_ExtractOutputDoFn.OUTPUT_TAG_METRICS)
289289

290290

291+
def PredictExtractor(eval_saved_model_path, add_metrics_callbacks,
292+
shared_handle, desired_batch_size):
293+
# Map function which loads and runs the eval_saved_model against every
294+
# example, yielding an types.ExampleAndExtracts containing a
295+
# FeaturesPredictionsLabels value (where key is 'fpl').
296+
return types.Extractor(
297+
stage_name='Predict',
298+
ptransform=predict_extractor.TFMAPredict(
299+
eval_saved_model_path=eval_saved_model_path,
300+
add_metrics_callbacks=add_metrics_callbacks,
301+
shared_handle=shared_handle,
302+
desired_batch_size=desired_batch_size))
303+
304+
305+
@beam.ptransform_fn
306+
def Extract(examples, extractors):
307+
"""Performs Extractions serially in provided order."""
308+
augmented = examples
309+
310+
for extractor in extractors:
311+
augmented = augmented | extractor.stage_name >> extractor.ptransform
312+
313+
return augmented
314+
315+
291316
@beam.ptransform_fn
292317
# No typehint for output type, since it's a multi-output DoFn result that
293318
# Beam doesn't support typehints for yet (BEAM-3280).
294319
def Evaluate(
295320
# pylint: disable=invalid-name
296321
examples,
297322
eval_saved_model_path,
323+
extractors = None,
298324
add_metrics_callbacks = None,
299325
slice_spec = None,
300326
desired_batch_size = None,
@@ -309,6 +335,8 @@ def Evaluate(
309335
(e.g. string containing CSV row, TensorFlow.Example, etc).
310336
eval_saved_model_path: Path to EvalSavedModel. This directory should contain
311337
the saved_model.pb file.
338+
extractors: Optional list of Extractors to execute prior to slicing and
339+
aggregating the metrics. If not provided, a default set will be run.
312340
add_metrics_callbacks: Optional list of callbacks for adding additional
313341
metrics to the graph. The names of the metrics added by the callbacks
314342
should not conflict with existing metrics, or metrics added by other
@@ -349,24 +377,22 @@ def add_metrics_callback(features_dict, predictions_dict, labels):
349377

350378
shared_handle = shared.Shared()
351379

380+
if not extractors:
381+
extractors = [
382+
PredictExtractor(eval_saved_model_path, add_metrics_callbacks,
383+
shared_handle, desired_batch_size),
384+
]
385+
352386
# pylint: disable=no-value-for-parameter
353387
return (
354388
examples
355389
# Our diagnostic outputs, pass types.ExampleAndExtracts throughout,
356390
# however our aggregating functions do not use this interface.
357391
| 'ToExampleAndExtracts' >>
358392
beam.Map(lambda x: types.ExampleAndExtracts(example=x, extracts={}))
393+
| Extract(extractors=extractors)
359394

360-
# Map function which loads and runs the eval_saved_model against every
361-
# example, yielding an types.ExampleAndExtracts containing a
362-
# FeaturesPredictionsLabels value (where key is 'fpl').
363-
| 'Predict' >> predict_extractor.TFMAPredict(
364-
eval_saved_model_path=eval_saved_model_path,
365-
add_metrics_callbacks=add_metrics_callbacks,
366-
shared_handle=shared_handle,
367-
desired_batch_size=desired_batch_size)
368-
369-
# Input: one example fpl at a time
395+
# Input: one example at a time
370396
# Output: one fpl example per slice key (notice that the example turns
371397
# into n, replicated once per applicable slice key)
372398
| 'Slice' >> slice_api.Slice(slice_spec)
@@ -395,6 +421,7 @@ def BuildDiagnosticTable(
395421
# pylint: disable=invalid-name
396422
examples,
397423
eval_saved_model_path,
424+
extractors = None,
398425
desired_batch_size = None):
399426
"""Build diagnostics for the spacified EvalSavedModel and example collection.
400427
@@ -403,18 +430,24 @@ def BuildDiagnosticTable(
403430
(e.g. string containing CSV row, TensorFlow.Example, etc).
404431
eval_saved_model_path: Path to EvalSavedModel. This directory should contain
405432
the saved_model.pb file.
433+
extractors: Optional list of Extractors to execute prior to slicing and
434+
aggregating the metrics. If not provided, a default set will be run.
406435
desired_batch_size: Optional batch size for batching in Predict and
407436
Aggregate.
408437
409438
Returns:
410439
PCollection of ExampleAndExtracts
411440
"""
441+
442+
if not extractors:
443+
extractors = [
444+
PredictExtractor(eval_saved_model_path, None, shared.Shared(),
445+
desired_batch_size),
446+
types.Extractor(
447+
stage_name='ExtractFeatures',
448+
ptransform=feature_extractor.ExtractFeatures()),
449+
]
412450
return (examples
413451
| 'ToExampleAndExtracts' >>
414452
beam.Map(lambda x: types.ExampleAndExtracts(example=x, extracts={}))
415-
| 'Predict' >> predict_extractor.TFMAPredict(
416-
eval_saved_model_path,
417-
add_metrics_callbacks=None,
418-
shared_handle=shared.Shared(),
419-
desired_batch_size=desired_batch_size)
420-
| 'ExtractFeatures' >> feature_extractor.ExtractFeatures())
453+
| Extract(extractors=extractors))

tensorflow_model_analysis/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import copy
2222

23+
import apache_beam as beam
2324
import numpy as np
2425
import tensorflow as tf
2526

@@ -66,6 +67,10 @@ def is_tensor(obj):
6667
DictOfExtractedValues = Dict[Text, Any]
6768

6869

70+
Extractor = NamedTuple('Extractor', [('stage_name', bytes),
71+
('ptransform', beam.PTransform)])
72+
73+
6974
class ExampleAndExtracts(
7075
NamedTuple('ExampleAndExtracts', [('example', bytes),
7176
('extracts', DictOfExtractedValues)])):

0 commit comments

Comments
 (0)