From ece00566e4d5f38585f2810bef38e526cae7d25e Mon Sep 17 00:00:00 2001
From: Andrew Or <andrew@databricks.com>
Date: Fri, 14 Aug 2015 12:37:21 -0700
Subject: [PATCH] [SPARK-9561] Re-enable BroadcastJoinSuite

We can do this now that SPARK-9580 is resolved.

Author: Andrew Or <andrew@databricks.com>

Closes #8208 from andrewor14/reenable-sql-tests.
---
 .../execution/joins/BroadcastJoinSuite.scala  | 145 +++++++++---------
 1 file changed, 69 insertions(+), 76 deletions(-)

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 0554e11d25..53a0e53fd7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -15,80 +15,73 @@
 * limitations under the License.
 */
 
-// TODO: uncomment the test here! It is currently failing due to
-// bad interaction with org.apache.spark.sql.test.TestSQLContext.
+package org.apache.spark.sql.execution.joins
 
-// scalastyle:off
-//package org.apache.spark.sql.execution.joins
-//
-//import scala.reflect.ClassTag
-//
-//import org.scalatest.BeforeAndAfterAll
-//
-//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
-//import org.apache.spark.sql.functions._
-//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest}
-//
-///**
-// * Test various broadcast join operators with unsafe enabled.
-// *
-// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs
-// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode.
-// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in
-// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without
-// * serializing the hashed relation, which does not happen in local mode.
-// */
-//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
-//  private var sc: SparkContext = null
-//  private var sqlContext: SQLContext = null
-//
-//  /**
-//   * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled.
-//   */
-//  override def beforeAll(): Unit = {
-//    super.beforeAll()
-//    val conf = new SparkConf()
-//      .setMaster("local-cluster[2,1,1024]")
-//      .setAppName("testing")
-//    sc = new SparkContext(conf)
-//    sqlContext = new SQLContext(sc)
-//    sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true)
-//    sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
-//  }
-//
-//  override def afterAll(): Unit = {
-//    sc.stop()
-//    sc = null
-//    sqlContext = null
-//  }
-//
-//  /**
-//   * Test whether the specified broadcast join updates the peak execution memory accumulator.
-//   */
-//  private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
-//    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) {
-//      val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
-//      val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
-//      // Comparison at the end is for broadcast left semi join
-//      val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
-//      val df3 = df1.join(broadcast(df2), joinExpression, joinType)
-//      val plan = df3.queryExecution.executedPlan
-//      assert(plan.collect { case p: T => p }.size === 1)
-//      plan.executeCollect()
-//    }
-//  }
-//
-//  test("unsafe broadcast hash join updates peak execution memory") {
-//    testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner")
-//  }
-//
-//  test("unsafe broadcast hash outer join updates peak execution memory") {
-//    testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
-//  }
-//
-//  test("unsafe broadcast left semi join updates peak execution memory") {
-//    testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
-//  }
-//
-//}
-// scalastyle:on
+import scala.reflect.ClassTag
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest}
+
+/**
+ * Test various broadcast join operators with unsafe enabled.
+ *
+ * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of
+ * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered
+ * without serializing the hashed relation, which does not happen in local mode.
+ */
+class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
+  private var sc: SparkContext = null
+  private var sqlContext: SQLContext = null
+
+  /**
+   * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled.
+   */
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val conf = new SparkConf()
+      .setMaster("local-cluster[2,1,1024]")
+      .setAppName("testing")
+    sc = new SparkContext(conf)
+    sqlContext = new SQLContext(sc)
+    sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true)
+    sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true)
+  }
+
+  override def afterAll(): Unit = {
+    sc.stop()
+    sc = null
+    sqlContext = null
+  }
+
+  /**
+   * Test whether the specified broadcast join updates the peak execution memory accumulator.
+   */
+  private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = {
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) {
+      val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value")
+      val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value")
+      // Comparison at the end is for broadcast left semi join
+      val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
+      val df3 = df1.join(broadcast(df2), joinExpression, joinType)
+      val plan = df3.queryExecution.executedPlan
+      assert(plan.collect { case p: T => p }.size === 1)
+      plan.executeCollect()
+    }
+  }
+
+  test("unsafe broadcast hash join updates peak execution memory") {
+    testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner")
+  }
+
+  test("unsafe broadcast hash outer join updates peak execution memory") {
+    testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer")
+  }
+
+  test("unsafe broadcast left semi join updates peak execution memory") {
+    testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
+  }
+
+}
-- 
GitLab