Skip to content
Snippets Groups Projects
Commit 55cc1c99 authored by Wenchen Fan's avatar Wenchen Fan Committed by Cheng Lian
Browse files

[SPARK-14139][SQL] RowEncoder should preserve schema nullability

## What changes were proposed in this pull request?

The problem is: In `RowEncoder`, we use `Invoke` to get the field of an external row, which lose the nullability information. This PR creates a `GetExternalRowField` expression, so that we can preserve the nullability info.

TODO: simplify the null handling logic in `RowEncoder`, to remove so many if branches, in follow-up PR.

## How was this patch tested?

new tests in `RowEncoderSuite`

Note that, This PR takes over https://github.com/apache/spark/pull/11980, with a little simplification, so all credits should go to koertkuipers

Author: Wenchen Fan <wenchen@databricks.com>
Author: Koert Kuipers <koert@tresata.com>

Closes #12364 from cloud-fan/nullable.
parent 77361a43
No related branches found
No related tags found
No related merge requests found
......@@ -35,9 +35,8 @@ import org.apache.spark.unsafe.types.UTF8String
object RowEncoder {
def apply(schema: StructType): ExpressionEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
// We use an If expression to wrap extractorsFor result of StructType
val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue
val inputObject = BoundReference(0, ObjectType(cls), nullable = false)
val serializer = serializerFor(inputObject, schema)
val deserializer = deserializerFor(schema)
new ExpressionEncoder[Row](
schema,
......@@ -130,21 +129,28 @@ object RowEncoder {
case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
val method = if (f.dataType.isInstanceOf[StructType]) {
"getStruct"
val fieldValue = serializerFor(
GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)),
f.dataType
)
if (f.nullable) {
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
fieldValue
)
} else {
"get"
fieldValue
}
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
serializerFor(
Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
f.dataType))
}
If(IsNull(inputObject),
Literal.create(null, inputType),
CreateStruct(convertedFields))
if (inputObject.nullable) {
If(IsNull(inputObject),
Literal.create(null, inputType),
CreateStruct(convertedFields))
} else {
CreateStruct(convertedFields)
}
}
/**
......
......@@ -688,3 +688,45 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
ev.copy(code = code, isNull = "false", value = childGen.value)
}
}
/**
* Returns the value of field at index `index` from the external row `child`.
* This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s.
*
* Note that the input row and the field we try to get are both guaranteed to be not null, if they
* are null, a runtime exception will be thrown.
*/
case class GetExternalRowField(
child: Expression,
index: Int,
dataType: DataType) extends UnaryExpression with NonSQLExpression {
override def nullable: Boolean = false
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val row = child.genCode(ctx)
val getField = dataType match {
case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)"""
case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)"""
}
val code = s"""
${row.code}
if (${row.isNull}) {
throw new RuntimeException("The input external row cannot be null.");
}
if (${row.value}.isNullAt($index)) {
throw new RuntimeException("The ${index}th field of input row cannot be null.");
}
final ${ctx.javaType(dataType)} ${ev.value} = $getField;
"""
ev.copy(code = code, isNull = "false")
}
}
......@@ -160,6 +160,14 @@ class RowEncoderSuite extends SparkFunSuite {
.compareTo(convertedBack.getDecimal(3)) == 0)
}
test("RowEncoder should preserve schema nullability") {
val schema = new StructType().add("int", IntegerType, nullable = false)
val encoder = RowEncoder(schema)
assert(encoder.serializer.length == 1)
assert(encoder.serializer.head.dataType == IntegerType)
assert(encoder.serializer.head.nullable == false)
}
private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema)
......
......@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.postfixOps
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
......@@ -658,6 +658,22 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val dataset = Seq(1, 2, 3).toDS()
checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4)
}
test("runtime null check for RowEncoder") {
val schema = new StructType().add("i", IntegerType, nullable = false)
val df = sqlContext.range(10).map(l => {
if (l % 5 == 0) {
Row(null)
} else {
Row(l)
}
})(RowEncoder(schema))
val message = intercept[Exception] {
df.collect()
}.getMessage
assert(message.contains("The 0th field of input row cannot be null"))
}
}
case class OtherTuple(_1: String, _2: Int)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment