Skip to content
Snippets Groups Projects
Commit 39ccabac authored by Reynold Xin's avatar Reynold Xin Committed by Michael Armbrust
Browse files

[SPARK-3861][SQL] Avoid rebuilding hash tables for broadcast joins on each partition

Author: Reynold Xin <rxin@apache.org>

Closes #2727 from rxin/SPARK-3861-broadcast-hash-2 and squashes the following commits:

9c7b1a2 [Reynold Xin] Revert "Reuse CompactBuffer in UniqueKeyHashedRelation."
97626a1 [Reynold Xin] Reuse CompactBuffer in UniqueKeyHashedRelation.
7fcffb5 [Reynold Xin] Make UniqueKeyHashedRelation private[joins].
18eb214 [Reynold Xin] Merge branch 'SPARK-3861-broadcast-hash' into SPARK-3861-broadcast-hash-1
4b9d0c9 [Reynold Xin] UniqueKeyHashedRelation.get should return null if the value is null.
e0ebdd1 [Reynold Xin] Added a test case.
90b58c0 [Reynold Xin] [SPARK-3861] Avoid rebuilding hash tables on each partition
0c0082b [Reynold Xin] Fix line length.
cbc664c [Reynold Xin] Rename join -> joins package.
a070d44 [Reynold Xin] Fix line length in HashJoin
a39be8c [Reynold Xin] [SPARK-3857] Create a join package for various join operators.
parent 942847fd
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,7 @@ import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
......@@ -49,14 +49,16 @@ case class BroadcastHashJoin(
@transient
private val broadcastFuture = future {
sparkContext.broadcast(buildPlan.executeCollect())
val input: Array[Row] = buildPlan.executeCollect()
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
sparkContext.broadcast(hashed)
}
override def execute() = {
val broadcastRelation = Await.result(broadcastFuture, 5.minute)
streamedPlan.execute().mapPartitions { streamedIter =>
joinIterators(broadcastRelation.value.iterator, streamedIter)
hashJoin(streamedIter, broadcastRelation.value)
}
}
}
......@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer
......@@ -43,34 +43,14 @@ trait HashJoin {
override def output = left.output ++ right.output
@transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output)
@transient protected lazy val streamSideKeyGenerator =
@transient protected lazy val buildSideKeyGenerator: Projection =
newProjection(buildKeys, buildPlan.output)
@transient protected lazy val streamSideKeyGenerator: () => MutableProjection =
newMutableProjection(streamedKeys, streamedPlan.output)
protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] =
protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] =
{
// TODO: Use Spark's HashMap implementation.
val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]()
var currentRow: Row = null
// Create a mapping of buildKeys -> rows
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new CompactBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
existingMatchList
}
matchList += currentRow.copy()
}
}
new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatches: CompactBuffer[Row] = _
......@@ -107,7 +87,7 @@ trait HashJoin {
while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatches = hashTable.get(joinKeys.currentValue)
currentHashMatches = hashedRelation.get(joinKeys.currentValue)
}
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.joins
import java.util.{HashMap => JavaHashMap}
import org.apache.spark.sql.catalyst.expressions.{Projection, Row}
import org.apache.spark.util.collection.CompactBuffer
/**
* Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
* object.
*/
private[joins] sealed trait HashedRelation {
def get(key: Row): CompactBuffer[Row]
}
/**
* A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
*/
private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]])
extends HashedRelation with Serializable {
override def get(key: Row) = hashTable.get(key)
}
/**
* A specialized [[HashedRelation]] that maps key into a single value. This implementation
* assumes the key is unique.
*/
private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row])
extends HashedRelation with Serializable {
override def get(key: Row) = {
val v = hashTable.get(key)
if (v eq null) null else CompactBuffer(v)
}
def getValue(key: Row): Row = hashTable.get(key)
}
// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
private[joins] object HashedRelation {
def apply(
input: Iterator[Row],
keyGenerator: Projection,
sizeEstimate: Int = 64): HashedRelation = {
// TODO: Use Spark's HashMap implementation.
val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate)
var currentRow: Row = null
// Whether the join key is unique. If the key is unique, we can convert the underlying
// hash map into one specialized for this.
var keyIsUnique = true
// Create a mapping of buildKeys -> rows
while (input.hasNext) {
currentRow = input.next()
val rowKey = keyGenerator(currentRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new CompactBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
keyIsUnique = false
existingMatchList
}
matchList += currentRow.copy()
}
}
if (keyIsUnique) {
val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size)
val iter = hashTable.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
uniqHashTable.put(entry.getKey, entry.getValue()(0))
}
new UniqueKeyHashedRelation(uniqHashTable)
} else {
new GeneralHashedRelation(hashTable)
}
}
}
......@@ -42,8 +42,9 @@ case class ShuffledHashJoin(
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
override def execute() = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) {
(buildIter, streamIter) => joinIterators(buildIter, streamIter)
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
hashJoin(streamIter, hashed)
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.joins
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions.{Projection, Row}
import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends FunSuite {
// Key is simply the record itself
private val keyProjection = new Projection {
override def apply(row: Row): Row = row
}
test("GeneralHashedRelation") {
val data = Array(Row(0), Row(1), Row(2), Row(2))
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])
assert(hashed.get(data(0)) == CompactBuffer[Row](data(0)))
assert(hashed.get(data(1)) == CompactBuffer[Row](data(1)))
assert(hashed.get(Row(10)) === null)
val data2 = CompactBuffer[Row](data(2))
data2 += data(2)
assert(hashed.get(data(2)) == data2)
}
test("UniqueKeyHashedRelation") {
val data = Array(Row(0), Row(1), Row(2))
val hashed = HashedRelation(data.iterator, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
assert(hashed.get(data(0)) == CompactBuffer[Row](data(0)))
assert(hashed.get(data(1)) == CompactBuffer[Row](data(1)))
assert(hashed.get(data(2)) == CompactBuffer[Row](data(2)))
assert(hashed.get(Row(10)) === null)
val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
assert(uniqHashed.getValue(data(0)) == data(0))
assert(uniqHashed.getValue(data(1)) == data(1))
assert(uniqHashed.getValue(data(2)) == data(2))
assert(uniqHashed.getValue(Row(10)) == null)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment