From c19680be1c532dded1e70edce7a981ba28af09ad Mon Sep 17 00:00:00 2001 From: Yanbo Liang <ybliang8@gmail.com> Date: Sun, 2 Jul 2017 16:17:03 +0800 Subject: [PATCH] [SPARK-19852][PYSPARK][ML] Python StringIndexer supports 'keep' to handle invalid data ## What changes were proposed in this pull request? This PR is to maintain API parity with changes made in SPARK-17498 to support a new option 'keep' in StringIndexer to handle unseen labels or NULL values with PySpark. Note: This is updated version of #17237 , the primary author of this PR is VinceShieh . ## How was this patch tested? Unit tests. Author: VinceShieh <vincent.xie@intel.com> Author: Yanbo Liang <ybliang8@gmail.com> Closes #18453 from yanboliang/spark-19852. --- python/pyspark/ml/feature.py | 6 ++++++ python/pyspark/ml/tests.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 77de1cc182..25ad06f682 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2132,6 +2132,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", typeConverter=TypeConverters.toString) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "labels or NULL values). Options are 'skip' (filter out rows with " + + "invalid data), error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 17a39472e1..ffb8b0a890 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -551,6 +551,27 @@ class FeatureTests(SparkSessionTestCase): for i in range(0, len(expected)): self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + def test_string_indexer_handle_invalid(self): + df = self.spark.createDataFrame([ + (0, "a"), + (1, "d"), + (2, None)], ["id", "label"]) + + si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", + stringOrderType="alphabetAsc") + model1 = si1.fit(df) + td1 = model1.transform(df) + actual1 = td1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] + self.assertEqual(actual1, expected1) + + si2 = si1.setHandleInvalid("skip") + model2 = si2.fit(df) + td2 = model2.transform(df) + actual2 = td2.select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] + self.assertEqual(actual2, expected2) + class HasInducedError(Params): -- GitLab