diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 042c99db4dcffd23ee24973764de3c69703565d1..382654afacb896d55cddc8d57170a9351955348c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -108,12 +108,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- Inner joins -------------------------------------------------------------------------- case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashJoin( - leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashJoin( - leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => @@ -124,13 +124,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys( LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys( RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) => - joins.BroadcastHashOuterJoin( - leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil + Seq(joins.BroadcastHashJoin( + leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index f35efb5b24b1f2d8815a70032a50f64c2a962eb0..8626f54eb413cd2fcb0b849aa4695520fddff0bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} -import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric} +import org.apache.spark.sql.execution.metric.LongSQLMetricValue /** * An interface for those physical operators that support codegen. @@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { case _: TungstenAggregate => "agg" - case _: BroadcastHashJoin => "bhj" + case _: BroadcastHashJoin => "join" case _ => nodeName.toLowerCase } @@ -391,9 +391,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru var inputs = ArrayBuffer[SparkPlan]() val combined = plan.transform { // The build side can't be compiled together - case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) => + case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) => b.copy(left = apply(left)) - case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) => + case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) => b.copy(right = apply(right)) case p if !supportCodegen(p) => val input = apply(p) // collapse them recursively diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 985e74011daa7e0e9f2350e910548c68a7f200b3..a64da225800a308830761ca11812665309f56f33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -24,8 +24,9 @@ import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -41,6 +42,7 @@ import org.apache.spark.util.collection.CompactBuffer case class BroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + joinType: JoinType, buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, @@ -105,75 +107,144 @@ case class BroadcastHashJoin( val broadcastRelation = Await.result(broadcastFuture, timeout) streamedPlan.execute().mapPartitions { streamedIter => - val hashedRelation = broadcastRelation.value - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) - hashJoin(streamedIter, hashedRelation, numOutputRows) + val joinedRow = new JoinedRow() + val hashTable = broadcastRelation.value + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) + val keyGenerator = streamSideKeyGenerator + val resultProj = createResultProjection + + joinType match { + case Inner => + hashJoin(streamedIter, hashTable, numOutputRows) + + case LeftOuter => + streamedIter.flatMap { currentRow => + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) + } + + case RightOuter => + streamedIter.flatMap { currentRow => + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) + } + + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") + } } } - private var broadcastRelation: Broadcast[HashedRelation] = _ - // the term for hash relation - private var relationTerm: String = _ - override def upstream(): RDD[InternalRow] = { streamedPlan.asInstanceOf[CodegenSupport].upstream() } override def doProduce(ctx: CodegenContext): String = { + streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + if (joinType == Inner) { + codegenInner(ctx, input) + } else { + // LeftOuter and RightOuter + codegenOuter(ctx, input) + } + } + + /** + * Returns a tuple of Broadcast of HashedRelation and the variable name for it. + */ + private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { // create a name for HashedRelation - broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcastRelation = Await.result(broadcastFuture, timeout) val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) - relationTerm = ctx.freshName("relation") + val relationTerm = ctx.freshName("relation") val clsName = broadcastRelation.value.getClass.getName ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = ($clsName) $broadcast.value(); | incPeakExecutionMemory($relationTerm.getMemorySize()); """.stripMargin) - - s""" - | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} - """.stripMargin + (broadcastRelation, relationTerm) } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // generate the key as UnsafeRow or Long + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + private def genStreamSideJoinKey( + ctx: CodegenContext, + input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - val (keyVal, anyNull) = if (canJoinKeyFitWithinLong) { + if (canJoinKeyFitWithinLong) { + // generate the join key as Long val expr = rewriteKeyExpr(streamedKeys).head val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) (ev, ev.isNull) } else { + // generate the join key as UnsafeRow val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) (ev, s"${ev.value}.anyNull()") } + } - // find the matches from HashedRelation - val matched = ctx.freshName("matched") - - // create variables for output + /** + * Generates the code for variable of build side. + */ + private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { ctx.currentVars = null ctx.INPUT_ROW = matched - val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => - BoundReference(i, a.dataType, a.nullable).gen(ctx) + buildPlan.output.zipWithIndex.map { case (a, i) => + val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) + if (joinType == Inner) { + ev + } else { + // the variables are needed even there is no matched rows + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val code = s""" + |boolean $isNull = true; + |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, isNull, value) + } } + } + + /** + * Generates the code for Inner join. + */ + private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) val resultVars = buildSide match { - case BuildLeft => buildColumns ++ input - case BuildRight => input ++ buildColumns + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars } - val numOutput = metricTerm(ctx, "numOutputRows") + val outputCode = if (condition.isDefined) { // filter the output via condition ctx.currentVars = resultVars val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) s""" - | ${ev.code} - | if (!${ev.isNull} && ${ev.value}) { - | $numOutput.add(1); - | ${consume(ctx, resultVars)} - | } + |${ev.code} + |if (!${ev.isNull} && ${ev.value}) { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + |} """.stripMargin } else { s""" @@ -184,36 +255,110 @@ case class BroadcastHashJoin( if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashedRelation - | UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyVal.value}); - | if ($matched != null) { - | ${buildColumns.map(_.code).mkString("\n")} - | $outputCode - | } - """.stripMargin + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched != null) { + | ${buildVars.map(_.code).mkString("\n")} + | $outputCode + |} + """.stripMargin + + } else { + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { + | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + | ${buildVars.map(_.code).mkString("\n")} + | $outputCode + | } + |} + """.stripMargin + } + } + + + /** + * Generates the code for left or right outer join. + */ + private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + val numOutput = metricTerm(ctx, "numOutputRows") + + // filter the output via condition + val conditionPassed = ctx.freshName("conditionPassed") + val checkCondition = if (condition.isDefined) { + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + |boolean $conditionPassed = true; + |if ($matched != null) { + | ${ev.code} + | $conditionPassed = !${ev.isNull} && ${ev.value}; + |} + """.stripMargin + } else { + s"final boolean $conditionPassed = true;" + } + + if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |${buildVars.map(_.code).mkString("\n")} + |${checkCondition.trim} + |if (!$conditionPassed) { + | // reset to null + | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")} + |} + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin } else { val matches = ctx.freshName("matches") val bufferType = classOf[CompactBuffer[UnsafeRow]].getName val i = ctx.freshName("i") val size = ctx.freshName("size") + val found = ctx.freshName("found") s""" - | // generate join key - | ${keyVal.code} - | // find matches from HashRelation - | $bufferType $matches = ${anyNull} ? null : - | ($bufferType) $relationTerm.get(${keyVal.value}); - | if ($matches != null) { - | int $size = $matches.size(); - | for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); - | ${buildColumns.map(_.code).mkString("\n")} - | $outputCode - | } - | } - """.stripMargin + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |int $size = $matches != null ? $matches.size() : 0; + |boolean $found = false; + |// the last iteration of this loop is to emit an empty row if there is no matched rows. + |for (int $i = 0; $i <= $size; $i++) { + | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; + | ${buildVars.map(_.code).mkString("\n")} + | ${checkCondition.trim} + | if ($conditionPassed && ($i < $size || !$found)) { + | $found = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala deleted file mode 100644 index 5e8c8ca0436293276ae5bfcdbab59c80a0b9d71c..0000000000000000000000000000000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.concurrent._ -import scala.concurrent.duration._ - -import org.apache.spark.{InternalAccumulator, TaskContext} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Performs a outer hash join for two child relations. When the output RDD of this operator is - * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed - * relation is not shuffled. - */ -case class BroadcastHashOuterJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) extends BinaryNode with HashOuterJoin { - - override private[sql] lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - val timeout = { - val timeoutValue = sqlContext.conf.broadcastTimeout - if (timeoutValue < 0) { - Duration.Inf - } else { - timeoutValue.seconds - } - } - - override def requiredChildDistribution: Seq[Distribution] = - UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - - // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value - // for the same query. - @transient - private lazy val broadcastFuture = { - // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { - // Note that we use .execute().collect() because we don't want to convert data to Scala - // types - val input: Array[InternalRow] = buildPlan.execute().map { row => - row.copy() - }.collect() - val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) - sparkContext.broadcast(hashed) - } - }(BroadcastHashJoin.broadcastHashJoinExecutionContext) - } - - protected override def doPrepare(): Unit = { - broadcastFuture - } - - override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - val broadcastRelation = Await.result(broadcastFuture, timeout) - - streamedPlan.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow() - val hashTable = broadcastRelation.value - val keyGenerator = streamedKeyGenerator - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) - - val resultProj = resultProjection - joinType match { - case LeftOuter => - streamedIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) - }) - - case RightOuter => - streamedIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) - }) - - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 332a748d3bfc06f79bebddd3d16fde6a1d950918..2fe9c06cc95375427a99d9e3d518b5454eddaed6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -21,20 +21,38 @@ import java.util.NoSuchElementException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.sql.types.{IntegralType, LongType} +import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} +import org.apache.spark.util.collection.CompactBuffer trait HashJoin { self: SparkPlan => val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] + val joinType: JoinType val buildSide: BuildSide val condition: Option[Expression] val left: SparkPlan val right: SparkPlan + override def output: Seq[Attribute] = { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") + } + } + protected lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) @@ -45,8 +63,6 @@ trait HashJoin { case BuildRight => (rightKeys, leftKeys) } - override def output: Seq[Attribute] = left.output ++ right.output - /** * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * @@ -67,8 +83,17 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 + // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same + // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys + // with two same ints have hash code 0, we rotate the bits of second one. + val rotated = if (e.dataType == IntegerType) { + // (e >>> 15) | (e << 17) + BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) + } else { + e + } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType @@ -97,11 +122,13 @@ trait HashJoin { (r: InternalRow) => true } + protected def createResultProjection: (InternalRow) => InternalRow = + UnsafeProjection.create(self.schema) + protected def hashJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = - { + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ private[this] var currentHashMatches: Seq[InternalRow] = _ @@ -109,8 +136,7 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: (InternalRow) => InternalRow = - UnsafeProjection.create(self.schema) + private[this] val resultProjection = createResultProjection private[this] val joinKeys = streamSideKeyGenerator @@ -163,4 +189,73 @@ trait HashJoin { } } } + + @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() + + @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) + @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) + + protected[this] def leftOuterIterator( + key: InternalRow, + joinedRow: JoinedRow, + rightIter: Iterable[InternalRow], + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { + if (!key.anyNull) { + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } + } + } else { + List.empty + } + if (temp.isEmpty) { + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + } else { + temp + } + } else { + numOutputRows += 1 + resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + } + } + ret.iterator + } + + protected[this] def rightOuterIterator( + key: InternalRow, + leftIter: Iterable[InternalRow], + joinedRow: JoinedRow, + resultProjection: InternalRow => InternalRow, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + val ret: Iterable[InternalRow] = { + if (!key.anyNull) { + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => { + numOutputRows += 1 + resultProjection(joinedRow).copy() + } + } + } else { + List.empty + } + if (temp.isEmpty) { + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil + } else { + temp + } + } else { + numOutputRows += 1 + resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil + } + } + ret.iterator + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala deleted file mode 100644 index 9e614309de129e5d5203118685b67ffeff6d11c7..0000000000000000000000000000000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.util.collection.CompactBuffer - - -trait HashOuterJoin { - self: SparkPlan => - - val leftKeys: Seq[Expression] - val rightKeys: Seq[Expression] - val joinType: JoinType - val condition: Option[Expression] - val left: SparkPlan - val right: SparkPlan - - override def output: Seq[Attribute] = { - joinType match { - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - } - - protected[this] lazy val (buildPlan, streamedPlan) = joinType match { - case RightOuter => (left, right) - case LeftOuter => (right, left) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected[this] lazy val (buildKeys, streamedKeys) = joinType match { - case RightOuter => (leftKeys, rightKeys) - case LeftOuter => (rightKeys, leftKeys) - case x => - throw new IllegalArgumentException( - s"HashOuterJoin should not take $x as the JoinType") - } - - protected def buildKeyGenerator: Projection = - UnsafeProjection.create(buildKeys, buildPlan.output) - - protected[this] def streamedKeyGenerator: Projection = - UnsafeProjection.create(streamedKeys, streamedPlan.output) - - protected[this] def resultProjection: InternalRow => InternalRow = - UnsafeProjection.create(output, output) - - @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) - @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - - @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - } else { - (row: InternalRow) => true - } - - // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. - - protected[this] def leftOuterIterator( - key: InternalRow, - joinedRow: JoinedRow, - rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (rightIter != null) { - rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } - } - ret.iterator - } - - protected[this] def rightOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (leftIter != null) { - leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } else { - temp - } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } - } - ret.iterator - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 9a3c262e9485d0542e5043b209430abccde54797..92ff7e73fad88cc11ca2972cdfc0bf0e28b0b3f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -46,7 +46,6 @@ class JoinSuite extends QueryTest with SharedSQLContext { val operators = physical.collect { case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j - case j: BroadcastHashOuterJoin => j case j: LeftSemiJoinBNL => j case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j @@ -123,9 +122,9 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", - classOf[BroadcastHashOuterJoin]), + classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", - classOf[BroadcastHashOuterJoin]) + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } sql("UNCACHE TABLE testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 4a151179bf6f288118dab6a99faa321d729ee055..bcac660a35a6561977ee60bb1aaea8101e6e4f07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.HashMap + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext @@ -124,37 +126,65 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ignore("broadcast hash join") { val N = 100 << 20 - val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) + val M = 1 << 16 + val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("Join w long", N) { - sqlContext.range(N).join(dim, (col("id") % 60000) === col("k")).count() + sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 10174 / 10317 10.0 100.0 1.0X - Join w long codegen=true 1069 / 1107 98.0 10.2 9.5X + Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X + Join w long codegen=true 735 / 853 142.7 7.0 7.8X */ - val dim2 = broadcast(sqlContext.range(1 << 16) + val dim2 = broadcast(sqlContext.range(M) .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) runBenchmark("Join w 2 ints", N) { sqlContext.range(N).join(dim2, - (col("id") bitwiseAND 60000).cast(IntegerType) === col("k1") - && (col("id") bitwiseAND 50000).cast(IntegerType) === col("k2")).count() + (col("id") bitwiseAND M).cast(IntegerType) === col("k1") + && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count() } /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - BroadcastHashJoin: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 11435 / 11530 9.0 111.1 1.0X - Join w 2 ints codegen=true 1265 / 1424 82.0 12.2 9.0X + Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X + Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X */ + val dim3 = broadcast(sqlContext.range(M) + .selectExpr("id as k1", "id as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 longs", N) { + sqlContext.range(N).join(dim3, + (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) + .count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w 2 longs codegen=false 7877 / 8358 13.3 75.1 1.0X + Join w 2 longs codegen=true 3877 / 3937 27.0 37.0 2.0X + */ + runBenchmark("outer join w long", N) { + sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X + outer join w long codegen=true 769 / 796 136.3 7.3 19.9X + */ } ignore("rube") { @@ -175,7 +205,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("hash and BytesToBytesMap") { - val N = 50 << 20 + val N = 10 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) @@ -227,6 +257,80 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + benchmark.addCase("Java HashMap (Long)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + map.put(i.toLong, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.get(i % 100000) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (two ints) ") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + val key = (i.toLong << 32) + Integer.rotateRight(i, 15) + map.put(key, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (UnsafeRow)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[UnsafeRow, UnsafeRow]() + while (i < 65536) { + key.setInt(0, i) + value.setInt(0, i) + map.put(key, value.copy()) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + key.setInt(0, i % 100000) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -268,6 +372,9 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { hash 651 / 678 80.0 12.5 1.0X fast hash 336 / 343 155.9 6.4 1.9X arrayEqual 417 / 428 125.0 8.0 1.6X + Java HashMap (Long) 145 / 168 72.2 13.8 0.8X + Java HashMap (two ints) 157 / 164 66.8 15.0 0.8X + Java HashMap (UnsafeRow) 538 / 573 19.5 51.3 0.2X BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index aee8e84db56e2ebde37c46a82cf8f08ef13b9c72..e25b5e0610ea12c4e362835e56e15290b9e8a9bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -73,7 +73,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { } test("unsafe broadcast hash outer join updates peak execution memory") { - testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash outer join", "left_outer") } test("unsafe broadcast left semi join updates peak execution memory") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 149f34dbd748fad1f6e796f3a1b933317f5b5c3d..e22a810a6b42fe205707b715bd0420a27a22fb60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -88,7 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan) + joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan) } def makeSortMergeJoin( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 3d3e9a7b90928552e562e75be95768c5a9f3ecaf..f4b01fbad05859d57e8bfb4d4912a8ad6031e5d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -75,11 +75,16 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } if (joinType != FullOuter) { - test(s"$testName using BroadcastHashOuterJoin") { + test(s"$testName using BroadcastHashJoin") { + val buildSide = joinType match { + case LeftOuter => BuildRight + case RightOuter => BuildLeft + } extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right), + BroadcastHashJoin( + leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index f4bc9e501c21c77e471f049c86de2b30acdd828b..46bb699b780a9b57b67006e3af190531c326cd70 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -209,20 +209,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("BroadcastHashOuterJoin metrics") { + test("BroadcastHashJoin(outer) metrics") { val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value") val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value") // Assume the execution plan is - // ... -> BroadcastHashOuterJoin(nodeId = 0) + // ... -> BroadcastHashJoin(nodeId = 0) val df = df1.join(broadcast(df2), $"key" === $"key2", "left_outer") testSparkPlanMetrics(df, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( + 0L -> ("BroadcastHashJoin", Map( "number of output rows" -> 5L))) ) val df3 = df1.join(broadcast(df2), $"key" === $"key2", "right_outer") testSparkPlanMetrics(df3, 2, Map( - 0L -> ("BroadcastHashOuterJoin", Map( + 0L -> ("BroadcastHashJoin", Map( "number of output rows" -> 6L))) ) }