From 95e1ab223e87fc216f3256d404fe3be50d111a9d Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Thu, 18 Feb 2016 15:15:06 -0800 Subject: [PATCH] [SPARK-13237] [SQL] generated broadcast outer join This PR support codegen for broadcast outer join. In order to reduce the duplicated codes, this PR merge HashJoin and HashOuterJoin together (also BroadcastHashJoin and BroadcastHashOuterJoin). Author: Davies Liu <davies@databricks.com> Closes #11130 from davies/gen_out. --- .../spark/sql/execution/SparkStrategies.scala | 16 +- .../sql/execution/WholeStageCodegen.scala | 8 +- .../execution/joins/BroadcastHashJoin.scala | 253 ++++++++++++++---- .../joins/BroadcastHashOuterJoin.scala | 121 --------- .../spark/sql/execution/joins/HashJoin.scala | 111 +++++++- .../sql/execution/joins/HashOuterJoin.scala | 153 ----------- .../org/apache/spark/sql/JoinSuite.scala | 5 +- .../BenchmarkWholeStageCodegen.scala | 131 ++++++++- .../execution/joins/BroadcastJoinSuite.scala | 2 +- .../sql/execution/joins/InnerJoinSuite.scala | 2 +- .../sql/execution/joins/OuterJoinSuite.scala | 9 +- .../execution/metric/SQLMetricsSuite.scala | 8 +- 12 files changed, 448 insertions(+), 371 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 042c99db4d..382654afac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -108,12 +108,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Inner joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashJoin( - leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashJoin( - leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => @@ -124,13 +124,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys( RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index f35efb5b24..8626f54eb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} -import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue /** * An interface for those physical operators that support codegen. @@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { case _: TungstenAggregate => "agg" - case _: BroadcastHashJoin => "bhj" + case _: BroadcastHashJoin => "join" case _ => nodeName.toLowerCase } @@ -391,9 +391,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { // The build side can't be compiled together - case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) => + case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) => b.copy(left = apply(left)) - case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) => + case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) => b.copy(right = apply(right)) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 985e74011d..a64da22580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -24,8 +24,9 @@ import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer case class BroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, @@ -105,75 +107,144 @@ case class BroadcastHashJoin( val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => - val hashedRelation = broadcastRelation.value - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) - hashJoin(streamedIter, hashedRelation, numOutputRows) + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) + val keyGenerator = streamSideKeyGenerator + val resultProj = createResultProjection + + joinType match { + case Inner => + hashJoin(streamedIter, hashTable, numOutputRows) + + case LeftOuter => + streamedIter.flatMap { currentRow => + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) + } + + case RightOuter => + streamedIter.flatMap { currentRow => + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) + } + + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") + } } } - private var broadcastRelation: Broadcast[HashedRelation] = _ - // the term for hash relation - private var relationTerm: String = _ - override def upstream(): RDD[InternalRow] = { streamedPlan.asInstanceOf[CodegenSupport].upstream() } override def doProduce(ctx: CodegenContext): String = { + streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + if (joinType == Inner) { + codegenInner(ctx, input) + } else { + // LeftOuter and RightOuter + codegenOuter(ctx, input) + } + } + + /** + * Returns a tuple of Broadcast of HashedRelation and the variable name for it. + */ + private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { // create a name for HashedRelation - broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcastRelation = Await.result(broadcastFuture, timeout) val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) - relationTerm = ctx.freshName("relation") + val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = ($clsName) $broadcast.value(); | incPeakExecutionMemory($relationTerm.getMemorySize()); """.stripMargin) - - s""" - | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} - """.stripMargin + (broadcastRelation, relationTerm) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // generate the key as UnsafeRow or Long + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey( + ctx: CodegenContext, + input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { + if (canJoinKeyFitWithinLong) { + // generate the join key as Long val expr = rewriteKeyExpr(streamedKeys).head val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) (ev, ev.isNull) } else { + // generate the join key as UnsafeRow val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) (ev, s"${ev.value}.anyNull()") } + } - // find the matches from HashedRelation - val matched = ctx.freshName("matched") - - // create variables for output + /** + * Generates the code for variable of build side. + */ + private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { ctx.currentVars = null ctx.INPUT_ROW = matched - val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => - BoundReference(i, a.dataType, a.nullable).gen(ctx) + buildPlan.output.zipWithIndex.map { case (a, i) => + val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) + if (joinType == Inner) { + ev + } else { + // the variables are needed even there is no matched rows + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val code = s""" + |boolean $isNull = true; + |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, isNull, value) + } } + } + + /** + * Generates the code for Inner join. + */ + private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) val resultVars = buildSide match { - case BuildLeft => buildColumns ++ input - case BuildRight => input ++ buildColumns + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars } - val numOutput = metricTerm(ctx, "numOutputRows") + val outputCode = if (condition.isDefined) { // filter the output via condition ctx.currentVars = resultVars val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) s""" - | ${ev.code} - | if (!${ev.isNull} && ${ev.value}) { - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - | } + |${ev.code} + |if (!${ev.isNull} && ${ev.value}) { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + |} """.stripMargin } else { s""" @@ -184,36 +255,110 @@ case class BroadcastHashJoin( if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashedRelation - | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); - | if ($matched != null) { - | ${buildColumns.map(_.code).mkString("\n")} - | $outputCode - | } - """.stripMargin + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched != null) { + | ${buildVars.map(_.code).mkString("\n")} + | $outputCode + |} + """.stripMargin + + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | ${buildVars.map(_.code).mkString("\n")} + | $outputCode + | } + |} + """.stripMargin + } + } + + + /** + * Generates the code for left or right outer join. + */ + private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + val numOutput = metricTerm(ctx, "numOutputRows") + + // filter the output via condition + val conditionPassed = ctx.freshName("conditionPassed") + val checkCondition = if (condition.isDefined) { + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + |boolean $conditionPassed = true; + |if ($matched != null) { + | ${ev.code} + | $conditionPassed = !${ev.isNull} && ${ev.value}; + |} + """.stripMargin + } else { + s"final boolean $conditionPassed = true;" + } + + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |${buildVars.map(_.code).mkString("\n")} + |${checkCondition.trim} + |if (!$conditionPassed) { + | // reset to null + | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")} + |} + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin } else { val matches = ctx.freshName("matches") val bufferType = classOf[CompactBuffer[UnsafeRow]].getName val i = ctx.freshName("i") val size = ctx.freshName("size") + val found = ctx.freshName("found") s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashRelation - | $bufferType $matches = ${anyNull} ? null : - | ($bufferType) $relationTerm.get(${keyVal.value}); - | if ($matches != null) { - | int $size = $matches.size(); - | for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); - | ${buildColumns.map(_.code).mkString("\n")} - | $outputCode - | } - | } - """.stripMargin + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |int $size = $matches != null ? $matches.size() : 0; + |boolean $found = false; + |// the last iteration of this loop is to emit an empty row if there is no matched rows. + |for (int $i = 0; $i <= $size; $i++) { + | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; + | ${buildVars.map(_.code).mkString("\n")} + | ${checkCondition.trim} + | if ($conditionPassed && ($i < $size || !$found)) { + | $found = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala deleted file mode 100644 index 5e8c8ca043..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.{InternalAccumulator, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Performs a outer hash join for two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -case class BroadcastHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - val timeout = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map { row => - row.copy() - }.collect() - val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) - sparkContext.broadcast(hashed) - } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - } - - protected override def doPrepare(): Unit = { - broadcastFuture - } - - override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val broadcastRelation = Await.result(broadcastFuture, timeout) - - streamedPlan.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow() - val hashTable = broadcastRelation.value - val keyGenerator = streamedKeyGenerator - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) - - val resultProj = resultProjection - joinType match { - case LeftOuter => - streamedIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) - }) - - case RightOuter => - streamedIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) - }) - - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 332a748d3b..2fe9c06cc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -21,20 +21,38 @@ import java.util.NoSuchElementException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.sql.types.{IntegralType, LongType} +import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} +import org.apache.spark.util.collection.CompactBuffer trait HashJoin { self: SparkPlan => val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] + val joinType: JoinType val buildSide: BuildSide val condition: Option[Expression] val left: SparkPlan val right: SparkPlan + override def output: Seq[Attribute] = { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") + } + } + protected lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) @@ -45,8 +63,6 @@ trait HashJoin { case BuildRight => (rightKeys, leftKeys) } - override def output: Seq[Attribute] = left.output ++ right.output - /** * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * @@ -67,8 +83,17 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 + // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same + // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys + // with two same ints have hash code 0, we rotate the bits of second one. + val rotated = if (e.dataType == IntegerType) { + // (e >>> 15) | (e << 17) + BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) + } else { + e + } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType @@ -97,11 +122,13 @@ trait HashJoin { (r: InternalRow) => true } + protected def createResultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(self.schema) + protected def hashJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = - { + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ private[this] var currentHashMatches: Seq[InternalRow] = _ @@ -109,8 +136,7 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = - UnsafeProjection.create(self.schema) + private[this] val resultProjection = createResultProjection private[this] val joinKeys = streamSideKeyGenerator @@ -163,4 +189,73 @@ trait HashJoin { } } } + + @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() + + @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) + @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) + + protected[this] def leftOuterIterator( + key: InternalRow, + joinedRow: JoinedRow, + rightIter: Iterable[InternalRow], + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { + if (!key.anyNull) { + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } + } + } else { + List.empty + } + if (temp.isEmpty) { + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + } else { + temp + } + } else { + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + } + } + ret.iterator + } + + protected[this] def rightOuterIterator( + key: InternalRow, + leftIter: Iterable[InternalRow], + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { + if (!key.anyNull) { + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } + } + } else { + List.empty + } + if (temp.isEmpty) { + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil + } else { + temp + } + } else { + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil + } + } + ret.iterator + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala deleted file mode 100644 index 9e614309de..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.util.collection.CompactBuffer - - -trait HashOuterJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val joinType: JoinType - val condition: Option[Expression] - val left: SparkPlan - val right: SparkPlan - - override def output: Seq[Attribute] = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - } - - protected[this] lazy val (buildPlan, streamedPlan) = joinType match { - case RightOuter => (left, right) - case LeftOuter => (right, left) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected[this] lazy val (buildKeys, streamedKeys) = joinType match { - case RightOuter => (leftKeys, rightKeys) - case LeftOuter => (rightKeys, leftKeys) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected def buildKeyGenerator: Projection = - UnsafeProjection.create(buildKeys, buildPlan.output) - - protected[this] def streamedKeyGenerator: Projection = - UnsafeProjection.create(streamedKeys, streamedPlan.output) - - protected[this] def resultProjection: InternalRow => InternalRow = - UnsafeProjection.create(output, output) - - @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) - @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - - @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - } else { - (row: InternalRow) => true - } - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - protected[this] def leftOuterIterator( - key: InternalRow, - joinedRow: JoinedRow, - rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (rightIter != null) { - rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } - } - ret.iterator - } - - protected[this] def rightOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (leftIter != null) { - leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } - } - ret.iterator - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 9a3c262e94..92ff7e73fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -46,7 +46,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { val operators = physical.collect { case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j - case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j @@ -123,9 +122,9 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), + classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } sql("UNCACHE TABLE testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 4a151179bf..bcac660a35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.HashMap + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext @@ -124,37 +126,65 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ignore("broadcast hash join") { val N = 100 << 20 - val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) + val M = 1 << 16 + val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("Join w long", N) { - sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count() + sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X - Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X + Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X + Join w long codegen=true 735 / 853 142.7 7.0 7.8X */ - val dim2 = broadcast(sqlContext.range(1 << 16) + val dim2 = broadcast(sqlContext.range(M) .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) runBenchmark("Join w 2 ints", N) { sqlContext.range(N).join(dim2, - (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1") - && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count() + (col("id") bitwiseAND M).cast(IntegerType) === col("k1") + && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count() } /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X - Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X + Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X + Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X */ + val dim3 = broadcast(sqlContext.range(M) + .selectExpr("id as k1", "id as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 longs", N) { + sqlContext.range(N).join(dim3, + (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) + .count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w 2 longs codegen=false 7877 / 8358 13.3 75.1 1.0X + Join w 2 longs codegen=true 3877 / 3937 27.0 37.0 2.0X + */ + runBenchmark("outer join w long", N) { + sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X + outer join w long codegen=true 769 / 796 136.3 7.3 19.9X + */ } ignore("rube") { @@ -175,7 +205,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("hash and BytesToBytesMap") { - val N = 50 << 20 + val N = 10 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) @@ -227,6 +257,80 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + benchmark.addCase("Java HashMap (Long)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + map.put(i.toLong, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.get(i % 100000) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (two ints) ") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + val key = (i.toLong << 32) + Integer.rotateRight(i, 15) + map.put(key, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (UnsafeRow)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[UnsafeRow, UnsafeRow]() + while (i < 65536) { + key.setInt(0, i) + value.setInt(0, i) + map.put(key, value.copy()) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + key.setInt(0, i % 100000) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -268,6 +372,9 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { hash 651 / 678 80.0 12.5 1.0X fast hash 336 / 343 155.9 6.4 1.9X arrayEqual 417 / 428 125.0 8.0 1.6X + Java HashMap (Long) 145 / 168 72.2 13.8 0.8X + Java HashMap (two ints) 157 / 164 66.8 15.0 0.8X + Java HashMap (UnsafeRow) 538 / 573 19.5 51.3 0.2X BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index aee8e84db5..e25b5e0610 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -73,7 +73,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { } test("unsafe broadcast hash outer join updates peak execution memory") { - testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash outer join", "left_outer") } test("unsafe broadcast left semi join updates peak execution memory") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 149f34dbd7..e22a810a6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -88,7 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan) + joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan) } def makeSortMergeJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 3d3e9a7b90..f4b01fbad0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -75,11 +75,16 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { + test(s"$testName using BroadcastHashJoin") { + val buildSide = joinType match { + case LeftOuter => BuildRight + case RightOuter => BuildLeft + } extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + BroadcastHashJoin( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index f4bc9e501c..46bb699b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -209,20 +209,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("BroadcastHashOuterJoin metrics") { + test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") // Assume the execution plan is - // ... -> BroadcastHashOuterJoin(nodeId = 0) + // ... -> BroadcastHashJoin(nodeId = 0) val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( + 0L -> ("BroadcastHashJoin", Map( "number of output rows" -> 5L))) ) val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") testSparkPlanMetrics(df3, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( + 0L -> ("BroadcastHashJoin", Map( "number of output rows" -> 6L))) ) } -- GitLab