Skip to content
Snippets Groups Projects
Commit d5f18de1 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7612] [MLLIB] update NB training to use mllib's BLAS

This is similar to the changes to k-means, which gives us better control on the performance. dbtsai

Author: Xiangrui Meng <meng@databricks.com>

Closes #6128 from mengxr/SPARK-7612 and squashes the following commits:

b5c24c5 [Xiangrui Meng] merge master
a90e3ec [Xiangrui Meng] update NB training to use mllib's BLAS
parent 3113da9c
No related branches found
No related tags found
No related merge requests found
......@@ -21,15 +21,13 @@ import java.lang.{Iterable => JIterable}
import scala.collection.JavaConverters._
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis}
import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.numerics.{exp => brzExp, log => brzLog}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.json4s.{DefaultFormats, JValue}
import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
......@@ -90,13 +88,13 @@ class NaiveBayesModel private[mllib] (
val brzData = testData.toBreeze
modelType match {
case "Multinomial" =>
labels (brzArgmax (brzPi + brzTheta * brzData) )
labels(brzArgmax(brzPi + brzTheta * brzData))
case "Bernoulli" =>
if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
}
labels (brzArgmax (brzPi +
labels(brzArgmax(brzPi +
(brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
case _ =>
// This should never happen.
......@@ -152,7 +150,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
......@@ -198,7 +196,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
// Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
......@@ -288,10 +286,8 @@ class NaiveBayes private (
def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
case SparseVector(size, indices, values) =>
values
case DenseVector(values) =>
values
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
}
if (!values.forall(_ >= 0.0)) {
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
......@@ -300,10 +296,8 @@ class NaiveBayes private (
val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
val values = v match {
case SparseVector(size, indices, values) =>
values
case DenseVector(values) =>
values
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
}
if (!values.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
......@@ -314,21 +308,24 @@ class NaiveBayes private (
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
createCombiner = (v: Vector) => {
if (modelType == "Bernoulli") {
requireZeroOneBernoulliValues(v)
} else {
requireNonnegativeValues(v)
}
(1L, v.toBreeze.toDenseVector)
(1L, v.copy.toDense)
},
mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
mergeValue = (c: (Long, DenseVector), v: Vector) => {
requireNonnegativeValues(v)
(c._1 + 1L, c._2 += v.toBreeze)
BLAS.axpy(1.0, v, c._2)
(c._1 + 1L, c._2)
},
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
(c1._1 + c2._1, c1._2 += c2._2)
mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => {
BLAS.axpy(1.0, c2._2, c1._2)
(c1._1 + c2._1, c1._2)
}
).collect()
val numLabels = aggregated.length
......@@ -348,7 +345,7 @@ class NaiveBayes private (
labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match {
case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
case "Bernoulli" => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment