From 0f2f56c37b8d09eec2722a5ffba3015d7f3b626f Mon Sep 17 00:00:00 2001
From: Wayne Zhang <actuaryzhang@uber.com>
Date: Sun, 21 May 2017 16:51:55 -0700
Subject: [PATCH] [SPARK-20736][PYTHON] PySpark StringIndexer supports
 StringOrderType

## What changes were proposed in this pull request?
PySpark StringIndexer supports StringOrderType added in #17879.

Author: Wayne Zhang <actuaryzhang@uber.com>

Closes #17978 from actuaryzhang/PythonStringIndexer.
---
 python/pyspark/ml/feature.py | 51 ++++++++++++++++++++++++++++++------
 1 file changed, 43 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 8d25f5b3a7..955bc9768c 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2082,10 +2082,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
     """
     A label indexer that maps a string column of labels to an ML column of label indices.
     If the input column is numeric, we cast it to string and index the string values.
-    The indices are in [0, numLabels), ordered by label frequencies.
-    So the most frequent label gets index 0.
+    The indices are in [0, numLabels). By default, this is ordered by label frequencies
+    so the most frequent label gets index 0. The ordering behavior is controlled by
+    setting :py:attr:`stringOrderType`. Its default value is 'frequencyDesc'.
 
-    >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid='error')
+    >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error",
+    ...     stringOrderType="frequencyDesc")
     >>> model = stringIndexer.fit(stringIndDf)
     >>> td = model.transform(stringIndDf)
     >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
@@ -2111,26 +2113,45 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
     >>> loadedInverter = IndexToString.load(indexToStringPath)
     >>> loadedInverter.getLabels() == inverter.getLabels()
     True
+    >>> stringIndexer.getStringOrderType()
+    'frequencyDesc'
+    >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error",
+    ...     stringOrderType="alphabetDesc")
+    >>> model = stringIndexer.fit(stringIndDf)
+    >>> td = model.transform(stringIndDf)
+    >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]),
+    ...     key=lambda x: x[0])
+    [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
 
     .. versionadded:: 1.4.0
     """
 
+    stringOrderType = Param(Params._dummy(), "stringOrderType",
+                            "How to order labels of string column. The first label after " +
+                            "ordering is assigned an index of 0. Supported options: " +
+                            "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
+                            typeConverter=TypeConverters.toString)
+
     @keyword_only
-    def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
+    def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
+                 stringOrderType="frequencyDesc"):
         """
-        __init__(self, inputCol=None, outputCol=None, handleInvalid="error")
+        __init__(self, inputCol=None, outputCol=None, handleInvalid="error", \
+                 stringOrderType="frequencyDesc")
         """
         super(StringIndexer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
-        self._setDefault(handleInvalid="error")
+        self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
+    def setParams(self, inputCol=None, outputCol=None, handleInvalid="error",
+                  stringOrderType="frequencyDesc"):
         """
-        setParams(self, inputCol=None, outputCol=None, handleInvalid="error")
+        setParams(self, inputCol=None, outputCol=None, handleInvalid="error", \
+                  stringOrderType="frequencyDesc")
         Sets params for this StringIndexer.
         """
         kwargs = self._input_kwargs
@@ -2139,6 +2160,20 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
     def _create_model(self, java_model):
         return StringIndexerModel(java_model)
 
+    @since("2.3.0")
+    def setStringOrderType(self, value):
+        """
+        Sets the value of :py:attr:`stringOrderType`.
+        """
+        return self._set(stringOrderType=value)
+
+    @since("2.3.0")
+    def getStringOrderType(self):
+        """
+        Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'.
+        """
+        return self.getOrDefault(self.stringOrderType)
+
 
 class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
     """
-- 
GitLab