Skip to content

Commit ce78a79

Browse files
author
tf-model-analysis-team
committed
Raise error when batch size does not match the first dimension of the shape.
PiperOrigin-RevId: 381062461
1 parent ad1a9eb commit ce78a79

File tree

2 files changed

+97
-4
lines changed

2 files changed

+97
-4
lines changed

tensorflow_model_analysis/model_util.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,12 @@ def maybe_expand_dims(arr):
857857
def to_dense(t):
858858
return tf.sparse.to_dense(t) if isinstance(t, tf.SparseTensor) else t
859859

860+
def check_shape(t, batch_size, key=None):
861+
if t.shape[0] != batch_size:
862+
raise ValueError(
863+
'First dimension does not correspond with batch size. '
864+
f'Batch size: {batch_size}, Dimensions: {t.shape}, Key: {key}.')
865+
860866
result = copy.copy(batched_extract)
861867
record_batch = batched_extract[constants.ARROW_RECORD_BATCH_KEY]
862868
serialized_examples = batched_extract[constants.INPUT_KEY]
@@ -895,16 +901,26 @@ def to_dense(t):
895901
outputs = signature(inputs)
896902
else:
897903
outputs = signature(tf.constant(inputs, dtype=tf.string))
904+
905+
dense_outputs = {}
906+
if isinstance(outputs, dict):
907+
for k, v in outputs.items():
908+
dense_outputs[k] = to_dense(v)
909+
check_shape(dense_outputs[k], record_batch.num_rows, key=k)
910+
else:
911+
dense_outputs = to_dense(outputs)
912+
check_shape(dense_outputs, record_batch.num_rows)
913+
898914
for i in range(record_batch.num_rows):
899-
if isinstance(outputs, dict):
915+
if isinstance(dense_outputs, dict):
900916
output = {
901-
k: maybe_expand_dims(to_dense(v)[i].numpy())
902-
for k, v in outputs.items()
917+
k: maybe_expand_dims(v[i].numpy())
918+
for k, v in dense_outputs.items()
903919
}
904920
else:
905921
output = {
906922
signature_name:
907-
maybe_expand_dims(np.asarray(to_dense(outputs))[i])
923+
maybe_expand_dims(np.asarray(dense_outputs)[i])
908924
}
909925
if result[extracts_key][i] is None:
910926
result[extracts_key][i] = collections.defaultdict(dict)

tensorflow_model_analysis/model_util_test.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,39 @@ def custom_multi_output(features):
172172
tf.saved_model.save(model, export_path, signatures=signatures)
173173
return export_path
174174

175+
def createModelWithInvalidOutputShape(self):
176+
input1 = tf.keras.layers.Input(shape=(1,), name='input_1')
177+
input2 = tf.keras.layers.Input(shape=(1,), name='input_2')
178+
inputs = [input1, input2]
179+
input_layer = tf.keras.layers.concatenate(inputs)
180+
output_layer = tf.keras.layers.Dense(
181+
2, activation=tf.nn.sigmoid, name='output')(
182+
input_layer)
183+
# Flatten the layer such that the first dimension no longer corresponds
184+
# with the batch size.
185+
reshape_layer = tf.keras.layers.Lambda(
186+
lambda x: tf.reshape(x, [-1]), name='reshape')(
187+
output_layer)
188+
model = tf.keras.models.Model(inputs, reshape_layer)
189+
190+
@tf.function
191+
def serving_default(serialized_tf_examples):
192+
parsed_features = tf.io.parse_example(
193+
serialized_tf_examples, {
194+
'input_1': tf.io.FixedLenFeature([1], dtype=tf.float32),
195+
'input_2': tf.io.FixedLenFeature([1], dtype=tf.float32)
196+
})
197+
return model(parsed_features)
198+
199+
input_spec = tf.TensorSpec(shape=(None,), dtype=tf.string, name='examples')
200+
signatures = {
201+
'serving_default': serving_default.get_concrete_function(input_spec),
202+
}
203+
204+
export_path = tempfile.mkdtemp()
205+
model.save(export_path, save_format='tf', signatures=signatures)
206+
return export_path
207+
175208
def createModelWithMultipleMixedInputs(self, save_as_keras):
176209
dense_input = tf.keras.layers.Input(
177210
shape=(2,), name='input_1', dtype=tf.int64)
@@ -821,6 +854,50 @@ def check_result(got):
821854

822855
util.assert_that(result, check_result, label='result')
823856

857+
def testModelSignaturesDoFnError(self):
858+
export_path = self.createModelWithInvalidOutputShape()
859+
signature_names = {constants.PREDICTIONS_KEY: {'': [None]}}
860+
eval_shared_models = {
861+
'':
862+
self.createTestEvalSharedModel(
863+
eval_saved_model_path=export_path,
864+
tags=[tf.saved_model.SERVING])
865+
}
866+
model_specs = [config.ModelSpec()]
867+
eval_config = config.EvalConfig(model_specs=model_specs)
868+
schema = self.createDenseInputsSchema()
869+
tfx_io = tf_example_record.TFExampleBeamRecord(
870+
physical_format='text',
871+
schema=schema,
872+
raw_record_column_name=constants.ARROW_INPUT_COLUMN)
873+
tensor_adapter_config = tensor_adapter.TensorAdapterConfig(
874+
arrow_schema=tfx_io.ArrowSchema(),
875+
tensor_representations=tfx_io.TensorRepresentations())
876+
877+
examples = [
878+
self._makeExample(input_1=1.0, input_2=2.0),
879+
self._makeExample(input_1=3.0, input_2=4.0),
880+
self._makeExample(input_1=5.0, input_2=6.0),
881+
]
882+
883+
with self.assertRaisesRegex(
884+
ValueError, 'First dimension does not correspond with batch size.'):
885+
with beam.Pipeline() as pipeline:
886+
# pylint: disable=no-value-for-parameter
887+
_ = (
888+
pipeline
889+
| 'Create' >> beam.Create([e.SerializeToString() for e in examples])
890+
| 'BatchExamples' >> tfx_io.BeamSource(batch_size=3)
891+
| 'ToExtracts' >> beam.Map(_record_batch_to_extracts)
892+
| 'ModelSignatures' >> beam.ParDo(
893+
model_util.ModelSignaturesDoFn(
894+
eval_config=eval_config,
895+
eval_shared_models=eval_shared_models,
896+
signature_names=signature_names,
897+
default_signature_names=None,
898+
prefer_dict_outputs=False,
899+
tensor_adapter_config=tensor_adapter_config)))
900+
824901
def testHasRubberStamp(self):
825902
# Model agnostic.
826903
self.assertFalse(model_util.has_rubber_stamp(None))

0 commit comments

Comments
 (0)