Skip to content
Snippets Groups Projects
Commit 20637818 authored by Xiu Guo's avatar Xiu Guo Committed by Reynold Xin
Browse files

[SPARK-13422][SQL] Use HashedRelation instead of HashSet in Left Semi Joins

Use the HashedRelation which is a more optimized datastructure and reduce code complexity

Author: Xiu Guo <xguo27@gmail.com>

Closes #11291 from xguo27/SPARK-13422.
parent 173aa949
No related branches found
No related tags found
No related merge requests found
......@@ -26,8 +26,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
* Build the right table's join keys into a HashedRelation, and iteratively go through the left
* table, to find if the join keys are in the HashedRelation.
*/
case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
......@@ -40,29 +40,18 @@ case class BroadcastLeftSemiJoinHash(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
override def requiredChildDistribution: Seq[Distribution] = {
val mode = if (condition.isEmpty) {
HashSetBroadcastMode(rightKeys, right.output)
} else {
HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
}
val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
if (condition.isEmpty) {
val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]()
left.execute().mapPartitionsInternal { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows)
}
} else {
val broadcastedRelation = right.executeBroadcast[HashedRelation]()
left.execute().mapPartitionsInternal { streamIter =>
val hashedRelation = broadcastedRelation.value
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
hashSemiJoin(streamIter, hashedRelation, numOutputRows)
}
val broadcastedRelation = right.executeBroadcast[HashedRelation]()
left.execute().mapPartitionsInternal { streamIter =>
val hashedRelation = broadcastedRelation.value
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
hashSemiJoin(streamIter, hashedRelation, numOutputRows)
}
}
}
......@@ -43,24 +43,6 @@ trait HashSemiJoin {
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
protected def buildKeyHashSet(
buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter)
}
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashSet: java.util.Set[InternalRow],
numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
val joinKeys = leftKeyGenerator
streamIter.filter(current => {
val key = joinKeys(current)
val r = !key.anyNull && hashSet.contains(key)
if (r) numOutputRows += 1
r
})
}
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation,
......@@ -70,44 +52,11 @@ trait HashSemiJoin {
streamIter.filter { current =>
val key = joinKeys(current)
lazy val rowBuffer = hashedRelation.get(key)
val r = !key.anyNull && rowBuffer != null && rowBuffer.exists {
val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
(row: InternalRow) => boundCondition(joinedRow(current, row))
}
})
if (r) numOutputRows += 1
r
}
}
}
private[execution] object HashSemiJoin {
def buildKeyHashSet(
keys: Seq[Expression],
attributes: Seq[Attribute],
rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = {
val hashSet = new java.util.HashSet[InternalRow]()
// Create a Hash set of buildKeys
val key = UnsafeProjection.create(keys, attributes)
while (rows.hasNext) {
val currentRow = rows.next()
val rowKey = key(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey.copy())
}
}
}
hashSet
}
}
/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */
private[execution] case class HashSetBroadcastMode(
keys: Seq[Expression],
attributes: Seq[Attribute]) extends BroadcastMode {
override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = {
HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator)
}
}
......@@ -25,8 +25,8 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
* Build the right table's join keys into a HashedRelation, and iteratively go through the left
* table, to find if the join keys are in the HashedRelation.
*/
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
......@@ -47,13 +47,8 @@ case class LeftSemiJoinHash(
val numOutputRows = longMetric("numOutputRows")
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
if (condition.isEmpty) {
val hashSet = buildKeyHashSet(buildIter)
hashSemiJoin(streamIter, hashSet, numOutputRows)
} else {
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation, numOutputRows)
}
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation, numOutputRows)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment