Skip to content

Commit b503fbd

Browse files
tf-model-analysis-teamtfx-copybara
tf-model-analysis-team
authored andcommitted
Pass a do_fn argument to make the TFLitePredictionDoFn and use _make_interpreter to make Interpreters.
PiperOrigin-RevId: 523199658
1 parent e03557b commit b503fbd

File tree

1 file changed

+36
-23
lines changed

1 file changed

+36
-23
lines changed

tensorflow_model_analysis/extractors/tflite_predict_extractor.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import collections
1717
import copy
18-
from typing import Dict, Sequence, Union
18+
from typing import Callable, Dict, Sequence, Union
1919

2020
from absl import logging
2121
import apache_beam as beam
@@ -28,13 +28,14 @@
2828
from tensorflow_model_analysis.utils import model_util
2929
from tensorflow_model_analysis.utils import util
3030

31+
_OpResolverType = tf.lite.experimental.OpResolverType
3132
_TFLITE_PREDICT_EXTRACTOR_STAGE_NAME = 'ExtractTFLitePredictions'
3233

3334

3435
# TODO(b/149981535) Determine if we should merge with RunInference.
3536
@beam.typehints.with_input_types(types.Extracts)
3637
@beam.typehints.with_output_types(types.Extracts)
37-
class _TFLitePredictionDoFn(model_util.BatchReducibleBatchedDoFnWithModels):
38+
class TFLitePredictionDoFn(model_util.BatchReducibleBatchedDoFnWithModels):
3839
"""A DoFn that loads tflite models and predicts."""
3940

4041
def __init__(self, eval_config: config_pb2.EvalConfig,
@@ -47,17 +48,19 @@ def setup(self):
4748
self._interpreters = {}
4849

4950
major, minor, _ = tf.version.VERSION.split('.')
51+
op_resolver_type = _OpResolverType.AUTO
52+
# TODO(b/207600661): drop BUILTIN_WITHOUT_DEFAULT_DELEGATES once the issue
53+
# is fixed.
54+
if int(major) > 2 or (int(major) == 2 and int(minor) >= 5):
55+
op_resolver_type = _OpResolverType.BUILTIN_WITHOUT_DEFAULT_DELEGATES
5056
for model_name, model_contents in self._loaded_models.items():
51-
# TODO(b/207600661): drop BUILTIN_WITHOUT_DEFAULT_DELEGATES once the issue
52-
# is fixed.
53-
if int(major) > 2 or (int(major) == 2 and int(minor) >= 5):
54-
self._interpreters[model_name] = tf.lite.Interpreter(
55-
model_content=model_contents.contents,
56-
experimental_op_resolver_type=tf.lite.experimental.OpResolverType
57-
.BUILTIN_WITHOUT_DEFAULT_DELEGATES)
58-
else:
59-
self._interpreters[model_name] = tf.lite.Interpreter(
60-
model_content=model_contents.contents)
57+
self._interpreters[model_name] = self._make_interpreter(
58+
model_content=model_contents.contents,
59+
experimental_op_resolver_type=op_resolver_type,
60+
)
61+
62+
def _make_interpreter(self, **kwargs) -> tf.lite.Interpreter:
63+
return tf.lite.Interpreter(**kwargs)
6164

6265
def _get_input_name_from_input_detail(self, input_detail):
6366
"""Get input name from input detail.
@@ -157,31 +160,37 @@ def _batch_reducible_process(
157160
@beam.typehints.with_input_types(types.Extracts)
158161
@beam.typehints.with_output_types(types.Extracts)
159162
def _ExtractTFLitePredictions( # pylint: disable=invalid-name
160-
extracts: beam.pvalue.PCollection, eval_config: config_pb2.EvalConfig,
161-
eval_shared_models: Dict[str,
162-
types.EvalSharedModel]) -> beam.pvalue.PCollection:
163+
extracts: beam.pvalue.PCollection,
164+
eval_config: config_pb2.EvalConfig,
165+
eval_shared_models: Dict[str, types.EvalSharedModel],
166+
do_fn: Callable[..., TFLitePredictionDoFn],
167+
) -> beam.pvalue.PCollection:
163168
"""A PTransform that adds predictions and possibly other tensors to extracts.
164169
165170
Args:
166171
extracts: PCollection of extracts containing model inputs keyed by
167172
tfma.FEATURES_KEY.
168173
eval_config: Eval config.
169174
eval_shared_models: Shared model parameters keyed by model name.
175+
do_fn: Constructor for TFLitePredictionDoFn.
170176
171177
Returns:
172178
PCollection of Extracts updated with the predictions.
173179
"""
174-
return (
175-
extracts
176-
| 'Predict' >> beam.ParDo(
177-
_TFLitePredictionDoFn(
178-
eval_config=eval_config, eval_shared_models=eval_shared_models)))
180+
return extracts | 'Predict' >> beam.ParDo(
181+
do_fn(
182+
eval_config=eval_config,
183+
eval_shared_models=eval_shared_models,
184+
)
185+
)
179186

180187

181188
def TFLitePredictExtractor(
182189
eval_config: config_pb2.EvalConfig,
183-
eval_shared_model: Union[types.EvalSharedModel, Dict[str,
184-
types.EvalSharedModel]]
190+
eval_shared_model: Union[
191+
types.EvalSharedModel, Dict[str, types.EvalSharedModel]
192+
],
193+
do_fn: Callable[..., TFLitePredictionDoFn] = TFLitePredictionDoFn,
185194
) -> extractor.Extractor:
186195
"""Creates an extractor for performing predictions on tflite models.
187196
@@ -195,6 +204,7 @@ def TFLitePredictExtractor(
195204
eval_config: Eval config.
196205
eval_shared_model: Shared model (single-model evaluation) or dict of shared
197206
models keyed by model name (multi-model evaluation).
207+
do_fn: Constructor for TFLitePredictionDoFn.
198208
199209
Returns:
200210
Extractor for extracting predictions.
@@ -207,4 +217,7 @@ def TFLitePredictExtractor(
207217
stage_name=_TFLITE_PREDICT_EXTRACTOR_STAGE_NAME,
208218
ptransform=_ExtractTFLitePredictions(
209219
eval_config=eval_config,
210-
eval_shared_models={m.model_name: m for m in eval_shared_models}))
220+
eval_shared_models={m.model_name: m for m in eval_shared_models},
221+
do_fn=do_fn,
222+
),
223+
)

0 commit comments

Comments
 (0)