Skip to content
Snippets Groups Projects
Commit ece00566 authored by Andrew Or's avatar Andrew Or Committed by Michael Armbrust
Browse files

[SPARK-9561] Re-enable BroadcastJoinSuite

We can do this now that SPARK-9580 is resolved.

Author: Andrew Or <andrew@databricks.com>

Closes #8208 from andrewor14/reenable-sql-tests.
parent 3bc55287
No related branches found
No related tags found
No related merge requests found
......@@ -15,80 +15,73 @@
* limitations under the License.
*/
// TODO: uncomment the test here! It is currently failing due to
// bad interaction with org.apache.spark.sql.test.TestSQLContext.
package org.apache.spark.sql.execution.joins
// scalastyle:off
//package org.apache.spark.sql.execution.joins
//
//import scala.reflect.ClassTag
//
//import org.scalatest.BeforeAndAfterAll
//
//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
//import org.apache.spark.sql.functions._
//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest}
//
///**
// * Test various broadcast join operators with unsafe enabled.
// *
// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs
// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode.
// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in
// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without
// * serializing the hashed relation, which does not happen in local mode.
// */
//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
// private var sc: SparkContext = null
// private var sqlContext: SQLContext = null
//
// /**
// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled.
// */
// override def beforeAll(): Unit = {
// super.beforeAll()
// val conf = new SparkConf()
// .setMaster("local-cluster[2,1,1024]")
// .setAppName("testing")
// sc = new SparkContext(conf)
// sqlContext = new SQLContext(sc)
// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true)
// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
// }
//
// override def afterAll(): Unit = {
// sc.stop()
// sc = null
// sqlContext = null
// }
//
// /**
// * Test whether the specified broadcast join updates the peak execution memory accumulator.
// */
// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) {
// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
// // Comparison at the end is for broadcast left semi join
// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
// val df3 = df1.join(broadcast(df2), joinExpression, joinType)
// val plan = df3.queryExecution.executedPlan
// assert(plan.collect { case p: T => p }.size === 1)
// plan.executeCollect()
// }
// }
//
// test("unsafe broadcast hash join updates peak execution memory") {
// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner")
// }
//
// test("unsafe broadcast hash outer join updates peak execution memory") {
// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
// }
//
// test("unsafe broadcast left semi join updates peak execution memory") {
// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
// }
//
//}
// scalastyle:on
import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest}
/**
* Test various broadcast join operators with unsafe enabled.
*
* Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of
* unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered
* without serializing the hashed relation, which does not happen in local mode.
*/
class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
private var sc: SparkContext = null
private var sqlContext: SQLContext = null
/**
* Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled.
*/
override def beforeAll(): Unit = {
super.beforeAll()
val conf = new SparkConf()
.setMaster("local-cluster[2,1,1024]")
.setAppName("testing")
sc = new SparkContext(conf)
sqlContext = new SQLContext(sc)
sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true)
sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
}
override def afterAll(): Unit = {
sc.stop()
sc = null
sqlContext = null
}
/**
* Test whether the specified broadcast join updates the peak execution memory accumulator.
*/
private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) {
val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
// Comparison at the end is for broadcast left semi join
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
val df3 = df1.join(broadcast(df2), joinExpression, joinType)
val plan = df3.queryExecution.executedPlan
assert(plan.collect { case p: T => p }.size === 1)
plan.executeCollect()
}
}
test("unsafe broadcast hash join updates peak execution memory") {
testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner")
}
test("unsafe broadcast hash outer join updates peak execution memory") {
testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
}
test("unsafe broadcast left semi join updates peak execution memory") {
testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment