From 69b62f76fced18efa35a107c9be4bc22eba72878 Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Thu, 30 Jul 2015 23:03:48 -0700
Subject: [PATCH] [SPARK-9214] [ML] [PySpark] support ml.NaiveBayes for Python

support ml.NaiveBayes for Python

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #7568 from yanboliang/spark-9214 and squashes the following commits:

5ee3fd6 [Yanbo Liang] fix typos
3ecd046 [Yanbo Liang] fix typos
f9c94d1 [Yanbo Liang] change lambda_ to smoothing and fix other issues
180452a [Yanbo Liang] fix typos
7dda1f4 [Yanbo Liang] support ml.NaiveBayes for Python
---
 .../spark/ml/classification/NaiveBayes.scala  |  10 +-
 .../classification/JavaNaiveBayesSuite.java   |   4 +-
 .../ml/classification/NaiveBayesSuite.scala   |   6 +-
 python/pyspark/ml/classification.py           | 116 +++++++++++++++++-
 4 files changed, 125 insertions(+), 11 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 1f547e4a98..5be35fe209 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -38,11 +38,11 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
    * (default = 1.0).
    * @group param
    */
-  final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.",
+  final val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "The smoothing parameter.",
     ParamValidators.gtEq(0))
 
   /** @group getParam */
-  final def getLambda: Double = $(lambda)
+  final def getSmoothing: Double = $(smoothing)
 
   /**
    * The model type which is a string (case-sensitive).
@@ -79,8 +79,8 @@ class NaiveBayes(override val uid: String)
    * Default is 1.0.
    * @group setParam
    */
-  def setLambda(value: Double): this.type = set(lambda, value)
-  setDefault(lambda -> 1.0)
+  def setSmoothing(value: Double): this.type = set(smoothing, value)
+  setDefault(smoothing -> 1.0)
 
   /**
    * Set the model type using a string (case-sensitive).
@@ -92,7 +92,7 @@ class NaiveBayes(override val uid: String)
 
   override protected def train(dataset: DataFrame): NaiveBayesModel = {
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
-    val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType))
+    val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
     NaiveBayesModel.fromOld(oldModel, this)
   }
 
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 09a9fba0c1..a700c9cddb 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -68,7 +68,7 @@ public class JavaNaiveBayesSuite implements Serializable {
     assert(nb.getLabelCol() == "label");
     assert(nb.getFeaturesCol() == "features");
     assert(nb.getPredictionCol() == "prediction");
-    assert(nb.getLambda() == 1.0);
+    assert(nb.getSmoothing() == 1.0);
     assert(nb.getModelType() == "multinomial");
   }
 
@@ -89,7 +89,7 @@ public class JavaNaiveBayesSuite implements Serializable {
     });
 
     DataFrame dataset = jsql.createDataFrame(jrdd, schema);
-    NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial");
+    NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
     NaiveBayesModel model = nb.fit(dataset);
 
     DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 76381a2741..264bde3703 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -58,7 +58,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(nb.getLabelCol === "label")
     assert(nb.getFeaturesCol === "features")
     assert(nb.getPredictionCol === "prediction")
-    assert(nb.getLambda === 1.0)
+    assert(nb.getSmoothing === 1.0)
     assert(nb.getModelType === "multinomial")
   }
 
@@ -75,7 +75,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
 
     val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
       piArray, thetaArray, nPoints, 42, "multinomial"))
-    val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial")
+    val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
     val model = nb.fit(testDataset)
 
     validateModelFit(pi, theta, model)
@@ -101,7 +101,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
 
     val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
       piArray, thetaArray, nPoints, 45, "bernoulli"))
-    val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli")
+    val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
     val model = nb.fit(testDataset)
 
     validateModelFit(pi, theta, model)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 5a82bc286d..93ffcd4094 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -25,7 +25,8 @@ from pyspark.mllib.common import inherit_doc
 
 __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier',
            'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel',
-           'RandomForestClassifier', 'RandomForestClassificationModel']
+           'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes',
+           'NaiveBayesModel']
 
 
 @inherit_doc
@@ -576,6 +577,119 @@ class GBTClassificationModel(TreeEnsembleModels):
     """
 
 
+@inherit_doc
+class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol):
+    """
+    Naive Bayes Classifiers.
+
+    >>> from pyspark.sql import Row
+    >>> from pyspark.mllib.linalg import Vectors
+    >>> df = sqlContext.createDataFrame([
+    ...     Row(label=0.0, features=Vectors.dense([0.0, 0.0])),
+    ...     Row(label=0.0, features=Vectors.dense([0.0, 1.0])),
+    ...     Row(label=1.0, features=Vectors.dense([1.0, 0.0]))])
+    >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
+    >>> model = nb.fit(df)
+    >>> model.pi
+    DenseVector([-0.51..., -0.91...])
+    >>> model.theta
+    DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1)
+    >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
+    >>> model.transform(test0).head().prediction
+    1.0
+    >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
+    >>> model.transform(test1).head().prediction
+    1.0
+    """
+
+    # a placeholder to make it appear in the generated doc
+    smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
+                      "default is 1.0")
+    modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
+                      "(case-sensitive). Supported options: multinomial (default) and bernoulli.")
+
+    @keyword_only
+    def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+                 smoothing=1.0, modelType="multinomial"):
+        """
+        __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+                 smoothing=1.0, modelType="multinomial")
+        """
+        super(NaiveBayes, self).__init__()
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.classification.NaiveBayes", self.uid)
+        #: param for the smoothing parameter.
+        self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " +
+                               "default is 1.0")
+        #: param for the model type.
+        self.modelType = Param(self, "modelType", "The model type which is a string " +
+                               "(case-sensitive). Supported options: multinomial (default) " +
+                               "and bernoulli.")
+        self._setDefault(smoothing=1.0, modelType="multinomial")
+        kwargs = self.__init__._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+                  smoothing=1.0, modelType="multinomial"):
+        """
+        setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+                  smoothing=1.0, modelType="multinomial")
+        Sets params for Naive Bayes.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
+    def _create_model(self, java_model):
+        return NaiveBayesModel(java_model)
+
+    def setSmoothing(self, value):
+        """
+        Sets the value of :py:attr:`smoothing`.
+        """
+        self._paramMap[self.smoothing] = value
+        return self
+
+    def getSmoothing(self):
+        """
+        Gets the value of smoothing or its default value.
+        """
+        return self.getOrDefault(self.smoothing)
+
+    def setModelType(self, value):
+        """
+        Sets the value of :py:attr:`modelType`.
+        """
+        self._paramMap[self.modelType] = value
+        return self
+
+    def getModelType(self):
+        """
+        Gets the value of modelType or its default value.
+        """
+        return self.getOrDefault(self.modelType)
+
+
+class NaiveBayesModel(JavaModel):
+    """
+    Model fitted by NaiveBayes.
+    """
+
+    @property
+    def pi(self):
+        """
+        log of class priors.
+        """
+        return self._call_java("pi")
+
+    @property
+    def theta(self):
+        """
+        log of class conditional probabilities.
+        """
+        return self._call_java("theta")
+
+
 if __name__ == "__main__":
     import doctest
     from pyspark.context import SparkContext
-- 
GitLab