Skip to content
Snippets Groups Projects
Commit 2210e8cc authored by Reynold Xin's avatar Reynold Xin
Browse files

Use a different validation dataset for Logistic Regression prediction testing.

parent 87a9dd89
No related branches found
No related tags found
No related merge requests found
...@@ -35,10 +35,11 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { ...@@ -35,10 +35,11 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
// Generate input of the form Y = logistic(offset + scale*X) // Generate input of the form Y = logistic(offset + scale*X)
def generateLogisticInput( def generateLogisticInput(
offset: Double, offset: Double,
scale: Double, scale: Double,
nPoints: Int) : Seq[(Double, Array[Double])] = { nPoints: Int,
val rnd = new Random(42) seed: Int): Seq[(Double, Array[Double])] = {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
// NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1) // NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1)
...@@ -60,12 +61,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { ...@@ -60,12 +61,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
} }
def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) { def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
val offPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
// A prediction is off if the prediction is more than 0.5 away from expected value. // A prediction is off if the prediction is more than 0.5 away from expected value.
math.abs(prediction - expected) > 0.5 math.abs(prediction - expected) > 0.5
}.size }.size
// At least 80% of the predictions should be on. // At least 80% of the predictions should be on.
assert(offPredictions < input.length / 5) assert(numOffPredictions < input.length / 5)
} }
// Test if we can correctly learn A, B where Y = logistic(A + B*X) // Test if we can correctly learn A, B where Y = logistic(A + B*X)
...@@ -74,7 +75,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { ...@@ -74,7 +75,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
val A = 2.0 val A = 2.0
val B = -1.5 val B = -1.5
val testData = generateLogisticInput(A, B, nPoints) val testData = generateLogisticInput(A, B, nPoints, 42)
val testRDD = sc.parallelize(testData, 2) val testRDD = sc.parallelize(testData, 2)
testRDD.cache() testRDD.cache()
...@@ -87,11 +88,13 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { ...@@ -87,11 +88,13 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
val validationData = generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD. // Test prediction on RDD.
validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData) validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
// Test prediction on Array. // Test prediction on Array.
validatePrediction(testData.map(row => model.predict(row._2)), testData) validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
} }
test("logistic regression with initial weights") { test("logistic regression with initial weights") {
...@@ -99,7 +102,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { ...@@ -99,7 +102,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
val A = 2.0 val A = 2.0
val B = -1.5 val B = -1.5
val testData = generateLogisticInput(A, B, nPoints) val testData = generateLogisticInput(A, B, nPoints, 42)
val initialB = -1.0 val initialB = -1.0
val initialWeights = Array(initialB) val initialWeights = Array(initialB)
...@@ -116,10 +119,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { ...@@ -116,10 +119,12 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
val validationData = generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD. // Test prediction on RDD.
validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData) validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
// Test prediction on Array. // Test prediction on Array.
validatePrediction(testData.map(row => model.predict(row._2)), testData) validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment