Skip to content
Snippets Groups Projects
Commit 9e2ffb13 authored by Burak Yavuz's avatar Burak Yavuz Committed by Xiangrui Meng
Browse files

[SPARK-7388] [SPARK-7383] wrapper for VectorAssembler in Python

The wrapper required the implementation of the `ArrayParam`, because `Array[T]` is hard to obtain from Python. `ArrayParam` has an extra function called `wCast` which is an internal function to obtain `Array[T]` from `Seq[T]`

Author: Burak Yavuz <brkyvz@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #5930 from brkyvz/ml-feat and squashes the following commits:

73e745f [Burak Yavuz] Merge pull request #3 from mengxr/SPARK-7388
c221db9 [Xiangrui Meng] overload StringArrayParam.w
c81072d [Burak Yavuz] addressed comments
99c2ebf [Burak Yavuz] add to python_shared_params
39ecb07 [Burak Yavuz] fix scalastyle
7f7ea2a [Burak Yavuz] [SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python
parent ed9be06a
No related branches found
No related tags found
No related merge requests found
...@@ -30,7 +30,7 @@ import org.apache.spark.sql.types._ ...@@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
/** /**
* :: AlphaComponent :: * :: AlphaComponent ::
* A feature transformer than merge multiple columns into a vector column. * A feature transformer that merges multiple columns into a vector column.
*/ */
@AlphaComponent @AlphaComponent
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
......
...@@ -22,6 +22,7 @@ import java.util.NoSuchElementException ...@@ -22,6 +22,7 @@ import java.util.NoSuchElementException
import scala.annotation.varargs import scala.annotation.varargs
import scala.collection.mutable import scala.collection.mutable
import scala.collection.JavaConverters._
import org.apache.spark.annotation.AlphaComponent import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
...@@ -218,6 +219,19 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV ...@@ -218,6 +219,19 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
override def w(value: Boolean): ParamPair[Boolean] = super.w(value) override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
} }
/** Specialized version of [[Param[Array[T]]]] for Java. */
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
extends Param[Array[String]](parent, name, doc, isValid) {
def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}
/** /**
* A param amd its value. * A param amd its value.
*/ */
...@@ -310,9 +324,7 @@ trait Params extends Identifiable with Serializable { ...@@ -310,9 +324,7 @@ trait Params extends Identifiable with Serializable {
* Sets a parameter in the embedded param map. * Sets a parameter in the embedded param map.
*/ */
protected final def set[T](param: Param[T], value: T): this.type = { protected final def set[T](param: Param[T], value: T): this.type = {
shouldOwn(param) set(param -> value)
paramMap.put(param.asInstanceOf[Param[Any]], value)
this
} }
/** /**
...@@ -322,6 +334,15 @@ trait Params extends Identifiable with Serializable { ...@@ -322,6 +334,15 @@ trait Params extends Identifiable with Serializable {
set(getParam(param), value) set(getParam(param), value)
} }
/**
* Sets a parameter in the embedded param map.
*/
protected final def set(paramPair: ParamPair[_]): this.type = {
shouldOwn(paramPair.param)
paramMap.put(paramPair)
this
}
/** /**
* Optionally returns the user-supplied value of a param. * Optionally returns the user-supplied value of a param.
*/ */
......
...@@ -85,6 +85,7 @@ private[shared] object SharedParamsCodeGen { ...@@ -85,6 +85,7 @@ private[shared] object SharedParamsCodeGen {
case _ if c == classOf[Float] => "FloatParam" case _ if c == classOf[Float] => "FloatParam"
case _ if c == classOf[Double] => "DoubleParam" case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam" case _ if c == classOf[Boolean] => "BooleanParam"
case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
case _ => s"Param[${getTypeString(c)}]" case _ => s"Param[${getTypeString(c)}]"
} }
} }
......
...@@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params { ...@@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
* Param for input column names. * Param for input column names.
* @group param * @group param
*/ */
final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names") final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")
/** @group getParam */ /** @group getParam */
final def getInputCols: Array[String] = $(inputCols) final def getInputCols: Array[String] = $(inputCols)
......
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
# #
from pyspark.rdd import ignore_unicode_prefix from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures
from pyspark.ml.util import keyword_only from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaTransformer from pyspark.ml.wrapper import JavaTransformer
from pyspark.mllib.common import inherit_doc from pyspark.mllib.common import inherit_doc
__all__ = ['Tokenizer', 'HashingTF'] __all__ = ['Tokenizer', 'HashingTF', 'VectorAssembler']
@inherit_doc @inherit_doc
...@@ -112,6 +112,45 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): ...@@ -112,6 +112,45 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
return self._set(**kwargs) return self._set(**kwargs)
@inherit_doc
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
"""
A feature transformer that merges multiple columns into a vector column.
>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
>>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
>>> vecAssembler.transform(df).head().features
SparseVector(3, {0: 1.0, 2: 3.0})
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
SparseVector(3, {0: 1.0, 2: 3.0})
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
>>> vecAssembler.transform(df, params).head().vector
SparseVector(2, {1: 1.0})
"""
_java_class = "org.apache.spark.ml.feature.VectorAssembler"
@keyword_only
def __init__(self, inputCols=None, outputCol=None):
"""
__init__(self, inputCols=None, outputCol=None)
"""
super(VectorAssembler, self).__init__()
self._setDefault()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, inputCols=None, outputCol=None):
"""
setParams(self, inputCols=None, outputCol=None)
Sets params for this VectorAssembler.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
from pyspark.context import SparkContext from pyspark.context import SparkContext
......
...@@ -95,6 +95,7 @@ if __name__ == "__main__": ...@@ -95,6 +95,7 @@ if __name__ == "__main__":
("predictionCol", "prediction column name", "'prediction'"), ("predictionCol", "prediction column name", "'prediction'"),
("rawPredictionCol", "raw prediction column name", "'rawPrediction'"), ("rawPredictionCol", "raw prediction column name", "'rawPrediction'"),
("inputCol", "input column name", None), ("inputCol", "input column name", None),
("inputCols", "input column names", None),
("outputCol", "output column name", None), ("outputCol", "output column name", None),
("numFeatures", "number of features", None)] ("numFeatures", "number of features", None)]
code = [] code = []
......
...@@ -223,6 +223,35 @@ class HasInputCol(Params): ...@@ -223,6 +223,35 @@ class HasInputCol(Params):
return self.getOrDefault(self.inputCol) return self.getOrDefault(self.inputCol)
class HasInputCols(Params):
"""
Mixin for param inputCols: input column names.
"""
# a placeholder to make it appear in the generated doc
inputCols = Param(Params._dummy(), "inputCols", "input column names")
def __init__(self):
super(HasInputCols, self).__init__()
#: param for input column names
self.inputCols = Param(self, "inputCols", "input column names")
if None is not None:
self._setDefault(inputCols=None)
def setInputCols(self, value):
"""
Sets the value of :py:attr:`inputCols`.
"""
self.paramMap[self.inputCols] = value
return self
def getInputCols(self):
"""
Gets the value of inputCols or its default value.
"""
return self.getOrDefault(self.inputCols)
class HasOutputCol(Params): class HasOutputCol(Params):
""" """
Mixin for param outputCol: output column name. Mixin for param outputCol: output column name.
......
...@@ -67,7 +67,9 @@ class JavaWrapper(Params): ...@@ -67,7 +67,9 @@ class JavaWrapper(Params):
paramMap = self.extractParamMap(params) paramMap = self.extractParamMap(params)
for param in self.params: for param in self.params:
if param in paramMap: if param in paramMap:
java_obj.set(param.name, paramMap[param]) value = paramMap[param]
java_param = java_obj.getParam(param.name)
java_obj.set(java_param.w(value))
def _empty_java_param_map(self): def _empty_java_param_map(self):
""" """
...@@ -79,7 +81,8 @@ class JavaWrapper(Params): ...@@ -79,7 +81,8 @@ class JavaWrapper(Params):
paramMap = self._empty_java_param_map() paramMap = self._empty_java_param_map()
for param, value in params.items(): for param, value in params.items():
if param.parent is self: if param.parent is self:
paramMap.put(java_obj.getParam(param.name), value) java_param = java_obj.getParam(param.name)
paramMap.put(java_param.w(value))
return paramMap return paramMap
...@@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper): ...@@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper):
def transform(self, dataset, params={}): def transform(self, dataset, params={}):
java_obj = self._java_obj() java_obj = self._java_obj()
self._transfer_params_to_java({}, java_obj) self._transfer_params_to_java(params, java_obj)
java_param_map = self._create_java_param_map(params, java_obj) return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)
return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
dataset.sql_ctx)
@inherit_doc @inherit_doc
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment