From 7dd01d9c019ee8d015a82fcda5c85f66bf8a3673 Mon Sep 17 00:00:00 2001
From: Herman van Hovell <hvanhovell@questtec.nl>
Date: Wed, 27 Apr 2016 19:15:17 +0200
Subject: [PATCH] [SPARK-14950][SQL] Fix BroadcastHashJoin's unique key
 Anti-Joins

### What changes were proposed in this pull request?
Anti-Joins using BroadcastHashJoin's unique key code path are broken; it currently returns Semi Join results . This PR fixes this bug.

### How was this patch tested?
Added tests cases to `ExistenceJoinSuite`.

cc davies gatorsmile

Author: Herman van Hovell <hvanhovell@questtec.nl>

Closes #12730 from hvanhovell/SPARK-14950.
---
 .../joins/BroadcastHashJoinExec.scala         | 17 +++--
 .../execution/joins/ExistenceJoinSuite.scala  | 64 ++++++++++++++++---
 2 files changed, 67 insertions(+), 14 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 51399e1830..b0a6b8f28a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -161,7 +161,8 @@ case class BroadcastHashJoinExec(
    */
   private def getJoinCondition(
       ctx: CodegenContext,
-      input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+      input: Seq[ExprCode],
+      anti: Boolean = false): (String, String, Seq[ExprCode]) = {
     val matched = ctx.freshName("matched")
     val buildVars = genBuildSideVars(ctx, matched)
     val checkCondition = if (condition.isDefined) {
@@ -172,11 +173,18 @@ case class BroadcastHashJoinExec(
       ctx.currentVars = input ++ buildVars
       val ev =
         BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx)
+      val skipRow = if (!anti) {
+        s"${ev.isNull} || !${ev.value}"
+      } else {
+        s"!${ev.isNull} && ${ev.value}"
+      }
       s"""
          |$eval
          |${ev.code}
-         |if (${ev.isNull} || !${ev.value}) continue;
+         |if ($skipRow) continue;
        """.stripMargin
+    } else if (anti) {
+      "continue;"
     } else {
       ""
     }
@@ -351,11 +359,12 @@ case class BroadcastHashJoinExec(
    */
   private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+    val uniqueKeyCodePath = broadcastRelation.value.keyIsUnique
     val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
-    val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+    val (matched, checkCondition, _) = getJoinCondition(ctx, input, uniqueKeyCodePath)
     val numOutput = metricTerm(ctx, "numOutputRows")
 
-    if (broadcastRelation.value.keyIsUnique) {
+    if (uniqueKeyCodePath) {
       s"""
          |// generate join key for stream side
          |${keyEv.code}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index c7c10abe9a..b32b6444b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -53,12 +53,23 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
       Row(6, null)
     )), new StructType().add("c", IntegerType).add("d", DoubleType))
 
-  private lazy val condition = {
+  private lazy val rightUniqueKey = sqlContext.createDataFrame(
+    sparkContext.parallelize(Seq(
+      Row(2, 3.0),
+      Row(3, 2.0),
+      Row(4, 1.0),
+      Row(null, 5.0),
+      Row(6, null)
+    )), new StructType().add("c", IntegerType).add("d", DoubleType))
+
+  private lazy val singleConditionEQ = (left.col("a") === right.col("c")).expr
+
+  private lazy val composedConditionEQ = {
     And((left.col("a") === right.col("c")).expr,
       LessThan(left.col("b").expr, right.col("d").expr))
   }
 
-  private lazy val conditionNEQ = {
+  private lazy val composedConditionNEQ = {
     And((left.col("a") < right.col("c")).expr,
       LessThan(left.col("b").expr, right.col("d").expr))
   }
@@ -138,34 +149,67 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
   }
 
   testExistenceJoin(
-    "basic test for left semi join",
+    "test single condition (equal) for left semi join",
+    LeftSemi,
+    left,
+    right,
+    singleConditionEQ,
+    Seq(Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(6, null)))
+
+  testExistenceJoin(
+    "test composed condition (equal & non-equal) for left semi join",
     LeftSemi,
     left,
     right,
-    condition,
+    composedConditionEQ,
     Seq(Row(2, 1.0), Row(2, 1.0)))
 
   testExistenceJoin(
-    "basic test for left semi non equal join",
+    "test composed condition (both non-equal) for left semi join",
     LeftSemi,
     left,
     right,
-    conditionNEQ,
+    composedConditionNEQ,
     Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0)))
 
   testExistenceJoin(
-    "basic test for anti join",
+    "test single condition (equal) for left Anti join",
     LeftAnti,
     left,
     right,
-    condition,
+    singleConditionEQ,
+    Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0)))
+
+
+  testExistenceJoin(
+    "test single unique condition (equal) for left Anti join",
+    LeftAnti,
+    left,
+    right.select(right.col("c")).distinct(), /* Trigger BHJs unique key code path! */
+    singleConditionEQ,
+    Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0)))
+
+  testExistenceJoin(
+    "test composed condition (equal & non-equal) test for anti join",
+    LeftAnti,
+    left,
+    right,
+    composedConditionEQ,
     Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
 
   testExistenceJoin(
-    "basic test for anti non equal join",
+    "test composed condition (both non-equal) for anti join",
     LeftAnti,
     left,
     right,
-    conditionNEQ,
+    composedConditionNEQ,
     Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
+
+  testExistenceJoin(
+    "test composed unique condition (both non-equal) for anti join",
+    LeftAnti,
+    left,
+    rightUniqueKey,
+    (left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d")).expr,
+    Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null)))
 }
-- 
GitLab