Skip to content
Snippets Groups Projects
Commit 295747e5 authored by Wenchen Fan's avatar Wenchen Fan Committed by Cheng Lian
Browse files

[SPARK-19716][SQL] support by-name resolution for struct type elements in array

## What changes were proposed in this pull request?

Previously when we construct deserializer expression for array type, we will first cast the corresponding field to expected array type and then apply `MapObjects`.

However, by doing that, we lose the opportunity to do by-name resolution for struct type inside array type. In this PR, I introduce a `UnresolvedMapObjects` to hold the lambda function and the input array expression. Then during analysis, after the input array expression is resolved, we get the actual array element type and apply by-name resolution. Then we don't need to add `Cast` for array type when constructing the deserializer expression, as the element type is determined later at analyzer.

## How was this patch tested?

new regression test

Author: Wenchen Fan <wenchen@databricks.com>

Closes #17398 from cloud-fan/dataset.
parent 402bf2a5
No related branches found
No related tags found
No related merge requests found
......@@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection {
* Array[T]. Special handling is performed for primitive types to map them back to their raw
* JVM form instead of the Scala Array that handles auto boxing.
*/
private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized {
val cls = tpe match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
......@@ -178,15 +178,17 @@ object ScalaReflection extends ScalaReflection {
* is [a: int, b: long], then we will hit runtime error and say that we can't construct class
* `Data` with int and long, because we lost the information that `b` should be a string.
*
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
* don't need to cast struct type because there must be `UnresolvedExtractValue` or
* `GetStructField` wrapping it, thus we only need to handle leaf type.
* This method help us "remember" the required data type by adding a `UpCast`. Note that we
* only need to do this for leaf nodes.
*/
def upCastToExpectedType(
expr: Expression,
expected: DataType,
walkedTypePath: Seq[String]): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
// TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and
// it's not trivial to support by-name resolution for StructType inside MapType.
case _ => UpCast(expr, expected, walkedTypePath)
}
......@@ -265,42 +267,48 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(_, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
// TODO: add runtime null check for primitive array
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => Some("toIntArray")
case t if t <:< definitions.LongTpe => Some("toLongArray")
case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
case t if t <:< definitions.FloatTpe => Some("toFloatArray")
case t if t <:< definitions.ShortTpe => Some("toShortArray")
case t if t <:< definitions.ByteTpe => Some("toByteArray")
case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
case _ => None
val mapFunction: Expression => Expression = p => {
val converter = deserializerFor(elementType, Some(p), newTypePath)
if (elementNullable) {
converter
} else {
AssertNotNull(converter, newTypePath)
}
}
primitiveMethod.map { method =>
Invoke(getPath, method, arrayClassFor(elementType))
}.getOrElse {
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
Invoke(
MapObjects(
p => deserializerFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
arrayClassFor(elementType))
val arrayData = UnresolvedMapObjects(mapFunction, getPath)
val arrayCls = arrayClassFor(elementType)
if (elementNullable) {
Invoke(arrayData, "array", arrayCls)
} else {
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => "toIntArray"
case t if t <:< definitions.LongTpe => "toLongArray"
case t if t <:< definitions.DoubleTpe => "toDoubleArray"
case t if t <:< definitions.FloatTpe => "toFloatArray"
case t if t <:< definitions.ShortTpe => "toShortArray"
case t if t <:< definitions.ByteTpe => "toByteArray"
case t if t <:< definitions.BooleanTpe => "toBooleanArray"
case other => throw new IllegalStateException("expect primitive array element type " +
"but got " + other)
}
Invoke(arrayData, primitiveMethod, arrayCls)
}
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
val Schema(_, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
val mapFunction: Expression => Expression = p => {
val converter = deserializerFor(elementType, Some(p), newTypePath)
if (nullable) {
if (elementNullable) {
converter
} else {
AssertNotNull(converter, newTypePath)
......@@ -312,7 +320,7 @@ object ScalaReflection extends ScalaReflection {
case NoSymbol => classOf[Seq[_]]
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
MapObjects(mapFunction, getPath, dataType, Some(cls))
UnresolvedMapObjects(mapFunction, getPath, Some(cls))
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
......
......@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects}
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
......@@ -2227,8 +2227,21 @@ class Analyzer(
validateTopLevelTupleFields(deserializer, inputs)
val resolved = resolveExpression(
deserializer, LocalRelation(inputs), throws = true)
validateNestedTupleFields(resolved)
resolved
val result = resolved transformDown {
case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved =>
inputData.dataType match {
case ArrayType(et, _) =>
val expr = MapObjects(func, inputData, et, cls) transformUp {
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
expr
case other =>
throw new AnalysisException("need an array field but got " + other.simpleString)
}
}
validateNestedTupleFields(result)
result
}
}
......
......@@ -68,7 +68,7 @@ object ExtractValue {
case StructType(_) =>
s"Field name should be String Literal, but it's $extraction"
case other =>
s"Can't extract value from $child"
s"Can't extract value from $child: need struct type but got ${other.simpleString}"
}
throw new AnalysisException(errorMsg)
}
......
......@@ -448,6 +448,17 @@ object MapObjects {
}
}
case class UnresolvedMapObjects(
function: Expression => Expression,
child: Expression,
customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable {
override lazy val resolved = false
override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse {
throw new UnsupportedOperationException("not resolved")
}
}
/**
* Applies the given expression to every element of a collection of items, returning the result
* as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
......@@ -581,17 +592,24 @@ case class MapObjects private(
// collection
val collObjectName = s"${cls.getName}$$.MODULE$$"
val getBuilderVar = s"$collObjectName.newBuilder()"
(s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
$builderValue.sizeHint($dataLength);""",
(
s"""
${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
$builderValue.sizeHint($dataLength);
""",
genValue => s"$builderValue.$$plus$$eq($genValue);",
s"(${cls.getName}) $builderValue.result();")
s"(${cls.getName}) $builderValue.result();"
)
case None =>
// array
(s"""$convertedType[] $convertedArray = null;
$convertedArray = $arrayConstructor;""",
(
s"""
$convertedType[] $convertedArray = null;
$convertedArray = $arrayConstructor;
""",
genValue => s"$convertedArray[$loopIndex] = $genValue;",
s"new ${classOf[GenericArrayData].getName}($convertedArray);")
s"new ${classOf[GenericArrayData].getName}($convertedArray);"
)
}
val code = s"""
......
......@@ -33,6 +33,10 @@ case class StringIntClass(a: String, b: Int)
case class ComplexClass(a: Long, b: StringLongClass)
case class ArrayClass(arr: Seq[StringIntClass])
case class NestedArrayClass(nestedArr: Array[ArrayClass])
class EncoderResolutionSuite extends PlanTest {
private val str = UTF8String.fromString("hello")
......@@ -62,6 +66,54 @@ class EncoderResolutionSuite extends PlanTest {
encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
}
test("real type doesn't match encoder schema but they are compatible: array") {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
}
test("real type doesn't match encoder schema but they are compatible: nested array") {
val encoder = ExpressionEncoder[NestedArrayClass]
val et = new StructType().add("arr", ArrayType(
new StructType().add("a", "int").add("b", "int").add("c", "int")))
val attrs = Seq('nestedArr.array(et))
val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3)))
val outerArr = new GenericArrayData(Array(InternalRow(innerArr)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr))
}
test("the real type is not compatible with encoder schema: non-array field") {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.int)
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"need an array field but got int")
}
test("the real type is not compatible with encoder schema: array element type") {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("c", "int")))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"No such struct field a in c")
}
test("the real type is not compatible with encoder schema: nested array element type") {
val encoder = ExpressionEncoder[NestedArrayClass]
withClue("inner element is not array") {
val attrs = Seq('nestedArr.array(new StructType().add("arr", "int")))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"need an array field but got int")
}
withClue("nested array element type is not compatible") {
val attrs = Seq('nestedArr.array(new StructType()
.add("arr", ArrayType(new StructType().add("c", "int")))))
assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"No such struct field a in c")
}
}
test("nullability of array type element should not fail analysis") {
val encoder = ExpressionEncoder[Seq[Int]]
val attrs = 'a.array(IntegerType) :: Nil
......
......@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
extends Aggregator[T, (Boolean, T), T] {
private val encoder = implicitly[Encoder[T]]
@transient private val encoder = implicitly[Encoder[T]]
override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
......
......@@ -142,6 +142,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2)))
}
test("as seq of case class - reorder fields by name") {
val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), lit("a").as("a"))))
val ds = df.as[Seq[ClassData]]
assert(ds.collect() === Array(
Seq(ClassData("a", 0)),
Seq(ClassData("a", 1)),
Seq(ClassData("a", 2))))
}
test("map") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment