From 7dd01d9c019ee8d015a82fcda5c85f66bf8a3673 Mon Sep 17 00:00:00 2001 From: Herman van Hovell <hvanhovell@questtec.nl> Date: Wed, 27 Apr 2016 19:15:17 +0200 Subject: [PATCH] [SPARK-14950][SQL] Fix BroadcastHashJoin's unique key Anti-Joins ### What changes were proposed in this pull request? Anti-Joins using BroadcastHashJoin's unique key code path are broken; it currently returns Semi Join results . This PR fixes this bug. ### How was this patch tested? Added tests cases to `ExistenceJoinSuite`. cc davies gatorsmile Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #12730 from hvanhovell/SPARK-14950. --- .../joins/BroadcastHashJoinExec.scala | 17 +++-- .../execution/joins/ExistenceJoinSuite.scala | 64 ++++++++++++++++--- 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 51399e1830..b0a6b8f28a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -161,7 +161,8 @@ case class BroadcastHashJoinExec( */ private def getJoinCondition( ctx: CodegenContext, - input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { + input: Seq[ExprCode], + anti: Boolean = false): (String, String, Seq[ExprCode]) = { val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) val checkCondition = if (condition.isDefined) { @@ -172,11 +173,18 @@ case class BroadcastHashJoinExec( ctx.currentVars = input ++ buildVars val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + val skipRow = if (!anti) { + s"${ev.isNull} || !${ev.value}" + } else { + s"!${ev.isNull} && ${ev.value}" + } s""" |$eval |${ev.code} - |if (${ev.isNull} || !${ev.value}) continue; + |if ($skipRow) continue; """.stripMargin + } else if (anti) { + "continue;" } else { "" } @@ -351,11 +359,12 @@ case class BroadcastHashJoinExec( */ private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val uniqueKeyCodePath = broadcastRelation.value.keyIsUnique val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input, uniqueKeyCodePath) val numOutput = metricTerm(ctx, "numOutputRows") - if (broadcastRelation.value.keyIsUnique) { + if (uniqueKeyCodePath) { s""" |// generate join key for stream side |${keyEv.code} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index c7c10abe9a..b32b6444b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -53,12 +53,23 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) - private lazy val condition = { + private lazy val rightUniqueKey = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val singleConditionEQ = (left.col("a") === right.col("c")).expr + + private lazy val composedConditionEQ = { And((left.col("a") === right.col("c")).expr, LessThan(left.col("b").expr, right.col("d").expr)) } - private lazy val conditionNEQ = { + private lazy val composedConditionNEQ = { And((left.col("a") < right.col("c")).expr, LessThan(left.col("b").expr, right.col("d").expr)) } @@ -138,34 +149,67 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } testExistenceJoin( - "basic test for left semi join", + "test single condition (equal) for left semi join", + LeftSemi, + left, + right, + singleConditionEQ, + Seq(Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(6, null))) + + testExistenceJoin( + "test composed condition (equal & non-equal) for left semi join", LeftSemi, left, right, - condition, + composedConditionEQ, Seq(Row(2, 1.0), Row(2, 1.0))) testExistenceJoin( - "basic test for left semi non equal join", + "test composed condition (both non-equal) for left semi join", LeftSemi, left, right, - conditionNEQ, + composedConditionNEQ, Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0))) testExistenceJoin( - "basic test for anti join", + "test single condition (equal) for left Anti join", LeftAnti, left, right, - condition, + singleConditionEQ, + Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0))) + + + testExistenceJoin( + "test single unique condition (equal) for left Anti join", + LeftAnti, + left, + right.select(right.col("c")).distinct(), /* Trigger BHJs unique key code path! */ + singleConditionEQ, + Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0))) + + testExistenceJoin( + "test composed condition (equal & non-equal) test for anti join", + LeftAnti, + left, + right, + composedConditionEQ, Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) testExistenceJoin( - "basic test for anti non equal join", + "test composed condition (both non-equal) for anti join", LeftAnti, left, right, - conditionNEQ, + composedConditionNEQ, Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) + + testExistenceJoin( + "test composed unique condition (both non-equal) for anti join", + LeftAnti, + left, + rightUniqueKey, + (left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d")).expr, + Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null))) } -- GitLab