diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 07f4d2946c1b52ec22ff18c892a733fb1f551eec..8b4cf5bac01872d1f2c7a22e40b1203a1dcc3c11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -19,17 +19,13 @@ package org.apache.spark.sql
 
 import org.scalatest.BeforeAndAfterEach
 
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi}
-import org.apache.spark.sql.execution._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter}
 import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.test.TestSQLContext._
 
 class JoinSuite extends QueryTest with BeforeAndAfterEach {
-
   // Ensures tables are loaded.
   TestData
 
@@ -41,54 +37,65 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
     assert(planned.size === 1)
   }
 
-  test("join operator selection") {
-    def assertJoin(sqlString: String, c: Class[_]): Any = {
-      val rdd = sql(sqlString)
-      val physical = rdd.queryExecution.sparkPlan
-      val operators = physical.collect {
-        case j: ShuffledHashJoin => j
-        case j: HashOuterJoin => j
-        case j: LeftSemiJoinHash => j
-        case j: BroadcastHashJoin => j
-        case j: LeftSemiJoinBNL => j
-        case j: CartesianProduct => j
-        case j: BroadcastNestedLoopJoin => j
-      }
-
-      assert(operators.size === 1)
-      if (operators(0).getClass() != c) {
-        fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
-      }
+  def assertJoin(sqlString: String, c: Class[_]): Any = {
+    val rdd = sql(sqlString)
+    val physical = rdd.queryExecution.sparkPlan
+    val operators = physical.collect {
+      case j: ShuffledHashJoin => j
+      case j: HashOuterJoin => j
+      case j: LeftSemiJoinHash => j
+      case j: BroadcastHashJoin => j
+      case j: LeftSemiJoinBNL => j
+      case j: CartesianProduct => j
+      case j: BroadcastNestedLoopJoin => j
+    }
+
+    assert(operators.size === 1)
+    if (operators(0).getClass() != c) {
+      fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
     }
+  }
 
-    val cases1 = Seq(
-      ("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]),
-      ("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]),
-      ("SELECT * FROM testData join testData2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData left join testData2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData right join testData2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]),
-      ("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]),
-      ("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]),
-      ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
-      ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
-      ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]),
-      ("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]),
-      ("SELECT * FROM testData right join testData2 ON key = a where key=2", 
+  test("join operator selection") {
+    clearCache()
+
+    Seq(
+      ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
+      ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
+      ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
+      ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+      ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
+      ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
+      ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
+      ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+      ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+      ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+      ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
+      ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
+      ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
+      ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
+      ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
+      ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]),
+      ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
         classOf[HashOuterJoin]),
-      ("SELECT * FROM testData right join testData2 ON key = a and key=2", 
+      ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
         classOf[HashOuterJoin]),
-      ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]),
-      ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
-      ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
-      ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin])
-    // TODO add BroadcastNestedLoopJoin
-    )
-    cases1.foreach { c => assertJoin(c._1, c._2) }
+      ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin])
+      // TODO add BroadcastNestedLoopJoin
+    ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+  }
+
+  test("broadcasted hash join operator selection") {
+    clearCache()
+    sql("CACHE TABLE testData")
+
+    Seq(
+      ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
+      ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]),
+      ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin])
+    ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+
+    sql("UNCACHE TABLE testData")
   }
 
   test("multiple-key equi-join is hash-join") {
@@ -171,7 +178,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       (4, "D", 4, "d") ::
       (5, "E", null, null) ::
       (6, "F", null, null) :: Nil)
-    
+
     checkAnswer(
       upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
       (1, "A", null, null) ::
@@ -180,7 +187,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       (4, "D", 4, "d") ::
       (5, "E", null, null) ::
       (6, "F", null, null) :: Nil)
-    
+
     checkAnswer(
       upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
       (1, "A", null, null) ::
@@ -189,7 +196,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       (4, "D", 4, "d") ::
       (5, "E", null, null) ::
       (6, "F", null, null) :: Nil)
-    
+
     checkAnswer(
       upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
       (1, "A", 1, "a") ::
@@ -300,7 +307,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       (4, "D", 4, "D") ::
       (null, null, 5, "E") ::
       (null, null, 6, "F") :: Nil)
-    
+
     checkAnswer(
       left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
       (1, "A", null, null) ::
@@ -310,7 +317,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       (4, "D", 4, "D") ::
       (null, null, 5, "E") ::
       (null, null, 6, "F") :: Nil)
-    
+
     checkAnswer(
       left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
       (1, "A", null, null) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 6c38575b13a2d9078b1c31f75df1ce2215e43db6..c4dd3e860f5fd78f2f2d698abd6a810e2fb9c3e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -80,7 +80,7 @@ object TestData {
       UpperCaseData(3, "C") ::
       UpperCaseData(4, "D") ::
       UpperCaseData(5, "E") ::
-      UpperCaseData(6, "F") :: Nil)
+      UpperCaseData(6, "F") :: Nil).toSchemaRDD
   upperCaseData.registerTempTable("upperCaseData")
 
   case class LowerCaseData(n: Int, l: String)
@@ -89,7 +89,7 @@ object TestData {
       LowerCaseData(1, "a") ::
       LowerCaseData(2, "b") ::
       LowerCaseData(3, "c") ::
-      LowerCaseData(4, "d") :: Nil)
+      LowerCaseData(4, "d") :: Nil).toSchemaRDD
   lowerCaseData.registerTempTable("lowerCaseData")
 
   case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])