Skip to content
Snippets Groups Projects
Commit 87a9dd89 authored by Reynold Xin's avatar Reynold Xin
Browse files

Made RegressionModel serializable and added unit tests to make sure predict methods would work.

parent 401aac8b
No related branches found
No related tags found
No related merge requests found
......@@ -39,9 +39,9 @@ object GradientDescent {
* @param miniBatchFraction - fraction of the input data set that should be used for
* one iteration of SGD. Default value 1.0.
*
* @return weights - Column matrix containing weights for every feature.
* @return stochasticLossHistory - Array containing the stochastic loss computed for
* every iteration.
* @return A tuple containing two elements. The first element is a column matrix containing
* weights for every feature, and the second element is an array containing the stochastic
* loss computed for every iteration.
*/
def runMiniBatchSGD(
data: RDD[(Double, Array[Double])],
......
......@@ -23,13 +23,13 @@ abstract class Updater extends Serializable {
/**
* Compute an updated value for weights given the gradient, stepSize and iteration number.
*
* @param weightsOld - Column matrix of size nx1 where n is the number of features.
* @param weightsOlds - Column matrix of size nx1 where n is the number of features.
* @param gradient - Column matrix of size nx1 where n is the number of features.
* @param stepSize - step size across iterations
* @param iter - Iteration number
*
* @return weightsNew - Column matrix containing updated weights
* @return reg_val - regularization value
* @return A tuple of 2 elements. The first element is a column matrix containing updated weights,
* and the second element is the regularization value.
*/
def compute(weightsOlds: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int):
(DoubleMatrix, Double)
......
......@@ -36,8 +36,12 @@ class LogisticRegressionModel(
private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
override def predict(testData: spark.RDD[Array[Double]]) = {
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed.
val localWeights = weightsMatrix
val localIntercept = intercept
testData.map { x =>
val margin = new DoubleMatrix(1, x.length, x:_*).mmul(weightsMatrix).get(0) + this.intercept
val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept
1.0/ (1.0 + math.exp(margin * -1))
}
}
......
......@@ -19,7 +19,7 @@ package spark.mllib.regression
import spark.RDD
trait RegressionModel {
trait RegressionModel extends Serializable {
/**
* Predict values for the given data set using the model trained.
*
......
......@@ -37,8 +37,11 @@ class RidgeRegressionModel(
extends RegressionModel {
override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
// A small optimization to avoid serializing the entire model.
val localIntercept = this.intercept
val localWeights = this.weights
testData.map { x =>
(new DoubleMatrix(1, x.length, x:_*).mmul(this.weights)).get(0) + this.intercept
(new DoubleMatrix(1, x.length, x:_*).mmul(localWeights)).get(0) + localIntercept
}
}
......
......@@ -23,7 +23,6 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
import spark.SparkContext._
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
......@@ -51,15 +50,24 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
// y <- A + B*x + rLogis()
// y <- as.numeric(y > 0)
val y = (0 until nPoints).map { i =>
val y: Seq[Double] = (0 until nPoints).map { i =>
val yVal = offset + scale * x1(i) + rLogis(i)
if (yVal > 0) 1.0 else 0.0
}
val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i))))
val testData = (0 until nPoints).map(i => (y(i), Array(x1(i))))
testData
}
def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
val offPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
// A prediction is off if the prediction is more than 0.5 away from expected value.
math.abs(prediction - expected) > 0.5
}.size
// At least 80% of the predictions should be on.
assert(offPredictions < input.length / 5)
}
// Test if we can correctly learn A, B where Y = logistic(A + B*X)
test("logistic regression") {
val nPoints = 10000
......@@ -70,14 +78,20 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val lr = new LogisticRegression().setStepSize(10.0)
.setNumIterations(20)
val lr = new LogisticRegression().setStepSize(10.0).setNumIterations(20)
val model = lr.train(testRDD)
// Test the weights
val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
// Test prediction on RDD.
validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData)
// Test prediction on Array.
validatePrediction(testData.map(row => model.predict(row._2)), testData)
}
test("logistic regression with initial weights") {
......@@ -94,13 +108,18 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
testRDD.cache()
// Use half as many iterations as the previous test.
val lr = new LogisticRegression().setStepSize(10.0)
.setNumIterations(10)
val lr = new LogisticRegression().setStepSize(10.0).setNumIterations(10)
val model = lr.train(testRDD, initialWeights)
val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
// Test prediction on RDD.
validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData)
// Test prediction on Array.
validatePrediction(testData.map(row => model.predict(row._2)), testData)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment