diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 4b3a7cab32118a51660cdde29b74df628083e02f..1d1d7dcf6ffcbeede9fc0153ac0719d42e392a59 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -51,9 +51,8 @@ val training = splits(0) val test = splits(1) val model = NaiveBayes.train(training, lambda = 1.0) -val prediction = model.predict(test.map(_.features)) -val predictionAndLabel = prediction.zip(test.map(_.label)) +val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() {% endhighlight %} </div> @@ -71,6 +70,7 @@ can be used for evaluation and prediction. import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.regression.LabeledPoint; @@ -81,18 +81,12 @@ JavaRDD<LabeledPoint> test = ... // test set final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0); -JavaRDD<Double> prediction = - test.map(new Function<LabeledPoint, Double>() { - @Override public Double call(LabeledPoint p) { - return model.predict(p.features()); - } - }); JavaPairRDD<Double, Double> predictionAndLabel = - prediction.zip(test.map(new Function<LabeledPoint, Double>() { - @Override public Double call(LabeledPoint p) { - return p.label(); + test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { + @Override public Tuple2<Double, Double> call(LabeledPoint p) { + return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); } - })); + }); double accuracy = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { @Override public Boolean call(Tuple2<Double, Double> pl) { return pl._1() == pl._2();