diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 8c8e4a161aa5b5d561897bf2c5f17a9afdb2e016..a967df857bed399bab3cb186af2ae2e4f3d1f86c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -93,10 +93,10 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with def run(data: RDD[LabeledPoint]) = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { - case sv: SparseVector => - sv.values - case dv: DenseVector => - dv.values + case SparseVector(size, indices, values) => + values + case DenseVector(values) => + values } if (!values.forall(_ >= 0.0)) { throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 19120e1e8af19f75491a7ec726b4fc84b50be1a9..3260f27513c7f78dee5c482fcf7038632cb9ebec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -86,20 +86,20 @@ private object IDF { df = BDV.zeros(doc.size) } doc match { - case sv: SparseVector => - val nnz = sv.indices.size + case SparseVector(size, indices, values) => + val nnz = indices.size var k = 0 while (k < nnz) { - if (sv.values(k) > 0) { - df(sv.indices(k)) += 1L + if (values(k) > 0) { + df(indices(k)) += 1L } k += 1 } - case dv: DenseVector => - val n = dv.size + case DenseVector(values) => + val n = values.size var j = 0 while (j < n) { - if (dv.values(j) > 0.0) { + if (values(j) > 0.0) { df(j) += 1L } j += 1 @@ -207,20 +207,20 @@ private object IDFModel { def transform(idf: Vector, v: Vector): Vector = { val n = v.size v match { - case sv: SparseVector => - val nnz = sv.indices.size + case SparseVector(size, indices, values) => + val nnz = indices.size val newValues = new Array[Double](nnz) var k = 0 while (k < nnz) { - newValues(k) = sv.values(k) * idf(sv.indices(k)) + newValues(k) = values(k) * idf(indices(k)) k += 1 } - Vectors.sparse(n, sv.indices, newValues) - case dv: DenseVector => + Vectors.sparse(n, indices, newValues) + case DenseVector(values) => val newValues = new Array[Double](n) var j = 0 while (j < n) { - newValues(j) = dv.values(j) * idf(j) + newValues(j) = values(j) * idf(j) j += 1 } Vectors.dense(newValues) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 1ced26a9b70a26dc84dbd8247ed320d428be3341..32848e039eb81bbfb26c2beb161b182c18680307 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -52,8 +52,8 @@ class Normalizer(p: Double) extends VectorTransformer { // However, for sparse vector, the `index` array will not be changed, // so we can re-use it to save memory. vector match { - case dv: DenseVector => - val values = dv.values.clone() + case DenseVector(vs) => + val values = vs.clone() val size = values.size var i = 0 while (i < size) { @@ -61,15 +61,15 @@ class Normalizer(p: Double) extends VectorTransformer { i += 1 } Vectors.dense(values) - case sv: SparseVector => - val values = sv.values.clone() + case SparseVector(size, ids, vs) => + val values = vs.clone() val nnz = values.size var i = 0 while (i < nnz) { values(i) /= norm i += 1 } - Vectors.sparse(sv.size, sv.indices, values) + Vectors.sparse(size, ids, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 8c4c5db5258d5d3c2add6d6cfebd49fb11bb4d94..3c2091732f9b06698db490ce5d2ac0b2d5810007 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -105,8 +105,8 @@ class StandardScalerModel private[mllib] ( // This can be avoid by having a local reference of `shift`. val localShift = shift vector match { - case dv: DenseVector => - val values = dv.values.clone() + case DenseVector(vs) => + val values = vs.clone() val size = values.size if (withStd) { // Having a local reference of `factor` to avoid overhead as the comment before. @@ -130,8 +130,8 @@ class StandardScalerModel private[mllib] ( // Having a local reference of `factor` to avoid overhead as the comment before. val localFactor = factor vector match { - case dv: DenseVector => - val values = dv.values.clone() + case DenseVector(vs) => + val values = vs.clone() val size = values.size var i = 0 while(i < size) { @@ -139,18 +139,17 @@ class StandardScalerModel private[mllib] ( i += 1 } Vectors.dense(values) - case sv: SparseVector => + case SparseVector(size, indices, vs) => // For sparse vector, the `index` array inside sparse vector object will not be changed, // so we can re-use it to save memory. - val indices = sv.indices - val values = sv.values.clone() + val values = vs.clone() val nnz = values.size var i = 0 while (i < nnz) { values(i) *= localFactor(indices(i)) i += 1 } - Vectors.sparse(sv.size, indices, values) + Vectors.sparse(size, indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index d40f13342a3d9a79dd38b9054dc7e373f90b4beb..bf1faa25ef0e054ae82c839e8397e616c819b890 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -108,16 +108,16 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def serialize(obj: Any): Row = { val row = new GenericMutableRow(4) obj match { - case sv: SparseVector => + case SparseVector(size, indices, values) => row.setByte(0, 0) - row.setInt(1, sv.size) - row.update(2, sv.indices.toSeq) - row.update(3, sv.values.toSeq) - case dv: DenseVector => + row.setInt(1, size) + row.update(2, indices.toSeq) + row.update(3, values.toSeq) + case DenseVector(values) => row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, dv.values.toSeq) + row.update(3, values.toSeq) } row } @@ -271,8 +271,8 @@ object Vectors { def norm(vector: Vector, p: Double): Double = { require(p >= 1.0) val values = vector match { - case dv: DenseVector => dv.values - case sv: SparseVector => sv.values + case DenseVector(vs) => vs + case SparseVector(n, ids, vs) => vs case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } val size = values.size @@ -427,6 +427,10 @@ class DenseVector(val values: Array[Double]) extends Vector { } } +object DenseVector { + def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) +} + /** * A sparse vector represented by an index array and an value array. * @@ -474,3 +478,8 @@ class SparseVector( } } } + +object SparseVector { + def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] = + Some((sv.size, sv.indices, sv.values)) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index a3fca53929ab7cc56ddcdd003d64a082b4158f85..fbd35e372f9b10c22d4ebc79542488d44ae0b2c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -528,21 +528,21 @@ class RowMatrix( iter.flatMap { row => val buf = new ListBuffer[((Int, Int), Double)]() row match { - case sv: SparseVector => - val nnz = sv.indices.size + case SparseVector(size, indices, values) => + val nnz = indices.size var k = 0 while (k < nnz) { - scaled(k) = sv.values(k) / q(sv.indices(k)) + scaled(k) = values(k) / q(indices(k)) k += 1 } k = 0 while (k < nnz) { - val i = sv.indices(k) + val i = indices(k) val iVal = scaled(k) if (iVal != 0 && rand.nextDouble() < p(i)) { var l = k + 1 while (l < nnz) { - val j = sv.indices(l) + val j = indices(l) val jVal = scaled(l) if (jVal != 0 && rand.nextDouble() < p(j)) { buf += (((i, j), iVal * jVal)) @@ -552,11 +552,11 @@ class RowMatrix( } k += 1 } - case dv: DenseVector => - val n = dv.values.size + case DenseVector(values) => + val n = values.size var i = 0 while (i < n) { - scaled(i) = dv.values(i) / q(i) + scaled(i) = values(i) / q(i) i += 1 } i = 0 @@ -620,11 +620,9 @@ object RowMatrix { // TODO: Find a better home (breeze?) for this method. val n = v.size v match { - case dv: DenseVector => - blas.dspr("U", n, alpha, dv.values, 1, U) - case sv: SparseVector => - val indices = sv.indices - val values = sv.values + case DenseVector(values) => + blas.dspr("U", n, alpha, values, 1, U) + case SparseVector(size, indices, values) => val nnz = indices.length var colStartIdx = 0 var prevCol = 0