Skip to content

Commit 78fe3b6

Browse files
committed
Enhancement: Support batch_read implementation based on Spark DataSource V2 API.
1 parent 139d320 commit 78fe3b6

File tree

7 files changed

+677
-4
lines changed

7 files changed

+677
-4
lines changed

spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/catalog/OceanBaseTable.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ package com.oceanbase.spark.catalog
1818
import com.oceanbase.spark.config.OceanBaseConfig
1919
import com.oceanbase.spark.dialect.OceanBaseDialect
2020
import com.oceanbase.spark.read.JDBCLimitScanBuilder
21+
import com.oceanbase.spark.read.v2.OBJdbcScanBuilder
2122
import com.oceanbase.spark.writer.v2.{DirectLoadWriteBuilderV2, JDBCWriteBuilder}
2223

2324
import org.apache.spark.sql.SparkSession
2425
import org.apache.spark.sql.connector.catalog._
2526
import org.apache.spark.sql.connector.catalog.TableCapability._
27+
import org.apache.spark.sql.connector.read.ScanBuilder
2628
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
2729
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
2830
import org.apache.spark.sql.types.StructType
@@ -47,10 +49,13 @@ case class OceanBaseTable(
4749
util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE)
4850
}
4951

50-
override def newScanBuilder(options: CaseInsensitiveStringMap): JDBCLimitScanBuilder = {
52+
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
5153
val mergedOptions = new JDBCOptions(
5254
jdbcOptions.parameters ++ options.asCaseSensitiveMap().asScala)
53-
JDBCLimitScanBuilder(SparkSession.active, schema, mergedOptions)
55+
jdbcOptions.parameters.get("enable-legacy_batch_reader").map(_.toBoolean) match {
56+
case Some(true) => JDBCLimitScanBuilder(SparkSession.active, schema, mergedOptions)
57+
case _ => OBJdbcScanBuilder(schema, mergedOptions, dialect)
58+
}
5459
}
5560

5661
override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {

spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/dialect/OceanBaseDialect.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@ package com.oceanbase.spark.dialect
1818

1919
import com.oceanbase.spark.utils.OBJdbcUtils.executeStatement
2020

21+
import org.apache.commons.lang3.StringUtils
2122
import org.apache.spark.internal.Logging
23+
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
2224
import org.apache.spark.sql.connector.catalog.TableChange
2325
import org.apache.spark.sql.connector.expressions.Transform
2426
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite}
27+
import org.apache.spark.sql.internal.SQLConf
2528
import org.apache.spark.sql.jdbc.JdbcDialects
2629
import org.apache.spark.sql.types.StructType
2730

28-
import java.sql.Connection
31+
import java.sql.{Connection, Date, Timestamp}
32+
import java.time.{Instant, LocalDate}
2933

3034
import scala.collection.mutable.ArrayBuffer
3135
import scala.util.Try
@@ -150,6 +154,36 @@ abstract class OceanBaseDialect extends Logging with Serializable {
150154
tableName: String,
151155
schema: StructType,
152156
priKeyColumnInfo: ArrayBuffer[PriKeyColumnInfo]): String
157+
158+
/**
159+
* Escape special characters in SQL string literals.
160+
* @param value
161+
* The string to be escaped.
162+
* @return
163+
* Escaped string.
164+
*/
165+
def escapeSql(value: String): String =
166+
if (value == null) null else StringUtils.replace(value, "'", "''")
167+
168+
/**
169+
* Converts value to SQL expression.
170+
* @param value
171+
* The value to be converted.
172+
* @return
173+
* Converted value.
174+
*/
175+
def compileValue(value: Any): Any = value match {
176+
case stringValue: String => s"'${escapeSql(stringValue)}'"
177+
case timestampValue: Timestamp => "'" + timestampValue + "'"
178+
case timestampValue: Instant =>
179+
val timestampFormatter = TimestampFormatter.getFractionFormatter(
180+
DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone))
181+
s"'${timestampFormatter.format(timestampValue)}'"
182+
case dateValue: Date => "'" + dateValue + "'"
183+
case dateValue: LocalDate => s"'${DateFormatter().format(dateValue)}'"
184+
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
185+
case _ => value
186+
}
153187
}
154188

155189
case class PriKeyColumnInfo(columnName: String, columnType: String, columnKey: String)

spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/dialect/OceanBaseOracleDialect.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import org.apache.spark.sql.connector.expressions.Transform
2020
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite}
2121
import org.apache.spark.sql.types.StructType
2222

23-
import java.sql.Connection
23+
import java.sql.{Connection, Date, Timestamp}
2424
import java.util
2525

2626
import scala.collection.mutable.ArrayBuffer
@@ -78,4 +78,17 @@ class OceanBaseOracleDialect extends OceanBaseDialect {
7878
priKeyColumnInfo: ArrayBuffer[PriKeyColumnInfo]): String = {
7979
throw new UnsupportedOperationException("Not currently supported in oracle mode")
8080
}
81+
82+
override def compileValue(value: Any): Any = value match {
83+
// The JDBC drivers support date literals in SQL statements written in the
84+
// format: {d 'yyyy-mm-dd'} and timestamp literals in SQL statements written
85+
// in the format: {ts 'yyyy-mm-dd hh:mm:ss.f...'}. For details, see
86+
// 'Oracle Database JDBC Developer’s Guide and Reference, 11g Release 1 (11.1)'
87+
// Appendix A Reference Information.
88+
case stringValue: String => s"'${escapeSql(stringValue)}'"
89+
case timestampValue: Timestamp => "{ts '" + timestampValue + "'}"
90+
case dateValue: Date => "{d '" + dateValue + "'}"
91+
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
92+
case _ => value
93+
}
8194
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
/*
2+
* Copyright 2024 OceanBase.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.oceanbase.spark.read.v2
18+
19+
import com.oceanbase.spark.dialect.OceanBaseDialect
20+
import com.oceanbase.spark.read.v2.OBJdbcReader.{makeGetters, OBValueGetter}
21+
import com.oceanbase.spark.utils.OBJdbcUtils
22+
23+
import org.apache.spark.internal.Logging
24+
import org.apache.spark.sql.ExprUtils.compileFilter
25+
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
26+
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
27+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
28+
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
29+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
30+
import org.apache.spark.sql.sources.Filter
31+
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, Metadata, ShortType, StringType, StructType, TimestampType}
32+
import org.apache.spark.unsafe.types.UTF8String
33+
34+
import java.sql.{PreparedStatement, ResultSet}
35+
import java.util.Objects
36+
import java.util.concurrent.TimeUnit
37+
38+
class OBJdbcReader(
39+
schema: StructType,
40+
options: JDBCOptions,
41+
partition: InputPartition,
42+
pushedFilter: Array[Filter],
43+
dialect: OceanBaseDialect)
44+
extends PartitionReader[InternalRow]
45+
with SQLConfHelper
46+
with Logging {
47+
48+
private val getters: Array[OBValueGetter] = makeGetters(schema)
49+
private val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
50+
private lazy val conn = OBJdbcUtils.getConnection(options)
51+
private lazy val stmt: PreparedStatement =
52+
conn.prepareStatement(buildQuerySql(), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
53+
private lazy val rs: ResultSet = {
54+
stmt.setFetchSize(options.fetchSize)
55+
stmt.setQueryTimeout(options.queryTimeout)
56+
stmt.executeQuery()
57+
}
58+
59+
private var currentRecord: InternalRow = _
60+
61+
override def next(): Boolean = {
62+
val hasNext = rs.next()
63+
if (hasNext) currentRecord = {
64+
var i = 0
65+
while (i < getters.length) {
66+
getters(i)(rs, mutableRow, i)
67+
if (rs.wasNull) mutableRow.setNullAt(i)
68+
i = i + 1
69+
}
70+
mutableRow
71+
}
72+
hasNext
73+
}
74+
75+
override def get(): InternalRow = currentRecord
76+
77+
override def close(): Unit = {
78+
if (Objects.nonNull(rs)) {
79+
rs.close()
80+
}
81+
if (Objects.nonNull(stmt)) {
82+
stmt.close()
83+
}
84+
if (Objects.nonNull(conn)) {
85+
conn.close()
86+
}
87+
}
88+
89+
private def buildQuerySql(): String = {
90+
val columns = schema.map(col => dialect.quoteIdentifier(col.name)).toArray
91+
val columnStr: String = if (columns.isEmpty) "1" else columns.mkString(",")
92+
93+
val filterWhereClause: String =
94+
pushedFilter
95+
.flatMap(compileFilter(_, dialect))
96+
.map(p => s"($p)")
97+
.mkString(" AND ")
98+
99+
val whereClause: String = {
100+
if (filterWhereClause.nonEmpty) {
101+
"WHERE " + filterWhereClause
102+
} else {
103+
""
104+
}
105+
}
106+
val part: OBMySQLPartition = partition.asInstanceOf[OBMySQLPartition]
107+
s"""
108+
|SELECT $columnStr FROM ${options.tableOrQuery} ${part.partitionClause}
109+
|$whereClause ${part.limitOffsetClause}
110+
|""".stripMargin
111+
}
112+
}
113+
114+
object OBJdbcReader extends SQLConfHelper {
115+
116+
// A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
117+
// for `MutableRow`. The last argument `Int` means the index for the value to be set in
118+
// the row and also used for the value in `ResultSet`.
119+
type OBValueGetter = (ResultSet, InternalRow, Int) => Unit
120+
121+
/**
122+
* Creates `JDBCValueGetter`s according to [[StructType]], which can set each value from
123+
* `ResultSet` to each field of [[InternalRow]] correctly.
124+
*/
125+
def makeGetters(schema: StructType): Array[OBValueGetter] =
126+
schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
127+
128+
private def makeGetter(dt: DataType, metadata: Metadata): OBValueGetter = dt match {
129+
case BooleanType =>
130+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1))
131+
132+
case DateType =>
133+
(rs: ResultSet, row: InternalRow, pos: Int) =>
134+
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
135+
val dateVal = rs.getDate(pos + 1)
136+
if (dateVal != null) {
137+
row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
138+
} else {
139+
row.update(pos, null)
140+
}
141+
142+
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
143+
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
144+
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
145+
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
146+
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
147+
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
148+
// retrieve it, you will get wrong result 199.99.
149+
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
150+
case _: DecimalType =>
151+
(rs: ResultSet, row: InternalRow, pos: Int) =>
152+
val decimal =
153+
nullSafeConvert[java.math.BigDecimal](
154+
rs.getBigDecimal(pos + 1),
155+
d => Decimal(d, d.precision(), d.scale()))
156+
row.update(pos, decimal)
157+
158+
case DoubleType =>
159+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setDouble(pos, rs.getDouble(pos + 1))
160+
161+
case FloatType =>
162+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setFloat(pos, rs.getFloat(pos + 1))
163+
164+
case IntegerType =>
165+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setInt(pos, rs.getInt(pos + 1))
166+
167+
case LongType if metadata.contains("binarylong") =>
168+
(rs: ResultSet, row: InternalRow, pos: Int) =>
169+
val bytes = rs.getBytes(pos + 1)
170+
var ans = 0L
171+
var j = 0
172+
while (j < bytes.length) {
173+
ans = 256 * ans + (255 & bytes(j))
174+
j = j + 1
175+
}
176+
row.setLong(pos, ans)
177+
178+
case LongType =>
179+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setLong(pos, rs.getLong(pos + 1))
180+
181+
case ShortType =>
182+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setShort(pos, rs.getShort(pos + 1))
183+
184+
case ByteType =>
185+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setByte(pos, rs.getByte(pos + 1))
186+
187+
case StringType if metadata.contains("rowid") =>
188+
(rs: ResultSet, row: InternalRow, pos: Int) =>
189+
row.update(pos, UTF8String.fromString(rs.getRowId(pos + 1).toString))
190+
191+
case StringType =>
192+
(rs: ResultSet, row: InternalRow, pos: Int) =>
193+
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
194+
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
195+
196+
// SPARK-34357 - sql TIME type represents as zero epoch timestamp.
197+
// It is mapped as Spark TimestampType but fixed at 1970-01-01 for day,
198+
// time portion is time of day, with no reference to a particular calendar,
199+
// time zone or date, with a precision till microseconds.
200+
// It stores the number of milliseconds after midnight, 00:00:00.000000
201+
case TimestampType if metadata.contains("logical_time_type") =>
202+
(rs: ResultSet, row: InternalRow, pos: Int) => {
203+
val rawTime = rs.getTime(pos + 1)
204+
if (rawTime != null) {
205+
val localTimeMicro = TimeUnit.NANOSECONDS.toMicros(rawTime.toLocalTime.toNanoOfDay)
206+
val utcTimeMicro = DateTimeUtils.toUTCTime(localTimeMicro, conf.sessionLocalTimeZone)
207+
row.setLong(pos, utcTimeMicro)
208+
} else {
209+
row.update(pos, null)
210+
}
211+
}
212+
213+
case TimestampType =>
214+
(rs: ResultSet, row: InternalRow, pos: Int) =>
215+
val t = rs.getTimestamp(pos + 1)
216+
if (t != null) {
217+
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
218+
} else {
219+
row.update(pos, null)
220+
}
221+
222+
case BinaryType =>
223+
(rs: ResultSet, row: InternalRow, pos: Int) => row.update(pos, rs.getBytes(pos + 1))
224+
225+
case ArrayType(et, _) =>
226+
val elementConversion = et match {
227+
case TimestampType =>
228+
(array: Object) =>
229+
array.asInstanceOf[Array[java.sql.Timestamp]].map {
230+
timestamp => nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
231+
}
232+
233+
case StringType =>
234+
(array: Object) =>
235+
// some underling types are not String such as uuid, inet, cidr, etc.
236+
array
237+
.asInstanceOf[Array[java.lang.Object]]
238+
.map(obj => if (obj == null) null else UTF8String.fromString(obj.toString))
239+
240+
case DateType =>
241+
(array: Object) =>
242+
array.asInstanceOf[Array[java.sql.Date]].map {
243+
date => nullSafeConvert(date, DateTimeUtils.fromJavaDate)
244+
}
245+
246+
case dt: DecimalType =>
247+
(array: Object) =>
248+
array.asInstanceOf[Array[java.math.BigDecimal]].map {
249+
decimal =>
250+
nullSafeConvert[java.math.BigDecimal](
251+
decimal,
252+
d => Decimal(d, dt.precision, dt.scale))
253+
}
254+
255+
case LongType if metadata.contains("binarylong") =>
256+
throw new UnsupportedOperationException(
257+
s"unsupportedArrayElementTypeBasedOnBinaryError ${dt.catalogString}")
258+
259+
case ArrayType(_, _) =>
260+
throw new UnsupportedOperationException(s"Not support Array data-type now")
261+
262+
case _ => (array: Object) => array.asInstanceOf[Array[Any]]
263+
}
264+
265+
(rs: ResultSet, row: InternalRow, pos: Int) =>
266+
val array = nullSafeConvert[java.sql.Array](
267+
input = rs.getArray(pos + 1),
268+
array => new GenericArrayData(elementConversion.apply(array.getArray)))
269+
row.update(pos, array)
270+
271+
case _ =>
272+
throw new UnsupportedOperationException(s"unsupportedJdbcTypeError ${dt.catalogString}")
273+
}
274+
275+
private def nullSafeConvert[T](input: T, f: T => Any): Any = {
276+
if (input == null) {
277+
null
278+
} else {
279+
f(input)
280+
}
281+
}
282+
}

0 commit comments

Comments
 (0)