diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 02ede711372d3fa93deb647accf6778f00aa8eda..05322b024d5f6d4e97e000543a6f28b5bcefc2f5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
 
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.LocalSparkContext
 
 object LogisticRegressionSuite {
 
@@ -66,19 +67,7 @@ object LogisticRegressionSuite {
 
 }
 
-class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
-
+class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
       prediction != expected.label
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index b615f76e66cf91cfad2d4042f0f4dc6e20e68b74..9dd6c79ee6ad8cc50a7ae202854995daf0452078 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.util.LocalSparkContext
 
 object NaiveBayesSuite {
 
@@ -59,17 +59,7 @@ object NaiveBayesSuite {
   }
 }
 
-class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class NaiveBayesSuite extends FunSuite with LocalSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOfPredictions = predictions.zip(input).count {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 3357b86f9b7061cf4ef5be89c6c29a312118a5b1..bc7abb568a172fc5daec6660d8d601e5a280d90f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -25,8 +25,9 @@ import org.scalatest.FunSuite
 
 import org.jblas.DoubleMatrix
 
-import org.apache.spark.{SparkException, SparkContext}
+import org.apache.spark.SparkException
 import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.LocalSparkContext
 
 object SVMSuite {
 
@@ -58,17 +59,7 @@ object SVMSuite {
 
 }
 
-class SVMSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class SVMSuite extends FunSuite with LocalSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 73657cac893ce9180e54696fe18dba6d0605397a..4ef1d1f64ff06dd92eca61638e8a5052d17f55e5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -21,20 +21,9 @@ package org.apache.spark.mllib.clustering
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.util.LocalSparkContext
 
-
-class KMeansSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class KMeansSuite extends FunSuite with LocalSparkContext {
 
   val EPSILON = 1e-4
 
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index a6028a1e981dd6ff80b0f7cdcac94e5dca26149f..a453de6767aa2214555ebbc607640ce7982b269d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
 
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.LocalSparkContext
 
 object GradientDescentSuite {
 
@@ -62,17 +63,7 @@ object GradientDescentSuite {
   }
 }
 
-class GradientDescentSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
 
   test("Assert the loss is decreasing.") {
     val nPoints = 10000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 4e8dbde65801c05712fe98d1aed98c7b649ba744..5dcec7dc3eb9b3b4451445c0410c6cc78225cbf7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -23,10 +23,10 @@ import scala.util.Random
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext
-
 import org.jblas._
 
+import org.apache.spark.mllib.util.LocalSparkContext
+
 object ALSSuite {
 
   def generateRatingsAsJavaList(
@@ -73,17 +73,7 @@ object ALSSuite {
 }
 
 
-class ALSSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class ALSSuite extends FunSuite with LocalSparkContext {
 
   test("rank-1 matrices") {
     testALS(50, 100, 1, 15, 0.7, 0.3)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index b2c8df97a82a77f62a41b804d392c5f0767dc54b..64e4cbb860f61d8f805fe956fc98fd50ff062a85 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -22,21 +22,9 @@ import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
 import org.apache.spark.SparkContext
-import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
 
-
-class LassoSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class LassoSuite extends FunSuite with LocalSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 406afbaa3e2c108f71404f075d856ebd4cd4add5..281f9df36ddb3a025d21e68b6f2d6e85dbabf6b9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -20,20 +20,9 @@ package org.apache.spark.mllib.regression
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext
-import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
 
-class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class LinearRegressionSuite extends FunSuite with LocalSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 1d6a10b66e89238dfe4cc9a2b95a41f63ec890ed..67dd06cc0f5eb9d1cb56dc5a231c6749bb5d5671 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -22,20 +22,10 @@ import org.jblas.DoubleMatrix
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext
-import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
 
-class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
 
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
 
   def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
     predictions.zip(input).map { case (prediction, expected) =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
new file mode 100644
index 0000000000000000000000000000000000000000..7d840043e5c6b2cf9dc5a2e736849e395beb56fc
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -0,0 +1,23 @@
+package org.apache.spark.mllib.util
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkContext
+
+trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
+  @transient var sc: SparkContext = _
+
+  override def beforeAll() {
+    sc = new SparkContext("local", "test")
+    super.beforeAll()
+  }
+
+  override def afterAll() {
+    if (sc != null) {
+      sc.stop()
+    }
+    System.clearProperty("spark.driver.port")
+    super.afterAll()
+  }
+}