From 54d13bed87fcf2968f77e1f1153e85184ec91d78 Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph@databricks.com>
Date: Fri, 25 Mar 2016 16:00:09 -0700
Subject: [PATCH] [SPARK-14159][ML] Fixed bug in StringIndexer + related issue
 in RFormula

## What changes were proposed in this pull request?

StringIndexerModel.transform sets the output column metadata to use name inputCol.  It should not.  Fixing this causes a problem with the metadata produced by RFormula.

Fix in RFormula: I added the StringIndexer columns to prefixesToRewrite, and I modified VectorAttributeRewriter to find and replace all "prefixes" since attributes collect multiple prefixes from StringIndexer + Interaction.

Note that "prefixes" is no longer accurate since internal strings may be replaced.

## How was this patch tested?

Unit test which failed before this fix.

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #11965 from jkbradley/StringIndexer-fix.
---
 .../org/apache/spark/ml/feature/RFormula.scala    | 15 ++++++---------
 .../apache/spark/ml/feature/StringIndexer.scala   |  7 +++----
 .../spark/ml/feature/StringIndexerSuite.scala     | 13 +++++++++++++
 3 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index e7ca7ada74..12a76dbbfb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -125,6 +125,7 @@ class RFormula(override val uid: String)
           encoderStages += new StringIndexer()
             .setInputCol(term)
             .setOutputCol(indexCol)
+          prefixesToRewrite(indexCol + "_") = term + "_"
           (term, indexCol)
         case _ =>
           (term, term)
@@ -229,7 +230,7 @@ class RFormulaModel private[feature](
   override def copy(extra: ParamMap): RFormulaModel = copyValues(
     new RFormulaModel(uid, resolvedFormula, pipelineModel))
 
-  override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)"
+  override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
 
   private def transformLabel(dataset: DataFrame): DataFrame = {
     val labelName = resolvedFormula.label
@@ -400,14 +401,10 @@ private class VectorAttributeRewriter(
       val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
       val attrs = group.attributes.get.map { attr =>
         if (attr.name.isDefined) {
-          val name = attr.name.get
-          val replacement = prefixesToRewrite.filter { case (k, _) => name.startsWith(k) }
-          if (replacement.nonEmpty) {
-            val (k, v) = replacement.headOption.get
-            attr.withName(v + name.stripPrefix(k))
-          } else {
-            attr
+          val name = prefixesToRewrite.foldLeft(attr.name.get) { case (curName, (from, to)) =>
+            curName.replace(from, to)
           }
+          attr.withName(name)
         } else {
           attr
         }
@@ -416,7 +413,7 @@ private class VectorAttributeRewriter(
     }
     val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
     val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
-    dataset.select((otherCols :+ rewrittenCol): _*)
+    dataset.select(otherCols :+ rewrittenCol : _*)
   }
 
   override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index c579a0d68e..faa0f6f407 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -161,15 +161,14 @@ class StringIndexerModel (
     }
 
     val metadata = NominalAttribute.defaultAttr
-      .withName($(inputCol)).withValues(labels).toMetadata()
+      .withName($(outputCol)).withValues(labels).toMetadata()
     // If we are skipping invalid records, filter them out.
-    val filteredDataset = (getHandleInvalid) match {
-      case "skip" => {
+    val filteredDataset = getHandleInvalid match {
+      case "skip" =>
         val filterer = udf { label: String =>
           labelToIndex.contains(label)
         }
         dataset.where(filterer(dataset($(inputCol))))
-      }
       case _ => dataset
     }
     filteredDataset.select(col("*"),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d40e69dced..2c3255ef33 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -210,4 +210,17 @@ class StringIndexerSuite
       .setLabels(Array("a", "b", "c"))
     testDefaultReadWrite(t)
   }
+
+  test("StringIndexer metadata") {
+    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    val transformed = indexer.transform(df)
+    val attrs =
+      NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true)
+    assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+  }
 }
-- 
GitLab