Skip to content
Snippets Groups Projects
Commit 61e05fc5 authored by leahmcguire's avatar leahmcguire Committed by Joseph K. Bradley
Browse files

[SPARK-7545] [MLLIB] Added check in Bernoulli Naive Bayes to make sure that...

[SPARK-7545] [MLLIB] Added check in Bernoulli Naive Bayes to make sure that both training and predict features have values of 0 or 1

Author: leahmcguire <lmcguire@salesforce.com>

Closes #6073 from leahmcguire/binaryCheckNB and squashes the following commits:

b8442c2 [leahmcguire] changed to if else for value checks
911bf83 [leahmcguire] undid reformat
4eedf1e [leahmcguire] moved bernoulli check
9ee9e84 [leahmcguire] fixed style error
3f3b32c [leahmcguire] fixed zero one check so only called in combiner
831fd27 [leahmcguire] got test working
f44bb3c [leahmcguire] removed changes from CV branch
67253f0 [leahmcguire] added check to bernoulli to ensure feature values are zero or one
f191c71 [leahmcguire] fixed name
58d060b [leahmcguire] changed param name and test according to comments
04f0d3c [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access
parent 5db18ba6
No related branches found
No related tags found
No related merge requests found
......@@ -87,12 +87,17 @@ class NaiveBayesModel private[mllib] (
}
override def predict(testData: Vector): Double = {
val brzData = testData.toBreeze
modelType match {
case "Multinomial" =>
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
labels (brzArgmax (brzPi + brzTheta * brzData) )
case "Bernoulli" =>
if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
}
labels (brzArgmax (brzPi +
(brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
(brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
......@@ -293,12 +298,29 @@ class NaiveBayes private (
}
}
val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
val values = v match {
case SparseVector(size, indices, values) =>
values
case DenseVector(values) =>
values
}
if (!values.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
}
}
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
createCombiner = (v: Vector) => {
requireNonnegativeValues(v)
if (modelType == "Bernoulli") {
requireZeroOneBernoulliValues(v)
} else {
requireNonnegativeValues(v)
}
(1L, v.toBreeze.toDenseVector)
},
mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
......
......@@ -208,6 +208,39 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
test("detect non zero or one values in Bernoulli") {
val badTrain = Seq(
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(1.0, Vectors.dense(0.0)))
intercept[SparkException] {
NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli")
}
val okTrain = Seq(
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(1.0, Vectors.dense(1.0)),
LabeledPoint(1.0, Vectors.dense(1.0))
)
val badPredict = Seq(
Vectors.dense(1.0),
Vectors.dense(2.0),
Vectors.dense(1.0),
Vectors.dense(0.0))
val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli")
intercept[SparkException] {
model.predict(sc.makeRDD(badPredict, 2)).collect()
}
}
test("model save/load: 2.0 to 2.0") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment