diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index ea60fab029582363d94f5599e9ecc2e4b3113570..3f763a10d4066d623e1dc085069be1acd34b0f6e 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -981,7 +981,7 @@ class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWrita
 
 @inherit_doc
 class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
-                 HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable):
+                 HasRawPredictionCol, HasThresholds, HasWeightCol, JavaMLWritable, JavaMLReadable):
     """
     Naive Bayes Classifiers.
     It supports both Multinomial and Bernoulli NB. `Multinomial NB
@@ -995,23 +995,23 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
     >>> from pyspark.sql import Row
     >>> from pyspark.ml.linalg import Vectors
     >>> df = spark.createDataFrame([
-    ...     Row(label=0.0, features=Vectors.dense([0.0, 0.0])),
-    ...     Row(label=0.0, features=Vectors.dense([0.0, 1.0])),
-    ...     Row(label=1.0, features=Vectors.dense([1.0, 0.0]))])
-    >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
+    ...     Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
+    ...     Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
+    ...     Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])
+    >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight")
     >>> model = nb.fit(df)
     >>> model.pi
-    DenseVector([-0.51..., -0.91...])
+    DenseVector([-0.81..., -0.58...])
     >>> model.theta
-    DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1)
+    DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)
     >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()
     >>> result = model.transform(test0).head()
     >>> result.prediction
     1.0
     >>> result.probability
-    DenseVector([0.42..., 0.57...])
+    DenseVector([0.32..., 0.67...])
     >>> result.rawPrediction
-    DenseVector([-1.60..., -1.32...])
+    DenseVector([-1.72..., -0.99...])
     >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
     >>> model.transform(test1).head().prediction
     1.0
@@ -1045,11 +1045,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
     @keyword_only
     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                  probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
-                 modelType="multinomial", thresholds=None):
+                 modelType="multinomial", thresholds=None, weightCol=None):
         """
         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
-                 modelType="multinomial", thresholds=None)
+                 modelType="multinomial", thresholds=None, weightCol=None)
         """
         super(NaiveBayes, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -1062,11 +1062,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
     @since("1.5.0")
     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                   probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0,
-                  modelType="multinomial", thresholds=None):
+                  modelType="multinomial", thresholds=None, weightCol=None):
         """
         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \
-                  modelType="multinomial", thresholds=None)
+                  modelType="multinomial", thresholds=None, weightCol=None)
         Sets params for Naive Bayes.
         """
         kwargs = self.setParams._input_kwargs