From 23f5bdf06a388e08ea5a69e848f0ecd5165aa481 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Wed, 12 Nov 2014 18:15:14 -0800
Subject: [PATCH] [SPARK-4373][MLLIB] fix MLlib maven tests

We want to make sure there is at most one spark context inside the same jvm. JoshRosen

Author: Xiangrui Meng <meng@databricks.com>

Closes #3235 from mengxr/SPARK-4373 and squashes the following commits:

6574b69 [Xiangrui Meng] rename LocalSparkContext to MLlibTestSparkContext
913d48d [Xiangrui Meng] make sure there is at most one spark context inside the same jvm
---
 .../LogisticRegressionSuite.scala             | 22 ++++++++++++++----
 .../spark/ml/tuning/CrossValidatorSuite.scala | 15 ++++++++----
 .../LogisticRegressionSuite.scala             |  4 ++--
 .../classification/NaiveBayesSuite.scala      |  4 ++--
 .../spark/mllib/classification/SVMSuite.scala |  4 ++--
 .../spark/mllib/clustering/KMeansSuite.scala  |  4 ++--
 .../evaluation/AreaUnderCurveSuite.scala      |  4 ++--
 .../BinaryClassificationMetricsSuite.scala    |  4 ++--
 .../evaluation/MulticlassMetricsSuite.scala   |  4 ++--
 .../evaluation/MultilabelMetricsSuite.scala   |  4 ++--
 .../evaluation/RankingMetricsSuite.scala      |  4 ++--
 .../evaluation/RegressionMetricsSuite.scala   |  4 ++--
 .../spark/mllib/feature/HashingTFSuite.scala  |  4 ++--
 .../apache/spark/mllib/feature/IDFSuite.scala |  4 ++--
 .../spark/mllib/feature/NormalizerSuite.scala |  4 ++--
 .../mllib/feature/StandardScalerSuite.scala   |  4 ++--
 .../spark/mllib/feature/Word2VecSuite.scala   |  4 ++--
 .../distributed/CoordinateMatrixSuite.scala   |  4 ++--
 .../distributed/IndexedRowMatrixSuite.scala   |  4 ++--
 .../linalg/distributed/RowMatrixSuite.scala   |  4 ++--
 .../optimization/GradientDescentSuite.scala   |  4 ++--
 .../spark/mllib/optimization/LBFGSSuite.scala |  4 ++--
 .../spark/mllib/random/RandomRDDsSuite.scala  |  4 ++--
 .../spark/mllib/rdd/RDDFunctionsSuite.scala   |  4 ++--
 .../spark/mllib/recommendation/ALSSuite.scala |  4 ++--
 .../spark/mllib/regression/LassoSuite.scala   |  4 ++--
 .../regression/LinearRegressionSuite.scala    |  4 ++--
 .../regression/RidgeRegressionSuite.scala     |  4 ++--
 .../spark/mllib/stat/CorrelationSuite.scala   |  4 ++--
 .../mllib/stat/HypothesisTestSuite.scala      |  4 ++--
 .../spark/mllib/tree/DecisionTreeSuite.scala  |  4 ++--
 .../mllib/tree/GradientBoostingSuite.scala    |  4 ++--
 .../spark/mllib/tree/RandomForestSuite.scala  |  4 ++--
 .../mllib/tree/impl/BaggedPointSuite.scala    |  4 ++--
 .../spark/mllib/util/MLUtilsSuite.scala       |  2 +-
 ...text.scala => MLlibTestSparkContext.scala} | 23 +++++++++++++------
 36 files changed, 108 insertions(+), 82 deletions(-)
 rename mllib/src/test/scala/org/apache/spark/mllib/util/{LocalSparkContext.scala => MLlibTestSparkContext.scala} (66%)

diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 625af299a5..e8030fef55 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -20,16 +20,24 @@ package org.apache.spark.ml.classification
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
-import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{SQLContext, SchemaRDD}
 
-class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
+class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
 
-  import sqlContext._
+  @transient var sqlContext: SQLContext = _
+  @transient var dataset: SchemaRDD = _
 
-  val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sqlContext = new SQLContext(sc)
+    dataset = sqlContext.createSchemaRDD(
+      sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
+  }
 
   test("logistic regression") {
+    val sqlContext = this.sqlContext
+    import sqlContext._
     val lr = new LogisticRegression
     val model = lr.fit(dataset)
     model.transform(dataset)
@@ -38,6 +46,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
   }
 
   test("logistic regression with setters") {
+    val sqlContext = this.sqlContext
+    import sqlContext._
     val lr = new LogisticRegression()
       .setMaxIter(10)
       .setRegParam(1.0)
@@ -48,6 +58,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
   }
 
   test("logistic regression fit and transform with varargs") {
+    val sqlContext = this.sqlContext
+    import sqlContext._
     val lr = new LogisticRegression
     val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
     model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 72a334ae93..41cc13da4d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -22,14 +22,19 @@ import org.scalatest.FunSuite
 import org.apache.spark.ml.classification.LogisticRegression
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
 import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
-import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{SQLContext, SchemaRDD}
 
-class CrossValidatorSuite extends FunSuite with LocalSparkContext {
+class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
 
-  import sqlContext._
+  @transient var dataset: SchemaRDD = _
 
-  val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val sqlContext = new SQLContext(sc)
+    dataset = sqlContext.createSchemaRDD(
+      sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
+  }
 
   test("cross validation with logistic regression") {
     val lr = new LogisticRegression
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index e954baaf7d..6c1c784a19 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
 
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 
 object LogisticRegressionSuite {
@@ -57,7 +57,7 @@ object LogisticRegressionSuite {
   }
 }
 
-class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers {
+class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
   def validatePrediction(
       predictions: Seq[Double],
       input: Seq[LabeledPoint],
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 80989bc074..e68fe89d6c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -24,7 +24,7 @@ import org.scalatest.FunSuite
 import org.apache.spark.SparkException
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 
 object NaiveBayesSuite {
 
@@ -60,7 +60,7 @@ object NaiveBayesSuite {
   }
 }
 
-class NaiveBayesSuite extends FunSuite with LocalSparkContext {
+class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOfPredictions = predictions.zip(input).count {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 65e5df58db..a2de7fbd41 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.FunSuite
 import org.apache.spark.SparkException
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 
 object SVMSuite {
 
@@ -58,7 +58,7 @@ object SVMSuite {
 
 }
 
-class SVMSuite extends FunSuite with LocalSparkContext {
+class SVMSuite extends FunSuite with MLlibTestSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index afa1f79b95..9ebef8466c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -22,10 +22,10 @@ import scala.util.Random
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 
-class KMeansSuite extends FunSuite with LocalSparkContext {
+class KMeansSuite extends FunSuite with MLlibTestSparkContext {
 
   import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM}
 
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
index 994e0feb86..79847633ff 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 
-class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
+class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext {
   test("auc computation") {
     val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
     val auc = 4.0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index a733f88b60..3a29ccb519 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 
-class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
+class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext {
 
   def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
 
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
index 1ea503971c..7dc4f3cfbc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.mllib.evaluation
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Matrices
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
+class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext {
   test("Multiclass evaluation metrics") {
     /*
      * Confusion matrix for 3-class classification with total 9 instances:
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
index 342baa0274..2537dd62c9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 
-class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
+class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext {
   test("Multilabel evaluation metrics") {
     /*
     * Documents true labels (5x class0, 3x class1, 4x class2):
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
index a2d4bb4148..609eed983f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.mllib.evaluation
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class RankingMetricsSuite extends FunSuite with LocalSparkContext {
+class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext {
   test("Ranking metrics: map, ndcg") {
     val predictionAndLabels = sc.parallelize(
       Seq(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index 5396d7b2b7..670b4c34e6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 
-class RegressionMetricsSuite extends FunSuite with LocalSparkContext {
+class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext {
 
   test("regression metrics") {
     val predictionAndObservations = sc.parallelize(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
index a599e0d938..0c4dfb7b97 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.mllib.feature
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class HashingTFSuite extends FunSuite with LocalSparkContext {
+class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
 
   test("hashing tf on a single doc") {
     val hashingTF = new HashingTF(1000)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
index 43974f84e3..30147e7fd9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
@@ -21,10 +21,10 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.SparkContext._
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 
-class IDFSuite extends FunSuite with LocalSparkContext {
+class IDFSuite extends FunSuite with MLlibTestSparkContext {
 
   test("idf") {
     val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
index 2bf9d9816a..85fdd271b5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala
@@ -22,10 +22,10 @@ import org.scalatest.FunSuite
 import breeze.linalg.{norm => brzNorm}
 
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 
-class NormalizerSuite extends FunSuite with LocalSparkContext {
+class NormalizerSuite extends FunSuite with MLlibTestSparkContext {
 
   val data = Array(
     Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
index e217b93ceb..4c93c0ca4f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.mllib.feature
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.mllib.rdd.RDDFunctions._
 import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
 import org.apache.spark.rdd.RDD
 
-class StandardScalerSuite extends FunSuite with LocalSparkContext {
+class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
 
   private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = {
     data.treeAggregate(new MultivariateOnlineSummarizer)(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index e34335d89e..52278690db 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.mllib.feature
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class Word2VecSuite extends FunSuite with LocalSparkContext {
+class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
 
   // TODO: add more tests
 
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
index cd45438fb6..f8709751ef 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala
@@ -21,10 +21,10 @@ import org.scalatest.FunSuite
 
 import breeze.linalg.{DenseMatrix => BDM}
 
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.linalg.Vectors
 
-class CoordinateMatrixSuite extends FunSuite with LocalSparkContext {
+class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext {
 
   val m = 5
   val n = 4
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index f7c46f23b7..e25bc02b06 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -21,11 +21,11 @@ import org.scalatest.FunSuite
 
 import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV}
 
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.linalg.{Matrices, Vectors}
 
-class IndexedRowMatrixSuite extends FunSuite with LocalSparkContext {
+class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext {
 
   val m = 4
   val n = 3
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index 63f3ed58c0..dbf55ff81c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -23,9 +23,9 @@ import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, s
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector}
-import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 
-class RowMatrixSuite extends FunSuite with LocalSparkContext {
+class RowMatrixSuite extends FunSuite with MLlibTestSparkContext {
 
   val m = 4
   val n = 3
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index bf040110e2..86481c6e66 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -24,7 +24,7 @@ import org.scalatest.{FunSuite, Matchers}
 
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 
 object GradientDescentSuite {
@@ -61,7 +61,7 @@ object GradientDescentSuite {
   }
 }
 
-class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers {
+class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers {
 
   test("Assert the loss is decreasing.") {
     val nPoints = 10000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index ccba004baa..70c64775e4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -23,10 +23,10 @@ import org.scalatest.{FunSuite, Matchers}
 
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
+import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 
-class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
+class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers {
 
   val nPoints = 10000
   val A = 2.0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
index c50b78bcbc..ea5889b3ec 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala
@@ -24,7 +24,7 @@ import org.scalatest.FunSuite
 import org.apache.spark.SparkContext._
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD}
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.StatCounter
 
@@ -34,7 +34,7 @@ import org.apache.spark.util.StatCounter
  *
  * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged
  */
-class RandomRDDsSuite extends FunSuite with LocalSparkContext with Serializable {
+class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable {
 
   def testGeneratedRDD(rdd: RDD[Double],
       expectedSize: Long,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
index 4ef67a40b9..681ce92639 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala
@@ -19,10 +19,10 @@ package org.apache.spark.mllib.rdd
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.rdd.RDDFunctions._
 
-class RDDFunctionsSuite extends FunSuite with LocalSparkContext {
+class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext {
 
   test("sliding") {
     val data = 0 until 6
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 017c39edb1..603d0ad127 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.FunSuite
 import org.jblas.DoubleMatrix
 
 import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.recommendation.ALS.BlockStats
 
 object ALSSuite {
@@ -85,7 +85,7 @@ object ALSSuite {
 }
 
 
-class ALSSuite extends FunSuite with LocalSparkContext {
+class ALSSuite extends FunSuite with MLlibTestSparkContext {
 
   test("rank-1 matrices") {
     testALS(50, 100, 1, 15, 0.7, 0.3)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 7aa96421ae..2668dcc14a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -23,9 +23,9 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
-  LocalSparkContext}
+  MLlibTestSparkContext}
 
-class LassoSuite extends FunSuite with LocalSparkContext {
+class LassoSuite extends FunSuite with MLlibTestSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 4f89112b65..864622a929 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -23,9 +23,9 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
-  LocalSparkContext}
+  MLlibTestSparkContext}
 
-class LinearRegressionSuite extends FunSuite with LocalSparkContext {
+class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 727bbd051f..18d3bf5ea4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -24,9 +24,9 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator,
-  LocalSparkContext}
+  MLlibTestSparkContext}
 
-class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
+class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext {
 
   def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
     predictions.zip(input).map { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
index 34548c86eb..d20a09b4b4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala
@@ -24,9 +24,9 @@ import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation,
   SpearmanCorrelation}
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class CorrelationSuite extends FunSuite with LocalSparkContext {
+class CorrelationSuite extends FunSuite with MLlibTestSparkContext {
 
   // test input data
   val xData = Array(1.0, 0.0, -2.0)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
index 6de3840b3f..15418e6035 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala
@@ -25,10 +25,10 @@ import org.apache.spark.SparkException
 import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.stat.test.ChiSqTest
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 
-class HypothesisTestSuite extends FunSuite with LocalSparkContext {
+class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext {
 
   test("chi squared pearson goodness of fit") {
 
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index c579cb5854..972c905ec9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -30,9 +30,9 @@ import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
 import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class DecisionTreeSuite extends FunSuite with LocalSparkContext {
+class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
 
   test("Binary classification with continuous features: split and bin calculation") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
index ae0028a688..84de40103d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
@@ -25,12 +25,12 @@ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
 import org.apache.spark.mllib.tree.impurity.Variance
 import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
 
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
 /**
  * Test suite for [[GradientBoosting]].
  */
-class GradientBoostingSuite extends FunSuite with LocalSparkContext {
+class GradientBoostingSuite extends FunSuite with MLlibTestSparkContext {
 
   test("Regression with continuous features: SquaredError") {
     GradientBoostingSuite.testCombinations.foreach {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 73c4393c35..2734e089d6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -28,12 +28,12 @@ import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
 import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
 import org.apache.spark.mllib.tree.model.Node
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
 /**
  * Test suite for [[RandomForest]].
  */
-class RandomForestSuite extends FunSuite with LocalSparkContext {
+class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
   def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
     val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
     val rdd = sc.parallelize(arr)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
index 5cb433232e..b184e93667 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree.impl
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.tree.EnsembleTestHelper
-import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
 /**
  * Test suite for [[BaggedPoint]].
  */
-class BaggedPointSuite extends FunSuite with LocalSparkContext  {
+class BaggedPointSuite extends FunSuite with MLlibTestSparkContext  {
 
   test("BaggedPoint RDD: without subsampling") {
     val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 0dbe766b4d..88bc49cc61 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -34,7 +34,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLUtils._
 import org.apache.spark.util.Utils
 
-class MLUtilsSuite extends FunSuite with LocalSparkContext {
+class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
 
   test("epsilon computation") {
     assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
similarity index 66%
rename from mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
rename to mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index 4417d66adf..b658889476 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -17,17 +17,26 @@
 
 package org.apache.spark.mllib.util
 
-import org.scalatest.{BeforeAndAfterAll, Suite}
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterAll
 
-import org.apache.spark.SparkContext
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.{SparkConf, SparkContext}
 
-trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
-  @transient val sc = new SparkContext("local", "test")
-  @transient lazy val sqlContext = new SQLContext(sc)
+trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
+  @transient var sc: SparkContext = _
+
+  override def beforeAll() {
+    super.beforeAll()
+    val conf = new SparkConf()
+      .setMaster("local[2]")
+      .setAppName("MLlibUnitTest")
+    sc = new SparkContext(conf)
+  }
 
   override def afterAll() {
-    sc.stop()
+    if (sc != null) {
+      sc.stop()
+    }
     super.afterAll()
   }
 }
-- 
GitLab