Skip to content

Commit

Permalink
Add caching during protobuf generation (GoogleCloudDataproc#1181)
Browse files Browse the repository at this point in the history
  • Loading branch information
agrawal-siddharth authored Feb 14, 2024
1 parent 02fd850 commit f61b3cc
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 62 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Next

* PR #1181: Add caching during protobuf generation

## 0.36.1 - 2024-01-31

* PR #1176: fix timestamp filter translation issue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import com.google.protobuf.DynamicMessage;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.apache.spark.sql.Row;
Expand Down Expand Up @@ -72,6 +74,70 @@

public class ProtobufUtils {

public static final class ProtobufSchemaFieldCacheEntry {
private DataType sparkType;
private boolean nullable;
private Descriptors.Descriptor nestedTypeDescriptor;
private Descriptors.FieldDescriptor fieldDescriptor;
private Optional<TypeConverter> typeConverterOptional;
private Optional<SupportedCustomDataType> customDataTypeOptional;

public ProtobufSchemaFieldCacheEntry(
DataType sparkType,
boolean nullable,
Descriptors.Descriptor nestedTypeDescriptor,
Descriptors.FieldDescriptor fieldDescriptor,
Optional<TypeConverter> typeConverterOptional,
Optional<SupportedCustomDataType> customDataTypeOptional) {
this.sparkType = sparkType;
this.nullable = nullable;
this.nestedTypeDescriptor = nestedTypeDescriptor;
this.fieldDescriptor = fieldDescriptor;
this.typeConverterOptional = typeConverterOptional;
this.customDataTypeOptional = customDataTypeOptional;
}

public DataType getSparkType() {
return sparkType;
}

public boolean getNullable() {
return nullable;
}

public Descriptors.Descriptor getNestedTypeDescriptor() {
return nestedTypeDescriptor;
}

public Descriptors.FieldDescriptor getFieldDescriptor() {
return fieldDescriptor;
}

public Optional<TypeConverter> getTypeConverterOptional() {
return typeConverterOptional;
}

public Optional<SupportedCustomDataType> getCustomDataTypeOptional() {
return customDataTypeOptional;
}
}

private static ProtobufSchemaFieldCacheEntry computeProtobufSchemaFieldEntry(
StructType schema, Descriptors.Descriptor schemaDescriptor, int fieldIndex) {
StructField sparkField = schema.fields()[fieldIndex];
int protoFieldNumber = fieldIndex + 1;
return new ProtobufSchemaFieldCacheEntry(
sparkField.dataType(),
sparkField.nullable(),
schemaDescriptor.findNestedTypeByName(
ProtobufUtils.RESERVED_NESTED_TYPE_NAME + (protoFieldNumber)),
schemaDescriptor.findFieldByNumber(protoFieldNumber),
SparkBigQueryUtil.getTypeConverterStream()
.filter(tc -> tc.supportsSparkType(sparkField.dataType()))
.findFirst(),
SupportedCustomDataType.of(sparkField.dataType()));
}

static final Logger logger = LoggerFactory.getLogger(ProtobufUtils.class);
// The maximum nesting depth of a BigQuery RECORD:
private static final int MAX_BIGQUERY_NESTED_DEPTH = 15;
Expand Down Expand Up @@ -287,8 +353,17 @@ public static ProtoRows toProtoRows(StructType sparkSchema, InternalRow[] rows)
try {
Descriptors.Descriptor schemaDescriptor = toDescriptor(sparkSchema);
ProtoRows.Builder protoRows = ProtoRows.newBuilder();
Map<Integer, ProtobufSchemaFieldCacheEntry> fieldIndexToEntryMap = new HashMap<>();
DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(schemaDescriptor);

for (InternalRow row : rows) {
DynamicMessage rowMessage = buildSingleRowMessage(sparkSchema, schemaDescriptor, row);
DynamicMessage rowMessage =
buildSingleRowMessage(
sparkSchema,
schemaDescriptor,
row,
Optional.of(fieldIndexToEntryMap),
messageBuilder);
protoRows.addSerializedRows(rowMessage.toByteString());
}
return protoRows.build();
Expand All @@ -298,27 +373,41 @@ public static ProtoRows toProtoRows(StructType sparkSchema, InternalRow[] rows)
}

public static DynamicMessage buildSingleRowMessage(
StructType schema, Descriptors.Descriptor schemaDescriptor, InternalRow row) {
DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(schemaDescriptor);
StructType schema,
Descriptors.Descriptor schemaDescriptor,
InternalRow row,
Optional<Map<Integer, ProtobufSchemaFieldCacheEntry>> fieldIndexToEntryMap,
DynamicMessage.Builder messageBuilder) {
messageBuilder.clear();

for (int fieldIndex = 0; fieldIndex < schemaDescriptor.getFields().size(); fieldIndex++) {
int protoFieldNumber = fieldIndex + 1;

StructField sparkField = schema.fields()[fieldIndex];
DataType sparkType = sparkField.dataType();

Object sparkValue = row.get(fieldIndex, sparkType);
boolean nullable = sparkField.nullable();
Descriptors.Descriptor nestedTypeDescriptor =
schemaDescriptor.findNestedTypeByName(RESERVED_NESTED_TYPE_NAME + (protoFieldNumber));
int fieldIndexCopy = fieldIndex;
ProtobufSchemaFieldCacheEntry protobufSchemaFieldEntry =
fieldIndexToEntryMap.isPresent()
? fieldIndexToEntryMap
.get()
.computeIfAbsent(
fieldIndexCopy,
k -> {
return computeProtobufSchemaFieldEntry(
schema, schemaDescriptor, fieldIndexCopy);
})
: computeProtobufSchemaFieldEntry(schema, schemaDescriptor, fieldIndexCopy);
Object sparkValue = row.get(fieldIndex, protobufSchemaFieldEntry.getSparkType());
Object protoValue =
convertSparkValueToProtoRowValue(sparkType, sparkValue, nullable, nestedTypeDescriptor);
convertSparkValueToProtoRowValue(
protobufSchemaFieldEntry.getSparkType(),
sparkValue,
protobufSchemaFieldEntry.getNullable(),
protobufSchemaFieldEntry.getNestedTypeDescriptor(),
protobufSchemaFieldEntry.getTypeConverterOptional(),
protobufSchemaFieldEntry.getCustomDataTypeOptional());

if (protoValue == null) {
continue;
}

messageBuilder.setField(schemaDescriptor.findFieldByNumber(protoFieldNumber), protoValue);
messageBuilder.setField(protobufSchemaFieldEntry.getFieldDescriptor(), protoValue);
}

return messageBuilder.build();
Expand All @@ -343,7 +432,9 @@ private static Object convertSparkValueToProtoRowValue(
DataType sparkType,
Object sparkValue,
boolean nullable,
Descriptors.Descriptor nestedTypeDescriptor) {
Descriptors.Descriptor nestedTypeDescriptor,
Optional<TypeConverter> typeConverterOptional,
Optional<SupportedCustomDataType> customDataTypeOptional) {
if (sparkValue == null) {
if (!nullable) {
throw new IllegalArgumentException("Non-nullable field was null.");
Expand All @@ -352,43 +443,52 @@ private static Object convertSparkValueToProtoRowValue(
}
}

DataType finalSparkType = sparkType;
Optional<Object> protoValueFromConverter =
SparkBigQueryUtil.getTypeConverterStream()
.filter(tc -> tc.supportsSparkType(finalSparkType))
.map(tc -> tc.sparkToProtoValue(sparkValue))
.findFirst();
if (protoValueFromConverter.isPresent()) {
return protoValueFromConverter.get();
if (typeConverterOptional.isPresent()) {
return typeConverterOptional.get().sparkToProtoValue(sparkValue);
}

// UDT support
Optional<SupportedCustomDataType> customDataType = SupportedCustomDataType.of(sparkType);
sparkType = customDataType.map(SupportedCustomDataType::getSqlType).orElse(sparkType);
if (customDataType.isPresent()) {
if (customDataTypeOptional.isPresent()) {
sparkType = customDataTypeOptional.map(SupportedCustomDataType::getSqlType).orElse(sparkType);
InternalRow internalRow;
if (sparkValue instanceof InternalRow) {
// Spark 2.4
internalRow = (InternalRow) sparkValue;
} else {
// spark 3.x
internalRow = customDataType.get().serialize(sparkValue);
internalRow = customDataTypeOptional.get().serialize(sparkValue);
}
return buildSingleRowMessage((StructType) sparkType, nestedTypeDescriptor, internalRow);
return buildSingleRowMessage(
(StructType) sparkType,
nestedTypeDescriptor,
internalRow,
/*fieldIndexToEntryMap*/ Optional.empty(),
DynamicMessage.newBuilder(nestedTypeDescriptor));
}

if (sparkType instanceof ArrayType) {
ArrayType arrayType = (ArrayType) sparkType;
DataType elementType = arrayType.elementType();
boolean containsNull = arrayType.containsNull();
Optional<TypeConverter> typeConverterElementOptional =
SparkBigQueryUtil.getTypeConverterStream()
.filter(tc -> tc.supportsSparkType(elementType))
.findFirst();
Optional<SupportedCustomDataType> customDataTypeElementOptional =
SupportedCustomDataType.of(elementType);
List<Object> protoValue = new ArrayList<>();
// having issues to convert WrappedArray to Object[] in Java
if (sparkValue instanceof ArrayData) {
Object[] sparkArrayData = ((ArrayData) sparkValue).toObjectArray(elementType);
for (Object sparkElement : sparkArrayData) {
Object converted =
convertSparkValueToProtoRowValue(
elementType, sparkElement, containsNull, nestedTypeDescriptor);
elementType,
sparkElement,
containsNull,
nestedTypeDescriptor,
typeConverterElementOptional,
customDataTypeElementOptional);
if (converted == null) {
continue;
}
Expand All @@ -400,7 +500,12 @@ private static Object convertSparkValueToProtoRowValue(
for (int i = 0; i < sparkArrayDataLength; i++) {
Object converted =
convertSparkValueToProtoRowValue(
elementType, sparkArrayData.apply(i), containsNull, nestedTypeDescriptor);
elementType,
sparkArrayData.apply(i),
containsNull,
nestedTypeDescriptor,
typeConverterElementOptional,
customDataTypeElementOptional);
if (converted == null) {
continue;
}
Expand All @@ -417,7 +522,12 @@ private static Object convertSparkValueToProtoRowValue(
} else {
internalRow = (InternalRow) sparkValue;
}
return buildSingleRowMessage((StructType) sparkType, nestedTypeDescriptor, internalRow);
return buildSingleRowMessage(
(StructType) sparkType,
nestedTypeDescriptor,
internalRow,
/*fieldIndexToEntryMap*/ Optional.empty(),
DynamicMessage.newBuilder(nestedTypeDescriptor));
}

if (sparkType instanceof ByteType
Expand Down Expand Up @@ -486,23 +596,53 @@ public <A> Function1<Tuple2, A> andThen(Function1<Object, A> g) {
MapData map = (MapData) sparkValue;
Object[] keys = map.keyArray().toObjectArray(mapType.keyType());
Object[] values = map.valueArray().toObjectArray(mapType.valueType());
Optional<TypeConverter> typeConverterMapKeyOptional =
SparkBigQueryUtil.getTypeConverterStream()
.filter(tc -> tc.supportsSparkType(mapType.keyType()))
.findFirst();
Optional<SupportedCustomDataType> customDataTypeMapKeyOptional =
SupportedCustomDataType.of(mapType.keyType());
Optional<TypeConverter> typeConverterMapValueOptional =
SparkBigQueryUtil.getTypeConverterStream()
.filter(tc -> tc.supportsSparkType(mapType.valueType()))
.findFirst();
Optional<SupportedCustomDataType> customDataTypeMapValueOptional =
SupportedCustomDataType.of(mapType.valueType());
for (int i = 0; i < map.numElements(); i++) {
Object key =
convertSparkValueToProtoRowValue(
mapType.keyType(), keys[i], /* nullable */ false, nestedTypeDescriptor);
mapType.keyType(),
keys[i], /* nullable */
false,
nestedTypeDescriptor,
typeConverterMapKeyOptional,
customDataTypeMapKeyOptional);
Object value =
convertSparkValueToProtoRowValue(
mapType.valueType(),
values[i],
mapType.valueContainsNull(),
nestedTypeDescriptor);
nestedTypeDescriptor,
typeConverterMapValueOptional,
customDataTypeMapValueOptional);
entries.add(new GenericInternalRow(new Object[] {key, value}));
}
}
ArrayData resultArray = ArrayData.toArrayData(entries.stream().toArray());
ArrayType resultArrayType = ArrayType.apply(mapStructType, /* containsNull */ false);
Optional<TypeConverter> typeConverterArrayOptional =
SparkBigQueryUtil.getTypeConverterStream()
.filter(tc -> tc.supportsSparkType(resultArrayType))
.findFirst();
Optional<SupportedCustomDataType> customDataTypeArrayOptional =
SupportedCustomDataType.of(resultArrayType);
return convertSparkValueToProtoRowValue(
resultArrayType, resultArray, nullable, nestedTypeDescriptor);
resultArrayType,
resultArray,
nullable,
nestedTypeDescriptor,
typeConverterArrayOptional,
customDataTypeArrayOptional);
}

throw new IllegalStateException("Unexpected type: " + sparkType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.google.cloud.spark.bigquery.write.context;

import static com.google.cloud.spark.bigquery.ProtobufUtils.ProtobufSchemaFieldCacheEntry;
import static com.google.cloud.spark.bigquery.ProtobufUtils.buildSingleRowMessage;
import static com.google.cloud.spark.bigquery.ProtobufUtils.toDescriptor;

Expand All @@ -26,7 +27,10 @@
import com.google.common.base.Optional;
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
Expand All @@ -41,6 +45,8 @@ public class BigQueryDirectDataWriterContext implements DataWriterContext<Intern
private final String tablePath;
private final StructType sparkSchema;
private final Descriptors.Descriptor schemaDescriptor;
private final Map<Integer, ProtobufSchemaFieldCacheEntry> fieldIndexToEntryMap;
private final DynamicMessage.Builder messageBuilder;

/**
* A helper object to assist the BigQueryDataWriter with all the writing: essentially does all the
Expand Down Expand Up @@ -70,6 +76,8 @@ public BigQueryDirectDataWriterContext(
throw new BigQueryConnectorException.InvalidSchemaException(
"Could not convert spark-schema to descriptor object", e);
}
this.fieldIndexToEntryMap = new HashMap<>();
this.messageBuilder = DynamicMessage.newBuilder(this.schemaDescriptor);

this.writerHelper =
new BigQueryDirectDataWriterHelper(
Expand All @@ -85,7 +93,13 @@ public BigQueryDirectDataWriterContext(
@Override
public void write(InternalRow record) throws IOException {
ByteString message =
buildSingleRowMessage(sparkSchema, schemaDescriptor, record).toByteString();
buildSingleRowMessage(
sparkSchema,
schemaDescriptor,
record,
java.util.Optional.of(fieldIndexToEntryMap),
messageBuilder)
.toByteString();
writerHelper.addRow(message);
}

Expand Down
Loading

0 comments on commit f61b3cc

Please sign in to comment.