15
15
16
16
import collections
17
17
import copy
18
- from typing import Dict , Sequence , Union
18
+ from typing import Callable , Dict , Sequence , Union
19
19
20
20
from absl import logging
21
21
import apache_beam as beam
28
28
from tensorflow_model_analysis .utils import model_util
29
29
from tensorflow_model_analysis .utils import util
30
30
31
+ _OpResolverType = tf .lite .experimental .OpResolverType
31
32
_TFLITE_PREDICT_EXTRACTOR_STAGE_NAME = 'ExtractTFLitePredictions'
32
33
33
34
34
35
# TODO(b/149981535) Determine if we should merge with RunInference.
35
36
@beam .typehints .with_input_types (types .Extracts )
36
37
@beam .typehints .with_output_types (types .Extracts )
37
- class _TFLitePredictionDoFn (model_util .BatchReducibleBatchedDoFnWithModels ):
38
+ class TFLitePredictionDoFn (model_util .BatchReducibleBatchedDoFnWithModels ):
38
39
"""A DoFn that loads tflite models and predicts."""
39
40
40
41
def __init__ (self , eval_config : config_pb2 .EvalConfig ,
@@ -47,17 +48,19 @@ def setup(self):
47
48
self ._interpreters = {}
48
49
49
50
major , minor , _ = tf .version .VERSION .split ('.' )
51
+ op_resolver_type = _OpResolverType .AUTO
52
+ # TODO(b/207600661): drop BUILTIN_WITHOUT_DEFAULT_DELEGATES once the issue
53
+ # is fixed.
54
+ if int (major ) > 2 or (int (major ) == 2 and int (minor ) >= 5 ):
55
+ op_resolver_type = _OpResolverType .BUILTIN_WITHOUT_DEFAULT_DELEGATES
50
56
for model_name , model_contents in self ._loaded_models .items ():
51
- # TODO(b/207600661): drop BUILTIN_WITHOUT_DEFAULT_DELEGATES once the issue
52
- # is fixed.
53
- if int (major ) > 2 or (int (major ) == 2 and int (minor ) >= 5 ):
54
- self ._interpreters [model_name ] = tf .lite .Interpreter (
55
- model_content = model_contents .contents ,
56
- experimental_op_resolver_type = tf .lite .experimental .OpResolverType
57
- .BUILTIN_WITHOUT_DEFAULT_DELEGATES )
58
- else :
59
- self ._interpreters [model_name ] = tf .lite .Interpreter (
60
- model_content = model_contents .contents )
57
+ self ._interpreters [model_name ] = self ._make_interpreter (
58
+ model_content = model_contents .contents ,
59
+ experimental_op_resolver_type = op_resolver_type ,
60
+ )
61
+
62
+ def _make_interpreter (self , ** kwargs ) -> tf .lite .Interpreter :
63
+ return tf .lite .Interpreter (** kwargs )
61
64
62
65
def _get_input_name_from_input_detail (self , input_detail ):
63
66
"""Get input name from input detail.
@@ -157,31 +160,37 @@ def _batch_reducible_process(
157
160
@beam .typehints .with_input_types (types .Extracts )
158
161
@beam .typehints .with_output_types (types .Extracts )
159
162
def _ExtractTFLitePredictions ( # pylint: disable=invalid-name
160
- extracts : beam .pvalue .PCollection , eval_config : config_pb2 .EvalConfig ,
161
- eval_shared_models : Dict [str ,
162
- types .EvalSharedModel ]) -> beam .pvalue .PCollection :
163
+ extracts : beam .pvalue .PCollection ,
164
+ eval_config : config_pb2 .EvalConfig ,
165
+ eval_shared_models : Dict [str , types .EvalSharedModel ],
166
+ do_fn : Callable [..., TFLitePredictionDoFn ],
167
+ ) -> beam .pvalue .PCollection :
163
168
"""A PTransform that adds predictions and possibly other tensors to extracts.
164
169
165
170
Args:
166
171
extracts: PCollection of extracts containing model inputs keyed by
167
172
tfma.FEATURES_KEY.
168
173
eval_config: Eval config.
169
174
eval_shared_models: Shared model parameters keyed by model name.
175
+ do_fn: Constructor for TFLitePredictionDoFn.
170
176
171
177
Returns:
172
178
PCollection of Extracts updated with the predictions.
173
179
"""
174
- return (
175
- extracts
176
- | 'Predict' >> beam .ParDo (
177
- _TFLitePredictionDoFn (
178
- eval_config = eval_config , eval_shared_models = eval_shared_models )))
180
+ return extracts | 'Predict' >> beam .ParDo (
181
+ do_fn (
182
+ eval_config = eval_config ,
183
+ eval_shared_models = eval_shared_models ,
184
+ )
185
+ )
179
186
180
187
181
188
def TFLitePredictExtractor (
182
189
eval_config : config_pb2 .EvalConfig ,
183
- eval_shared_model : Union [types .EvalSharedModel , Dict [str ,
184
- types .EvalSharedModel ]]
190
+ eval_shared_model : Union [
191
+ types .EvalSharedModel , Dict [str , types .EvalSharedModel ]
192
+ ],
193
+ do_fn : Callable [..., TFLitePredictionDoFn ] = TFLitePredictionDoFn ,
185
194
) -> extractor .Extractor :
186
195
"""Creates an extractor for performing predictions on tflite models.
187
196
@@ -195,6 +204,7 @@ def TFLitePredictExtractor(
195
204
eval_config: Eval config.
196
205
eval_shared_model: Shared model (single-model evaluation) or dict of shared
197
206
models keyed by model name (multi-model evaluation).
207
+ do_fn: Constructor for TFLitePredictionDoFn.
198
208
199
209
Returns:
200
210
Extractor for extracting predictions.
@@ -207,4 +217,7 @@ def TFLitePredictExtractor(
207
217
stage_name = _TFLITE_PREDICT_EXTRACTOR_STAGE_NAME ,
208
218
ptransform = _ExtractTFLitePredictions (
209
219
eval_config = eval_config ,
210
- eval_shared_models = {m .model_name : m for m in eval_shared_models }))
220
+ eval_shared_models = {m .model_name : m for m in eval_shared_models },
221
+ do_fn = do_fn ,
222
+ ),
223
+ )
0 commit comments