Skip to content
Snippets Groups Projects
Commit 5cd3e6f6 authored by movelikeriver's avatar movelikeriver Committed by Xiangrui Meng
Browse files

[SPARK-13257][IMPROVEMENT] Refine naive Bayes example by checking model after loading it

Refine naive Bayes example by checking model after loading it

Author: movelikeriver <mars.lenjoy@gmail.com>

Closes #11125 from movelikeriver/naive_bayes.
parent 764ca180
No related branches found
No related tags found
No related merge requests found
...@@ -17,9 +17,15 @@ ...@@ -17,9 +17,15 @@
""" """
NaiveBayes Example. NaiveBayes Example.
Usage:
`spark-submit --master local[4] examples/src/main/python/mllib/naive_bayes_example.py`
""" """
from __future__ import print_function from __future__ import print_function
import shutil
from pyspark import SparkContext from pyspark import SparkContext
# $example on$ # $example on$
from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel
...@@ -50,8 +56,15 @@ if __name__ == "__main__": ...@@ -50,8 +56,15 @@ if __name__ == "__main__":
# Make prediction and test accuracy. # Make prediction and test accuracy.
predictionAndLabel = test.map(lambda p: (model.predict(p.features), p.label)) predictionAndLabel = test.map(lambda p: (model.predict(p.features), p.label))
accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count()
print('model accuracy {}'.format(accuracy))
# Save and load model # Save and load model
model.save(sc, "target/tmp/myNaiveBayesModel") output_dir = 'target/tmp/myNaiveBayesModel'
sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel") shutil.rmtree(output_dir, ignore_errors=True)
model.save(sc, output_dir)
sameModel = NaiveBayesModel.load(sc, output_dir)
predictionAndLabel = test.map(lambda p: (sameModel.predict(p.features), p.label))
accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count()
print('sameModel accuracy {}'.format(accuracy))
# $example off$ # $example off$
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment