diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 51399e18301162e3021256d5acdda42b09139735..b0a6b8f28a4670e5f3713c1ed58dfbcf07fb5197 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -161,7 +161,8 @@ case class BroadcastHashJoinExec( */ private def getJoinCondition( ctx: CodegenContext, - input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { + input: Seq[ExprCode], + anti: Boolean = false): (String, String, Seq[ExprCode]) = { val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) val checkCondition = if (condition.isDefined) { @@ -172,11 +173,18 @@ case class BroadcastHashJoinExec( ctx.currentVars = input ++ buildVars val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + val skipRow = if (!anti) { + s"${ev.isNull} || !${ev.value}" + } else { + s"!${ev.isNull} && ${ev.value}" + } s""" |$eval |${ev.code} - |if (${ev.isNull} || !${ev.value}) continue; + |if ($skipRow) continue; """.stripMargin + } else if (anti) { + "continue;" } else { "" } @@ -351,11 +359,12 @@ case class BroadcastHashJoinExec( */ private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val uniqueKeyCodePath = broadcastRelation.value.keyIsUnique val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input, uniqueKeyCodePath) val numOutput = metricTerm(ctx, "numOutputRows") - if (broadcastRelation.value.keyIsUnique) { + if (uniqueKeyCodePath) { s""" |// generate join key for stream side |${keyEv.code} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index c7c10abe9aff537e362d9e5ebefcc633fbd55c3c..b32b6444b6d9a39a6fb604ee4ea2be778beef7f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -53,12 +53,23 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) - private lazy val condition = { + private lazy val rightUniqueKey = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val singleConditionEQ = (left.col("a") === right.col("c")).expr + + private lazy val composedConditionEQ = { And((left.col("a") === right.col("c")).expr, LessThan(left.col("b").expr, right.col("d").expr)) } - private lazy val conditionNEQ = { + private lazy val composedConditionNEQ = { And((left.col("a") < right.col("c")).expr, LessThan(left.col("b").expr, right.col("d").expr)) } @@ -138,34 +149,67 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } testExistenceJoin( - "basic test for left semi join", + "test single condition (equal) for left semi join", + LeftSemi, + left, + right, + singleConditionEQ, + Seq(Row(2, 1.0), Row(2, 1.0), Row(3, 3.0), Row(6, null))) + + testExistenceJoin( + "test composed condition (equal & non-equal) for left semi join", LeftSemi, left, right, - condition, + composedConditionEQ, Seq(Row(2, 1.0), Row(2, 1.0))) testExistenceJoin( - "basic test for left semi non equal join", + "test composed condition (both non-equal) for left semi join", LeftSemi, left, right, - conditionNEQ, + composedConditionNEQ, Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0))) testExistenceJoin( - "basic test for anti join", + "test single condition (equal) for left Anti join", LeftAnti, left, right, - condition, + singleConditionEQ, + Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0))) + + + testExistenceJoin( + "test single unique condition (equal) for left Anti join", + LeftAnti, + left, + right.select(right.col("c")).distinct(), /* Trigger BHJs unique key code path! */ + singleConditionEQ, + Seq(Row(1, 2.0), Row(1, 2.0), Row(null, null), Row(null, 5.0))) + + testExistenceJoin( + "test composed condition (equal & non-equal) test for anti join", + LeftAnti, + left, + right, + composedConditionEQ, Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) testExistenceJoin( - "basic test for anti non equal join", + "test composed condition (both non-equal) for anti join", LeftAnti, left, right, - conditionNEQ, + composedConditionNEQ, Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) + + testExistenceJoin( + "test composed unique condition (both non-equal) for anti join", + LeftAnti, + left, + rightUniqueKey, + (left.col("a") === rightUniqueKey.col("c") && left.col("b") < rightUniqueKey.col("d")).expr, + Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(null, null), Row(null, 5.0), Row(6, null))) }