Skip to content
Snippets Groups Projects
Commit 0c8444cf authored by Yanbo Liang's avatar Yanbo Liang
Browse files

[SPARK-14657][SPARKR][ML] RFormula w/o intercept should output reference...

[SPARK-14657][SPARKR][ML] RFormula w/o intercept should output reference category when encoding string terms

## What changes were proposed in this pull request?

Please see [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657) for detail of this bug.
I searched online and test some other cases, found when we fit R glm model(or other models powered by R formula) w/o intercept on a dataset including string/category features, one of the categories in the first category feature is being used as reference category, we will not drop any category for that feature.
I think we should keep consistent semantics between Spark RFormula and R formula.
## How was this patch tested?

Add standard unit tests.

cc mengxr

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12414 from yanboliang/spark-14657.
parent 376d90d5
No related branches found
No related tags found
No related merge requests found
......@@ -205,12 +205,20 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
}.toMap
// Then we handle one-hot encoding and interactions between terms.
var keepReferenceCategory = false
val encodedTerms = resolvedFormula.terms.map {
case Seq(term) if dataset.schema(term).dataType == StringType =>
val encodedCol = tmpColumn("onehot")
encoderStages += new OneHotEncoder()
var encoder = new OneHotEncoder()
.setInputCol(indexed(term))
.setOutputCol(encodedCol)
// Formula w/o intercept, one of the categories in the first category feature is
// being used as reference category, we will not drop any category for that feature.
if (!hasIntercept && !keepReferenceCategory) {
encoder = encoder.setDropLast(false)
keepReferenceCategory = true
}
encoderStages += encoder
prefixesToRewrite(encodedCol + "_") = term + "_"
encodedCol
case Seq(term) =>
......
......@@ -213,6 +213,89 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(result.collect() === expected.collect())
}
test("formula w/o intercept, we should output reference category when encoding string terms") {
/*
R code:
df <- data.frame(id = c(1, 2, 3, 4),
a = c("foo", "bar", "bar", "baz"),
b = c("zq", "zz", "zz", "zz"),
c = c(4, 4, 5, 5))
model.matrix(id ~ a + b + c - 1, df)
abar abaz afoo bzz c
1 0 0 1 0 4
2 1 0 0 1 4
3 1 0 0 1 5
4 0 1 0 1 5
model.matrix(id ~ a:b + c - 1, df)
c abar:bzq abaz:bzq afoo:bzq abar:bzz abaz:bzz afoo:bzz
1 4 0 0 1 0 0 0
2 4 0 0 0 1 0 0
3 5 0 0 0 1 0 0
4 5 0 0 0 0 1 0
*/
val original = Seq((1, "foo", "zq", 4), (2, "bar", "zz", 4), (3, "bar", "zz", 5),
(4, "baz", "zz", 5)).toDF("id", "a", "b", "c")
val formula1 = new RFormula().setFormula("id ~ a + b + c - 1")
.setStringIndexerOrderType(StringIndexer.alphabetDesc)
val model1 = formula1.fit(original)
val result1 = model1.transform(original)
val resultSchema1 = model1.transformSchema(original.schema)
// Note the column order is different between R and Spark.
val expected1 = Seq(
(1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0),
(2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0),
(3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0),
(4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0)
).toDF("id", "a", "b", "c", "features", "label")
assert(result1.schema.toString == resultSchema1.toString)
assert(result1.collect() === expected1.collect())
val attrs1 = AttributeGroup.fromStructField(result1.schema("features"))
val expectedAttrs1 = new AttributeGroup(
"features",
Array[Attribute](
new BinaryAttribute(Some("a_foo"), Some(1)),
new BinaryAttribute(Some("a_baz"), Some(2)),
new BinaryAttribute(Some("a_bar"), Some(3)),
new BinaryAttribute(Some("b_zz"), Some(4)),
new NumericAttribute(Some("c"), Some(5))))
assert(attrs1 === expectedAttrs1)
// There is no impact for string terms interaction.
val formula2 = new RFormula().setFormula("id ~ a:b + c - 1")
.setStringIndexerOrderType(StringIndexer.alphabetDesc)
val model2 = formula2.fit(original)
val result2 = model2.transform(original)
val resultSchema2 = model2.transformSchema(original.schema)
// Note the column order is different between R and Spark.
val expected2 = Seq(
(1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0),
(2, "bar", "zz", 4, Vectors.sparse(7, Array(4, 6), Array(1.0, 4.0)), 2.0),
(3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0),
(4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0)
).toDF("id", "a", "b", "c", "features", "label")
assert(result2.schema.toString == resultSchema2.toString)
assert(result2.collect() === expected2.collect())
val attrs2 = AttributeGroup.fromStructField(result2.schema("features"))
val expectedAttrs2 = new AttributeGroup(
"features",
Array[Attribute](
new NumericAttribute(Some("a_foo:b_zz"), Some(1)),
new NumericAttribute(Some("a_foo:b_zq"), Some(2)),
new NumericAttribute(Some("a_baz:b_zz"), Some(3)),
new NumericAttribute(Some("a_baz:b_zq"), Some(4)),
new NumericAttribute(Some("a_bar:b_zz"), Some(5)),
new NumericAttribute(Some("a_bar:b_zq"), Some(6)),
new NumericAttribute(Some("c"), Some(7))))
assert(attrs2 === expectedAttrs2)
}
test("index string label") {
val formula = new RFormula().setFormula("id ~ a + b")
val original =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment