diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b108017c4c482720419c91a8ddba6f1b447769b5..e67f2be6d237ea554756947fca8dfa5d2fa62631 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -24,6 +24,15 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => + /** + * Override [[TreeNode.apply]] to so we can return a more narrow type. + * + * Note that this cannot return BaseType because logical plan's plan node might return + * physical plan for innerChildren, e.g. in-memory relation logical plan node has a reference + * to the physical plan node it is referencing. + */ + override def apply(number: Int): QueryPlan[_] = super.apply(number).asInstanceOf[QueryPlan[_]] + def output: Seq[Attribute] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 0f33e1dae944e6ffabe8276ef1775e9e37f6092b..b4358c2ef2e620df1b39e0a8b1149191840f96d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -412,7 +412,7 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) s"CTE $cteAliases" } - override def innerChildren: Seq[QueryPlan[_]] = cteRelations.map(_._2) + override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2) } case class WithWindowDefinition( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index ea8d8fef7bdf19753671733f14c5453130a1be20..670fa2bc8de8eeef84b2c91979759d40fd58722a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID import scala.collection.Map -import scala.collection.mutable.Stack import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -28,12 +27,9 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext -import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.ScalaReflection._ -import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -493,7 +489,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns a string representation of the nodes in this tree, where each operator is numbered. - * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. + * The numbers can be used with [[TreeNode.apply]] to easily access specific subtrees. + * + * The numbers are based on depth-first traversal of the tree (with innerChildren traversed first + * before children). */ def numberedTreeString: String = treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") @@ -501,17 +500,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns the tree node at the specified number. * Numbers for each node can be found in the [[numberedTreeString]]. + * + * Note that this cannot return BaseType because logical plan's plan node might return + * physical plan for innerChildren, e.g. in-memory relation logical plan node has a reference + * to the physical plan node it is referencing. */ - def apply(number: Int): BaseType = getNodeNumbered(new MutableInt(number)) + def apply(number: Int): TreeNode[_] = getNodeNumbered(new MutableInt(number)).orNull - protected def getNodeNumbered(number: MutableInt): BaseType = { + private def getNodeNumbered(number: MutableInt): Option[TreeNode[_]] = { if (number.i < 0) { - null.asInstanceOf[BaseType] + None } else if (number.i == 0) { - this + Some(this) } else { number.i -= 1 - children.map(_.getNodeNumbered(number)).find(_ != null).getOrElse(null.asInstanceOf[BaseType]) + // Note that this traversal order must be the same as numberedTreeString. + innerChildren.map(_.getNodeNumbered(number)).find(_ != None).getOrElse { + children.map(_.getNodeNumbered(number)).find(_ != None).flatten + } } } @@ -527,6 +533,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and * `lastChildren` for the root node should be empty. + * + * Note that this traversal (numbering) order must be the same as [[getNodeNumbered]]. */ def generateTreeString( depth: Int, @@ -534,19 +542,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder: StringBuilder, verbose: Boolean, prefix: String = ""): StringBuilder = { + if (depth > 0) { lastChildren.init.foreach { isLast => - val prefixFragment = if (isLast) " " else ": " - builder.append(prefixFragment) + builder.append(if (isLast) " " else ": ") } - - val branch = if (lastChildren.last) "+- " else ":- " - builder.append(branch) + builder.append(if (lastChildren.last) "+- " else ":- ") } builder.append(prefix) - val headline = if (verbose) verboseString else simpleString - builder.append(headline) + builder.append(if (verbose) verboseString else simpleString) builder.append("\n") if (innerChildren.nonEmpty) { @@ -557,9 +562,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } if (children.nonEmpty) { - children.init.foreach( - _.generateTreeString(depth + 1, lastChildren :+ false, builder, verbose, prefix)) - children.last.generateTreeString(depth + 1, lastChildren :+ true, builder, verbose, prefix) + children.init.foreach(_.generateTreeString( + depth + 1, lastChildren :+ false, builder, verbose, prefix)) + children.last.generateTreeString( + depth + 1, lastChildren :+ true, builder, verbose, prefix) } builder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 56bd5c1891e8d03df7d328cc83063240dfa58cda..03cc04659bd55526400b62fa6d94bb583d52900b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan @@ -64,7 +63,7 @@ case class InMemoryRelation( val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) + override protected def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5a4b1cfe95e270e19fa60237cb9085608431a24d..2ef8b18c046124e8eb4d68ae975206d4e3b2fcf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -54,6 +54,24 @@ class SubquerySuite extends QueryTest with SharedSQLContext { t.createOrReplaceTempView("t") } + test("SPARK-18854 numberedTreeString for subquery") { + val df = sql("select * from range(10) where id not in " + + "(select id from range(2) union all select id from range(2))") + + // The depth first traversal of the plan tree + val dfs = Seq("Project", "Filter", "Union", "Project", "Range", "Project", "Range", "Range") + val numbered = df.queryExecution.analyzed.numberedTreeString.split("\n") + + // There should be 8 plan nodes in total + assert(numbered.size == dfs.size) + + for (i <- dfs.indices) { + val node = df.queryExecution.analyzed(i) + assert(node.nodeName == dfs(i)) + assert(numbered(i).contains(node.nodeName)) + } + } + test("rdd deserialization does not crash [SPARK-15791]") { sql("select (select 1 as b) as b").rdd.count() }