Skip to content
Snippets Groups Projects
Commit 30c4774f authored by Wenchen Fan's avatar Wenchen Fan Committed by Cheng Lian
Browse files

[SPARK-15657][SQL] RowEncoder should validate the data type of input object

## What changes were proposed in this pull request?

This PR improves the error handling of `RowEncoder`. When we create a `RowEncoder` with a given schema, we should validate the data type of input object. e.g. we should throw an exception when a field is boolean but is declared as a string column.

This PR also removes the support to use `Product` as a valid external type of struct type.  This support is added at https://github.com/apache/spark/pull/9712, but is incomplete, e.g. nested product, product in array are both not working.  However, we never officially support this feature and I think it's ok to ban it.

## How was this patch tested?

new tests in `RowEncoderSuite`.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #13401 from cloud-fan/bug.
parent 8a911051
No related branches found
No related tags found
No related merge requests found
...@@ -304,15 +304,7 @@ trait Row extends Serializable { ...@@ -304,15 +304,7 @@ trait Row extends Serializable {
* *
* @throws ClassCastException when data type does not match. * @throws ClassCastException when data type does not match.
*/ */
def getStruct(i: Int): Row = { def getStruct(i: Int): Row = getAs[Row](i)
// Product and Row both are recognized as StructType in a Row
val t = get(i)
if (t.isInstanceOf[Product]) {
Row.fromTuple(t.asInstanceOf[Product])
} else {
t.asInstanceOf[Row]
}
}
/** /**
* Returns the value at position i. * Returns the value at position i.
......
...@@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String ...@@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String
* BinaryType -> byte array * BinaryType -> byte array
* ArrayType -> scala.collection.Seq or Array * ArrayType -> scala.collection.Seq or Array
* MapType -> scala.collection.Map * MapType -> scala.collection.Map
* StructType -> org.apache.spark.sql.Row or Product * StructType -> org.apache.spark.sql.Row
* }}} * }}}
*/ */
object RowEncoder { object RowEncoder {
...@@ -121,11 +121,15 @@ object RowEncoder { ...@@ -121,11 +121,15 @@ object RowEncoder {
case t @ ArrayType(et, _) => et match { case t @ ArrayType(et, _) => et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
// TODO: validate input type for primitive array.
NewInstance( NewInstance(
classOf[GenericArrayData], classOf[GenericArrayData],
inputObject :: Nil, inputObject :: Nil,
dataType = t) dataType = t)
case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et)) case _ => MapObjects(
element => serializerFor(ValidateExternalType(element, et), et),
inputObject,
ObjectType(classOf[Object]))
} }
case t @ MapType(kt, vt, valueNullable) => case t @ MapType(kt, vt, valueNullable) =>
...@@ -151,8 +155,9 @@ object RowEncoder { ...@@ -151,8 +155,9 @@ object RowEncoder {
case StructType(fields) => case StructType(fields) =>
val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>
val fieldValue = serializerFor( val fieldValue = serializerFor(
GetExternalRowField( ValidateExternalType(
inputObject, index, field.name, externalDataTypeForInput(field.dataType)), GetExternalRowField(inputObject, index, field.name),
field.dataType),
field.dataType) field.dataType)
val convertedField = if (field.nullable) { val convertedField = if (field.nullable) {
If( If(
...@@ -183,7 +188,7 @@ object RowEncoder { ...@@ -183,7 +188,7 @@ object RowEncoder {
* can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
* `org.apache.spark.sql.types.Decimal`. * `org.apache.spark.sql.types.Decimal`.
*/ */
private def externalDataTypeForInput(dt: DataType): DataType = dt match { def externalDataTypeForInput(dt: DataType): DataType = dt match {
// In order to support both Decimal and java/scala BigDecimal in external row, we make this // In order to support both Decimal and java/scala BigDecimal in external row, we make this
// as java.lang.Object. // as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object]) case _: DecimalType => ObjectType(classOf[java.lang.Object])
...@@ -192,7 +197,7 @@ object RowEncoder { ...@@ -192,7 +197,7 @@ object RowEncoder {
case _ => externalDataTypeFor(dt) case _ => externalDataTypeFor(dt)
} }
private def externalDataTypeFor(dt: DataType): DataType = dt match { def externalDataTypeFor(dt: DataType): DataType = dt match {
case _ if ScalaReflection.isNativeType(dt) => dt case _ if ScalaReflection.isNativeType(dt) => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date]) case DateType => ObjectType(classOf[java.sql.Date])
......
...@@ -26,6 +26,7 @@ import org.apache.spark.SparkConf ...@@ -26,6 +26,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.serializer._ import org.apache.spark.serializer._
import org.apache.spark.sql.Row import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData
...@@ -692,22 +693,17 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) ...@@ -692,22 +693,17 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
case class GetExternalRowField( case class GetExternalRowField(
child: Expression, child: Expression,
index: Int, index: Int,
fieldName: String, fieldName: String) extends UnaryExpression with NonSQLExpression {
dataType: DataType) extends UnaryExpression with NonSQLExpression {
override def nullable: Boolean = false override def nullable: Boolean = false
override def dataType: DataType = ObjectType(classOf[Object])
override def eval(input: InternalRow): Any = override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported") throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val row = child.genCode(ctx) val row = child.genCode(ctx)
val getField = dataType match {
case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)"""
case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)"""
}
val code = s""" val code = s"""
${row.code} ${row.code}
...@@ -720,8 +716,55 @@ case class GetExternalRowField( ...@@ -720,8 +716,55 @@ case class GetExternalRowField(
"cannot be null."); "cannot be null.");
} }
final ${ctx.javaType(dataType)} ${ev.value} = $getField; final Object ${ev.value} = ${row.value}.get($index);
""" """
ev.copy(code = code, isNull = "false") ev.copy(code = code, isNull = "false")
} }
} }
/**
* Validates the actual data type of input expression at runtime. If it doesn't match the
* expectation, throw an exception.
*/
case class ValidateExternalType(child: Expression, expected: DataType)
extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object]))
override def nullable: Boolean = child.nullable
override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val input = child.genCode(ctx)
val obj = input.value
val typeCheck = expected match {
case _: DecimalType =>
Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
case _: ArrayType =>
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
case _ =>
s"$obj instanceof ${ctx.boxedType(dataType)}"
}
val code = s"""
${input.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${input.isNull}) {
if ($typeCheck) {
${ev.value} = (${ctx.boxedType(dataType)}) $obj;
} else {
throw new RuntimeException($obj.getClass().getName() + " is not a valid " +
"external type for schema of ${expected.simpleString}");
}
}
"""
ev.copy(code = code, isNull = input.isNull)
}
}
...@@ -127,22 +127,6 @@ class RowEncoderSuite extends SparkFunSuite { ...@@ -127,22 +127,6 @@ class RowEncoderSuite extends SparkFunSuite {
new StructType().add("array", arrayOfString).add("map", mapOfString)) new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT)) .add("structOfUDT", structOfUDT))
test(s"encode/decode: Product") {
val schema = new StructType()
.add("structAsProduct",
new StructType()
.add("int", IntegerType)
.add("string", StringType)
.add("double", DoubleType))
val encoder = RowEncoder(schema).resolveAndBind()
val input: Row = Row((100, "test", 0.123))
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
assert(input.getStruct(0) == convertedBack.getStruct(0))
}
test("encode/decode decimal type") { test("encode/decode decimal type") {
val schema = new StructType() val schema = new StructType()
.add("int", IntegerType) .add("int", IntegerType)
...@@ -232,6 +216,37 @@ class RowEncoderSuite extends SparkFunSuite { ...@@ -232,6 +216,37 @@ class RowEncoderSuite extends SparkFunSuite {
assert(e.getMessage.contains("top level row object")) assert(e.getMessage.contains("top level row object"))
} }
test("RowEncoder should validate external type") {
val e1 = intercept[RuntimeException] {
val schema = new StructType().add("a", IntegerType)
val encoder = RowEncoder(schema)
encoder.toRow(Row(1.toShort))
}
assert(e1.getMessage.contains("java.lang.Short is not a valid external type"))
val e2 = intercept[RuntimeException] {
val schema = new StructType().add("a", StringType)
val encoder = RowEncoder(schema)
encoder.toRow(Row(1))
}
assert(e2.getMessage.contains("java.lang.Integer is not a valid external type"))
val e3 = intercept[RuntimeException] {
val schema = new StructType().add("a",
new StructType().add("b", IntegerType).add("c", StringType))
val encoder = RowEncoder(schema)
encoder.toRow(Row(1 -> "a"))
}
assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type"))
val e4 = intercept[RuntimeException] {
val schema = new StructType().add("a", ArrayType(TimestampType))
val encoder = RowEncoder(schema)
encoder.toRow(Row(Array("a")))
}
assert(e4.getMessage.contains("java.lang.String is not a valid external type"))
}
private def encodeDecodeTest(schema: StructType): Unit = { private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") { test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema).resolveAndBind() val encoder = RowEncoder(schema).resolveAndBind()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment