Skip to content

Commit d883320

Browse files
adding an option to specify default values for nullable fields in BigQuery connector (#1583)
* adding option to specify default values for nullable fields * fix when default_values arg is set * linter fixes * fixing tests for MacOS and Windows * making tests to pass for both Linux and MacOS
1 parent d58cf29 commit d883320

File tree

7 files changed

+315
-127
lines changed

7 files changed

+315
-127
lines changed

tensorflow_io/bigquery.md

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ from tensorflow.python.framework import dtypes
5757
from tensorflow_io.bigquery import BigQueryClient
5858
from tensorflow_io.bigquery import BigQueryReadSession
5959

60-
GCP_PROJECT_ID = '<FILL_ME_IN>'
60+
GCP_PROJECT_ID = "<FILL_ME_IN>"
6161
DATASET_GCP_PROJECT_ID = "bigquery-public-data"
6262
DATASET_ID = "samples"
6363
TABLE_ID = "wikipedia"
@@ -68,20 +68,29 @@ def main():
6868
read_session = client.read_session(
6969
"projects/" + GCP_PROJECT_ID,
7070
DATASET_GCP_PROJECT_ID, TABLE_ID, DATASET_ID,
71-
["title",
71+
selected_fields=["title",
7272
"id",
7373
"num_characters",
7474
"language",
7575
"timestamp",
7676
"wp_namespace",
7777
"contributor_username"],
78-
[dtypes.string,
78+
output_types=[dtypes.string,
7979
dtypes.int64,
8080
dtypes.int64,
8181
dtypes.string,
8282
dtypes.int64,
8383
dtypes.int64,
8484
dtypes.string],
85+
default_values=[
86+
"",
87+
0,
88+
0,
89+
"",
90+
0,
91+
0,
92+
""
93+
],
8594
requested_streams=2,
8695
row_restriction="num_characters > 1000",
8796
data_format=BigQueryClient.DataFormat.AVRO)
@@ -98,8 +107,8 @@ def main():
98107
print("row %d: %s" % (row_index, row))
99108
row_index += 1
100109

101-
if __name__ == '__main__':
102-
app.run(main)
110+
if __name__ == "__main__":
111+
main()
103112

104113
```
105114

@@ -127,10 +136,10 @@ dataset = streams_ds.interleave(
127136
Connector also supports reading BigQuery column with repeated mode (each field contains array of values with primitive type: Integer, Float, Boolean, String, but RECORD is not supported). In this case, selected_fields needs be a dictionary in a form like this:
128137

129138
```python
130-
{ "field_a_name": {"mode": BigQueryClient.FieldMode.REPEATED, output_type: dtypes.int64},
131-
"field_b_name": {"mode": BigQueryClient.FieldMode.NULLABLE, output_type: dtypes.string},
139+
{ "field_a_name": {"mode": BigQueryClient.FieldMode.REPEATED, "output_type": dtypes.int64},
140+
"field_b_name": {"mode": BigQueryClient.FieldMode.NULLABLE, "output_type": dtypes.string, "default_value", "<default_value>"},
132141
...
133-
"field_x_name": {"mode": BigQueryClient.FieldMode.REQUIRED, output_type: dtypes.string}
142+
"field_x_name": {"mode": BigQueryClient.FieldMode.REQUIRED, "output_type": dtypes.string}
134143
}
135144
```
136145
"mode" is BigQuery column attribute concept, it can be 'repeated', 'nullable' or 'required' (enum BigQueryClient.FieldMode.REPEATED, NULLABLE, REQUIRED).The output field order is unrelated to the order of fields in
@@ -144,7 +153,7 @@ from tensorflow.python.framework import dtypes
144153
from tensorflow_io.bigquery import BigQueryClient
145154
from tensorflow_io.bigquery import BigQueryReadSession
146155

147-
GCP_PROJECT_ID = '<FILL_ME_IN>'
156+
GCP_PROJECT_ID = "<FILL_ME_IN>"
148157
DATASET_GCP_PROJECT_ID = "bigquery-public-data"
149158
DATASET_ID = "certain_dataset"
150159
TABLE_ID = "certain_table_with_repeated_field"
@@ -156,10 +165,10 @@ def main():
156165
"projects/" + GCP_PROJECT_ID,
157166
DATASET_GCP_PROJECT_ID, TABLE_ID, DATASET_ID,
158167
selected_fiels={
159-
"field_a_name": {"mode": BigQueryClient.FieldMode.REPEATED, output_type: dtypes.int64},
160-
"field_b_name": {"mode": BigQueryClient.FieldMode.NULLABLE, output_type: dtypes.string},
161-
"field_c_name": {"mode": BigQueryClient.FieldMode.REQUIRED, output_type: dtypes.string}
162-
"field_d_name": {"mode": BigQueryClient.FieldMode.REPEATED, output_type: dtypes.string}
168+
"field_a_name": {"mode": BigQueryClient.FieldMode.REPEATED, "output_type": dtypes.int64},
169+
"field_b_name": {"mode": BigQueryClient.FieldMode.NULLABLE, "output_type": dtypes.string, "default_value": ""},
170+
"field_c_name": {"mode": BigQueryClient.FieldMode.REQUIRED, "output_type": dtypes.string}
171+
"field_d_name": {"mode": BigQueryClient.FieldMode.REPEATED, "output_type": dtypes.string}
163172
}
164173
requested_streams=2,
165174
row_restriction="num_characters > 1000",
@@ -171,8 +180,8 @@ def main():
171180
print("row %d: %s" % (row_index, row))
172181
row_index += 1
173182

174-
if __name__ == '__main__':
175-
app.run(main)
183+
if __name__ == "__main__":
184+
main()
176185
```
177186

178187
Then each field of a repeated column becomes a rank-1 variable length Tensor. If you want to

tensorflow_io/core/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ cc_library(
168168
"@arrow",
169169
"@avro",
170170
"@com_github_grpc_grpc//:grpc++",
171+
"@com_google_absl//absl/types:any",
171172
"@com_google_googleapis//google/cloud/bigquery/storage/v1beta1:storage_cc_grpc",
172173
"@local_config_tf//:libtensorflow_framework",
173174
"@local_config_tf//:tf_header_lib",
@@ -190,6 +191,7 @@ cc_library(
190191
"@com_google_absl//absl/algorithm",
191192
"@com_google_absl//absl/container:fixed_array",
192193
"@com_google_absl//absl/container:flat_hash_map",
194+
"@com_google_absl//absl/types:any",
193195
"@com_google_absl//absl/types:variant",
194196
"@com_google_googleapis//google/cloud/bigquery/storage/v1beta1:storage_cc_grpc",
195197
"@local_config_tf//:libtensorflow_framework",

tensorflow_io/core/kernels/bigquery/bigquery_dataset_op.cc

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include <memory>
1717
#include <vector>
1818

19+
#include "absl/types/any.h"
1920
#include "arrow/buffer.h"
2021
#include "arrow/ipc/api.h"
2122
#include "tensorflow/core/framework/op_kernel.h"
@@ -30,6 +31,7 @@ class BigQueryDatasetOp : public DatasetOpKernel {
3031
explicit BigQueryDatasetOp(OpKernelConstruction *ctx) : DatasetOpKernel(ctx) {
3132
OP_REQUIRES_OK(ctx, ctx->GetAttr("selected_fields", &selected_fields_));
3233
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
34+
OP_REQUIRES_OK(ctx, ctx->GetAttr("default_values", &default_values_));
3335
OP_REQUIRES_OK(ctx, ctx->GetAttr("offset", &offset_));
3436
string data_format_str;
3537
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
@@ -54,20 +56,53 @@ class BigQueryDatasetOp : public DatasetOpKernel {
5456
output_shapes.reserve(num_outputs);
5557
DataTypeVector output_types_vector;
5658
output_types_vector.reserve(num_outputs);
59+
typed_default_values_.reserve(num_outputs);
5760
for (uint64 i = 0; i < num_outputs; ++i) {
5861
output_shapes.push_back({});
5962
output_types_vector.push_back(output_types_[i]);
63+
const DataType &output_type = output_types_[i];
64+
const string &default_value = default_values_[i];
65+
switch (output_type) {
66+
case DT_FLOAT:
67+
typed_default_values_.push_back(absl::any(std::stof(default_value)));
68+
break;
69+
case DT_DOUBLE:
70+
typed_default_values_.push_back(absl::any(std::stod(default_value)));
71+
break;
72+
case DT_INT32:
73+
int32_t value_int32_t;
74+
strings::safe_strto32(default_value, &value_int32_t);
75+
typed_default_values_.push_back(absl::any(value_int32_t));
76+
break;
77+
case DT_INT64:
78+
int64_t value_int64_t;
79+
strings::safe_strto64(default_value, &value_int64_t);
80+
typed_default_values_.push_back(absl::any(value_int64_t));
81+
break;
82+
case DT_BOOL:
83+
typed_default_values_.push_back(absl::any(default_value == "True"));
84+
break;
85+
case DT_STRING:
86+
typed_default_values_.push_back(absl::any(default_value));
87+
break;
88+
default:
89+
ctx->CtxFailure(
90+
errors::InvalidArgument("Unsupported output_type:", output_type));
91+
break;
92+
}
6093
}
6194

6295
*output = new Dataset(ctx, client_resource, output_types_vector,
6396
std::move(output_shapes), std::move(stream),
6497
std::move(schema), selected_fields_, output_types_,
65-
offset_, data_format_);
98+
typed_default_values_, offset_, data_format_);
6699
}
67100

68101
private:
69102
std::vector<string> selected_fields_;
70103
std::vector<DataType> output_types_;
104+
std::vector<string> default_values_;
105+
std::vector<absl::any> typed_default_values_;
71106
int64 offset_;
72107
apiv1beta1::DataFormat data_format_;
73108

@@ -79,7 +114,8 @@ class BigQueryDatasetOp : public DatasetOpKernel {
79114
std::vector<PartialTensorShape> output_shapes,
80115
string stream, string schema,
81116
std::vector<string> selected_fields,
82-
std::vector<DataType> output_types, int64 offset_,
117+
std::vector<DataType> output_types,
118+
std::vector<absl::any> typed_default_values, int64 offset_,
83119
apiv1beta1::DataFormat data_format)
84120
: DatasetBase(DatasetContext(ctx)),
85121
client_resource_(client_resource),
@@ -88,6 +124,7 @@ class BigQueryDatasetOp : public DatasetOpKernel {
88124
stream_(stream),
89125
selected_fields_(selected_fields),
90126
output_types_(output_types),
127+
typed_default_values_(typed_default_values),
91128
offset_(offset_),
92129
avro_schema_(absl::make_unique<avro::ValidSchema>()),
93130
data_format_(data_format) {
@@ -147,6 +184,10 @@ class BigQueryDatasetOp : public DatasetOpKernel {
147184

148185
const std::vector<DataType> &output_types() const { return output_types_; }
149186

187+
const std::vector<absl::any> &typed_default_values() const {
188+
return typed_default_values_;
189+
}
190+
150191
const std::unique_ptr<avro::ValidSchema> &avro_schema() const {
151192
return avro_schema_;
152193
}
@@ -180,6 +221,7 @@ class BigQueryDatasetOp : public DatasetOpKernel {
180221
const string stream_;
181222
const std::vector<string> selected_fields_;
182223
const std::vector<DataType> output_types_;
224+
const std::vector<absl::any> typed_default_values_;
183225
const std::unique_ptr<avro::ValidSchema> avro_schema_;
184226
const int64 offset_;
185227
std::shared_ptr<::arrow::Schema> arrow_schema_;

tensorflow_io/core/kernels/bigquery/bigquery_lib.h

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include <Windows.h>
2727
#undef OPTIONAL
2828
#endif
29+
#include "absl/types/any.h"
2930
#include "api/Compiler.hh"
3031
#include "api/DataFile.hh"
3132
#include "api/Decoder.hh"
@@ -127,7 +128,8 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator<Dataset> {
127128

128129
auto status =
129130
ReadRecord(ctx, out_tensors, this->dataset()->selected_fields(),
130-
this->dataset()->output_types());
131+
this->dataset()->output_types(),
132+
this->dataset()->typed_default_values());
131133
current_row_index_++;
132134
return status;
133135
}
@@ -181,10 +183,11 @@ class BigQueryReaderDatasetIteratorBase : public DatasetIterator<Dataset> {
181183
}
182184

183185
virtual Status EnsureHasRow(bool *end_of_sequence) = 0;
184-
virtual Status ReadRecord(IteratorContext *ctx,
185-
std::vector<Tensor> *out_tensors,
186-
const std::vector<string> &columns,
187-
const std::vector<DataType> &output_types) = 0;
186+
virtual Status ReadRecord(
187+
IteratorContext *ctx, std::vector<Tensor> *out_tensors,
188+
const std::vector<string> &columns,
189+
const std::vector<DataType> &output_types,
190+
const std::vector<absl::any> &typed_default_values) = 0;
188191
int current_row_index_ = 0;
189192
mutex mu_;
190193
std::unique_ptr<::grpc::ClientContext> read_rows_context_ TF_GUARDED_BY(mu_);
@@ -245,15 +248,15 @@ class BigQueryReaderArrowDatasetIterator
245248

246249
Status ReadRecord(IteratorContext *ctx, std::vector<Tensor> *out_tensors,
247250
const std::vector<string> &columns,
248-
const std::vector<DataType> &output_types)
251+
const std::vector<DataType> &output_types,
252+
const std::vector<absl::any> &typed_default_values)
249253
TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {
250254
out_tensors->clear();
251255
out_tensors->reserve(columns.size());
252256

253257
if (this->current_row_index_ == 0 && this->column_indices_.empty()) {
254258
this->column_indices_.resize(columns.size());
255259
for (size_t i = 0; i < columns.size(); ++i) {
256-
DataType output_type = output_types[i];
257260
auto column_name = this->record_batch_->column_name(i);
258261
auto it = std::find(columns.begin(), columns.end(), column_name);
259262
if (it == columns.end()) {
@@ -337,7 +340,8 @@ class BigQueryReaderAvroDatasetIterator
337340

338341
Status ReadRecord(IteratorContext *ctx, std::vector<Tensor> *out_tensors,
339342
const std::vector<string> &columns,
340-
const std::vector<DataType> &output_types)
343+
const std::vector<DataType> &output_types,
344+
const std::vector<absl::any> &typed_default_values)
341345
TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {
342346
avro::decode(*this->decoder_, *this->datum_);
343347
if (this->datum_->type() != avro::AVRO_RECORD) {
@@ -521,22 +525,28 @@ class BigQueryReaderAvroDatasetIterator
521525
case avro::AVRO_NULL:
522526
switch (output_types[i]) {
523527
case DT_BOOL:
524-
((*out_tensors)[i]).scalar<bool>()() = false;
528+
((*out_tensors)[i]).scalar<bool>()() =
529+
absl::any_cast<bool>(typed_default_values[i]);
525530
break;
526531
case DT_INT32:
527-
((*out_tensors)[i]).scalar<int32>()() = 0;
532+
((*out_tensors)[i]).scalar<int32>()() =
533+
absl::any_cast<int32_t>(typed_default_values[i]);
528534
break;
529535
case DT_INT64:
530-
((*out_tensors)[i]).scalar<int64>()() = 0l;
536+
((*out_tensors)[i]).scalar<int64>()() =
537+
absl::any_cast<int64_t>(typed_default_values[i]);
531538
break;
532539
case DT_FLOAT:
533-
((*out_tensors)[i]).scalar<float>()() = 0.0f;
540+
((*out_tensors)[i]).scalar<float>()() =
541+
absl::any_cast<float>(typed_default_values[i]);
534542
break;
535543
case DT_DOUBLE:
536-
((*out_tensors)[i]).scalar<double>()() = 0.0;
544+
((*out_tensors)[i]).scalar<double>()() =
545+
absl::any_cast<double>(typed_default_values[i]);
537546
break;
538547
case DT_STRING:
539-
((*out_tensors)[i]).scalar<tstring>()() = "";
548+
((*out_tensors)[i]).scalar<tstring>()() =
549+
absl::any_cast<string>(typed_default_values[i]);
540550
break;
541551
default:
542552
return errors::InvalidArgument(

tensorflow_io/core/ops/bigquery_ops.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ REGISTER_OP("IO>BigQueryReadSession")
3232
.Attr("dataset_id: string")
3333
.Attr("selected_fields: list(string) >= 1")
3434
.Attr("output_types: list(type) >= 1")
35+
.Attr("default_values: list(string) >= 1")
3536
.Attr("requested_streams: int")
3637
.Attr("data_format: string")
3738
.Attr("row_restriction: string = ''")
@@ -53,6 +54,7 @@ REGISTER_OP("IO>BigQueryDataset")
5354
.Attr("data_format: string")
5455
.Attr("selected_fields: list(string) >= 1")
5556
.Attr("output_types: list(type) >= 1")
57+
.Attr("default_values: list(string) >= 1")
5658
.Output("handle: variant")
5759
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
5860
// stateful to inhibit constant folding.

0 commit comments

Comments
 (0)