File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
tensorflow_model_analysis Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -781,6 +781,9 @@ def maybe_expand_dims(arr):
781
781
else :
782
782
return arr
783
783
784
+ def to_dense (t ):
785
+ return tf .sparse .to_dense (t ) if isinstance (t , tf .SparseTensor ) else t
786
+
784
787
result = copy .copy (batched_extract )
785
788
record_batch = batched_extract [constants .ARROW_RECORD_BATCH_KEY ]
786
789
serialized_examples = batched_extract [constants .INPUT_KEY ]
@@ -841,12 +844,13 @@ def maybe_expand_dims(arr):
841
844
for i in range (record_batch .num_rows ):
842
845
if isinstance (outputs , dict ):
843
846
output = {
844
- k : maybe_expand_dims (v [i ].numpy ())
847
+ k : maybe_expand_dims (to_dense ( v ) [i ].numpy ())
845
848
for k , v in outputs .items ()
846
849
}
847
850
else :
848
851
output = {
849
- signature_name : maybe_expand_dims (np .asarray (outputs )[i ])
852
+ signature_name :
853
+ maybe_expand_dims (np .asarray (to_dense (outputs ))[i ])
850
854
}
851
855
if result [extracts_key ][i ] is None :
852
856
result [extracts_key ][i ] = collections .defaultdict (dict )
You can’t perform that action at this time.
0 commit comments