From 3fa3d121dfec60f9768d3859e8450ee482b2d4e8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust <michael@databricks.com> Date: Tue, 24 Mar 2015 12:28:01 -0700 Subject: [PATCH] [SPARK-6054][SQL] Fix transformations of TreeNodes that hold StructTypes Due to a recent change that made `StructType` a `Seq` we started inadvertently turning `StructType`s into generic `Traversable` when attempting nested tree transformations. In this PR we explicitly avoid descending into `DataType`s to avoid this bug. Author: Michael Armbrust <michael@databricks.com> Closes #5157 from marmbrus/udfFix and squashes the following commits: 26f7087 [Michael Armbrust] Fix transformations of TreeNodes that hold StructTypes --- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 ++ .../spark/sql/catalyst/trees/TreeNode.scala | 20 ++++++++++++++++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 6 ++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 48191f3119..bd9291e9ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionDown(e) case other => other @@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionUp(e) case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f84ffe4e17..0ae9f6b296 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.types.DataType /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) @@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) @@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + val defaultCtor = + getClass.getConstructors + .find(_.getParameterTypes.size != 0) + .headOption + .getOrElse(sys.error(s"No valid constructor for $nodeName")) + try { CurrentOrigin.withOrigin(origin) { // Skip no-arg constructors that are just there for kryo. - val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head if (otherCopyArgs.isEmpty) { defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] } else { @@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " - + s"Exception message: ${e.getMessage}.") + this, + s""" + |Failed to copy node. + |Is otherCopyArgs specified correctly for $nodeName. + |Exception message: ${e.getMessage} + |ctor: $defaultCtor? + |args: ${newArgs.mkString(", ")} + """.stripMargin) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index be105c6e83..d615542ab5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -50,4 +50,10 @@ class UDFSuite extends QueryTest { .select($"ret.f1").head().getString(0) assert(result === "test") } + + test("udf that is transformed") { + udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + // 1 + 1 is constant folded causing a transformation. + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + } } -- GitLab