Skip to content
Snippets Groups Projects
Commit fbef566a authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-9308] [ML] ml.NaiveBayesModel support predicting class probabilities

Make NaiveBayesModel support predicting class probabilities, inherit from ProbabilisticClassificationModel.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #7672 from yanboliang/spark-9308 and squashes the following commits:

25e224c [Yanbo Liang] raw2probabilityInPlace should operate in-place
3ee56d6 [Yanbo Liang] change predictRaw and raw2probabilityInPlace
c07e7a2 [Yanbo Liang] ml.NaiveBayesModel support predicting class probabilities
parent 060c79aa
No related branches found
No related tags found
No related merge requests found
......@@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
* The input feature values must be nonnegative.
*/
class NaiveBayes(override val uid: String)
extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams {
def this() = this(Identifiable.randomUID("nb"))
......@@ -106,7 +106,7 @@ class NaiveBayesModel private[ml] (
override val uid: String,
val pi: Vector,
val theta: Matrix)
extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
import OldNaiveBayes.{Bernoulli, Multinomial}
......@@ -129,29 +129,62 @@ class NaiveBayesModel private[ml] (
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
override protected def predict(features: Vector): Double = {
override val numClasses: Int = pi.size
private def multinomialCalculation(features: Vector) = {
val prob = theta.multiply(features)
BLAS.axpy(1.0, pi, prob)
prob
}
private def bernoulliCalculation(features: Vector) = {
features.foreachActive((_, value) =>
if (value != 0.0 && value != 1.0) {
throw new SparkException(
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.")
}
)
val prob = thetaMinusNegTheta.get.multiply(features)
BLAS.axpy(1.0, pi, prob)
BLAS.axpy(1.0, negThetaSum.get, prob)
prob
}
override protected def predictRaw(features: Vector): Vector = {
$(modelType) match {
case Multinomial =>
val prob = theta.multiply(features)
BLAS.axpy(1.0, pi, prob)
prob.argmax
multinomialCalculation(features)
case Bernoulli =>
features.foreachActive{ (index, value) =>
if (value != 0.0 && value != 1.0) {
throw new SparkException(
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
}
}
val prob = thetaMinusNegTheta.get.multiply(features)
BLAS.axpy(1.0, pi, prob)
BLAS.axpy(1.0, negThetaSum.get, prob)
prob.argmax
bernoulliCalculation(features)
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
}
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
var i = 0
val size = dv.size
val maxLog = dv.values.max
while (i < size) {
dv.values(i) = math.exp(dv.values(i) - maxLog)
i += 1
}
val probSum = dv.values.sum
i = 0
while (i < size) {
dv.values(i) = dv.values(i) / probSum
i += 1
}
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in NaiveBayesModel:" +
" raw2probabilityInPlace encountered SparseVector")
}
}
override def copy(extra: ParamMap): NaiveBayesModel = {
copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
}
......
......@@ -17,8 +17,11 @@
package org.apache.spark.ml.classification
import breeze.linalg.{Vector => BV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
......@@ -28,6 +31,8 @@ import org.apache.spark.sql.Row
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
import NaiveBayes.{Multinomial, Bernoulli}
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
val numOfErrorPredictions = predictionAndLabels.collect().count {
case Row(prediction: Double, label: Double) =>
......@@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
}
def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
val classProbs = logClassProbs.toArray.map(math.exp)
val classProbsSum = classProbs.sum
Vectors.dense(classProbs.map(_ / classProbsSum))
}
def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v)))
val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v))
val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze
val classProbs = logClassProbs.toArray.map(math.exp)
val classProbsSum = classProbs.sum
Vectors.dense(classProbs.map(_ / classProbsSum))
}
def validateProbabilities(
featureAndProbabilities: DataFrame,
model: NaiveBayesModel,
modelType: String): Unit = {
featureAndProbabilities.collect().foreach {
case Row(features: Vector, probability: Vector) => {
assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
val expected = modelType match {
case Multinomial =>
expectedMultinomialProbabilities(model, features)
case Bernoulli =>
expectedBernoulliProbabilities(model, features)
case _ =>
throw new UnknownError(s"Invalid modelType: $modelType.")
}
assert(probability ~== expected relTol 1.0e-10)
}
}
}
test("params") {
ParamsSuite.checkParams(new NaiveBayes)
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
......@@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 17, "multinomial"))
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
val featureAndProbabilities = model.transform(validationDataset)
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "multinomial")
}
test("Naive Bayes Bernoulli") {
......@@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 20, "bernoulli"))
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
val featureAndProbabilities = model.transform(validationDataset)
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "bernoulli")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment