diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index c646dcfa11811b83e4fcd684f3b5528d4708170c..e01f69f81359ee42f8e98f9f343f3e1a2a77a77a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -31,5 +31,6 @@ trait BroadcastMode { * IdentityBroadcastMode requires that rows are broadcasted in their original form. */ case object IdentityBroadcastMode extends BroadcastMode { + // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5fdf38c733efca3885986a47193780aeb58c64a7..dd8c96d5fa1d649d5753478fbd8410cea815e6b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -253,22 +253,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join( - CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi => + case j @ logical.Join(CanBroadcast(left), right, Inner | RightOuter, condition) => execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil - case logical.Join( - left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi => + planLater(left), planLater(right), joins.BuildLeft, j.joinType, condition) :: Nil + case j @ logical.Join(left, CanBroadcast(right), Inner | LeftOuter | LeftSemi, condition) => execution.joins.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil + planLater(left), planLater(right), joins.BuildRight, j.joinType, condition) :: Nil case _ => Nil } } object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // TODO CartesianProduct doesn't support the Left Semi Join - case logical.Join(left, right, joinType, None) if joinType != LeftSemi => + case logical.Join(left, right, Inner, None) => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil case logical.Join(left, right, Inner, Some(condition)) => execution.Filter(condition, @@ -286,6 +283,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } else { joins.BuildLeft } + // This join could be very slow or even hang forever joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), buildSide, joinType, condition) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index e8bd7f69dbab9512d8a5aec805c1e38785a14ab5..d83486df02c87d6a8ef0126cd106c8aeb7e7c7a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.joins +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -26,7 +27,6 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.{BitSet, CompactBuffer} - case class BroadcastNestedLoopJoin( left: SparkPlan, right: SparkPlan, @@ -51,125 +51,266 @@ case class BroadcastNestedLoopJoin( } private[this] def genResultProjection: InternalRow => InternalRow = { - UnsafeProjection.create(schema) + if (joinType == LeftSemi) { + UnsafeProjection.create(output, output) + } else { + // Always put the stream side on left to simplify implementation + UnsafeProjection.create(output, streamed.output ++ broadcast.output) + } } override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = { joinType match { + case Inner => + left.output ++ right.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case Inner => - // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case - left.output ++ right.output - case x => // TODO support the Left Semi Join + case LeftSemi => + left.output + case x => throw new IllegalArgumentException( s"BroadcastNestedLoopJoin should not take $x as the JoinType") } } - @transient private lazy val boundCondition = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + @transient private lazy val boundCondition = { + if (condition.isDefined) { + newPredicate(condition.get, streamed.output ++ broadcast.output) + } else { + (r: InternalRow) => true + } + } - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") + /** + * The implementation for InnerJoin. + */ + private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow - val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() + streamedIter.flatMap { streamedRow => + val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r)) + if (condition.isDefined) { + joinedRows.filter(boundCondition) + } else { + joinedRows + } + } + } + } - /** All rows that either match both-way, or rows from streamed joined with nulls. */ - val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val relation = broadcastedRelation.value + /** + * The implementation for these joins: + * + * LeftOuter with BuildRight + * RightOuter with BuildLeft + */ + private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) + + // Returns an iterator to avoid copy the rows. + new Iterator[InternalRow] { + // current row from stream side + private var streamRow: InternalRow = null + // have found a match for current row or not + private var foundMatch: Boolean = false + // the matched result row + private var resultRow: InternalRow = null + // the next index of buildRows to try + private var nextIndex: Int = 0 - val matchedRows = new CompactBuffer[InternalRow] - val includedBroadcastTuples = new BitSet(relation.length) + private def findNextMatch(): Boolean = { + if (streamRow == null) { + if (!streamedIter.hasNext) { + return false + } + streamRow = streamedIter.next() + nextIndex = 0 + foundMatch = false + } + while (nextIndex < buildRows.length) { + resultRow = joinedRow(streamRow, buildRows(nextIndex)) + nextIndex += 1 + if (boundCondition(resultRow)) { + foundMatch = true + return true + } + } + if (!foundMatch) { + resultRow = joinedRow(streamRow, nulls) + streamRow = null + true + } else { + resultRow = null + streamRow = null + findNextMatch() + } + } + + override def hasNext(): Boolean = { + resultRow != null || findNextMatch() + } + override def next(): InternalRow = { + val r = resultRow + resultRow = null + r + } + } + } + } + + /** + * The implementation for these joins: + * + * LeftSemi with BuildRight + */ + private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + assert(buildSide == BuildRight) + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value val joinedRow = new JoinedRow - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val resultProj = genResultProjection + if (condition.isDefined) { + streamedIter.filter(l => + buildRows.exists(r => boundCondition(joinedRow(l, r))) + ) + } else { + streamedIter.filter(r => !buildRows.isEmpty) + } + } + } + + /** + * The implementation for these joins: + * + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + */ + private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + /** All rows that either match both-way, or rows from streamed joined with nulls. */ + val streamRdd = streamed.execute() + + val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val matched = new BitSet(buildRows.length) + val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => var i = 0 - var streamRowMatched = false - - while (i < relation.length) { - val broadcastedRow = relation(i) - buildSide match { - case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() - streamRowMatched = true - includedBroadcastTuples.set(i) - case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() - streamRowMatched = true - includedBroadcastTuples.set(i) - case _ => + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matched.set(i) } i += 1 } + } + Seq(matched).toIterator + } - (streamRowMatched, joinType, buildSide) match { - case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() - case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() - case _ => + val matchedBroadcastRows = matchedBuildRows.fold( + new BitSet(relation.value.length) + )(_ | _) + + if (joinType == LeftSemi) { + assert(buildSide == BuildLeft) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + if (matchedBroadcastRows.get(i)) { + buf += rel(i).copy() } + i += 1 } - Iterator((matchedRows, includedBroadcastTuples)) + return sparkContext.makeRDD(buf.toSeq) } - val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = includedBroadcastTuples.fold( - new BitSet(broadcastedRelation.value.size) - )(_ | _) + val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + val nulls = new GenericMutableRow(broadcast.output.size) - val leftNulls = new GenericMutableRow(left.output.size) - val rightNulls = new GenericMutableRow(right.output.size) - val resultProj = genResultProjection + streamedIter.flatMap { streamedRow => + var i = 0 + var foundMatch = false + val matchedRows = new CompactBuffer[InternalRow] + + while (i < buildRows.length) { + if (boundCondition(joinedRow(streamedRow, buildRows(i)))) { + matchedRows += joinedRow.copy() + foundMatch = true + } + i += 1 + } + + if (!foundMatch && joinType == FullOuter) { + matchedRows += joinedRow(streamedRow, nulls).copy() + } + matchedRows.iterator + } + } - /** Rows from broadcasted joined with nulls. */ - val broadcastRowsWithNulls: Seq[InternalRow] = { + val notMatchedBroadcastRows: Seq[InternalRow] = { + val nulls = new GenericMutableRow(streamed.output.size) val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 - val rel = broadcastedRelation.value - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - val joinedRow = new JoinedRow - joinedRow.withLeft(leftNulls) - while (i < rel.length) { - if (!allIncludedBroadcastTuples.get(i)) { - buf += resultProj(joinedRow.withRight(rel(i))).copy() - } - i += 1 - } - case (LeftOuter | FullOuter, BuildLeft) => - val joinedRow = new JoinedRow - joinedRow.withRight(rightNulls) - while (i < rel.length) { - if (!allIncludedBroadcastTuples.get(i)) { - buf += resultProj(joinedRow.withLeft(rel(i))).copy() - } - i += 1 - } - case _ => + val buildRows = relation.value + val joinedRow = new JoinedRow + joinedRow.withLeft(nulls) + while (i < buildRows.length) { + if (!matchedBroadcastRows.get(i)) { + buf += joinedRow.withRight(buildRows(i)).copy() + } + i += 1 } buf.toSeq } - // TODO: Breaks lineage. sparkContext.union( - matchesOrStreamedRowsWithNulls.flatMap(_._1), - sparkContext.makeRDD(broadcastRowsWithNulls) - ).map { row => - // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here. - numOutputRows += 1 - row + matchedStreamRows, + sparkContext.makeRDD(notMatchedBroadcastRows) + ) + } + + protected override def doExecute(): RDD[InternalRow] = { + val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() + + val resultRdd = (joinType, buildSide) match { + case (Inner, _) => + innerJoin(broadcastedRelation) + case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => + outerJoin(broadcastedRelation) + case (LeftSemi, BuildRight) => + leftSemiJoin(broadcastedRelation) + case _ => + /** + * LeftOuter with BuildLeft + * RightOuter with BuildRight + * FullOuter + * LeftSemi with BuildLeft + */ + defaultJoin(broadcastedRelation) + } + + val numOutputRows = longMetric("numOutputRows") + resultRdd.mapPartitionsInternal { iter => + val resultProj = genResultProjection + iter.map { r => + numOutputRows += 1 + resultProj(r) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 41e27ec46648ff1bc520118f1c85bd994b5ad5fe..3dab848e7b0330d4a5130bc35f27df5ee03efd89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -70,13 +70,14 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]), - ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), + ("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData RIGHT JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", + classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", - classOf[CartesianProduct]), + classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index b748229e402f597a6f8f340760d271c05c3e6e4f..7eb15249ebbd69ae2fbeb1716267021de591e642 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -146,6 +146,33 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + + test(s"$testName using CartesianProduct") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + Filter(condition(), CartesianProduct(left, right)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } testInnerJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 22fe8caff265ecce2608e29ee5dc98fc5beecbca..0d1c29fe574a6adb3840f98daaf76a65352a65ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -105,6 +105,24 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } // --- Basic outer joins ------------------------------------------------------------------------ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 5c982885d652f38de4495846c31e9ff36fcc4ed4..355f916a975522253b0806c6b7c85c23d3cd641f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements @@ -103,6 +103,24 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = true) } } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } testLeftSemiJoin(