Skip to content
Snippets Groups Projects
Commit 7a8a81d7 authored by WeichenXu's avatar WeichenXu Committed by Sean Owen
Browse files

[SPARK-17363][ML][MLLIB] fix MultivariantOnlineSummerizer.numNonZeros

## What changes were proposed in this pull request?

fix `MultivariantOnlineSummerizer.numNonZeros` method,
return `nnz` array, instead of  `weightSum` array

## How was this patch tested?

Existing test.

Author: WeichenXu <WeichenXu123@outlook.com>

Closes #14923 from WeichenXu123/fix_MultivariantOnlineSummerizer_numNonZeros.
parent d2fde6b7
No related branches found
No related tags found
No related merge requests found
......@@ -231,9 +231,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
*/
@Since("1.1.0")
override def numNonzeros: Vector = {
require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
Vectors.dense(weightSum)
Vectors.dense(nnz.map(_.toDouble))
}
/**
......
......@@ -237,7 +237,7 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
absTol 1E-10, "mean mismatch")
assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857))
absTol 1E-8, "variance mismatch")
assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4))
assert(summarizer.numNonzeros ~== Vectors.dense(Array(3.0, 4.0, 3.0))
absTol 1E-10, "numNonzeros mismatch")
assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch")
assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment