Skip to content
Snippets Groups Projects
Commit 04fa1223 authored by Sean Owen's avatar Sean Owen Committed by Xiangrui Meng
Browse files

SPARK-2293. Replace RDD.zip usage by map with predict inside.

This is the only occurrence of this pattern in the examples that needs to be replaced. It only addresses the example change.

Author: Sean Owen <sowen@cloudera.com>

Closes #1250 from srowen/SPARK-2293 and squashes the following commits:

6b1b28c [Sean Owen] Compute prediction-and-label RDD directly rather than by zipping, for efficiency
parent 5fccb567
No related branches found
No related tags found
No related merge requests found
......@@ -51,9 +51,8 @@ val training = splits(0)
val test = splits(1)
val model = NaiveBayes.train(training, lambda = 1.0)
val prediction = model.predict(test.map(_.features))
val predictionAndLabel = prediction.zip(test.map(_.label))
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
{% endhighlight %}
</div>
......@@ -71,6 +70,7 @@ can be used for evaluation and prediction.
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.regression.LabeledPoint;
......@@ -81,18 +81,12 @@ JavaRDD<LabeledPoint> test = ... // test set
final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
JavaRDD<Double> prediction =
test.map(new Function<LabeledPoint, Double>() {
@Override public Double call(LabeledPoint p) {
return model.predict(p.features());
}
});
JavaPairRDD<Double, Double> predictionAndLabel =
prediction.zip(test.map(new Function<LabeledPoint, Double>() {
@Override public Double call(LabeledPoint p) {
return p.label();
test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
@Override public Tuple2<Double, Double> call(LabeledPoint p) {
return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
}));
});
double accuracy = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override public Boolean call(Tuple2<Double, Double> pl) {
return pl._1() == pl._2();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment