diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 3ff37fffbd40a0a974ee38a59b0a1b2ed1f28718..0e0453b517d9284c3af654355f034d6380a7e78c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -229,8 +229,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
 
   override def simpleString: String = statePrefix + super.simpleString
 
-  override def treeChildren: Seq[PlanType] = {
-    val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e})
-    children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType])
+  /**
+   * All the subqueries of current plan.
+   */
+  def subqueries: Seq[PlanType] = {
+    expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]})
   }
+
+  override def innerChildren: Seq[PlanType] = subqueries
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 2d0bf6b375d3051fba8413e9b0cd2ec32d7431cb..6b7997e903a9986a04dbbd004c13f125a22b3f72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -447,9 +447,52 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
 
   /**
    * All the nodes that will be used to generate tree string.
+   *
+   * For example:
+   *
+   *   WholeStageCodegen
+   *   +-- SortMergeJoin
+   *       |-- InputAdapter
+   *       |   +-- Sort
+   *       +-- InputAdapter
+   *           +-- Sort
+   *
+   * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string
+   * like this:
+   *
+   *   WholeStageCodegen
+   *   : +- SortMergeJoin
+   *   :    :- INPUT
+   *   :    :- INPUT
+   *   :-  Sort
+   *   :-  Sort
    */
   protected def treeChildren: Seq[BaseType] = children
 
+  /**
+   * All the nodes that are parts of this node.
+   *
+   * For example:
+   *
+   *   WholeStageCodegen
+   *   +- SortMergeJoin
+   *      |-- InputAdapter
+   *      |   +-- Sort
+   *      +-- InputAdapter
+   *          +-- Sort
+   *
+   * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree
+   * string like this:
+   *
+   *   WholeStageCodegen
+   *   : +- SortMergeJoin
+   *   :    :- INPUT
+   *   :    :- INPUT
+   *   :-  Sort
+   *   :-  Sort
+   */
+  protected def innerChildren: Seq[BaseType] = Nil
+
   /**
    * Appends the string represent of this node and its children to the given StringBuilder.
    *
@@ -472,6 +515,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
     builder.append(simpleString)
     builder.append("\n")
 
+    if (innerChildren.nonEmpty) {
+      innerChildren.init.foreach(_.generateTreeString(
+        depth + 2, lastChildren :+ false :+ false, builder))
+      innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
+    }
+
     if (treeChildren.nonEmpty) {
       treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
       treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 4dd9928244197eafdf2a9d5b332dbc7131243ed6..9019e5dfd66c61a31d4e3ba3f6484efd6b20ef69 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -36,11 +36,8 @@ class SparkPlanInfo(
 private[sql] object SparkPlanInfo {
 
   def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
-    val children = plan match {
-      case WholeStageCodegen(child, _) => child :: Nil
-      case InputAdapter(child) => child :: Nil
-      case plan => plan.children
-    }
+
+    val children = plan.children ++ plan.subqueries
     val metrics = plan.metrics.toSeq.map { case (key, metric) =>
       new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
         Utils.getFormattedClassName(metric.param))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index cb68ca6ada366dc932b807e573eeb4ede5449c4a..6d231bf74a0e9b21ec54cbbc7cf1f0c938201b1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.execution
 
-import scala.collection.mutable.ArrayBuffer
-
 import org.apache.spark.broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
@@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.toCommentSafeString
 import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
 import org.apache.spark.sql.execution.metric.LongSQLMetricValue
 
 /**
@@ -163,16 +161,12 @@ trait CodegenSupport extends SparkPlan {
   * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
   * an RDD iterator of InternalRow.
   */
-case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
+case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport {
 
   override def output: Seq[Attribute] = child.output
   override def outputPartitioning: Partitioning = child.outputPartitioning
   override def outputOrdering: Seq[SortOrder] = child.outputOrdering
 
-  override def doPrepare(): Unit = {
-    child.prepare()
-  }
-
   override def doExecute(): RDD[InternalRow] = {
     child.execute()
   }
@@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
     child.doExecuteBroadcast()
   }
 
-  override def supportCodegen: Boolean = false
-
   override def upstreams(): Seq[RDD[InternalRow]] = {
     child.execute() :: Nil
   }
@@ -210,6 +202,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
   }
 
   override def simpleString: String = "INPUT"
+
+  override def treeChildren: Seq[SparkPlan] = Nil
 }
 
 /**
@@ -243,22 +237,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
   * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
   * used to generated code for BoundReference.
   */
-case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
-  extends SparkPlan with CodegenSupport {
-
-  override def supportCodegen: Boolean = false
-
-  override def output: Seq[Attribute] = plan.output
-  override def outputPartitioning: Partitioning = plan.outputPartitioning
-  override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
+case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {
 
-  override def doPrepare(): Unit = {
-    plan.prepare()
-  }
+  override def output: Seq[Attribute] = child.output
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
 
   override def doExecute(): RDD[InternalRow] = {
     val ctx = new CodegenContext
-    val code = plan.produce(ctx, this)
+    val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
     val references = ctx.references.toArray
     val source = s"""
       public Object generate(Object[] references) {
@@ -266,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
       }
 
       /** Codegened pipeline for:
-        * ${toCommentSafeString(plan.treeString.trim)}
+        * ${toCommentSafeString(child.treeString.trim)}
         */
       class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
 
@@ -294,7 +281,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
     // println(s"${CodeFormatter.format(cleanedSource)}")
     CodeGenerator.compile(cleanedSource)
 
-    val rdds = plan.upstreams()
+    val rdds = child.asInstanceOf[CodegenSupport].upstreams()
     assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
     if (rdds.length == 1) {
       rdds.head.mapPartitions { iter =>
@@ -361,34 +348,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
     }
   }
 
-  private[sql] override def resetMetrics(): Unit = {
-    plan.foreach(_.resetMetrics())
+  override def innerChildren: Seq[SparkPlan] = {
+    child :: Nil
   }
 
-  override def generateTreeString(
-      depth: Int,
-      lastChildren: Seq[Boolean],
-      builder: StringBuilder): StringBuilder = {
-    if (depth > 0) {
-      lastChildren.init.foreach { isLast =>
-        val prefixFragment = if (isLast) "   " else ":  "
-        builder.append(prefixFragment)
-      }
-
-      val branch = if (lastChildren.last) "+- " else ":- "
-      builder.append(branch)
-    }
-
-    builder.append(simpleString)
-    builder.append("\n")
-
-    plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
-    if (children.nonEmpty) {
-      children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
-      children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
-    }
+  private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
+    case InputAdapter(c) => c :: Nil
+    case other => other.children.flatMap(collectInputs)
+  }
 
-    builder
+  override def treeChildren: Seq[SparkPlan] = {
+    collectInputs(child)
   }
 
   override def simpleString: String = "WholeStageCodegen"
@@ -416,27 +386,34 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
     case _ => false
   }
 
+  /**
+   * Inserts a InputAdapter on top of those that do not support codegen.
+   */
+  private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
+    case j @ SortMergeJoin(_, _, _, left, right) =>
+      // The children of SortMergeJoin should do codegen separately.
+      j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
+        right = InputAdapter(insertWholeStageCodegen(right)))
+    case p if !supportCodegen(p) =>
+      // collapse them recursively
+      InputAdapter(insertWholeStageCodegen(p))
+    case p =>
+      p.withNewChildren(p.children.map(insertInputAdapter))
+  }
+
+  /**
+   * Inserts a WholeStageCodegen on top of those that support codegen.
+   */
+  private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
+    case plan: CodegenSupport if supportCodegen(plan) =>
+      WholeStageCodegen(insertInputAdapter(plan))
+    case other =>
+      other.withNewChildren(other.children.map(insertWholeStageCodegen))
+  }
+
   def apply(plan: SparkPlan): SparkPlan = {
     if (sqlContext.conf.wholeStageEnabled) {
-      plan.transform {
-        case plan: CodegenSupport if supportCodegen(plan) =>
-          var inputs = ArrayBuffer[SparkPlan]()
-          val combined = plan.transform {
-            // The build side can't be compiled together
-            case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
-              b.copy(left = apply(left))
-            case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
-              b.copy(right = apply(right))
-            case j @ SortMergeJoin(_, _, _, left, right) =>
-              // The children of SortMergeJoin should do codegen separately.
-              j.copy(left = apply(left), right = apply(right))
-            case p if !supportCodegen(p) =>
-              val input = apply(p)  // collapse them recursively
-              inputs += input
-              InputAdapter(input)
-          }.asInstanceOf[CodegenSupport]
-          WholeStageCodegen(combined, inputs)
-      }
+      insertWholeStageCodegen(plan)
     } else {
       plan
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 95d033bc575489493d0e3adb3cbe94bd837473a0..fed88b8c0a1179f0678c6f9f6e0720ff600b7543 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.trees.TreeNodeRef
 import org.apache.spark.sql.internal.SQLConf
 
@@ -68,7 +69,7 @@ package object debug {
     }
   }
 
-  private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
+  private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport {
     def output: Seq[Attribute] = child.output
 
     implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
@@ -86,10 +87,11 @@ package object debug {
     /**
      * A collection of metrics for each column of output.
      * @param elementTypes the actual runtime types for the output.  Useful when there are bugs
-     *        causing the wrong data to be projected.
+     *                     causing the wrong data to be projected.
      */
     case class ColumnMetrics(
-        elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
+      elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
+
     val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0)
 
     val numColumns: Int = child.output.size
@@ -98,7 +100,7 @@ package object debug {
     def dumpStats(): Unit = {
       logDebug(s"== ${child.simpleString} ==")
       logDebug(s"Tuples output: ${tupleCount.value}")
-      child.output.zip(columnStats).foreach { case(attr, metric) =>
+      child.output.zip(columnStats).foreach { case (attr, metric) =>
         val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
         logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
       }
@@ -108,6 +110,7 @@ package object debug {
       child.execute().mapPartitions { iter =>
         new Iterator[InternalRow] {
           def hasNext: Boolean = iter.hasNext
+
           def next(): InternalRow = {
             val currentRow = iter.next()
             tupleCount += 1
@@ -124,5 +127,17 @@ package object debug {
         }
       }
     }
+
+    override def upstreams(): Seq[RDD[InternalRow]] = {
+      child.asInstanceOf[CodegenSupport].upstreams()
+    }
+
+    override def doProduce(ctx: CodegenContext): String = {
+      child.asInstanceOf[CodegenSupport].produce(ctx, this)
+    }
+
+    override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+      consume(ctx, input)
+    }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 4eb248569b2810d8ee2f1893d76361b5b5643ef9..12e586ada59765e565be610a3be267236bde393b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong
 
 import scala.collection.mutable
 
-import org.apache.spark.sql.execution.{InputAdapter, SparkPlanInfo, WholeStageCodegen}
+import org.apache.spark.sql.execution.SparkPlanInfo
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
 /**
@@ -73,36 +73,40 @@ private[sql] object SparkPlanGraph {
       edges: mutable.ArrayBuffer[SparkPlanGraphEdge],
       parent: SparkPlanGraphNode,
       subgraph: SparkPlanGraphCluster): Unit = {
-    if (planInfo.nodeName == classOf[WholeStageCodegen].getSimpleName) {
-      val cluster = new SparkPlanGraphCluster(
-        nodeIdGenerator.getAndIncrement(),
-        planInfo.nodeName,
-        planInfo.simpleString,
-        mutable.ArrayBuffer[SparkPlanGraphNode]())
-      nodes += cluster
-      buildSparkPlanGraphNode(
-        planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
-    } else if (planInfo.nodeName == classOf[InputAdapter].getSimpleName) {
-      buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null)
-    } else {
-      val metrics = planInfo.metrics.map { metric =>
-        SQLPlanMetric(metric.name, metric.accumulatorId,
-          SQLMetrics.getMetricParam(metric.metricParam))
-      }
-      val node = new SparkPlanGraphNode(
-        nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
-        planInfo.simpleString, planInfo.metadata, metrics)
-      if (subgraph == null) {
-        nodes += node
-      } else {
-        subgraph.nodes += node
-      }
-
-      if (parent != null) {
-        edges += SparkPlanGraphEdge(node.id, parent.id)
-      }
-      planInfo.children.foreach(
-        buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph))
+    planInfo.nodeName match {
+      case "WholeStageCodegen" =>
+        val cluster = new SparkPlanGraphCluster(
+          nodeIdGenerator.getAndIncrement(),
+          planInfo.nodeName,
+          planInfo.simpleString,
+          mutable.ArrayBuffer[SparkPlanGraphNode]())
+        nodes += cluster
+        buildSparkPlanGraphNode(
+          planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
+      case "InputAdapter" =>
+        buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null)
+      case "Subquery" if subgraph != null =>
+        // Subquery should not be included in WholeStageCodegen
+        buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null)
+      case _ =>
+        val metrics = planInfo.metrics.map { metric =>
+          SQLPlanMetric(metric.name, metric.accumulatorId,
+            SQLMetrics.getMetricParam(metric.metricParam))
+        }
+        val node = new SparkPlanGraphNode(
+          nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
+          planInfo.simpleString, planInfo.metadata, metrics)
+        if (subgraph == null) {
+          nodes += node
+        } else {
+          subgraph.nodes += node
+        }
+
+        if (parent != null) {
+          edges += SparkPlanGraphEdge(node.id, parent.id)
+        }
+        planInfo.children.foreach(
+          buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph))
     }
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index de371d85d9fd70dc39fa2c6f60027264212b7152..e00c762c670548c8396c51e5801e3eaf036113a0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -31,14 +31,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
     val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1")
     val plan = df.queryExecution.executedPlan
     assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined)
-
-    checkThatPlansAgree(
-      sqlContext.range(100),
-      (p: SparkPlan) =>
-        WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()),
-      (p: SparkPlan) => Filter('a == 1, p),
-      sortAnswers = false
-    )
+    assert(df.collect() === Array(Row(2)))
   }
 
   test("Aggregate should be included in WholeStageCodegen") {
@@ -46,7 +39,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
     val plan = df.queryExecution.executedPlan
     assert(plan.find(p =>
       p.isInstanceOf[WholeStageCodegen] &&
-        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+        p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
     assert(df.collect() === Array(Row(9, 4.5)))
   }
 
@@ -55,7 +48,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
     val plan = df.queryExecution.executedPlan
     assert(plan.find(p =>
       p.isInstanceOf[WholeStageCodegen] &&
-        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+        p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
     assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
   }
 
@@ -66,7 +59,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
     val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
     assert(df.queryExecution.executedPlan.find(p =>
       p.isInstanceOf[WholeStageCodegen] &&
-        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
+        p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[BroadcastHashJoin]).isDefined)
     assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
   }
 
@@ -75,7 +68,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
     val plan = df.queryExecution.executedPlan
     assert(plan.find(p =>
       p.isInstanceOf[WholeStageCodegen] &&
-        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined)
+        p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
     assert(df.collect() === Array(Row(1), Row(2), Row(3)))
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 4358c7c76da844c292de623cfe3baa291be6de03..b0d64aa7bf875fca4db018f0c96f3572ddc5709b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -210,8 +210,8 @@ class JDBCSuite extends SparkFunSuite
       // the plan only has PhysicalRDD to scan JDBCRelation.
       assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
       val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
-      assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD])
-      assert(node.plan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
+      assert(node.child.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD])
+      assert(node.child.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
       df
     }
     assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0)
@@ -248,7 +248,7 @@ class JDBCSuite extends SparkFunSuite
       // cannot compile given predicates.
       assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
       val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
-      assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.Filter])
+      assert(node.child.isInstanceOf[org.apache.spark.sql.execution.Filter])
       df
     }
     assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 15a95623d1e5cc442d4adfeaa3c6f8c189fce112..e7d2b5ad96821f3f0f9bfb7ad0583de7c8c48006 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -93,7 +93,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
 
       override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
         val metric = qe.executedPlan match {
-          case w: WholeStageCodegen => w.plan.longMetric("numOutputRows")
+          case w: WholeStageCodegen => w.child.longMetric("numOutputRows")
           case other => other.longMetric("numOutputRows")
         }
         metrics += metric.value.value