Skip to content
Snippets Groups Projects
Commit 92940449 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-5885][MLLIB] Add VectorAssembler as a feature transformer

VectorAssembler merges multiple columns into a vector column. This PR contains content from #5195.

~~carry ML attributes~~ (moved to a follow-up PR)

Author: Xiangrui Meng <meng@databricks.com>

Closes #5196 from mengxr/SPARK-5885 and squashes the following commits:

a52b101 [Xiangrui Meng] recognize more types
35daac2 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5885
bb5e64b [Xiangrui Meng] add TODO for null
976a3d6 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5885
0859311 [Xiangrui Meng] Revert "add CreateStruct"
29fb6ac [Xiangrui Meng] use CreateStruct
adb71c4 [Xiangrui Meng] Merge branch 'SPARK-6542' into SPARK-5885
85f3106 [Xiangrui Meng] add CreateStruct
4ff16ce [Xiangrui Meng] add VectorAssembler
parent 685ddcf5
No related branches found
No related tags found
No related merge requests found
......@@ -29,5 +29,5 @@ private[ml] trait Identifiable extends Serializable {
* random hex chars.
*/
private[ml] val uid: String =
this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8)
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{HasInputCols, HasOutputCol, ParamMap}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
* :: AlphaComponent ::
* A feature transformer than merge multiple columns into a vector column.
*/
@AlphaComponent
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
val map = this.paramMap ++ paramMap
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
}
val schema = dataset.schema
val inputColNames = map(inputCols)
val args = inputColNames.map { c =>
schema(c).dataType match {
case DoubleType => UnresolvedAttribute(c)
case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c)
case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
}
}
dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol)))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputColNames = map(inputCols)
val outputColName = map(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
inputDataTypes.foreach {
case _: NativeType =>
case t if t.isInstanceOf[VectorUDT] =>
case other =>
throw new IllegalArgumentException(s"Data type $other is not supported.")
}
if (schema.fieldNames.contains(outputColName)) {
throw new IllegalArgumentException(s"Output column $outputColName already exists.")
}
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
}
}
@AlphaComponent
object VectorAssembler {
private[feature] def assemble(vv: Any*): Vector = {
val indices = ArrayBuilder.make[Int]
val values = ArrayBuilder.make[Double]
var cur = 0
vv.foreach {
case v: Double =>
if (v != 0.0) {
indices += cur
values += v
}
cur += 1
case vec: Vector =>
vec.foreachActive { case (i, v) =>
if (v != 0.0) {
indices += cur + i
values += v
}
}
cur += vec.size
case null =>
// TODO: output Double.NaN?
throw new SparkException("Values to assemble cannot be null.")
case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
}
Vectors.sparse(cur, indices.result(), values.result())
}
}
......@@ -140,6 +140,16 @@ private[ml] trait HasInputCol extends Params {
def getInputCol: String = get(inputCol)
}
private[ml] trait HasInputCols extends Params {
/**
* Param for input column names.
*/
val inputCols: Param[Array[String]] = new Param(this, "inputCols", "input column names")
/** @group getParam */
def getInputCols: Array[String] = get(inputCols)
}
private[ml] trait HasOutputCol extends Params {
/**
* param for output column name
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.scalatest.FunSuite
import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}
test("assemble") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty))
assert(assemble(0.0, 1.0) === Vectors.sparse(2, Array(1), Array(1.0)))
val dv = Vectors.dense(2.0, 0.0)
assert(assemble(0.0, dv, 1.0) === Vectors.sparse(4, Array(1, 3), Array(2.0, 1.0)))
val sv = Vectors.sparse(2, Array(0, 1), Array(3.0, 4.0))
assert(assemble(0.0, dv, 1.0, sv) ===
Vectors.sparse(6, Array(1, 3, 4, 5), Array(2.0, 1.0, 3.0, 4.0)))
for (v <- Seq(1, "a", null)) {
intercept[SparkException](assemble(v))
intercept[SparkException](assemble(1.0, v))
}
}
test("VectorAssembler") {
val df = sqlContext.createDataFrame(Seq(
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
)).toDF("id", "x", "y", "name", "z", "n")
val assembler = new VectorAssembler()
.setInputCols(Array("x", "y", "z", "n"))
.setOutputCol("features")
assembler.transform(df).select("features").collect().foreach {
case Row(v: Vector) =>
assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0)))
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment