Skip to content

Commit f213f4f

Browse files
authored
Enhancement: Support batch_read implementation based on Spark DataSource V2 API (#25)
* Enhancement: Support batch_read implementation based on Spark DataSource V2 API. * Enhancement: Fix compatibility with Spark 3.1.
1 parent 139d320 commit f213f4f

File tree

13 files changed

+1081
-11
lines changed

13 files changed

+1081
-11
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
/*
2+
* Copyright 2024 OceanBase.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.oceanbase.spark.reader.v2
18+
19+
import com.oceanbase.spark.dialect.OceanBaseDialect
20+
import com.oceanbase.spark.reader.v2.OBJdbcReader.{makeGetters, OBValueGetter}
21+
import com.oceanbase.spark.utils.OBJdbcUtils
22+
23+
import org.apache.spark.internal.Logging
24+
import org.apache.spark.sql.ExprUtils.compileFilter
25+
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
26+
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
27+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
28+
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
29+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
30+
import org.apache.spark.sql.sources.Filter
31+
import org.apache.spark.sql.types._
32+
import org.apache.spark.unsafe.types.UTF8String
33+
34+
import java.sql.{PreparedStatement, ResultSet}
35+
import java.util.Objects
36+
import java.util.concurrent.TimeUnit
37+
38+
class OBJdbcReader(
39+
schema: StructType,
40+
options: JDBCOptions,
41+
partition: InputPartition,
42+
pushedFilter: Array[Filter],
43+
dialect: OceanBaseDialect)
44+
extends PartitionReader[InternalRow]
45+
with SQLConfHelper
46+
with Logging {
47+
48+
private val getters: Array[OBValueGetter] = makeGetters(schema)
49+
private val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType))
50+
private lazy val conn = OBJdbcUtils.getConnection(options)
51+
private lazy val stmt: PreparedStatement =
52+
conn.prepareStatement(buildQuerySql(), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
53+
private lazy val rs: ResultSet = {
54+
stmt.setFetchSize(options.fetchSize)
55+
stmt.setQueryTimeout(options.queryTimeout)
56+
stmt.executeQuery()
57+
}
58+
59+
private var currentRecord: InternalRow = _
60+
61+
override def next(): Boolean = {
62+
val hasNext = rs.next()
63+
if (hasNext) currentRecord = {
64+
var i = 0
65+
while (i < getters.length) {
66+
getters(i)(rs, mutableRow, i)
67+
if (rs.wasNull) mutableRow.setNullAt(i)
68+
i = i + 1
69+
}
70+
mutableRow
71+
}
72+
hasNext
73+
}
74+
75+
override def get(): InternalRow = currentRecord
76+
77+
override def close(): Unit = {
78+
if (Objects.nonNull(rs)) {
79+
rs.close()
80+
}
81+
if (Objects.nonNull(stmt)) {
82+
stmt.close()
83+
}
84+
if (Objects.nonNull(conn)) {
85+
conn.close()
86+
}
87+
}
88+
89+
private def buildQuerySql(): String = {
90+
val columns = schema.map(col => dialect.quoteIdentifier(col.name)).toArray
91+
val columnStr: String = if (columns.isEmpty) "1" else columns.mkString(",")
92+
93+
val filterWhereClause: String =
94+
pushedFilter
95+
.flatMap(compileFilter(_, dialect))
96+
.map(p => s"($p)")
97+
.mkString(" AND ")
98+
99+
val whereClause: String = {
100+
if (filterWhereClause.nonEmpty) {
101+
"WHERE " + filterWhereClause
102+
} else {
103+
""
104+
}
105+
}
106+
val part: OBMySQLPartition = partition.asInstanceOf[OBMySQLPartition]
107+
s"""
108+
|SELECT $columnStr FROM ${options.tableOrQuery} ${part.partitionClause}
109+
|$whereClause ${part.limitOffsetClause}
110+
|""".stripMargin
111+
}
112+
}
113+
114+
object OBJdbcReader extends SQLConfHelper {
115+
116+
// A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
117+
// for `MutableRow`. The last argument `Int` means the index for the value to be set in
118+
// the row and also used for the value in `ResultSet`.
119+
type OBValueGetter = (ResultSet, InternalRow, Int) => Unit
120+
121+
/**
122+
* Creates `JDBCValueGetter`s according to [[StructType]], which can set each value from
123+
* `ResultSet` to each field of [[InternalRow]] correctly.
124+
*/
125+
def makeGetters(schema: StructType): Array[OBValueGetter] =
126+
schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
127+
128+
private def makeGetter(dt: DataType, metadata: Metadata): OBValueGetter = dt match {
129+
case BooleanType =>
130+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1))
131+
132+
case DateType =>
133+
(rs: ResultSet, row: InternalRow, pos: Int) =>
134+
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
135+
val dateVal = rs.getDate(pos + 1)
136+
if (dateVal != null) {
137+
row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
138+
} else {
139+
row.update(pos, null)
140+
}
141+
142+
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
143+
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
144+
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
145+
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
146+
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
147+
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
148+
// retrieve it, you will get wrong result 199.99.
149+
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
150+
case _: DecimalType =>
151+
(rs: ResultSet, row: InternalRow, pos: Int) =>
152+
val decimal =
153+
nullSafeConvert[java.math.BigDecimal](
154+
rs.getBigDecimal(pos + 1),
155+
d => Decimal(d, d.precision(), d.scale()))
156+
row.update(pos, decimal)
157+
158+
case DoubleType =>
159+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setDouble(pos, rs.getDouble(pos + 1))
160+
161+
case FloatType =>
162+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setFloat(pos, rs.getFloat(pos + 1))
163+
164+
case IntegerType =>
165+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setInt(pos, rs.getInt(pos + 1))
166+
167+
case LongType if metadata.contains("binarylong") =>
168+
(rs: ResultSet, row: InternalRow, pos: Int) =>
169+
val bytes = rs.getBytes(pos + 1)
170+
var ans = 0L
171+
var j = 0
172+
while (j < bytes.length) {
173+
ans = 256 * ans + (255 & bytes(j))
174+
j = j + 1
175+
}
176+
row.setLong(pos, ans)
177+
178+
case LongType =>
179+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setLong(pos, rs.getLong(pos + 1))
180+
181+
case ShortType =>
182+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setShort(pos, rs.getShort(pos + 1))
183+
184+
case ByteType =>
185+
(rs: ResultSet, row: InternalRow, pos: Int) => row.setByte(pos, rs.getByte(pos + 1))
186+
187+
case StringType if metadata.contains("rowid") =>
188+
(rs: ResultSet, row: InternalRow, pos: Int) =>
189+
row.update(pos, UTF8String.fromString(rs.getRowId(pos + 1).toString))
190+
191+
case StringType =>
192+
(rs: ResultSet, row: InternalRow, pos: Int) =>
193+
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
194+
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
195+
196+
// SPARK-34357 - sql TIME type represents as zero epoch timestamp.
197+
// It is mapped as Spark TimestampType but fixed at 1970-01-01 for day,
198+
// time portion is time of day, with no reference to a particular calendar,
199+
// time zone or date, with a precision till microseconds.
200+
// It stores the number of milliseconds after midnight, 00:00:00.000000
201+
case TimestampType if metadata.contains("logical_time_type") =>
202+
(rs: ResultSet, row: InternalRow, pos: Int) => {
203+
val rawTime = rs.getTime(pos + 1)
204+
if (rawTime != null) {
205+
val localTimeMicro = TimeUnit.NANOSECONDS.toMicros(rawTime.toLocalTime.toNanoOfDay)
206+
val utcTimeMicro = DateTimeUtils.toUTCTime(localTimeMicro, conf.sessionLocalTimeZone)
207+
row.setLong(pos, utcTimeMicro)
208+
} else {
209+
row.update(pos, null)
210+
}
211+
}
212+
213+
case TimestampType =>
214+
(rs: ResultSet, row: InternalRow, pos: Int) =>
215+
val t = rs.getTimestamp(pos + 1)
216+
if (t != null) {
217+
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
218+
} else {
219+
row.update(pos, null)
220+
}
221+
222+
case BinaryType =>
223+
(rs: ResultSet, row: InternalRow, pos: Int) => row.update(pos, rs.getBytes(pos + 1))
224+
225+
case ArrayType(et, _) =>
226+
val elementConversion = et match {
227+
case TimestampType =>
228+
(array: Object) =>
229+
array.asInstanceOf[Array[java.sql.Timestamp]].map {
230+
timestamp => nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
231+
}
232+
233+
case StringType =>
234+
(array: Object) =>
235+
// some underling types are not String such as uuid, inet, cidr, etc.
236+
array
237+
.asInstanceOf[Array[java.lang.Object]]
238+
.map(obj => if (obj == null) null else UTF8String.fromString(obj.toString))
239+
240+
case DateType =>
241+
(array: Object) =>
242+
array.asInstanceOf[Array[java.sql.Date]].map {
243+
date => nullSafeConvert(date, DateTimeUtils.fromJavaDate)
244+
}
245+
246+
case dt: DecimalType =>
247+
(array: Object) =>
248+
array.asInstanceOf[Array[java.math.BigDecimal]].map {
249+
decimal =>
250+
nullSafeConvert[java.math.BigDecimal](
251+
decimal,
252+
d => Decimal(d, dt.precision, dt.scale))
253+
}
254+
255+
case LongType if metadata.contains("binarylong") =>
256+
throw new UnsupportedOperationException(
257+
s"unsupportedArrayElementTypeBasedOnBinaryError ${dt.catalogString}")
258+
259+
case ArrayType(_, _) =>
260+
throw new UnsupportedOperationException(s"Not support Array data-type now")
261+
262+
case _ => (array: Object) => array.asInstanceOf[Array[Any]]
263+
}
264+
265+
(rs: ResultSet, row: InternalRow, pos: Int) =>
266+
val array = nullSafeConvert[java.sql.Array](
267+
input = rs.getArray(pos + 1),
268+
array => new GenericArrayData(elementConversion.apply(array.getArray)))
269+
row.update(pos, array)
270+
271+
case _ =>
272+
throw new UnsupportedOperationException(s"unsupportedJdbcTypeError ${dt.catalogString}")
273+
}
274+
275+
private def nullSafeConvert[T](input: T, f: T => Any): Any = {
276+
if (input == null) {
277+
null
278+
} else {
279+
f(input)
280+
}
281+
}
282+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright 2024 OceanBase.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.oceanbase.spark.reader.v2
18+
19+
import com.oceanbase.spark.dialect.OceanBaseDialect
20+
21+
import org.apache.spark.internal.Logging
22+
import org.apache.spark.sql.ExprUtils.compileFilter
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.connector.expressions.NamedReference
25+
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
26+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
27+
import org.apache.spark.sql.sources.Filter
28+
import org.apache.spark.sql.types.StructType
29+
30+
/**
31+
* This is for compatibility with Spark 3.1, which does not support the SupportsPushDownAggregates
32+
* feature.
33+
*/
34+
case class OBJdbcScanBuilder(
35+
schema: StructType,
36+
jdbcOptions: JDBCOptions,
37+
dialect: OceanBaseDialect)
38+
extends ScanBuilder
39+
with SupportsPushDownFilters
40+
with SupportsPushDownRequiredColumns
41+
with Logging {
42+
private var finalSchema = schema
43+
private var pushedFilter = Array.empty[Filter]
44+
45+
/** TODO: support org.apache.spark.sql.connector.read.SupportsPushDownV2Filters */
46+
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
47+
val (pushed, unSupported) = filters.partition(f => compileFilter(f, dialect).isDefined)
48+
this.pushedFilter = pushed
49+
unSupported
50+
}
51+
52+
override def pushedFilters(): Array[Filter] = pushedFilter
53+
54+
override def pruneColumns(requiredSchema: StructType): Unit = {
55+
val requiredCols = requiredSchema.map(_.name)
56+
this.finalSchema = StructType(finalSchema.filter(field => requiredCols.contains(field.name)))
57+
}
58+
59+
override def build(): Scan =
60+
OBJdbcBatchScan(
61+
finalSchema: StructType,
62+
jdbcOptions: JDBCOptions,
63+
pushedFilter: Array[Filter],
64+
dialect: OceanBaseDialect)
65+
}
66+
67+
case class OBJdbcBatchScan(
68+
schema: StructType,
69+
jdbcOptions: JDBCOptions,
70+
pushedFilter: Array[Filter],
71+
dialect: OceanBaseDialect)
72+
extends Scan {
73+
74+
override def readSchema(): StructType = schema
75+
76+
override def toBatch: Batch =
77+
new OBJdbcBatch(
78+
schema: StructType,
79+
jdbcOptions: JDBCOptions,
80+
pushedFilter: Array[Filter],
81+
dialect: OceanBaseDialect)
82+
}
83+
84+
class OBJdbcBatch(
85+
schema: StructType,
86+
jdbcOptions: JDBCOptions,
87+
pushedFilter: Array[Filter],
88+
dialect: OceanBaseDialect)
89+
extends Batch {
90+
private lazy val inputPartitions: Array[InputPartition] =
91+
OBMySQLPartition.columnPartition(jdbcOptions)
92+
93+
override def planInputPartitions(): Array[InputPartition] = inputPartitions
94+
95+
override def createReaderFactory(): PartitionReaderFactory = new OBJdbcReaderFactory(
96+
schema: StructType,
97+
jdbcOptions: JDBCOptions,
98+
pushedFilter: Array[Filter],
99+
dialect: OceanBaseDialect)
100+
}
101+
102+
class OBJdbcReaderFactory(
103+
schema: StructType,
104+
jdbcOptions: JDBCOptions,
105+
pushedFilter: Array[Filter],
106+
dialect: OceanBaseDialect)
107+
extends PartitionReaderFactory {
108+
109+
override def createReader(partition: InputPartition): PartitionReader[InternalRow] =
110+
new OBJdbcReader(
111+
schema: StructType,
112+
jdbcOptions: JDBCOptions,
113+
partition: InputPartition,
114+
pushedFilter: Array[Filter],
115+
dialect: OceanBaseDialect)
116+
}

0 commit comments

Comments
 (0)