diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 5eb8fdf048e288533b76e90e15665fddb52113bc..e8e2a7bbabcd448a14f9d84429e036c1bd43474b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -467,50 +467,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
   }
 
   /**
-   * All the nodes that will be used to generate tree string.
-   *
-   * For example:
-   *
-   *   WholeStageCodegen
-   *   +-- SortMergeJoin
-   *       |-- InputAdapter
-   *       |   +-- Sort
-   *       +-- InputAdapter
-   *           +-- Sort
-   *
-   * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string
-   * like this:
-   *
-   *   WholeStageCodegen
-   *   : +- SortMergeJoin
-   *   :    :- INPUT
-   *   :    :- INPUT
-   *   :-  Sort
-   *   :-  Sort
-   */
-  protected def treeChildren: Seq[BaseType] = children
-
-  /**
-   * All the nodes that are parts of this node.
-   *
-   * For example:
-   *
-   *   WholeStageCodegen
-   *   +- SortMergeJoin
-   *      |-- InputAdapter
-   *      |   +-- Sort
-   *      +-- InputAdapter
-   *          +-- Sort
-   *
-   * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree
-   * string like this:
-   *
-   *   WholeStageCodegen
-   *   : +- SortMergeJoin
-   *   :    :- INPUT
-   *   :    :- INPUT
-   *   :-  Sort
-   *   :-  Sort
+   * All the nodes that are parts of this node, this is used by subquries.
    */
   protected def innerChildren: Seq[BaseType] = Nil
 
@@ -522,7 +479,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
    * `lastChildren` for the root node should be empty.
    */
   def generateTreeString(
-      depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = {
+      depth: Int,
+      lastChildren: Seq[Boolean],
+      builder: StringBuilder,
+      prefix: String = ""): StringBuilder = {
     if (depth > 0) {
       lastChildren.init.foreach { isLast =>
         val prefixFragment = if (isLast) "   " else ":  "
@@ -533,6 +493,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
       builder.append(branch)
     }
 
+    builder.append(prefix)
     builder.append(simpleString)
     builder.append("\n")
 
@@ -542,9 +503,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
       innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
     }
 
-    if (treeChildren.nonEmpty) {
-      treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
-      treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
+    if (children.nonEmpty) {
+      children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder, prefix))
+      children.last.generateTreeString(depth + 1, lastChildren :+ true, builder, prefix)
     }
 
     builder
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 37fdc362b5e4ec9e8eabda31f363fb1737aae31c..2a1ce735b74ea6ffd7a0106c70f0db2a535abee5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -245,9 +245,13 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
      """.stripMargin
   }
 
-  override def simpleString: String = "INPUT"
-
-  override def treeChildren: Seq[SparkPlan] = Nil
+  override def generateTreeString(
+      depth: Int,
+      lastChildren: Seq[Boolean],
+      builder: StringBuilder,
+      prefix: String = ""): StringBuilder = {
+    child.generateTreeString(depth, lastChildren, builder, "")
+  }
 }
 
 object WholeStageCodegenExec {
@@ -398,20 +402,13 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
      """.stripMargin.trim
   }
 
-  override def innerChildren: Seq[SparkPlan] = {
-    child :: Nil
-  }
-
-  private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
-    case InputAdapter(c) => c :: Nil
-    case other => other.children.flatMap(collectInputs)
+  override def generateTreeString(
+      depth: Int,
+      lastChildren: Seq[Boolean],
+      builder: StringBuilder,
+      prefix: String = ""): StringBuilder = {
+    child.generateTreeString(depth, lastChildren, builder, "*")
   }
-
-  override def treeChildren: Seq[SparkPlan] = {
-    collectInputs(child)
-  }
-
-  override def simpleString: String = "WholeStageCodegen"
 }
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
index 9da9df617405d8ea58a6c0b9328f42af742360d8..9a9597d3733e0aa3f859a58f2806e9d7e8068394 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -60,9 +60,6 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
   override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
     child.executeBroadcast()
   }
-
-  // Do not repeat the same tree in explain.
-  override def treeChildren: Seq[SparkPlan] = Nil
 }
 
 /**