Skip to content
Snippets Groups Projects
Commit 323d51f1 authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-12700] [SQL] embed condition into SMJ and BroadcastHashJoin

Currently SortMergeJoin and BroadcastHashJoin do not support condition, the need a followed Filter for that, the result projection to generate UnsafeRow could be very expensive if they generate lots of rows and could be filtered mostly by condition.

This PR brings the support of condition for SortMergeJoin and BroadcastHashJoin, just like other outer joins do.

This could improve the performance of Q72 by 7x (from 120s to 16.5s).

Author: Davies Liu <davies@databricks.com>

Closes #10653 from davies/filter_join.
parent 39ac56fc
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution package org.apache.spark.sql.execution
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
import org.apache.spark.sql.{execution, Strategy} import org.apache.spark.sql.{execution, Strategy}
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
...@@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ...@@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/ */
object EquiJoinSelection extends Strategy with PredicateHelper { object EquiJoinSelection extends Strategy with PredicateHelper {
private[this] def makeBroadcastHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: LogicalPlan,
right: LogicalPlan,
condition: Option[Expression],
side: joins.BuildSide): Seq[SparkPlan] = {
val broadcastHashJoin = execution.joins.BroadcastHashJoin(
leftKeys, rightKeys, side, planLater(left), planLater(right))
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// --- Inner joins -------------------------------------------------------------------------- // --- Inner joins --------------------------------------------------------------------------
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) joins.BroadcastHashJoin(
leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) joins.BroadcastHashJoin(
leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) => if RowOrdering.isOrderable(leftKeys) =>
val mergeJoin = joins.SortMergeJoin(
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
// --- Outer joins -------------------------------------------------------------------------- // --- Outer joins --------------------------------------------------------------------------
......
...@@ -39,6 +39,7 @@ case class BroadcastHashJoin( ...@@ -39,6 +39,7 @@ case class BroadcastHashJoin(
leftKeys: Seq[Expression], leftKeys: Seq[Expression],
rightKeys: Seq[Expression], rightKeys: Seq[Expression],
buildSide: BuildSide, buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan, left: SparkPlan,
right: SparkPlan) right: SparkPlan)
extends BinaryNode with HashJoin { extends BinaryNode with HashJoin {
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.joins package org.apache.spark.sql.execution.joins
import java.util.NoSuchElementException
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.SparkPlan
...@@ -29,6 +31,7 @@ trait HashJoin { ...@@ -29,6 +31,7 @@ trait HashJoin {
val leftKeys: Seq[Expression] val leftKeys: Seq[Expression]
val rightKeys: Seq[Expression] val rightKeys: Seq[Expression]
val buildSide: BuildSide val buildSide: BuildSide
val condition: Option[Expression]
val left: SparkPlan val left: SparkPlan
val right: SparkPlan val right: SparkPlan
...@@ -50,6 +53,12 @@ trait HashJoin { ...@@ -50,6 +53,12 @@ trait HashJoin {
protected def streamSideKeyGenerator: Projection = protected def streamSideKeyGenerator: Projection =
UnsafeProjection.create(streamedKeys, streamedPlan.output) UnsafeProjection.create(streamedKeys, streamedPlan.output)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
} else {
(r: InternalRow) => true
}
protected def hashJoin( protected def hashJoin(
streamIter: Iterator[InternalRow], streamIter: Iterator[InternalRow],
numStreamRows: LongSQLMetric, numStreamRows: LongSQLMetric,
...@@ -68,44 +77,52 @@ trait HashJoin { ...@@ -68,44 +77,52 @@ trait HashJoin {
private[this] val joinKeys = streamSideKeyGenerator private[this] val joinKeys = streamSideKeyGenerator
override final def hasNext: Boolean = override final def hasNext: Boolean = {
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || while (true) {
(streamIter.hasNext && fetchNext()) // check if it's end of current matches
if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
currentHashMatches = null
currentMatchPosition = -1
}
override final def next(): InternalRow = { // find the next match
val ret = buildSide match { while (currentHashMatches == null && streamIter.hasNext) {
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) currentStreamedRow = streamIter.next()
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) numStreamRows += 1
} val key = joinKeys(currentStreamedRow)
currentMatchPosition += 1 if (!key.anyNull) {
numOutputRows += 1 currentHashMatches = hashedRelation.get(key)
resultProjection(ret) if (currentHashMatches != null) {
} currentMatchPosition = 0
}
}
}
if (currentHashMatches == null) {
return false
}
/** // found some matches
* Searches the streamed iterator for the next row that has at least one match in hashtable. buildSide match {
* case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
* @return true if the search is successful, and false if the streamed iterator runs out of case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
* tuples. }
*/ if (boundCondition(joinRow)) {
private final def fetchNext(): Boolean = { return true
currentHashMatches = null } else {
currentMatchPosition = -1 currentMatchPosition += 1
while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
numStreamRows += 1
val key = joinKeys(currentStreamedRow)
if (!key.anyNull) {
currentHashMatches = hashedRelation.get(key)
} }
} }
false // unreachable
}
if (currentHashMatches == null) { override final def next(): InternalRow = {
false // next() could be called without calling hasNext()
if (hasNext) {
currentMatchPosition += 1
numOutputRows += 1
resultProjection(joinRow)
} else { } else {
currentMatchPosition = 0 throw new NoSuchElementException
true
} }
} }
} }
......
...@@ -78,8 +78,11 @@ trait HashOuterJoin { ...@@ -78,8 +78,11 @@ trait HashOuterJoin {
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@transient private[this] lazy val boundCondition = @transient private[this] lazy val boundCondition = if (condition.isDefined) {
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
} else {
(row: InternalRow) => true
}
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose. // iterator for performance purpose.
......
...@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} ...@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
case class SortMergeJoin( case class SortMergeJoin(
leftKeys: Seq[Expression], leftKeys: Seq[Expression],
rightKeys: Seq[Expression], rightKeys: Seq[Expression],
condition: Option[Expression],
left: SparkPlan, left: SparkPlan,
right: SparkPlan) extends BinaryNode { right: SparkPlan) extends BinaryNode {
...@@ -64,6 +65,13 @@ case class SortMergeJoin( ...@@ -64,6 +65,13 @@ case class SortMergeJoin(
val numOutputRows = longMetric("numOutputRows") val numOutputRows = longMetric("numOutputRows")
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
newPredicate(cond, left.output ++ right.output)
}.getOrElse {
(r: InternalRow) => true
}
}
new RowIterator { new RowIterator {
// The projection used to extract keys from input rows of the left child. // The projection used to extract keys from input rows of the left child.
private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output) private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
...@@ -89,26 +97,34 @@ case class SortMergeJoin( ...@@ -89,26 +97,34 @@ case class SortMergeJoin(
private[this] val resultProjection: (InternalRow) => InternalRow = private[this] val resultProjection: (InternalRow) => InternalRow =
UnsafeProjection.create(schema) UnsafeProjection.create(schema)
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
currentMatchIdx = 0
}
override def advanceNext(): Boolean = { override def advanceNext(): Boolean = {
if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) { while (currentMatchIdx >= 0) {
if (smjScanner.findNextInnerJoinRows()) { if (currentMatchIdx == currentRightMatches.length) {
currentRightMatches = smjScanner.getBufferedMatches if (smjScanner.findNextInnerJoinRows()) {
currentLeftRow = smjScanner.getStreamedRow currentRightMatches = smjScanner.getBufferedMatches
currentMatchIdx = 0 currentLeftRow = smjScanner.getStreamedRow
} else { currentMatchIdx = 0
currentRightMatches = null } else {
currentLeftRow = null currentRightMatches = null
currentMatchIdx = -1 currentLeftRow = null
currentMatchIdx = -1
return false
}
} }
}
if (currentLeftRow != null) {
joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
currentMatchIdx += 1 currentMatchIdx += 1
numOutputRows += 1 if (boundCondition(joinRow)) {
true numOutputRows += 1
} else { return true
false }
} }
false
} }
override def getRow: InternalRow = resultProjection(joinRow) override def getRow: InternalRow = resultProjection(joinRow)
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.joins package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf}
import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.Inner
...@@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Join ...@@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution._
import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLConf}
class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
import testImplicits.localSeqToDataFrameHolder import testImplicits.localSeqToDataFrameHolder
...@@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ...@@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan, leftPlan: SparkPlan,
rightPlan: SparkPlan, rightPlan: SparkPlan,
side: BuildSide) = { side: BuildSide) = {
val broadcastHashJoin = joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
} }
def makeSortMergeJoin( def makeSortMergeJoin(
...@@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ...@@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan, leftPlan: SparkPlan,
rightPlan: SparkPlan) = { rightPlan: SparkPlan) = {
val sortMergeJoin = val sortMergeJoin =
execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan) joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin) EnsureRequirements(sqlContext).apply(sortMergeJoin)
EnsureRequirements(sqlContext).apply(filteredJoin)
} }
test(s"$testName using BroadcastHashJoin (build=left)") { test(s"$testName using BroadcastHashJoin (build=left)") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment