Skip to content
Snippets Groups Projects
Commit f47700c9 authored by Wayne Zhang's avatar Wayne Zhang Committed by Yanbo Liang
Browse files

[SPARK-14659][ML] RFormula consistent with R when handling strings

## What changes were proposed in this pull request?
When handling strings, the category dropped by RFormula and R are different:
- RFormula drops the least frequent level
- R drops the first level after ascending alphabetical ordering

This PR supports different string ordering types in StringIndexer #17879 so that RFormula can drop the same level as R when handling strings using`stringOrderType = "alphabetDesc"`.

## How was this patch tested?
new tests

Author: Wayne Zhang <actuaryzhang@uber.com>

Closes #17967 from actuaryzhang/RFormula.
parent 2dbe0c52
No related branches found
No related tags found
No related merge requests found
......@@ -26,7 +26,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
......@@ -37,6 +37,42 @@ import org.apache.spark.sql.types._
*/
private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol {
/**
* Param for how to order categories of a string FEATURE column used by `StringIndexer`.
* The last category after ordering is dropped when encoding strings.
* Supported options: 'frequencyDesc', 'frequencyAsc', 'alphabetDesc', 'alphabetAsc'.
* The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', `RFormula`
* drops the same category as R when encoding strings.
*
* The options are explained using an example `'b', 'a', 'b', 'a', 'c', 'b'`:
* {{{
* +-----------------+---------------------------------------+----------------------------------+
* | Option | Category mapped to 0 by StringIndexer | Category dropped by RFormula |
* +-----------------+---------------------------------------+----------------------------------+
* | 'frequencyDesc' | most frequent category ('b') | least frequent category ('c') |
* | 'frequencyAsc' | least frequent category ('c') | most frequent category ('b') |
* | 'alphabetDesc' | last alphabetical category ('c') | first alphabetical category ('a')|
* | 'alphabetAsc' | first alphabetical category ('a') | last alphabetical category ('c') |
* +-----------------+---------------------------------------+----------------------------------+
* }}}
* Note that this ordering option is NOT used for the label column. When the label column is
* indexed, it uses the default descending frequency ordering in `StringIndexer`.
*
* @group param
*/
@Since("2.3.0")
final val stringIndexerOrderType: Param[String] = new Param(this, "stringIndexerOrderType",
"How to order categories of a string FEATURE column used by StringIndexer. " +
"The last category after ordering is dropped when encoding strings. " +
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}. " +
"The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', " +
"RFormula drops the same category as R when encoding strings.",
ParamValidators.inArray(StringIndexer.supportedStringOrderType))
/** @group getParam */
@Since("2.3.0")
def getStringIndexerOrderType: String = $(stringIndexerOrderType)
protected def hasLabelCol(schema: StructType): Boolean = {
schema.map(_.name).contains($(labelCol))
}
......@@ -125,6 +161,11 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("2.1.0")
def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value)
/** @group setParam */
@Since("2.3.0")
def setStringIndexerOrderType(value: String): this.type = set(stringIndexerOrderType, value)
setDefault(stringIndexerOrderType, StringIndexer.frequencyDesc)
/** Whether the formula specifies fitting an intercept. */
private[ml] def hasIntercept: Boolean = {
require(isDefined(formula), "Formula must be defined first.")
......@@ -155,6 +196,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
encoderStages += new StringIndexer()
.setInputCol(term)
.setOutputCol(indexCol)
.setStringOrderType($(stringIndexerOrderType))
prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
case _ =>
......
......@@ -47,7 +47,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
* @group param
*/
@Since("1.6.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
"invalid data (unseen labels or NULL values). " +
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
......@@ -73,7 +73,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
*/
@Since("2.3.0")
final val stringOrderType: Param[String] = new Param(this, "stringOrderType",
"how to order labels of string column. " +
"How to order labels of string column. " +
"The first label after ordering is assigned an index of 0. " +
s"Supported options: ${StringIndexer.supportedStringOrderType.mkString(", ")}.",
ParamValidators.inArray(StringIndexer.supportedStringOrderType))
......
......@@ -129,6 +129,90 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(result.collect() === expected.collect())
}
test("encodes string terms with string indexer order type") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5))
.toDF("id", "a", "b")
val expected = Seq(
Seq(
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
(3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
(4, "aaz", 5, Vectors.dense(0.0, 1.0, 5.0), 4.0)
).toDF("id", "a", "b", "features", "label"),
Seq(
(1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(0.0, 0.0, 4.0), 2.0),
(3, "bar", 5, Vectors.dense(0.0, 0.0, 5.0), 3.0),
(4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0)
).toDF("id", "a", "b", "features", "label"),
Seq(
(1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0),
(3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0),
(4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
).toDF("id", "a", "b", "features", "label"),
Seq(
(1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0),
(3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0),
(4, "aaz", 5, Vectors.dense(1.0, 0.0, 5.0), 4.0)
).toDF("id", "a", "b", "features", "label")
)
var idx = 0
for (orderType <- StringIndexer.supportedStringOrderType) {
val model = formula.setStringIndexerOrderType(orderType).fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected(idx).collect())
idx += 1
}
}
test("test consistency with R when encoding string terms") {
/*
R code:
df <- data.frame(id = c(1, 2, 3, 4),
a = c("foo", "bar", "bar", "aaz"),
b = c(4, 4, 5, 5))
model.matrix(id ~ a + b, df)[, -1]
abar afoo b
0 1 4
1 0 4
1 0 5
0 0 5
*/
val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "aaz", 5))
.toDF("id", "a", "b")
val formula = new RFormula().setFormula("id ~ a + b")
.setStringIndexerOrderType(StringIndexer.alphabetDesc)
/*
Note that the category dropped after encoding is the same between R and Spark
(i.e., "aaz" is treated as the reference level).
However, the column order is still different:
R renders the columns in ascending alphabetical order ("bar", "foo"), while
RFormula renders the columns in descending alphabetical order ("foo", "bar").
*/
val expected = Seq(
(1, "foo", 4, Vectors.dense(1.0, 0.0, 4.0), 1.0),
(2, "bar", 4, Vectors.dense(0.0, 1.0, 4.0), 2.0),
(3, "bar", 5, Vectors.dense(0.0, 1.0, 5.0), 3.0),
(4, "aaz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
).toDF("id", "a", "b", "features", "label")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
}
test("index string label") {
val formula = new RFormula().setFormula("id ~ a + b")
val original =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment