Skip to content
Snippets Groups Projects
Commit cafd5056 authored by Josh Rosen's avatar Josh Rosen Committed by Reynold Xin
Browse files

[SPARK-7691] [SQL] Refactor CatalystTypeConverter to use type-specific row accessors

This patch significantly refactors CatalystTypeConverters to both clean up the code and enable these conversions to work with future Project Tungsten features.

At a high level, I've reorganized the code so that all functions dealing with the same type are grouped together into type-specific subclasses of `CatalystTypeConveter`.  In addition, I've added new methods that allow the Catalyst Row -> Scala Row conversions to access the Catalyst row's fields through type-specific `getTYPE()` methods rather than the generic `get()` / `Row.apply` methods.  This refactoring is a blocker to being able to unit test new operators that I'm developing as part of Project Tungsten, since those operators may output `UnsafeRow` instances which don't support the generic `get()`.

The stricter type usage of types here has uncovered some bugs in other parts of Spark SQL:

- #6217: DescribeCommand is assigned wrong output attributes in SparkStrategies
- #6218: DataFrame.describe() should cast all aggregates to String
- #6400: Use output schema, not relation schema, for data source input conversion

Spark SQL current has undefined behavior for what happens when you try to create a DataFrame from user-specified rows whose values don't match the declared schema.  According to the `createDataFrame()` Scaladoc:

>  It is important to make sure that the structure of every [[Row]] of the provided RDD matches the provided schema. Otherwise, there will be runtime exception.

Given this, it sounds like it's technically not a break of our API contract to fail-fast when the data types don't match. However, there appear to be many cases where we don't fail even though the types don't match. For example, `JavaHashingTFSuite.hasingTF` passes a column of integers values for a "label" column which is supposed to contain floats.  This column isn't actually read or modified as part of query processing, so its actual concrete type doesn't seem to matter. In other cases, there could be situations where we have generic numeric aggregates that tolerate being called with different numeric types than the schema specified, but this can be okay due to numeric conversions.

In the long run, we will probably want to come up with precise semantics for implicit type conversions / widening when converting Java / Scala rows to Catalyst rows.  Until then, though, I think that failing fast with a ClassCastException is a reasonable behavior; this is the approach taken in this patch.  Note that certain optimizations in the inbound conversion functions for primitive types mean that we'll probably preserve the old undefined behavior in a majority of cases.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #6222 from JoshRosen/catalyst-converters-refactoring and squashes the following commits:

740341b [Josh Rosen] Optimize method dispatch for primitive type conversions
befc613 [Josh Rosen] Add tests to document Option-handling behavior.
5989593 [Josh Rosen] Use new SparkFunSuite base in CatalystTypeConvertersSuite
6edf7f8 [Josh Rosen] Re-add convertToScala(), since a Hive test still needs it
3f7b2d8 [Josh Rosen] Initialize converters lazily so that the attributes are resolved first
6ad0ebb [Josh Rosen] Fix JavaHashingTFSuite ClassCastException
677ff27 [Josh Rosen] Fix null handling bug; add tests.
8033d4c [Josh Rosen] Fix serialization error in UserDefinedGenerator.
85bba9d [Josh Rosen] Fix wrong input data in InMemoryColumnarQuerySuite
9c0e4e1 [Josh Rosen] Remove last use of convertToScala().
ae3278d [Josh Rosen] Throw ClassCastException errors during inbound conversions.
7ca7fcb [Josh Rosen] Comments and cleanup
1e87a45 [Josh Rosen] WIP refactoring of CatalystTypeConverters
parent a86b3e9b
No related branches found
No related tags found
No related merge requests found
...@@ -55,9 +55,9 @@ public class JavaHashingTFSuite { ...@@ -55,9 +55,9 @@ public class JavaHashingTFSuite {
@Test @Test
public void hashingTF() { public void hashingTF() {
JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList( JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
RowFactory.create(0, "Hi I heard about Spark"), RowFactory.create(0.0, "Hi I heard about Spark"),
RowFactory.create(0, "I wish Java could use case classes"), RowFactory.create(0.0, "I wish Java could use case classes"),
RowFactory.create(1, "Logistic regression models are neat") RowFactory.create(1.0, "Logistic regression models are neat")
)); ));
StructType schema = new StructType(new StructField[]{ StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
......
...@@ -18,7 +18,10 @@ ...@@ -18,7 +18,10 @@
package org.apache.spark.sql.catalyst package org.apache.spark.sql.catalyst
import java.lang.{Iterable => JavaIterable} import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.sql.Date
import java.util.{Map => JavaMap} import java.util.{Map => JavaMap}
import javax.annotation.Nullable
import scala.collection.mutable.HashMap import scala.collection.mutable.HashMap
...@@ -34,197 +37,338 @@ object CatalystTypeConverters { ...@@ -34,197 +37,338 @@ object CatalystTypeConverters {
// Since the map values can be mutable, we explicitly import scala.collection.Map at here. // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
import scala.collection.Map import scala.collection.Map
private def isPrimitive(dataType: DataType): Boolean = {
dataType match {
case BooleanType => true
case ByteType => true
case ShortType => true
case IntegerType => true
case LongType => true
case FloatType => true
case DoubleType => true
case _ => false
}
}
private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
val converter = dataType match {
case udt: UserDefinedType[_] => UDTConverter(udt)
case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
case structType: StructType => StructConverter(structType)
case StringType => StringConverter
case DateType => DateConverter
case dt: DecimalType => BigDecimalConverter
case BooleanType => BooleanConverter
case ByteType => ByteConverter
case ShortType => ShortConverter
case IntegerType => IntConverter
case LongType => LongConverter
case FloatType => FloatConverter
case DoubleType => DoubleConverter
case _ => IdentityConverter
}
converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]]
}
/** /**
* Converts Scala objects to catalyst rows / types. This method is slow, and for batch * Converts a Scala type to its Catalyst equivalent (and vice versa).
* conversion you should be using converter produced by createToCatalystConverter. *
* Note: This is always called after schemaFor has been called. * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst.
* This ordering is important for UDT registration. * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala.
* @tparam CatalystType The internal Catalyst type used to represent values of this Scala type.
*/ */
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType]
// Check UDT first since UDTs can override other types extends Serializable {
case (obj, udt: UserDefinedType[_]) =>
udt.serialize(obj) /**
* Converts a Scala type to its Catalyst equivalent while automatically handling nulls
case (o: Option[_], _) => * and Options.
o.map(convertToCatalyst(_, dataType)).orNull */
final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = {
case (s: Seq[_], arrayType: ArrayType) => if (maybeScalaValue == null) {
s.map(convertToCatalyst(_, arrayType.elementType)) null.asInstanceOf[CatalystType]
} else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) {
case (jit: JavaIterable[_], arrayType: ArrayType) => { val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]]
val iter = jit.iterator if (opt.isDefined) {
var listOfItems: List[Any] = List() toCatalystImpl(opt.get)
while (iter.hasNext) { } else {
val item = iter.next() null.asInstanceOf[CatalystType]
listOfItems :+= convertToCatalyst(item, arrayType.elementType) }
} else {
toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType])
} }
listOfItems
} }
case (s: Array[_], arrayType: ArrayType) => /**
s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) * Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
*/
final def toScala(row: Row, column: Int): ScalaOutputType = {
if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column)
}
/**
* Convert a Catalyst value to its Scala equivalent.
*/
def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType
/**
* Converts a Scala value to its Catalyst equivalent.
* @param scalaValue the Scala value, guaranteed not to be null.
* @return the Catalyst value.
*/
protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType
/**
* Given a Catalyst row, convert the value at column `column` to its Scala equivalent.
* This method will only be called on non-null columns.
*/
protected def toScalaImpl(row: Row, column: Int): ScalaOutputType
}
case (m: Map[_, _], mapType: MapType) => private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] {
m.map { case (k, v) => override def toCatalystImpl(scalaValue: Any): Any = scalaValue
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) override def toScala(catalystValue: Any): Any = catalystValue
} override def toScalaImpl(row: Row, column: Int): Any = row(column)
}
case (jmap: JavaMap[_, _], mapType: MapType) => private case class UDTConverter(
val iter = jmap.entrySet.iterator udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
var listOfEntries: List[(Any, Any)] = List() override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
while (iter.hasNext) { override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue)
val entry = iter.next() override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column))
listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), }
convertToCatalyst(entry.getValue, mapType.valueType))
/** Converter for arrays, sequences, and Java iterables. */
private case class ArrayConverter(
elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] {
private[this] val elementConverter = getConverterForType(elementType)
override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
scalaValue match {
case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
case s: Seq[_] => s.map(elementConverter.toCatalyst)
case i: JavaIterable[_] =>
val iter = i.iterator
var convertedIterable: List[Any] = List()
while (iter.hasNext) {
val item = iter.next()
convertedIterable :+= elementConverter.toCatalyst(item)
}
convertedIterable
} }
listOfEntries.toMap }
case (p: Product, structType: StructType) => override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
val ar = new Array[Any](structType.size) if (catalystValue == null) {
val iter = p.productIterator null
var idx = 0 } else {
while (idx < structType.size) { catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala)
ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType)
idx += 1
} }
new GenericRowWithSchema(ar, structType) }
case (d: String, _) => override def toScalaImpl(row: Row, column: Int): Seq[Any] =
UTF8String(d) toScala(row(column).asInstanceOf[Seq[Any]])
}
private case class MapConverter(
keyType: DataType,
valueType: DataType)
extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] {
case (d: BigDecimal, _) => private[this] val keyConverter = getConverterForType(keyType)
Decimal(d) private[this] val valueConverter = getConverterForType(valueType)
case (d: java.math.BigDecimal, _) => override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match {
Decimal(d) case m: Map[_, _] =>
m.map { case (k, v) =>
keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v)
}
case (d: java.sql.Date, _) => case jmap: JavaMap[_, _] =>
DateUtils.fromJavaDate(d) val iter = jmap.entrySet.iterator
val convertedMap: HashMap[Any, Any] = HashMap()
while (iter.hasNext) {
val entry = iter.next()
val key = keyConverter.toCatalyst(entry.getKey)
convertedMap(key) = valueConverter.toCatalyst(entry.getValue)
}
convertedMap
}
case (r: Row, structType: StructType) => override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = {
val converters = structType.fields.map { if (catalystValue == null) {
f => (item: Any) => convertToCatalyst(item, f.dataType) null
} else {
catalystValue.map { case (k, v) =>
keyConverter.toScala(k) -> valueConverter.toScala(v)
}
} }
convertRowWithConverters(r, structType, converters) }
case (other, _) => override def toScalaImpl(row: Row, column: Int): Map[Any, Any] =
other toScala(row(column).asInstanceOf[Map[Any, Any]])
} }
/** private case class StructConverter(
* Creates a converter function that will convert Scala objects to the specified catalyst type. structType: StructType) extends CatalystTypeConverter[Any, Row, Row] {
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = {
def extractOption(item: Any): Any = item match {
case opt: Option[_] => opt.orNull
case other => other
}
dataType match { private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) }
// Check UDT first since UDTs can override other types
case udt: UserDefinedType[_] =>
(item) => extractOption(item) match {
case null => null
case other => udt.serialize(other)
}
case arrayType: ArrayType => override def toCatalystImpl(scalaValue: Any): Row = scalaValue match {
val elementConverter = createToCatalystConverter(arrayType.elementType) case row: Row =>
(item: Any) => { val ar = new Array[Any](row.size)
extractOption(item) match { var idx = 0
case a: Array[_] => a.toSeq.map(elementConverter) while (idx < row.size) {
case s: Seq[_] => s.map(elementConverter) ar(idx) = converters(idx).toCatalyst(row(idx))
case i: JavaIterable[_] => { idx += 1
val iter = i.iterator
var convertedIterable: List[Any] = List()
while (iter.hasNext) {
val item = iter.next()
convertedIterable :+= elementConverter(item)
}
convertedIterable
}
case null => null
}
} }
new GenericRowWithSchema(ar, structType)
case mapType: MapType =>
val keyConverter = createToCatalystConverter(mapType.keyType) case p: Product =>
val valueConverter = createToCatalystConverter(mapType.valueType) val ar = new Array[Any](structType.size)
(item: Any) => { val iter = p.productIterator
extractOption(item) match { var idx = 0
case m: Map[_, _] => while (idx < structType.size) {
m.map { case (k, v) => ar(idx) = converters(idx).toCatalyst(iter.next())
keyConverter(k) -> valueConverter(v) idx += 1
}
case jmap: JavaMap[_, _] =>
val iter = jmap.entrySet.iterator
val convertedMap: HashMap[Any, Any] = HashMap()
while (iter.hasNext) {
val entry = iter.next()
convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue)
}
convertedMap
case null => null
}
} }
new GenericRowWithSchema(ar, structType)
}
case structType: StructType => override def toScala(row: Row): Row = {
val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) if (row == null) {
(item: Any) => { null
extractOption(item) match { } else {
case r: Row => val ar = new Array[Any](row.size)
convertRowWithConverters(r, structType, converters) var idx = 0
while (idx < row.size) {
case p: Product => ar(idx) = converters(idx).toScala(row, idx)
val ar = new Array[Any](structType.size) idx += 1
val iter = p.productIterator
var idx = 0
while (idx < structType.size) {
ar(idx) = converters(idx)(iter.next())
idx += 1
}
new GenericRowWithSchema(ar, structType)
case null =>
null
}
} }
new GenericRowWithSchema(ar, structType)
case dateType: DateType => (item: Any) => extractOption(item) match {
case d: java.sql.Date => DateUtils.fromJavaDate(d)
case other => other
} }
}
case dataType: StringType => (item: Any) => extractOption(item) match { override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row])
case s: String => UTF8String(s) }
case other => other
} private object StringConverter extends CatalystTypeConverter[Any, String, Any] {
override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match {
case str: String => UTF8String(str)
case utf8: UTF8String => utf8
}
override def toScala(catalystValue: Any): String = catalystValue match {
case null => null
case str: String => str
case utf8: UTF8String => utf8.toString()
}
override def toScalaImpl(row: Row, column: Int): String = row(column).toString
}
private object DateConverter extends CatalystTypeConverter[Date, Date, Any] {
override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue)
override def toScala(catalystValue: Any): Date =
if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int])
override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column))
}
private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
case d: BigDecimal => Decimal(d)
case d: JavaBigDecimal => Decimal(d)
case d: Decimal => d
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
override def toScalaImpl(row: Row, column: Int): JavaBigDecimal = row.get(column) match {
case d: JavaBigDecimal => d
case d: Decimal => d.toJavaBigDecimal
}
}
private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] {
final override def toScala(catalystValue: Any): Any = catalystValue
final override def toCatalystImpl(scalaValue: T): Any = scalaValue
}
private object BooleanConverter extends PrimitiveConverter[Boolean] {
override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column)
}
private object ByteConverter extends PrimitiveConverter[Byte] {
override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column)
}
private object ShortConverter extends PrimitiveConverter[Short] {
override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column)
}
private object IntConverter extends PrimitiveConverter[Int] {
override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column)
}
private object LongConverter extends PrimitiveConverter[Long] {
override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column)
}
private object FloatConverter extends PrimitiveConverter[Float] {
override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column)
}
case _ => private object DoubleConverter extends PrimitiveConverter[Double] {
(item: Any) => extractOption(item) match { override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column)
case d: BigDecimal => Decimal(d) }
case d: java.math.BigDecimal => Decimal(d)
case other => other /**
* Converts Scala objects to catalyst rows / types. This method is slow, and for batch
* conversion you should be using converter produced by createToCatalystConverter.
* Note: This is always called after schemaFor has been called.
* This ordering is important for UDT registration.
*/
def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = {
getConverterForType(dataType).toCatalyst(scalaValue)
}
/**
* Creates a converter function that will convert Scala objects to the specified Catalyst type.
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = {
if (isPrimitive(dataType)) {
// Although the `else` branch here is capable of handling inbound conversion of primitives,
// we add some special-case handling for those types here. The motivation for this relates to
// Java method invocation costs: if we have rows that consist entirely of primitive columns,
// then returning the same conversion function for all of the columns means that the call site
// will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in
// a measurable performance impact. Note that this optimization will be unnecessary if we
// use code generation to construct Scala Row -> Catalyst Row converters.
def convert(maybeScalaValue: Any): Any = {
if (maybeScalaValue.isInstanceOf[Option[Any]]) {
maybeScalaValue.asInstanceOf[Option[Any]].orNull
} else {
maybeScalaValue
} }
}
convert
} else {
getConverterForType(dataType).toCatalyst
} }
} }
/** /**
* Converts Scala objects to catalyst rows / types. * Converts Scala objects to Catalyst rows / types.
* *
* Note: This should be called before do evaluation on Row * Note: This should be called before do evaluation on Row
* (It does not support UDT) * (It does not support UDT)
* This is used to create an RDD or test results with correct types for Catalyst. * This is used to create an RDD or test results with correct types for Catalyst.
*/ */
def convertToCatalyst(a: Any): Any = a match { def convertToCatalyst(a: Any): Any = a match {
case s: String => UTF8String(s) case s: String => StringConverter.toCatalyst(s)
case d: java.sql.Date => DateUtils.fromJavaDate(d) case d: Date => DateConverter.toCatalyst(d)
case d: BigDecimal => Decimal(d) case d: BigDecimal => BigDecimalConverter.toCatalyst(d)
case d: java.math.BigDecimal => Decimal(d) case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d)
case seq: Seq[Any] => seq.map(convertToCatalyst) case seq: Seq[Any] => seq.map(convertToCatalyst)
case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) case r: Row => Row(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray
...@@ -238,33 +382,8 @@ object CatalystTypeConverters { ...@@ -238,33 +382,8 @@ object CatalystTypeConverters {
* This method is slow, and for batch conversion you should be using converter * This method is slow, and for batch conversion you should be using converter
* produced by createToScalaConverter. * produced by createToScalaConverter.
*/ */
def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { def convertToScala(catalystValue: Any, dataType: DataType): Any = {
// Check UDT first since UDTs can override other types getConverterForType(dataType).toScala(catalystValue)
case (d, udt: UserDefinedType[_]) =>
udt.deserialize(d)
case (s: Seq[_], arrayType: ArrayType) =>
s.map(convertToScala(_, arrayType.elementType))
case (m: Map[_, _], mapType: MapType) =>
m.map { case (k, v) =>
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (r: Row, s: StructType) =>
convertRowToScala(r, s)
case (d: Decimal, _: DecimalType) =>
d.toJavaBigDecimal
case (i: Int, DateType) =>
DateUtils.toJavaDate(i)
case (s: UTF8String, StringType) =>
s.toString()
case (other, _) =>
other
} }
/** /**
...@@ -272,82 +391,7 @@ object CatalystTypeConverters { ...@@ -272,82 +391,7 @@ object CatalystTypeConverters {
* Typical use case would be converting a collection of rows that have the same schema. You will * Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row. * call this function once to get a converter, and apply it to every row.
*/ */
private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { private[sql] def createToScalaConverter(dataType: DataType): Any => Any = {
// Check UDT first since UDTs can override other types getConverterForType(dataType).toScala
case udt: UserDefinedType[_] =>
(item: Any) => if (item == null) null else udt.deserialize(item)
case arrayType: ArrayType =>
val elementConverter = createToScalaConverter(arrayType.elementType)
(item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter)
case mapType: MapType =>
val keyConverter = createToScalaConverter(mapType.keyType)
val valueConverter = createToScalaConverter(mapType.valueType)
(item: Any) => if (item == null) {
null
} else {
item.asInstanceOf[Map[_, _]].map { case (k, v) =>
keyConverter(k) -> valueConverter(v)
}
}
case s: StructType =>
val converters = s.fields.map(f => createToScalaConverter(f.dataType))
(item: Any) => {
if (item == null) {
null
} else {
convertRowWithConverters(item.asInstanceOf[Row], s, converters)
}
}
case _: DecimalType =>
(item: Any) => item match {
case d: Decimal => d.toJavaBigDecimal
case other => other
}
case DateType =>
(item: Any) => item match {
case i: Int => DateUtils.toJavaDate(i)
case other => other
}
case StringType =>
(item: Any) => item match {
case s: UTF8String => s.toString()
case other => other
}
case other =>
(item: Any) => item
}
def convertRowToScala(r: Row, schema: StructType): Row = {
val ar = new Array[Any](r.size)
var idx = 0
while (idx < r.size) {
ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType)
idx += 1
}
new GenericRowWithSchema(ar, schema)
}
/**
* Converts a row by applying the provided set of converter functions. It is used for both
* toScala and toCatalyst conversions.
*/
private[sql] def convertRowWithConverters(
row: Row,
schema: StructType,
converters: Array[Any => Any]): Row = {
val ar = new Array[Any](row.size)
var idx = 0
while (idx < row.size) {
ar(idx) = converters(idx)(row(idx))
idx += 1
}
new GenericRowWithSchema(ar, schema)
} }
} }
...@@ -71,12 +71,23 @@ case class UserDefinedGenerator( ...@@ -71,12 +71,23 @@ case class UserDefinedGenerator(
children: Seq[Expression]) children: Seq[Expression])
extends Generator { extends Generator {
@transient private[this] var inputRow: InterpretedProjection = _
@transient private[this] var convertToScala: (Row) => Row = _
private def initializeConverters(): Unit = {
inputRow = new InterpretedProjection(children)
convertToScala = {
val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true)))
CatalystTypeConverters.createToScalaConverter(inputSchema)
}.asInstanceOf[(Row => Row)]
}
override def eval(input: Row): TraversableOnce[Row] = { override def eval(input: Row): TraversableOnce[Row] = {
// TODO(davies): improve this if (inputRow == null) {
initializeConverters()
}
// Convert the objects into Scala Type before calling function, we need schema to support UDT // Convert the objects into Scala Type before calling function, we need schema to support UDT
val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) function(convertToScala(inputRow(input)))
val inputRow = new InterpretedProjection(children)
function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row])
} }
override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class CatalystTypeConvertersSuite extends SparkFunSuite {
private val simpleTypes: Seq[DataType] = Seq(
StringType,
DateType,
BooleanType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType)
test("null handling in rows") {
val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t)))
val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema)
val convertToScala = CatalystTypeConverters.createToScalaConverter(schema)
val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null))
assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow)
}
test("null handling for individual values") {
for (dataType <- simpleTypes) {
assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null)
}
}
test("option handling in convertToCatalyst") {
// convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with
// createToCatalystConverter but it may not actually matter as this is only called internally
// in a handful of places where we don't expect to receive Options.
assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123))
}
test("option handling in createToCatalystConverter") {
assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
}
}
...@@ -173,7 +173,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { ...@@ -173,7 +173,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
new Timestamp(i), new Timestamp(i),
(1 to i).toSeq, (1 to i).toSeq,
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
Row((i - 0.25).toFloat, (1 to i).toSeq)) Row((i - 0.25).toFloat, Seq(true, false, null)))
} }
createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table. // Cache the table.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment