Skip to content
Snippets Groups Projects
Commit e43803b8 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-6948] [MLLIB] compress vectors in VectorAssembler

The compression is based on storage. brkyvz

Author: Xiangrui Meng <meng@databricks.com>

Closes #5985 from mengxr/SPARK-6948 and squashes the following commits:

df56a00 [Xiangrui Meng] update python tests
6d90d45 [Xiangrui Meng] compress vectors in VectorAssembler
parent 658a478d
No related branches found
No related tags found
No related merge requests found
...@@ -102,6 +102,6 @@ object VectorAssembler { ...@@ -102,6 +102,6 @@ object VectorAssembler {
case o => case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
} }
Vectors.sparse(cur, indices.result(), values.result()) Vectors.sparse(cur, indices.result(), values.result()).compressed
} }
} }
...@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature ...@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.spark.SparkException import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.{Row, SQLContext}
...@@ -48,6 +48,14 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { ...@@ -48,6 +48,14 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
} }
} }
test("assemble should compress vectors") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
assert(v1.isInstanceOf[SparseVector])
val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
assert(v2.isInstanceOf[DenseVector])
}
test("VectorAssembler") { test("VectorAssembler") {
val df = sqlContext.createDataFrame(Seq( val df = sqlContext.createDataFrame(Seq(
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
......
...@@ -121,12 +121,12 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): ...@@ -121,12 +121,12 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
>>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF() >>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
>>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features") >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
>>> vecAssembler.transform(df).head().features >>> vecAssembler.transform(df).head().features
SparseVector(3, {0: 1.0, 2: 3.0}) DenseVector([1.0, 0.0, 3.0])
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs >>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
SparseVector(3, {0: 1.0, 2: 3.0}) DenseVector([1.0, 0.0, 3.0])
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"} >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
>>> vecAssembler.transform(df, params).head().vector >>> vecAssembler.transform(df, params).head().vector
SparseVector(2, {1: 1.0}) DenseVector([0.0, 1.0])
""" """
_java_class = "org.apache.spark.ml.feature.VectorAssembler" _java_class = "org.apache.spark.ml.feature.VectorAssembler"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment