Skip to content
Snippets Groups Projects
Commit d5f18de1 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7612] [MLLIB] update NB training to use mllib's BLAS

This is similar to the changes to k-means, which gives us better control on the performance. dbtsai

Author: Xiangrui Meng <meng@databricks.com>

Closes #6128 from mengxr/SPARK-7612 and squashes the following commits:

b5c24c5 [Xiangrui Meng] merge master
a90e3ec [Xiangrui Meng] update NB training to use mllib's BLAS
parent 3113da9c
No related branches found
No related tags found
No related merge requests found
...@@ -21,15 +21,13 @@ import java.lang.{Iterable => JIterable} ...@@ -21,15 +21,13 @@ import java.lang.{Iterable => JIterable}
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
import breeze.numerics.{exp => brzExp, log => brzLog} import breeze.numerics.{exp => brzExp, log => brzLog}
import org.json4s.JsonDSL._ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.json4s.{DefaultFormats, JValue}
import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
...@@ -90,13 +88,13 @@ class NaiveBayesModel private[mllib] ( ...@@ -90,13 +88,13 @@ class NaiveBayesModel private[mllib] (
val brzData = testData.toBreeze val brzData = testData.toBreeze
modelType match { modelType match {
case "Multinomial" => case "Multinomial" =>
labels (brzArgmax (brzPi + brzTheta * brzData) ) labels(brzArgmax(brzPi + brzTheta * brzData))
case "Bernoulli" => case "Bernoulli" =>
if (!brzData.forall(v => v == 0.0 || v == 1.0)) { if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException( throw new SparkException(
s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
} }
labels (brzArgmax (brzPi + labels(brzArgmax(brzPi +
(brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
case _ => case _ =>
// This should never happen. // This should never happen.
...@@ -152,7 +150,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { ...@@ -152,7 +150,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
// Check schema explicitly since erasure makes it hard to use match-case for checking. // Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema) checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0) val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray val pi = data.getAs[Seq[Double]](1).toArray
...@@ -198,7 +196,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { ...@@ -198,7 +196,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
// Check schema explicitly since erasure makes it hard to use match-case for checking. // Check schema explicitly since erasure makes it hard to use match-case for checking.
checkSchema[Data](dataRDD.schema) checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("labels", "pi", "theta").take(1) val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
val data = dataArray(0) val data = dataArray(0)
val labels = data.getAs[Seq[Double]](0).toArray val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray val pi = data.getAs[Seq[Double]](1).toArray
...@@ -288,10 +286,8 @@ class NaiveBayes private ( ...@@ -288,10 +286,8 @@ class NaiveBayes private (
def run(data: RDD[LabeledPoint]): NaiveBayesModel = { def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match { val values = v match {
case SparseVector(size, indices, values) => case sv: SparseVector => sv.values
values case dv: DenseVector => dv.values
case DenseVector(values) =>
values
} }
if (!values.forall(_ >= 0.0)) { if (!values.forall(_ >= 0.0)) {
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
...@@ -300,10 +296,8 @@ class NaiveBayes private ( ...@@ -300,10 +296,8 @@ class NaiveBayes private (
val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
val values = v match { val values = v match {
case SparseVector(size, indices, values) => case sv: SparseVector => sv.values
values case dv: DenseVector => dv.values
case DenseVector(values) =>
values
} }
if (!values.forall(v => v == 0.0 || v == 1.0)) { if (!values.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException( throw new SparkException(
...@@ -314,21 +308,24 @@ class NaiveBayes private ( ...@@ -314,21 +308,24 @@ class NaiveBayes private (
// Aggregates term frequencies per label. // Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage. // TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
createCombiner = (v: Vector) => { createCombiner = (v: Vector) => {
if (modelType == "Bernoulli") { if (modelType == "Bernoulli") {
requireZeroOneBernoulliValues(v) requireZeroOneBernoulliValues(v)
} else { } else {
requireNonnegativeValues(v) requireNonnegativeValues(v)
} }
(1L, v.toBreeze.toDenseVector) (1L, v.copy.toDense)
}, },
mergeValue = (c: (Long, BDV[Double]), v: Vector) => { mergeValue = (c: (Long, DenseVector), v: Vector) => {
requireNonnegativeValues(v) requireNonnegativeValues(v)
(c._1 + 1L, c._2 += v.toBreeze) BLAS.axpy(1.0, v, c._2)
(c._1 + 1L, c._2)
}, },
mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => {
(c1._1 + c2._1, c1._2 += c2._2) BLAS.axpy(1.0, c2._2, c1._2)
(c1._1 + c2._1, c1._2)
}
).collect() ).collect()
val numLabels = aggregated.length val numLabels = aggregated.length
...@@ -348,7 +345,7 @@ class NaiveBayes private ( ...@@ -348,7 +345,7 @@ class NaiveBayes private (
labels(i) = label labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match { val thetaLogDenom = modelType match {
case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
case "Bernoulli" => math.log(n + 2.0 * lambda) case "Bernoulli" => math.log(n + 2.0 * lambda)
case _ => case _ =>
// This should never happen. // This should never happen.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment