Skip to content
Snippets Groups Projects
Commit 740b034f authored by Sean Owen's avatar Sean Owen
Browse files

[SPARK-4362] [MLLIB] Make prediction probability available in NaiveBayesModel

Add predictProbabilities to Naive Bayes, return class probabilities.

Continues https://github.com/apache/spark/pull/6761

Author: Sean Owen <sowen@cloudera.com>

Closes #7376 from srowen/SPARK-4362 and squashes the following commits:

23d5a76 [Sean Owen] Fix model.labels -> model.theta
95d91fb [Sean Owen] Check that predicted probabilities sum to 1
b32d1c8 [Sean Owen] Add predictProbabilities to Naive Bayes, return class probabilities
parent 4b5cfc98
No related branches found
No related tags found
No related merge requests found
...@@ -93,26 +93,70 @@ class NaiveBayesModel private[mllib] ( ...@@ -93,26 +93,70 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = { override def predict(testData: Vector): Double = {
modelType match { modelType match {
case Multinomial => case Multinomial =>
val prob = thetaMatrix.multiply(testData) labels(multinomialCalculation(testData).argmax)
BLAS.axpy(1.0, piVector, prob)
labels(prob.argmax)
case Bernoulli => case Bernoulli =>
testData.foreachActive { (index, value) => labels(bernoulliCalculation(testData).argmax)
if (value != 0.0 && value != 1.0) { }
throw new SparkException( }
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
} /**
} * Predict values for the given data set using the model trained.
val prob = thetaMinusNegTheta.get.multiply(testData) *
BLAS.axpy(1.0, piVector, prob) * @param testData RDD representing data points to be predicted
BLAS.axpy(1.0, negThetaSum.get, prob) * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities,
labels(prob.argmax) * in the same order as class labels
case _ => */
// This should never happen. def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = {
throw new UnknownError(s"Invalid modelType: $modelType.") val bcModel = testData.context.broadcast(this)
testData.mapPartitions { iter =>
val model = bcModel.value
iter.map(model.predictProbabilities)
} }
} }
/**
* Predict posterior class probabilities for a single data point using the model trained.
*
* @param testData array representing a single data point
* @return predicted posterior class probabilities from the trained model,
* in the same order as class labels
*/
def predictProbabilities(testData: Vector): Vector = {
modelType match {
case Multinomial =>
posteriorProbabilities(multinomialCalculation(testData))
case Bernoulli =>
posteriorProbabilities(bernoulliCalculation(testData))
}
}
private def multinomialCalculation(testData: Vector) = {
val prob = thetaMatrix.multiply(testData)
BLAS.axpy(1.0, piVector, prob)
prob
}
private def bernoulliCalculation(testData: Vector) = {
testData.foreachActive((_, value) =>
if (value != 0.0 && value != 1.0) {
throw new SparkException(
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
}
)
val prob = thetaMinusNegTheta.get.multiply(testData)
BLAS.axpy(1.0, piVector, prob)
BLAS.axpy(1.0, negThetaSum.get, prob)
prob
}
private def posteriorProbabilities(logProb: DenseVector) = {
val logProbArray = logProb.toArray
val maxLog = logProbArray.max
val scaledProbs = logProbArray.map(lp => math.exp(lp - maxLog))
val probSum = scaledProbs.sum
new DenseVector(scaledProbs.map(_ / probSum))
}
override def save(sc: SparkContext, path: String): Unit = { override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
......
...@@ -19,13 +19,14 @@ package org.apache.spark.mllib.classification ...@@ -19,13 +19,14 @@ package org.apache.spark.mllib.classification
import scala.util.Random import scala.util.Random
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
import breeze.stats.distributions.{Multinomial => BrzMultinomial} import breeze.stats.distributions.{Multinomial => BrzMultinomial}
import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
object NaiveBayesSuite { object NaiveBayesSuite {
...@@ -154,6 +155,29 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -154,6 +155,29 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test prediction on Array. // Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData) validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
// Test posteriors
validationData.map(_.features).foreach { features =>
val predicted = model.predictProbabilities(features).toArray
assert(predicted.sum ~== 1.0 relTol 1.0e-10)
val expected = expectedMultinomialProbabilities(model, features)
expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
}
}
/**
* @param model Multinomial Naive Bayes model
* @param testData input to compute posterior probabilities for
* @return posterior class probabilities (in order of labels) for input
*/
private def expectedMultinomialProbabilities(model: NaiveBayesModel, testData: Vector) = {
val piVector = new BDV(model.pi)
// model.theta is row-major; treat it as col-major representation of transpose, and transpose:
val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze)
val classProbs = logClassProbs.toArray.map(math.exp)
val classProbsSum = classProbs.sum
classProbs.map(_ / classProbsSum)
} }
test("Naive Bayes Bernoulli") { test("Naive Bayes Bernoulli") {
...@@ -182,6 +206,33 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -182,6 +206,33 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test prediction on Array. // Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData) validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
// Test posteriors
validationData.map(_.features).foreach { features =>
val predicted = model.predictProbabilities(features).toArray
assert(predicted.sum ~== 1.0 relTol 1.0e-10)
val expected = expectedBernoulliProbabilities(model, features)
expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
}
}
/**
* @param model Bernoulli Naive Bayes model
* @param testData input to compute posterior probabilities for
* @return posterior class probabilities (in order of labels) for input
*/
private def expectedBernoulliProbabilities(model: NaiveBayesModel, testData: Vector) = {
val piVector = new BDV(model.pi)
val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length,
model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t
val testBreeze = testData.toBreeze
val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze
val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze)
val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze)
val classProbs = logClassProbs.toArray.map(math.exp)
val classProbsSum = classProbs.sum
classProbs.map(_ / classProbsSum)
} }
test("detect negative values") { test("detect negative values") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment