Skip to content
Snippets Groups Projects
Commit 217b5e91 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-3108][MLLIB] add predictOnValues to StreamingLR and fix predictOn

It is useful in streaming to allow users to carry extra data with the prediction, for monitoring the prediction error for example. freeman-lab

Author: Xiangrui Meng <meng@databricks.com>

Closes #2023 from mengxr/predict-on-values and squashes the following commits:

cac47b8 [Xiangrui Meng] add classtag
2821b3b [Xiangrui Meng] use mapValues
0925efa [Xiangrui Meng] add predictOnValues to StreamingLR and fix predictOn
parent c8b16ca0
No related branches found
No related tags found
No related merge requests found
......@@ -59,10 +59,10 @@ object StreamingLinearRegression {
val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0)))
.setInitialWeights(Vectors.zeros(args(3).toInt))
model.trainOn(trainingData)
model.predictOn(testData).print()
model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
ssc.start()
ssc.awaitTermination()
......
......@@ -17,8 +17,12 @@
package org.apache.spark.mllib.regression
import org.apache.spark.annotation.DeveloperApi
import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.DStream
/**
......@@ -92,15 +96,30 @@ abstract class StreamingLinearAlgorithm[
/**
* Use the model to make predictions on batches of data from a DStream
*
* @param data DStream containing labeled data
* @param data DStream containing feature vectors
* @return DStream containing predictions
*/
def predictOn(data: DStream[LabeledPoint]): DStream[Double] = {
def predictOn(data: DStream[Vector]): DStream[Double] = {
if (Option(model.weights) == None) {
logError("Initial weights must be set before starting prediction")
throw new IllegalArgumentException
val msg = "Initial weights must be set before starting prediction"
logError(msg)
throw new IllegalArgumentException(msg)
}
data.map(x => model.predict(x.features))
data.map(model.predict)
}
/**
* Use the model to make predictions on the values of a DStream and carry over its keys.
* @param data DStream containing feature vectors
* @tparam K key type
* @return DStream containing the input keys and the predictions as values
*/
def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = {
if (Option(model.weights) == None) {
val msg = "Initial weights must be set before starting prediction"
logError(msg)
throw new IllegalArgumentException(msg)
}
data.mapValues(model.predict)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment