diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index e7ca7ada74c8cade37adfa90e26c3e1fdf9529c1..12a76dbbfb4b30338c96db0ecf8a52558c741b6c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -125,6 +125,7 @@ class RFormula(override val uid: String)
           encoderStages += new StringIndexer()
             .setInputCol(term)
             .setOutputCol(indexCol)
+          prefixesToRewrite(indexCol + "_") = term + "_"
           (term, indexCol)
         case _ =>
           (term, term)
@@ -229,7 +230,7 @@ class RFormulaModel private[feature](
   override def copy(extra: ParamMap): RFormulaModel = copyValues(
     new RFormulaModel(uid, resolvedFormula, pipelineModel))
 
-  override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)"
+  override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
 
   private def transformLabel(dataset: DataFrame): DataFrame = {
     val labelName = resolvedFormula.label
@@ -400,14 +401,10 @@ private class VectorAttributeRewriter(
       val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
       val attrs = group.attributes.get.map { attr =>
         if (attr.name.isDefined) {
-          val name = attr.name.get
-          val replacement = prefixesToRewrite.filter { case (k, _) => name.startsWith(k) }
-          if (replacement.nonEmpty) {
-            val (k, v) = replacement.headOption.get
-            attr.withName(v + name.stripPrefix(k))
-          } else {
-            attr
+          val name = prefixesToRewrite.foldLeft(attr.name.get) { case (curName, (from, to)) =>
+            curName.replace(from, to)
           }
+          attr.withName(name)
         } else {
           attr
         }
@@ -416,7 +413,7 @@ private class VectorAttributeRewriter(
     }
     val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
     val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
-    dataset.select((otherCols :+ rewrittenCol): _*)
+    dataset.select(otherCols :+ rewrittenCol : _*)
   }
 
   override def transformSchema(schema: StructType): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index c579a0d68ec6ce92c67a43436c93cb889d9231a6..faa0f6f407b349cae6d8723d689b9fea44a2df35 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -161,15 +161,14 @@ class StringIndexerModel (
     }
 
     val metadata = NominalAttribute.defaultAttr
-      .withName($(inputCol)).withValues(labels).toMetadata()
+      .withName($(outputCol)).withValues(labels).toMetadata()
     // If we are skipping invalid records, filter them out.
-    val filteredDataset = (getHandleInvalid) match {
-      case "skip" => {
+    val filteredDataset = getHandleInvalid match {
+      case "skip" =>
         val filterer = udf { label: String =>
           labelToIndex.contains(label)
         }
         dataset.where(filterer(dataset($(inputCol))))
-      }
       case _ => dataset
     }
     filteredDataset.select(col("*"),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d40e69dceda959c16083c52aa6d03c93be85ee4e..2c3255ef3336c4056b50c340ccc08300724011fe 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -210,4 +210,17 @@ class StringIndexerSuite
       .setLabels(Array("a", "b", "c"))
     testDefaultReadWrite(t)
   }
+
+  test("StringIndexer metadata") {
+    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    val transformed = indexer.transform(df)
+    val attrs =
+      NominalAttribute.decodeStructField(transformed.schema("labelIndex"), preserveName = true)
+    assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+  }
 }