Skip to content
Snippets Groups Projects
Commit 7ffd99ec authored by hyukjinkwon's avatar hyukjinkwon Committed by Wenchen Fan
Browse files

[SPARK-16674][SQL] Avoid per-record type dispatch in JDBC when reading

## What changes were proposed in this pull request?

Currently, `JDBCRDD.compute` is doing type dispatch for each row to read appropriate values.
It might not have to be done like this because the schema is already kept in `JDBCRDD`.

So, appropriate converters can be created first according to the schema, and then apply them to each row.

## How was this patch tested?

Existing tests should cover this.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #14313 from HyukjinKwon/SPARK-16674.
parent 68b4020d
No related branches found
No related tags found
No related merge requests found
...@@ -28,7 +28,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} ...@@ -28,7 +28,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources._
...@@ -322,43 +322,134 @@ private[sql] class JDBCRDD( ...@@ -322,43 +322,134 @@ private[sql] class JDBCRDD(
} }
} }
// Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that // A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet`
// we don't have to potentially poke around in the Metadata once for every // into a field for `MutableRow`. The last argument `Int` means the index for the
// row. // value to be set in the row and also used for the value to retrieve from `ResultSet`.
// Is there a better way to do this? I'd rather be using a type that private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit
// contains only the tags I define.
abstract class JDBCConversion
case object BooleanConversion extends JDBCConversion
case object DateConversion extends JDBCConversion
case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion
case object DoubleConversion extends JDBCConversion
case object FloatConversion extends JDBCConversion
case object IntegerConversion extends JDBCConversion
case object LongConversion extends JDBCConversion
case object BinaryLongConversion extends JDBCConversion
case object StringConversion extends JDBCConversion
case object TimestampConversion extends JDBCConversion
case object BinaryConversion extends JDBCConversion
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
/** /**
* Maps a StructType to a type tag list. * Creates `JDBCValueSetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[MutableRow]] correctly.
*/ */
def getConversions(schema: StructType): Array[JDBCConversion] = def makeSetters(schema: StructType): Array[JDBCValueSetter] =
schema.fields.map(sf => getConversions(sf.dataType, sf.metadata)) schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata))
private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match { private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match {
case BooleanType => BooleanConversion case BooleanType =>
case DateType => DateConversion (rs: ResultSet, row: MutableRow, pos: Int) =>
case DecimalType.Fixed(p, s) => DecimalConversion(p, s) row.setBoolean(pos, rs.getBoolean(pos + 1))
case DoubleType => DoubleConversion
case FloatType => FloatConversion case DateType =>
case IntegerType => IntegerConversion (rs: ResultSet, row: MutableRow, pos: Int) =>
case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
case StringType => StringConversion val dateVal = rs.getDate(pos + 1)
case TimestampType => TimestampConversion if (dateVal != null) {
case BinaryType => BinaryConversion row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata)) } else {
row.update(pos, null)
}
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalType.Fixed(p, s) =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val decimal =
nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
row.update(pos, decimal)
case DoubleType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setDouble(pos, rs.getDouble(pos + 1))
case FloatType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setFloat(pos, rs.getFloat(pos + 1))
case IntegerType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setInt(pos, rs.getInt(pos + 1))
case LongType if metadata.contains("binarylong") =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val bytes = rs.getBytes(pos + 1)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1
}
row.setLong(pos, ans)
case LongType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setLong(pos, rs.getLong(pos + 1))
case StringType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
case TimestampType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
val t = rs.getTimestamp(pos + 1)
if (t != null) {
row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
} else {
row.update(pos, null)
}
case BinaryType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.update(pos, rs.getBytes(pos + 1))
case ArrayType(et, _) =>
val elementConversion = et match {
case TimestampType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringType =>
(array: Object) =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)
case DateType =>
(array: Object) =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case dt: DecimalType =>
(array: Object) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](
decimal, d => Decimal(d, dt.precision, dt.scale))
}
case LongType if metadata.contains("binarylong") =>
throw new IllegalArgumentException(s"Unsupported array element " +
s"type ${dt.simpleString} based on binary")
case ArrayType(_, _) =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => (array: Object) => array.asInstanceOf[Array[Any]]
}
(rs: ResultSet, row: MutableRow, pos: Int) =>
val array = nullSafeConvert[Object](
rs.getArray(pos + 1).getArray,
array => new GenericArrayData(elementConversion.apply(array)))
row.update(pos, array)
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
} }
...@@ -398,93 +489,15 @@ private[sql] class JDBCRDD( ...@@ -398,93 +489,15 @@ private[sql] class JDBCRDD(
stmt.setFetchSize(fetchSize) stmt.setFetchSize(fetchSize)
val rs = stmt.executeQuery() val rs = stmt.executeQuery()
val conversions = getConversions(schema) val setters: Array[JDBCValueSetter] = makeSetters(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
def getNext(): InternalRow = { def getNext(): InternalRow = {
if (rs.next()) { if (rs.next()) {
inputMetrics.incRecordsRead(1) inputMetrics.incRecordsRead(1)
var i = 0 var i = 0
while (i < conversions.length) { while (i < setters.length) {
val pos = i + 1 setters(i).apply(rs, mutableRow, i)
conversions(i) match {
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
case DateConversion =>
// DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
val dateVal = rs.getDate(pos)
if (dateVal != null) {
mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal))
} else {
mutableRow.update(i, null)
}
// When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
// object returned by ResultSet.getBigDecimal is not correctly matched to the table
// schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
// If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
// a BigDecimal object with scale as 0. But the dataframe schema has correct type as
// DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
// retrieve it, you will get wrong result 199.99.
// So it is needed to set precision and scale for Decimal based on JDBC metadata.
case DecimalConversion(p, s) =>
val decimalVal = rs.getBigDecimal(pos)
if (decimalVal == null) {
mutableRow.update(i, null)
} else {
mutableRow.update(i, Decimal(decimalVal, p, s))
}
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
// TODO(davies): use getBytes for better performance, if the encoding is UTF-8
case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos)))
case TimestampConversion =>
val t = rs.getTimestamp(pos)
if (t != null) {
mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t))
} else {
mutableRow.update(i, null)
}
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case BinaryLongConversion =>
val bytes = rs.getBytes(pos)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1
}
mutableRow.setLong(i, ans)
case ArrayConversion(elementConversion) =>
val array = rs.getArray(pos).getArray
if (array != null) {
val data = elementConversion match {
case TimestampConversion =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringConversion =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)
case DateConversion =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case DecimalConversion(p, s) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
}
case BinaryLongConversion =>
throw new IllegalArgumentException(s"Unsupported array element conversion $i")
case _: ArrayConversion =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => array.asInstanceOf[Array[Any]]
}
mutableRow.update(i, new GenericArrayData(data))
} else {
mutableRow.update(i, null)
}
}
if (rs.wasNull) mutableRow.setNullAt(i) if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1 i = i + 1
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment