diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 77de1cc18246dd56c184ad07a6b89f0498ff51ae..25ad06f682ed93690a4612c6fe23866da27a0f50 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2132,6 +2132,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
                             "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
                             typeConverter=TypeConverters.toString)
 
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
+                          "labels or NULL values). Options are 'skip' (filter out rows with " +
+                          "invalid data), error (throw an error), or 'keep' (put invalid data " +
+                          "in a special additional bucket, at index numLabels).",
+                          typeConverter=TypeConverters.toString)
+
     @keyword_only
     def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
                  stringOrderType="frequencyDesc"):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 17a39472e1fe511dbbbc374b4cc4ad190f37fcfb..ffb8b0a890ff806add00ba6ddab579f44f913f54 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -551,6 +551,27 @@ class FeatureTests(SparkSessionTestCase):
         for i in range(0, len(expected)):
             self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
 
+    def test_string_indexer_handle_invalid(self):
+        df = self.spark.createDataFrame([
+            (0, "a"),
+            (1, "d"),
+            (2, None)], ["id", "label"])
+
+        si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep",
+                            stringOrderType="alphabetAsc")
+        model1 = si1.fit(df)
+        td1 = model1.transform(df)
+        actual1 = td1.select("id", "indexed").collect()
+        expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)]
+        self.assertEqual(actual1, expected1)
+
+        si2 = si1.setHandleInvalid("skip")
+        model2 = si2.fit(df)
+        td2 = model2.transform(df)
+        actual2 = td2.select("id", "indexed").collect()
+        expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
+        self.assertEqual(actual2, expected2)
+
 
 class HasInducedError(Params):