diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d82d19185be045b7c1d695bf437df1e0e9825323..e8ee64756d5d0dca4800332b4ca60bef42638cd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -27,6 +27,8 @@ abstract class BaseMutableProjection extends MutableProjection /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. + * It exposes a `target` method, which is used to set the row that will be updated. + * The internal [[MutableRow]] object created internally is used only when `target` is not used. */ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index ea09e029da901b956e4e56350b3cfb8e90030dd3..9873630937d313deb5882c79c0175e1498764057 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.types._ /** - * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new - * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. + * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update + * itself based on a new input [[InternalRow]] for a fixed set of [[Expression Expressions]]. */ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala new file mode 100644 index 0000000000000000000000000000000000000000..52dcb9e43c4e8906ee32c6b83d62370040e6f546 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala @@ -0,0 +1,76 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} + +/** + * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of + * `buildSide`. The actual work of this node is defined in [[HashJoinNode]]. + */ +case class BinaryHashJoinNode( + conf: SQLConf, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + buildSide: BuildSide, + left: LocalNode, + right: LocalNode) + extends BinaryLocalNode(conf) with HashJoinNode { + + protected override val (streamedNode, streamedKeys) = buildSide match { + case BuildLeft => (right, rightKeys) + case BuildRight => (left, leftKeys) + } + + private val (buildNode, buildKeys) = buildSide match { + case BuildLeft => (left, leftKeys) + case BuildRight => (right, rightKeys) + } + + override def output: Seq[Attribute] = left.output ++ right.output + + private def buildSideKeyGenerator: Projection = { + // We are expecting the data types of buildKeys and streamedKeys are the same. + assert(buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)) + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildNode.output) + } else { + newMutableProjection(buildKeys, buildNode.output)() + } + } + + protected override def doOpen(): Unit = { + buildNode.open() + val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) + // We have built the HashedRelation. So, close buildNode. + buildNode.close() + + streamedNode.open() + // Set the HashedRelation used by the HashJoinNode. + withHashedRelation(hashedRelation) + } + + override def close(): Unit = { + // Please note that we do not need to call the close method of our buildNode because + // it has been called in this.open. + streamedNode.close() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala new file mode 100644 index 0000000000000000000000000000000000000000..cd1c86516ec5f73cfb4da20ebe88efc4851da91e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala @@ -0,0 +1,59 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} + +/** + * A [[HashJoinNode]] for broadcast join. It takes a streamedNode and a broadcast + * [[HashedRelation]]. The actual work of this node is defined in [[HashJoinNode]]. + */ +case class BroadcastHashJoinNode( + conf: SQLConf, + streamedKeys: Seq[Expression], + streamedNode: LocalNode, + buildSide: BuildSide, + buildOutput: Seq[Attribute], + hashedRelation: Broadcast[HashedRelation]) + extends UnaryLocalNode(conf) with HashJoinNode { + + override val child = streamedNode + + // Because we do not pass in the buildNode, we take the output of buildNode to + // create the inputSet properly. + override def inputSet: AttributeSet = AttributeSet(child.output ++ buildOutput) + + override def output: Seq[Attribute] = buildSide match { + case BuildRight => streamedNode.output ++ buildOutput + case BuildLeft => buildOutput ++ streamedNode.output + } + + protected override def doOpen(): Unit = { + streamedNode.open() + // Set the HashedRelation used by the HashJoinNode. + withHashedRelation(hashedRelation.value) + } + + override def close(): Unit = { + streamedNode.close() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala index e7b24e3fca2b44dfd61ae084bb2edf8caf280813..b1dc719ca85087464a808bba2058b91ea6440f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -17,27 +17,23 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.execution.metric.SQLMetrics /** + * An abstract node for sharing common functionality among different implementations of + * inner hash equi-join, notably [[BinaryHashJoinNode]] and [[BroadcastHashJoinNode]]. + * * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]]. */ -case class HashJoinNode( - conf: SQLConf, - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - buildSide: BuildSide, - left: LocalNode, - right: LocalNode) extends BinaryLocalNode(conf) { - - private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { - case BuildLeft => (left, leftKeys, right, rightKeys) - case BuildRight => (right, rightKeys, left, leftKeys) - } +trait HashJoinNode { + + self: LocalNode => + + protected def streamedKeys: Seq[Expression] + protected def streamedNode: LocalNode + protected def buildSide: BuildSide private[this] var currentStreamedRow: InternalRow = _ private[this] var currentHashMatches: Seq[InternalRow] = _ @@ -49,23 +45,14 @@ case class HashJoinNode( private[this] var hashed: HashedRelation = _ private[this] var joinKeys: Projection = _ - override def output: Seq[Attribute] = left.output ++ right.output - - private[this] def isUnsafeMode: Boolean = { - (codegenEnabled && unsafeEnabled - && UnsafeProjection.canSupport(buildKeys) - && UnsafeProjection.canSupport(schema)) - } - - private[this] def buildSideKeyGenerator: Projection = { - if (isUnsafeMode) { - UnsafeProjection.create(buildKeys, buildNode.output) - } else { - newMutableProjection(buildKeys, buildNode.output)() - } + protected def isUnsafeMode: Boolean = { + (codegenEnabled && + unsafeEnabled && + UnsafeProjection.canSupport(schema) && + UnsafeProjection.canSupport(streamedKeys)) } - private[this] def streamSideKeyGenerator: Projection = { + private def streamSideKeyGenerator: Projection = { if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedNode.output) } else { @@ -73,10 +60,21 @@ case class HashJoinNode( } } + /** + * Sets the HashedRelation used by this node. This method needs to be called after + * before the first `next` gets called. + */ + protected def withHashedRelation(hashedRelation: HashedRelation): Unit = { + hashed = hashedRelation + } + + /** + * Custom open implementation to be overridden by subclasses. + */ + protected def doOpen(): Unit + override def open(): Unit = { - buildNode.open() - hashed = HashedRelation(buildNode, buildSideKeyGenerator) - streamedNode.open() + doOpen() joinRow = new JoinedRow resultProjection = { if (isUnsafeMode) { @@ -128,9 +126,4 @@ case class HashJoinNode( } resultProjection(ret) } - - override def close(): Unit = { - left.close() - right.close() - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 5c1bdb088eeedd93768203f5060bb0185cce7382..8c2e78b2a9db7b8c4ae55dc3532ed7ee05684220 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.local +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.broadcast.TorrentBroadcast import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} - +import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, UnsafeProjection, Expression} +import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} class HashJoinNodeSuite extends LocalNodeTest { @@ -33,6 +36,35 @@ class HashJoinNodeSuite extends LocalNodeTest { } } + /** + * Builds a [[HashedRelation]] based on a resolved `buildKeys` + * and a resolved `buildNode`. + */ + private def buildHashedRelation( + conf: SQLConf, + buildKeys: Seq[Expression], + buildNode: LocalNode): HashedRelation = { + + val isUnsafeMode = + conf.codegenEnabled && + conf.unsafeEnabled && + UnsafeProjection.canSupport(buildKeys) + + val buildSideKeyGenerator = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildNode.output) + } else { + new InterpretedMutableProjection(buildKeys, buildNode.output) + } + + buildNode.prepare() + buildNode.open() + val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator) + buildNode.close() + + hashedRelation + } + /** * Test inner hash join with varying degrees of matches. */ @@ -51,20 +83,51 @@ class HashJoinNodeSuite extends LocalNodeTest { val rightInputMap = rightInput.toMap val leftNode = new DummyNode(joinNameAttributes, leftInput) val rightNode = new DummyNode(joinNicknameAttributes, rightInput) - val makeNode = (node1: LocalNode, node2: LocalNode) => { - resolveExpressions(new HashJoinNode( - conf, Seq('id1), Seq('id2), buildSide, node1, node2)) + val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => { + val binaryHashJoinNode = + BinaryHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2) + resolveExpressions(binaryHashJoinNode) + } + val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => { + val leftKeys = Seq('id1.attr) + val rightKeys = Seq('id2.attr) + // Figure out the build side and stream side. + val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match { + case BuildLeft => (node1, leftKeys, node2, rightKeys) + case BuildRight => (node2, rightKeys, node1, leftKeys) + } + // Resolve the expressions of the build side and then create a HashedRelation. + val resolvedBuildNode = resolveExpressions(buildNode) + val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode) + val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode) + val broadcastHashedRelation = mock(classOf[TorrentBroadcast[HashedRelation]]) + when(broadcastHashedRelation.value).thenReturn(hashedRelation) + + val hashJoinNode = + BroadcastHashJoinNode( + conf, + streamedKeys, + streamedNode, + buildSide, + resolvedBuildNode.output, + broadcastHashedRelation) + resolveExpressions(hashJoinNode) } - val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode - val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = leftInput .filter { case (k, _) => rightInputMap.contains(k) } .map { case (k, v) => (k, v, k, rightInputMap(k)) } - val actualOutput = hashJoinNode.collect().map { row => - // (id, name, id, nickname) - (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + + Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode => + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput === expectedOutput) } - assert(actualOutput === expectedOutput) } test(s"$testNamePrefix: empty") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 098050bcd22368a3a05a1fbf1d10ff745cb70111..615c4170936129a018df35337685f737d0625913 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference} import org.apache.spark.sql.types.{IntegerType, StringType} @@ -67,4 +67,22 @@ class LocalNodeTest extends SparkFunSuite { } } + /** + * Resolve all expressions in `expressions` based on the `output` of `localNode`. + * It assumes that all expressions in the `localNode` are resolved. + */ + protected def resolveExpressions( + expressions: Seq[Expression], + localNode: LocalNode): Seq[Expression] = { + require(localNode.expressions.forall(_.resolved)) + val inputMap = localNode.output.map { a => (a.name, a) }.toMap + expressions.map { expression => + expression.transformUp { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + } + }