From 3fa3d121dfec60f9768d3859e8450ee482b2d4e8 Mon Sep 17 00:00:00 2001
From: Michael Armbrust <michael@databricks.com>
Date: Tue, 24 Mar 2015 12:28:01 -0700
Subject: [PATCH] [SPARK-6054][SQL] Fix transformations of TreeNodes that hold
 StructTypes

Due to a recent change that made `StructType` a `Seq` we started inadvertently turning `StructType`s into generic `Traversable` when attempting nested tree transformations.  In this PR we explicitly avoid descending into `DataType`s to avoid this bug.

Author: Michael Armbrust <michael@databricks.com>

Closes #5157 from marmbrus/udfFix and squashes the following commits:

26f7087 [Michael Armbrust] Fix transformations of TreeNodes that hold StructTypes
---
 .../spark/sql/catalyst/plans/QueryPlan.scala  |  2 ++
 .../spark/sql/catalyst/trees/TreeNode.scala   | 20 ++++++++++++++++---
 .../scala/org/apache/spark/sql/UDFSuite.scala |  6 ++++++
 3 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 48191f3119..bd9291e9ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
       case e: Expression => transformExpressionDown(e)
       case Some(e: Expression) => Some(transformExpressionDown(e))
       case m: Map[_,_] => m
+      case d: DataType => d // Avoid unpacking Structs
       case seq: Traversable[_] => seq.map {
         case e: Expression => transformExpressionDown(e)
         case other => other
@@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
       case e: Expression => transformExpressionUp(e)
       case Some(e: Expression) => Some(transformExpressionUp(e))
       case m: Map[_,_] => m
+      case d: DataType => d // Avoid unpacking Structs
       case seq: Traversable[_] => seq.map {
         case e: Expression => transformExpressionUp(e)
         case other => other
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index f84ffe4e17..0ae9f6b296 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.trees
 
 import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.types.DataType
 
 /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
 private class MutableInt(var i: Int)
@@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
           Some(arg)
         }
       case m: Map[_,_] => m
+      case d: DataType => d // Avoid unpacking Structs
       case args: Traversable[_] => args.map {
         case arg: TreeNode[_] if children contains arg =>
           val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
@@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
           Some(arg)
         }
       case m: Map[_,_] => m
+      case d: DataType => d // Avoid unpacking Structs
       case args: Traversable[_] => args.map {
         case arg: TreeNode[_] if children contains arg =>
           val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
@@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
    * @param newArgs the new product arguments.
    */
   def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
+    val defaultCtor =
+      getClass.getConstructors
+        .find(_.getParameterTypes.size != 0)
+        .headOption
+        .getOrElse(sys.error(s"No valid constructor for $nodeName"))
+
     try {
       CurrentOrigin.withOrigin(origin) {
         // Skip no-arg constructors that are just there for kryo.
-        val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
         if (otherCopyArgs.isEmpty) {
           defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
         } else {
@@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
     } catch {
       case e: java.lang.IllegalArgumentException =>
         throw new TreeNodeException(
-          this, s"Failed to copy node.  Is otherCopyArgs specified correctly for $nodeName? "
-            + s"Exception message: ${e.getMessage}.")
+          this,
+          s"""
+             |Failed to copy node.
+             |Is otherCopyArgs specified correctly for $nodeName.
+             |Exception message: ${e.getMessage}
+             |ctor: $defaultCtor?
+             |args: ${newArgs.mkString(", ")}
+           """.stripMargin)
     }
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index be105c6e83..d615542ab5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -50,4 +50,10 @@ class UDFSuite extends QueryTest {
         .select($"ret.f1").head().getString(0)
     assert(result === "test")
   }
+
+  test("udf that is transformed") {
+    udf.register("makeStruct", (x: Int, y: Int) => (x, y))
+    // 1 + 1 is constant folded causing a transformation.
+    assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
+  }
 }
-- 
GitLab