diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 964f419d120dd8f983a0a9cbec82cc36a4fce554..7a2a7a35a91cdf9dadecdc47f6375653961007c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -231,9 +231,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def numNonzeros: Vector = { - require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") + require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.dense(weightSum) + Vectors.dense(nnz.map(_.toDouble)) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 165a3f314a2012155ce85cba80a3b2aadd2abfc8..797e84fcc73770606b1ed671b189dc1c1c323342 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -237,7 +237,7 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { absTol 1E-10, "mean mismatch") assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857)) absTol 1E-8, "variance mismatch") - assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4)) + assert(summarizer.numNonzeros ~== Vectors.dense(Array(3.0, 4.0, 3.0)) absTol 1E-10, "numNonzeros mismatch") assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch") assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch")