|
| 1 | +/* |
| 2 | + * Copyright 2024 OceanBase. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package com.oceanbase.spark.reader.v2 |
| 18 | + |
| 19 | +import com.oceanbase.spark.dialect.OceanBaseDialect |
| 20 | +import com.oceanbase.spark.reader.v2.OBJdbcReader.{makeGetters, OBValueGetter} |
| 21 | +import com.oceanbase.spark.utils.OBJdbcUtils |
| 22 | + |
| 23 | +import org.apache.spark.internal.Logging |
| 24 | +import org.apache.spark.sql.ExprUtils.compileFilter |
| 25 | +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} |
| 26 | +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow |
| 27 | +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} |
| 28 | +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} |
| 29 | +import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions |
| 30 | +import org.apache.spark.sql.sources.Filter |
| 31 | +import org.apache.spark.sql.types._ |
| 32 | +import org.apache.spark.unsafe.types.UTF8String |
| 33 | + |
| 34 | +import java.sql.{PreparedStatement, ResultSet} |
| 35 | +import java.util.Objects |
| 36 | +import java.util.concurrent.TimeUnit |
| 37 | + |
| 38 | +class OBJdbcReader( |
| 39 | + schema: StructType, |
| 40 | + options: JDBCOptions, |
| 41 | + partition: InputPartition, |
| 42 | + pushedFilter: Array[Filter], |
| 43 | + dialect: OceanBaseDialect) |
| 44 | + extends PartitionReader[InternalRow] |
| 45 | + with SQLConfHelper |
| 46 | + with Logging { |
| 47 | + |
| 48 | + private val getters: Array[OBValueGetter] = makeGetters(schema) |
| 49 | + private val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) |
| 50 | + private lazy val conn = OBJdbcUtils.getConnection(options) |
| 51 | + private lazy val stmt: PreparedStatement = |
| 52 | + conn.prepareStatement(buildQuerySql(), ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) |
| 53 | + private lazy val rs: ResultSet = { |
| 54 | + stmt.setFetchSize(options.fetchSize) |
| 55 | + stmt.setQueryTimeout(options.queryTimeout) |
| 56 | + stmt.executeQuery() |
| 57 | + } |
| 58 | + |
| 59 | + private var currentRecord: InternalRow = _ |
| 60 | + |
| 61 | + override def next(): Boolean = { |
| 62 | + val hasNext = rs.next() |
| 63 | + if (hasNext) currentRecord = { |
| 64 | + var i = 0 |
| 65 | + while (i < getters.length) { |
| 66 | + getters(i)(rs, mutableRow, i) |
| 67 | + if (rs.wasNull) mutableRow.setNullAt(i) |
| 68 | + i = i + 1 |
| 69 | + } |
| 70 | + mutableRow |
| 71 | + } |
| 72 | + hasNext |
| 73 | + } |
| 74 | + |
| 75 | + override def get(): InternalRow = currentRecord |
| 76 | + |
| 77 | + override def close(): Unit = { |
| 78 | + if (Objects.nonNull(rs)) { |
| 79 | + rs.close() |
| 80 | + } |
| 81 | + if (Objects.nonNull(stmt)) { |
| 82 | + stmt.close() |
| 83 | + } |
| 84 | + if (Objects.nonNull(conn)) { |
| 85 | + conn.close() |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + private def buildQuerySql(): String = { |
| 90 | + val columns = schema.map(col => dialect.quoteIdentifier(col.name)).toArray |
| 91 | + val columnStr: String = if (columns.isEmpty) "1" else columns.mkString(",") |
| 92 | + |
| 93 | + val filterWhereClause: String = |
| 94 | + pushedFilter |
| 95 | + .flatMap(compileFilter(_, dialect)) |
| 96 | + .map(p => s"($p)") |
| 97 | + .mkString(" AND ") |
| 98 | + |
| 99 | + val whereClause: String = { |
| 100 | + if (filterWhereClause.nonEmpty) { |
| 101 | + "WHERE " + filterWhereClause |
| 102 | + } else { |
| 103 | + "" |
| 104 | + } |
| 105 | + } |
| 106 | + val part: OBMySQLPartition = partition.asInstanceOf[OBMySQLPartition] |
| 107 | + s""" |
| 108 | + |SELECT $columnStr FROM ${options.tableOrQuery} ${part.partitionClause} |
| 109 | + |$whereClause ${part.limitOffsetClause} |
| 110 | + |""".stripMargin |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +object OBJdbcReader extends SQLConfHelper { |
| 115 | + |
| 116 | + // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field |
| 117 | + // for `MutableRow`. The last argument `Int` means the index for the value to be set in |
| 118 | + // the row and also used for the value in `ResultSet`. |
| 119 | + type OBValueGetter = (ResultSet, InternalRow, Int) => Unit |
| 120 | + |
| 121 | + /** |
| 122 | + * Creates `JDBCValueGetter`s according to [[StructType]], which can set each value from |
| 123 | + * `ResultSet` to each field of [[InternalRow]] correctly. |
| 124 | + */ |
| 125 | + def makeGetters(schema: StructType): Array[OBValueGetter] = |
| 126 | + schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) |
| 127 | + |
| 128 | + private def makeGetter(dt: DataType, metadata: Metadata): OBValueGetter = dt match { |
| 129 | + case BooleanType => |
| 130 | + (rs: ResultSet, row: InternalRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1)) |
| 131 | + |
| 132 | + case DateType => |
| 133 | + (rs: ResultSet, row: InternalRow, pos: Int) => |
| 134 | + // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. |
| 135 | + val dateVal = rs.getDate(pos + 1) |
| 136 | + if (dateVal != null) { |
| 137 | + row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal)) |
| 138 | + } else { |
| 139 | + row.update(pos, null) |
| 140 | + } |
| 141 | + |
| 142 | + // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal |
| 143 | + // object returned by ResultSet.getBigDecimal is not correctly matched to the table |
| 144 | + // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. |
| 145 | + // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through |
| 146 | + // a BigDecimal object with scale as 0. But the dataframe schema has correct type as |
| 147 | + // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then |
| 148 | + // retrieve it, you will get wrong result 199.99. |
| 149 | + // So it is needed to set precision and scale for Decimal based on JDBC metadata. |
| 150 | + case _: DecimalType => |
| 151 | + (rs: ResultSet, row: InternalRow, pos: Int) => |
| 152 | + val decimal = |
| 153 | + nullSafeConvert[java.math.BigDecimal]( |
| 154 | + rs.getBigDecimal(pos + 1), |
| 155 | + d => Decimal(d, d.precision(), d.scale())) |
| 156 | + row.update(pos, decimal) |
| 157 | + |
| 158 | + case DoubleType => |
| 159 | + (rs: ResultSet, row: InternalRow, pos: Int) => row.setDouble(pos, rs.getDouble(pos + 1)) |
| 160 | + |
| 161 | + case FloatType => |
| 162 | + (rs: ResultSet, row: InternalRow, pos: Int) => row.setFloat(pos, rs.getFloat(pos + 1)) |
| 163 | + |
| 164 | + case IntegerType => |
| 165 | + (rs: ResultSet, row: InternalRow, pos: Int) => row.setInt(pos, rs.getInt(pos + 1)) |
| 166 | + |
| 167 | + case LongType if metadata.contains("binarylong") => |
| 168 | + (rs: ResultSet, row: InternalRow, pos: Int) => |
| 169 | + val bytes = rs.getBytes(pos + 1) |
| 170 | + var ans = 0L |
| 171 | + var j = 0 |
| 172 | + while (j < bytes.length) { |
| 173 | + ans = 256 * ans + (255 & bytes(j)) |
| 174 | + j = j + 1 |
| 175 | + } |
| 176 | + row.setLong(pos, ans) |
| 177 | + |
| 178 | + case LongType => |
| 179 | + (rs: ResultSet, row: InternalRow, pos: Int) => row.setLong(pos, rs.getLong(pos + 1)) |
| 180 | + |
| 181 | + case ShortType => |
| 182 | + (rs: ResultSet, row: InternalRow, pos: Int) => row.setShort(pos, rs.getShort(pos + 1)) |
| 183 | + |
| 184 | + case ByteType => |
| 185 | + (rs: ResultSet, row: InternalRow, pos: Int) => row.setByte(pos, rs.getByte(pos + 1)) |
| 186 | + |
| 187 | + case StringType if metadata.contains("rowid") => |
| 188 | + (rs: ResultSet, row: InternalRow, pos: Int) => |
| 189 | + row.update(pos, UTF8String.fromString(rs.getRowId(pos + 1).toString)) |
| 190 | + |
| 191 | + case StringType => |
| 192 | + (rs: ResultSet, row: InternalRow, pos: Int) => |
| 193 | + // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 |
| 194 | + row.update(pos, UTF8String.fromString(rs.getString(pos + 1))) |
| 195 | + |
| 196 | + // SPARK-34357 - sql TIME type represents as zero epoch timestamp. |
| 197 | + // It is mapped as Spark TimestampType but fixed at 1970-01-01 for day, |
| 198 | + // time portion is time of day, with no reference to a particular calendar, |
| 199 | + // time zone or date, with a precision till microseconds. |
| 200 | + // It stores the number of milliseconds after midnight, 00:00:00.000000 |
| 201 | + case TimestampType if metadata.contains("logical_time_type") => |
| 202 | + (rs: ResultSet, row: InternalRow, pos: Int) => { |
| 203 | + val rawTime = rs.getTime(pos + 1) |
| 204 | + if (rawTime != null) { |
| 205 | + val localTimeMicro = TimeUnit.NANOSECONDS.toMicros(rawTime.toLocalTime.toNanoOfDay) |
| 206 | + val utcTimeMicro = DateTimeUtils.toUTCTime(localTimeMicro, conf.sessionLocalTimeZone) |
| 207 | + row.setLong(pos, utcTimeMicro) |
| 208 | + } else { |
| 209 | + row.update(pos, null) |
| 210 | + } |
| 211 | + } |
| 212 | + |
| 213 | + case TimestampType => |
| 214 | + (rs: ResultSet, row: InternalRow, pos: Int) => |
| 215 | + val t = rs.getTimestamp(pos + 1) |
| 216 | + if (t != null) { |
| 217 | + row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t)) |
| 218 | + } else { |
| 219 | + row.update(pos, null) |
| 220 | + } |
| 221 | + |
| 222 | + case BinaryType => |
| 223 | + (rs: ResultSet, row: InternalRow, pos: Int) => row.update(pos, rs.getBytes(pos + 1)) |
| 224 | + |
| 225 | + case ArrayType(et, _) => |
| 226 | + val elementConversion = et match { |
| 227 | + case TimestampType => |
| 228 | + (array: Object) => |
| 229 | + array.asInstanceOf[Array[java.sql.Timestamp]].map { |
| 230 | + timestamp => nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) |
| 231 | + } |
| 232 | + |
| 233 | + case StringType => |
| 234 | + (array: Object) => |
| 235 | + // some underling types are not String such as uuid, inet, cidr, etc. |
| 236 | + array |
| 237 | + .asInstanceOf[Array[java.lang.Object]] |
| 238 | + .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString)) |
| 239 | + |
| 240 | + case DateType => |
| 241 | + (array: Object) => |
| 242 | + array.asInstanceOf[Array[java.sql.Date]].map { |
| 243 | + date => nullSafeConvert(date, DateTimeUtils.fromJavaDate) |
| 244 | + } |
| 245 | + |
| 246 | + case dt: DecimalType => |
| 247 | + (array: Object) => |
| 248 | + array.asInstanceOf[Array[java.math.BigDecimal]].map { |
| 249 | + decimal => |
| 250 | + nullSafeConvert[java.math.BigDecimal]( |
| 251 | + decimal, |
| 252 | + d => Decimal(d, dt.precision, dt.scale)) |
| 253 | + } |
| 254 | + |
| 255 | + case LongType if metadata.contains("binarylong") => |
| 256 | + throw new UnsupportedOperationException( |
| 257 | + s"unsupportedArrayElementTypeBasedOnBinaryError ${dt.catalogString}") |
| 258 | + |
| 259 | + case ArrayType(_, _) => |
| 260 | + throw new UnsupportedOperationException(s"Not support Array data-type now") |
| 261 | + |
| 262 | + case _ => (array: Object) => array.asInstanceOf[Array[Any]] |
| 263 | + } |
| 264 | + |
| 265 | + (rs: ResultSet, row: InternalRow, pos: Int) => |
| 266 | + val array = nullSafeConvert[java.sql.Array]( |
| 267 | + input = rs.getArray(pos + 1), |
| 268 | + array => new GenericArrayData(elementConversion.apply(array.getArray))) |
| 269 | + row.update(pos, array) |
| 270 | + |
| 271 | + case _ => |
| 272 | + throw new UnsupportedOperationException(s"unsupportedJdbcTypeError ${dt.catalogString}") |
| 273 | + } |
| 274 | + |
| 275 | + private def nullSafeConvert[T](input: T, f: T => Any): Any = { |
| 276 | + if (input == null) { |
| 277 | + null |
| 278 | + } else { |
| 279 | + f(input) |
| 280 | + } |
| 281 | + } |
| 282 | +} |
0 commit comments