diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 810b02febbe77734739d06a3c1eb74d19f3fdd2a..99321bcc7cf98831edcbaa4aecc2842a809eb0cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap
 private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
 
   /**
-   * Param for how to handle unseen labels. Options are 'skip' (filter out rows with
-   * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
-   * bucket, at index numLabels.
+   * Param for how to handle invalid data (unseen labels or NULL values).
+   * Options are 'skip' (filter out rows with invalid data),
+   * 'error' (throw an error), or 'keep' (put invalid data in a special additional
+   * bucket, at index numLabels).
    * Default: "error"
    * @group param
    */
   @Since("1.6.0")
   val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
-    "unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
-    "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " +
-    "at index numLabels).",
+    "invalid data (unseen labels or NULL values). " +
+    "Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
+    "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
     ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
 
-  setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
+  setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
 
   /** @group getParam */
   @Since("1.6.0")
@@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") (
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): StringIndexerModel = {
     transformSchema(dataset.schema, logging = true)
-    val counts = dataset.select(col($(inputCol)).cast(StringType))
+    val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
       .rdd
       .map(_.getString(0))
       .countByValue()
@@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") (
 
 @Since("1.6.0")
 object StringIndexer extends DefaultParamsReadable[StringIndexer] {
-  private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
-  private[feature] val ERROR_UNSEEN_LABEL: String = "error"
-  private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
+  private[feature] val SKIP_INVALID: String = "skip"
+  private[feature] val ERROR_INVALID: String = "error"
+  private[feature] val KEEP_INVALID: String = "keep"
   private[feature] val supportedHandleInvalids: Array[String] =
-    Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
+    Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
 
   @Since("1.6.0")
   override def load(path: String): StringIndexer = super.load(path)
@@ -188,7 +189,7 @@ class StringIndexerModel (
     transformSchema(dataset.schema, logging = true)
 
     val filteredLabels = getHandleInvalid match {
-      case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
+      case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
       case _ => labels
     }
 
@@ -196,22 +197,31 @@ class StringIndexerModel (
       .withName($(outputCol)).withValues(filteredLabels).toMetadata()
     // If we are skipping invalid records, filter them out.
     val (filteredDataset, keepInvalid) = getHandleInvalid match {
-      case StringIndexer.SKIP_UNSEEN_LABEL =>
+      case StringIndexer.SKIP_INVALID =>
         val filterer = udf { label: String =>
           labelToIndex.contains(label)
         }
-        (dataset.where(filterer(dataset($(inputCol)))), false)
-      case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
+        (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false)
+      case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
     }
 
     val indexer = udf { label: String =>
-      if (labelToIndex.contains(label)) {
-        labelToIndex(label)
-      } else if (keepInvalid) {
-        labels.length
+      if (label == null) {
+        if (keepInvalid) {
+          labels.length
+        } else {
+          throw new SparkException("StringIndexer encountered NULL value. To handle or skip " +
+            "NULLS, try setting StringIndexer.handleInvalid.")
+        }
       } else {
-        throw new SparkException(s"Unseen label: $label.  To handle unseen labels, " +
-          s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
+        if (labelToIndex.contains(label)) {
+          labelToIndex(label)
+        } else if (keepInvalid) {
+          labels.length
+        } else {
+          throw new SparkException(s"Unseen label: $label.  To handle unseen labels, " +
+            s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
+        }
       }
     }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 188dffb3dd55ffd4c194e8358678e8361c03d682..8d9042b31e03322277e71aa361c66620b9b84fb4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -122,6 +122,51 @@ class StringIndexerSuite
     assert(output === expected)
   }
 
+  test("StringIndexer with NULLs") {
+    val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null))
+    val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null))
+    val df = data.toDF("id", "label")
+    val df2 = data2.toDF("id", "label")
+
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+
+    withClue("StringIndexer should throw error when setHandleInvalid=error " +
+      "when given NULL values") {
+      intercept[SparkException] {
+        indexer.setHandleInvalid("error")
+        indexer.fit(df).transform(df2).collect()
+      }
+    }
+
+    indexer.setHandleInvalid("skip")
+    val transformedSkip = indexer.fit(df).transform(df2)
+    val attrSkip = Attribute
+      .fromStructField(transformedSkip.schema("labelIndex"))
+      .asInstanceOf[NominalAttribute]
+    assert(attrSkip.values.get === Array("b", "a"))
+    val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
+      (r.getInt(0), r.getDouble(1))
+    }.collect().toSet
+    // a -> 1, b -> 0
+    val expectedSkip = Set((0, 1.0), (1, 0.0))
+    assert(outputSkip === expectedSkip)
+
+    indexer.setHandleInvalid("keep")
+    val transformedKeep = indexer.fit(df).transform(df2)
+    val attrKeep = Attribute
+      .fromStructField(transformedKeep.schema("labelIndex"))
+      .asInstanceOf[NominalAttribute]
+    assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+    val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
+      (r.getInt(0), r.getDouble(1))
+    }.collect().toSet
+    // a -> 1, b -> 0, null -> 2
+    val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0))
+    assert(outputKeep === expectedKeep)
+  }
+
   test("StringIndexerModel should keep silent if the input column does not exist.") {
     val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
       .setInputCol("label")