diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 784f83c10e02353df509f9fc354042521c3e55f4..88745dc086a0401b34b8b6def6df8cdec1f973a0 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -178,6 +178,16 @@ object MimaExcludes {
             // SPARK-4751 Dynamic allocation for standalone mode
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.SparkContext.supportDynamicAllocation")
+          ) ++ Seq(
+            // SPARK-9580: Remove SQL test singletons
+            ProblemFilters.exclude[MissingClassProblem](
+              "org.apache.spark.sql.test.LocalSQLContext$SQLSession"),
+            ProblemFilters.exclude[MissingClassProblem](
+              "org.apache.spark.sql.test.LocalSQLContext"),
+            ProblemFilters.exclude[MissingClassProblem](
+              "org.apache.spark.sql.test.TestSQLContext"),
+            ProblemFilters.exclude[MissingClassProblem](
+              "org.apache.spark.sql.test.TestSQLContext$")
           ) ++ Seq(
             // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs
             ProblemFilters.exclude[IncompatibleResultTypeProblem](
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 74f815f941d5bbaabdc807e44886105723d70dc0..04e0d49b178cfad36898d140017f42416b0bb9ab 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -319,6 +319,8 @@ object SQL {
   lazy val settings = Seq(
     initialCommands in console :=
       """
+        |import org.apache.spark.SparkContext
+        |import org.apache.spark.sql.SQLContext
         |import org.apache.spark.sql.catalyst.analysis._
         |import org.apache.spark.sql.catalyst.dsl._
         |import org.apache.spark.sql.catalyst.errors._
@@ -328,9 +330,14 @@ object SQL {
         |import org.apache.spark.sql.catalyst.util._
         |import org.apache.spark.sql.execution
         |import org.apache.spark.sql.functions._
-        |import org.apache.spark.sql.test.TestSQLContext._
-        |import org.apache.spark.sql.types._""".stripMargin,
-    cleanupCommands in console := "sparkContext.stop()"
+        |import org.apache.spark.sql.types._
+        |
+        |val sc = new SparkContext("local[*]", "dev-shell")
+        |val sqlContext = new SQLContext(sc)
+        |import sqlContext.implicits._
+        |import sqlContext._
+      """.stripMargin,
+    cleanupCommands in console := "sc.stop()"
   )
 }
 
@@ -340,8 +347,6 @@ object Hive {
     javaOptions += "-XX:MaxPermSize=256m",
     // Specially disable assertions since some Hive tests fail them
     javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
-    // Multiple queries rely on the TestHive singleton. See comments there for more details.
-    parallelExecution in Test := false,
     // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings
     // only for this subproject.
     scalacOptions <<= scalacOptions map { currentOpts: Seq[String] =>
@@ -349,6 +354,7 @@ object Hive {
     },
     initialCommands in console :=
       """
+        |import org.apache.spark.SparkContext
         |import org.apache.spark.sql.catalyst.analysis._
         |import org.apache.spark.sql.catalyst.dsl._
         |import org.apache.spark.sql.catalyst.errors._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 63b475b6366c2af81f4b46c69310999a8f0ab7c5..f60d11c988ef8471f44c51d0f4487e5eeab381b5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -17,14 +17,10 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.types._
@@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode {
   override def output: Seq[Attribute] = Nil
 }
 
-class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter {
+class AnalysisErrorSuite extends AnalysisTest {
   import TestRelations._
 
   def errorTest(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 4bf00b3399e7ae371075a380e3ca472ec5480c6f..53de10d5fa9daa2333f36873fb36bc3fcf5fcec0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference
 
 import scala.collection.JavaConversions._
 import scala.collection.immutable
-import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
 
@@ -41,10 +40,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
-import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
 /**
@@ -334,97 +332,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
    * @since 1.3.0
    */
   @Experimental
-  object implicits extends Serializable {
-    // scalastyle:on
-
-    /**
-     * Converts $"col name" into an [[Column]].
-     * @since 1.3.0
-     */
-    implicit class StringToColumn(val sc: StringContext) {
-      def $(args: Any*): ColumnName = {
-        new ColumnName(sc.s(args: _*))
-      }
-    }
-
-    /**
-     * An implicit conversion that turns a Scala `Symbol` into a [[Column]].
-     * @since 1.3.0
-     */
-    implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
-
-    /**
-     * Creates a DataFrame from an RDD of case classes or tuples.
-     * @since 1.3.0
-     */
-    implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
-      DataFrameHolder(self.createDataFrame(rdd))
-    }
-
-    /**
-     * Creates a DataFrame from a local Seq of Product.
-     * @since 1.3.0
-     */
-    implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
-    {
-      DataFrameHolder(self.createDataFrame(data))
-    }
-
-    // Do NOT add more implicit conversions. They are likely to break source compatibility by
-    // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
-    // because of [[DoubleRDDFunctions]].
-
-    /**
-     * Creates a single column DataFrame from an RDD[Int].
-     * @since 1.3.0
-     */
-    implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
-      val dataType = IntegerType
-      val rows = data.mapPartitions { iter =>
-        val row = new SpecificMutableRow(dataType :: Nil)
-        iter.map { v =>
-          row.setInt(0, v)
-          row: InternalRow
-        }
-      }
-      DataFrameHolder(
-        self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
-    }
-
-    /**
-     * Creates a single column DataFrame from an RDD[Long].
-     * @since 1.3.0
-     */
-    implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
-      val dataType = LongType
-      val rows = data.mapPartitions { iter =>
-        val row = new SpecificMutableRow(dataType :: Nil)
-        iter.map { v =>
-          row.setLong(0, v)
-          row: InternalRow
-        }
-      }
-      DataFrameHolder(
-        self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
-    }
-
-    /**
-     * Creates a single column DataFrame from an RDD[String].
-     * @since 1.3.0
-     */
-    implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
-      val dataType = StringType
-      val rows = data.mapPartitions { iter =>
-        val row = new SpecificMutableRow(dataType :: Nil)
-        iter.map { v =>
-          row.update(0, UTF8String.fromString(v))
-          row: InternalRow
-        }
-      }
-      DataFrameHolder(
-        self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
-    }
+  object implicits extends SQLImplicits with Serializable {
+    protected override def _sqlContext: SQLContext = self
   }
+  // scalastyle:on
 
   /**
    * :: Experimental ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5f82372700f2c9b93a42e28ec0442dce213e2e88
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.types.StructField
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s.
+ */
+private[sql] abstract class SQLImplicits {
+  protected def _sqlContext: SQLContext
+
+  /**
+   * Converts $"col name" into an [[Column]].
+   * @since 1.3.0
+   */
+  implicit class StringToColumn(val sc: StringContext) {
+    def $(args: Any*): ColumnName = {
+      new ColumnName(sc.s(args: _*))
+    }
+  }
+
+  /**
+   * An implicit conversion that turns a Scala `Symbol` into a [[Column]].
+   * @since 1.3.0
+   */
+  implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
+  /**
+   * Creates a DataFrame from an RDD of case classes or tuples.
+   * @since 1.3.0
+   */
+  implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
+    DataFrameHolder(_sqlContext.createDataFrame(rdd))
+  }
+
+  /**
+   * Creates a DataFrame from a local Seq of Product.
+   * @since 1.3.0
+   */
+  implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
+  {
+    DataFrameHolder(_sqlContext.createDataFrame(data))
+  }
+
+  // Do NOT add more implicit conversions. They are likely to break source compatibility by
+  // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
+  // because of [[DoubleRDDFunctions]].
+
+  /**
+   * Creates a single column DataFrame from an RDD[Int].
+   * @since 1.3.0
+   */
+  implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
+    val dataType = IntegerType
+    val rows = data.mapPartitions { iter =>
+      val row = new SpecificMutableRow(dataType :: Nil)
+      iter.map { v =>
+        row.setInt(0, v)
+        row: InternalRow
+      }
+    }
+    DataFrameHolder(
+      _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
+  }
+
+  /**
+   * Creates a single column DataFrame from an RDD[Long].
+   * @since 1.3.0
+   */
+  implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
+    val dataType = LongType
+    val rows = data.mapPartitions { iter =>
+      val row = new SpecificMutableRow(dataType :: Nil)
+      iter.map { v =>
+        row.setLong(0, v)
+        row: InternalRow
+      }
+    }
+    DataFrameHolder(
+      _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
+  }
+
+  /**
+   * Creates a single column DataFrame from an RDD[String].
+   * @since 1.3.0
+   */
+  implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
+    val dataType = StringType
+    val rows = data.mapPartitions { iter =>
+      val row = new SpecificMutableRow(dataType :: Nil)
+      iter.map { v =>
+        row.update(0, UTF8String.fromString(v))
+        row: InternalRow
+      }
+    }
+    DataFrameHolder(
+      _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
+  }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index e912eb835d169f5e08f2e358cb99ab40aa33a5f0..bf693c7c393f69bf12ac400ad05cb499f302b7c4 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -27,6 +27,7 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
@@ -34,7 +35,6 @@ import org.apache.spark.sql.DataFrame;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
 import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.test.TestSQLContext$;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
@@ -48,14 +48,16 @@ public class JavaApplySchemaSuite implements Serializable {
 
   @Before
   public void setUp() {
-    sqlContext = TestSQLContext$.MODULE$;
-    javaCtx = new JavaSparkContext(sqlContext.sparkContext());
+    SparkContext context = new SparkContext("local[*]", "testing");
+    javaCtx = new JavaSparkContext(context);
+    sqlContext = new SQLContext(context);
   }
 
   @After
   public void tearDown() {
-    javaCtx = null;
+    sqlContext.sparkContext().stop();
     sqlContext = null;
+    javaCtx = null;
   }
 
   public static class Person implements Serializable {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 7302361ab9fdb1d33e6007a3aec8f05c7538c279..7abdd3db8034116030d384c70b0c0b2eff6b65c1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -17,44 +17,45 @@
 
 package test.org.apache.spark.sql;
 
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
+import scala.collection.JavaConversions;
+import scala.collection.Seq;
+
 import com.google.common.collect.ImmutableMap;
 import com.google.common.primitives.Ints;
+import org.junit.*;
 
+import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.*;
+import static org.apache.spark.sql.functions.*;
 import org.apache.spark.sql.test.TestSQLContext;
-import org.apache.spark.sql.test.TestSQLContext$;
 import org.apache.spark.sql.types.*;
-import org.junit.*;
-
-import scala.collection.JavaConversions;
-import scala.collection.Seq;
-
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.Comparator;
-import java.util.List;
-import java.util.Map;
-
-import static org.apache.spark.sql.functions.*;
 
 public class JavaDataFrameSuite {
   private transient JavaSparkContext jsc;
-  private transient SQLContext context;
+  private transient TestSQLContext context;
 
   @Before
   public void setUp() {
     // Trigger static initializer of TestData
-    TestData$.MODULE$.testData();
-    jsc = new JavaSparkContext(TestSQLContext.sparkContext());
-    context = TestSQLContext$.MODULE$;
+    SparkContext sc = new SparkContext("local[*]", "testing");
+    jsc = new JavaSparkContext(sc);
+    context = new TestSQLContext(sc);
+    context.loadTestData();
   }
 
   @After
   public void tearDown() {
-    jsc = null;
+    context.sparkContext().stop();
     context = null;
+    jsc = null;
   }
 
   @Test
@@ -230,7 +231,7 @@ public class JavaDataFrameSuite {
 
   @Test
   public void testSampleBy() {
-    DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
+    DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
     DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
     Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
     Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)};
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index 79d92734ff3757bad7d3b299c2c4da29081fd168..bb02b58cca9be4e3f239f9875ddd1a6de7611beb 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -23,12 +23,12 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SparkContext;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SQLContext;
 import org.apache.spark.sql.api.java.UDF1;
 import org.apache.spark.sql.api.java.UDF2;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.test.TestSQLContext$;
 import org.apache.spark.sql.types.DataTypes;
 
 // The test suite itself is Serializable so that anonymous Function implementations can be
@@ -40,12 +40,16 @@ public class JavaUDFSuite implements Serializable {
 
   @Before
   public void setUp() {
-    sqlContext = TestSQLContext$.MODULE$;
-    sc = new JavaSparkContext(sqlContext.sparkContext());
+    SparkContext _sc = new SparkContext("local[*]", "testing");
+    sqlContext = new SQLContext(_sc);
+    sc = new JavaSparkContext(_sc);
   }
 
   @After
   public void tearDown() {
+    sqlContext.sparkContext().stop();
+    sqlContext = null;
+    sc = null;
   }
 
   @SuppressWarnings("unchecked")
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
index 2706e01bd28afc8426c0456e8d919cf08f25d63c..6f9e7f68dc39c669d4d989863d7f695788af525c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -21,13 +21,14 @@ import java.io.File;
 import java.io.IOException;
 import java.util.*;
 
+import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.test.TestSQLContext$;
 import org.apache.spark.sql.*;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
@@ -52,8 +53,9 @@ public class JavaSaveLoadSuite {
 
   @Before
   public void setUp() throws IOException {
-    sqlContext = TestSQLContext$.MODULE$;
-    sc = new JavaSparkContext(sqlContext.sparkContext());
+    SparkContext _sc = new SparkContext("local[*]", "testing");
+    sqlContext = new SQLContext(_sc);
+    sc = new JavaSparkContext(_sc);
 
     originalDefaultSource = sqlContext.conf().defaultDataSourceName();
     path =
@@ -71,6 +73,13 @@ public class JavaSaveLoadSuite {
     df.registerTempTable("jsonTable");
   }
 
+  @After
+  public void tearDown() {
+    sqlContext.sparkContext().stop();
+    sqlContext = null;
+    sc = null;
+  }
+
   @Test
   public void saveAndLoad() {
     Map<String, String> options = new HashMap<String, String>();
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index a88df91b1001cbda3ce784033182e17315ecb8a3..af7590c3d3c175bfb432557eb562a7a3165933f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -18,24 +18,20 @@
 package org.apache.spark.sql
 
 import scala.concurrent.duration._
-import scala.language.{implicitConversions, postfixOps}
+import scala.language.postfixOps
 
 import org.scalatest.concurrent.Eventually._
 
 import org.apache.spark.Accumulators
-import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.columnar._
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.storage.{StorageLevel, RDDBlockId}
 
-case class BigData(s: String)
+private case class BigData(s: String)
 
-class CachedTableSuite extends QueryTest {
-  TestData // Load test tables.
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
-  import ctx.sql
+class CachedTableSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   def rddIdOf(tableName: String): Int = {
     val executedPlan = ctx.table(tableName).queryExecution.executedPlan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 6a09a3b72c0818c83def92dcb3abbfb8b311cbd7..ee74e3e83da5a8db2f4c08b32d1057849541f7f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -21,16 +21,20 @@ import org.scalatest.Matchers._
 
 import org.apache.spark.sql.execution.{Project, TungstenProject}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.SQLTestUtils
 
-class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
-  import org.apache.spark.sql.TestData._
+class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
-
-  override def sqlContext(): SQLContext = ctx
+  private lazy val booleanData = {
+    ctx.createDataFrame(ctx.sparkContext.parallelize(
+      Row(false, false) ::
+      Row(false, true) ::
+      Row(true, false) ::
+      Row(true, true) :: Nil),
+      StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
+  }
 
   test("column names with space") {
     val df = Seq((1, "a")).toDF("name with space", "name.with.dot")
@@ -258,7 +262,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
       nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
 
     checkAnswer(
-      ctx.sql("select isnull(null), isnull(1)"),
+      sql("select isnull(null), isnull(1)"),
       Row(true, false))
   }
 
@@ -268,7 +272,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
       nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
 
     checkAnswer(
-      ctx.sql("select isnotnull(null), isnotnull('a')"),
+      sql("select isnotnull(null), isnotnull('a')"),
       Row(false, true))
   }
 
@@ -289,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
       Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil)
 
     checkAnswer(
-      ctx.sql("select isnan(15), isnan('invalid')"),
+      sql("select isnan(15), isnan('invalid')"),
       Row(false, false))
   }
 
@@ -309,7 +313,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
     )
     testData.registerTempTable("t")
     checkAnswer(
-      ctx.sql(
+      sql(
         "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " +
           " nanvl(b, e), nanvl(e, f) from t"),
       Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
@@ -433,13 +437,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
     }
   }
 
-  val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(
-    Row(false, false) ::
-      Row(false, true) ::
-      Row(true, false) ::
-      Row(true, true) :: Nil),
-    StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
-
   test("&&") {
     checkAnswer(
       booleanData.filter($"a" && true),
@@ -523,7 +520,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
     )
 
     checkAnswer(
-      ctx.sql("SELECT upper('aB'), ucase('cDe')"),
+      sql("SELECT upper('aB'), ucase('cDe')"),
       Row("AB", "CDE"))
   }
 
@@ -544,7 +541,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
     )
 
     checkAnswer(
-      ctx.sql("SELECT lower('aB'), lcase('cDe')"),
+      sql("SELECT lower('aB'), lcase('cDe')"),
       Row("ab", "cde"))
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index f9cff7440a76e21fe749d1c4416377eeb2b1f0eb..72cf7aab0b977463524c242b63b33b909aab2f6b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -17,15 +17,13 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{BinaryType, DecimalType}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.DecimalType
 
 
-class DataFrameAggregateSuite extends QueryTest {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("groupBy") {
     checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 03116a374f3be387e62b05977d841d94828327f9..9d965258e389df1dbfc249d60c2ad5371e0c8bdd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -17,17 +17,15 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 /**
  * Test suite for functions in [[org.apache.spark.sql.functions]].
  */
-class DataFrameFunctionsSuite extends QueryTest {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("array with column name") {
     val df = Seq((0, 1)).toDF("a", "b")
@@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest {
 
   test("constant functions") {
     checkAnswer(
-      ctx.sql("SELECT E()"),
+      sql("SELECT E()"),
       Row(scala.math.E)
     )
     checkAnswer(
-      ctx.sql("SELECT PI()"),
+      sql("SELECT PI()"),
       Row(scala.math.Pi)
     )
   }
@@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest {
 
   test("nvl function") {
     checkAnswer(
-      ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
+      sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
       Row("x", "y", null))
   }
 
@@ -222,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest {
       Row(-1)
     )
     checkAnswer(
-      ctx.sql("SELECT least(a, 2) as l from testData2 order by l"),
+      sql("SELECT least(a, 2) as l from testData2 order by l"),
       Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2))
     )
   }
@@ -233,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest {
       Row(3)
     )
     checkAnswer(
-      ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"),
+      sql("SELECT greatest(a, 2) as g from testData2 order by g"),
       Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
     )
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
index fbb30706a49430bd5442babcaf07c2f8e5a83f08..e5d7d63441a6b9bff1ee88750f8223ee2b5b11df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
@@ -17,10 +17,10 @@
 
 package org.apache.spark.sql
 
-class DataFrameImplicitsSuite extends QueryTest {
+import org.apache.spark.sql.test.SharedSQLContext
 
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("RDD of tuples") {
     checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index e1c6c706242d2a13483b93dac80c67e091b899c7..e2716d7841d85fd5339cac9082474eb234f8baba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -17,14 +17,12 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.execution.joins.BroadcastHashJoin
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
 
-class DataFrameJoinSuite extends QueryTest {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("join - join using") {
     val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
@@ -59,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest {
 
     checkAnswer(
       df1.join(df2, $"df1.key" === $"df2.key"),
-      ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
+      sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
         .collect().toSeq)
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index dbe3b44ee2c79515ee452452dfde8d0db717ef37..cdaa14ac80785626d0fd853d2b8349bde3a6e581 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql
 
 import scala.collection.JavaConversions._
 
+import org.apache.spark.sql.test.SharedSQLContext
 
-class DataFrameNaFunctionsSuite extends QueryTest {
 
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   def createDF(): DataFrame = {
     Seq[(String, java.lang.Integer, java.lang.Double)](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 8f5984e4a8ce27e2dcd472eee9610f339eaa2f4e..28bdd6f83b687b6010aded9c42b1e079805f3697 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -19,20 +19,17 @@ package org.apache.spark.sql
 
 import java.util.Random
 
-import org.scalatest.Matchers._
-
 import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.test.SharedSQLContext
 
-class DataFrameStatSuite extends QueryTest {
-
-  private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
-  import sqlCtx.implicits._
+class DataFrameStatSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   private def toLetter(i: Int): String = (i + 97).toChar.toString
 
   test("sample with replacement") {
     val n = 100
-    val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+    val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
     checkAnswer(
       data.sample(withReplacement = true, 0.05, seed = 13),
       Seq(5, 10, 52, 73).map(Row(_))
@@ -41,7 +38,7 @@ class DataFrameStatSuite extends QueryTest {
 
   test("sample without replacement") {
     val n = 100
-    val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+    val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
     checkAnswer(
       data.sample(withReplacement = false, 0.05, seed = 13),
       Seq(16, 23, 88, 100).map(Row(_))
@@ -50,7 +47,7 @@ class DataFrameStatSuite extends QueryTest {
 
   test("randomSplit") {
     val n = 600
-    val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+    val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
     for (seed <- 1 to 5) {
       val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
       assert(splits.length == 3, "wrong number of splits")
@@ -167,7 +164,7 @@ class DataFrameStatSuite extends QueryTest {
   }
 
   test("Frequent Items 2") {
-    val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4)
+    val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4)
     // this is a regression test, where when merging partitions, we omitted values with higher
     // counts than those that existed in the map when the map was full. This test should also fail
     // if anything like SPARK-9614 is observed once again
@@ -185,7 +182,7 @@ class DataFrameStatSuite extends QueryTest {
   }
 
   test("sampleBy") {
-    val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
+    val df = ctx.range(0, 100).select((col("id") % 3).as("key"))
     val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
     checkAnswer(
       sampled.groupBy("key").count().orderBy("key"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 2feec29955bc8308f0595658648e15cf2a9f5ae4..10bfa9b64f00db81e7b4fe34b320ca7c8b1207d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -23,18 +23,12 @@ import scala.language.postfixOps
 import scala.util.Random
 
 import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
-import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.execution.datasources.json.JSONRelation
-import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils}
+import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext}
 
-class DataFrameSuite extends QueryTest with SQLTestUtils {
-  import org.apache.spark.sql.TestData._
-
-  lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
-  import sqlContext.implicits._
+class DataFrameSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("analysis error should be eagerly reported") {
     // Eager analysis.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
index bf8ef9a97bc601fb8f82b3b3a422b648484f4f70..77907e91363ec2655852ee684f010e536794abb5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 /**
@@ -27,10 +27,8 @@ import org.apache.spark.sql.types._
  * This is here for now so I can make sure Tungsten project is tested without refactoring existing
  * end-to-end test infra. In the long run this should just go away.
  */
-class DataFrameTungstenSuite extends QueryTest with SQLTestUtils {
-
-  override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
-  import sqlContext.implicits._
+class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("test simple types") {
     withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
index 17897caf952a348d9ea36441d52f561af02fa7ca..9080c53c491acdbc67a7bfbd6729a45b8786f2af 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
@@ -22,19 +22,18 @@ import java.text.SimpleDateFormat
 
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.unsafe.types.CalendarInterval
 
-class DateFunctionsSuite extends QueryTest {
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-
-  import ctx.implicits._
+class DateFunctionsSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("function current_date") {
     val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
     val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
     val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0))
     val d2 = DateTimeUtils.fromJavaDate(
-      ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
+      sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
     val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis())
     assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1)
   }
@@ -44,9 +43,9 @@ class DateFunctionsSuite extends QueryTest {
     val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
     checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))
     // Execution in one query should return the same value
-    checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
+    checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
       Row(true))
-    assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
+    assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
       0).getTime - System.currentTimeMillis()) < 5000)
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index ae07eaf91c8720bf44ae78e3ea9bb766af98a356..f5c5046a8ed88438cb893ccc22fcd7d0c2324336 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -17,22 +17,15 @@
 
 package org.apache.spark.sql
 
-import org.scalatest.BeforeAndAfterEach
-
-import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.test.SharedSQLContext
 
 
-class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
-  // Ensures tables are loaded.
-  TestData
+class JoinSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
-  override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
-  lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
-  import ctx.logicalPlanToSparkQuery
+  setupTestData()
 
   test("equi-join is hash-join") {
     val x = testData2.as("x")
@@ -43,7 +36,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
   }
 
   def assertJoin(sqlString: String, c: Class[_]): Any = {
-    val df = ctx.sql(sqlString)
+    val df = sql(sqlString)
     val physical = df.queryExecution.sparkPlan
     val operators = physical.collect {
       case j: ShuffledHashJoin => j
@@ -126,7 +119,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
 
   test("broadcasted hash join operator selection") {
     ctx.cacheManager.clearCache()
-    ctx.sql("CACHE TABLE testData")
+    sql("CACHE TABLE testData")
     for (sortMergeJoinEnabled <- Seq(true, false)) {
       withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") {
         withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") {
@@ -141,12 +134,12 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
         }
       }
     }
-    ctx.sql("UNCACHE TABLE testData")
+    sql("UNCACHE TABLE testData")
   }
 
   test("broadcasted hash outer join operator selection") {
     ctx.cacheManager.clearCache()
-    ctx.sql("CACHE TABLE testData")
+    sql("CACHE TABLE testData")
     withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
       Seq(
         ("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
@@ -167,7 +160,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
           classOf[BroadcastHashOuterJoin])
       ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
     }
-    ctx.sql("UNCACHE TABLE testData")
+    sql("UNCACHE TABLE testData")
   }
 
   test("multiple-key equi-join is hash-join") {
@@ -279,7 +272,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
     // Make sure we are choosing left.outputPartitioning as the
     // outputPartitioning for the outer join operator.
     checkAnswer(
-      ctx.sql(
+      sql(
         """
           |SELECT l.N, count(*)
           |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@@ -293,7 +286,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
         Row(6, 1) :: Nil)
 
     checkAnswer(
-      ctx.sql(
+      sql(
         """
           |SELECT r.a, count(*)
           |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@@ -339,7 +332,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
     // Make sure we are choosing right.outputPartitioning as the
     // outputPartitioning for the outer join operator.
     checkAnswer(
-      ctx.sql(
+      sql(
         """
           |SELECT l.a, count(*)
           |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -348,7 +341,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
       Row(null, 6))
 
     checkAnswer(
-      ctx.sql(
+      sql(
         """
           |SELECT r.N, count(*)
           |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -400,7 +393,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
 
     // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
     checkAnswer(
-      ctx.sql(
+      sql(
         """
           |SELECT l.a, count(*)
           |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -409,7 +402,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
       Row(null, 10))
 
     checkAnswer(
-      ctx.sql(
+      sql(
         """
           |SELECT r.N, count(*)
           |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -424,7 +417,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
         Row(null, 4) :: Nil)
 
     checkAnswer(
-      ctx.sql(
+      sql(
         """
           |SELECT l.N, count(*)
           |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@@ -439,7 +432,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
         Row(null, 4) :: Nil)
 
     checkAnswer(
-      ctx.sql(
+      sql(
         """
           |SELECT r.a, count(*)
           |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@@ -450,7 +443,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
 
   test("broadcasted left semi join operator selection") {
     ctx.cacheManager.clearCache()
-    ctx.sql("CACHE TABLE testData")
+    sql("CACHE TABLE testData")
 
     withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
       Seq(
@@ -469,11 +462,11 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
       }
     }
 
-    ctx.sql("UNCACHE TABLE testData")
+    sql("UNCACHE TABLE testData")
   }
 
   test("left semi join") {
-    val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
+    val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
     checkAnswer(df,
       Row(1, 1) ::
         Row(1, 2) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 71c26a6f8d3678694299fb0b75bfa1a941a0602f..045fea82e4c897dfcc3ee146a541f9d31b73a175 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -17,10 +17,10 @@
 
 package org.apache.spark.sql
 
-class JsonFunctionsSuite extends QueryTest {
+import org.apache.spark.sql.test.SharedSQLContext
 
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("function get_json_object") {
     val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 2089660c52bf7bab8db00115ef3c6bcd01a33df6..babf8835d25454b3fc5f9e743c31d39f3f95fb40 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -19,12 +19,11 @@ package org.apache.spark.sql
 
 import org.scalatest.BeforeAndAfter
 
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
 
-class ListTablesSuite extends QueryTest with BeforeAndAfter {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
+  import testImplicits._
 
   private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")
 
@@ -42,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
       Row("ListTablesSuiteTable", true))
 
     checkAnswer(
-      ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
+      sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
       Row("ListTablesSuiteTable", true))
 
     ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
@@ -55,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
       Row("ListTablesSuiteTable", true))
 
     checkAnswer(
-      ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
+      sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
       Row("ListTablesSuiteTable", true))
 
     ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
@@ -67,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
       StructField("tableName", StringType, false) ::
       StructField("isTemporary", BooleanType, false) :: Nil)
 
-    Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach {
+    Seq(ctx.tables(), sql("SHOW TABLes")).foreach {
       case tableDF =>
         assert(expectedSchema === tableDF.schema)
 
         tableDF.registerTempTable("tables")
         checkAnswer(
-          ctx.sql(
+          sql(
             "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
           Row(true, "ListTablesSuiteTable")
         )
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 8cf2ef5957d8d72b9edd130c705babd20925c0a2..30289c3c1d097cec643214041c1c763f37917936 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -19,18 +19,16 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.functions.{log => logarithm}
+import org.apache.spark.sql.test.SharedSQLContext
 
 private object MathExpressionsTestData {
   case class DoubleData(a: java.lang.Double, b: java.lang.Double)
   case class NullDoubles(a: java.lang.Double)
 }
 
-class MathExpressionsSuite extends QueryTest {
-
+class MathExpressionsSuite extends QueryTest with SharedSQLContext {
   import MathExpressionsTestData._
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+  import testImplicits._
 
   private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF()
 
@@ -149,7 +147,7 @@ class MathExpressionsSuite extends QueryTest {
   test("toDegrees") {
     testOneToOneMathFunction(toDegrees, math.toDegrees)
     checkAnswer(
-      ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
+      sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
       Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5)))
     )
   }
@@ -157,7 +155,7 @@ class MathExpressionsSuite extends QueryTest {
   test("toRadians") {
     testOneToOneMathFunction(toRadians, math.toRadians)
     checkAnswer(
-      ctx.sql("SELECT radians(0), radians(1), radians(1.5)"),
+      sql("SELECT radians(0), radians(1), radians(1.5)"),
       Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5)))
     )
   }
@@ -169,7 +167,7 @@ class MathExpressionsSuite extends QueryTest {
   test("ceil and ceiling") {
     testOneToOneMathFunction(ceil, math.ceil)
     checkAnswer(
-      ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
+      sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
       Row(0.0, 1.0, 2.0))
   }
 
@@ -214,7 +212,7 @@ class MathExpressionsSuite extends QueryTest {
 
     val pi = 3.1415
     checkAnswer(
-      ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
+      sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
         s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
       Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
         BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
@@ -233,7 +231,7 @@ class MathExpressionsSuite extends QueryTest {
     testOneToOneMathFunction[Double](signum, math.signum)
 
     checkAnswer(
-      ctx.sql("SELECT sign(10), signum(-11)"),
+      sql("SELECT sign(10), signum(-11)"),
       Row(1, -1))
   }
 
@@ -241,7 +239,7 @@ class MathExpressionsSuite extends QueryTest {
     testTwoToOneMathFunction(pow, pow, math.pow)
 
     checkAnswer(
-      ctx.sql("SELECT pow(1, 2), power(2, 1)"),
+      sql("SELECT pow(1, 2), power(2, 1)"),
       Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1)))
     )
   }
@@ -280,7 +278,7 @@ class MathExpressionsSuite extends QueryTest {
   test("log / ln") {
     testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
     checkAnswer(
-      ctx.sql("SELECT ln(0), ln(1), ln(1.5)"),
+      sql("SELECT ln(0), ln(1), ln(1.5)"),
       Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5)))
     )
   }
@@ -375,7 +373,7 @@ class MathExpressionsSuite extends QueryTest {
       df.select(log2("b") + log2("a")),
       Row(1))
 
-    checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
+    checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
   }
 
   test("sqrt") {
@@ -384,13 +382,13 @@ class MathExpressionsSuite extends QueryTest {
       df.select(sqrt("a"), sqrt("b")),
       Row(1.0, 2.0))
 
-    checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
+    checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
     checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
   }
 
   test("negative") {
     checkAnswer(
-      ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
+      sql("SELECT negative(1), negative(0), negative(-1)"),
       Row(-1, 0, 1))
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 98ba3c99283a18812e018472a899712cb119ebdb..4adcefb7dc4b19df4c64d1ef83a956be5ef4a812 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -71,12 +71,6 @@ class QueryTest extends PlanTest {
     checkAnswer(df, expectedAnswer.collect())
   }
 
-  def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) {
-    test(sqlString) {
-      checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
-    }
-  }
-
   /**
    * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
    */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 8a679c7865d6a67c255bd68c576daf3a87687c37..795d4e983f27e844790e8b1728427f2641d82ff3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -20,13 +20,12 @@ package org.apache.spark.sql
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.execution.SparkSqlSerializer
 import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
-class RowSuite extends SparkFunSuite {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class RowSuite extends SparkFunSuite with SharedSQLContext {
+  import testImplicits._
 
   test("create row") {
     val expected = new GenericMutableRow(4)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 75791e9d53c20b47537ed3e9fb7f6e315e414444..7699adadd9cc8ab4863e1f6a8826de9692279b12 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -17,11 +17,10 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.test.SharedSQLContext
 
-class SQLConfSuite extends QueryTest {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
 
+class SQLConfSuite extends QueryTest with SharedSQLContext {
   private val testKey = "test.key.0"
   private val testVal = "test.val.0"
 
@@ -52,21 +51,21 @@ class SQLConfSuite extends QueryTest {
 
   test("parse SQL set commands") {
     ctx.conf.clear()
-    ctx.sql(s"set $testKey=$testVal")
+    sql(s"set $testKey=$testVal")
     assert(ctx.getConf(testKey, testVal + "_") === testVal)
     assert(ctx.getConf(testKey, testVal + "_") === testVal)
 
-    ctx.sql("set some.property=20")
+    sql("set some.property=20")
     assert(ctx.getConf("some.property", "0") === "20")
-    ctx.sql("set some.property = 40")
+    sql("set some.property = 40")
     assert(ctx.getConf("some.property", "0") === "40")
 
     val key = "spark.sql.key"
     val vs = "val0,val_1,val2.3,my_table"
-    ctx.sql(s"set $key=$vs")
+    sql(s"set $key=$vs")
     assert(ctx.getConf(key, "0") === vs)
 
-    ctx.sql(s"set $key=")
+    sql(s"set $key=")
     assert(ctx.getConf(key, "0") === "")
 
     ctx.conf.clear()
@@ -74,14 +73,14 @@ class SQLConfSuite extends QueryTest {
 
   test("deprecated property") {
     ctx.conf.clear()
-    ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
+    sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
     assert(ctx.conf.numShufflePartitions === 10)
   }
 
   test("invalid conf value") {
     ctx.conf.clear()
     val e = intercept[IllegalArgumentException] {
-      ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
+      sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
     }
     assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10")
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index c8d8796568a4195e8e0ecbb3415f7e2b864cc027..007be12950774d658c4901cb89ece12c4fdd1e8d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -17,16 +17,17 @@
 
 package org.apache.spark.sql
 
-import org.scalatest.BeforeAndAfterAll
-
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.SharedSQLContext
 
-class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+class SQLContextSuite extends SparkFunSuite with SharedSQLContext {
 
   override def afterAll(): Unit = {
-    SQLContext.setLastInstantiatedContext(ctx)
+    try {
+      SQLContext.setLastInstantiatedContext(ctx)
+    } finally {
+      super.afterAll()
+    }
   }
 
   test("getOrCreate instantiates SQLContext") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index b14ef9bab90cbeb7534020f323ddf5a39660c796..8c2c328f8191c99f4f4fd50c8ce1e1a5e1ede1af 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,28 +19,23 @@ package org.apache.spark.sql
 
 import java.sql.Timestamp
 
-import org.scalatest.BeforeAndAfterAll
-
 import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.DefaultParserDialect
 import org.apache.spark.sql.catalyst.errors.DialectException
 import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
 import org.apache.spark.sql.types._
 
 /** A SQL Dialect for testing purpose, and it can not be nested type */
 class MyDialect extends DefaultParserDialect
 
-class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
-  // Make sure the tables are loaded.
-  TestData
+class SQLQuerySuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
-  val sqlContext = org.apache.spark.sql.test.TestSQLContext
-  import sqlContext.implicits._
-  import sqlContext.sql
+  setupTestData()
 
   test("having clause") {
     Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav")
@@ -60,7 +55,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
   }
 
   test("show functions") {
-    checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
+    checkAnswer(sql("SHOW functions"),
+      FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
   }
 
   test("describe functions") {
@@ -178,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
 
     val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
     // we except the id is materialized once
-    val idUDF = udf(() => UUID.randomUUID().toString)
+    val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString)
 
     val dfWithId = df.withColumn("id", idUDF())
     // Make a new DataFrame (actually the same reference to the old one)
@@ -712,9 +708,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
 
     checkAnswer(
       sql(
-        """
-          |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3
-        """.stripMargin),
+        "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"),
       Row(2, 1, 2, 2, 1))
   }
 
@@ -1161,7 +1155,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
     validateMetadata(sql("SELECT * FROM personWithMeta"))
     validateMetadata(sql("SELECT id, name FROM personWithMeta"))
     validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId"))
-    validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
+    validateMetadata(sql(
+      "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
   }
 
   test("SPARK-3371 Renaming a function expression with group by gives error") {
@@ -1627,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
       .toDF("num", "str")
     df.registerTempTable("1one")
 
-    checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10))
+    checkAnswer(sql("select count(num) from 1one"), Row(10))
 
     sqlContext.dropTempTable("1one")
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index ab6d3dd96d271b80ad87408df9a12e2f2d5f14d7..295f02f9a7b5d0e08323c5f2ee9caa26bbb3a69c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import java.sql.{Date, Timestamp}
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.SharedSQLContext
 
 case class ReflectData(
     stringField: String,
@@ -71,17 +72,15 @@ case class ComplexReflectData(
     mapFieldContainsNull: Map[Int, Option[Long]],
     dataField: Data)
 
-class ScalaReflectionRelationSuite extends SparkFunSuite {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext {
+  import testImplicits._
 
   test("query case class RDD") {
     val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
       new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))
     Seq(data).toDF().registerTempTable("reflectData")
 
-    assert(ctx.sql("SELECT * FROM reflectData").collect().head ===
+    assert(sql("SELECT * FROM reflectData").collect().head ===
       Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
         new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
         new Timestamp(12345), Seq(1, 2, 3)))
@@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
     val data = NullReflectData(null, null, null, null, null, null, null)
     Seq(data).toDF().registerTempTable("reflectNullData")
 
-    assert(ctx.sql("SELECT * FROM reflectNullData").collect().head ===
+    assert(sql("SELECT * FROM reflectNullData").collect().head ===
       Row.fromSeq(Seq.fill(7)(null)))
   }
 
@@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
     val data = OptionalReflectData(None, None, None, None, None, None, None)
     Seq(data).toDF().registerTempTable("reflectOptionalData")
 
-    assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head ===
+    assert(sql("SELECT * FROM reflectOptionalData").collect().head ===
       Row.fromSeq(Seq.fill(7)(null)))
   }
 
@@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
   test("query binary data") {
     Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary")
 
-    val result = ctx.sql("SELECT data FROM reflectBinary")
+    val result = sql("SELECT data FROM reflectBinary")
       .collect().head(0).asInstanceOf[Array[Byte]]
     assert(result.toSeq === Seq[Byte](1))
   }
@@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
         Nested(None, "abc")))
 
     Seq(data).toDF().registerTempTable("reflectComplexData")
-    assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head ===
+    assert(sql("SELECT * FROM reflectComplexData").collect().head ===
       Row(
         Seq(1, 2, 3),
         Seq(1, 2, null),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
index e55c9e460b7919da8f91f7ec521ff34313427382..45d0ee4a8e749efc0850f526a2ed7b260df006c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.sql
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.sql.test.SharedSQLContext
 
-class SerializationSuite extends SparkFunSuite {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+class SerializationSuite extends SparkFunSuite with SharedSQLContext {
 
   test("[SPARK-5235] SQLContext should be serializable") {
-    val sqlContext = new SQLContext(ctx.sparkContext)
-    new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext)
+    val _sqlContext = new SQLContext(sqlContext.sparkContext)
+    new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index ca298b2434410971da1593477e20abb468bf4464..cc95eede005d7f8e17bc43681fe1db5043e075df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -18,13 +18,12 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.Decimal
 
 
-class StringFunctionsSuite extends QueryTest {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class StringFunctionsSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("string concat") {
     val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
deleted file mode 100644
index bd9729c431f3022e972684f32862862f9e00a148..0000000000000000000000000000000000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-import org.apache.spark.sql.test._
-
-
-case class TestData(key: Int, value: String)
-
-object TestData {
-  val testData = TestSQLContext.sparkContext.parallelize(
-    (1 to 100).map(i => TestData(i, i.toString))).toDF()
-  testData.registerTempTable("testData")
-
-  val negativeData = TestSQLContext.sparkContext.parallelize(
-    (1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
-  negativeData.registerTempTable("negativeData")
-
-  case class LargeAndSmallInts(a: Int, b: Int)
-  val largeAndSmallInts =
-    TestSQLContext.sparkContext.parallelize(
-      LargeAndSmallInts(2147483644, 1) ::
-      LargeAndSmallInts(1, 2) ::
-      LargeAndSmallInts(2147483645, 1) ::
-      LargeAndSmallInts(2, 2) ::
-      LargeAndSmallInts(2147483646, 1) ::
-      LargeAndSmallInts(3, 2) :: Nil).toDF()
-  largeAndSmallInts.registerTempTable("largeAndSmallInts")
-
-  case class TestData2(a: Int, b: Int)
-  val testData2 =
-    TestSQLContext.sparkContext.parallelize(
-      TestData2(1, 1) ::
-      TestData2(1, 2) ::
-      TestData2(2, 1) ::
-      TestData2(2, 2) ::
-      TestData2(3, 1) ::
-      TestData2(3, 2) :: Nil, 2).toDF()
-  testData2.registerTempTable("testData2")
-
-  case class DecimalData(a: BigDecimal, b: BigDecimal)
-
-  val decimalData =
-    TestSQLContext.sparkContext.parallelize(
-      DecimalData(1, 1) ::
-      DecimalData(1, 2) ::
-      DecimalData(2, 1) ::
-      DecimalData(2, 2) ::
-      DecimalData(3, 1) ::
-      DecimalData(3, 2) :: Nil).toDF()
-  decimalData.registerTempTable("decimalData")
-
-  case class BinaryData(a: Array[Byte], b: Int)
-  val binaryData =
-    TestSQLContext.sparkContext.parallelize(
-      BinaryData("12".getBytes(), 1) ::
-      BinaryData("22".getBytes(), 5) ::
-      BinaryData("122".getBytes(), 3) ::
-      BinaryData("121".getBytes(), 2) ::
-      BinaryData("123".getBytes(), 4) :: Nil).toDF()
-  binaryData.registerTempTable("binaryData")
-
-  case class TestData3(a: Int, b: Option[Int])
-  val testData3 =
-    TestSQLContext.sparkContext.parallelize(
-      TestData3(1, None) ::
-      TestData3(2, Some(2)) :: Nil).toDF()
-  testData3.registerTempTable("testData3")
-
-  case class UpperCaseData(N: Int, L: String)
-  val upperCaseData =
-    TestSQLContext.sparkContext.parallelize(
-      UpperCaseData(1, "A") ::
-      UpperCaseData(2, "B") ::
-      UpperCaseData(3, "C") ::
-      UpperCaseData(4, "D") ::
-      UpperCaseData(5, "E") ::
-      UpperCaseData(6, "F") :: Nil).toDF()
-  upperCaseData.registerTempTable("upperCaseData")
-
-  case class LowerCaseData(n: Int, l: String)
-  val lowerCaseData =
-    TestSQLContext.sparkContext.parallelize(
-      LowerCaseData(1, "a") ::
-      LowerCaseData(2, "b") ::
-      LowerCaseData(3, "c") ::
-      LowerCaseData(4, "d") :: Nil).toDF()
-  lowerCaseData.registerTempTable("lowerCaseData")
-
-  case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
-  val arrayData =
-    TestSQLContext.sparkContext.parallelize(
-      ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
-      ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
-  arrayData.toDF().registerTempTable("arrayData")
-
-  case class MapData(data: scala.collection.Map[Int, String])
-  val mapData =
-    TestSQLContext.sparkContext.parallelize(
-      MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
-      MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
-      MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
-      MapData(Map(1 -> "a4", 2 -> "b4")) ::
-      MapData(Map(1 -> "a5")) :: Nil)
-  mapData.toDF().registerTempTable("mapData")
-
-  case class StringData(s: String)
-  val repeatedData =
-    TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
-  repeatedData.toDF().registerTempTable("repeatedData")
-
-  val nullableRepeatedData =
-    TestSQLContext.sparkContext.parallelize(
-      List.fill(2)(StringData(null)) ++
-      List.fill(2)(StringData("test")))
-  nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData")
-
-  case class NullInts(a: Integer)
-  val nullInts =
-    TestSQLContext.sparkContext.parallelize(
-      NullInts(1) ::
-      NullInts(2) ::
-      NullInts(3) ::
-      NullInts(null) :: Nil
-    ).toDF()
-  nullInts.registerTempTable("nullInts")
-
-  val allNulls =
-    TestSQLContext.sparkContext.parallelize(
-      NullInts(null) ::
-      NullInts(null) ::
-      NullInts(null) ::
-      NullInts(null) :: Nil).toDF()
-  allNulls.registerTempTable("allNulls")
-
-  case class NullStrings(n: Int, s: String)
-  val nullStrings =
-    TestSQLContext.sparkContext.parallelize(
-      NullStrings(1, "abc") ::
-      NullStrings(2, "ABC") ::
-      NullStrings(3, null) :: Nil).toDF()
-  nullStrings.registerTempTable("nullStrings")
-
-  case class TableName(tableName: String)
-  TestSQLContext
-    .sparkContext
-    .parallelize(TableName("test") :: Nil)
-    .toDF()
-    .registerTempTable("tableName")
-
-  val unparsedStrings =
-    TestSQLContext.sparkContext.parallelize(
-      "1, A1, true, null" ::
-      "2, B2, false, null" ::
-      "3, C3, true, null" ::
-      "4, D4, true, 2147483644" :: Nil)
-
-  case class IntField(i: Int)
-  // An RDD with 4 elements and 8 partitions
-  val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
-  withEmptyParts.toDF().registerTempTable("withEmptyParts")
-
-  case class Person(id: Int, name: String, age: Int)
-  case class Salary(personId: Int, salary: Double)
-  val person = TestSQLContext.sparkContext.parallelize(
-    Person(0, "mike", 30) ::
-    Person(1, "jim", 20) :: Nil).toDF()
-  person.registerTempTable("person")
-  val salary = TestSQLContext.sparkContext.parallelize(
-    Salary(0, 2000.0) ::
-    Salary(1, 1000.0) :: Nil).toDF()
-  salary.registerTempTable("salary")
-
-  case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
-  val complexData =
-    TestSQLContext.sparkContext.parallelize(
-      ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true)
-        :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false)
-        :: Nil).toDF()
-  complexData.registerTempTable("complexData")
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 183dc3407b3ab2c6ea7c82656a5cc9593d70f8f4..eb275af101e2fe13447ec0becf00ce8511aa19ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,16 +17,13 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
 
-case class FunctionResult(f1: String, f2: String)
+private case class FunctionResult(f1: String, f2: String)
 
-class UDFSuite extends QueryTest with SQLTestUtils {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
-
-  override def sqlContext(): SQLContext = ctx
+class UDFSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("built-in fixed arity expressions") {
     val df = ctx.emptyDataFrame
@@ -57,7 +54,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
   test("SPARK-8003 spark_partition_id") {
     val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
     df.registerTempTable("tmp_table")
-    checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
+    checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
     ctx.dropTempTable("tmp_table")
   }
 
@@ -66,9 +63,9 @@ class UDFSuite extends QueryTest with SQLTestUtils {
       val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id")
       data.write.parquet(dir.getCanonicalPath)
       ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
-      val answer = ctx.sql("select input_file_name() from test_table").head().getString(0)
+      val answer = sql("select input_file_name() from test_table").head().getString(0)
       assert(answer.contains(dir.getCanonicalPath))
-      assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2)
+      assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2)
       ctx.dropTempTable("test_table")
     }
   }
@@ -91,17 +88,17 @@ class UDFSuite extends QueryTest with SQLTestUtils {
 
   test("Simple UDF") {
     ctx.udf.register("strLenScala", (_: String).length)
-    assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4)
+    assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
   }
 
   test("ZeroArgument UDF") {
     ctx.udf.register("random0", () => { Math.random()})
-    assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0)
+    assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
   }
 
   test("TwoArgument UDF") {
     ctx.udf.register("strLenScala", (_: String).length + (_: Int))
-    assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
+    assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
   }
 
   test("UDF in a WHERE") {
@@ -112,7 +109,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     df.registerTempTable("integerData")
 
     val result =
-      ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
+      sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
     assert(result.count() === 20)
   }
 
@@ -124,7 +121,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     df.registerTempTable("groupData")
 
     val result =
-      ctx.sql(
+      sql(
         """
          | SELECT g, SUM(v) as s
          | FROM groupData
@@ -143,7 +140,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     df.registerTempTable("groupData")
 
     val result =
-      ctx.sql(
+      sql(
         """
          | SELECT SUM(v)
          | FROM groupData
@@ -163,7 +160,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     df.registerTempTable("groupData")
 
     val result =
-      ctx.sql(
+      sql(
         """
          | SELECT timesHundred(SUM(v)) as v100
          | FROM groupData
@@ -178,7 +175,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
 
     val result =
-      ctx.sql("SELECT returnStruct('test', 'test2') as ret")
+      sql("SELECT returnStruct('test', 'test2') as ret")
         .select($"ret.f1").head().getString(0)
     assert(result === "test")
   }
@@ -186,12 +183,12 @@ class UDFSuite extends QueryTest with SQLTestUtils {
   test("udf that is transformed") {
     ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
     // 1 + 1 is constant folded causing a transformation.
-    assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
+    assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
   }
 
   test("type coercion for udf inputs") {
     ctx.udf.register("intExpected", (x: Int) => x)
     // pass a decimal to intExpected.
-    assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
+    assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 9181222f6922b34899903fba3f70499c58a308b6..b6d279ae4726862bd13eaf66feddbaf2269c8539 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.OpenHashSet
@@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
   private[spark] override def asNullable: MyDenseVectorUDT = this
 }
 
-class UserDefinedTypeSuite extends QueryTest {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   private lazy val pointsRDD = Seq(
     MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
@@ -94,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest {
     ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
     pointsRDD.registerTempTable("points")
     checkAnswer(
-      ctx.sql("SELECT testType(features) from points"),
+      sql("SELECT testType(features) from points"),
       Seq(Row(true), Row(true)))
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 9bca4e7e660d6309932e5a83ec44664ade65efb3..952637c5f9cb8a254b398d6da6a6a8b1acb54cbf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -19,18 +19,16 @@ package org.apache.spark.sql.columnar
 
 import java.sql.{Date, Timestamp}
 
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, Row, TestData}
 import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
 
-class InMemoryColumnarQuerySuite extends QueryTest {
-  // Make sure the tables are loaded.
-  TestData
+class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
-  import ctx.{logicalPlanToSparkQuery, sql}
+  setupTestData()
 
   test("simple columnar query") {
     val plan = ctx.executePlan(testData.logicalPlan).executedPlan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 2c0879927a1296c4a92ce3de1e36f790ed0160a5..ab2644eb4581da03f5feacc8e7552be16019203b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -17,20 +17,19 @@
 
 package org.apache.spark.sql.columnar
 
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
 
-class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
+  import testImplicits._
 
   private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
   private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning
 
   override protected def beforeAll(): Unit = {
+    super.beforeAll()
     // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
     ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
 
@@ -44,19 +43,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
     ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
     // Enable in-memory table scan accumulators
     ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
-  }
-
-  override protected def afterAll(): Unit = {
-    ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
-    ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
-  }
-
-  before {
     ctx.cacheTable("pruningData")
   }
 
-  after {
-    ctx.uncacheTable("pruningData")
+  override protected def afterAll(): Unit = {
+    try {
+      ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
+      ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
+      ctx.uncacheTable("pruningData")
+    } finally {
+      super.afterAll()
+    }
   }
 
   // Comparisons
@@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
       expectedQueryResult: => Seq[Int]): Unit = {
 
     test(query) {
-      val df = ctx.sql(query)
+      val df = sql(query)
       val queryExecution = df.queryExecution
 
       assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 79e903c2bbd40f694ff488220b2b3107d41c8ad0..8998f5111124c0ffcb1197c4ea3b9effe9fe2fa8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
+import org.apache.spark.sql.test.SharedSQLContext
 
-class ExchangeSuite extends SparkPlanTest {
+class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
   test("shuffling UnsafeRows in exchange") {
     val input = (1 to 1000).map(Tuple1.apply)
     checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 5582caa0d366ef8725d3d6186490a16922e21540..937a108543531918b320481a127ad52568fab7c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.{execution, Row, SQLConf}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
 import org.apache.spark.sql.catalyst.plans._
@@ -27,19 +27,18 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-import org.apache.spark.sql.test.TestSQLContext.planner._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}
 
 
-class PlannerSuite extends SparkFunSuite with SQLTestUtils {
+class PlannerSuite extends SparkFunSuite with SharedSQLContext {
+  import testImplicits._
 
-  override def sqlContext: SQLContext = TestSQLContext
+  setupTestData()
 
   private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+    val _ctx = ctx
+    import _ctx.planner._
     val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
     val planned =
       plannedOption.getOrElse(
@@ -54,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
   }
 
   test("unions are collapsed") {
+    val _ctx = ctx
+    import _ctx.planner._
     val query = testData.unionAll(testData).unionAll(testData).logicalPlan
     val planned = BasicOperators(query).head
     val logicalUnions = query collect { case u: logical.Union => u }
@@ -81,14 +82,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
 
   test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
     def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
-      setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
+      ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
       val fields = fieldTypes.zipWithIndex.map {
         case (dataType, index) => StructField(s"c${index}", dataType, true)
       } :+ StructField("key", IntegerType, true)
       val schema = StructType(fields)
       val row = Row.fromSeq(Seq.fill(fields.size)(null))
-      val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
-      createDataFrame(rowRDD, schema).registerTempTable("testLimit")
+      val rowRDD = ctx.sparkContext.parallelize(row :: Nil)
+      ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
 
       val planned = sql(
         """
@@ -102,10 +103,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
       assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
       assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
 
-      dropTempTable("testLimit")
+      ctx.dropTempTable("testLimit")
     }
 
-    val origThreshold = conf.autoBroadcastJoinThreshold
+    val origThreshold = ctx.conf.autoBroadcastJoinThreshold
 
     val simpleTypes =
       NullType ::
@@ -137,18 +138,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
 
     checkPlan(complexTypes, newThreshold = 901617)
 
-    setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
+    ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
   }
 
   test("InMemoryRelation statistics propagation") {
-    val origThreshold = conf.autoBroadcastJoinThreshold
-    setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
+    val origThreshold = ctx.conf.autoBroadcastJoinThreshold
+    ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
 
     testData.limit(3).registerTempTable("tiny")
     sql("CACHE TABLE tiny")
 
     val a = testData.as("a")
-    val b = table("tiny").as("b")
+    val b = ctx.table("tiny").as("b")
     val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
 
     val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
@@ -157,12 +158,12 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
     assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
     assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
 
-    setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
+    ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
   }
 
   test("efficient limit -> project -> sort") {
     val query = testData.sort('key).select('value).limit(2).logicalPlan
-    val planned = planner.TakeOrderedAndProject(query)
+    val planned = ctx.planner.TakeOrderedAndProject(query)
     assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index dd08e9025a9274d8767c041b4df1d6b43b8ebbf0..ef6ad59b71fb312a15dd7bf59c7d73761314e570 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull}
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType}
 import org.apache.spark.unsafe.types.UTF8String
 
-class RowFormatConvertersSuite extends SparkPlanTest {
+class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
 
   private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect {
     case c: ConvertToUnsafe => c
@@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest {
 
   test("planner should insert unsafe->safe conversions when required") {
     val plan = Limit(10, outputsUnsafe)
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
   }
 
   test("filter can process unsafe rows") {
     val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(getConverters(preparedPlan).size === 1)
     assert(preparedPlan.outputsUnsafeRows)
   }
 
   test("filter can process safe rows") {
     val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(getConverters(preparedPlan).isEmpty)
     assert(!preparedPlan.outputsUnsafeRows)
   }
@@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest {
   test("union requires all of its input rows' formats to agree") {
     val plan = Union(Seq(outputsSafe, outputsUnsafe))
     assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(preparedPlan.outputsUnsafeRows)
   }
 
   test("union can process safe rows") {
     val plan = Union(Seq(outputsSafe, outputsSafe))
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(!preparedPlan.outputsUnsafeRows)
   }
 
   test("union can process unsafe rows") {
     val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(preparedPlan.outputsUnsafeRows)
   }
 
   test("round trip with ConvertToUnsafe and ConvertToSafe") {
     val input = Seq(("hello", 1), ("world", 2))
     checkAnswer(
-      TestSQLContext.createDataFrame(input),
+      ctx.createDataFrame(input),
       plan => ConvertToSafe(ConvertToUnsafe(plan)),
       input.map(Row.fromTuple)
     )
   }
 
   test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
-    SparkPlan.currentContext.set(TestSQLContext)
+    SparkPlan.currentContext.set(ctx)
     val schema = ArrayType(StringType)
     val rows = (1 to 100).map { i =>
       InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index a2c10fdaf6cdb02248365a157d8151978e7d50fa..8fa77b0fcb7b76b7788cebc18d398adff5c712ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.test.SharedSQLContext
 
-class SortSuite extends SparkPlanTest {
+class SortSuite extends SparkPlanTest with SharedSQLContext {
 
   // This test was originally added as an example of how to use [[SparkPlanTest]];
   // it's not designed to be a comprehensive test of ExternalSort.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index f46855edfe0de949c467d5a3c38cdcb62631e59c..3a87f374d94b0591090d4c3a2af58bcaa38405cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -17,29 +17,27 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
-
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
 
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.util._
+
 /**
  * Base class for writing tests for individual physical operators. For an example of how this
  * class's test helper methods can be used, see [[SortSuite]].
  */
-class SparkPlanTest extends SparkFunSuite {
-
-  protected def sqlContext: SQLContext = TestSQLContext
+private[sql] abstract class SparkPlanTest extends SparkFunSuite {
+  protected def _sqlContext: SQLContext
 
   /**
    * Creates a DataFrame from a local Seq of Product.
    */
   implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
-    sqlContext.implicits.localSeqToDataFrameHolder(data)
+    _sqlContext.implicits.localSeqToDataFrameHolder(data)
   }
 
   /**
@@ -100,7 +98,7 @@ class SparkPlanTest extends SparkFunSuite {
       planFunction: Seq[SparkPlan] => SparkPlan,
       expectedAnswer: Seq[Row],
       sortAnswers: Boolean = true): Unit = {
-    SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
+    SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
     }
@@ -124,7 +122,7 @@ class SparkPlanTest extends SparkFunSuite {
       expectedPlanFunction: SparkPlan => SparkPlan,
       sortAnswers: Boolean = true): Unit = {
     SparkPlanTest.checkAnswer(
-        input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
+        input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
     }
@@ -151,13 +149,13 @@ object SparkPlanTest {
       planFunction: SparkPlan => SparkPlan,
       expectedPlanFunction: SparkPlan => SparkPlan,
       sortAnswers: Boolean,
-      sqlContext: SQLContext): Option[String] = {
+      _sqlContext: SQLContext): Option[String] = {
 
     val outputPlan = planFunction(input.queryExecution.sparkPlan)
     val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
 
     val expectedAnswer: Seq[Row] = try {
-      executePlan(expectedOutputPlan, sqlContext)
+      executePlan(expectedOutputPlan, _sqlContext)
     } catch {
       case NonFatal(e) =>
         val errorMessage =
@@ -172,7 +170,7 @@ object SparkPlanTest {
     }
 
     val actualAnswer: Seq[Row] = try {
-      executePlan(outputPlan, sqlContext)
+      executePlan(outputPlan, _sqlContext)
     } catch {
       case NonFatal(e) =>
         val errorMessage =
@@ -212,12 +210,12 @@ object SparkPlanTest {
       planFunction: Seq[SparkPlan] => SparkPlan,
       expectedAnswer: Seq[Row],
       sortAnswers: Boolean,
-      sqlContext: SQLContext): Option[String] = {
+      _sqlContext: SQLContext): Option[String] = {
 
     val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
 
     val sparkAnswer: Seq[Row] = try {
-      executePlan(outputPlan, sqlContext)
+      executePlan(outputPlan, _sqlContext)
     } catch {
       case NonFatal(e) =>
         val errorMessage =
@@ -280,10 +278,10 @@ object SparkPlanTest {
     }
   }
 
-  private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
+  private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
     // A very simple resolver to make writing tests easier. In contrast to the real resolver
     // this is always case sensitive and does not try to handle scoping or complex type resolution.
-    val resolvedPlan = sqlContext.prepareForExecution.execute(
+    val resolvedPlan = _sqlContext.prepareForExecution.execute(
       outputPlan transform {
         case plan: SparkPlan =>
           val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index 88bce0e319f9e8696c1b567ca9b5684a53f67f8f..3158458edb8313abc107c74b8562675d914fd125 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -19,25 +19,28 @@ package org.apache.spark.sql.execution
 
 import scala.util.Random
 
-import org.scalatest.BeforeAndAfterAll
-
 import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 /**
  * A test suite that generates randomized data to test the [[TungstenSort]] operator.
  */
-class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
+class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {
 
   override def beforeAll(): Unit = {
-    TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
+    super.beforeAll()
+    ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
   }
 
   override def afterAll(): Unit = {
-    TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+    try {
+      ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+    } finally {
+      super.afterAll()
+    }
   }
 
   test("sort followed by limit") {
@@ -61,7 +64,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
   }
 
   test("sorting updates peak execution memory") {
-    val sc = TestSQLContext.sparkContext
+    val sc = ctx.sparkContext
     AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
       checkThatPlansAgree(
         (1 to 100).map(v => Tuple1(v)).toDF("a"),
@@ -80,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
   ) {
     test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
       val inputData = Seq.fill(1000)(randomDataGenerator())
-      val inputDf = TestSQLContext.createDataFrame(
-        TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
+      val inputDf = ctx.createDataFrame(
+        ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
         StructType(StructField("a", dataType, nullable = true) :: Nil)
       )
       assert(TungstenSort.supportsSchema(inputDf.schema))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index e03473041c3e9d00f3336017efb3dbc6f12d70f3..d1f0b2b1fc52f787df0e5975092155ce04d2dd9b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.Matchers
 import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
 import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
 import org.apache.spark.unsafe.types.UTF8String
@@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String
  *
  * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases.
  */
-class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
+class UnsafeFixedWidthAggregationMapSuite
+  extends SparkFunSuite
+  with Matchers
+  with SharedSQLContext {
 
   import UnsafeFixedWidthAggregationMap._
 
@@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
   }
 
   testWithMemoryLeakDetection("test external sorting") {
-    // Calling this make sure we have block manager and everything else setup.
-    TestSQLContext
-
     // Memory consumption in the beginning of the task.
     val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
 
@@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
   }
 
   testWithMemoryLeakDetection("test external sorting with an empty map") {
-    // Calling this make sure we have block manager and everything else setup.
-    TestSQLContext
 
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
@@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
   }
 
   testWithMemoryLeakDetection("test external sorting with empty records") {
-    // Calling this make sure we have block manager and everything else setup.
-    TestSQLContext
 
     // Memory consumption in the beginning of the task.
     val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index a9515a03acf2c94c41848c2952b030399019ecb0..d3be568a8758c7dcb6f32a420c220201c431052a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -23,15 +23,14 @@ import org.apache.spark._
 import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
 
 /**
  * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
  */
-class UnsafeKVExternalSorterSuite extends SparkFunSuite {
-
+class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
   private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
   private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
 
@@ -109,8 +108,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
       inputData: Seq[(InternalRow, InternalRow)],
       pageSize: Long,
       spill: Boolean): Unit = {
-    // Calling this make sure we have block manager and everything else setup.
-    TestSQLContext
 
     val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
     val shuffleMemMgr = new TestShuffleMemoryManager
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index ac22c2f3c0a58021942ad53237187f35c68ba047..5fdb82b067728c7bbbd6e8f4f786463ffe4160bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -21,15 +21,12 @@ import org.apache.spark._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.unsafe.memory.TaskMemoryManager
 
-class TungstenAggregationIteratorSuite extends SparkFunSuite {
+class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
 
   test("memory acquired on construction") {
-    // set up environment
-    val ctx = TestSQLContext
-
     val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
     val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
     TaskContext.setTaskContext(taskContext)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 73d562189781955ec22594feea0362495b6d3fa9..1174b27732f220385a3f87eeac07313ac7a2a58d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -24,22 +24,16 @@ import com.fasterxml.jackson.core.JsonFactory
 import org.apache.spark.rdd.RDD
 import org.scalactic.Tolerance._
 
-import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf}
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.{QueryTest, Row, SQLConf}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation}
 import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.util.Utils
 
-class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
-
-  protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  override def sqlContext: SQLContext = ctx // used by SQLTestUtils
-
-  import ctx.sql
-  import ctx.implicits._
+class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
+  import testImplicits._
 
   test("Type promotion") {
     def checkTypePromotion(expected: Any, actual: Any) {
@@ -596,7 +590,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
 
     val schema = StructType(StructField("a", LongType, true) :: Nil)
     val logicalRelation =
-      ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
+      ctx.read.schema(schema).json(path)
+        .queryExecution.analyzed.asInstanceOf[LogicalRelation]
     val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
     assert(relationWithSchema.paths === Array(path))
     assert(relationWithSchema.schema === schema)
@@ -1040,31 +1035,29 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
   }
 
   test("JSONRelation equality test") {
-    val context = org.apache.spark.sql.test.TestSQLContext
-
     val relation0 = new JSONRelation(
       Some(empty),
       1.0,
       Some(StructType(StructField("a", IntegerType, true) :: Nil)),
-      None, None)(context)
+      None, None)(ctx)
     val logicalRelation0 = LogicalRelation(relation0)
     val relation1 = new JSONRelation(
       Some(singleRow),
       1.0,
       Some(StructType(StructField("a", IntegerType, true) :: Nil)),
-      None, None)(context)
+      None, None)(ctx)
     val logicalRelation1 = LogicalRelation(relation1)
     val relation2 = new JSONRelation(
       Some(singleRow),
       0.5,
       Some(StructType(StructField("a", IntegerType, true) :: Nil)),
-      None, None)(context)
+      None, None)(ctx)
     val logicalRelation2 = LogicalRelation(relation2)
     val relation3 = new JSONRelation(
       Some(singleRow),
       1.0,
       Some(StructType(StructField("b", IntegerType, true) :: Nil)),
-      None, None)(context)
+      None, None)(ctx)
     val logicalRelation3 = LogicalRelation(relation3)
 
     assert(relation0 !== relation1)
@@ -1089,14 +1082,14 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
         .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
 
       val d1 = ResolvedDataSource(
-        context,
+        ctx,
         userSpecifiedSchema = None,
         partitionColumns = Array.empty[String],
         provider = classOf[DefaultSource].getCanonicalName,
         options = Map("path" -> path))
 
       val d2 = ResolvedDataSource(
-        context,
+        ctx,
         userSpecifiedSchema = None,
         partitionColumns = Array.empty[String],
         provider = classOf[DefaultSource].getCanonicalName,
@@ -1162,11 +1155,12 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
         "abd")
 
         ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
-        checkAnswer(
-          sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
-        checkAnswer(
-          sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
-        checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
+        checkAnswer(sql(
+          "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
+        checkAnswer(sql(
+          "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
+        checkAnswer(sql(
+          "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
     })
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
index 6b62c9a003df6cfc0f9df7c423d4542b51e60d85..2864181cf91d57ef296d3aff6f9220b843b72286 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
@@ -20,12 +20,11 @@ package org.apache.spark.sql.execution.datasources.json
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
 
-trait TestJsonData {
-
-  protected def ctx: SQLContext
+private[json] trait TestJsonData {
+  protected def _sqlContext: SQLContext
 
   def primitiveFieldAndType: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"string":"this is a simple string.",
           "integer":10,
           "long":21474836470,
@@ -36,7 +35,7 @@ trait TestJsonData {
       }"""  :: Nil)
 
   def primitiveFieldValueTypeConflict: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
           "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
       """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
@@ -47,14 +46,14 @@ trait TestJsonData {
           "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
 
   def jsonNullStruct: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
         """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
         """{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
         """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
 
   def complexFieldValueTypeConflict: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"num_struct":11, "str_array":[1, 2, 3],
           "array":[], "struct_array":[], "struct": {}}""" ::
       """{"num_struct":{"field":false}, "str_array":null,
@@ -65,14 +64,14 @@ trait TestJsonData {
           "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
 
   def arrayElementTypeConflict: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
           "array2": [{"field":214748364700}, {"field":1}]}""" ::
       """{"array3": [{"field":"str"}, {"field":1}]}""" ::
       """{"array3": [1, 2, 3]}""" :: Nil)
 
   def missingFields: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"a":true}""" ::
       """{"b":21474836470}""" ::
       """{"c":[33, 44]}""" ::
@@ -80,7 +79,7 @@ trait TestJsonData {
       """{"e":"str"}""" :: Nil)
 
   def complexFieldAndType1: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"struct":{"field1": true, "field2": 92233720368547758070},
           "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
           "arrayOfString":["str1", "str2"],
@@ -96,7 +95,7 @@ trait TestJsonData {
          }"""  :: Nil)
 
   def complexFieldAndType2: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
           "complexArrayOfStruct": [
           {
@@ -150,7 +149,7 @@ trait TestJsonData {
       }""" :: Nil)
 
   def mapType1: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"map": {"a": 1}}""" ::
       """{"map": {"b": 2}}""" ::
       """{"map": {"c": 3}}""" ::
@@ -158,7 +157,7 @@ trait TestJsonData {
       """{"map": {"e": null}}""" :: Nil)
 
   def mapType2: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
       """{"map": {"b": {"field2": 2}}}""" ::
       """{"map": {"c": {"field1": [], "field2": 4}}}""" ::
@@ -167,21 +166,21 @@ trait TestJsonData {
       """{"map": {"f": {"field1": null}}}""" :: Nil)
 
   def nullsInArrays: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"field1":[[null], [[["Test"]]]]}""" ::
       """{"field2":[null, [{"Test":1}]]}""" ::
       """{"field3":[[null], [{"Test":"2"}]]}""" ::
       """{"field4":[[null, [1,2,3]]]}""" :: Nil)
 
   def jsonArray: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """[{"a":"str_a_1"}]""" ::
       """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
       """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
       """[]""" :: Nil)
 
   def corruptRecords: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{""" ::
       """""" ::
       """{"a":1, b:2}""" ::
@@ -190,7 +189,7 @@ trait TestJsonData {
       """]""" :: Nil)
 
   def emptyRecords: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{""" ::
         """""" ::
         """{"a": {}}""" ::
@@ -198,9 +197,8 @@ trait TestJsonData {
         """{"b": [{"c": {}}]}""" ::
         """]""" :: Nil)
 
-  lazy val singleRow: RDD[String] =
-    ctx.sparkContext.parallelize(
-      """{"a":123}""" :: Nil)
 
-  def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]())
+  lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
+
+  def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]())
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
index 866a975ad54043f9768632b6d736ab8b334b3d07..82d40e2b61a10350af4d2cffe10bdfce9688ff48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
@@ -27,18 +27,16 @@ import org.apache.avro.generic.IndexedRecord
 import org.apache.hadoop.fs.Path
 import org.apache.parquet.avro.AvroParquetWriter
 
-import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit}
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.datasources.parquet.test.avro._
+import org.apache.spark.sql.test.SharedSQLContext
 
-class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
+class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
   import ParquetCompatibilityTest._
 
-  override val sqlContext: SQLContext = TestSQLContext
-
   private def withWriter[T <: IndexedRecord]
       (path: String, schema: Schema)
-      (f: AvroParquetWriter[T] => Unit) = {
+      (f: AvroParquetWriter[T] => Unit): Unit = {
     val writer = new AvroParquetWriter[T](new Path(path), schema)
     try f(writer) finally writer.close()
   }
@@ -129,7 +127,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
   }
 
   test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") {
-    import sqlContext.implicits._
+    import testImplicits._
 
     withTempPath { dir =>
       val path = dir.getCanonicalPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
index 0ea64aa2a509b732518770071f472ff47b14488b..b3406729fcc5eaac2a4eb01bd9168f41a19e79d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
@@ -22,16 +22,18 @@ import scala.collection.JavaConversions._
 import org.apache.hadoop.fs.{Path, PathFilter}
 import org.apache.parquet.hadoop.ParquetFileReader
 import org.apache.parquet.schema.MessageType
-import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.sql.QueryTest
 
-abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll {
-  def readParquetSchema(path: String): MessageType = {
+/**
+ * Helper class for testing Parquet compatibility.
+ */
+private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest {
+  protected def readParquetSchema(path: String): MessageType = {
     readParquetSchema(path, { path => !path.getName.startsWith("_") })
   }
 
-  def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
+  protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
     val fsPath = new Path(path)
     val fs = fsPath.getFileSystem(configuration)
     val parquetFiles = fs.listStatus(fsPath, new PathFilter {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 7dd9680d8cd65f227126e2ecbfa9d98f99e37b6d..5b4e568bb9838963f7a5765d39e48f1ee8ee4a2e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet
 import org.apache.parquet.filter2.predicate.Operators._
 import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
 
+import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
 
 /**
  * A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -39,8 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
  * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred
  *    data type is nullable.
  */
-class ParquetFilterSuite extends QueryTest with ParquetTest {
-  lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
+class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext {
 
   private def checkFilterPredicate(
       df: DataFrame,
@@ -301,7 +301,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
   }
 
   test("SPARK-6554: don't push down predicates which reference partition columns") {
-    import sqlContext.implicits._
+    import testImplicits._
 
     withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
       withTempPath { dir =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index cb166349fdb2695863b0758ee6f62b3f9514cc0e..d819f3ab5e6abcfbf18b690fe2d9e501e2e4cbc5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -37,6 +37,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
@@ -62,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
 /**
  * A test suite that tests basic Parquet I/O.
  */
-class ParquetIOSuite extends QueryTest with ParquetTest {
-  lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
-  import sqlContext.implicits._
+class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
+  import testImplicits._
 
   /**
    * Writes `data` to a Parquet file, reads it back and check file contents.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 73152de244759bf0d8f07620b69bc616381b469b..ed8bafb10c60b44ed44b8f4d03265533d9e75cad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer
 import com.google.common.io.Files
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils}
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.sql._
 import org.apache.spark.unsafe.types.UTF8String
-import PartitioningUtils._
 
 // The data where the partitioning key exists only in the directory structure.
 case class ParquetData(intField: Int, stringField: String)
@@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String)
 // The data that also includes the partitioning key
 case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
 
-class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
-
-  override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
-  import sqlContext.implicits._
-  import sqlContext.sql
+class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext {
+  import PartitioningUtils._
+  import testImplicits._
 
   val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
index 981334cf771cfc28a2e87b64d03abadfd0d1a4f6..b290429c2a021fdd0256f20add86a833f2dd29b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
@@ -17,11 +17,10 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.test.SharedSQLContext
 
-class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest {
-  override def sqlContext: SQLContext = TestSQLContext
+class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
 
   private def readParquetProtobufFile(name: String): DataFrame = {
     val url = Thread.currentThread().getContextClassLoader.getResource(name)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 5e6d9c1cd44a82efb6cf472bedef51d90763dfae..e2f2a8c74478375e78fa19c4f16f493f6a407ebe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -21,16 +21,15 @@ import java.io.File
 
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.sql.types._
 import org.apache.spark.sql.{QueryTest, Row, SQLConf}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
 /**
  * A test suite that tests various Parquet queries.
  */
-class ParquetQuerySuite extends QueryTest with ParquetTest {
-  lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
-  import sqlContext.sql
+class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext {
 
   test("simple select queries") {
     withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
@@ -41,22 +40,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
 
   test("appending") {
     val data = (0 until 10).map(i => (i, i.toString))
-    sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+    ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
     withParquetTable(data, "t") {
       sql("INSERT INTO TABLE t SELECT * FROM tmp")
-      checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
+      checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple))
     }
-    sqlContext.catalog.unregisterTable(Seq("tmp"))
+    ctx.catalog.unregisterTable(Seq("tmp"))
   }
 
   test("overwriting") {
     val data = (0 until 10).map(i => (i, i.toString))
-    sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+    ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
     withParquetTable(data, "t") {
       sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
-      checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
+      checkAnswer(ctx.table("t"), data.map(Row.fromTuple))
     }
-    sqlContext.catalog.unregisterTable(Seq("tmp"))
+    ctx.catalog.unregisterTable(Seq("tmp"))
   }
 
   test("self-join") {
@@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
     val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
       StructField("time", TimestampType, false)).toArray)
     withTempPath { file =>
-      val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema)
+      val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema)
       df.write.parquet(file.getCanonicalPath)
-      val df2 = sqlContext.read.parquet(file.getCanonicalPath)
+      val df2 = ctx.read.parquet(file.getCanonicalPath)
       checkAnswer(df2, df.collect().toSeq)
     }
   }
@@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
     def testSchemaMerging(expectedColumnNumber: Int): Unit = {
       withTempDir { dir =>
         val basePath = dir.getCanonicalPath
-        sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
-        sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
+        ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+        ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
         // delete summary files, so if we don't merge part-files, one column will not be included.
         Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata"))
         Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata"))
-        assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
+        assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
       }
     }
 
@@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
     def testSchemaMerging(expectedColumnNumber: Int): Unit = {
       withTempDir { dir =>
         val basePath = dir.getCanonicalPath
-        sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
-        sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
-        assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
+        ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+        ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
+        assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
       }
     }
 
@@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
   test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") {
     withTempPath { dir =>
       val basePath = dir.getCanonicalPath
-      sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
-      sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
+      ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+      ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
 
       // Disables the global SQL option for schema merging
       withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") {
         assertResult(2) {
           // Disables schema merging via data source option
-          sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length
+          ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length
         }
 
         assertResult(3) {
           // Enables schema merging via data source option
-          sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length
+          ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length
         }
       }
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 971f71e27bfc613d18b9b41e20f5c516be2d0dbf..9dcbc1a047beac4398dfff4122e5b9938e0e6be2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -22,13 +22,11 @@ import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.parquet.schema.MessageTypeParser
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
-abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest {
-  val sqlContext = TestSQLContext
+abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext {
 
   /**
    * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 3c6e54db4bca723b264885674831747cd0a774a2..5dbc7d1630f27e25022fbcf0b512606689f9df97 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -22,9 +22,8 @@ import java.io.File
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.TypeTag
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.{DataFrame, SaveMode}
+import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
 
 /**
  * A helper trait that provides convenient facilities for Parquet testing.
@@ -33,7 +32,9 @@ import org.apache.spark.sql.{DataFrame, SaveMode}
  * convenient to use tuples rather than special case classes when writing test cases/suites.
  * Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
  */
-private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
+private[sql] trait ParquetTest extends SQLTestUtils {
+  protected def _sqlContext: SQLContext
+
   /**
    * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
    * returns.
@@ -42,7 +43,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
       (data: Seq[T])
       (f: String => Unit): Unit = {
     withTempPath { file =>
-      sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
+      _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
       f(file.getCanonicalPath)
     }
   }
@@ -54,7 +55,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
   protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
       (data: Seq[T])
       (f: DataFrame => Unit): Unit = {
-    withParquetFile(data)(path => f(sqlContext.read.parquet(path)))
+    withParquetFile(data)(path => f(_sqlContext.read.parquet(path)))
   }
 
   /**
@@ -66,14 +67,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
       (data: Seq[T], tableName: String)
       (f: => Unit): Unit = {
     withParquetDataFrame(data) { df =>
-      sqlContext.registerDataFrameAsTable(df, tableName)
+      _sqlContext.registerDataFrameAsTable(df, tableName)
       withTempTable(tableName)(f)
     }
   }
 
   protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
       data: Seq[T], path: File): Unit = {
-    sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
+    _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
   }
 
   protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
index 92b1d822172d566df2bca56674f13707aa26c9e1..b789c5a106e56ead9c9f461ec4f00bf2c75870b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
@@ -17,14 +17,12 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.test.SharedSQLContext
 
-class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest {
+class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
   import ParquetCompatibilityTest._
 
-  override val sqlContext: SQLContext = TestSQLContext
-
   private val parquetFilePath =
     Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet")
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index 239deb79738459c20d4807f63506efeb4a4c1b73..22189477d277dc6dca189bd2e08b36ea70926d9a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -18,10 +18,10 @@
 package org.apache.spark.sql.execution.debug
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
 
-class DebuggingSuite extends SparkFunSuite {
   test("DataFrame.debug()") {
     testData.debug()
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index d33a967093ca5edd66b58ea43d0ee76cd4bd4325..4c9187a9a7106c7e2b886d12b6f3afc65bcf4c5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -23,12 +23,12 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 import org.apache.spark.util.collection.CompactBuffer
 
 
-class HashedRelationSuite extends SparkFunSuite {
+class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
 
   // Key is simply the record itself
   private val keyProjection = new Projection {
@@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite {
 
   test("GeneralHashedRelation") {
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
-    val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
+    val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
     val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
     assert(hashed.isInstanceOf[GeneralHashedRelation])
 
@@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite {
 
   test("UniqueKeyHashedRelation") {
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
-    val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
+    val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
     val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
     assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
 
@@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite {
   test("UnsafeHashedRelation") {
     val schema = StructType(StructField("a", IntegerType, true) :: Nil)
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
-    val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
+    val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
     val toUnsafe = UnsafeProjection.create(schema)
     val unsafeData = data.map(toUnsafe(_).copy()).toArray
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index ddff7cebcc17d8a788fd5db04dac1de5860227c9..cc649b9bd4c450692895491135ad593fe53fbf8b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -17,97 +17,19 @@
 
 package org.apache.spark.sql.execution.joins
 
+import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf}
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
-import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame}
-import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.execution._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
 
-class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
-
-  private def testInnerJoin(
-      testName: String,
-      leftRows: DataFrame,
-      rightRows: DataFrame,
-      condition: Expression,
-      expectedAnswer: Seq[Product]): Unit = {
-    val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
-    ExtractEquiJoinKeys.unapply(join).foreach {
-      case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
-
-        def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
-          val broadcastHashJoin =
-            execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
-          boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
-        }
-
-        def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
-          val shuffledHashJoin =
-            execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
-          val filteredJoin =
-            boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
-          EnsureRequirements(sqlContext).apply(filteredJoin)
-        }
-
-        def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
-          val sortMergeJoin =
-            execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
-          val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
-          EnsureRequirements(sqlContext).apply(filteredJoin)
-        }
-
-        test(s"$testName using BroadcastHashJoin (build=left)") {
-          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-              makeBroadcastHashJoin(left, right, joins.BuildLeft),
-              expectedAnswer.map(Row.fromTuple),
-              sortAnswers = true)
-          }
-        }
-
-        test(s"$testName using BroadcastHashJoin (build=right)") {
-          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-              makeBroadcastHashJoin(left, right, joins.BuildRight),
-              expectedAnswer.map(Row.fromTuple),
-              sortAnswers = true)
-          }
-        }
-
-        test(s"$testName using ShuffledHashJoin (build=left)") {
-          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-              makeShuffledHashJoin(left, right, joins.BuildLeft),
-              expectedAnswer.map(Row.fromTuple),
-              sortAnswers = true)
-          }
-        }
-
-        test(s"$testName using ShuffledHashJoin (build=right)") {
-          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-              makeShuffledHashJoin(left, right, joins.BuildRight),
-              expectedAnswer.map(Row.fromTuple),
-              sortAnswers = true)
-          }
-        }
+class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
 
-        test(s"$testName using SortMergeJoin") {
-          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-              makeSortMergeJoin(left, right),
-              expectedAnswer.map(Row.fromTuple),
-              sortAnswers = true)
-          }
-        }
-    }
-  }
-
-  {
-    val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+  private lazy val myUpperCaseData = ctx.createDataFrame(
+    ctx.sparkContext.parallelize(Seq(
       Row(1, "A"),
       Row(2, "B"),
       Row(3, "C"),
@@ -117,7 +39,8 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
       Row(null, "G")
     )), new StructType().add("N", IntegerType).add("L", StringType))
 
-    val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+  private lazy val myLowerCaseData = ctx.createDataFrame(
+    ctx.sparkContext.parallelize(Seq(
       Row(1, "a"),
       Row(2, "b"),
       Row(3, "c"),
@@ -125,21 +48,7 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
       Row(null, "e")
     )), new StructType().add("n", IntegerType).add("l", StringType))
 
-    testInnerJoin(
-      "inner join, one match per row",
-      upperCaseData,
-      lowerCaseData,
-      (upperCaseData.col("N") === lowerCaseData.col("n")).expr,
-      Seq(
-        (1, "A", 1, "a"),
-        (2, "B", 2, "b"),
-        (3, "C", 3, "c"),
-        (4, "D", 4, "d")
-      )
-    )
-  }
-
-  private val testData2 = Seq(
+  private lazy val myTestData = Seq(
     (1, 1),
     (1, 2),
     (2, 1),
@@ -148,14 +57,139 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
     (3, 2)
   ).toDF("a", "b")
 
+  // Note: the input dataframes and expression must be evaluated lazily because
+  // the SQLContext should be used only within a test to keep SQL tests stable
+  private def testInnerJoin(
+      testName: String,
+      leftRows: => DataFrame,
+      rightRows: => DataFrame,
+      condition: () => Expression,
+      expectedAnswer: Seq[Product]): Unit = {
+
+    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
+      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition()))
+      ExtractEquiJoinKeys.unapply(join)
+    }
+
+    def makeBroadcastHashJoin(
+        leftKeys: Seq[Expression],
+        rightKeys: Seq[Expression],
+        boundCondition: Option[Expression],
+        leftPlan: SparkPlan,
+        rightPlan: SparkPlan,
+        side: BuildSide) = {
+      val broadcastHashJoin =
+        execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
+      boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
+    }
+
+    def makeShuffledHashJoin(
+        leftKeys: Seq[Expression],
+        rightKeys: Seq[Expression],
+        boundCondition: Option[Expression],
+        leftPlan: SparkPlan,
+        rightPlan: SparkPlan,
+        side: BuildSide) = {
+      val shuffledHashJoin =
+        execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
+      val filteredJoin =
+        boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
+      EnsureRequirements(sqlContext).apply(filteredJoin)
+    }
+
+    def makeSortMergeJoin(
+        leftKeys: Seq[Expression],
+        rightKeys: Seq[Expression],
+        boundCondition: Option[Expression],
+        leftPlan: SparkPlan,
+        rightPlan: SparkPlan) = {
+      val sortMergeJoin =
+        execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
+      val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
+      EnsureRequirements(sqlContext).apply(filteredJoin)
+    }
+
+    test(s"$testName using BroadcastHashJoin (build=left)") {
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+          checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+            makeBroadcastHashJoin(
+              leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
+            expectedAnswer.map(Row.fromTuple),
+            sortAnswers = true)
+        }
+      }
+    }
+
+    test(s"$testName using BroadcastHashJoin (build=right)") {
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+          checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+            makeBroadcastHashJoin(
+              leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
+            expectedAnswer.map(Row.fromTuple),
+            sortAnswers = true)
+        }
+      }
+    }
+
+    test(s"$testName using ShuffledHashJoin (build=left)") {
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+          checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+            makeShuffledHashJoin(
+              leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
+            expectedAnswer.map(Row.fromTuple),
+            sortAnswers = true)
+        }
+      }
+    }
+
+    test(s"$testName using ShuffledHashJoin (build=right)") {
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+          checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+            makeShuffledHashJoin(
+              leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
+            expectedAnswer.map(Row.fromTuple),
+            sortAnswers = true)
+        }
+      }
+    }
+
+    test(s"$testName using SortMergeJoin") {
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+          checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+            makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan),
+            expectedAnswer.map(Row.fromTuple),
+            sortAnswers = true)
+        }
+      }
+    }
+  }
+
+  testInnerJoin(
+    "inner join, one match per row",
+    myUpperCaseData,
+    myLowerCaseData,
+    () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr,
+    Seq(
+      (1, "A", 1, "a"),
+      (2, "B", 2, "b"),
+      (3, "C", 3, "c"),
+      (4, "D", 4, "d")
+    )
+  )
+
   {
-    val left = testData2.where("a = 1")
-    val right = testData2.where("a = 1")
+    lazy val left = myTestData.where("a = 1")
+    lazy val right = myTestData.where("a = 1")
     testInnerJoin(
       "inner join, multiple matches",
       left,
       right,
-      (left.col("a") === right.col("a")).expr,
+      () => (left.col("a") === right.col("a")).expr,
       Seq(
         (1, 1, 1, 1),
         (1, 1, 1, 2),
@@ -166,13 +200,13 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
   }
 
   {
-    val left = testData2.where("a = 1")
-    val right = testData2.where("a = 2")
+    lazy val left = myTestData.where("a = 1")
+    lazy val right = myTestData.where("a = 2")
     testInnerJoin(
       "inner join, no matches",
       left,
       right,
-      (left.col("a") === right.col("a")).expr,
+      () => (left.col("a") === right.col("a")).expr,
       Seq.empty
     )
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index e16f5e39aa2f4d6ec5aaacdb268b7a9a2ed765e1..a1a617d7b73987ec9fc217d09ac17a40af2d4c1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -17,28 +17,65 @@
 
 package org.apache.spark.sql.execution.joins
 
+import org.apache.spark.sql.{DataFrame, Row, SQLConf}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
+import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
-import org.apache.spark.sql.{SQLConf, DataFrame, Row}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest}
 
-class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
+class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
+
+  private lazy val left = ctx.createDataFrame(
+    ctx.sparkContext.parallelize(Seq(
+      Row(1, 2.0),
+      Row(2, 100.0),
+      Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
+      Row(2, 1.0),
+      Row(3, 3.0),
+      Row(5, 1.0),
+      Row(6, 6.0),
+      Row(null, null)
+    )), new StructType().add("a", IntegerType).add("b", DoubleType))
+
+  private lazy val right = ctx.createDataFrame(
+    ctx.sparkContext.parallelize(Seq(
+      Row(0, 0.0),
+      Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
+      Row(2, -1.0),
+      Row(2, -1.0),
+      Row(2, 3.0),
+      Row(3, 2.0),
+      Row(4, 1.0),
+      Row(5, 3.0),
+      Row(7, 7.0),
+      Row(null, null)
+    )), new StructType().add("c", IntegerType).add("d", DoubleType))
+
+  private lazy val condition = {
+    And((left.col("a") === right.col("c")).expr,
+      LessThan(left.col("b").expr, right.col("d").expr))
+  }
 
+  // Note: the input dataframes and expression must be evaluated lazily because
+  // the SQLContext should be used only within a test to keep SQL tests stable
   private def testOuterJoin(
       testName: String,
-      leftRows: DataFrame,
-      rightRows: DataFrame,
+      leftRows: => DataFrame,
+      rightRows: => DataFrame,
       joinType: JoinType,
-      condition: Expression,
+      condition: => Expression,
       expectedAnswer: Seq[Product]): Unit = {
-    val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
-    ExtractEquiJoinKeys.unapply(join).foreach {
-      case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
-        test(s"$testName using ShuffledHashOuterJoin") {
+
+    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
+      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
+      ExtractEquiJoinKeys.unapply(join)
+    }
+
+    test(s"$testName using ShuffledHashOuterJoin") {
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
           withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
             checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
               EnsureRequirements(sqlContext).apply(
@@ -46,19 +83,23 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
               expectedAnswer.map(Row.fromTuple),
               sortAnswers = true)
           }
-        }
+      }
+    }
 
-        if (joinType != FullOuter) {
-          test(s"$testName using BroadcastHashOuterJoin") {
+    if (joinType != FullOuter) {
+      test(s"$testName using BroadcastHashOuterJoin") {
+        extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
             withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
               checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
                 BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
                 expectedAnswer.map(Row.fromTuple),
                 sortAnswers = true)
             }
-          }
+        }
+      }
 
-          test(s"$testName using SortMergeOuterJoin") {
+      test(s"$testName using SortMergeOuterJoin") {
+        extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
             withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
               checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
                 EnsureRequirements(sqlContext).apply(
@@ -66,57 +107,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
                 expectedAnswer.map(Row.fromTuple),
                 sortAnswers = false)
             }
-          }
         }
-    }
-
-    test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
-      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-          joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
-          expectedAnswer.map(Row.fromTuple),
-          sortAnswers = true)
       }
     }
-
-    test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
-      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-          joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
-          expectedAnswer.map(Row.fromTuple),
-          sortAnswers = true)
-      }
-    }
-  }
-
-  val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
-    Row(1, 2.0),
-    Row(2, 100.0),
-    Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
-    Row(2, 1.0),
-    Row(3, 3.0),
-    Row(5, 1.0),
-    Row(6, 6.0),
-    Row(null, null)
-  )), new StructType().add("a", IntegerType).add("b", DoubleType))
-
-  val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
-    Row(0, 0.0),
-    Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
-    Row(2, -1.0),
-    Row(2, -1.0),
-    Row(2, 3.0),
-    Row(3, 2.0),
-    Row(4, 1.0),
-    Row(5, 3.0),
-    Row(7, 7.0),
-    Row(null, null)
-  )), new StructType().add("c", IntegerType).add("d", DoubleType))
-
-  val condition = {
-    And(
-      (left.col("a") === right.col("c")).expr,
-      LessThan(left.col("b").expr, right.col("d").expr))
   }
 
   // --- Basic outer joins ------------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 4503ed251fcb1b667522be8db5577ca03e496c2f..baa86e320d986275b4ec2bc60cc0a143bd5598cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -17,44 +17,80 @@
 
 package org.apache.spark.sql.execution.joins
 
+import org.apache.spark.sql.{SQLConf, DataFrame, Row}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
-import org.apache.spark.sql.{SQLConf, DataFrame, Row}
 import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
 import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
+
+class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
 
-class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
+  private lazy val left = ctx.createDataFrame(
+    ctx.sparkContext.parallelize(Seq(
+      Row(1, 2.0),
+      Row(1, 2.0),
+      Row(2, 1.0),
+      Row(2, 1.0),
+      Row(3, 3.0),
+      Row(null, null),
+      Row(null, 5.0),
+      Row(6, null)
+    )), new StructType().add("a", IntegerType).add("b", DoubleType))
 
+  private lazy val right = ctx.createDataFrame(
+    ctx.sparkContext.parallelize(Seq(
+      Row(2, 3.0),
+      Row(2, 3.0),
+      Row(3, 2.0),
+      Row(4, 1.0),
+      Row(null, null),
+      Row(null, 5.0),
+      Row(6, null)
+    )), new StructType().add("c", IntegerType).add("d", DoubleType))
+
+  private lazy val condition = {
+    And((left.col("a") === right.col("c")).expr,
+      LessThan(left.col("b").expr, right.col("d").expr))
+  }
+
+  // Note: the input dataframes and expression must be evaluated lazily because
+  // the SQLContext should be used only within a test to keep SQL tests stable
   private def testLeftSemiJoin(
       testName: String,
-      leftRows: DataFrame,
-      rightRows: DataFrame,
-      condition: Expression,
+      leftRows: => DataFrame,
+      rightRows: => DataFrame,
+      condition: => Expression,
       expectedAnswer: Seq[Product]): Unit = {
-    val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
-    ExtractEquiJoinKeys.unapply(join).foreach {
-      case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
-        test(s"$testName using LeftSemiJoinHash") {
-          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-              EnsureRequirements(left.sqlContext).apply(
-                LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
-              expectedAnswer.map(Row.fromTuple),
-              sortAnswers = true)
-          }
+
+    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
+      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
+      ExtractEquiJoinKeys.unapply(join)
+    }
+
+    test(s"$testName using LeftSemiJoinHash") {
+      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+            EnsureRequirements(left.sqlContext).apply(
+              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
+            expectedAnswer.map(Row.fromTuple),
+            sortAnswers = true)
         }
+      }
+    }
 
-        test(s"$testName using BroadcastLeftSemiJoinHash") {
-          withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
-            checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-              BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
-              expectedAnswer.map(Row.fromTuple),
-              sortAnswers = true)
-          }
+    test(s"$testName using BroadcastLeftSemiJoinHash") {
+      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
+            expectedAnswer.map(Row.fromTuple),
+            sortAnswers = true)
         }
+      }
     }
 
     test(s"$testName using LeftSemiJoinBNL") {
@@ -67,33 +103,6 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
     }
   }
 
-  val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
-    Row(1, 2.0),
-    Row(1, 2.0),
-    Row(2, 1.0),
-    Row(2, 1.0),
-    Row(3, 3.0),
-    Row(null, null),
-    Row(null, 5.0),
-    Row(6, null)
-  )), new StructType().add("a", IntegerType).add("b", DoubleType))
-
-  val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
-    Row(2, 3.0),
-    Row(2, 3.0),
-    Row(3, 2.0),
-    Row(4, 1.0),
-    Row(null, null),
-    Row(null, 5.0),
-    Row(6, null)
-  )), new StructType().add("c", IntegerType).add("d", DoubleType))
-
-  val condition = {
-    And(
-      (left.col("a") === right.col("c")).expr,
-      LessThan(left.col("b").expr, right.col("d").expr))
-  }
-
   testLeftSemiJoin(
     "basic test",
     left,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 7383d3f8fe024fbf491a0404ff7c4e64d64ab4a4..80006bf077fe8498022e55fdc35c112d0a790125 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -28,17 +28,15 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
 import org.apache.spark.sql.execution.ui.SparkPlanGraph
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.util.Utils
 
-class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
 
-  override val sqlContext = TestSQLContext
-
-  import sqlContext.implicits._
+class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
+  import testImplicits._
 
   test("LongSQLMetric should not box Long") {
-    val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long")
+    val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long")
     val f = () => {
       l += 1L
       l.add(1L)
@@ -52,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
 
   test("Normal accumulator should do boxing") {
     // We need this test to make sure BoxingFinder works.
-    val l = TestSQLContext.sparkContext.accumulator(0L)
+    val l = ctx.sparkContext.accumulator(0L)
     val f = () => { l += 1L }
     BoxingFinder.getClassReader(f.getClass).foreach { cl =>
       val boxingFinder = new BoxingFinder()
@@ -73,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
       df: DataFrame,
       expectedNumOfJobs: Int,
       expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
-    val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
+    val previousExecutionIds = ctx.listener.executionIdToData.keySet
     df.collect()
-    TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
-    val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+    ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
+    val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
     assert(executionIds.size === 1)
     val executionId = executionIds.head
-    val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
+    val jobs = ctx.listener.getExecution(executionId).get.jobs
     // Use "<=" because there is a race condition that we may miss some jobs
     // TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
     assert(jobs.size <= expectedNumOfJobs)
     if (jobs.size == expectedNumOfJobs) {
       // If we can track all jobs, check the metric values
-      val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
+      val metricValues = ctx.listener.getExecutionMetrics(executionId)
       val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node =>
         expectedMetrics.contains(node.id)
       }.map { node =>
@@ -111,7 +109,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
       SQLConf.TUNGSTEN_ENABLED.key -> "false") {
       // Assume the execution plan is
       // PhysicalRDD(nodeId = 1) -> Project(nodeId = 0)
-      val df = TestData.person.select('name)
+      val df = person.select('name)
       testSparkPlanMetrics(df, 1, Map(
         0L ->("Project", Map(
           "number of rows" -> 2L)))
@@ -126,7 +124,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
       SQLConf.TUNGSTEN_ENABLED.key -> "true") {
       // Assume the execution plan is
       // PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0)
-      val df = TestData.person.select('name)
+      val df = person.select('name)
       testSparkPlanMetrics(df, 1, Map(
         0L ->("TungstenProject", Map(
           "number of rows" -> 2L)))
@@ -137,7 +135,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
   test("Filter metrics") {
     // Assume the execution plan is
     // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0)
-    val df = TestData.person.filter('age < 25)
+    val df = person.filter('age < 25)
     testSparkPlanMetrics(df, 1, Map(
       0L -> ("Filter", Map(
         "number of input rows" -> 2L,
@@ -152,7 +150,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
       SQLConf.TUNGSTEN_ENABLED.key -> "false") {
       // Assume the execution plan is
       // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0)
-      val df = TestData.testData2.groupBy().count() // 2 partitions
+      val df = testData2.groupBy().count() // 2 partitions
       testSparkPlanMetrics(df, 1, Map(
         2L -> ("Aggregate", Map(
           "number of input rows" -> 6L,
@@ -163,7 +161,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
       )
 
       // 2 partitions and each partition contains 2 keys
-      val df2 = TestData.testData2.groupBy('a).count()
+      val df2 = testData2.groupBy('a).count()
       testSparkPlanMetrics(df2, 1, Map(
         2L -> ("Aggregate", Map(
           "number of input rows" -> 6L,
@@ -185,7 +183,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
       // Assume the execution plan is
       // ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) ->
       // SortBasedAggregate(nodeId = 0)
-      val df = TestData.testData2.groupBy().count() // 2 partitions
+      val df = testData2.groupBy().count() // 2 partitions
       testSparkPlanMetrics(df, 1, Map(
         2L -> ("SortBasedAggregate", Map(
           "number of input rows" -> 6L,
@@ -199,7 +197,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
       // ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2)
       // -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0)
       // 2 partitions and each partition contains 2 keys
-      val df2 = TestData.testData2.groupBy('a).count()
+      val df2 = testData2.groupBy('a).count()
       testSparkPlanMetrics(df2, 1, Map(
         3L -> ("SortBasedAggregate", Map(
           "number of input rows" -> 6L,
@@ -219,7 +217,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
       // Assume the execution plan is
       // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1)
       // -> TungstenAggregate(nodeId = 0)
-      val df = TestData.testData2.groupBy().count() // 2 partitions
+      val df = testData2.groupBy().count() // 2 partitions
       testSparkPlanMetrics(df, 1, Map(
         2L -> ("TungstenAggregate", Map(
           "number of input rows" -> 6L,
@@ -230,7 +228,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
       )
 
       // 2 partitions and each partition contains 2 keys
-      val df2 = TestData.testData2.groupBy('a).count()
+      val df2 = testData2.groupBy('a).count()
       testSparkPlanMetrics(df2, 1, Map(
         2L -> ("TungstenAggregate", Map(
           "number of input rows" -> 6L,
@@ -246,7 +244,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
     // Because SortMergeJoin may skip different rows if the number of partitions is different, this
     // test should use the deterministic number of partitions.
     withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
-      val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+      val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
       testDataForJoin.registerTempTable("testDataForJoin")
       withTempTable("testDataForJoin") {
         // Assume the execution plan is
@@ -268,7 +266,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
     // Because SortMergeOuterJoin may skip different rows if the number of partitions is different,
     // this test should use the deterministic number of partitions.
     withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
-      val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+      val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
       testDataForJoin.registerTempTable("testDataForJoin")
       withTempTable("testDataForJoin") {
         // Assume the execution plan is
@@ -314,7 +312,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
 
   test("ShuffledHashJoin metrics") {
     withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
-      val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+      val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
       testDataForJoin.registerTempTable("testDataForJoin")
       withTempTable("testDataForJoin") {
         // Assume the execution plan is
@@ -390,7 +388,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
 
   test("BroadcastNestedLoopJoin metrics") {
     withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
-      val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+      val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
       testDataForJoin.registerTempTable("testDataForJoin")
       withTempTable("testDataForJoin") {
         // Assume the execution plan is
@@ -458,7 +456,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
   }
 
   test("CartesianProduct metrics") {
-    val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+    val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
     testDataForJoin.registerTempTable("testDataForJoin")
     withTempTable("testDataForJoin") {
       // Assume the execution plan is
@@ -476,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
 
   test("save metrics") {
     withTempPath { file =>
-      val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
+      val previousExecutionIds = ctx.listener.executionIdToData.keySet
       // Assume the execution plan is
       // PhysicalRDD(nodeId = 0)
-      TestData.person.select('name).write.format("json").save(file.getAbsolutePath)
-      TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
-      val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+      person.select('name).write.format("json").save(file.getAbsolutePath)
+      ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
+      val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
       assert(executionIds.size === 1)
       val executionId = executionIds.head
-      val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
+      val jobs = ctx.listener.getExecution(executionId).get.jobs
       // Use "<=" because there is a race condition that we may miss some jobs
       // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
       assert(jobs.size <= 1)
-      val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
+      val metricValues = ctx.listener.getExecutionMetrics(executionId)
       // Because "save" will create a new DataFrame internally, we cannot get the real metric id.
       // However, we still can check the value.
       assert(metricValues.values.toSeq === Seq(2L))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 41dd1896c15df9400269e3f21741df63691271ad..80d1e88956949e5e042f77e923996cec89b13534 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -25,12 +25,12 @@ import org.apache.spark.sql.execution.metric.LongSQLMetricValue
 import org.apache.spark.scheduler._
 import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 
-class SQLListenerSuite extends SparkFunSuite {
+class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
+  import testImplicits._
 
   private def createTestDataFrame: DataFrame = {
-    import TestSQLContext.implicits._
     Seq(
       (1, 1),
       (2, 2)
@@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite {
   }
 
   test("basic") {
-    val listener = new SQLListener(TestSQLContext)
+    val listener = new SQLListener(ctx)
     val executionId = 0
     val df = createTestDataFrame
     val accumulatorIds =
@@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite {
   }
 
   test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {
-    val listener = new SQLListener(TestSQLContext)
+    val listener = new SQLListener(ctx)
     val executionId = 0
     val df = createTestDataFrame
     listener.onExecutionStart(
@@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite {
   }
 
   test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") {
-    val listener = new SQLListener(TestSQLContext)
+    val listener = new SQLListener(ctx)
     val executionId = 0
     val df = createTestDataFrame
     listener.onExecutionStart(
@@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite {
   }
 
   test("onExecutionEnd happens before onJobEnd(JobFailed)") {
-    val listener = new SQLListener(TestSQLContext)
+    val listener = new SQLListener(ctx)
     val executionId = 0
     val df = createTestDataFrame
     listener.onExecutionStart(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index e4dcf4c75d208b79a4274fd662033a9df9c97231..0edac0848c3bb8d168170f25223a6d9079ccdb3b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -25,10 +25,13 @@ import org.h2.jdbc.JdbcSQLException
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
-class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
+class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
+  import testImplicits._
+
   val url = "jdbc:h2:mem:testdb0"
   val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
   var conn: java.sql.Connection = null
@@ -42,10 +45,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
       Some(StringType)
   }
 
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
-  import ctx.sql
-
   before {
     Utils.classForName("org.h2.Driver")
     // Extra properties that will be specified for our database. We need these to test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 84b52ca2c733cd38630bbc53affb090f8890a6dc..5dc3a2c07b8c72b2fff64bf958e77f20c8da1bc8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -23,11 +23,13 @@ import java.util.Properties
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{SaveMode, Row}
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
-class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
+class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
+
   val url = "jdbc:h2:mem:testdb2"
   var conn: java.sql.Connection = null
   val url1 = "jdbc:h2:mem:testdb3"
@@ -37,10 +39,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
   properties.setProperty("password", "testPass")
   properties.setProperty("rowId", "false")
 
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
-  import ctx.sql
-
   before {
     Utils.classForName("org.h2.Driver")
     conn = DriverManager.getConnection(url)
@@ -58,14 +56,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
       "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
     conn1.commit()
 
-    ctx.sql(
+    sql(
       s"""
         |CREATE TEMPORARY TABLE PEOPLE
         |USING org.apache.spark.sql.jdbc
         |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
       """.stripMargin.replaceAll("\n", " "))
 
-    ctx.sql(
+    sql(
       s"""
         |CREATE TEMPORARY TABLE PEOPLE1
         |USING org.apache.spark.sql.jdbc
@@ -144,14 +142,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
   }
 
   test("INSERT to JDBC Datasource") {
-    ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+    sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
     assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
     assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
   }
 
   test("INSERT to JDBC Datasource with overwrite") {
-    ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
-    ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
+    sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+    sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
     assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
     assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 562c279067048063f6a09504e59cdaa97b84f9d4..9bc3f6bcf6fce5a2b3f00750743be33a80006ea5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -19,28 +19,32 @@ package org.apache.spark.sql.sources
 
 import java.io.{File, IOException}
 
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.execution.datasources.DDLException
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.util.Utils
 
-class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
-
-  import caseInsensitiveContext.sql
 
+class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
+  protected override lazy val sql = caseInsensitiveContext.sql _
   private lazy val sparkContext = caseInsensitiveContext.sparkContext
-
-  var path: File = null
+  private var path: File = null
 
   override def beforeAll(): Unit = {
+    super.beforeAll()
     path = Utils.createTempDir()
     val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
     caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
   }
 
   override def afterAll(): Unit = {
-    caseInsensitiveContext.dropTempTable("jt")
+    try {
+      caseInsensitiveContext.dropTempTable("jt")
+    } finally {
+      super.afterAll()
+    }
   }
 
   after {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala
index 392da0b0826b51e772950f636f3b8a206d24f326..853707c036c9a4b2e7d47a05f21f644b04c7244c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala
@@ -18,11 +18,12 @@
 package org.apache.spark.sql.sources
 
 import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{StringType, StructField, StructType}
 
 
 // please note that the META-INF/services had to be modified for the test directory for this to work
-class DDLSourceLoadSuite extends DataSourceTest {
+class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {
 
   test("data sources with the same name") {
     intercept[RuntimeException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
index 84855ce45e91804af6877bf9ddc82b4d7d23ffee..5f8514e1a24113046c8eddff402791f6f3202667 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -68,10 +69,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
   }
 }
 
-class DDLTestSuite extends DataSourceTest {
+class DDLTestSuite extends DataSourceTest with SharedSQLContext {
+  protected override lazy val sql = caseInsensitiveContext.sql _
 
-  before {
-    caseInsensitiveContext.sql(
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sql(
       """
       |CREATE TEMPORARY TABLE ddlPeople
       |USING org.apache.spark.sql.sources.DDLScanSource
@@ -105,7 +108,7 @@ class DDLTestSuite extends DataSourceTest {
       ))
 
   test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
-    val attributes = caseInsensitiveContext.sql("describe ddlPeople")
+    val attributes = sql("describe ddlPeople")
       .queryExecution.executedPlan.output
     assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
     assert(attributes.map(_.dataType).toSet === Set(StringType))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 00cc7d5ea580f0c523b07b844c4c9e938a0e2c0b..d74d29fb0beb00b03c0f786e22f015e5fefecbd6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -17,18 +17,23 @@
 
 package org.apache.spark.sql.sources
 
-import org.scalatest.BeforeAndAfter
-
 import org.apache.spark.sql._
-import org.apache.spark.sql.test.TestSQLContext
 
 
-abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
+private[sql] abstract class DataSourceTest extends QueryTest {
+  protected def _sqlContext: SQLContext
+
   // We want to test some edge cases.
-  protected implicit lazy val caseInsensitiveContext = {
-    val ctx = new SQLContext(TestSQLContext.sparkContext)
+  protected lazy val caseInsensitiveContext: SQLContext = {
+    val ctx = new SQLContext(_sqlContext.sparkContext)
     ctx.setConf(SQLConf.CASE_SENSITIVE, false)
     ctx
   }
 
+  protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) {
+    test(sqlString) {
+      checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer)
+    }
+  }
+
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 5ef365797eace9ac93ffc8196562f83f3509e57e..c81c3d3982805f1dfa4d5eda25d8013e0637a3f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -21,6 +21,7 @@ import scala.language.existentials
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 
@@ -96,11 +97,11 @@ object FiltersPushed {
   var list: Seq[Filter] = Nil
 }
 
-class FilteredScanSuite extends DataSourceTest {
+class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
+  protected override lazy val sql = caseInsensitiveContext.sql _
 
-  import caseInsensitiveContext.sql
-
-  before {
+  override def beforeAll(): Unit = {
+    super.beforeAll()
     sql(
       """
         |CREATE TEMPORARY TABLE oneToTenFiltered
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index cdbfaf6455fe4bc4fc1285219983219f003a1c32..78bd3e55829644411643d9d6f4eee345300d8275 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -19,20 +19,17 @@ package org.apache.spark.sql.sources
 
 import java.io.File
 
-import org.scalatest.BeforeAndAfterAll
-
 import org.apache.spark.sql.{SaveMode, AnalysisException, Row}
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.util.Utils
 
-class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
-
-  import caseInsensitiveContext.sql
-
+class InsertSuite extends DataSourceTest with SharedSQLContext {
+  protected override lazy val sql = caseInsensitiveContext.sql _
   private lazy val sparkContext = caseInsensitiveContext.sparkContext
-
-  var path: File = null
+  private var path: File = null
 
   override def beforeAll(): Unit = {
+    super.beforeAll()
     path = Utils.createTempDir()
     val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
     caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
@@ -47,9 +44,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
   }
 
   override def afterAll(): Unit = {
-    caseInsensitiveContext.dropTempTable("jsonTable")
-    caseInsensitiveContext.dropTempTable("jt")
-    Utils.deleteRecursively(path)
+    try {
+      caseInsensitiveContext.dropTempTable("jsonTable")
+      caseInsensitiveContext.dropTempTable("jt")
+      Utils.deleteRecursively(path)
+    } finally {
+      super.afterAll()
+    }
   }
 
   test("Simple INSERT OVERWRITE a JSONRelation") {
@@ -221,9 +222,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
       sql("SELECT a * 2 FROM jsonTable"),
       (1 to 10).map(i => Row(i * 2)).toSeq)
 
-    assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
-    checkAnswer(
-      sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
+    assertCached(sql(
+      "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
+    checkAnswer(sql(
+      "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
       (2 to 10).map(i => Row(i, i - 1)).toSeq)
 
     // Insert overwrite and keep the same schema.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
index c86ddd7c83e5385c7300de44900519c87d47b4c1..79b6e9b45c009a3abf02eea9d166246937844342 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
@@ -19,21 +19,21 @@ package org.apache.spark.sql.sources
 
 import org.apache.spark.sql.{Row, QueryTest}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.util.Utils
 
-class PartitionedWriteSuite extends QueryTest {
-  import TestSQLContext.implicits._
+class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("write many partitions") {
     val path = Utils.createTempDir()
     path.delete()
 
-    val df = TestSQLContext.range(100).select($"id", lit(1).as("data"))
+    val df = ctx.range(100).select($"id", lit(1).as("data"))
     df.write.partitionBy("id").save(path.getCanonicalPath)
 
     checkAnswer(
-      TestSQLContext.read.load(path.getCanonicalPath),
+      ctx.read.load(path.getCanonicalPath),
       (0 to 99).map(Row(1, _)).toSeq)
 
     Utils.deleteRecursively(path)
@@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest {
     val path = Utils.createTempDir()
     path.delete()
 
-    val base = TestSQLContext.range(100)
+    val base = ctx.range(100)
     val df = base.unionAll(base).select($"id", lit(1).as("data"))
     df.write.partitionBy("id").save(path.getCanonicalPath)
 
     checkAnswer(
-      TestSQLContext.read.load(path.getCanonicalPath),
+      ctx.read.load(path.getCanonicalPath),
       (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)
 
     Utils.deleteRecursively(path)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index 0d5183444af78599caf8c9324b5e79f4536e5380..a89c5f8007e78e781043dad154d5932b85c3ef82 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -21,6 +21,7 @@ import scala.language.existentials
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 class PrunedScanSource extends RelationProvider {
@@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo
   }
 }
 
-class PrunedScanSuite extends DataSourceTest {
+class PrunedScanSuite extends DataSourceTest with SharedSQLContext {
+  protected override lazy val sql = caseInsensitiveContext.sql _
 
-  before {
-    caseInsensitiveContext.sql(
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sql(
       """
         |CREATE TEMPORARY TABLE oneToTenPruned
         |USING org.apache.spark.sql.sources.PrunedScanSource
@@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest {
 
   def testPruning(sqlString: String, expectedColumns: String*): Unit = {
     test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
-      val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution
+      val queryExecution = sql(sqlString).queryExecution
       val rawPlan = queryExecution.executedPlan.collect {
         case p: execution.PhysicalRDD => p
       } match {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
index 31730a3d3f8d3adbb51c2018e0f44f1af8f71723..f18546b4c2d9bc63fa97221363e523b7442a6f58 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
@@ -19,25 +19,22 @@ package org.apache.spark.sql.sources
 
 import java.io.File
 
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame}
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
-class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
-
-  import caseInsensitiveContext.sql
-
+class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
+  protected override lazy val sql = caseInsensitiveContext.sql _
   private lazy val sparkContext = caseInsensitiveContext.sparkContext
-
-  var originalDefaultSource: String = null
-
-  var path: File = null
-
-  var df: DataFrame = null
+  private var originalDefaultSource: String = null
+  private var path: File = null
+  private var df: DataFrame = null
 
   override def beforeAll(): Unit = {
+    super.beforeAll()
     originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName
 
     path = Utils.createTempDir()
@@ -49,11 +46,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
   }
 
   override def afterAll(): Unit = {
-    caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+    try {
+      caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+    } finally {
+      super.afterAll()
+    }
   }
 
   after {
-    caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
     Utils.deleteRecursively(path)
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index e34e0956d1fddf923f945dea476feab0fa6d8b78..12af8068c398f5b0c3d37fea2d13823c4978132b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 class DefaultSource extends SimpleScanSource
@@ -95,8 +96,8 @@ case class AllDataTypesScan(
   }
 }
 
-class TableScanSuite extends DataSourceTest {
-  import caseInsensitiveContext.sql
+class TableScanSuite extends DataSourceTest with SharedSQLContext {
+  protected override lazy val sql = caseInsensitiveContext.sql _
 
   private lazy val tableWithSchemaExpected = (1 to 10).map { i =>
     Row(
@@ -122,7 +123,8 @@ class TableScanSuite extends DataSourceTest {
       Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}")))))
   }.toSeq
 
-  before {
+  override def beforeAll(): Unit = {
+    super.beforeAll()
     sql(
       """
         |CREATE TEMPORARY TABLE oneToTen
@@ -303,9 +305,10 @@ class TableScanSuite extends DataSourceTest {
       sql("SELECT i * 2 FROM oneToTen"),
       (1 to 10).map(i => Row(i * 2)).toSeq)
 
-    assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
-    checkAnswer(
-      sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
+    assertCached(sql(
+      "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
+    checkAnswer(sql(
+      "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
       (2 to 10).map(i => Row(i, i - 1)).toSeq)
 
     // Verify uncaching
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
new file mode 100644
index 0000000000000000000000000000000000000000..1374a97476ca14bd3ab372fe14ebb24f4c21e826
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -0,0 +1,290 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
+
+/**
+ * A collection of sample data used in SQL tests.
+ */
+private[sql] trait SQLTestData { self =>
+  protected def _sqlContext: SQLContext
+
+  // Helper object to import SQL implicits without a concrete SQLContext
+  private object internalImplicits extends SQLImplicits {
+    protected override def _sqlContext: SQLContext = self._sqlContext
+  }
+
+  import internalImplicits._
+  import SQLTestData._
+
+  // Note: all test data should be lazy because the SQLContext is not set up yet.
+
+  protected lazy val testData: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      (1 to 100).map(i => TestData(i, i.toString))).toDF()
+    df.registerTempTable("testData")
+    df
+  }
+
+  protected lazy val testData2: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      TestData2(1, 1) ::
+      TestData2(1, 2) ::
+      TestData2(2, 1) ::
+      TestData2(2, 2) ::
+      TestData2(3, 1) ::
+      TestData2(3, 2) :: Nil, 2).toDF()
+    df.registerTempTable("testData2")
+    df
+  }
+
+  protected lazy val testData3: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      TestData3(1, None) ::
+      TestData3(2, Some(2)) :: Nil).toDF()
+    df.registerTempTable("testData3")
+    df
+  }
+
+  protected lazy val negativeData: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      (1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
+    df.registerTempTable("negativeData")
+    df
+  }
+
+  protected lazy val largeAndSmallInts: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      LargeAndSmallInts(2147483644, 1) ::
+      LargeAndSmallInts(1, 2) ::
+      LargeAndSmallInts(2147483645, 1) ::
+      LargeAndSmallInts(2, 2) ::
+      LargeAndSmallInts(2147483646, 1) ::
+      LargeAndSmallInts(3, 2) :: Nil).toDF()
+    df.registerTempTable("largeAndSmallInts")
+    df
+  }
+
+  protected lazy val decimalData: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      DecimalData(1, 1) ::
+      DecimalData(1, 2) ::
+      DecimalData(2, 1) ::
+      DecimalData(2, 2) ::
+      DecimalData(3, 1) ::
+      DecimalData(3, 2) :: Nil).toDF()
+    df.registerTempTable("decimalData")
+    df
+  }
+
+  protected lazy val binaryData: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      BinaryData("12".getBytes, 1) ::
+      BinaryData("22".getBytes, 5) ::
+      BinaryData("122".getBytes, 3) ::
+      BinaryData("121".getBytes, 2) ::
+      BinaryData("123".getBytes, 4) :: Nil).toDF()
+    df.registerTempTable("binaryData")
+    df
+  }
+
+  protected lazy val upperCaseData: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      UpperCaseData(1, "A") ::
+      UpperCaseData(2, "B") ::
+      UpperCaseData(3, "C") ::
+      UpperCaseData(4, "D") ::
+      UpperCaseData(5, "E") ::
+      UpperCaseData(6, "F") :: Nil).toDF()
+    df.registerTempTable("upperCaseData")
+    df
+  }
+
+  protected lazy val lowerCaseData: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      LowerCaseData(1, "a") ::
+      LowerCaseData(2, "b") ::
+      LowerCaseData(3, "c") ::
+      LowerCaseData(4, "d") :: Nil).toDF()
+    df.registerTempTable("lowerCaseData")
+    df
+  }
+
+  protected lazy val arrayData: RDD[ArrayData] = {
+    val rdd = _sqlContext.sparkContext.parallelize(
+      ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
+      ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
+    rdd.toDF().registerTempTable("arrayData")
+    rdd
+  }
+
+  protected lazy val mapData: RDD[MapData] = {
+    val rdd = _sqlContext.sparkContext.parallelize(
+      MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
+      MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
+      MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
+      MapData(Map(1 -> "a4", 2 -> "b4")) ::
+      MapData(Map(1 -> "a5")) :: Nil)
+    rdd.toDF().registerTempTable("mapData")
+    rdd
+  }
+
+  protected lazy val repeatedData: RDD[StringData] = {
+    val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
+    rdd.toDF().registerTempTable("repeatedData")
+    rdd
+  }
+
+  protected lazy val nullableRepeatedData: RDD[StringData] = {
+    val rdd = _sqlContext.sparkContext.parallelize(
+      List.fill(2)(StringData(null)) ++
+      List.fill(2)(StringData("test")))
+    rdd.toDF().registerTempTable("nullableRepeatedData")
+    rdd
+  }
+
+  protected lazy val nullInts: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      NullInts(1) ::
+      NullInts(2) ::
+      NullInts(3) ::
+      NullInts(null) :: Nil).toDF()
+    df.registerTempTable("nullInts")
+    df
+  }
+
+  protected lazy val allNulls: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      NullInts(null) ::
+      NullInts(null) ::
+      NullInts(null) ::
+      NullInts(null) :: Nil).toDF()
+    df.registerTempTable("allNulls")
+    df
+  }
+
+  protected lazy val nullStrings: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      NullStrings(1, "abc") ::
+      NullStrings(2, "ABC") ::
+      NullStrings(3, null) :: Nil).toDF()
+    df.registerTempTable("nullStrings")
+    df
+  }
+
+  protected lazy val tableName: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF()
+    df.registerTempTable("tableName")
+    df
+  }
+
+  protected lazy val unparsedStrings: RDD[String] = {
+    _sqlContext.sparkContext.parallelize(
+      "1, A1, true, null" ::
+      "2, B2, false, null" ::
+      "3, C3, true, null" ::
+      "4, D4, true, 2147483644" :: Nil)
+  }
+
+  // An RDD with 4 elements and 8 partitions
+  protected lazy val withEmptyParts: RDD[IntField] = {
+    val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
+    rdd.toDF().registerTempTable("withEmptyParts")
+    rdd
+  }
+
+  protected lazy val person: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      Person(0, "mike", 30) ::
+      Person(1, "jim", 20) :: Nil).toDF()
+    df.registerTempTable("person")
+    df
+  }
+
+  protected lazy val salary: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      Salary(0, 2000.0) ::
+      Salary(1, 1000.0) :: Nil).toDF()
+    df.registerTempTable("salary")
+    df
+  }
+
+  protected lazy val complexData: DataFrame = {
+    val df = _sqlContext.sparkContext.parallelize(
+      ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) ::
+      ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) ::
+      Nil).toDF()
+    df.registerTempTable("complexData")
+    df
+  }
+
+  /**
+   * Initialize all test data such that all temp tables are properly registered.
+   */
+  def loadTestData(): Unit = {
+    assert(_sqlContext != null, "attempted to initialize test data before SQLContext.")
+    testData
+    testData2
+    testData3
+    negativeData
+    largeAndSmallInts
+    decimalData
+    binaryData
+    upperCaseData
+    lowerCaseData
+    arrayData
+    mapData
+    repeatedData
+    nullableRepeatedData
+    nullInts
+    allNulls
+    nullStrings
+    tableName
+    unparsedStrings
+    withEmptyParts
+    person
+    salary
+    complexData
+  }
+}
+
+/**
+ * Case classes used in test data.
+ */
+private[sql] object SQLTestData {
+  case class TestData(key: Int, value: String)
+  case class TestData2(a: Int, b: Int)
+  case class TestData3(a: Int, b: Option[Int])
+  case class LargeAndSmallInts(a: Int, b: Int)
+  case class DecimalData(a: BigDecimal, b: BigDecimal)
+  case class BinaryData(a: Array[Byte], b: Int)
+  case class UpperCaseData(N: Int, L: String)
+  case class LowerCaseData(n: Int, l: String)
+  case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
+  case class MapData(data: scala.collection.Map[Int, String])
+  case class StringData(s: String)
+  case class IntField(i: Int)
+  case class NullInts(a: Integer)
+  case class NullStrings(n: Int, s: String)
+  case class TableName(tableName: String)
+  case class Person(id: Int, name: String, age: Int)
+  case class Salary(personId: Int, salary: Double)
+  case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 1066695589778ade053128ad83e03266f8b15968..cdd691e03589792ba5525495d62cad2c860af0d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -21,15 +21,71 @@ import java.io.File
 import java.util.UUID
 
 import scala.util.Try
+import scala.language.implicitConversions
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.util.Utils
 
-trait SQLTestUtils { this: SparkFunSuite =>
-  protected def sqlContext: SQLContext
+/**
+ * Helper trait that should be extended by all SQL test suites.
+ *
+ * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data
+ * prepared in advance as well as all implicit conversions used extensively by dataframes.
+ * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]].
+ *
+ * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is
+ * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
+ */
+private[sql] trait SQLTestUtils
+  extends SparkFunSuite
+  with BeforeAndAfterAll
+  with SQLTestData { self =>
+
+  protected def _sqlContext: SQLContext
+
+  // Whether to materialize all test data before the first test is run
+  private var loadTestDataBeforeTests = false
+
+  // Shorthand for running a query using our SQLContext
+  protected lazy val sql = _sqlContext.sql _
+
+  /**
+   * A helper object for importing SQL implicits.
+   *
+   * Note that the alternative of importing `sqlContext.implicits._` is not possible here.
+   * This is because we create the [[SQLContext]] immediately before the first test is run,
+   * but the implicits import is needed in the constructor.
+   */
+  protected object testImplicits extends SQLImplicits {
+    protected override def _sqlContext: SQLContext = self._sqlContext
+  }
+
+  /**
+   * Materialize the test data immediately after the [[SQLContext]] is set up.
+   * This is necessary if the data is accessed by name but not through direct reference.
+   */
+  protected def setupTestData(): Unit = {
+    loadTestDataBeforeTests = true
+  }
 
-  protected def configuration = sqlContext.sparkContext.hadoopConfiguration
+  protected override def beforeAll(): Unit = {
+    super.beforeAll()
+    if (loadTestDataBeforeTests) {
+      loadTestData()
+    }
+  }
+
+  /**
+   * The Hadoop configuration used by the active [[SQLContext]].
+   */
+  protected def configuration: Configuration = {
+    _sqlContext.sparkContext.hadoopConfiguration
+  }
 
   /**
    * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
@@ -39,12 +95,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
    */
   protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
     val (keys, values) = pairs.unzip
-    val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
-    (keys, values).zipped.foreach(sqlContext.conf.setConfString)
+    val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption)
+    (keys, values).zipped.foreach(_sqlContext.conf.setConfString)
     try f finally {
       keys.zip(currentValues).foreach {
-        case (key, Some(value)) => sqlContext.conf.setConfString(key, value)
-        case (key, None) => sqlContext.conf.unsetConf(key)
+        case (key, Some(value)) => _sqlContext.conf.setConfString(key, value)
+        case (key, None) => _sqlContext.conf.unsetConf(key)
       }
     }
   }
@@ -76,7 +132,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
    * Drops temporary table `tableName` after calling `f`.
    */
   protected def withTempTable(tableNames: String*)(f: => Unit): Unit = {
-    try f finally tableNames.foreach(sqlContext.dropTempTable)
+    try f finally tableNames.foreach(_sqlContext.dropTempTable)
   }
 
   /**
@@ -85,7 +141,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
   protected def withTable(tableNames: String*)(f: => Unit): Unit = {
     try f finally {
       tableNames.foreach { name =>
-        sqlContext.sql(s"DROP TABLE IF EXISTS $name")
+        _sqlContext.sql(s"DROP TABLE IF EXISTS $name")
       }
     }
   }
@@ -98,12 +154,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
     val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}"
 
     try {
-      sqlContext.sql(s"CREATE DATABASE $dbName")
+      _sqlContext.sql(s"CREATE DATABASE $dbName")
     } catch { case cause: Throwable =>
       fail("Failed to create temporary database", cause)
     }
 
-    try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
+    try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
   }
 
   /**
@@ -111,7 +167,15 @@ trait SQLTestUtils { this: SparkFunSuite =>
    * `f` returns.
    */
   protected def activateDatabase(db: String)(f: => Unit): Unit = {
-    sqlContext.sql(s"USE $db")
-    try f finally sqlContext.sql(s"USE default")
+    _sqlContext.sql(s"USE $db")
+    try f finally _sqlContext.sql(s"USE default")
+  }
+
+  /**
+   * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier
+   * way to construct [[DataFrame]] directly out of local data without relying on implicits.
+   */
+  protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
+    DataFrame(_sqlContext, plan)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
new file mode 100644
index 0000000000000000000000000000000000000000..3cfd822e2a747412a5314a3643dd85eff3b7c846
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.apache.spark.sql.SQLContext
+
+
+/**
+ * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]].
+ */
+private[sql] trait SharedSQLContext extends SQLTestUtils {
+
+  /**
+   * The [[TestSQLContext]] to use for all tests in this suite.
+   *
+   * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
+   * mode with the default test configurations.
+   */
+  private var _ctx: TestSQLContext = null
+
+  /**
+   * The [[TestSQLContext]] to use for all tests in this suite.
+   */
+  protected def ctx: TestSQLContext = _ctx
+  protected def sqlContext: TestSQLContext = _ctx
+  protected override def _sqlContext: SQLContext = _ctx
+
+  /**
+   * Initialize the [[TestSQLContext]].
+   */
+  protected override def beforeAll(): Unit = {
+    if (_ctx == null) {
+      _ctx = new TestSQLContext
+    }
+    // Ensure we have initialized the context before calling parent code
+    super.beforeAll()
+  }
+
+  /**
+   * Stop the underlying [[org.apache.spark.SparkContext]], if any.
+   */
+  protected override def afterAll(): Unit = {
+    try {
+      if (_ctx != null) {
+        _ctx.sparkContext.stop()
+        _ctx = null
+      }
+    } finally {
+      super.afterAll()
+    }
+  }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
similarity index 54%
rename from sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index b3a4231da91c2a7e9a77e8015f574769e3263146..92ef2f7d74ba13f81a84b5807c839595e27c5017 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -17,40 +17,36 @@
 
 package org.apache.spark.sql.test
 
-import scala.language.implicitConversions
-
 import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-
-/** A SQLContext that can be used for local testing. */
-class LocalSQLContext
-  extends SQLContext(
-    new SparkContext("local[2]", "TestSQLContext", new SparkConf()
-      .set("spark.sql.testkey", "true")
-      // SPARK-8910
-      .set("spark.ui.enabled", "false"))) {
-
-  override protected[sql] def createSession(): SQLSession = {
-    new this.SQLSession()
+import org.apache.spark.sql.{SQLConf, SQLContext}
+
+
+/**
+ * A special [[SQLContext]] prepared for testing.
+ */
+private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self =>
+
+  def this() {
+    this(new SparkContext("local[2]", "test-sql-context",
+      new SparkConf().set("spark.sql.testkey", "true")))
   }
 
+  // Use fewer partitions to speed up testing
+  protected[sql] override def createSession(): SQLSession = new this.SQLSession()
+
+  /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */
   protected[sql] class SQLSession extends super.SQLSession {
     protected[sql] override lazy val conf: SQLConf = new SQLConf {
-      /** Fewer partitions to speed up testing. */
       override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
     }
   }
 
-  /**
-   * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to
-   * construct [[DataFrame]] directly out of local data without relying on implicits.
-   */
-  protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
-    DataFrame(this, plan)
+  // Needed for Java tests
+  def loadTestData(): Unit = {
+    testData.loadTestData()
   }
 
+  private object testData extends SQLTestData {
+    protected override def _sqlContext: SQLContext = self
+  }
 }
-
-object TestSQLContext extends LocalSQLContext
-
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
index 806240e6de458099c7c55700241a4a2921189c23..bf431cd6b0260fe49b9769b3c06dac74c2ec5813 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
@@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._
 import org.scalatest.selenium.WebBrowser
 import org.scalatest.time.SpanSugar._
 
-import org.apache.spark.sql.hive.HiveContext
 import org.apache.spark.ui.SparkUICssErrorHandler
 
 class UISeleniumSuite
@@ -36,7 +35,6 @@ class UISeleniumSuite
 
   implicit var webDriver: WebDriver = _
   var server: HiveThriftServer2 = _
-  var hc: HiveContext = _
   val uiPort = 20000 + Random.nextInt(10000)
   override def mode: ServerMode.Value = ServerMode.binary
 
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 59e65ff97b8e023153fc9927f60668f258ce933b..574624d501f2261df761fad598b02a26c00f2479 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
 import org.apache.spark.sql.sources.DataSourceTest
 import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
 import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
-import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.{Row, SaveMode, SQLContext}
 import org.apache.spark.{Logging, SparkFunSuite}
 
 
@@ -53,7 +53,8 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
 }
 
 class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils {
-  override val sqlContext = TestHive
+  override def _sqlContext: SQLContext = TestHive
+  import testImplicits._
 
   private val testDF = range(1, 3).select(
     ('id + 0.1) cast DecimalType(10, 3) as 'd1,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
index 1fa005d5f9a15f46a0c0894ba1f6eac5dcbe4424..fe0db5228de1621ba06b6dceb49ea50d1923a162 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql.hive
 
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
-import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.{QueryTest, Row, SQLContext}
 
 case class Cases(lower: String, UPPER: String)
 
 class HiveParquetSuite extends QueryTest with ParquetTest {
-  val sqlContext = TestHive
-
-  import sqlContext._
+  private val ctx = TestHive
+  override def _sqlContext: SQLContext = ctx
 
   test("Case insensitive attribute names") {
     withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
@@ -54,7 +53,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
   test("Converting Hive to Parquet Table via saveAsParquetFile") {
     withTempPath { dir =>
       sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath)
-      read.parquet(dir.getCanonicalPath).registerTempTable("p")
+      ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p")
       withTempTable("p") {
         checkAnswer(
           sql("SELECT * FROM src ORDER BY key"),
@@ -67,7 +66,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
     withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
       withTempPath { file =>
         sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
-        read.parquet(file.getCanonicalPath).registerTempTable("p")
+        ctx.read.parquet(file.getCanonicalPath).registerTempTable("p")
         withTempTable("p") {
           // let's do three overwrites for good measure
           sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 7f36a483a3965e1e783ad96a88e1d311e01005ca..20a50586d52015b71e5db6e43a6e05f33e3a3a0c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -22,7 +22,6 @@ import java.io.{IOException, File}
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapred.InvalidInputException
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.Logging
@@ -42,7 +41,8 @@ import org.apache.spark.util.Utils
  */
 class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll
   with Logging {
-  override val sqlContext = TestHive
+  override def _sqlContext: SQLContext = TestHive
+  private val sqlContext = _sqlContext
 
   var jsonFilePath: String = _
 
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
index 73852f13ad20d1a804456659dfd1ce2a5d7a3f3f..417e8b07917cc349a1ac59dbd534e342ce2ae772 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
@@ -22,9 +22,8 @@ import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode}
 
 class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
-  override val sqlContext: SQLContext = TestHive
-
-  import sqlContext.sql
+  override val _sqlContext: SQLContext = TestHive
+  private val sqlContext = _sqlContext
 
   private val df = sqlContext.range(10).coalesce(1)
 
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
index 251e0324bfa5f7335174b91e575406baf5e15723..13452e71a1b3bd131598c29512c0ee71e039b2ae 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
@@ -26,7 +26,8 @@ import org.apache.spark.sql.{Row, SQLConf, SQLContext}
 class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest {
   import ParquetCompatibilityTest.makeNullable
 
-  override val sqlContext: SQLContext = TestHive
+  override def _sqlContext: SQLContext = TestHive
+  private val sqlContext = _sqlContext
 
   /**
    * Set the staging directory (and hence path to ignore Parquet files under)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
index 9b3ede43ee2d16c715f0ff03e6ec7c70d8e07e9d..7ee1c8d13aa3f36885005d58969d32dcc1bb020b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -17,14 +17,12 @@
 
 package org.apache.spark.sql.hive
 
-import org.apache.spark.sql.{Row, QueryTest}
+import org.apache.spark.sql.QueryTest
 
 case class FunctionResult(f1: String, f2: String)
 
 class UDFSuite extends QueryTest {
-
   private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
-  import ctx.implicits._
 
   test("UDF case insensitive") {
     ctx.udf.register("random0", () => { Math.random() })
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 7b5aa4763fd9ea2f6e8e052910e2f8b22e4150d0..a312f849582485e4fb420aab5af7816a24dc8181 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -17,17 +17,18 @@
 
 package org.apache.spark.sql.hive.execution
 
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql._
 import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.sql._
-import org.scalatest.BeforeAndAfterAll
 import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
 
 abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
-
-  override val sqlContext = TestHive
+  override def _sqlContext: SQLContext = TestHive
+  protected val sqlContext = _sqlContext
   import sqlContext.implicits._
 
   var originalUseAggregate2: Boolean = _
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
index 44c5b80392fa52d1e29b1a88bbf8e294f6a4bd91..11d7a872dff090b8b0ed5188c78d5069c3f70af2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
@@ -26,8 +26,8 @@ import org.apache.spark.sql.test.SQLTestUtils
  * A set of tests that validates support for Hive Explain command.
  */
 class HiveExplainSuite extends QueryTest with SQLTestUtils {
-
-  def sqlContext: SQLContext = TestHive
+  override def _sqlContext: SQLContext = TestHive
+  private val sqlContext = _sqlContext
 
   test("explain extended command") {
     checkExistence(sql(" explain   select * from src where key=123 "), true,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 79a136ae6f619a939d2b6ae6347719baa201621b..8b8f520776e70c1956eb00de12f5ba2f01dda25f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -66,7 +66,8 @@ class MyDialect extends DefaultParserDialect
  * valid, but Hive currently cannot execute it.
  */
 class SQLQuerySuite extends QueryTest with SQLTestUtils {
-  override def sqlContext: SQLContext = TestHive
+  override def _sqlContext: SQLContext = TestHive
+  private val sqlContext = _sqlContext
 
   test("UDTF") {
     sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
index 0875232aede3e1efeaef5cf4f6b3bb081fee06e2..9aca40f15ac1556f6608a9f9c2d6fee858bbcfde 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StringType
 
 class ScriptTransformationSuite extends SparkPlanTest {
 
-  override def sqlContext: SQLContext = TestHive
+  override def _sqlContext: SQLContext = TestHive
+  private val sqlContext = _sqlContext
 
   private val noSerdeIOSchema = HiveScriptIOSchema(
     inputRowFormat = Seq.empty,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
index 145965388da01229d265b7565254e8f7d186976d..f7ba20ff41d8d504e752e9199e99854e02b7c5f3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
@@ -27,8 +27,8 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.test.SQLTestUtils
 
 private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite =>
-  lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive
-
+  protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive
+  protected val sqlContext = _sqlContext
   import sqlContext.implicits._
   import sqlContext.sparkContext
 
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 50f02432dacce1efa338edeb669f1aaeb7395ea1..34d3434569f581e13f3ed2324a5e49ec2df03886 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -685,7 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
  * A collection of tests for parquet data with various forms of partitioning.
  */
 abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
-  override def sqlContext: SQLContext = TestHive
+  override def _sqlContext: SQLContext = TestHive
+  protected val sqlContext = _sqlContext
 
   var partitionedTableDir: File = null
   var normalTableDir: File = null
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
index e976125b3706dca57fdea737bfc2d27b7a6813e3..b4640b1616281adc8595c6d74f64b3d6d30b1b32 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
@@ -18,14 +18,16 @@
 package org.apache.spark.sql.sources
 
 import org.apache.hadoop.fs.Path
-import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.test.SQLTestUtils
 
 
 class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
-  override val sqlContext = TestHive
+  override def _sqlContext: SQLContext = TestHive
+  private val sqlContext = _sqlContext
 
   // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
   val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 2a69d331b6e529b24a0e84c152662b113b866591..af445626fbe4dfa329510afa8b582414ba839d52 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -34,9 +34,8 @@ import org.apache.spark.sql.types._
 
 
 abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
-  override lazy val sqlContext: SQLContext = TestHive
-
-  import sqlContext.sql
+  override def _sqlContext: SQLContext = TestHive
+  protected val sqlContext = _sqlContext
   import sqlContext.implicits._
 
   val dataSourceName: String