diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index ebfa972532358afa06fb00aef2434c1fa6d6a39a..e4485eb038409674020780ad43faa1e649f5ff4d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -33,7 +33,8 @@ import org.apache.spark.util.collection.OpenHashMap
 /**
  * Base trait for [[StringIndexer]] and [[StringIndexerModel]].
  */
-private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
+private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
+    with HasHandleInvalid {
 
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
@@ -65,13 +66,16 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
 
   def this() = this(Identifiable.randomUID("strIdx"))
 
+  /** @group setParam */
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+  setDefault(handleInvalid, "error")
+
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
 
   /** @group setParam */
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  // TODO: handle unseen labels
 
   override def fit(dataset: DataFrame): StringIndexerModel = {
     val counts = dataset.select(col($(inputCol)).cast(StringType))
@@ -111,6 +115,10 @@ class StringIndexerModel private[ml] (
     map
   }
 
+  /** @group setParam */
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+  setDefault(handleInvalid, "error")
+
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
 
@@ -128,14 +136,24 @@ class StringIndexerModel private[ml] (
       if (labelToIndex.contains(label)) {
         labelToIndex(label)
       } else {
-        // TODO: handle unseen labels
         throw new SparkException(s"Unseen label: $label.")
       }
     }
+
     val outputColName = $(outputCol)
     val metadata = NominalAttribute.defaultAttr
       .withName(outputColName).withValues(labels).toMetadata()
-    dataset.select(col("*"),
+    // If we are skipping invalid records, filter them out.
+    val filteredDataset = (getHandleInvalid) match {
+      case "skip" => {
+        val filterer = udf { label: String =>
+          labelToIndex.contains(label)
+        }
+        dataset.where(filterer(dataset($(inputCol))))
+      }
+      case _ => dataset
+    }
+    filteredDataset.select(col("*"),
       indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index a97c8059b8d458423fff3497216bf279a99183a3..da4c076830391bd29b2c11b6de39d0f6359e42e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -59,6 +59,10 @@ private[shared] object SharedParamsCodeGen {
       ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
         isValid = "ParamValidators.gtEq(1)"),
       ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
+      ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
+        "will filter out rows with bad values), or error (which will throw an errror). More " +
+        "options may be added later.",
+        isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
       ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
         " before fitting the model.", Some("true")),
       ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index f332630c32f1bcbf49c437e6e906eb43de41eb59..23e2b6cc439966f3fca6278aaa9035c576cdbcac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -247,6 +247,21 @@ private[ml] trait HasFitIntercept extends Params {
   final def getFitIntercept: Boolean = $(fitIntercept)
 }
 
+/**
+ * Trait for shared param handleInvalid.
+ */
+private[ml] trait HasHandleInvalid extends Params {
+
+  /**
+   * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
+   * @group param
+   */
+  final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error")))
+
+  /** @group getParam */
+  final def getHandleInvalid: String = $(handleInvalid)
+}
+
 /**
  * Trait for shared param standardization (default: true).
  */
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d0295a0fe2fc15162ce000c2f6628093a2d482e4..b111036087e6ac928daeb597a86a2440c699acb9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.ml.feature
 
+import org.apache.spark.SparkException
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.ml.param.ParamsSuite
@@ -62,6 +63,37 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
       reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
   }
 
+  test("StringIndexerUnseen") {
+    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2)
+    val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val df2 = sqlContext.createDataFrame(data2).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    // Verify we throw by default with unseen values
+    intercept[SparkException] {
+      indexer.transform(df2).collect()
+    }
+    val indexerSkipInvalid = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .setHandleInvalid("skip")
+      .fit(df)
+    // Verify that we skip the c record
+    val transformed = indexerSkipInvalid.transform(df2)
+    val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
+      .asInstanceOf[NominalAttribute]
+    assert(attr.values.get === Array("b", "a"))
+    val output = transformed.select("id", "labelIndex").map { r =>
+      (r.getInt(0), r.getDouble(1))
+    }.collect().toSet
+    // a -> 1, b -> 0
+    val expected = Set((0, 1.0), (1, 0.0))
+    assert(output === expected)
+  }
+
   test("StringIndexer with a numeric input column") {
     val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2)
     val df = sqlContext.createDataFrame(data).toDF("id", "label")