diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 09183fe65b72202c6af2837b5ed749dd10da42e6..035bfc07b684dbdd8ec32bf4c3549b6b5fd191a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -21,13 +21,11 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
 
 class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   test("MinMaxScaler fit basic case") {
-    val sqlContext = new SQLContext(sc)
-
     val data = Array(
       Vectors.dense(1, 0, Long.MinValue),
       Vectors.dense(2, 0, 0),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index de3d438ce83be8be70ffbdbc660afc95508177ee..468833901995a36ffc92325065fb515eace253db 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
 
 
 class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -61,7 +61,6 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
       Vectors.sparse(3, Seq())
     )
 
-    val sqlContext = new SQLContext(sc)
     dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
     normalizer = new Normalizer()
       .setInputCol("features")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 74706a23e0936660a7363150a642f8a956cced47..8acc3369c489c53aa5682539486b6a94b6d4b651 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
 
 class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
@@ -54,8 +54,6 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
   }
 
   test("Test vector slicer") {
-    val sqlContext = new SQLContext(sc)
-
     val data = Array(
       Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
       Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 460849c79f04fe89b9465baf695b3eb03e3b966f..4e2d0e93bd4122af59d8a554daa5eece229ef63c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -42,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite {
       data: RDD[LabeledPoint],
       categoricalFeatures: Map[Int, Int],
       numClasses: Int): DataFrame = {
-    val sqlContext = new SQLContext(data.sparkContext)
+    val sqlContext = SQLContext.getOrCreate(data.sparkContext)
     import sqlContext.implicits._
     val df = data.toDF()
     val numFeatures = data.first().features.size
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index dd6366050c02071eb6d751766665148cbda7d2d3..d281084f913c0689a8bfb84c82291d4ad53a4cc7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.ml.regression.LinearRegression
 import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
 
 class CrossValidatorSuite
@@ -39,7 +39,6 @@ class CrossValidatorSuite
 
   override def beforeAll(): Unit = {
     super.beforeAll()
-    val sqlContext = new SQLContext(sc)
     dataset = sqlContext.createDataFrame(
       sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
   }