Skip to content
Snippets Groups Projects
Commit 3fa3d121 authored by Michael Armbrust's avatar Michael Armbrust
Browse files

[SPARK-6054][SQL] Fix transformations of TreeNodes that hold StructTypes

Due to a recent change that made `StructType` a `Seq` we started inadvertently turning `StructType`s into generic `Traversable` when attempting nested tree transformations.  In this PR we explicitly avoid descending into `DataType`s to avoid this bug.

Author: Michael Armbrust <michael@databricks.com>

Closes #5157 from marmbrus/udfFix and squashes the following commits:

26f7087 [Michael Armbrust] Fix transformations of TreeNodes that hold StructTypes
parent 26c6ce3d
No related branches found
No related tags found
No related merge requests found
...@@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy ...@@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionDown(e) case e: Expression => transformExpressionDown(e)
case Some(e: Expression) => Some(transformExpressionDown(e)) case Some(e: Expression) => Some(transformExpressionDown(e))
case m: Map[_,_] => m case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map { case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionDown(e) case e: Expression => transformExpressionDown(e)
case other => other case other => other
...@@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy ...@@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionUp(e) case e: Expression => transformExpressionUp(e)
case Some(e: Expression) => Some(transformExpressionUp(e)) case Some(e: Expression) => Some(transformExpressionUp(e))
case m: Map[_,_] => m case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map { case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionUp(e) case e: Expression => transformExpressionUp(e)
case other => other case other => other
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.trees package org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.types.DataType
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int) private class MutableInt(var i: Int)
...@@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { ...@@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg) Some(arg)
} }
case m: Map[_,_] => m case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map { case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg => case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule) val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
...@@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { ...@@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg) Some(arg)
} }
case m: Map[_,_] => m case m: Map[_,_] => m
case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map { case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg => case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule) val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
...@@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { ...@@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param newArgs the new product arguments. * @param newArgs the new product arguments.
*/ */
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
val defaultCtor =
getClass.getConstructors
.find(_.getParameterTypes.size != 0)
.headOption
.getOrElse(sys.error(s"No valid constructor for $nodeName"))
try { try {
CurrentOrigin.withOrigin(origin) { CurrentOrigin.withOrigin(origin) {
// Skip no-arg constructors that are just there for kryo. // Skip no-arg constructors that are just there for kryo.
val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
if (otherCopyArgs.isEmpty) { if (otherCopyArgs.isEmpty) {
defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
} else { } else {
...@@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { ...@@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
} catch { } catch {
case e: java.lang.IllegalArgumentException => case e: java.lang.IllegalArgumentException =>
throw new TreeNodeException( throw new TreeNodeException(
this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " this,
+ s"Exception message: ${e.getMessage}.") s"""
|Failed to copy node.
|Is otherCopyArgs specified correctly for $nodeName.
|Exception message: ${e.getMessage}
|ctor: $defaultCtor?
|args: ${newArgs.mkString(", ")}
""".stripMargin)
} }
} }
......
...@@ -50,4 +50,10 @@ class UDFSuite extends QueryTest { ...@@ -50,4 +50,10 @@ class UDFSuite extends QueryTest {
.select($"ret.f1").head().getString(0) .select($"ret.f1").head().getString(0)
assert(result === "test") assert(result === "test")
} }
test("udf that is transformed") {
udf.register("makeStruct", (x: Int, y: Int) => (x, y))
// 1 + 1 is constant folded causing a transformation.
assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment