From 0f2f56c37b8d09eec2722a5ffba3015d7f3b626f Mon Sep 17 00:00:00 2001 From: Wayne Zhang <actuaryzhang@uber.com> Date: Sun, 21 May 2017 16:51:55 -0700 Subject: [PATCH] [SPARK-20736][PYTHON] PySpark StringIndexer supports StringOrderType ## What changes were proposed in this pull request? PySpark StringIndexer supports StringOrderType added in #17879. Author: Wayne Zhang <actuaryzhang@uber.com> Closes #17978 from actuaryzhang/PythonStringIndexer. --- python/pyspark/ml/feature.py | 51 ++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 8d25f5b3a7..955bc9768c 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2082,10 +2082,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. - The indices are in [0, numLabels), ordered by label frequencies. - So the most frequent label gets index 0. + The indices are in [0, numLabels). By default, this is ordered by label frequencies + so the most frequent label gets index 0. The ordering behavior is controlled by + setting :py:attr:`stringOrderType`. Its default value is 'frequencyDesc'. - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid='error') + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error", + ... stringOrderType="frequencyDesc") >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), @@ -2111,26 +2113,45 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> loadedInverter = IndexToString.load(indexToStringPath) >>> loadedInverter.getLabels() == inverter.getLabels() True + >>> stringIndexer.getStringOrderType() + 'frequencyDesc' + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error", + ... stringOrderType="alphabetDesc") + >>> model = stringIndexer.fit(stringIndDf) + >>> td = model.transform(stringIndDf) + >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), + ... key=lambda x: x[0]) + [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)] .. versionadded:: 1.4.0 """ + stringOrderType = Param(Params._dummy(), "stringOrderType", + "How to order labels of string column. The first label after " + + "ordering is assigned an index of 0. Supported options: " + + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): + def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", + stringOrderType="frequencyDesc"): """ - __init__(self, inputCol=None, outputCol=None, handleInvalid="error") + __init__(self, inputCol=None, outputCol=None, handleInvalid="error", \ + stringOrderType="frequencyDesc") """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) - self._setDefault(handleInvalid="error") + self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): + def setParams(self, inputCol=None, outputCol=None, handleInvalid="error", + stringOrderType="frequencyDesc"): """ - setParams(self, inputCol=None, outputCol=None, handleInvalid="error") + setParams(self, inputCol=None, outputCol=None, handleInvalid="error", \ + stringOrderType="frequencyDesc") Sets params for this StringIndexer. """ kwargs = self._input_kwargs @@ -2139,6 +2160,20 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, def _create_model(self, java_model): return StringIndexerModel(java_model) + @since("2.3.0") + def setStringOrderType(self, value): + """ + Sets the value of :py:attr:`stringOrderType`. + """ + return self._set(stringOrderType=value) + + @since("2.3.0") + def getStringOrderType(self): + """ + Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. + """ + return self.getOrDefault(self.stringOrderType) + class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ -- GitLab