From 3822f33f3ce1428703a4796d7a119b40a6b32259 Mon Sep 17 00:00:00 2001
From: Yin Huai <huai@cse.ohio-state.edu>
Date: Fri, 1 Aug 2014 18:52:01 -0700
Subject: [PATCH] [SPARK-2212][SQL] Hash Outer Join (follow-up bug fix).

We need to carefully set the ouputPartitioning of the HashOuterJoin Operator. Otherwise, we may not correctly handle nulls.

Author: Yin Huai <huai@cse.ohio-state.edu>

Closes #1721 from yhuai/SPARK-2212-BugFix and squashes the following commits:

ed5eef7 [Yin Huai] Correctly choosing outputPartitioning for the HashOuterJoin operator.
---
 .../apache/spark/sql/execution/joins.scala    |  9 +-
 .../org/apache/spark/sql/JoinSuite.scala      | 99 +++++++++++++++++++
 .../scala/org/apache/spark/sql/TestData.scala |  8 ++
 3 files changed, 114 insertions(+), 2 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 82f0a74b63..cc138c7499 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -158,7 +158,12 @@ case class HashOuterJoin(
     left: SparkPlan,
     right: SparkPlan) extends BinaryNode {
 
-  override def outputPartitioning: Partitioning = left.outputPartitioning
+  override def outputPartitioning: Partitioning = joinType match {
+    case LeftOuter => left.outputPartitioning
+    case RightOuter => right.outputPartitioning
+    case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+    case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
+  }
 
   override def requiredChildDistribution =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
@@ -309,7 +314,7 @@ case class HashOuterJoin(
             leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), 
             rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST))
         }
-        case x => throw new Exception(s"Need to add implementation for $x")
+        case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
       }
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 037890682f..2fc8058818 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -197,6 +197,31 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       (4, "D", 4, "d") ::
       (5, "E", null, null) ::
       (6, "F", null, null) :: Nil)
+
+    // Make sure we are choosing left.outputPartitioning as the
+    // outputPartitioning for the outer join operator.
+    checkAnswer(
+      sql(
+        """
+          |SELECT l.N, count(*)
+          |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
+          |GROUP BY l.N
+        """.stripMargin),
+      (1, 1) ::
+      (2, 1) ::
+      (3, 1) ::
+      (4, 1) ::
+      (5, 1) ::
+      (6, 1) :: Nil)
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT r.a, count(*)
+          |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
+          |GROUP BY r.a
+        """.stripMargin),
+      (null, 6) :: Nil)
   }
 
   test("right outer join") {
@@ -232,6 +257,31 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       (4, "d", 4, "D") ::
       (null, null, 5, "E") ::
       (null, null, 6, "F") :: Nil)
+
+    // Make sure we are choosing right.outputPartitioning as the
+    // outputPartitioning for the outer join operator.
+    checkAnswer(
+      sql(
+        """
+          |SELECT l.a, count(*)
+          |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
+          |GROUP BY l.a
+        """.stripMargin),
+      (null, 6) :: Nil)
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT r.N, count(*)
+          |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
+          |GROUP BY r.N
+        """.stripMargin),
+      (1, 1) ::
+      (2, 1) ::
+      (3, 1) ::
+      (4, 1) ::
+      (5, 1) ::
+      (6, 1) :: Nil)
   }
 
   test("full outer join") {
@@ -269,5 +319,54 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       (4, "D", 4, "D") ::
       (null, null, 5, "E") ::
       (null, null, 6, "F") :: Nil)
+
+    // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
+    checkAnswer(
+      sql(
+        """
+          |SELECT l.a, count(*)
+          |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
+          |GROUP BY l.a
+        """.stripMargin),
+      (null, 10) :: Nil)
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT r.N, count(*)
+          |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
+          |GROUP BY r.N
+        """.stripMargin),
+      (1, 1) ::
+      (2, 1) ::
+      (3, 1) ::
+      (4, 1) ::
+      (5, 1) ::
+      (6, 1) ::
+      (null, 4) :: Nil)
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT l.N, count(*)
+          |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
+          |GROUP BY l.N
+        """.stripMargin),
+      (1, 1) ::
+      (2, 1) ::
+      (3, 1) ::
+      (4, 1) ::
+      (5, 1) ::
+      (6, 1) ::
+      (null, 4) :: Nil)
+
+    checkAnswer(
+      sql(
+        """
+          |SELECT r.a, count(*)
+          |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
+          |GROUP BY r.a
+        """.stripMargin),
+      (null, 10) :: Nil)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 213190e812..58cee21e8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -118,6 +118,14 @@ object TestData {
     )
   nullInts.registerAsTable("nullInts")
 
+  val allNulls =
+    TestSQLContext.sparkContext.parallelize(
+      NullInts(null) ::
+      NullInts(null) ::
+      NullInts(null) ::
+      NullInts(null) :: Nil)
+  allNulls.registerAsTable("allNulls")
+
   case class NullStrings(n: Int, s: String)
   val nullStrings =
     TestSQLContext.sparkContext.parallelize(
-- 
GitLab