@@ -558,26 +558,34 @@ void SetOutputT(TfLiteContext* context, int ragged_rank,
558
558
}
559
559
}
560
560
561
- void SetOutput (TfLiteContext* context, int ragged_rank,
562
- const std::vector<int >& output_index,
563
- const TfLiteTensor& values_tensor,
564
- const TfLiteTensor& default_value_tensor,
565
- TfLiteTensor* output_tensor) {
561
+ bool IsSupportedTensorType (TfLiteType type) {
562
+ // Should reflect SetOutput capabilities.
563
+ return type == kTfLiteInt32 || type == kTfLiteInt64 || type == kTfLiteFloat32 ;
564
+ }
565
+
566
+ TfLiteStatus SetOutput (TfLiteContext* context, int ragged_rank,
567
+ const std::vector<int >& output_index,
568
+ const TfLiteTensor& values_tensor,
569
+ const TfLiteTensor& default_value_tensor,
570
+ TfLiteTensor* output_tensor) {
566
571
switch (output_tensor->type ) {
567
572
case kTfLiteInt32 :
568
573
SetOutputT<int32_t >(context, ragged_rank, output_index, values_tensor,
569
574
default_value_tensor, output_tensor);
570
- break ;
575
+ return kTfLiteOk ;
571
576
case kTfLiteInt64 :
572
577
SetOutputT<int64_t >(context, ragged_rank, output_index, values_tensor,
573
578
default_value_tensor, output_tensor);
574
- break ;
579
+ return kTfLiteOk ;
575
580
case kTfLiteFloat32 :
576
581
SetOutputT<float >(context, ragged_rank, output_index, values_tensor,
577
582
default_value_tensor, output_tensor);
578
- break ;
583
+ return kTfLiteOk ;
579
584
default :
585
+ // Should not happen, checked in Prepare.
586
+ // Left as a defensive programming artifact for future updates.
580
587
context->ReportError (context, " Not supported values type" );
588
+ return kTfLiteError ;
581
589
}
582
590
}
583
591
@@ -624,17 +632,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
624
632
context->ReportError (context, " Attributes are not initialized" );
625
633
return kTfLiteError ;
626
634
}
627
- // The output tensor need to be set to dynamic because it can have different
628
- // size.
629
635
TfLiteTensor& output_tensor =
630
636
context->tensors [node->outputs ->data [kOutputTensor ]];
637
+ if (!IsSupportedTensorType (output_tensor.type )) {
638
+ context->ReportError (context, " Unsupported ragged tensor type" );
639
+ return kTfLiteError ;
640
+ }
641
+ // The output tensor needs to be set to dynamic because it can have different
642
+ // size.
631
643
SetTensorToDynamic (&output_tensor);
632
644
633
645
// Check that input shape tensor is int32 or int64
634
646
TfLiteTensor& input_shape = context->tensors [node->inputs ->data [kShapeInput ]];
635
647
if (input_shape.type != kTfLiteInt32 && input_shape.type != kTfLiteInt64 ) {
636
648
context->ReportError (context,
637
- " Input form tensor could be only int32 or int64" );
649
+ " Input shape tensor could be only int32 or int64" );
638
650
return kTfLiteError ;
639
651
}
640
652
return kTfLiteOk ;
@@ -704,8 +716,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
704
716
new_output_index.clear ();
705
717
}
706
718
707
- SetOutput (context, attributes->ragged_rank , output_index, input_values,
708
- default_value, &output_tensor);
719
+ TF_LITE_ENSURE_OK (context,
720
+ SetOutput (context, attributes->ragged_rank , output_index,
721
+ input_values, default_value, &output_tensor));
709
722
}
710
723
return kTfLiteOk ;
711
724
}
0 commit comments