From 321b4f03bc983c582a3c6259019c077cdfac9d26 Mon Sep 17 00:00:00 2001
From: wangzhenhua <wangzhenhua@huawei.com>
Date: Tue, 18 Apr 2017 20:12:21 +0800
Subject: [PATCH] [SPARK-20366][SQL] Fix recursive join reordering: inside
 joins are not reordered

## What changes were proposed in this pull request?

If a plan has multi-level successive joins, e.g.:
```
         Join
         /   \
     Union   t5
      /   \
    Join  t4
    /   \
  Join  t3
  /  \
 t1   t2
```
Currently we fail to reorder the inside joins, i.e. t1, t2, t3.

In join reorder, we use `OrderedJoin` to indicate a join has been ordered, such that when transforming down the plan, these joins don't need to be rerodered again.

But there's a problem in the definition of `OrderedJoin`:
The real join node is a parameter, but not a child. This breaks the transform procedure because `mapChildren` applies transform function on parameters which should be children.

In this patch, we change `OrderedJoin` to a class having the same structure as a join node.

## How was this patch tested?

Add a corresponding test case.

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes #17668 from wzhfy/recursiveReorder.
---
 .../optimizer/CostBasedJoinReorder.scala      | 22 +++++----
 .../catalyst/optimizer/JoinReorderSuite.scala | 49 +++++++++++++++++--
 2 files changed, 58 insertions(+), 13 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
index c704c2e6d3..51eca6ca33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper}
-import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike}
+import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, JoinType}
 import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.internal.SQLConf
@@ -47,7 +47,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
       }
       // After reordering is finished, convert OrderedJoin back to Join
       result transformDown {
-        case oj: OrderedJoin => oj.join
+        case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond)
       }
     }
   }
@@ -87,22 +87,24 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr
   }
 
   private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match {
-    case j @ Join(left, right, _: InnerLike, Some(cond)) =>
+    case j @ Join(left, right, jt: InnerLike, Some(cond)) =>
       val replacedLeft = replaceWithOrderedJoin(left)
       val replacedRight = replaceWithOrderedJoin(right)
-      OrderedJoin(j.copy(left = replacedLeft, right = replacedRight))
+      OrderedJoin(replacedLeft, replacedRight, jt, Some(cond))
     case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond))) =>
       p.copy(child = replaceWithOrderedJoin(j))
     case _ =>
       plan
   }
+}
 
-  /** This is a wrapper class for a join node that has been ordered. */
-  private case class OrderedJoin(join: Join) extends BinaryNode {
-    override def left: LogicalPlan = join.left
-    override def right: LogicalPlan = join.right
-    override def output: Seq[Attribute] = join.output
-  }
+/** This is a mimic class for a join node that has been ordered. */
+case class OrderedJoin(
+    left: LogicalPlan,
+    right: LogicalPlan,
+    joinType: JoinType,
+    condition: Option[Expression]) extends BinaryNode {
+  override def output: Seq[Attribute] = left.output ++ right.output
 }
 
 /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
index 1922eb30fd..71db4e2e0e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -25,13 +25,12 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED, JOIN_REORDER_ENABLED}
+import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED}
 
 
 class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
 
-  override val conf = new SQLConf().copy(
-    CASE_SENSITIVE -> true, CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true)
+  override val conf = new SQLConf().copy(CBO_ENABLED -> true, JOIN_REORDER_ENABLED -> true)
 
   object Optimize extends RuleExecutor[LogicalPlan] {
     val batches =
@@ -212,6 +211,50 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
     }
   }
 
+  test("reorder recursively") {
+    // Original order:
+    //          Join
+    //          / \
+    //      Union  t5
+    //       / \
+    //     Join t4
+    //     / \
+    //   Join t3
+    //   / \
+    //  t1  t2
+    val bottomJoins =
+      t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) &&
+        (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+        .select(nameToAttr("t1.v-1-10"))
+
+    val originalPlan = bottomJoins
+      .union(t4.select(nameToAttr("t4.v-1-10")))
+      .join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5")))
+
+    // Should be able to reorder the bottom part.
+    // Best order:
+    //          Join
+    //          / \
+    //      Union  t5
+    //       / \
+    //     Join t4
+    //     / \
+    //   Join t2
+    //   / \
+    //  t1  t3
+    val bestBottomPlan =
+      t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
+        .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10"))
+        .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+        .select(nameToAttr("t1.v-1-10"))
+
+    val bestPlan = bestBottomPlan
+      .union(t4.select(nameToAttr("t4.v-1-10")))
+      .join(t5, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t5.v-1-5")))
+
+    assertEqualPlans(originalPlan, bestPlan)
+  }
+
   private def assertEqualPlans(
       originalPlan: LogicalPlan,
       groundTruthBestPlan: LogicalPlan): Unit = {
-- 
GitLab