From fbef566a107b47e5fddde0ea65b8587d5039062d Mon Sep 17 00:00:00 2001 From: Yanbo Liang <ybliang8@gmail.com> Date: Fri, 31 Jul 2015 13:11:42 -0700 Subject: [PATCH] [SPARK-9308] [ML] ml.NaiveBayesModel support predicting class probabilities Make NaiveBayesModel support predicting class probabilities, inherit from ProbabilisticClassificationModel. Author: Yanbo Liang <ybliang8@gmail.com> Closes #7672 from yanboliang/spark-9308 and squashes the following commits: 25e224c [Yanbo Liang] raw2probabilityInPlace should operate in-place 3ee56d6 [Yanbo Liang] change predictRaw and raw2probabilityInPlace c07e7a2 [Yanbo Liang] ml.NaiveBayesModel support predicting class probabilities --- .../spark/ml/classification/NaiveBayes.scala | 65 ++++++++++++++----- .../ml/classification/NaiveBayesSuite.scala | 54 ++++++++++++++- 2 files changed, 101 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 5be35fe209..b46b676204 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * The input feature values must be nonnegative. */ class NaiveBayes(override val uid: String) - extends Predictor[Vector, NaiveBayes, NaiveBayesModel] + extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams { def this() = this(Identifiable.randomUID("nb")) @@ -106,7 +106,7 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -129,29 +129,62 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } - override protected def predict(features: Vector): Double = { + override val numClasses: Int = pi.size + + private def multinomialCalculation(features: Vector) = { + val prob = theta.multiply(features) + BLAS.axpy(1.0, pi, prob) + prob + } + + private def bernoulliCalculation(features: Vector) = { + features.foreachActive((_, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") + } + ) + val prob = thetaMinusNegTheta.get.multiply(features) + BLAS.axpy(1.0, pi, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob + } + + override protected def predictRaw(features: Vector): Vector = { $(modelType) match { case Multinomial => - val prob = theta.multiply(features) - BLAS.axpy(1.0, pi, prob) - prob.argmax + multinomialCalculation(features) case Bernoulli => - features.foreachActive{ (index, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features") - } - } - val prob = thetaMinusNegTheta.get.multiply(features) - BLAS.axpy(1.0, pi, prob) - BLAS.axpy(1.0, negThetaSum.get, prob) - prob.argmax + bernoulliCalculation(features) case _ => // This should never happen. throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } } + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val maxLog = dv.values.max + while (i < size) { + dv.values(i) = math.exp(dv.values(i) - maxLog) + i += 1 + } + val probSum = dv.values.sum + i = 0 + while (i < size) { + dv.values(i) = dv.values(i) / probSum + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in NaiveBayesModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 264bde3703..aea3d9b694 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.ml.classification +import breeze.linalg.{Vector => BV} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.classification.NaiveBayes import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -28,6 +31,8 @@ import org.apache.spark.sql.Row class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + import NaiveBayes.{Multinomial, Bernoulli} + def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { case Row(prediction: Double, label: Double) => @@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch") } + def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v))) + val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v)) + val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def validateProbabilities( + featureAndProbabilities: DataFrame, + model: NaiveBayesModel, + modelType: String): Unit = { + featureAndProbabilities.collect().foreach { + case Row(features: Vector, probability: Vector) => { + assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) + val expected = modelType match { + case Multinomial => + expectedMultinomialProbabilities(model, features) + case Bernoulli => + expectedBernoulliProbabilities(model, features) + case _ => + throw new UnknownError(s"Invalid modelType: $modelType.") + } + assert(probability ~== expected relTol 1.0e-10) + } + } + } + test("params") { ParamsSuite.checkParams(new NaiveBayes) val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), @@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 17, "multinomial")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "multinomial") } test("Naive Bayes Bernoulli") { @@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 20, "bernoulli")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "bernoulli") } } -- GitLab