Skip to content
Snippets Groups Projects
Commit 0e23ca9f authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-5601][MLLIB] make streaming linear algorithms Java-friendly

Overload `trainOn`, `predictOn`, and `predictOnValues`.

CC freeman-lab

Author: Xiangrui Meng <meng@databricks.com>

Closes #4432 from mengxr/streaming-java and squashes the following commits:

6a79b85 [Xiangrui Meng] add java test for streaming logistic regression
2d7b357 [Xiangrui Meng] organize imports
1f662b3 [Xiangrui Meng] make streaming linear algorithms Java-friendly
parent c4021401
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,9 @@ import scala.reflect.ClassTag ...@@ -21,7 +21,9 @@ import scala.reflect.ClassTag
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream}
import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.dstream.DStream
/** /**
...@@ -76,7 +78,7 @@ abstract class StreamingLinearAlgorithm[ ...@@ -76,7 +78,7 @@ abstract class StreamingLinearAlgorithm[
* *
* @param data DStream containing labeled data * @param data DStream containing labeled data
*/ */
def trainOn(data: DStream[LabeledPoint]) { def trainOn(data: DStream[LabeledPoint]): Unit = {
if (model.isEmpty) { if (model.isEmpty) {
throw new IllegalArgumentException("Model must be initialized before starting training.") throw new IllegalArgumentException("Model must be initialized before starting training.")
} }
...@@ -99,6 +101,9 @@ abstract class StreamingLinearAlgorithm[ ...@@ -99,6 +101,9 @@ abstract class StreamingLinearAlgorithm[
} }
} }
/** Java-friendly version of `trainOn`. */
def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream)
/** /**
* Use the model to make predictions on batches of data from a DStream * Use the model to make predictions on batches of data from a DStream
* *
...@@ -112,6 +117,11 @@ abstract class StreamingLinearAlgorithm[ ...@@ -112,6 +117,11 @@ abstract class StreamingLinearAlgorithm[
data.map(model.get.predict) data.map(model.get.predict)
} }
/** Java-friendly version of `predictOn`. */
def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = {
JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]])
}
/** /**
* Use the model to make predictions on the values of a DStream and carry over its keys. * Use the model to make predictions on the values of a DStream and carry over its keys.
* @param data DStream containing feature vectors * @param data DStream containing feature vectors
...@@ -124,4 +134,12 @@ abstract class StreamingLinearAlgorithm[ ...@@ -124,4 +134,12 @@ abstract class StreamingLinearAlgorithm[
} }
data.mapValues(model.get.predict) data.mapValues(model.get.predict)
} }
/** Java-friendly version of `predictOnValues`. */
def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Double] = {
implicit val tag = fakeClassTag[K]
JavaPairDStream.fromPairDStream(
predictOnValues(data.dstream).asInstanceOf[DStream[(K, java.lang.Double)]])
}
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.List;
import scala.Tuple2;
import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.SparkConf;
import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStreamingLogisticRegressionSuite implements Serializable {
protected transient JavaStreamingContext ssc;
@Before
public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
ssc = new JavaStreamingContext(conf, new Duration(1000));
ssc.checkpoint("checkpoint");
}
@After
public void tearDown() {
ssc.stop();
ssc = null;
}
@Test
@SuppressWarnings("unchecked")
public void javaAPI() {
List<LabeledPoint> trainingBatch = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(1.0)),
new LabeledPoint(0.0, Vectors.dense(0.0)));
JavaDStream<LabeledPoint> training =
attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2);
List<Tuple2<Integer, Vector>> testBatch = Lists.newArrayList(
new Tuple2<Integer, Vector>(10, Vectors.dense(1.0)),
new Tuple2<Integer, Vector>(11, Vectors.dense(0.0)));
JavaPairDStream<Integer, Vector> test = JavaPairDStream.fromJavaDStream(
attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD()
.setNumIterations(2)
.setInitialWeights(Vectors.dense(0.0));
slr.trainOn(training);
JavaPairDStream<Integer, Double> prediction = slr.predictOnValues(test);
attachTestOutputStream(prediction.count());
runStreams(ssc, 2, 2);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.regression;
import java.io.Serializable;
import java.util.List;
import scala.Tuple2;
import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.SparkConf;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import static org.apache.spark.streaming.JavaTestUtils.*;
public class JavaStreamingLinearRegressionSuite implements Serializable {
protected transient JavaStreamingContext ssc;
@Before
public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
ssc = new JavaStreamingContext(conf, new Duration(1000));
ssc.checkpoint("checkpoint");
}
@After
public void tearDown() {
ssc.stop();
ssc = null;
}
@Test
@SuppressWarnings("unchecked")
public void javaAPI() {
List<LabeledPoint> trainingBatch = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(1.0)),
new LabeledPoint(0.0, Vectors.dense(0.0)));
JavaDStream<LabeledPoint> training =
attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2);
List<Tuple2<Integer, Vector>> testBatch = Lists.newArrayList(
new Tuple2<Integer, Vector>(10, Vectors.dense(1.0)),
new Tuple2<Integer, Vector>(11, Vectors.dense(0.0)));
JavaPairDStream<Integer, Vector> test = JavaPairDStream.fromJavaDStream(
attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2));
StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD()
.setNumIterations(2)
.setInitialWeights(Vectors.dense(0.0));
slr.trainOn(training);
JavaPairDStream<Integer, Double> prediction = slr.predictOnValues(test);
attachTestOutputStream(prediction.count());
runStreams(ssc, 2, 2);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment