diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 1f99fbedde411d8f9ecc738818702007a5144869..d3bcfad7c3de0c88188caa079fbfad978e580825 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. + * Build the right table's join keys into a HashedRelation, and iteratively go through the left + * table, to find if the join keys are in the HashedRelation. */ case class BroadcastLeftSemiJoinHash( leftKeys: Seq[Expression], @@ -40,29 +40,18 @@ case class BroadcastLeftSemiJoinHash( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) override def requiredChildDistribution: Seq[Distribution] = { - val mode = if (condition.isEmpty) { - HashSetBroadcastMode(rightKeys, right.output) - } else { - HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output) - } + val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output) UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - if (condition.isEmpty) { - val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]() - left.execute().mapPartitionsInternal { streamIter => - hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows) - } - } else { - val broadcastedRelation = right.executeBroadcast[HashedRelation]() - left.execute().mapPartitionsInternal { streamIter => - val hashedRelation = broadcastedRelation.value - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) - hashSemiJoin(streamIter, hashedRelation, numOutputRows) - } + val broadcastedRelation = right.executeBroadcast[HashedRelation]() + left.execute().mapPartitionsInternal { streamIter => + val hashedRelation = broadcastedRelation.value + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize) + hashSemiJoin(streamIter, hashedRelation, numOutputRows) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 1cb6a00617c5eb521ff7805deafa5361af83320f..3eed6e3e11131c593f63c694dd7615c5006fb45e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -43,24 +43,6 @@ trait HashSemiJoin { @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet( - buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { - HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter) - } - - protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow], - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator - streamIter.filter(current => { - val key = joinKeys(current) - val r = !key.anyNull && hashSet.contains(key) - if (r) numOutputRows += 1 - r - }) - } - protected def hashSemiJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation, @@ -70,44 +52,11 @@ trait HashSemiJoin { streamIter.filter { current => val key = joinKeys(current) lazy val rowBuffer = hashedRelation.get(key) - val r = !key.anyNull && rowBuffer != null && rowBuffer.exists { + val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) - } + }) if (r) numOutputRows += 1 r } } } - -private[execution] object HashSemiJoin { - def buildKeyHashSet( - keys: Seq[Expression], - attributes: Seq[Attribute], - rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = { - val hashSet = new java.util.HashSet[InternalRow]() - - // Create a Hash set of buildKeys - val key = UnsafeProjection.create(keys, attributes) - while (rows.hasNext) { - val currentRow = rows.next() - val rowKey = key(currentRow) - if (!rowKey.anyNull) { - val keyExists = hashSet.contains(rowKey) - if (!keyExists) { - hashSet.add(rowKey.copy()) - } - } - } - hashSet - } -} - -/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */ -private[execution] case class HashSetBroadcastMode( - keys: Seq[Expression], - attributes: Seq[Attribute]) extends BroadcastMode { - - override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = { - HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index d8d3045ccf5c310516693515c2f0ae8644372048..242ed612327cc35324a193abbd716e7219c995cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** - * Build the right table's join keys into a HashSet, and iteratively go through the left - * table, to find the if join keys are in the Hash set. + * Build the right table's join keys into a HashedRelation, and iteratively go through the left + * table, to find if the join keys are in the HashedRelation. */ case class LeftSemiJoinHash( leftKeys: Seq[Expression], @@ -47,13 +47,8 @@ case class LeftSemiJoinHash( val numOutputRows = longMetric("numOutputRows") right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => - if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) - hashSemiJoin(streamIter, hashSet, numOutputRows) - } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) - hashSemiJoin(streamIter, hashRelation, numOutputRows) - } + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + hashSemiJoin(streamIter, hashRelation, numOutputRows) } } }