Skip to content
Snippets Groups Projects
Commit b2ba83d1 authored by Yanbo Liang's avatar Yanbo Liang
Browse files

[SPARK-14077][ML][FOLLOW-UP] Minor refactor and cleanup for NaiveBayes


## What changes were proposed in this pull request?
* Refactor out ```trainWithLabelCheck``` and make ```mllib.NaiveBayes``` call into it.
* Avoid capturing the outer object for ```modelType```.
* Move ```requireNonnegativeValues``` and ```requireZeroOneBernoulliValues``` to companion object.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #15826 from yanboliang/spark-14077-2.

(cherry picked from commit 22cb3a06)
Signed-off-by: default avatarYanbo Liang <ybliang8@gmail.com>
parent 89335514
No related branches found
No related tags found
No related merge requests found
......@@ -76,7 +76,7 @@ class NaiveBayes @Since("1.5.0") (
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams with DefaultParamsWritable {
import NaiveBayes.{Bernoulli, Multinomial}
import NaiveBayes._
@Since("1.5.0")
def this() = this(Identifiable.randomUID("nb"))
......@@ -110,21 +110,20 @@ class NaiveBayes @Since("1.5.0") (
@Since("2.1.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
trainWithLabelCheck(dataset, positiveLabel = true)
}
/**
* ml assumes input labels in range [0, numClasses). But this implementation
* is also called by mllib NaiveBayes which allows other kinds of input labels
* such as {-1, +1}. Here we use this parameter to switch between different processing logic.
* It should be removed when we remove mllib NaiveBayes.
* such as {-1, +1}. `positiveLabel` is used to determine whether the label
* should be checked and it should be removed when we remove mllib NaiveBayes.
*/
private[spark] var isML: Boolean = true
private[spark] def setIsML(isML: Boolean): this.type = {
this.isML = isML
this
}
override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
if (isML) {
private[spark] def trainWithLabelCheck(
dataset: Dataset[_],
positiveLabel: Boolean): NaiveBayesModel = {
if (positiveLabel) {
val numClasses = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
......@@ -133,28 +132,9 @@ class NaiveBayes @Since("1.5.0") (
}
}
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
}
require(values.forall(_ >= 0.0),
s"Naive Bayes requires nonnegative feature values but found $v.")
}
val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
val values = v match {
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
}
require(values.forall(v => v == 0.0 || v == 1.0),
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
}
val modelTypeValue = $(modelType)
val requireValues: Vector => Unit = {
$(modelType) match {
modelTypeValue match {
case Multinomial =>
requireNonnegativeValues
case Bernoulli =>
......@@ -226,13 +206,33 @@ class NaiveBayes @Since("1.5.0") (
@Since("1.6.0")
object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
/** String name for multinomial model type. */
private[spark] val Multinomial: String = "multinomial"
private[classification] val Multinomial: String = "multinomial"
/** String name for Bernoulli model type. */
private[spark] val Bernoulli: String = "bernoulli"
private[classification] val Bernoulli: String = "bernoulli"
/* Set of modelTypes that NaiveBayes supports */
private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli)
private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = {
val values = v match {
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
}
require(values.forall(_ >= 0.0),
s"Naive Bayes requires nonnegative feature values but found $v.")
}
private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = {
val values = v match {
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
}
require(values.forall(v => v == 0.0 || v == 1.0),
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
}
@Since("1.6.0")
override def load(path: String): NaiveBayes = super.load(path)
......
......@@ -364,12 +364,12 @@ class NaiveBayes private (
val nb = new NewNaiveBayes()
.setModelType(modelType)
.setSmoothing(lambda)
.setIsML(false)
val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) }
.toDF("label", "features")
val newModel = nb.fit(dataset)
// mllib NaiveBayes allows input labels like {-1, +1}, so set `positiveLabel` as false.
val newModel = nb.trainWithLabelCheck(dataset, positiveLabel = false)
val pi = newModel.pi.toArray
val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0)
......@@ -378,7 +378,7 @@ class NaiveBayes private (
theta(i)(j) = v
}
require(newModel.oldLabels != null,
assert(newModel.oldLabels != null,
"The underlying ML NaiveBayes training does not produce labels.")
new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment