Skip to content
Snippets Groups Projects
Commit 8e0a072f authored by Eric Liang's avatar Eric Liang Committed by Xiangrui Meng
Browse files

[SPARK-9895] User Guide for RFormula Feature Transformer

mengxr

Author: Eric Liang <ekl@databricks.com>

Closes #8293 from ericl/docs-2.
parent b0dbaec4
No related branches found
No related tags found
No related merge requests found
...@@ -1477,3 +1477,111 @@ print(output.select("features", "clicked").first()) ...@@ -1477,3 +1477,111 @@ print(output.select("features", "clicked").first())
</div> </div>
</div> </div>
## RFormula
`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula.
**Examples**
Assume that we have a DataFrame with the columns `id`, `country`, `hour`, and `clicked`:
~~~
id | country | hour | clicked
---|---------|------|---------
7 | "US" | 18 | 1.0
8 | "CA" | 12 | 0.0
9 | "NZ" | 15 | 0.0
~~~
If we use `RFormula` with a formula string of `clicked ~ country + hour`, which indicates that we want to
predict `clicked` based on `country` and `hour`, after transformation we should get the following DataFrame:
~~~
id | country | hour | clicked | features | label
---|---------|------|---------|------------------|-------
7 | "US" | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0
8 | "CA" | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0
9 | "NZ" | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0
~~~
<div class="codetabs">
<div data-lang="scala" markdown="1">
[`RFormula`](api/scala/index.html#org.apache.spark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns.
{% highlight scala %}
import org.apache.spark.ml.feature.RFormula
val dataset = sqlContext.createDataFrame(Seq(
(7, "US", 18, 1.0),
(8, "CA", 12, 0.0),
(9, "NZ", 15, 0.0)
)).toDF("id", "country", "hour", "clicked")
val formula = new RFormula()
.setFormula("clicked ~ country + hour")
.setFeaturesCol("features")
.setLabelCol("label")
val output = formula.fit(dataset).transform(dataset)
output.select("features", "label").show()
{% endhighlight %}
</div>
<div data-lang="java" markdown="1">
[`RFormula`](api/java/org/apache/spark/ml/feature/RFormula.html) takes an R formula string, and optional parameters for the names of its output columns.
{% highlight java %}
import java.util.Arrays;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.ml.feature.RFormula;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.*;
import static org.apache.spark.sql.types.DataTypes.*;
StructType schema = createStructType(new StructField[] {
createStructField("id", IntegerType, false),
createStructField("country", StringType, false),
createStructField("hour", IntegerType, false),
createStructField("clicked", DoubleType, false)
});
JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(
RowFactory.create(7, "US", 18, 1.0),
RowFactory.create(8, "CA", 12, 0.0),
RowFactory.create(9, "NZ", 15, 0.0)
));
DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
RFormula formula = new RFormula()
.setFormula("clicked ~ country + hour")
.setFeaturesCol("features")
.setLabelCol("label");
DataFrame output = formula.fit(dataset).transform(dataset);
output.select("features", "label").show();
{% endhighlight %}
</div>
<div data-lang="python" markdown="1">
[`RFormula`](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns.
{% highlight python %}
from pyspark.ml.feature import RFormula
dataset = sqlContext.createDataFrame(
[(7, "US", 18, 1.0),
(8, "CA", 12, 0.0),
(9, "NZ", 15, 0.0)],
["id", "country", "hour", "clicked"])
formula = RFormula(
formula="clicked ~ country + hour",
featuresCol="features",
labelCol="label")
output = formula.fit(dataset).transform(dataset)
output.select("features", "label").show()
{% endhighlight %}
</div>
</div>
...@@ -42,8 +42,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { ...@@ -42,8 +42,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
/** /**
* :: Experimental :: * :: Experimental ::
* Implements the transforms required for fitting a dataset against an R model formula. Currently * Implements the transforms required for fitting a dataset against an R model formula. Currently
* we support a limited subset of the R operators, including '~' and '+'. Also see the R formula * we support a limited subset of the R operators, including '.', '~', '+', and '-'. Also see the
* docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html * R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html
*/ */
@Experimental @Experimental
class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment