Skip to content
Snippets Groups Projects
Commit b5d17ef1 authored by DB Tsai's avatar DB Tsai Committed by Xiangrui Meng
Browse files

[SPARK-4431][MLlib] Implement efficient foreachActive for dense and sparse vector

Previously, we were using Breeze's activeIterator to access the non-zero elements
in dense/sparse vector. Due to the overhead, we switched back to native `while loop`
in #SPARK-4129.

However, #SPARK-4129 requires de-reference the dv.values/sv.values in
each access to the value, which is very expensive. Also, in MultivariateOnlineSummarizer,
we're using Breeze's dense vector to store the partial stats, and this is very expensive compared
with using primitive scala array.

In this PR, efficient foreachActive is implemented to unify the code path for dense and sparse
vector operation which makes codebase easier to maintain. Breeze dense vector is replaced
by primitive array to reduce the overhead further.

Benchmarking with mnist8m dataset on single JVM
with first 200 samples loaded in memory, and repeating 5000 times.

Before change:
Sparse Vector - 30.02
Dense Vector - 38.27

With this PR:
Sparse Vector - 6.29
Dense Vector - 11.72

Author: DB Tsai <dbtsai@alpinenow.com>

Closes #3288 from dbtsai/activeIterator and squashes the following commits:

844b0e6 [DB Tsai] formating
03dd693 [DB Tsai] futher performance tunning.
1907ae1 [DB Tsai] address feedback
98448bb [DB Tsai] Made the override final, and had a local copy of variables which made the accessing a single step operation.
c0cbd5a [DB Tsai] fix a bug
6441f92 [DB Tsai] Finished SPARK-4431
parent ce95bd8e
No related branches found
No related tags found
No related merge requests found
...@@ -76,6 +76,15 @@ sealed trait Vector extends Serializable { ...@@ -76,6 +76,15 @@ sealed trait Vector extends Serializable {
def copy: Vector = { def copy: Vector = {
throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
} }
/**
* Applies a function `f` to all the active elements of dense and sparse vector.
*
* @param f the function takes two parameters where the first parameter is the index of
* the vector with type `Int`, and the second parameter is the corresponding value
* with type `Double`.
*/
private[spark] def foreachActive(f: (Int, Double) => Unit)
} }
/** /**
...@@ -273,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector { ...@@ -273,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector {
override def copy: DenseVector = { override def copy: DenseVector = {
new DenseVector(values.clone()) new DenseVector(values.clone())
} }
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
var i = 0
val localValuesSize = values.size
val localValues = values
while (i < localValuesSize) {
f(i, localValues(i))
i += 1
}
}
} }
/** /**
...@@ -309,4 +329,16 @@ class SparseVector( ...@@ -309,4 +329,16 @@ class SparseVector(
} }
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
var i = 0
val localValuesSize = values.size
val localIndices = indices
val localValues = values
while (i < localValuesSize) {
f(localIndices(i), localValues(i))
i += 1
}
}
} }
...@@ -17,10 +17,8 @@ ...@@ -17,10 +17,8 @@
package org.apache.spark.mllib.stat package org.apache.spark.mllib.stat
import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.annotation.DeveloperApi import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} import org.apache.spark.mllib.linalg.{Vectors, Vector}
/** /**
* :: DeveloperApi :: * :: DeveloperApi ::
...@@ -40,37 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector ...@@ -40,37 +38,14 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
private var n = 0 private var n = 0
private var currMean: BDV[Double] = _ private var currMean: Array[Double] = _
private var currM2n: BDV[Double] = _ private var currM2n: Array[Double] = _
private var currM2: BDV[Double] = _ private var currM2: Array[Double] = _
private var currL1: BDV[Double] = _ private var currL1: Array[Double] = _
private var totalCnt: Long = 0 private var totalCnt: Long = 0
private var nnz: BDV[Double] = _ private var nnz: Array[Double] = _
private var currMax: BDV[Double] = _ private var currMax: Array[Double] = _
private var currMin: BDV[Double] = _ private var currMin: Array[Double] = _
/**
* Adds input value to position i.
*/
private[this] def add(i: Int, value: Double) = {
if (value != 0.0) {
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}
val prevMean = currMean(i)
val diff = value - prevMean
currMean(i) = prevMean + diff / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * diff
currM2(i) += value * value
currL1(i) += math.abs(value)
nnz(i) += 1.0
}
}
/** /**
* Add a new sample to this summarizer, and update the statistical summary. * Add a new sample to this summarizer, and update the statistical summary.
...@@ -83,33 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S ...@@ -83,33 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(sample.size > 0, s"Vector should have dimension larger than zero.") require(sample.size > 0, s"Vector should have dimension larger than zero.")
n = sample.size n = sample.size
currMean = BDV.zeros[Double](n) currMean = Array.ofDim[Double](n)
currM2n = BDV.zeros[Double](n) currM2n = Array.ofDim[Double](n)
currM2 = BDV.zeros[Double](n) currM2 = Array.ofDim[Double](n)
currL1 = BDV.zeros[Double](n) currL1 = Array.ofDim[Double](n)
nnz = BDV.zeros[Double](n) nnz = Array.ofDim[Double](n)
currMax = BDV.fill(n)(Double.MinValue) currMax = Array.fill[Double](n)(Double.MinValue)
currMin = BDV.fill(n)(Double.MaxValue) currMin = Array.fill[Double](n)(Double.MaxValue)
} }
require(n == sample.size, s"Dimensions mismatch when adding new sample." + require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.") s" Expecting $n but got ${sample.size}.")
sample match { sample.foreachActive { (index, value) =>
case dv: DenseVector => { if (value != 0.0) {
var j = 0 if (currMax(index) < value) {
while (j < dv.size) { currMax(index) = value
add(j, dv.values(j))
j += 1
} }
} if (currMin(index) > value) {
case sv: SparseVector => currMin(index) = value
var j = 0
while (j < sv.indices.size) {
add(sv.indices(j), sv.values(j))
j += 1
} }
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
val prevMean = currMean(index)
val diff = value - prevMean
currMean(index) = prevMean + diff / (nnz(index) + 1.0)
currM2n(index) += (value - currMean(index)) * diff
currM2(index) += value * value
currL1(index) += math.abs(value)
nnz(index) += 1.0
}
} }
totalCnt += 1 totalCnt += 1
...@@ -152,14 +130,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S ...@@ -152,14 +130,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
} }
} else if (totalCnt == 0 && other.totalCnt != 0) { } else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n this.n = other.n
this.currMean = other.currMean.copy this.currMean = other.currMean.clone
this.currM2n = other.currM2n.copy this.currM2n = other.currM2n.clone
this.currM2 = other.currM2.copy this.currM2 = other.currM2.clone
this.currL1 = other.currL1.copy this.currL1 = other.currL1.clone
this.totalCnt = other.totalCnt this.totalCnt = other.totalCnt
this.nnz = other.nnz.copy this.nnz = other.nnz.clone
this.currMax = other.currMax.copy this.currMax = other.currMax.clone
this.currMin = other.currMin.copy this.currMin = other.currMin.clone
} }
this this
} }
...@@ -167,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S ...@@ -167,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
override def mean: Vector = { override def mean: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.") require(totalCnt > 0, s"Nothing has been added to this summarizer.")
val realMean = BDV.zeros[Double](n) val realMean = Array.ofDim[Double](n)
var i = 0 var i = 0
while (i < n) { while (i < n) {
realMean(i) = currMean(i) * (nnz(i) / totalCnt) realMean(i) = currMean(i) * (nnz(i) / totalCnt)
i += 1 i += 1
} }
Vectors.fromBreeze(realMean) Vectors.dense(realMean)
} }
override def variance: Vector = { override def variance: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.") require(totalCnt > 0, s"Nothing has been added to this summarizer.")
val realVariance = BDV.zeros[Double](n) val realVariance = Array.ofDim[Double](n)
val denominator = totalCnt - 1.0 val denominator = totalCnt - 1.0
...@@ -194,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S ...@@ -194,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
i += 1 i += 1
} }
} }
Vectors.dense(realVariance)
Vectors.fromBreeze(realVariance)
} }
override def count: Long = totalCnt override def count: Long = totalCnt
...@@ -203,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S ...@@ -203,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
override def numNonzeros: Vector = { override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.") require(totalCnt > 0, s"Nothing has been added to this summarizer.")
Vectors.fromBreeze(nnz) Vectors.dense(nnz)
} }
override def max: Vector = { override def max: Vector = {
...@@ -214,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S ...@@ -214,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1 i += 1
} }
Vectors.fromBreeze(currMax) Vectors.dense(currMax)
} }
override def min: Vector = { override def min: Vector = {
...@@ -225,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S ...@@ -225,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1 i += 1
} }
Vectors.fromBreeze(currMin) Vectors.dense(currMin)
} }
override def normL2: Vector = { override def normL2: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.") require(totalCnt > 0, s"Nothing has been added to this summarizer.")
val realMagnitude = BDV.zeros[Double](n) val realMagnitude = Array.ofDim[Double](n)
var i = 0 var i = 0
while (i < currM2.size) { while (i < currM2.size) {
realMagnitude(i) = math.sqrt(currM2(i)) realMagnitude(i) = math.sqrt(currM2(i))
i += 1 i += 1
} }
Vectors.dense(realMagnitude)
Vectors.fromBreeze(realMagnitude)
} }
override def normL1: Vector = { override def normL1: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.") require(totalCnt > 0, s"Nothing has been added to this summarizer.")
Vectors.fromBreeze(currL1)
Vectors.dense(currL1)
} }
} }
...@@ -173,4 +173,28 @@ class VectorsSuite extends FunSuite { ...@@ -173,4 +173,28 @@ class VectorsSuite extends FunSuite {
val v = Vectors.fromBreeze(x(::, 0)) val v = Vectors.fromBreeze(x(::, 0))
assert(v.size === x.rows) assert(v.size === x.rows)
} }
test("foreachActive") {
val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0)
val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0)))
val dvMap = scala.collection.mutable.Map[Int, Double]()
dv.foreachActive { (index, value) =>
dvMap.put(index, value)
}
assert(dvMap.size === 4)
assert(dvMap.get(0) === Some(0.0))
assert(dvMap.get(1) === Some(1.2))
assert(dvMap.get(2) === Some(3.1))
assert(dvMap.get(3) === Some(0.0))
val svMap = scala.collection.mutable.Map[Int, Double]()
sv.foreachActive { (index, value) =>
svMap.put(index, value)
}
assert(svMap.size === 3)
assert(svMap.get(1) === Some(1.2))
assert(svMap.get(2) === Some(3.1))
assert(svMap.get(3) === Some(0.0))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment