Skip to content
Snippets Groups Projects
Commit e4765a46 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-9544] [MLLIB] add Python API for RFormula

Add Python API for RFormula. Similar to other feature transformers in Python. This is just a thin wrapper over the Scala implementation. ericl MechCoder

Author: Xiangrui Meng <meng@databricks.com>

Closes #7879 from mengxr/SPARK-9544 and squashes the following commits:

3d5ff03 [Xiangrui Meng] add an doctest for . and -
5e969a5 [Xiangrui Meng] fix pydoc
1cd41f8 [Xiangrui Meng] organize imports
3c18b10 [Xiangrui Meng] add Python API for RFormula
parent 8ca287eb
No related branches found
No related tags found
No related merge requests found
...@@ -19,16 +19,14 @@ package org.apache.spark.ml.feature ...@@ -19,16 +19,14 @@ package org.apache.spark.ml.feature
import scala.collection.mutable import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
/** /**
...@@ -63,31 +61,26 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R ...@@ -63,31 +61,26 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
*/ */
val formula: Param[String] = new Param(this, "formula", "R model formula") val formula: Param[String] = new Param(this, "formula", "R model formula")
private var parsedFormula: Option[ParsedRFormula] = None
/** /**
* Sets the formula to use for this transformer. Must be called before use. * Sets the formula to use for this transformer. Must be called before use.
* @group setParam * @group setParam
* @param value an R formula in string form (e.g. "y ~ x + z") * @param value an R formula in string form (e.g. "y ~ x + z")
*/ */
def setFormula(value: String): this.type = { def setFormula(value: String): this.type = set(formula, value)
parsedFormula = Some(RFormulaParser.parse(value))
set(formula, value)
this
}
/** @group getParam */ /** @group getParam */
def getFormula: String = $(formula) def getFormula: String = $(formula)
/** Whether the formula specifies fitting an intercept. */ /** Whether the formula specifies fitting an intercept. */
private[ml] def hasIntercept: Boolean = { private[ml] def hasIntercept: Boolean = {
require(parsedFormula.isDefined, "Must call setFormula() first.") require(isDefined(formula), "Formula must be defined first.")
parsedFormula.get.hasIntercept RFormulaParser.parse($(formula)).hasIntercept
} }
override def fit(dataset: DataFrame): RFormulaModel = { override def fit(dataset: DataFrame): RFormulaModel = {
require(parsedFormula.isDefined, "Must call setFormula() first.") require(isDefined(formula), "Formula must be defined first.")
val resolvedFormula = parsedFormula.get.resolve(dataset.schema) val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
// StringType terms and terms representing interactions need to be encoded before assembly. // StringType terms and terms representing interactions need to be encoded before assembly.
// TODO(ekl) add support for feature interactions // TODO(ekl) add support for feature interactions
val encoderStages = ArrayBuffer[PipelineStage]() val encoderStages = ArrayBuffer[PipelineStage]()
......
...@@ -24,7 +24,7 @@ from pyspark.mllib.common import inherit_doc ...@@ -24,7 +24,7 @@ from pyspark.mllib.common import inherit_doc
__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer',
'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel'] 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel']
@inherit_doc @inherit_doc
...@@ -1110,6 +1110,89 @@ class PCAModel(JavaModel): ...@@ -1110,6 +1110,89 @@ class PCAModel(JavaModel):
""" """
@inherit_doc
class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
"""
.. note:: Experimental
Implements the transforms required for fitting a dataset against an
R model formula. Currently we support a limited subset of the R
operators, including '~', '+', '-', and '.'. Also see the R formula
docs:
http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
>>> df = sqlContext.createDataFrame([
... (1.0, 1.0, "a"),
... (0.0, 2.0, "b"),
... (0.0, 0.0, "a")
... ], ["y", "x", "s"])
>>> rf = RFormula(formula="y ~ x + s")
>>> rf.fit(df).transform(df).show()
+---+---+---+---------+-----+
| y| x| s| features|label|
+---+---+---+---------+-----+
|1.0|1.0| a|[1.0,1.0]| 1.0|
|0.0|2.0| b|[2.0,0.0]| 0.0|
|0.0|0.0| a|[0.0,1.0]| 0.0|
+---+---+---+---------+-----+
...
>>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show()
+---+---+---+--------+-----+
| y| x| s|features|label|
+---+---+---+--------+-----+
|1.0|1.0| a| [1.0]| 1.0|
|0.0|2.0| b| [2.0]| 0.0|
|0.0|0.0| a| [0.0]| 0.0|
+---+---+---+--------+-----+
...
"""
# a placeholder to make it appear in the generated doc
formula = Param(Params._dummy(), "formula", "R model formula")
@keyword_only
def __init__(self, formula=None, featuresCol="features", labelCol="label"):
"""
__init__(self, formula=None, featuresCol="features", labelCol="label")
"""
super(RFormula, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
self.formula = Param(self, "formula", "R model formula")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, formula=None, featuresCol="features", labelCol="label"):
"""
setParams(self, formula=None, featuresCol="features", labelCol="label")
Sets params for RFormula.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def setFormula(self, value):
"""
Sets the value of :py:attr:`formula`.
"""
self._paramMap[self.formula] = value
return self
def getFormula(self):
"""
Gets the value of :py:attr:`formula`.
"""
return self.getOrDefault(self.formula)
def _create_model(self, java_model):
return RFormulaModel(java_model)
class RFormulaModel(JavaModel):
"""
Model fitted by :py:class:`RFormula`.
"""
if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
from pyspark.context import SparkContext from pyspark.context import SparkContext
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment