Skip to content
Snippets Groups Projects
Commit f6a18993 authored by Jeremy Freeman's avatar Jeremy Freeman Committed by Xiangrui Meng
Browse files

Streaming mllib [SPARK-2438][MLLIB]

This PR implements a streaming linear regression analysis, in which a linear regression model is trained online as new data arrive. The design is based on discussions with tdas and mengxr, in which we determined how to add this functionality in a general way, with minimal changes to existing libraries.

__Summary of additions:__

_StreamingLinearAlgorithm_
- An abstract class for fitting generalized linear models online to streaming data, including training on (and updating) a model, and making predictions.

_StreamingLinearRegressionWithSGD_
- Class and companion object for running streaming linear regression

_StreamingLinearRegressionTestSuite_
- Unit tests

_StreamingLinearRegression_
- Example use case: fitting a model online to data from one stream, and making predictions on other data

__Notes__
- If this looks good, I can use the StreamingLinearAlgorithm class to easily implement other analyses that follow the same logic (Ridge, Lasso, Logistic, SVM).

Author: Jeremy Freeman <the.freeman.lab@gmail.com>
Author: freeman <the.freeman.lab@gmail.com>

Closes #1361 from freeman-lab/streaming-mllib and squashes the following commits:

775ea29 [Jeremy Freeman] Throw error if user doesn't initialize weights
4086fee [Jeremy Freeman] Fixed current weight formatting
8b95b27 [Jeremy Freeman] Restored broadcasting
29f27ec [Jeremy Freeman] Formatting
8711c41 [Jeremy Freeman] Used return to avoid indentation
777b596 [Jeremy Freeman] Restored treeAggregate
74cf440 [Jeremy Freeman] Removed static methods
d28cf9a [Jeremy Freeman] Added usage notes
c3326e7 [Jeremy Freeman] Improved documentation
9541a41 [Jeremy Freeman] Merge remote-tracking branch 'upstream/master' into streaming-mllib
66eba5e [Jeremy Freeman] Fixed line lengths
2fe0720 [Jeremy Freeman] Minor cleanup
7d51378 [Jeremy Freeman] Moved streaming loader to MLUtils
b9b69f6 [Jeremy Freeman] Added setter methods
c3f8b5a [Jeremy Freeman] Modified logging
00aafdc [Jeremy Freeman] Add modifiers
14b801e [Jeremy Freeman] Name changes
c7d38a3 [Jeremy Freeman] Move check for empty data to GradientDescent
4b0a5d3 [Jeremy Freeman] Cleaned up tests
74188d6 [Jeremy Freeman] Eliminate dependency on commons
50dd237 [Jeremy Freeman] Removed experimental tag
6bfe1e6 [Jeremy Freeman] Fixed imports
a2a63ad [freeman] Makes convergence test more robust
86220bc [freeman] Streaming linear regression unit tests
fb4683a [freeman] Minor changes for scalastyle consistency
fd31e03 [freeman] Changed logging behavior
453974e [freeman] Fixed indentation
c4b1143 [freeman] Streaming linear regression
604f4d7 [freeman] Expanded private class to include mllib
d99aa85 [freeman] Helper methods for streaming MLlib apps
0898add [freeman] Added dependency on streaming
parent e8e0fd69
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.examples.mllib
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD
import org.apache.spark.SparkConf
import org.apache.spark.streaming.{Seconds, StreamingContext}
/**
* Train a linear regression model on one stream of data and make predictions
* on another stream, where the data streams arrive as text files
* into two different directories.
*
* The rows of the text files must be labeled data points in the form
* `(y,[x1,x2,x3,...,xn])`
* Where n is the number of features. n must be the same for train and test.
*
* Usage: StreamingLinearRegression <trainingDir> <testDir> <batchDuration> <numFeatures>
*
* To run on your local machine using the two directories `trainingDir` and `testDir`,
* with updates every 5 seconds, and 2 features per data point, call:
* $ bin/run-example \
* org.apache.spark.examples.mllib.StreamingLinearRegression trainingDir testDir 5 2
*
* As you add text files to `trainingDir` the model will continuously update.
* Anytime you add text files to `testDir`, you'll see predictions from the current model.
*
*/
object StreamingLinearRegression {
def main(args: Array[String]) {
if (args.length != 4) {
System.err.println(
"Usage: StreamingLinearRegression <trainingDir> <testDir> <batchDuration> <numFeatures>")
System.exit(1)
}
val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression")
val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
val trainingData = MLUtils.loadStreamingLabeledPoints(ssc, args(0))
val testData = MLUtils.loadStreamingLabeledPoints(ssc, args(1))
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0)))
model.trainOn(trainingData)
model.predictOn(testData).print()
ssc.start()
ssc.awaitTermination()
}
}
......@@ -40,6 +40,11 @@
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
......
......@@ -162,6 +162,14 @@ object GradientDescent extends Logging {
val numExamples = data.count()
val miniBatchSize = numExamples * miniBatchFraction
// if no data, return initial weights to avoid NaNs
if (numExamples == 0) {
logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
return (initialWeights, stochasticLossHistory.toArray)
}
// Initialize weights as a column vector
var weights = Vectors.dense(initialWeights.toArray)
val n = weights.size
......@@ -202,5 +210,6 @@ object GradientDescent extends Logging {
stochasticLossHistory.takeRight(10).mkString(", ")))
(weights, stochasticLossHistory.toArray)
}
}
......@@ -49,7 +49,7 @@ class LinearRegressionModel (
* its corresponding right hand side label y.
* See also the documentation for the precise formulation.
*/
class LinearRegressionWithSGD private (
class LinearRegressionWithSGD private[mllib] (
private var stepSize: Double,
private var numIterations: Int,
private var miniBatchFraction: Double)
......@@ -68,7 +68,7 @@ class LinearRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
override protected[mllib] def createModel(weights: Vector, intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.regression
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.streaming.dstream.DStream
/**
* :: DeveloperApi ::
* StreamingLinearAlgorithm implements methods for continuously
* training a generalized linear model model on streaming data,
* and using it for prediction on (possibly different) streaming data.
*
* This class takes as type parameters a GeneralizedLinearModel,
* and a GeneralizedLinearAlgorithm, making it easy to extend to construct
* streaming versions of any analyses using GLMs.
* Initial weights must be set before calling trainOn or predictOn.
* Only weights will be updated, not an intercept. If the model needs
* an intercept, it should be manually appended to the input data.
*
* For example usage, see `StreamingLinearRegressionWithSGD`.
*
* NOTE(Freeman): In some use cases, the order in which trainOn and predictOn
* are called in an application will affect the results. When called on
* the same DStream, if trainOn is called before predictOn, when new data
* arrive the model will update and the prediction will be based on the new
* model. Whereas if predictOn is called first, the prediction will use the model
* from the previous update.
*
* NOTE(Freeman): It is ok to call predictOn repeatedly on multiple streams; this
* will generate predictions for each one all using the current model.
* It is also ok to call trainOn on different streams; this will update
* the model using each of the different sources, in sequence.
*
*/
@DeveloperApi
abstract class StreamingLinearAlgorithm[
M <: GeneralizedLinearModel,
A <: GeneralizedLinearAlgorithm[M]] extends Logging {
/** The model to be updated and used for prediction. */
protected var model: M
/** The algorithm to use for updating. */
protected val algorithm: A
/** Return the latest model. */
def latestModel(): M = {
model
}
/**
* Update the model by training on batches of data from a DStream.
* This operation registers a DStream for training the model,
* and updates the model based on every subsequent
* batch of data from the stream.
*
* @param data DStream containing labeled data
*/
def trainOn(data: DStream[LabeledPoint]) {
if (Option(model.weights) == None) {
logError("Initial weights must be set before starting training")
throw new IllegalArgumentException
}
data.foreachRDD { (rdd, time) =>
model = algorithm.run(rdd, model.weights)
logInfo("Model updated at time %s".format(time.toString))
val display = model.weights.size match {
case x if x > 100 => model.weights.toArray.take(100).mkString("[", ",", "...")
case _ => model.weights.toArray.mkString("[", ",", "]")
}
logInfo("Current model: weights, %s".format (display))
}
}
/**
* Use the model to make predictions on batches of data from a DStream
*
* @param data DStream containing labeled data
* @return DStream containing predictions
*/
def predictOn(data: DStream[LabeledPoint]): DStream[Double] = {
if (Option(model.weights) == None) {
logError("Initial weights must be set before starting prediction")
throw new IllegalArgumentException
}
data.map(x => model.predict(x.features))
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.regression
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
* Train or predict a linear regression model on streaming data. Training uses
* Stochastic Gradient Descent to update the model based on each new batch of
* incoming data from a DStream (see `LinearRegressionWithSGD` for model equation)
*
* Each batch of data is assumed to be an RDD of LabeledPoints.
* The number of data points per batch can vary, but the number
* of features must be constant. An initial weight
* vector must be provided.
*
* Use a builder pattern to construct a streaming linear regression
* analysis in an application, like:
*
* val model = new StreamingLinearRegressionWithSGD()
* .setStepSize(0.5)
* .setNumIterations(10)
* .setInitialWeights(Vectors.dense(...))
* .trainOn(DStream)
*
*/
@Experimental
class StreamingLinearRegressionWithSGD (
private var stepSize: Double,
private var numIterations: Int,
private var miniBatchFraction: Double,
private var initialWeights: Vector)
extends StreamingLinearAlgorithm[
LinearRegressionModel, LinearRegressionWithSGD] with Serializable {
/**
* Construct a StreamingLinearRegression object with default parameters:
* {stepSize: 0.1, numIterations: 50, miniBatchFraction: 1.0}.
* Initial weights must be set before using trainOn or predictOn
* (see `StreamingLinearAlgorithm`)
*/
def this() = this(0.1, 50, 1.0, null)
val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)
var model = algorithm.createModel(initialWeights, 0.0)
/** Set the step size for gradient descent. Default: 0.1. */
def setStepSize(stepSize: Double): this.type = {
this.algorithm.optimizer.setStepSize(stepSize)
this
}
/** Set the number of iterations of gradient descent to run per update. Default: 50. */
def setNumIterations(numIterations: Int): this.type = {
this.algorithm.optimizer.setNumIterations(numIterations)
this
}
/** Set the fraction of each batch to use for updates. Default: 1.0. */
def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction)
this
}
/** Set the initial weights. Default: [0.0, 0.0]. */
def setInitialWeights(initialWeights: Vector): this.type = {
this.model = algorithm.createModel(initialWeights, 0.0)
this
}
}
......@@ -30,6 +30,8 @@ import org.apache.spark.util.random.BernoulliSampler
import org.apache.spark.mllib.regression.{LabeledPointParser, LabeledPoint}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.dstream.DStream
/**
* Helper methods to load, save and pre-process data used in ML Lib.
......@@ -192,6 +194,19 @@ object MLUtils {
def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] =
loadLabeledPoints(sc, dir, sc.defaultMinPartitions)
/**
* Loads streaming labeled points from a stream of text files
* where points are in the same format as used in `RDD[LabeledPoint].saveAsTextFile`.
* See `StreamingContext.textFileStream` for more details on how to
* generate a stream from files
*
* @param ssc Streaming context
* @param dir Directory path in any Hadoop-supported file system URI
* @return Labeled points stored as a DStream[LabeledPoint]
*/
def loadStreamingLabeledPoints(ssc: StreamingContext, dir: String): DStream[LabeledPoint] =
ssc.textFileStream(dir).map(LabeledPointParser.parse)
/**
* Load labeled data from a file. The data format used here is
* <L>, <f1> <f2> ...
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.regression
import java.io.File
import java.nio.charset.Charset
import scala.collection.mutable.ArrayBuffer
import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext, MLUtils}
import org.apache.spark.streaming.{Milliseconds, StreamingContext}
import org.apache.spark.util.Utils
class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
// Assert that two values are equal within tolerance epsilon
def assertEqual(v1: Double, v2: Double, epsilon: Double) {
def errorMessage = v1.toString + " did not equal " + v2.toString
assert(math.abs(v1-v2) <= epsilon, errorMessage)
}
// Assert that model predictions are correct
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
// A prediction is off if the prediction is more than 0.5 away from expected value.
math.abs(prediction - expected.label) > 0.5
}
// At least 80% of the predictions should be on.
assert(numOffPredictions < input.length / 5)
}
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
test("streaming linear regression parameter accuracy") {
val testDir = Files.createTempDir()
val numBatches = 10
val batchDuration = Milliseconds(1000)
val ssc = new StreamingContext(sc, batchDuration)
val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString)
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0, 0.0))
.setStepSize(0.1)
.setNumIterations(50)
model.trainOn(data)
ssc.start()
// write data to a file stream
for (i <- 0 until numBatches) {
val samples = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
val file = new File(testDir, i.toString)
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
Thread.sleep(batchDuration.milliseconds)
}
ssc.stop(stopSparkContext=false)
System.clearProperty("spark.driver.port")
Utils.deleteRecursively(testDir)
// check accuracy of final parameter estimates
assertEqual(model.latestModel().intercept, 0.0, 0.1)
assertEqual(model.latestModel().weights(0), 10.0, 0.1)
assertEqual(model.latestModel().weights(1), 10.0, 0.1)
// check accuracy of predictions
val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17)
validatePrediction(validationData.map(row => model.latestModel().predict(row.features)),
validationData)
}
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
test("streaming linear regression parameter convergence") {
val testDir = Files.createTempDir()
val batchDuration = Milliseconds(2000)
val ssc = new StreamingContext(sc, batchDuration)
val numBatches = 5
val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString)
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0))
.setStepSize(0.1)
.setNumIterations(50)
model.trainOn(data)
ssc.start()
// write data to a file stream
val history = new ArrayBuffer[Double](numBatches)
for (i <- 0 until numBatches) {
val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1))
val file = new File(testDir, i.toString)
Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8"))
Thread.sleep(batchDuration.milliseconds)
// wait an extra few seconds to make sure the update finishes before new data arrive
Thread.sleep(4000)
history.append(math.abs(model.latestModel().weights(0) - 10.0))
}
ssc.stop(stopSparkContext=false)
System.clearProperty("spark.driver.port")
Utils.deleteRecursively(testDir)
val deltas = history.drop(1).zip(history.dropRight(1))
// check error stability (it always either shrinks, or increases with small tol)
assert(deltas.forall(x => (x._1 - x._2) <= 0.1))
// check that error shrunk on at least 2 batches
assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment