Skip to content
Snippets Groups Projects
Commit 07fa1910 authored by wangxiaojing's avatar wangxiaojing Committed by Michael Armbrust
Browse files

[SPARK-4570][SQL]add BroadcastLeftSemiJoinHash

JIRA issue: [SPARK-4570](https://issues.apache.org/jira/browse/SPARK-4570)
We are planning to create a `BroadcastLeftSemiJoinHash` to implement the broadcast join for `left semijoin`
In left semijoin :
If the size of data from right side is smaller than the user-settable threshold `AUTO_BROADCASTJOIN_THRESHOLD`,
the planner would mark it as the `broadcast` relation and mark the other relation as the stream side. The broadcast table will be broadcasted to all of the executors involved in the join, as a `org.apache.spark.broadcast.Broadcast` object. It will use `joins.BroadcastLeftSemiJoinHash`.,else it will use `joins.LeftSemiJoinHash`.

The benchmark suggests these  made the optimized version 4x faster  when `left semijoin`
<pre><code>
Original:
left semi join : 9288 ms
Optimized:
left semi join : 1963 ms
</code></pre>
The micro benchmark load `data1/kv3.txt` into a normal Hive table.
Benchmark code:
<pre><code>
 def benchmark(f: => Unit) = {
    val begin = System.currentTimeMillis()
    f
    val end = System.currentTimeMillis()
    end - begin
  }
  val sc = new SparkContext(
    new SparkConf()
      .setMaster("local")
      .setAppName(getClass.getSimpleName.stripSuffix("$")))
  val hiveContext = new HiveContext(sc)
  import hiveContext._
  sql("drop table if exists left_table")
  sql("drop table if exists right_table")
  sql( """create table left_table (key int, value string)
       """.stripMargin)
  sql( s"""load data local inpath "/data1/kv3.txt" into table left_table""")
  sql( """create table right_table (key int, value string)
       """.stripMargin)
  sql(
    """
      |from left_table
      |insert overwrite table right_table
      |select left_table.key, left_table.value
    """.stripMargin)

  val leftSimeJoin = sql(
    """select a.key from left_table a
      |left semi join right_table b on a.key = b.key""".stripMargin)
  val leftSemiJoinDuration = benchmark(leftSimeJoin.count())
  println(s"left semi join : $leftSemiJoinDuration ms ")
</code></pre>

Author: wangxiaojing <u9jing@gmail.com>

Closes #3442 from wangxiaojing/SPARK-4570 and squashes the following commits:

a4a43c9 [wangxiaojing] rebase
f103983 [wangxiaojing] change style
fbe4887 [wangxiaojing] change style
ff2e618 [wangxiaojing] add testsuite
1a8da2a [wangxiaojing] add BroadcastLeftSemiJoinHash
parent 8f29b7ca
No related branches found
No related tags found
No related merge requests found
...@@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ...@@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object LeftSemiJoin extends Strategy with PredicateHelper { object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
if sqlContext.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
val semiJoin = joins.BroadcastLeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// Find left semi joins where at least some predicates can be evaluated by matching join keys // Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = joins.LeftSemiJoinHash( val semiJoin = joins.LeftSemiJoinHash(
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.expressions.{Expression, Row}
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
/**
* :: DeveloperApi ::
* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
*/
@DeveloperApi
case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode with HashJoin {
override val buildSide = BuildRight
override def output = left.output
override def execute() = {
val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator
val hashSet = new java.util.HashSet[Row]()
var currentRow: Row = null
// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if (!rowKey.anyNull) {
val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
hashSet.add(rowKey)
}
}
}
val broadcastedRelation = sparkContext.broadcast(hashSet)
streamedPlan.execute().mapPartitions { streamIter =>
val joinKeys = streamSideKeyGenerator()
streamIter.filter(current => {
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
})
}
}
}
...@@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ...@@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
case j: LeftSemiJoinBNL => j case j: LeftSemiJoinBNL => j
case j: CartesianProduct => j case j: CartesianProduct => j
case j: BroadcastNestedLoopJoin => j case j: BroadcastNestedLoopJoin => j
case j: BroadcastLeftSemiJoinHash => j
} }
assert(operators.size === 1) assert(operators.size === 1)
...@@ -382,4 +383,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ...@@ -382,4 +383,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
""".stripMargin), """.stripMargin),
(null, 10) :: Nil) (null, 10) :: Nil)
} }
test("broadcasted left semi join operator selection") {
clearCache()
sql("CACHE TABLE testData")
val tmp = autoBroadcastJoinThreshold
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000")
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[BroadcastLeftSemiJoinHash])
).foreach {
case (query, joinClass) => assertJoin(query, joinClass)
}
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
).foreach {
case (query, joinClass) => assertJoin(query, joinClass)
}
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString)
sql("UNCACHE TABLE testData")
}
test("left semi join") {
val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(rdd,
(1, 1) ::
(1, 2) ::
(2, 1) ::
(2, 2) ::
(3, 1) ::
(3, 2) :: Nil)
}
} }
...@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll ...@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.{SQLConf, QueryTest}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.execution._
...@@ -193,4 +193,52 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ...@@ -193,4 +193,52 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
) )
} }
test("auto converts to broadcast left semi join, by size estimate of a relation") {
val leftSemiJoinQuery =
"""SELECT * FROM src a
|left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
val answer = (86, "val_86") :: Nil
var rdd = sql(leftSemiJoinQuery)
// Assert src has a size smaller than the threshold.
val sizes = rdd.queryExecution.analyzed.collect {
case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass
.isAssignableFrom(r.getClass) =>
r.statistics.sizeInBytes
}
assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold
&& sizes(0) <= autoBroadcastJoinThreshold,
s"query should contain two relations, each of which has size smaller than autoConvertSize")
// Using `sparkPlan` because for relevant patterns in HashJoin to be
// matched, other strategies need to be applied.
var bhj = rdd.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
}
assert(bhj.size === 1,
s"actual query plans do not contain broadcast join: ${rdd.queryExecution}")
checkAnswer(rdd, answer) // check correctness of output
TestHive.settings.synchronized {
val tmp = autoBroadcastJoinThreshold
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
rdd = sql(leftSemiJoinQuery)
bhj = rdd.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
}
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
val shj = rdd.queryExecution.sparkPlan.collect {
case j: LeftSemiJoinHash => j
}
assert(shj.size === 1,
"LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp")
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment