Skip to content
Snippets Groups Projects
Commit b373a888 authored by Davies Liu's avatar Davies Liu Committed by Yin Huai
Browse files

[SPARK-13415][SQL] Visualize subquery in SQL web UI

## What changes were proposed in this pull request?

This PR support visualization for subquery in SQL web UI, also improve the explain of subquery, especially when it's used together with whole stage codegen.

For example:
```python
>>> sqlContext.range(100).registerTempTable("range")
>>> sqlContext.sql("select id / (select sum(id) from range) from range where id > (select id from range limit 1)").explain(True)
== Parsed Logical Plan ==
'Project [unresolvedalias(('id / subquery#9), None)]
:  +- 'SubqueryAlias subquery#9
:     +- 'Project [unresolvedalias('sum('id), None)]
:        +- 'UnresolvedRelation `range`, None
+- 'Filter ('id > subquery#8)
   :  +- 'SubqueryAlias subquery#8
   :     +- 'GlobalLimit 1
   :        +- 'LocalLimit 1
   :           +- 'Project [unresolvedalias('id, None)]
   :              +- 'UnresolvedRelation `range`, None
   +- 'UnresolvedRelation `range`, None

== Analyzed Logical Plan ==
(id / scalarsubquery()): double
Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / scalarsubquery())#11]
:  +- SubqueryAlias subquery#9
:     +- Aggregate [(sum(id#0L),mode=Complete,isDistinct=false) AS sum(id)#10L]
:        +- SubqueryAlias range
:           +- Range 0, 100, 1, 4, [id#0L]
+- Filter (id#0L > subquery#8)
   :  +- SubqueryAlias subquery#8
   :     +- GlobalLimit 1
   :        +- LocalLimit 1
   :           +- Project [id#0L]
   :              +- SubqueryAlias range
   :                 +- Range 0, 100, 1, 4, [id#0L]
   +- SubqueryAlias range
      +- Range 0, 100, 1, 4, [id#0L]

== Optimized Logical Plan ==
Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / scalarsubquery())#11]
:  +- SubqueryAlias subquery#9
:     +- Aggregate [(sum(id#0L),mode=Complete,isDistinct=false) AS sum(id)#10L]
:        +- Range 0, 100, 1, 4, [id#0L]
+- Filter (id#0L > subquery#8)
   :  +- SubqueryAlias subquery#8
   :     +- GlobalLimit 1
   :        +- LocalLimit 1
   :           +- Project [id#0L]
   :              +- Range 0, 100, 1, 4, [id#0L]
   +- Range 0, 100, 1, 4, [id#0L]

== Physical Plan ==
WholeStageCodegen
:  +- Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / scalarsubquery())#11]
:     :  +- Subquery subquery#9
:     :     +- WholeStageCodegen
:     :        :  +- TungstenAggregate(key=[], functions=[(sum(id#0L),mode=Final,isDistinct=false)], output=[sum(id)#10L])
:     :        :     +- INPUT
:     :        +- Exchange SinglePartition, None
:     :           +- WholeStageCodegen
:     :              :  +- TungstenAggregate(key=[], functions=[(sum(id#0L),mode=Partial,isDistinct=false)], output=[sum#14L])
:     :              :     +- Range 0, 1, 4, 100, [id#0L]
:     +- Filter (id#0L > subquery#8)
:        :  +- Subquery subquery#8
:        :     +- CollectLimit 1
:        :        +- WholeStageCodegen
:        :           :  +- Project [id#0L]
:        :           :     +- Range 0, 1, 4, 100, [id#0L]
:        +- Range 0, 1, 4, 100, [id#0L]
```

The web UI looks like:

![subquery](https://cloud.githubusercontent.com/assets/40902/13377963/932bcbae-dda7-11e5-82f7-03c9be85d77c.png)

This PR also change the tree structure of WholeStageCodegen to make it consistent than others. Before this change, Both WholeStageCodegen and InputAdapter hold a references to the same plans, those could be updated without notify another, causing problems, this is discovered by #11403 .

## How was this patch tested?

Existing tests, also manual tests with the example query, check the explain and web UI.

Author: Davies Liu <davies@databricks.com>

Closes #11417 from davies/viz_subquery.
parent ad0de99f
No related branches found
No related tags found
No related merge requests found
Showing
with 166 additions and 127 deletions
...@@ -229,8 +229,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy ...@@ -229,8 +229,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
override def simpleString: String = statePrefix + super.simpleString override def simpleString: String = statePrefix + super.simpleString
override def treeChildren: Seq[PlanType] = { /**
val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e}) * All the subqueries of current plan.
children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType]) */
def subqueries: Seq[PlanType] = {
expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]})
} }
override def innerChildren: Seq[PlanType] = subqueries
} }
...@@ -447,9 +447,52 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { ...@@ -447,9 +447,52 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
/** /**
* All the nodes that will be used to generate tree string. * All the nodes that will be used to generate tree string.
*
* For example:
*
* WholeStageCodegen
* +-- SortMergeJoin
* |-- InputAdapter
* | +-- Sort
* +-- InputAdapter
* +-- Sort
*
* the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string
* like this:
*
* WholeStageCodegen
* : +- SortMergeJoin
* : :- INPUT
* : :- INPUT
* :- Sort
* :- Sort
*/ */
protected def treeChildren: Seq[BaseType] = children protected def treeChildren: Seq[BaseType] = children
/**
* All the nodes that are parts of this node.
*
* For example:
*
* WholeStageCodegen
* +- SortMergeJoin
* |-- InputAdapter
* | +-- Sort
* +-- InputAdapter
* +-- Sort
*
* the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree
* string like this:
*
* WholeStageCodegen
* : +- SortMergeJoin
* : :- INPUT
* : :- INPUT
* :- Sort
* :- Sort
*/
protected def innerChildren: Seq[BaseType] = Nil
/** /**
* Appends the string represent of this node and its children to the given StringBuilder. * Appends the string represent of this node and its children to the given StringBuilder.
* *
...@@ -472,6 +515,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { ...@@ -472,6 +515,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
builder.append(simpleString) builder.append(simpleString)
builder.append("\n") builder.append("\n")
if (innerChildren.nonEmpty) {
innerChildren.init.foreach(_.generateTreeString(
depth + 2, lastChildren :+ false :+ false, builder))
innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
}
if (treeChildren.nonEmpty) { if (treeChildren.nonEmpty) {
treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder) treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
......
...@@ -36,11 +36,8 @@ class SparkPlanInfo( ...@@ -36,11 +36,8 @@ class SparkPlanInfo(
private[sql] object SparkPlanInfo { private[sql] object SparkPlanInfo {
def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
val children = plan match {
case WholeStageCodegen(child, _) => child :: Nil val children = plan.children ++ plan.subqueries
case InputAdapter(child) => child :: Nil
case plan => plan.children
}
val metrics = plan.metrics.toSeq.map { case (key, metric) => val metrics = plan.metrics.toSeq.map { case (key, metric) =>
new SQLMetricInfo(metric.name.getOrElse(key), metric.id, new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
Utils.getFormattedClassName(metric.param)) Utils.getFormattedClassName(metric.param))
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution package org.apache.spark.sql.execution
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.broadcast import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SQLContext
...@@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning ...@@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin} import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue import org.apache.spark.sql.execution.metric.LongSQLMetricValue
/** /**
...@@ -163,16 +161,12 @@ trait CodegenSupport extends SparkPlan { ...@@ -163,16 +161,12 @@ trait CodegenSupport extends SparkPlan {
* This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
* an RDD iterator of InternalRow. * an RDD iterator of InternalRow.
*/ */
case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def doPrepare(): Unit = {
child.prepare()
}
override def doExecute(): RDD[InternalRow] = { override def doExecute(): RDD[InternalRow] = {
child.execute() child.execute()
} }
...@@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { ...@@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
child.doExecuteBroadcast() child.doExecuteBroadcast()
} }
override def supportCodegen: Boolean = false
override def upstreams(): Seq[RDD[InternalRow]] = { override def upstreams(): Seq[RDD[InternalRow]] = {
child.execute() :: Nil child.execute() :: Nil
} }
...@@ -210,6 +202,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { ...@@ -210,6 +202,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
} }
override def simpleString: String = "INPUT" override def simpleString: String = "INPUT"
override def treeChildren: Seq[SparkPlan] = Nil
} }
/** /**
...@@ -243,22 +237,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { ...@@ -243,22 +237,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
* doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
* used to generated code for BoundReference. * used to generated code for BoundReference.
*/ */
case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {
extends SparkPlan with CodegenSupport {
override def supportCodegen: Boolean = false
override def output: Seq[Attribute] = plan.output
override def outputPartitioning: Partitioning = plan.outputPartitioning
override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
override def doPrepare(): Unit = { override def output: Seq[Attribute] = child.output
plan.prepare() override def outputPartitioning: Partitioning = child.outputPartitioning
} override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def doExecute(): RDD[InternalRow] = { override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext val ctx = new CodegenContext
val code = plan.produce(ctx, this) val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val references = ctx.references.toArray val references = ctx.references.toArray
val source = s""" val source = s"""
public Object generate(Object[] references) { public Object generate(Object[] references) {
...@@ -266,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ...@@ -266,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
} }
/** Codegened pipeline for: /** Codegened pipeline for:
* ${toCommentSafeString(plan.treeString.trim)} * ${toCommentSafeString(child.treeString.trim)}
*/ */
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
...@@ -294,7 +281,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ...@@ -294,7 +281,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
// println(s"${CodeFormatter.format(cleanedSource)}") // println(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource) CodeGenerator.compile(cleanedSource)
val rdds = plan.upstreams() val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) { if (rdds.length == 1) {
rdds.head.mapPartitions { iter => rdds.head.mapPartitions { iter =>
...@@ -361,34 +348,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ...@@ -361,34 +348,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
} }
} }
private[sql] override def resetMetrics(): Unit = { override def innerChildren: Seq[SparkPlan] = {
plan.foreach(_.resetMetrics()) child :: Nil
} }
override def generateTreeString( private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
depth: Int, case InputAdapter(c) => c :: Nil
lastChildren: Seq[Boolean], case other => other.children.flatMap(collectInputs)
builder: StringBuilder): StringBuilder = { }
if (depth > 0) {
lastChildren.init.foreach { isLast =>
val prefixFragment = if (isLast) " " else ": "
builder.append(prefixFragment)
}
val branch = if (lastChildren.last) "+- " else ":- "
builder.append(branch)
}
builder.append(simpleString)
builder.append("\n")
plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
if (children.nonEmpty) {
children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
}
builder override def treeChildren: Seq[SparkPlan] = {
collectInputs(child)
} }
override def simpleString: String = "WholeStageCodegen" override def simpleString: String = "WholeStageCodegen"
...@@ -416,27 +386,34 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru ...@@ -416,27 +386,34 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case _ => false case _ => false
} }
/**
* Inserts a InputAdapter on top of those that do not support codegen.
*/
private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
case j @ SortMergeJoin(_, _, _, left, right) =>
// The children of SortMergeJoin should do codegen separately.
j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
right = InputAdapter(insertWholeStageCodegen(right)))
case p if !supportCodegen(p) =>
// collapse them recursively
InputAdapter(insertWholeStageCodegen(p))
case p =>
p.withNewChildren(p.children.map(insertInputAdapter))
}
/**
* Inserts a WholeStageCodegen on top of those that support codegen.
*/
private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
case plan: CodegenSupport if supportCodegen(plan) =>
WholeStageCodegen(insertInputAdapter(plan))
case other =>
other.withNewChildren(other.children.map(insertWholeStageCodegen))
}
def apply(plan: SparkPlan): SparkPlan = { def apply(plan: SparkPlan): SparkPlan = {
if (sqlContext.conf.wholeStageEnabled) { if (sqlContext.conf.wholeStageEnabled) {
plan.transform { insertWholeStageCodegen(plan)
case plan: CodegenSupport if supportCodegen(plan) =>
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
b.copy(left = apply(left))
case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
case j @ SortMergeJoin(_, _, _, left, right) =>
// The children of SortMergeJoin should do codegen separately.
j.copy(left = apply(left), right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
inputs += input
InputAdapter(input)
}.asInstanceOf[CodegenSupport]
WholeStageCodegen(combined, inputs)
}
} else { } else {
plan plan
} }
......
...@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD ...@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
...@@ -68,7 +69,7 @@ package object debug { ...@@ -68,7 +69,7 @@ package object debug {
} }
} }
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport {
def output: Seq[Attribute] = child.output def output: Seq[Attribute] = child.output
implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
...@@ -86,10 +87,11 @@ package object debug { ...@@ -86,10 +87,11 @@ package object debug {
/** /**
* A collection of metrics for each column of output. * A collection of metrics for each column of output.
* @param elementTypes the actual runtime types for the output. Useful when there are bugs * @param elementTypes the actual runtime types for the output. Useful when there are bugs
* causing the wrong data to be projected. * causing the wrong data to be projected.
*/ */
case class ColumnMetrics( case class ColumnMetrics(
elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0)
val numColumns: Int = child.output.size val numColumns: Int = child.output.size
...@@ -98,7 +100,7 @@ package object debug { ...@@ -98,7 +100,7 @@ package object debug {
def dumpStats(): Unit = { def dumpStats(): Unit = {
logDebug(s"== ${child.simpleString} ==") logDebug(s"== ${child.simpleString} ==")
logDebug(s"Tuples output: ${tupleCount.value}") logDebug(s"Tuples output: ${tupleCount.value}")
child.output.zip(columnStats).foreach { case(attr, metric) => child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
} }
...@@ -108,6 +110,7 @@ package object debug { ...@@ -108,6 +110,7 @@ package object debug {
child.execute().mapPartitions { iter => child.execute().mapPartitions { iter =>
new Iterator[InternalRow] { new Iterator[InternalRow] {
def hasNext: Boolean = iter.hasNext def hasNext: Boolean = iter.hasNext
def next(): InternalRow = { def next(): InternalRow = {
val currentRow = iter.next() val currentRow = iter.next()
tupleCount += 1 tupleCount += 1
...@@ -124,5 +127,17 @@ package object debug { ...@@ -124,5 +127,17 @@ package object debug {
} }
} }
} }
override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}
override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
consume(ctx, input)
}
} }
} }
...@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong ...@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable import scala.collection.mutable
import org.apache.spark.sql.execution.{InputAdapter, SparkPlanInfo, WholeStageCodegen} import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLMetrics
/** /**
...@@ -73,36 +73,40 @@ private[sql] object SparkPlanGraph { ...@@ -73,36 +73,40 @@ private[sql] object SparkPlanGraph {
edges: mutable.ArrayBuffer[SparkPlanGraphEdge], edges: mutable.ArrayBuffer[SparkPlanGraphEdge],
parent: SparkPlanGraphNode, parent: SparkPlanGraphNode,
subgraph: SparkPlanGraphCluster): Unit = { subgraph: SparkPlanGraphCluster): Unit = {
if (planInfo.nodeName == classOf[WholeStageCodegen].getSimpleName) { planInfo.nodeName match {
val cluster = new SparkPlanGraphCluster( case "WholeStageCodegen" =>
nodeIdGenerator.getAndIncrement(), val cluster = new SparkPlanGraphCluster(
planInfo.nodeName, nodeIdGenerator.getAndIncrement(),
planInfo.simpleString, planInfo.nodeName,
mutable.ArrayBuffer[SparkPlanGraphNode]()) planInfo.simpleString,
nodes += cluster mutable.ArrayBuffer[SparkPlanGraphNode]())
buildSparkPlanGraphNode( nodes += cluster
planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster) buildSparkPlanGraphNode(
} else if (planInfo.nodeName == classOf[InputAdapter].getSimpleName) { planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null) case "InputAdapter" =>
} else { buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null)
val metrics = planInfo.metrics.map { metric => case "Subquery" if subgraph != null =>
SQLPlanMetric(metric.name, metric.accumulatorId, // Subquery should not be included in WholeStageCodegen
SQLMetrics.getMetricParam(metric.metricParam)) buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null)
} case _ =>
val node = new SparkPlanGraphNode( val metrics = planInfo.metrics.map { metric =>
nodeIdGenerator.getAndIncrement(), planInfo.nodeName, SQLPlanMetric(metric.name, metric.accumulatorId,
planInfo.simpleString, planInfo.metadata, metrics) SQLMetrics.getMetricParam(metric.metricParam))
if (subgraph == null) { }
nodes += node val node = new SparkPlanGraphNode(
} else { nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
subgraph.nodes += node planInfo.simpleString, planInfo.metadata, metrics)
} if (subgraph == null) {
nodes += node
if (parent != null) { } else {
edges += SparkPlanGraphEdge(node.id, parent.id) subgraph.nodes += node
} }
planInfo.children.foreach(
buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph)) if (parent != null) {
edges += SparkPlanGraphEdge(node.id, parent.id)
}
planInfo.children.foreach(
buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph))
} }
} }
} }
......
...@@ -31,14 +31,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { ...@@ -31,14 +31,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1")
val plan = df.queryExecution.executedPlan val plan = df.queryExecution.executedPlan
assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined) assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined)
assert(df.collect() === Array(Row(2)))
checkThatPlansAgree(
sqlContext.range(100),
(p: SparkPlan) =>
WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()),
(p: SparkPlan) => Filter('a == 1, p),
sortAnswers = false
)
} }
test("Aggregate should be included in WholeStageCodegen") { test("Aggregate should be included in WholeStageCodegen") {
...@@ -46,7 +39,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { ...@@ -46,7 +39,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan val plan = df.queryExecution.executedPlan
assert(plan.find(p => assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] && p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(9, 4.5))) assert(df.collect() === Array(Row(9, 4.5)))
} }
...@@ -55,7 +48,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { ...@@ -55,7 +48,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan val plan = df.queryExecution.executedPlan
assert(plan.find(p => assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] && p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
} }
...@@ -66,7 +59,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { ...@@ -66,7 +59,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
assert(df.queryExecution.executedPlan.find(p => assert(df.queryExecution.executedPlan.find(p =>
p.isInstanceOf[WholeStageCodegen] && p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[BroadcastHashJoin]).isDefined)
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
} }
...@@ -75,7 +68,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { ...@@ -75,7 +68,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan val plan = df.queryExecution.executedPlan
assert(plan.find(p => assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] && p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined) p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
assert(df.collect() === Array(Row(1), Row(2), Row(3))) assert(df.collect() === Array(Row(1), Row(2), Row(3)))
} }
} }
...@@ -210,8 +210,8 @@ class JDBCSuite extends SparkFunSuite ...@@ -210,8 +210,8 @@ class JDBCSuite extends SparkFunSuite
// the plan only has PhysicalRDD to scan JDBCRelation. // the plan only has PhysicalRDD to scan JDBCRelation.
assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD]) assert(node.child.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD])
assert(node.plan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) assert(node.child.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
df df
} }
assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0)
...@@ -248,7 +248,7 @@ class JDBCSuite extends SparkFunSuite ...@@ -248,7 +248,7 @@ class JDBCSuite extends SparkFunSuite
// cannot compile given predicates. // cannot compile given predicates.
assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.Filter]) assert(node.child.isInstanceOf[org.apache.spark.sql.execution.Filter])
df df
} }
assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0)
......
...@@ -93,7 +93,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { ...@@ -93,7 +93,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
val metric = qe.executedPlan match { val metric = qe.executedPlan match {
case w: WholeStageCodegen => w.plan.longMetric("numOutputRows") case w: WholeStageCodegen => w.child.longMetric("numOutputRows")
case other => other.longMetric("numOutputRows") case other => other.longMetric("numOutputRows")
} }
metrics += metric.value.value metrics += metric.value.value
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment