diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala index 6a8098b59d3345c45b4124b4e78c09d051bd6a57..0a99b78cf859b0f388fc1fc98c7d0f80a06c366e 100644 --- a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala @@ -35,10 +35,11 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { // Generate input of the form Y = logistic(offset + scale*X) def generateLogisticInput( - offset: Double, - scale: Double, - nPoints: Int) : Seq[(Double, Array[Double])] = { - val rnd = new Random(42) + offset: Double, + scale: Double, + nPoints: Int, + seed: Int): Seq[(Double, Array[Double])] = { + val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) // NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1) @@ -60,12 +61,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { } def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) { - val offPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => + val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => // A prediction is off if the prediction is more than 0.5 away from expected value. math.abs(prediction - expected) > 0.5 }.size // At least 80% of the predictions should be on. - assert(offPredictions < input.length / 5) + assert(numOffPredictions < input.length / 5) } // Test if we can correctly learn A, B where Y = logistic(A + B*X) @@ -74,7 +75,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { val A = 2.0 val B = -1.5 - val testData = generateLogisticInput(A, B, nPoints) + val testData = generateLogisticInput(A, B, nPoints, 42) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -87,11 +88,13 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + val validationData = generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData) + validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData) // Test prediction on Array. - validatePrediction(testData.map(row => model.predict(row._2)), testData) + validatePrediction(validationData.map(row => model.predict(row._2)), validationData) } test("logistic regression with initial weights") { @@ -99,7 +102,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { val A = 2.0 val B = -1.5 - val testData = generateLogisticInput(A, B, nPoints) + val testData = generateLogisticInput(A, B, nPoints, 42) val initialB = -1.0 val initialWeights = Array(initialB) @@ -116,10 +119,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + val validationData = generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. - validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData) + validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData) // Test prediction on Array. - validatePrediction(testData.map(row => model.predict(row._2)), testData) + validatePrediction(validationData.map(row => model.predict(row._2)), validationData) } }