Skip to content

Commit ecb6ea4

Browse files
Switches bert_clu_annotator build rules to cc_library_with_tflite.
PiperOrigin-RevId: 452161028
1 parent 9aa3b1a commit ecb6ea4

File tree

4 files changed

+44
-23
lines changed

4 files changed

+44
-23
lines changed

tensorflow_lite_support/cc/task/text/BUILD

+10-6
Original file line numberDiff line numberDiff line change
@@ -138,32 +138,36 @@ cc_library_with_tflite(
138138
],
139139
)
140140

141-
cc_library(
141+
cc_library_with_tflite(
142142
name = "clu_annotator",
143143
hdrs = [
144144
"clu_annotator.h",
145145
],
146-
deps = [
146+
tflite_deps = [
147147
"//tensorflow_lite_support/cc/task/core:base_task_api",
148148
"//tensorflow_lite_support/cc/task/core:tflite_engine",
149+
],
150+
deps = [
149151
"//tensorflow_lite_support/cc/task/text/proto:clu_proto_inc",
150152
],
151153
)
152154

153-
cc_library(
155+
cc_library_with_tflite(
154156
name = "bert_clu_annotator",
155157
srcs = [
156158
"bert_clu_annotator.cc",
157159
],
158160
hdrs = [
159161
"bert_clu_annotator.h",
160162
],
161-
deps = [
163+
tflite_deps = [
162164
":clu_annotator",
163-
"//tensorflow_lite_support/cc/port:status_macros",
164165
"//tensorflow_lite_support/cc/task/core:task_api_factory",
165-
"//tensorflow_lite_support/cc/task/core:task_utils",
166166
"//tensorflow_lite_support/cc/task/text/clu_lib:tflite_modules",
167+
],
168+
deps = [
169+
"//tensorflow_lite_support/cc/port:status_macros",
170+
"//tensorflow_lite_support/cc/task/core:task_utils",
167171
"//tensorflow_lite_support/cc/task/text/proto:bert_clu_annotator_options_proto_inc",
168172
"//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer",
169173
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils",

tensorflow_lite_support/cc/task/text/clu_lib/BUILD

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1+
load(
2+
"@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl",
3+
"cc_library_with_tflite",
4+
)
5+
16
package(
27
default_visibility = ["//visibility:public"],
38
licenses = ["notice"], # Apache 2.0
49
)
510

6-
cc_library(
11+
cc_library_with_tflite(
712
name = "tflite_modules",
813
srcs = ["tflite_modules.cc"],
914
hdrs = ["tflite_modules.h"],
15+
tflite_deps = [
16+
"//tensorflow_lite_support/cc/task/core:tflite_engine",
17+
],
1018
deps = [
1119
":bert_utils",
1220
":constants",

tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc

+14-9
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ absl::Status PopulateInputTextTensorForBERT(
4242
const CluRequest& request, int token_id_tensor_idx,
4343
int token_mask_tensor_idx, int token_type_id_tensor_idx,
4444
const tflite::support::text::tokenizer::BertTokenizer* tokenizer,
45-
size_t max_seq_len, int max_history_turns, tflite::Interpreter* interpreter,
46-
Artifacts* artifacts) {
45+
size_t max_seq_len, int max_history_turns,
46+
core::TfLiteEngine::Interpreter* interpreter, Artifacts* artifacts) {
4747
size_t seq_len;
4848
int64_t* tokens_tensor =
4949
interpreter->typed_input_tensor<int64_t>(token_id_tensor_idx);
@@ -116,8 +116,9 @@ absl::Status PopulateInputTextTensorForBERT(
116116
return absl::OkStatus();
117117
}
118118

119-
absl::StatusOr<int> GetInputSeqDimSize(const size_t input_idx,
120-
const tflite::Interpreter* interpreter) {
119+
absl::StatusOr<int> GetInputSeqDimSize(
120+
const size_t input_idx,
121+
const core::TfLiteEngine::Interpreter* interpreter) {
121122
if (input_idx >= interpreter->inputs().size()) {
122123
return absl::InternalError(absl::StrCat(
123124
"input_idx should be less than interpreter input numbers. ", input_idx,
@@ -132,14 +133,15 @@ absl::StatusOr<int> GetInputSeqDimSize(const size_t input_idx,
132133
return tflite::SizeOfDimension(tensor, 1);
133134
}
134135

135-
absl::Status AbstractModule::Init(tflite::Interpreter* interpreter,
136+
absl::Status AbstractModule::Init(core::TfLiteEngine::Interpreter* interpreter,
136137
const BertCluAnnotatorOptions* options) {
137138
interpreter_ = interpreter;
138139
return absl::OkStatus();
139140
}
140141

141142
absl::StatusOr<std::unique_ptr<AbstractModule>> UtteranceSeqModule::Create(
142-
tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
143+
core::TfLiteEngine::Interpreter* interpreter,
144+
const TensorIndexMap* tensor_index_map,
143145
const BertCluAnnotatorOptions* options,
144146
const tflite::support::text::tokenizer::BertTokenizer* tokenizer) {
145147
auto out = std::make_unique<UtteranceSeqModule>();
@@ -187,7 +189,8 @@ AbstractModule::NamesAndConfidencesFromOutput(int names_tensor_idx,
187189
}
188190

189191
absl::StatusOr<std::unique_ptr<AbstractModule>> DomainModule::Create(
190-
tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
192+
core::TfLiteEngine::Interpreter* interpreter,
193+
const TensorIndexMap* tensor_index_map,
191194
const BertCluAnnotatorOptions* options) {
192195
auto out = std::make_unique<DomainModule>();
193196
out->tensor_index_map_ = tensor_index_map;
@@ -215,7 +218,8 @@ absl::Status DomainModule::Postprocess(Artifacts* artifacts,
215218
}
216219

217220
absl::StatusOr<std::unique_ptr<AbstractModule>> IntentModule::Create(
218-
tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
221+
core::TfLiteEngine::Interpreter* interpreter,
222+
const TensorIndexMap* tensor_index_map,
219223
const BertCluAnnotatorOptions* options) {
220224
auto out = std::make_unique<IntentModule>();
221225
out->tensor_index_map_ = tensor_index_map;
@@ -261,7 +265,8 @@ absl::Status IntentModule::Postprocess(Artifacts* artifacts,
261265
}
262266

263267
absl::StatusOr<std::unique_ptr<AbstractModule>> SlotModule::Create(
264-
tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
268+
core::TfLiteEngine::Interpreter* interpreter,
269+
const TensorIndexMap* tensor_index_map,
265270
const BertCluAnnotatorOptions* options) {
266271
auto out = std::make_unique<SlotModule>();
267272
out->tensor_index_map_ = tensor_index_map;

tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h

+11-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License.
1818

1919
#include "absl/status/statusor.h" // from @com_google_absl
2020
#include "absl/strings/string_view.h" // from @com_google_absl
21-
#include "tensorflow/lite/interpreter.h"
21+
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
2222
#include "tensorflow_lite_support/cc/task/text/proto/bert_clu_annotator_options_proto_inc.h"
2323
#include "tensorflow_lite_support/cc/task/text/proto/clu_proto_inc.h"
2424
#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
@@ -76,7 +76,7 @@ class AbstractModule {
7676
protected:
7777
AbstractModule() = default;
7878

79-
absl::Status Init(Interpreter* interpreter,
79+
absl::Status Init(core::TfLiteEngine::Interpreter* interpreter,
8080
const BertCluAnnotatorOptions* options);
8181

8282
using NamesAndConfidences =
@@ -88,7 +88,7 @@ class AbstractModule {
8888
int names_tensor_idx, int scores_tensor_idx) const;
8989

9090
// TFLite interpreter
91-
Interpreter* interpreter_ = nullptr;
91+
core::TfLiteEngine::Interpreter* interpreter_ = nullptr;
9292

9393
const TensorIndexMap* tensor_index_map_ = nullptr;
9494
};
@@ -98,7 +98,8 @@ class AbstractModule {
9898
class UtteranceSeqModule : public AbstractModule {
9999
public:
100100
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
101-
Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
101+
core::TfLiteEngine::Interpreter* interpreter,
102+
const TensorIndexMap* tensor_index_map,
102103
const BertCluAnnotatorOptions* options,
103104
const tflite::support::text::tokenizer::BertTokenizer* tokenizer);
104105

@@ -116,7 +117,8 @@ class UtteranceSeqModule : public AbstractModule {
116117
class DomainModule : public AbstractModule {
117118
public:
118119
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
119-
Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
120+
core::TfLiteEngine::Interpreter* interpreter,
121+
const TensorIndexMap* tensor_index_map,
120122
const BertCluAnnotatorOptions* options);
121123

122124
absl::Status Postprocess(Artifacts* artifacts,
@@ -130,7 +132,8 @@ class DomainModule : public AbstractModule {
130132
class IntentModule : public AbstractModule {
131133
public:
132134
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
133-
Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
135+
core::TfLiteEngine::Interpreter* interpreter,
136+
const TensorIndexMap* tensor_index_map,
134137
const BertCluAnnotatorOptions* options);
135138

136139
absl::Status Postprocess(Artifacts* artifacts,
@@ -145,7 +148,8 @@ class IntentModule : public AbstractModule {
145148
class SlotModule : public AbstractModule {
146149
public:
147150
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
148-
Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
151+
core::TfLiteEngine::Interpreter* interpreter,
152+
const TensorIndexMap* tensor_index_map,
149153
const BertCluAnnotatorOptions* options);
150154

151155
absl::Status Postprocess(Artifacts* artifacts,

0 commit comments

Comments
 (0)