Skip to content
Snippets Groups Projects
Commit f54ff19b authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-11349][ML] Support transform string label for RFormula

Currently ```RFormula``` can only handle label with ```NumericType``` or ```BinaryType``` (cast it to ```DoubleType``` as the label of Linear Regression training), we should also support label of ```StringType``` which is needed for Logistic Regression (glm with family = "binomial").
For label of ```StringType```, we should use ```StringIndexer``` to transform it to 0-based index.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9302 from yanboliang/spark-11349.
parent 3434572b
No related branches found
No related tags found
No related merge requests found
......@@ -132,6 +132,14 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
.setOutputCol($(featuresCol))
encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)
encoderStages += new ColumnPruner(tempColumns.toSet)
if (dataset.schema.fieldNames.contains(resolvedFormula.label) &&
dataset.schema(resolvedFormula.label).dataType == StringType) {
encoderStages += new StringIndexer()
.setInputCol(resolvedFormula.label)
.setOutputCol($(labelCol))
}
val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
}
......@@ -172,7 +180,7 @@ class RFormulaModel private[feature](
override def transformSchema(schema: StructType): StructType = {
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(schema)) {
if (hasLabelCol(withFeatures)) {
withFeatures
} else if (schema.exists(_.name == resolvedFormula.label)) {
val nullable = schema(resolvedFormula.label).dataType match {
......
......@@ -107,6 +107,25 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(result.collect() === expected.collect())
}
test("index string label") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(
Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5))
).toDF("id", "a", "b")
val model = formula.fit(original)
val result = model.transform(original)
val resultSchema = model.transformSchema(original.schema)
val expected = sqlContext.createDataFrame(
Seq(
("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0),
("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0))
).toDF("id", "a", "b", "features", "label")
// assert(result.schema.toString == resultSchema.toString)
assert(result.collect() === expected.collect())
}
test("attribute generation") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment