Skip to content
Snippets Groups Projects
Commit 426004a9 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Michael Armbrust
Browse files

[SPARK-11908][SQL] Add NullType support to RowEncoder

JIRA: https://issues.apache.org/jira/browse/SPARK-11908

We should add NullType support to RowEncoder.

Author: Liang-Chi Hsieh <viirya@appier.com>

Closes #9891 from viirya/rowencoder-nulltype.
parent ff442bbc
No related branches found
No related tags found
No related merge requests found
......@@ -48,7 +48,7 @@ object RowEncoder {
private def extractorsFor(
inputObject: Expression,
inputType: DataType): Expression = inputType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => inputObject
case udt: UserDefinedType[_] =>
......@@ -143,6 +143,7 @@ object RowEncoder {
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
case udt: UserDefinedType[_] => ObjectType(udt.userClass)
case _: NullType => ObjectType(classOf[java.lang.Object])
}
private def constructorFor(schema: StructType): Expression = {
......@@ -158,7 +159,7 @@ object RowEncoder {
}
private def constructorFor(input: Expression): Expression = input.dataType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => input
case udt: UserDefinedType[_] =>
......
......@@ -369,6 +369,9 @@ case class MapObjects(
private lazy val completeFunction = function(loopAttribute)
private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
case NullType =>
val nullTypeClassName = NullType.getClass.getName + ".MODULE$"
(i: String) => s".get($i, $nullTypeClassName)"
case IntegerType => (i: String) => s".getInt($i)"
case LongType => (i: String) => s".getLong($i)"
case FloatType => (i: String) => s".getFloat($i)"
......
......@@ -80,11 +80,13 @@ class RowEncoderSuite extends SparkFunSuite {
private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)
private val arrayOfNull = ArrayType(NullType)
private val mapOfString = MapType(StringType, StringType)
private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
encodeDecodeTest(
new StructType()
.add("null", NullType)
.add("boolean", BooleanType)
.add("byte", ByteType)
.add("short", ShortType)
......@@ -101,6 +103,7 @@ class RowEncoderSuite extends SparkFunSuite {
encodeDecodeTest(
new StructType()
.add("arrayOfNull", arrayOfNull)
.add("arrayOfString", arrayOfString)
.add("arrayOfArrayOfString", ArrayType(arrayOfString))
.add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment