Skip to content
Snippets Groups Projects
Commit fdfb45e6 authored by Xusen Yin's avatar Xusen Yin Committed by Patrick Wendell
Browse files

[WIP] [SPARK-1328] Add vector statistics

As with the new vector system in MLlib, we find that it is good to add some new APIs to precess the `RDD[Vector]`. Beside, the former implementation of `computeStat` is not stable which could loss precision, and has the possibility to cause `Nan` in scientific computing, just as said in the [SPARK-1328](https://spark-project.atlassian.net/browse/SPARK-1328).

APIs contain:

* rowMeans(): RDD[Double]
* rowNorm2(): RDD[Double]
* rowSDs(): RDD[Double]
* colMeans(): Vector
* colMeans(size: Int): Vector
* colNorm2(): Vector
* colNorm2(size: Int): Vector
* colSDs(): Vector
* colSDs(size: Int): Vector
* maxOption((Vector, Vector) => Boolean): Option[Vector]
* minOption((Vector, Vector) => Boolean): Option[Vector]
* rowShrink(): RDD[Vector]
* colShrink(): RDD[Vector]

This is working in process now, and some more APIs will add to `LabeledPoint`. Moreover, the implicit declaration will move from `MLUtils` to `MLContext` later.

Author: Xusen Yin <yinxusen@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #268 from yinxusen/vector-statistics and squashes the following commits:

d61363f [Xusen Yin] rebase to latest master
16ae684 [Xusen Yin] fix minor error and remove useless method
10cf5d3 [Xusen Yin] refine some return type
b064714 [Xusen Yin] remove computeStat in MLUtils
cbbefdb [Xiangrui Meng] update multivariate statistical summary interface and clean tests
4eaf28a [Xusen Yin] merge VectorRDDStatistics into RowMatrix
48ee053 [Xusen Yin] fix minor error
e624f93 [Xusen Yin] fix scala style error
1fba230 [Xusen Yin] merge while loop together
69e1f37 [Xusen Yin] remove lazy eval, and minor memory footprint
548e9de [Xusen Yin] minor revision
86522c4 [Xusen Yin] add comments on functions
dc77e38 [Xusen Yin] test sparse vector RDD
18cf072 [Xusen Yin] change def to lazy val to make sure that the computations in function be evaluated only once
f7a3ca2 [Xusen Yin] fix the corner case of maxmin
967d041 [Xusen Yin] full revision with Aggregator class
138300c [Xusen Yin] add new Aggregator class
1376ff4 [Xusen Yin] rename variables and adjust code
4a5c38d [Xusen Yin] add scala doc, refine code and comments
036b7a5 [Xusen Yin] fix the bug of Nan occur
f6e8e9a [Xusen Yin] add sparse vectors test
4cfbadf [Xusen Yin] fix bug of min max
4e4fbd1 [Xusen Yin] separate seqop and combop out as independent functions
a6d5a2e [Xusen Yin] rewrite for only computing non-zero elements
3980287 [Xusen Yin] rename variables
62a2c3e [Xusen Yin] use axpy and in-place if possible
9a75ebd [Xusen Yin] add case class to wrap return values
d816ac7 [Xusen Yin] remove useless APIs
c4651bb [Xusen Yin] remove row-wise APIs and refine code
1338ea1 [Xusen Yin] all-in-one version test passed
cc65810 [Xusen Yin] add parallel mean and variance
9af2e95 [Xusen Yin] refine the code style
ad6c82d [Xusen Yin] add shrink test
e09d5d2 [Xusen Yin] add scala docs and refine shrink method
8ef3377 [Xusen Yin] pass all tests
28cf060 [Xusen Yin] fix error of column means
54b19ab [Xusen Yin] add new API to shrink RDD[Vector]
8c6c0e1 [Xusen Yin] add basic statistics
parent 7038b00b
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
import java.util
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
import breeze.numerics.{sqrt => brzSqrt}
import com.github.fommil.netlib.BLAS.{getInstance => blas}
......@@ -27,6 +27,138 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
/**
* Column statistics aggregator implementing
* [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
* together with add() and merge() function.
* A numerically stable algorithm is implemented to compute sample mean and variance:
*[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
* Zero elements (including explicit zero values) are skipped when calling add() and merge(),
* to have time complexity O(nnz) instead of O(n) for each column.
*/
private class ColumnStatisticsAggregator(private val n: Int)
extends MultivariateStatisticalSummary with Serializable {
private val currMean: BDV[Double] = BDV.zeros[Double](n)
private val currM2n: BDV[Double] = BDV.zeros[Double](n)
private var totalCnt = 0.0
private val nnz: BDV[Double] = BDV.zeros[Double](n)
private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)
override def mean: Vector = {
val realMean = BDV.zeros[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * nnz(i) / totalCnt
i += 1
}
Vectors.fromBreeze(realMean)
}
override def variance: Vector = {
val realVariance = BDV.zeros[Double](n)
val denominator = totalCnt - 1.0
// Sample variance is computed, if the denominator is less than 0, the variance is just 0.
if (denominator > 0.0) {
val deltaMean = currMean
var i = 0
while (i < currM2n.size) {
realVariance(i) =
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
realVariance(i) /= denominator
i += 1
}
}
Vectors.fromBreeze(realVariance)
}
override def count: Long = totalCnt.toLong
override def numNonzeros: Vector = Vectors.fromBreeze(nnz)
override def max: Vector = {
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMax)
}
override def min: Vector = {
var i = 0
while (i < n) {
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
Vectors.fromBreeze(currMin)
}
/**
* Aggregates a row.
*/
def add(currData: BV[Double]): this.type = {
currData.activeIterator.foreach {
case (_, 0.0) => // Skip explicit zero elements.
case (i, value) =>
if (currMax(i) < value) {
currMax(i) = value
}
if (currMin(i) > value) {
currMin(i) = value
}
val tmpPrevMean = currMean(i)
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
nnz(i) += 1.0
}
totalCnt += 1.0
this
}
/**
* Merges another aggregator.
*/
def merge(other: ColumnStatisticsAggregator): this.type = {
require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
val deltaMean = currMean - other.currMean
var i = 0
while (i < n) {
// merge mean together
if (other.currMean(i) != 0.0) {
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
(nnz(i) + other.nnz(i))
}
// merge m2n together
if (nnz(i) + other.nnz(i) != 0.0) {
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
(nnz(i) + other.nnz(i))
}
if (currMax(i) < other.currMax(i)) {
currMax(i) = other.currMax(i)
}
if (currMin(i) > other.currMin(i)) {
currMin(i) = other.currMin(i)
}
i += 1
}
nnz += other.nnz
this
}
}
/**
* :: Experimental ::
......@@ -182,13 +314,7 @@ class RowMatrix(
combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2)
)
// Update _m if it is not set, or verify its value.
if (nRows <= 0L) {
nRows = m
} else {
require(nRows == m,
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
}
updateNumRows(m)
mean :/= m.toDouble
......@@ -240,6 +366,19 @@ class RowMatrix(
}
}
/**
* Computes column-wise summary statistics.
*/
def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
val zeroValue = new ColumnStatisticsAggregator(numCols().toInt)
val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)(
(aggregator, data) => aggregator.add(data),
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
)
updateNumRows(summary.count)
summary
}
/**
* Multiply this matrix by a local matrix on the right.
*
......@@ -276,6 +415,16 @@ class RowMatrix(
}
mat
}
/** Updates or verfires the number of rows. */
private def updateNumRows(m: Long) {
if (nRows <= 0) {
nRows == m
} else {
require(nRows == m,
s"The number of rows $m is different from what specified or previously computed: ${nRows}.")
}
}
}
object RowMatrix {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.stat
import org.apache.spark.mllib.linalg.Vector
/**
* Trait for multivariate statistical summary of a data matrix.
*/
trait MultivariateStatisticalSummary {
/**
* Sample mean vector.
*/
def mean: Vector
/**
* Sample variance vector. Should return a zero vector if the sample size is 1.
*/
def variance: Vector
/**
* Sample size.
*/
def count: Long
/**
* Number of nonzero elements (including explicitly presented zero values) in each column.
*/
def numNonzeros: Vector
/**
* Maximum value of each column.
*/
def max: Vector
/**
* Minimum value of each column.
*/
def min: Vector
}
......@@ -17,14 +17,13 @@
package org.apache.spark.mllib.util
import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
squaredDistance => breezeSquaredDistance}
import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.Vectors
/**
* Helper methods to load, save and pre-process data used in ML Lib.
......@@ -158,58 +157,6 @@ object MLUtils {
dataStr.saveAsTextFile(dir)
}
/**
* Utility function to compute mean and standard deviation on a given dataset.
*
* @param data - input data set whose statistics are computed
* @param numFeatures - number of features
* @param numExamples - number of examples in input dataset
*
* @return (yMean, xColMean, xColSd) - Tuple consisting of
* yMean - mean of the labels
* xColMean - Row vector with mean for every column (or feature) of the input data
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
*/
private[mllib] def computeStats(
data: RDD[LabeledPoint],
numFeatures: Int,
numExamples: Long): (Double, Vector, Vector) = {
val brzData = data.map { case LabeledPoint(label, features) =>
(label, features.toBreeze)
}
val aggStats = brzData.aggregate(
(0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures))
)(
seqOp = (c, v) => (c, v) match {
case ((n, sumLabel, sum, sumSq), (label, features)) =>
features.activeIterator.foreach { case (i, x) =>
sumSq(i) += x * x
}
(n + 1L, sumLabel + label, sum += features, sumSq)
},
combOp = (c1, c2) => (c1, c2) match {
case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) =>
(n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2)
}
)
val (nl, sumLabel, sum, sumSq) = aggStats
require(nl > 0, "Input data is empty.")
require(nl == numExamples)
val n = nl.toDouble
val yMean = sumLabel / n
val mean = sum / n
val std = new Array[Double](sum.length)
var i = 0
while (i < numFeatures) {
std(i) = sumSq(i) / n - mean(i) * mean(i)
i += 1
}
(yMean, Vectors.fromBreeze(mean), Vectors.dense(std))
}
/**
* Returns the squared Euclidean distance between two vectors. The following formula will be used
* if it does not introduce too much numerical error:
......
......@@ -170,4 +170,19 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext {
))
}
}
test("compute column summary statistics") {
for (mat <- Seq(denseMat, sparseMat)) {
val summary = mat.computeColumnSummaryStatistics()
// Run twice to make sure no internal states are changed.
for (k <- 0 to 1) {
assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch")
assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch")
assert(summary.count === m, "count mismatch.")
assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch")
assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch")
assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.")
}
}
}
}
......@@ -27,7 +27,6 @@ import com.google.common.base.Charsets
import com.google.common.io.Files
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
class MLUtilsSuite extends FunSuite with LocalSparkContext {
......@@ -56,18 +55,6 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
}
}
test("compute stats") {
val data = Seq.fill(3)(Seq(
LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)),
LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0))
)).flatten
val rdd = sc.parallelize(data, 2)
val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6)
assert(meanLabel === 0.5)
assert(mean === Vectors.dense(2.0, 3.0, 4.0))
assert(std === Vectors.dense(1.0, 1.0, 1.0))
}
test("loadLibSVMData") {
val lines =
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment