Skip to content

Commit b6e2f42

Browse files
mdrevestf-model-analysis-team
authored and
tf-model-analysis-team
committed
Add support for SparseTensor outputs.
PiperOrigin-RevId: 347874515
1 parent 888b982 commit b6e2f42

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tensorflow_model_analysis/model_util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,9 @@ def maybe_expand_dims(arr):
781781
else:
782782
return arr
783783

784+
def to_dense(t):
785+
return tf.sparse.to_dense(t) if isinstance(t, tf.SparseTensor) else t
786+
784787
result = copy.copy(batched_extract)
785788
record_batch = batched_extract[constants.ARROW_RECORD_BATCH_KEY]
786789
serialized_examples = batched_extract[constants.INPUT_KEY]
@@ -841,12 +844,13 @@ def maybe_expand_dims(arr):
841844
for i in range(record_batch.num_rows):
842845
if isinstance(outputs, dict):
843846
output = {
844-
k: maybe_expand_dims(v[i].numpy())
847+
k: maybe_expand_dims(to_dense(v)[i].numpy())
845848
for k, v in outputs.items()
846849
}
847850
else:
848851
output = {
849-
signature_name: maybe_expand_dims(np.asarray(outputs)[i])
852+
signature_name:
853+
maybe_expand_dims(np.asarray(to_dense(outputs))[i])
850854
}
851855
if result[extracts_key][i] is None:
852856
result[extracts_key][i] = collections.defaultdict(dict)

0 commit comments

Comments
 (0)