Skip to content
Snippets Groups Projects
Commit f8544981 authored by Reynold Xin's avatar Reynold Xin
Browse files

Merge pull request #469 from ajtulloch/use-local-spark-context-in-tests-for-mllib

[MLlib] Use a LocalSparkContext trait in test suites

Replaces the 9 instances of

```scala
class XXXSuite extends FunSuite with BeforeAndAfterAll {
  @transient private var sc: SparkContext = _

  override def beforeAll() {
    sc = new SparkContext("local", "test")
  }

  override def afterAll() {
    sc.stop()
    System.clearProperty("spark.driver.port")
  }
```

with

```scala
class XXXSuite extends FunSuite with LocalSparkContext {
```
parents 77b986f6 3a067b4a
No related branches found
No related tags found
No related merge requests found
Showing
with 43 additions and 113 deletions
......@@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
object LogisticRegressionSuite {
......@@ -66,19 +67,7 @@ object LogisticRegressionSuite {
}
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
......
......@@ -23,7 +23,7 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LocalSparkContext
object NaiveBayesSuite {
......@@ -59,17 +59,7 @@ object NaiveBayesSuite {
}
}
class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class NaiveBayesSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOfPredictions = predictions.zip(input).count {
......
......@@ -25,8 +25,9 @@ import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
import org.apache.spark.{SparkException, SparkContext}
import org.apache.spark.SparkException
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
object SVMSuite {
......@@ -58,17 +59,7 @@ object SVMSuite {
}
class SVMSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class SVMSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
......
......@@ -21,20 +21,9 @@ package org.apache.spark.mllib.clustering
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LocalSparkContext
class KMeansSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class KMeansSuite extends FunSuite with LocalSparkContext {
val EPSILON = 1e-4
......
......@@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
object GradientDescentSuite {
......@@ -62,17 +63,7 @@ object GradientDescentSuite {
}
}
class GradientDescentSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
test("Assert the loss is decreasing.") {
val nPoints = 10000
......
......@@ -23,10 +23,10 @@ import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.jblas._
import org.apache.spark.mllib.util.LocalSparkContext
object ALSSuite {
def generateRatingsAsJavaList(
......@@ -73,17 +73,7 @@ object ALSSuite {
}
class ALSSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class ALSSuite extends FunSuite with LocalSparkContext {
test("rank-1 matrices") {
testALS(50, 100, 1, 15, 0.7, 0.3)
......
......@@ -22,21 +22,9 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LassoSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class LassoSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
......
......@@ -20,20 +20,9 @@ package org.apache.spark.mllib.regression
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
......
......@@ -22,20 +22,10 @@ import org.jblas.DoubleMatrix
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
predictions.zip(input).map { case (prediction, expected) =>
......
package org.apache.spark.mllib.util
import org.scalatest.Suite
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkContext
trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
@transient var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
super.beforeAll()
}
override def afterAll() {
if (sc != null) {
sc.stop()
}
System.clearProperty("spark.driver.port")
super.afterAll()
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment