Skip to content
Snippets Groups Projects
Commit 6e1c1ec6 authored by freeman's avatar freeman Committed by Xiangrui Meng
Browse files

[SPARK-6345][STREAMING][MLLIB] Fix for training with prediction

This patch fixes a reported bug causing model updates to not properly propagate to model predictions during streaming regression. These minor changes in model declaration fix the problem, and I expanded the tests to include the scenario in which the bug was arising. The two new tests failed prior to the patch and now pass.

cc mengxr

Author: freeman <the.freeman.lab@gmail.com>

Closes #5037 from freeman-lab/train-predict-fix and squashes the following commits:

3af953e [freeman] Expand test coverage to include combined training and prediction
8f84fc8 [freeman] Move model declaration
parent 8a0aa81c
No related branches found
No related tags found
No related merge requests found
......@@ -63,6 +63,8 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
protected val algorithm = new LogisticRegressionWithSGD(
stepSize, numIterations, regParam, miniBatchFraction)
protected var model: Option[LogisticRegressionModel] = None
/** Set the step size for gradient descent. Default: 0.1. */
def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize)
......
......@@ -60,7 +60,7 @@ abstract class StreamingLinearAlgorithm[
A <: GeneralizedLinearAlgorithm[M]] extends Logging {
/** The model to be updated and used for prediction. */
protected var model: Option[M] = None
protected var model: Option[M]
/** The algorithm to use for updating. */
protected val algorithm: A
......@@ -114,7 +114,7 @@ abstract class StreamingLinearAlgorithm[
if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting prediction.")
}
data.map(model.get.predict)
data.map{x => model.get.predict(x)}
}
/** Java-friendly version of `predictOn`. */
......@@ -132,7 +132,7 @@ abstract class StreamingLinearAlgorithm[
if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting prediction")
}
data.mapValues(model.get.predict)
data.mapValues{x => model.get.predict(x)}
}
......
......@@ -59,6 +59,8 @@ class StreamingLinearRegressionWithSGD private[mllib] (
val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)
protected var model: Option[LinearRegressionModel] = None
/** Set the step size for gradient descent. Default: 0.1. */
def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize)
......
......@@ -132,4 +132,31 @@ class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase {
assert(errors.forall(x => x <= 0.4))
}
// Test training combined with prediction
test("training and prediction") {
// create model initialized with zero weights
val model = new StreamingLogisticRegressionWithSGD()
.setInitialWeights(Vectors.dense(-0.1))
.setStepSize(0.01)
.setNumIterations(10)
// generate sequence of simulated data for testing
val numBatches = 10
val nPoints = 100
val testInput = (0 until numBatches).map { i =>
LogisticRegressionSuite.generateLogisticInput(0.0, 5.0, nPoints, 42 * (i + 1))
}
// train and predict
val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
// assert that prediction error improves, ensuring that the updated model is being used
val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
assert(error.head > 0.8 & error.last < 0.2)
}
}
......@@ -139,4 +139,32 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
assert(errors.forall(x => x <= 0.1))
}
// Test training combined with prediction
test("training and prediction") {
// create model initialized with zero weights
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0, 0.0))
.setStepSize(0.2)
.setNumIterations(25)
// generate sequence of simulated data for testing
val numBatches = 10
val nPoints = 100
val testInput = (0 until numBatches).map { i =>
LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1))
}
// train and predict
val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
model.trainOn(inputDStream)
model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
})
val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
// assert that prediction error improves, ensuring that the updated model is being used
val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
assert((error.head - error.last) > 2)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment