From 945d8bcbf67032edd7bdd201cf9f88c75b3464f7 Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <viirya@appier.com>
Date: Sun, 26 Jul 2015 22:13:37 -0700
Subject: [PATCH] [SPARK-9306] [SQL] Don't use SortMergeJoin when joining on
 unsortable columns

JIRA: https://issues.apache.org/jira/browse/SPARK-9306

Author: Liang-Chi Hsieh <viirya@appier.com>

Closes #7645 from viirya/smj_unsortable and squashes the following commits:

a240707 [Liang-Chi Hsieh] Use forall instead of exists for readability.
55221fa [Liang-Chi Hsieh] Shouldn't use SortMergeJoin when joining on unsortable columns.
---
 .../sql/catalyst/planning/patterns.scala      |  2 +-
 .../spark/sql/execution/SparkStrategies.scala | 19 +++++++++++++++----
 .../org/apache/spark/sql/JoinSuite.scala      | 12 ++++++++++++
 3 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index b8e3b0d53a..1e7b2a536a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -184,7 +184,7 @@ object PartialAggregation {
  * A pattern that finds joins with equality conditions that can be evaluated using equi-join.
  */
 object ExtractEquiJoinKeys extends Logging with PredicateHelper {
-  /** (joinType, rightKeys, leftKeys, condition, leftChild, rightChild) */
+  /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */
   type ReturnType =
     (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index deeea3900c..306bbfec62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -35,9 +35,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
   object LeftSemiJoin extends Strategy with PredicateHelper {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
-        if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
-          right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
+      case ExtractEquiJoinKeys(
+             LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
         joins.BroadcastLeftSemiJoinHash(
           leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
       // Find left semi joins where at least some predicates can be evaluated by matching join keys
@@ -90,6 +89,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
     }
 
+    private[this] def isValidSort(
+        leftKeys: Seq[Expression],
+        rightKeys: Seq[Expression]): Boolean = {
+      leftKeys.zip(rightKeys).forall { keys =>
+        (keys._1.dataType, keys._2.dataType) match {
+          case (l: AtomicType, r: AtomicType) => true
+          case (NullType, NullType) => true
+          case _ => false
+        }
+      }
+    }
+
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
         makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
@@ -100,7 +111,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       // If the sort merge join option is set, we want to use sort merge join prior to hashjoin
       // for now let's support inner join first, then add outer join
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
-        if sqlContext.conf.sortMergeJoinEnabled =>
+        if sqlContext.conf.sortMergeJoinEnabled && isValidSort(leftKeys, rightKeys) =>
         val mergeJoin =
           joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
         condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 8953889d1f..dfb2a7e099 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -108,6 +108,18 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
     }
   }
 
+  test("SortMergeJoin shouldn't work on unsortable columns") {
+    val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
+    try {
+      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
+      Seq(
+        ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin])
+      ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    } finally {
+      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
+    }
+  }
+
   test("broadcasted hash join operator selection") {
     ctx.cacheManager.clearCache()
     ctx.sql("CACHE TABLE testData")
-- 
GitLab