Skip to content
Snippets Groups Projects
Commit b0c3fd34 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Davies Liu
Browse files

[SPARK-11743] [SQL] Add UserDefinedType support to RowEncoder

JIRA: https://issues.apache.org/jira/browse/SPARK-11743

RowEncoder doesn't support UserDefinedType now. We should add the support for it.

Author: Liang-Chi Hsieh <viirya@appier.com>

Closes #9712 from viirya/rowencoder-udt.
parent 06f1fdba
No related branches found
No related tags found
No related merge requests found
...@@ -152,7 +152,7 @@ trait Row extends Serializable { ...@@ -152,7 +152,7 @@ trait Row extends Serializable {
* BinaryType -> byte array * BinaryType -> byte array
* ArrayType -> scala.collection.Seq (use getList for java.util.List) * ArrayType -> scala.collection.Seq (use getList for java.util.List)
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
* StructType -> org.apache.spark.sql.Row * StructType -> org.apache.spark.sql.Row (or Product)
* }}} * }}}
*/ */
def apply(i: Int): Any = get(i) def apply(i: Int): Any = get(i)
...@@ -177,7 +177,7 @@ trait Row extends Serializable { ...@@ -177,7 +177,7 @@ trait Row extends Serializable {
* BinaryType -> byte array * BinaryType -> byte array
* ArrayType -> scala.collection.Seq (use getList for java.util.List) * ArrayType -> scala.collection.Seq (use getList for java.util.List)
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
* StructType -> org.apache.spark.sql.Row * StructType -> org.apache.spark.sql.Row (or Product)
* }}} * }}}
*/ */
def get(i: Int): Any def get(i: Int): Any
...@@ -306,7 +306,15 @@ trait Row extends Serializable { ...@@ -306,7 +306,15 @@ trait Row extends Serializable {
* *
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
*/ */
def getStruct(i: Int): Row = getAs[Row](i) def getStruct(i: Int): Row = {
// Product and Row both are recoginized as StructType in a Row
val t = get(i)
if (t.isInstanceOf[Product]) {
Row.fromTuple(t.asInstanceOf[Product])
} else {
t.asInstanceOf[Row]
}
}
/** /**
* Returns the value at position i. * Returns the value at position i.
......
...@@ -50,6 +50,14 @@ object RowEncoder { ...@@ -50,6 +50,14 @@ object RowEncoder {
case BooleanType | ByteType | ShortType | IntegerType | LongType | case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => inputObject FloatType | DoubleType | BinaryType => inputObject
case udt: UserDefinedType[_] =>
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
case TimestampType => case TimestampType =>
StaticInvoke( StaticInvoke(
DateTimeUtils, DateTimeUtils,
...@@ -109,11 +117,16 @@ object RowEncoder { ...@@ -109,11 +117,16 @@ object RowEncoder {
case StructType(fields) => case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) => val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val method = if (f.dataType.isInstanceOf[StructType]) {
"getStruct"
} else {
"get"
}
If( If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType), Literal.create(null, f.dataType),
extractorsFor( extractorsFor(
Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil), Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
f.dataType)) f.dataType))
} }
CreateStruct(convertedFields) CreateStruct(convertedFields)
...@@ -137,6 +150,7 @@ object RowEncoder { ...@@ -137,6 +150,7 @@ object RowEncoder {
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row]) case _: StructType => ObjectType(classOf[Row])
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
} }
private def constructorFor(schema: StructType): Expression = { private def constructorFor(schema: StructType): Expression = {
...@@ -155,6 +169,14 @@ object RowEncoder { ...@@ -155,6 +169,14 @@ object RowEncoder {
case BooleanType | ByteType | ShortType | IntegerType | LongType | case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => input FloatType | DoubleType | BinaryType => input
case udt: UserDefinedType[_] =>
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
case TimestampType => case TimestampType =>
StaticInvoke( StaticInvoke(
DateTimeUtils, DateTimeUtils,
......
...@@ -113,7 +113,7 @@ case class Invoke( ...@@ -113,7 +113,7 @@ case class Invoke(
arguments: Seq[Expression] = Nil) extends Expression { arguments: Seq[Expression] = Nil) extends Expression {
override def nullable: Boolean = true override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject :: Nil override def children: Seq[Expression] = arguments.+:(targetObject)
override def eval(input: InternalRow): Any = override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.") throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
...@@ -343,33 +343,35 @@ case class MapObjects( ...@@ -343,33 +343,35 @@ case class MapObjects(
private lazy val loopAttribute = AttributeReference("loopVar", elementType)() private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
private lazy val completeFunction = function(loopAttribute) private lazy val completeFunction = function(loopAttribute)
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
case IntegerType => (i: String) => s".getInt($i)"
case LongType => (i: String) => s".getLong($i)"
case FloatType => (i: String) => s".getFloat($i)"
case DoubleType => (i: String) => s".getDouble($i)"
case ByteType => (i: String) => s".getByte($i)"
case ShortType => (i: String) => s".getShort($i)"
case BooleanType => (i: String) => s".getBoolean($i)"
case StringType => (i: String) => s".getUTF8String($i)"
case s: StructType => (i: String) => s".getStruct($i, ${s.size})"
case a: ArrayType => (i: String) => s".getArray($i)"
case _: MapType => (i: String) => s".getMap($i)"
case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
}
private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
(".size()", (i: String) => s".apply($i)", false) (".size()", (i: String) => s".apply($i)", false)
case ObjectType(cls) if cls.isArray => case ObjectType(cls) if cls.isArray =>
(".length", (i: String) => s"[$i]", false) (".length", (i: String) => s"[$i]", false)
case ArrayType(s: StructType, _) => case ArrayType(t, _) =>
(".numElements()", (i: String) => s".getStruct($i, ${s.size})", false) val (sqlType, primitiveElement) = t match {
case ArrayType(a: ArrayType, _) => case m: MapType => (m, false)
(".numElements()", (i: String) => s".getArray($i)", true) case s: StructType => (s, false)
case ArrayType(IntegerType, _) => case s: StringType => (s, false)
(".numElements()", (i: String) => s".getInt($i)", true) case udt: UserDefinedType[_] => (udt.sqlType, false)
case ArrayType(LongType, _) => case o => (o, true)
(".numElements()", (i: String) => s".getLong($i)", true) }
case ArrayType(FloatType, _) => (".numElements()", itemAccessorMethod(sqlType), primitiveElement)
(".numElements()", (i: String) => s".getFloat($i)", true)
case ArrayType(DoubleType, _) =>
(".numElements()", (i: String) => s".getDouble($i)", true)
case ArrayType(ByteType, _) =>
(".numElements()", (i: String) => s".getByte($i)", true)
case ArrayType(ShortType, _) =>
(".numElements()", (i: String) => s".getShort($i)", true)
case ArrayType(BooleanType, _) =>
(".numElements()", (i: String) => s".getBoolean($i)", true)
case ArrayType(StringType, _) =>
(".numElements()", (i: String) => s".getUTF8String($i)", false)
case ArrayType(_: MapType, _) =>
(".numElements()", (i: String) => s".getMap($i)", false)
} }
override def nullable: Boolean = true override def nullable: Boolean = true
......
...@@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders ...@@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.UTF8String
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
class ExamplePoint(val x: Double, val y: Double) extends Serializable {
override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt
override def equals(that: Any): Boolean = {
if (that.isInstanceOf[ExamplePoint]) {
val e = that.asInstanceOf[ExamplePoint]
(this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) &&
(this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity))
} else {
false
}
}
}
/**
* User-defined type for [[ExamplePoint]].
*/
class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
override def sqlType: DataType = ArrayType(DoubleType, false)
override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
override def serialize(obj: Any): GenericArrayData = {
obj match {
case p: ExamplePoint =>
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}
}
override def deserialize(datum: Any): ExamplePoint = {
datum match {
case values: ArrayData =>
new ExamplePoint(values.getDouble(0), values.getDouble(1))
}
}
override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
private[spark] override def asNullable: ExamplePointUDT = this
}
class RowEncoderSuite extends SparkFunSuite { class RowEncoderSuite extends SparkFunSuite {
private val structOfString = new StructType().add("str", StringType) private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType) private val arrayOfString = ArrayType(StringType)
private val mapOfString = MapType(StringType, StringType) private val mapOfString = MapType(StringType, StringType)
private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
encodeDecodeTest( encodeDecodeTest(
new StructType() new StructType()
...@@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite { ...@@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite {
.add("string", StringType) .add("string", StringType)
.add("binary", BinaryType) .add("binary", BinaryType)
.add("date", DateType) .add("date", DateType)
.add("timestamp", TimestampType)) .add("timestamp", TimestampType)
.add("udt", new ExamplePointUDT, false))
encodeDecodeTest( encodeDecodeTest(
new StructType() new StructType()
...@@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite { ...@@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite {
.add("structOfArray", new StructType().add("array", arrayOfString)) .add("structOfArray", new StructType().add("array", arrayOfString))
.add("structOfMap", new StructType().add("map", mapOfString)) .add("structOfMap", new StructType().add("map", mapOfString))
.add("structOfArrayAndMap", .add("structOfArrayAndMap",
new StructType().add("array", arrayOfString).add("map", mapOfString))) new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT))
test(s"encode/decode: arrayOfUDT") {
val schema = new StructType()
.add("arrayOfUDT", arrayOfUDT)
val encoder = RowEncoder(schema)
val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4)))
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0))
}
test(s"encode/decode: Product") {
val schema = new StructType()
.add("structAsProduct",
new StructType()
.add("int", IntegerType)
.add("string", StringType)
.add("double", DoubleType))
val encoder = RowEncoder(schema)
val input: Row = Row((100, "test", 0.123))
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
assert(input.getStruct(0) == convertedBack.getStruct(0))
}
private def encodeDecodeTest(schema: StructType): Unit = { private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") { test(s"encode/decode: ${schema.simpleString}") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment