diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 3025d4837cab42daec042d84804d16d21792b90e..fab7c4405c65dc179ad1d7e6cd72302fe71d984d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} /** * :: DeveloperApi :: @@ -72,9 +72,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - sample.toBreeze.activeIterator.foreach { - case (_, 0.0) => // Skip explicit zero elements. - case (i, value) => + @inline def update(i: Int, value: Double) = { + if (value != 0.0) { if (currMax(i) < value) { currMax(i) = value } @@ -89,6 +88,24 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currL1(i) += math.abs(value) nnz(i) += 1.0 + } + } + + sample match { + case dv: DenseVector => { + var j = 0 + while (j < dv.size) { + update(j, dv.values(j)) + j += 1 + } + } + case sv: SparseVector => + var j = 0 + while (j < sv.indices.size) { + update(sv.indices(j), sv.values(j)) + j += 1 + } + case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } totalCnt += 1