From c19680be1c532dded1e70edce7a981ba28af09ad Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Sun, 2 Jul 2017 16:17:03 +0800
Subject: [PATCH] [SPARK-19852][PYSPARK][ML] Python StringIndexer supports
 'keep' to handle invalid data

## What changes were proposed in this pull request?
This PR is to maintain API parity with changes made in SPARK-17498 to support a new option
'keep' in StringIndexer to handle unseen labels or NULL values with PySpark.

Note: This is updated version of #17237 , the primary author of this PR is VinceShieh .
## How was this patch tested?
Unit tests.

Author: VinceShieh <vincent.xie@intel.com>
Author: Yanbo Liang <ybliang8@gmail.com>

Closes #18453 from yanboliang/spark-19852.
---
 python/pyspark/ml/feature.py |  6 ++++++
 python/pyspark/ml/tests.py   | 21 +++++++++++++++++++++
 2 files changed, 27 insertions(+)

diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 77de1cc182..25ad06f682 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2132,6 +2132,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
                             "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
                             typeConverter=TypeConverters.toString)
 
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
+                          "labels or NULL values). Options are 'skip' (filter out rows with " +
+                          "invalid data), error (throw an error), or 'keep' (put invalid data " +
+                          "in a special additional bucket, at index numLabels).",
+                          typeConverter=TypeConverters.toString)
+
     @keyword_only
     def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
                  stringOrderType="frequencyDesc"):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 17a39472e1..ffb8b0a890 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -551,6 +551,27 @@ class FeatureTests(SparkSessionTestCase):
         for i in range(0, len(expected)):
             self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
 
+    def test_string_indexer_handle_invalid(self):
+        df = self.spark.createDataFrame([
+            (0, "a"),
+            (1, "d"),
+            (2, None)], ["id", "label"])
+
+        si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep",
+                            stringOrderType="alphabetAsc")
+        model1 = si1.fit(df)
+        td1 = model1.transform(df)
+        actual1 = td1.select("id", "indexed").collect()
+        expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)]
+        self.assertEqual(actual1, expected1)
+
+        si2 = si1.setHandleInvalid("skip")
+        model2 = si2.fit(df)
+        td2 = model2.transform(df)
+        actual2 = td2.select("id", "indexed").collect()
+        expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
+        self.assertEqual(actual2, expected2)
+
 
 class HasInducedError(Params):
 
-- 
GitLab