diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 4b3a7cab32118a51660cdde29b74df628083e02f..1d1d7dcf6ffcbeede9fc0153ac0719d42e392a59 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -51,9 +51,8 @@ val training = splits(0)
 val test = splits(1)
 
 val model = NaiveBayes.train(training, lambda = 1.0)
-val prediction = model.predict(test.map(_.features))
 
-val predictionAndLabel = prediction.zip(test.map(_.label))
+val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
 val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
 {% endhighlight %}
 </div>
@@ -71,6 +70,7 @@ can be used for evaluation and prediction.
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.mllib.classification.NaiveBayes;
 import org.apache.spark.mllib.classification.NaiveBayesModel;
 import org.apache.spark.mllib.regression.LabeledPoint;
@@ -81,18 +81,12 @@ JavaRDD<LabeledPoint> test = ... // test set
 
 final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
 
-JavaRDD<Double> prediction =
-  test.map(new Function<LabeledPoint, Double>() {
-    @Override public Double call(LabeledPoint p) {
-      return model.predict(p.features());
-    }
-  });
 JavaPairRDD<Double, Double> predictionAndLabel = 
-  prediction.zip(test.map(new Function<LabeledPoint, Double>() {
-    @Override public Double call(LabeledPoint p) {
-      return p.label();
+  test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+    @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+      return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
     }
-  }));
+  });
 double accuracy = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
     @Override public Boolean call(Tuple2<Double, Double> pl) {
       return pl._1() == pl._2();