Skip to content

Commit e673985

Browse files
Report unsupported tensor type in RaggedTensorToTensor in Prepare.
PiperOrigin-RevId: 572528631
1 parent 404b38b commit e673985

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite.cc

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -558,26 +558,34 @@ void SetOutputT(TfLiteContext* context, int ragged_rank,
558558
}
559559
}
560560

561-
void SetOutput(TfLiteContext* context, int ragged_rank,
562-
const std::vector<int>& output_index,
563-
const TfLiteTensor& values_tensor,
564-
const TfLiteTensor& default_value_tensor,
565-
TfLiteTensor* output_tensor) {
561+
bool IsSupportedTensorType(TfLiteType type) {
562+
// Should reflect SetOutput capabilities.
563+
return type == kTfLiteInt32 || type == kTfLiteInt64 || type == kTfLiteFloat32;
564+
}
565+
566+
TfLiteStatus SetOutput(TfLiteContext* context, int ragged_rank,
567+
const std::vector<int>& output_index,
568+
const TfLiteTensor& values_tensor,
569+
const TfLiteTensor& default_value_tensor,
570+
TfLiteTensor* output_tensor) {
566571
switch (output_tensor->type) {
567572
case kTfLiteInt32:
568573
SetOutputT<int32_t>(context, ragged_rank, output_index, values_tensor,
569574
default_value_tensor, output_tensor);
570-
break;
575+
return kTfLiteOk;
571576
case kTfLiteInt64:
572577
SetOutputT<int64_t>(context, ragged_rank, output_index, values_tensor,
573578
default_value_tensor, output_tensor);
574-
break;
579+
return kTfLiteOk;
575580
case kTfLiteFloat32:
576581
SetOutputT<float>(context, ragged_rank, output_index, values_tensor,
577582
default_value_tensor, output_tensor);
578-
break;
583+
return kTfLiteOk;
579584
default:
585+
// Should not happen, checked in Prepare.
586+
// Left as a defensive programming artifact for future updates.
580587
context->ReportError(context, "Not supported values type");
588+
return kTfLiteError;
581589
}
582590
}
583591

@@ -624,17 +632,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
624632
context->ReportError(context, "Attributes are not initialized");
625633
return kTfLiteError;
626634
}
627-
// The output tensor need to be set to dynamic because it can have different
628-
// size.
629635
TfLiteTensor& output_tensor =
630636
context->tensors[node->outputs->data[kOutputTensor]];
637+
if (!IsSupportedTensorType(output_tensor.type)) {
638+
context->ReportError(context, "Unsupported ragged tensor type");
639+
return kTfLiteError;
640+
}
641+
// The output tensor needs to be set to dynamic because it can have different
642+
// size.
631643
SetTensorToDynamic(&output_tensor);
632644

633645
// Check that input shape tensor is int32 or int64
634646
TfLiteTensor& input_shape = context->tensors[node->inputs->data[kShapeInput]];
635647
if (input_shape.type != kTfLiteInt32 && input_shape.type != kTfLiteInt64) {
636648
context->ReportError(context,
637-
"Input form tensor could be only int32 or int64");
649+
"Input shape tensor could be only int32 or int64");
638650
return kTfLiteError;
639651
}
640652
return kTfLiteOk;
@@ -704,8 +716,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
704716
new_output_index.clear();
705717
}
706718

707-
SetOutput(context, attributes->ragged_rank, output_index, input_values,
708-
default_value, &output_tensor);
719+
TF_LITE_ENSURE_OK(context,
720+
SetOutput(context, attributes->ragged_rank, output_index,
721+
input_values, default_value, &output_tensor));
709722
}
710723
return kTfLiteOk;
711724
}

tensorflow_text/core/kernels/ragged_tensor_to_tensor_tflite_test.cc

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ class RaggedTensorToTensorOpModel : public SingleOpModel {
6060
partition_tensors_shapes,
6161
std::vector<std::string> partition_types,
6262
TensorType value_type = TensorType_FLOAT32,
63-
TensorType index_type = TensorType_INT32) {
63+
TensorType index_type = TensorType_INT32,
64+
bool allocate_and_delegate = true) {
6465
// A structure to collect shapes for the input.
6566
std::vector<std::vector<int>> shapes;
6667
input_shape_ = AddInput(index_type);
@@ -89,7 +90,11 @@ class RaggedTensorToTensorOpModel : public SingleOpModel {
8990
fbb.Finish();
9091
SetCustomOp("RaggedTensorToTensor", fbb.GetBuffer(),
9192
ops::custom::text::Register_RAGGED_TENSOR_TO_TENSOR);
92-
BuildInterpreter(shapes);
93+
BuildInterpreter(shapes, /*num_threads=*/-1,
94+
/*allow_fp32_relax_to_fp16=*/false,
95+
/*apply_delegate=*/true,
96+
/*allocate_and_delegate=*/allocate_and_delegate,
97+
/*use_simple_allocator=*/false);
9398
}
9499

95100
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
@@ -119,6 +124,7 @@ class RaggedTensorToTensorOpModel : public SingleOpModel {
119124
}
120125
SingleOpModel::Invoke();
121126
}
127+
TfLiteStatus TryAllocateTensors() { return interpreter_->AllocateTensors(); }
122128

123129
private:
124130
int input_shape_;
@@ -295,5 +301,16 @@ TEST(RaggedTensorToTensorTest, RaggedTensorToTensorContractExpandedDense) {
295301
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, //
296302
.4, 1.4, .5, 1.5, .6, 1.6, .7, 1.7, 1.5, 1.5}));
297303
}
304+
305+
TEST(RaggedTensorToTensorTest, StringType) {
306+
RaggedTensorToTensorOpModel model(
307+
2, // output_shape_dims
308+
{9}, // values_shape
309+
{{1}, {9}}, // partition_tensors_shapes
310+
std::vector<std::string>({"FIRST_DIM_SIZE", "VALUE_ROWIDS"}),
311+
TensorType_STRING, TensorType_INT32, /*allocate_and_delegate=*/false);
312+
EXPECT_EQ(model.TryAllocateTensors(), kTfLiteError);
313+
}
314+
298315
} // namespace
299316
} // namespace tflite

0 commit comments

Comments
 (0)