From 295747e59739ee8a697ac3eba485d3439e4a04c3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Tue, 4 Apr 2017 16:38:32 -0700 Subject: [PATCH] [SPARK-19716][SQL] support by-name resolution for struct type elements in array ## What changes were proposed in this pull request? Previously when we construct deserializer expression for array type, we will first cast the corresponding field to expected array type and then apply `MapObjects`. However, by doing that, we lose the opportunity to do by-name resolution for struct type inside array type. In this PR, I introduce a `UnresolvedMapObjects` to hold the lambda function and the input array expression. Then during analysis, after the input array expression is resolved, we get the actual array element type and apply by-name resolution. Then we don't need to add `Cast` for array type when constructing the deserializer expression, as the element type is determined later at analyzer. ## How was this patch tested? new regression test Author: Wenchen Fan <wenchen@databricks.com> Closes #17398 from cloud-fan/dataset. --- .../spark/sql/catalyst/ScalaReflection.scala | 66 +++++++++++-------- .../sql/catalyst/analysis/Analyzer.scala | 19 +++++- .../expressions/complexTypeExtractors.scala | 2 +- .../expressions/objects/objects.scala | 32 +++++++-- .../encoders/EncoderResolutionSuite.scala | 52 +++++++++++++++ .../sql/expressions/ReduceAggregator.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 9 +++ 7 files changed, 141 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index da37eb00dc..206ae2f0e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -92,7 +92,7 @@ object ScalaReflection extends ScalaReflection { * Array[T]. Special handling is performed for primitive types to map them back to their raw * JVM form instead of the Scala Array that handles auto boxing. */ - private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized { + private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized { val cls = tpe match { case t if t <:< definitions.IntTpe => classOf[Array[Int]] case t if t <:< definitions.LongTpe => classOf[Array[Long]] @@ -178,15 +178,17 @@ object ScalaReflection extends ScalaReflection { * is [a: int, b: long], then we will hit runtime error and say that we can't construct class * `Data` with int and long, because we lost the information that `b` should be a string. * - * This method help us "remember" the required data type by adding a `UpCast`. Note that we - * don't need to cast struct type because there must be `UnresolvedExtractValue` or - * `GetStructField` wrapping it, thus we only need to handle leaf type. + * This method help us "remember" the required data type by adding a `UpCast`. Note that we + * only need to do this for leaf nodes. */ def upCastToExpectedType( expr: Expression, expected: DataType, walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr + case _: ArrayType => expr + // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and + // it's not trivial to support by-name resolution for StructType inside MapType. case _ => UpCast(expr, expected, walkedTypePath) } @@ -265,42 +267,48 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t + val Schema(_, elementNullable) = schemaFor(elementType) + val className = getClassNameFromType(elementType) + val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - // TODO: add runtime null check for primitive array - val primitiveMethod = elementType match { - case t if t <:< definitions.IntTpe => Some("toIntArray") - case t if t <:< definitions.LongTpe => Some("toLongArray") - case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") - case t if t <:< definitions.FloatTpe => Some("toFloatArray") - case t if t <:< definitions.ShortTpe => Some("toShortArray") - case t if t <:< definitions.ByteTpe => Some("toByteArray") - case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") - case _ => None + val mapFunction: Expression => Expression = p => { + val converter = deserializerFor(elementType, Some(p), newTypePath) + if (elementNullable) { + converter + } else { + AssertNotNull(converter, newTypePath) + } } - primitiveMethod.map { method => - Invoke(getPath, method, arrayClassFor(elementType)) - }.getOrElse { - val className = getClassNameFromType(elementType) - val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath - Invoke( - MapObjects( - p => deserializerFor(elementType, Some(p), newTypePath), - getPath, - schemaFor(elementType).dataType), - "array", - arrayClassFor(elementType)) + val arrayData = UnresolvedMapObjects(mapFunction, getPath) + val arrayCls = arrayClassFor(elementType) + + if (elementNullable) { + Invoke(arrayData, "array", arrayCls) + } else { + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => "toIntArray" + case t if t <:< definitions.LongTpe => "toLongArray" + case t if t <:< definitions.DoubleTpe => "toDoubleArray" + case t if t <:< definitions.FloatTpe => "toFloatArray" + case t if t <:< definitions.ShortTpe => "toShortArray" + case t if t <:< definitions.ByteTpe => "toByteArray" + case t if t <:< definitions.BooleanTpe => "toBooleanArray" + case other => throw new IllegalStateException("expect primitive array element type " + + "but got " + other) + } + Invoke(arrayData, primitiveMethod, arrayCls) } case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) + val Schema(_, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val mapFunction: Expression => Expression = p => { val converter = deserializerFor(elementType, Some(p), newTypePath) - if (nullable) { + if (elementNullable) { converter } else { AssertNotNull(converter, newTypePath) @@ -312,7 +320,7 @@ object ScalaReflection extends ScalaReflection { case NoSymbol => classOf[Seq[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } - MapObjects(mapFunction, getPath, dataType, Some(cls)) + UnresolvedMapObjects(mapFunction, getPath, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2d53d2424a..c698ca6a83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.expressions.objects.{MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ @@ -2227,8 +2227,21 @@ class Analyzer( validateTopLevelTupleFields(deserializer, inputs) val resolved = resolveExpression( deserializer, LocalRelation(inputs), throws = true) - validateNestedTupleFields(resolved) - resolved + val result = resolved transformDown { + case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => + inputData.dataType match { + case ArrayType(et, _) => + val expr = MapObjects(func, inputData, et, cls) transformUp { + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + expr + case other => + throw new AnalysisException("need an array field but got " + other.simpleString) + } + } + validateNestedTupleFields(result) + result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index de1594d119..ef88cfb543 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -68,7 +68,7 @@ object ExtractValue { case StructType(_) => s"Field name should be String Literal, but it's $extraction" case other => - s"Can't extract value from $child" + s"Can't extract value from $child: need struct type but got ${other.simpleString}" } throw new AnalysisException(errorMsg) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bb584f7d08..00e2ac91e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -448,6 +448,17 @@ object MapObjects { } } +case class UnresolvedMapObjects( + function: Expression => Expression, + child: Expression, + customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { + override lazy val resolved = false + + override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { + throw new UnsupportedOperationException("not resolved") + } +} + /** * Applies the given expression to every element of a collection of items, returning the result * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda @@ -581,17 +592,24 @@ case class MapObjects private( // collection val collObjectName = s"${cls.getName}$$.MODULE$$" val getBuilderVar = s"$collObjectName.newBuilder()" - - (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; - $builderValue.sizeHint($dataLength);""", + ( + s""" + ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength); + """, genValue => s"$builderValue.$$plus$$eq($genValue);", - s"(${cls.getName}) $builderValue.result();") + s"(${cls.getName}) $builderValue.result();" + ) case None => // array - (s"""$convertedType[] $convertedArray = null; - $convertedArray = $arrayConstructor;""", + ( + s""" + $convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor; + """, genValue => s"$convertedArray[$loopIndex] = $genValue;", - s"new ${classOf[GenericArrayData].getName}($convertedArray);") + s"new ${classOf[GenericArrayData].getName}($convertedArray);" + ) } val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 802397d50e..e5a3e1fd37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -33,6 +33,10 @@ case class StringIntClass(a: String, b: Int) case class ComplexClass(a: Long, b: StringLongClass) +case class ArrayClass(arr: Seq[StringIntClass]) + +case class NestedArrayClass(nestedArr: Array[ArrayClass]) + class EncoderResolutionSuite extends PlanTest { private val str = UTF8String.fromString("hello") @@ -62,6 +66,54 @@ class EncoderResolutionSuite extends PlanTest { encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } + test("real type doesn't match encoder schema but they are compatible: array") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int"))) + val array = new GenericArrayData(Array(InternalRow(1, 2, 3))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(array)) + } + + test("real type doesn't match encoder schema but they are compatible: nested array") { + val encoder = ExpressionEncoder[NestedArrayClass] + val et = new StructType().add("arr", ArrayType( + new StructType().add("a", "int").add("b", "int").add("c", "int"))) + val attrs = Seq('nestedArr.array(et)) + val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3))) + val outerArr = new GenericArrayData(Array(InternalRow(innerArr))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr)) + } + + test("the real type is not compatible with encoder schema: non-array field") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.int) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "need an array field but got int") + } + + test("the real type is not compatible with encoder schema: array element type") { + val encoder = ExpressionEncoder[ArrayClass] + val attrs = Seq('arr.array(new StructType().add("c", "int"))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "No such struct field a in c") + } + + test("the real type is not compatible with encoder schema: nested array element type") { + val encoder = ExpressionEncoder[NestedArrayClass] + + withClue("inner element is not array") { + val attrs = Seq('nestedArr.array(new StructType().add("arr", "int"))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "need an array field but got int") + } + + withClue("nested array element type is not compatible") { + val attrs = Seq('nestedArr.array(new StructType() + .add("arr", ArrayType(new StructType().add("c", "int"))))) + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == + "No such struct field a in c") + } + } + test("nullability of array type element should not fail analysis") { val encoder = ExpressionEncoder[Seq[Int]] val attrs = 'a.array(IntegerType) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala index 174378304d..e266ae55cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) extends Aggregator[T, (Boolean, T), T] { - private val encoder = implicitly[Encoder[T]] + @transient private val encoder = implicitly[Encoder[T]] override def zero: (Boolean, T) = (false, null.asInstanceOf[T]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 68e071a1a6..5b5cd28ad0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -142,6 +142,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) } + test("as seq of case class - reorder fields by name") { + val df = spark.range(3).select(array(struct($"id".cast("int").as("b"), lit("a").as("a")))) + val ds = df.as[Seq[ClassData]] + assert(ds.collect() === Array( + Seq(ClassData("a", 0)), + Seq(ClassData("a", 1)), + Seq(ClassData("a", 2)))) + } + test("map") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( -- GitLab