diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index b03a07a6bc1e79b33dc5665cd292361096670ad1..f1a7676c74b0ebbcdbbf6e244e8704a6b072c9e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -76,7 +76,7 @@ class NaiveBayes @Since("1.5.0") (
   extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
   with NaiveBayesParams with DefaultParamsWritable {
 
-  import NaiveBayes.{Bernoulli, Multinomial}
+  import NaiveBayes._
 
   @Since("1.5.0")
   def this() = this(Identifiable.randomUID("nb"))
@@ -110,21 +110,20 @@ class NaiveBayes @Since("1.5.0") (
   @Since("2.1.0")
   def setWeightCol(value: String): this.type = set(weightCol, value)
 
+  override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
+    trainWithLabelCheck(dataset, positiveLabel = true)
+  }
+
   /**
    * ml assumes input labels in range [0, numClasses). But this implementation
    * is also called by mllib NaiveBayes which allows other kinds of input labels
-   * such as {-1, +1}. Here we use this parameter to switch between different processing logic.
-   * It should be removed when we remove mllib NaiveBayes.
+   * such as {-1, +1}. `positiveLabel` is used to determine whether the label
+   * should be checked and it should be removed when we remove mllib NaiveBayes.
    */
-  private[spark] var isML: Boolean = true
-
-  private[spark] def setIsML(isML: Boolean): this.type = {
-    this.isML = isML
-    this
-  }
-
-  override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
-    if (isML) {
+  private[spark] def trainWithLabelCheck(
+      dataset: Dataset[_],
+      positiveLabel: Boolean): NaiveBayesModel = {
+    if (positiveLabel) {
       val numClasses = getNumClasses(dataset)
       if (isDefined(thresholds)) {
         require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -133,28 +132,9 @@ class NaiveBayes @Since("1.5.0") (
       }
     }
 
-    val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
-      val values = v match {
-        case sv: SparseVector => sv.values
-        case dv: DenseVector => dv.values
-      }
-
-      require(values.forall(_ >= 0.0),
-        s"Naive Bayes requires nonnegative feature values but found $v.")
-    }
-
-    val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
-      val values = v match {
-        case sv: SparseVector => sv.values
-        case dv: DenseVector => dv.values
-      }
-
-      require(values.forall(v => v == 0.0 || v == 1.0),
-        s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
-    }
-
+    val modelTypeValue = $(modelType)
     val requireValues: Vector => Unit = {
-      $(modelType) match {
+      modelTypeValue match {
         case Multinomial =>
           requireNonnegativeValues
         case Bernoulli =>
@@ -226,13 +206,33 @@ class NaiveBayes @Since("1.5.0") (
 @Since("1.6.0")
 object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
   /** String name for multinomial model type. */
-  private[spark] val Multinomial: String = "multinomial"
+  private[classification] val Multinomial: String = "multinomial"
 
   /** String name for Bernoulli model type. */
-  private[spark] val Bernoulli: String = "bernoulli"
+  private[classification] val Bernoulli: String = "bernoulli"
 
   /* Set of modelTypes that NaiveBayes supports */
-  private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
+  private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
+
+  private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = {
+    val values = v match {
+      case sv: SparseVector => sv.values
+      case dv: DenseVector => dv.values
+    }
+
+    require(values.forall(_ >= 0.0),
+      s"Naive Bayes requires nonnegative feature values but found $v.")
+  }
+
+  private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = {
+    val values = v match {
+      case sv: SparseVector => sv.values
+      case dv: DenseVector => dv.values
+    }
+
+    require(values.forall(v => v == 0.0 || v == 1.0),
+      s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
+  }
 
   @Since("1.6.0")
   override def load(path: String): NaiveBayes = super.load(path)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 33561be4b5bc1703d487540de395f8d490314aa6..767d056861a8b7bcc6694764eae57cd4cbf590f2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -364,12 +364,12 @@ class NaiveBayes private (
     val nb = new NewNaiveBayes()
       .setModelType(modelType)
       .setSmoothing(lambda)
-      .setIsML(false)
 
     val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) }
       .toDF("label", "features")
 
-    val newModel = nb.fit(dataset)
+    // mllib NaiveBayes allows input labels like {-1, +1}, so set `positiveLabel` as false.
+    val newModel = nb.trainWithLabelCheck(dataset, positiveLabel = false)
 
     val pi = newModel.pi.toArray
     val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0)
@@ -378,7 +378,7 @@ class NaiveBayes private (
         theta(i)(j) = v
     }
 
-    require(newModel.oldLabels != null,
+    assert(newModel.oldLabels != null,
       "The underlying ML NaiveBayes training does not produce labels.")
     new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType)
   }