Skip to content
Snippets Groups Projects
Commit 6bcbf9b7 authored by Josh Rosen's avatar Josh Rosen Committed by Herman van Hovell
Browse files

[SPARK-17351] Refactor JDBCRDD to expose ResultSet -> Seq[Row] utility methods

This patch refactors the internals of the JDBC data source in order to allow some of its code to be re-used in an automated comparison testing harness. Here are the key changes:

- Move the JDBC `ResultSetMetadata` to `StructType` conversion logic from `JDBCRDD.resolveTable()` to the `JdbcUtils` object (as a new `getSchema(ResultSet, JdbcDialect)` method), allowing it to be applied on `ResultSet`s that are created elsewhere.
- Move the `ResultSet` to `InternalRow` conversion methods from `JDBCRDD` to `JdbcUtils`:
  - It makes sense to move the `JDBCValueGetter` type and `makeGetter` functions here given that their write-path counterparts (`JDBCValueSetter`) are already in `JdbcUtils`.
  - Add an internal `resultSetToSparkInternalRows` method which takes a `ResultSet` and schema and returns an `Iterator[InternalRow]`. This effectively extracts the main loop of `JDBCRDD` into its own method.
  - Add a public `resultSetToRows` method to `JdbcUtils`, which wraps the minimal machinery around `resultSetToSparkInternalRows` in order to allow it to be called from outside of a Spark job.
- Make `JdbcDialect.get` into a `DeveloperApi` (`JdbcDialect` itself is already a `DeveloperApi`).

Put together, these changes enable the following testing pattern:

```scala
val jdbResultSet: ResultSet = conn.prepareStatement(query).executeQuery()
val resultSchema: StructType = JdbcUtils.getSchema(jdbResultSet, JdbcDialects.get("jdbc:postgresql"))
val jdbcRows: Seq[Row] = JdbcUtils.resultSetToRows(jdbResultSet, schema).toSeq
checkAnswer(sparkResult, jdbcRows) // in a test case
```

Author: Josh Rosen <joshrosen@databricks.com>

Closes #14907 from JoshRosen/modularize-jdbc-internals.
parent 806d8a8e
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.jdbc
import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp}
import java.util.Properties
import scala.util.control.NonFatal
......@@ -28,12 +28,10 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.CompletionIterator
/**
* Data corresponding to one partition of a JDBCRDD.
......@@ -44,68 +42,6 @@ case class JDBCPartition(whereClause: String, idx: Int) extends Partition {
object JDBCRDD extends Logging {
/**
* Maps a JDBC type to a Catalyst type. This function is called only when
* the JdbcDialect class corresponding to your database driver returns null.
*
* @param sqlType - A field of java.sql.Types
* @return The Catalyst type corresponding to sqlType.
*/
private def getCatalystType(
sqlType: Int,
precision: Int,
scale: Int,
signed: Boolean): DataType = {
val answer = sqlType match {
// scalastyle:off
case java.sql.Types.ARRAY => null
case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) }
case java.sql.Types.BINARY => BinaryType
case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
case java.sql.Types.BLOB => BinaryType
case java.sql.Types.BOOLEAN => BooleanType
case java.sql.Types.CHAR => StringType
case java.sql.Types.CLOB => StringType
case java.sql.Types.DATALINK => null
case java.sql.Types.DATE => DateType
case java.sql.Types.DECIMAL
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.DISTINCT => null
case java.sql.Types.DOUBLE => DoubleType
case java.sql.Types.FLOAT => FloatType
case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType }
case java.sql.Types.JAVA_OBJECT => null
case java.sql.Types.LONGNVARCHAR => StringType
case java.sql.Types.LONGVARBINARY => BinaryType
case java.sql.Types.LONGVARCHAR => StringType
case java.sql.Types.NCHAR => StringType
case java.sql.Types.NCLOB => StringType
case java.sql.Types.NULL => null
case java.sql.Types.NUMERIC
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.NVARCHAR => StringType
case java.sql.Types.OTHER => null
case java.sql.Types.REAL => DoubleType
case java.sql.Types.REF => StringType
case java.sql.Types.ROWID => LongType
case java.sql.Types.SMALLINT => IntegerType
case java.sql.Types.SQLXML => StringType
case java.sql.Types.STRUCT => StringType
case java.sql.Types.TIME => TimestampType
case java.sql.Types.TIMESTAMP => TimestampType
case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
case _ => null
// scalastyle:on
}
if (answer == null) throw new SQLException("Unsupported type " + sqlType)
answer
}
/**
* Takes a (schema, table) specification and returns the table's Catalyst
* schema.
......@@ -126,37 +62,7 @@ object JDBCRDD extends Logging {
try {
val rs = statement.executeQuery()
try {
val rsmd = rs.getMetaData
val ncols = rsmd.getColumnCount
val fields = new Array[StructField](ncols)
var i = 0
while (i < ncols) {
val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
val fieldScale = rsmd.getScale(i + 1)
val isSigned = {
try {
rsmd.isSigned(i + 1)
} catch {
// Workaround for HIVE-14684:
case e: SQLException if
e.getMessage == "Method not supported" &&
rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
}
}
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder()
.putString("name", columnName)
.putLong("scale", fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
}
return new StructType(fields)
return JdbcUtils.getSchema(rs, dialect)
} finally {
rs.close()
}
......@@ -331,195 +237,15 @@ private[jdbc] class JDBCRDD(
}
}
// A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
// for `MutableRow`. The last argument `Int` means the index for the value to be set in
// the row and also used for the value in `ResultSet`.
private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit
/**
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[MutableRow]] correctly.
*/
def makeGetters(schema: StructType): Array[JDBCValueGetter] =
schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
case DateType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos + 1)
if (dateVal != null) {
row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
} else {
row.update(pos, null)
}
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalType.Fixed(p, s) =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val decimal =
nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
row.update(pos, decimal)
case DoubleType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setDouble(pos, rs.getDouble(pos + 1))
case FloatType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setFloat(pos, rs.getFloat(pos + 1))
case IntegerType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setInt(pos, rs.getInt(pos + 1))
case LongType if metadata.contains("binarylong") =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val bytes = rs.getBytes(pos + 1)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1
}
row.setLong(pos, ans)
case LongType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setLong(pos, rs.getLong(pos + 1))
case ShortType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setShort(pos, rs.getShort(pos + 1))
case StringType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
case TimestampType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
} else {
row.update(pos, null)
}
case BinaryType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.update(pos, rs.getBytes(pos + 1))
case ArrayType(et, _) =>
val elementConversion = et match {
case TimestampType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringType =>
(array: Object) =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)
case DateType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case dt: DecimalType =>
(array: Object) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](
decimal, d => Decimal(d, dt.precision, dt.scale))
}
case LongType if metadata.contains("binarylong") =>
throw new IllegalArgumentException(s"Unsupported array element " +
s"type ${dt.simpleString} based on binary")
case ArrayType(_, _) =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => (array: Object) => array.asInstanceOf[Array[Any]]
}
(rs: ResultSet, row: MutableRow, pos: Int) =>
val array = nullSafeConvert[Object](
rs.getArray(pos + 1).getArray,
array => new GenericArrayData(elementConversion.apply(array)))
row.update(pos, array)
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
}
/**
* Runs the SQL query against the JDBC driver.
*
*/
override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] =
new Iterator[InternalRow] {
override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = {
var closed = false
var finished = false
var gotNext = false
var nextValue: InternalRow = null
context.addTaskCompletionListener{ context => close() }
val inputMetrics = context.taskMetrics().inputMetrics
val part = thePart.asInstanceOf[JDBCPartition]
val conn = getConnection()
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, properties.asScala.toMap)
// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
// talk about a table in a completely portable way.
val myWhereClause = getWhereClause(part)
val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
val stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt
require(fetchSize >= 0,
s"Invalid value `${fetchSize.toString}` for parameter " +
s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " +
"the JDBC driver ignores the value and does the estimates.")
stmt.setFetchSize(fetchSize)
val rs = stmt.executeQuery()
val getters: Array[JDBCValueGetter] = makeGetters(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
while (i < getters.length) {
getters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}
mutableRow
} else {
finished = true
null.asInstanceOf[InternalRow]
}
}
var rs: ResultSet = null
var stmt: PreparedStatement = null
var conn: Connection = null
def close() {
if (closed) return
......@@ -555,33 +281,33 @@ private[jdbc] class JDBCRDD(
closed = true
}
override def hasNext: Boolean = {
if (!finished) {
if (!gotNext) {
nextValue = getNext()
if (finished) {
close()
}
gotNext = true
}
}
!finished
}
context.addTaskCompletionListener{ context => close() }
override def next(): InternalRow = {
if (!hasNext) {
throw new NoSuchElementException("End of stream")
}
gotNext = false
nextValue
}
}
val inputMetrics = context.taskMetrics().inputMetrics
val part = thePart.asInstanceOf[JDBCPartition]
conn = getConnection()
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, properties.asScala.toMap)
private def nullSafeConvert[T](input: T, f: T => Any): Any = {
if (input == null) {
null
} else {
f(input)
}
// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
// talk about a table in a completely portable way.
val myWhereClause = getWhereClause(part)
val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt
require(fetchSize >= 0,
s"Invalid value `${fetchSize.toString}` for parameter " +
s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " +
"the JDBC driver ignores the value and does the estimates.")
stmt.setFetchSize(fetchSize)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close())
}
}
......@@ -17,17 +17,25 @@
package org.apache.spark.sql.execution.datasources.jdbc
import java.sql.{Connection, Driver, DriverManager, PreparedStatement, SQLException}
import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
import java.util.Properties
import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal
import org.apache.spark.TaskContext
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.NextIterator
/**
* Util functions for JDBC tables.
......@@ -127,6 +135,7 @@ object JdbcUtils extends Logging {
/**
* Retrieve standard jdbc types.
*
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The default JdbcType for this DataType
*/
......@@ -154,6 +163,297 @@ object JdbcUtils extends Logging {
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}
/**
* Maps a JDBC type to a Catalyst type. This function is called only when
* the JdbcDialect class corresponding to your database driver returns null.
*
* @param sqlType - A field of java.sql.Types
* @return The Catalyst type corresponding to sqlType.
*/
private def getCatalystType(
sqlType: Int,
precision: Int,
scale: Int,
signed: Boolean): DataType = {
val answer = sqlType match {
// scalastyle:off
case java.sql.Types.ARRAY => null
case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) }
case java.sql.Types.BINARY => BinaryType
case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
case java.sql.Types.BLOB => BinaryType
case java.sql.Types.BOOLEAN => BooleanType
case java.sql.Types.CHAR => StringType
case java.sql.Types.CLOB => StringType
case java.sql.Types.DATALINK => null
case java.sql.Types.DATE => DateType
case java.sql.Types.DECIMAL
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.DISTINCT => null
case java.sql.Types.DOUBLE => DoubleType
case java.sql.Types.FLOAT => FloatType
case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType }
case java.sql.Types.JAVA_OBJECT => null
case java.sql.Types.LONGNVARCHAR => StringType
case java.sql.Types.LONGVARBINARY => BinaryType
case java.sql.Types.LONGVARCHAR => StringType
case java.sql.Types.NCHAR => StringType
case java.sql.Types.NCLOB => StringType
case java.sql.Types.NULL => null
case java.sql.Types.NUMERIC
if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
case java.sql.Types.NVARCHAR => StringType
case java.sql.Types.OTHER => null
case java.sql.Types.REAL => DoubleType
case java.sql.Types.REF => StringType
case java.sql.Types.ROWID => LongType
case java.sql.Types.SMALLINT => IntegerType
case java.sql.Types.SQLXML => StringType
case java.sql.Types.STRUCT => StringType
case java.sql.Types.TIME => TimestampType
case java.sql.Types.TIMESTAMP => TimestampType
case java.sql.Types.TINYINT => IntegerType
case java.sql.Types.VARBINARY => BinaryType
case java.sql.Types.VARCHAR => StringType
case _ => null
// scalastyle:on
}
if (answer == null) throw new SQLException("Unsupported type " + sqlType)
answer
}
/**
* Takes a [[ResultSet]] and returns its Catalyst schema.
*
* @return A [[StructType]] giving the Catalyst schema.
* @throws SQLException if the schema contains an unsupported type.
*/
def getSchema(resultSet: ResultSet, dialect: JdbcDialect): StructType = {
val rsmd = resultSet.getMetaData
val ncols = rsmd.getColumnCount
val fields = new Array[StructField](ncols)
var i = 0
while (i < ncols) {
val columnName = rsmd.getColumnLabel(i + 1)
val dataType = rsmd.getColumnType(i + 1)
val typeName = rsmd.getColumnTypeName(i + 1)
val fieldSize = rsmd.getPrecision(i + 1)
val fieldScale = rsmd.getScale(i + 1)
val isSigned = {
try {
rsmd.isSigned(i + 1)
} catch {
// Workaround for HIVE-14684:
case e: SQLException if
e.getMessage == "Method not supported" &&
rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
}
}
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder()
.putString("name", columnName)
.putLong("scale", fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
}
new StructType(fields)
}
/**
* Convert a [[ResultSet]] into an iterator of Catalyst Rows.
*/
def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
val inputMetrics =
Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics)
val encoder = RowEncoder(schema).resolveAndBind()
val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
internalRows.map(encoder.fromRow)
}
private[spark] def resultSetToSparkInternalRows(
resultSet: ResultSet,
schema: StructType,
inputMetrics: InputMetrics): Iterator[InternalRow] = {
new NextIterator[InternalRow] {
private[this] val rs = resultSet
private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
private[this] val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
override protected def close(): Unit = {
try {
rs.close()
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
}
override protected def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
while (i < getters.length) {
getters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}
mutableRow
} else {
finished = true
null.asInstanceOf[InternalRow]
}
}
}
}
// A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
// for `MutableRow`. The last argument `Int` means the index for the value to be set in
// the row and also used for the value in `ResultSet`.
private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit
/**
* Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[MutableRow]] correctly.
*/
private def makeGetters(schema: StructType): Array[JDBCValueGetter] =
schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
case DateType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos + 1)
if (dateVal != null) {
row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
} else {
row.update(pos, null)
}
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalType.Fixed(p, s) =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val decimal =
nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
row.update(pos, decimal)
case DoubleType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setDouble(pos, rs.getDouble(pos + 1))
case FloatType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setFloat(pos, rs.getFloat(pos + 1))
case IntegerType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setInt(pos, rs.getInt(pos + 1))
case LongType if metadata.contains("binarylong") =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val bytes = rs.getBytes(pos + 1)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1
}
row.setLong(pos, ans)
case LongType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setLong(pos, rs.getLong(pos + 1))
case ShortType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setShort(pos, rs.getShort(pos + 1))
case StringType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
case TimestampType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
} else {
row.update(pos, null)
}
case BinaryType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.update(pos, rs.getBytes(pos + 1))
case ArrayType(et, _) =>
val elementConversion = et match {
case TimestampType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringType =>
(array: Object) =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)
case DateType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case dt: DecimalType =>
(array: Object) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](
decimal, d => Decimal(d, dt.precision, dt.scale))
}
case LongType if metadata.contains("binarylong") =>
throw new IllegalArgumentException(s"Unsupported array element " +
s"type ${dt.simpleString} based on binary")
case ArrayType(_, _) =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => (array: Object) => array.asInstanceOf[Array[Any]]
}
(rs: ResultSet, row: MutableRow, pos: Int) =>
val array = nullSafeConvert[Object](
rs.getArray(pos + 1).getArray,
array => new GenericArrayData(elementConversion.apply(array)))
row.update(pos, array)
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
}
private def nullSafeConvert[T](input: T, f: T => Any): Any = {
if (input == null) {
null
} else {
f(input)
}
}
// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. The last argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.
......
......@@ -162,7 +162,7 @@ object JdbcDialects {
/**
* Fetch the JdbcDialect class corresponding to a given database url.
*/
private[sql] def get(url: String): JdbcDialect = {
def get(url: String): JdbcDialect = {
val matchingDialects = dialects.filter(_.canHandle(url))
matchingDialects.length match {
case 0 => NoopDialect
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment