Skip to content
Snippets Groups Projects
Commit 1d7bcc88 authored by Cheng Lian's avatar Cheng Lian Committed by Michael Armbrust
Browse files

[SQL] Fixes caching related JoinSuite failure

PR #2860 refines in-memory table statistics and enables broader broadcasted hash join optimization for in-memory tables. This makes `JoinSuite` fail when some test suite caches test table `testData` and gets executed before `JoinSuite`. Because expected `ShuffledHashJoin`s are optimized to `BroadcastedHashJoin` according to collected in-memory table statistics.

This PR fixes this issue by clearing the cache before testing join operator selection. A separate test case is also added to test broadcasted hash join operator selection.

Author: Cheng Lian <lian@databricks.com>

Closes #2960 from liancheng/fix-join-suite and squashes the following commits:

715b2de [Cheng Lian] Fixes caching related JoinSuite failure
parent dea302dd
No related branches found
No related tags found
No related merge requests found
...@@ -19,17 +19,13 @@ package org.apache.spark.sql ...@@ -19,17 +19,13 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.TestData._ import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi} import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext._
class JoinSuite extends QueryTest with BeforeAndAfterEach { class JoinSuite extends QueryTest with BeforeAndAfterEach {
// Ensures tables are loaded. // Ensures tables are loaded.
TestData TestData
...@@ -41,54 +37,65 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ...@@ -41,54 +37,65 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
assert(planned.size === 1) assert(planned.size === 1)
} }
test("join operator selection") { def assertJoin(sqlString: String, c: Class[_]): Any = {
def assertJoin(sqlString: String, c: Class[_]): Any = { val rdd = sql(sqlString)
val rdd = sql(sqlString) val physical = rdd.queryExecution.sparkPlan
val physical = rdd.queryExecution.sparkPlan val operators = physical.collect {
val operators = physical.collect { case j: ShuffledHashJoin => j
case j: ShuffledHashJoin => j case j: HashOuterJoin => j
case j: HashOuterJoin => j case j: LeftSemiJoinHash => j
case j: LeftSemiJoinHash => j case j: BroadcastHashJoin => j
case j: BroadcastHashJoin => j case j: LeftSemiJoinBNL => j
case j: LeftSemiJoinBNL => j case j: CartesianProduct => j
case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j
case j: BroadcastNestedLoopJoin => j }
}
assert(operators.size === 1)
assert(operators.size === 1) if (operators(0).getClass() != c) {
if (operators(0).getClass() != c) { fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
}
} }
}
val cases1 = Seq( test("join operator selection") {
("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]), clearCache()
("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]),
("SELECT * FROM testData join testData2", classOf[CartesianProduct]), Seq(
("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]), ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
("SELECT * FROM testData left join testData2", classOf[CartesianProduct]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
("SELECT * FROM testData right join testData2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]), ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]), ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]), ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]), ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
("SELECT * FROM testData right join testData2 ON key = a where key=2", ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]),
("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
classOf[HashOuterJoin]), classOf[HashOuterJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key=2", ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[HashOuterJoin]), classOf[HashOuterJoin]),
("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]), ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin])
("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]), // TODO add BroadcastNestedLoopJoin
("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]), ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]) }
// TODO add BroadcastNestedLoopJoin
) test("broadcasted hash join operator selection") {
cases1.foreach { c => assertJoin(c._1, c._2) } clearCache()
sql("CACHE TABLE testData")
Seq(
("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
sql("UNCACHE TABLE testData")
} }
test("multiple-key equi-join is hash-join") { test("multiple-key equi-join is hash-join") {
...@@ -171,7 +178,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ...@@ -171,7 +178,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "d") :: (4, "D", 4, "d") ::
(5, "E", null, null) :: (5, "E", null, null) ::
(6, "F", null, null) :: Nil) (6, "F", null, null) :: Nil)
checkAnswer( checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
(1, "A", null, null) :: (1, "A", null, null) ::
...@@ -180,7 +187,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ...@@ -180,7 +187,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "d") :: (4, "D", 4, "d") ::
(5, "E", null, null) :: (5, "E", null, null) ::
(6, "F", null, null) :: Nil) (6, "F", null, null) :: Nil)
checkAnswer( checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
(1, "A", null, null) :: (1, "A", null, null) ::
...@@ -189,7 +196,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ...@@ -189,7 +196,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "d") :: (4, "D", 4, "d") ::
(5, "E", null, null) :: (5, "E", null, null) ::
(6, "F", null, null) :: Nil) (6, "F", null, null) :: Nil)
checkAnswer( checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
(1, "A", 1, "a") :: (1, "A", 1, "a") ::
...@@ -300,7 +307,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ...@@ -300,7 +307,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "D") :: (4, "D", 4, "D") ::
(null, null, 5, "E") :: (null, null, 5, "E") ::
(null, null, 6, "F") :: Nil) (null, null, 6, "F") :: Nil)
checkAnswer( checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
(1, "A", null, null) :: (1, "A", null, null) ::
...@@ -310,7 +317,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ...@@ -310,7 +317,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "D") :: (4, "D", 4, "D") ::
(null, null, 5, "E") :: (null, null, 5, "E") ::
(null, null, 6, "F") :: Nil) (null, null, 6, "F") :: Nil)
checkAnswer( checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
(1, "A", null, null) :: (1, "A", null, null) ::
......
...@@ -80,7 +80,7 @@ object TestData { ...@@ -80,7 +80,7 @@ object TestData {
UpperCaseData(3, "C") :: UpperCaseData(3, "C") ::
UpperCaseData(4, "D") :: UpperCaseData(4, "D") ::
UpperCaseData(5, "E") :: UpperCaseData(5, "E") ::
UpperCaseData(6, "F") :: Nil) UpperCaseData(6, "F") :: Nil).toSchemaRDD
upperCaseData.registerTempTable("upperCaseData") upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String) case class LowerCaseData(n: Int, l: String)
...@@ -89,7 +89,7 @@ object TestData { ...@@ -89,7 +89,7 @@ object TestData {
LowerCaseData(1, "a") :: LowerCaseData(1, "a") ::
LowerCaseData(2, "b") :: LowerCaseData(2, "b") ::
LowerCaseData(3, "c") :: LowerCaseData(3, "c") ::
LowerCaseData(4, "d") :: Nil) LowerCaseData(4, "d") :: Nil).toSchemaRDD
lowerCaseData.registerTempTable("lowerCaseData") lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment