From d9f4ce6943c16a7e29f98e57c33acbfc0379b54d Mon Sep 17 00:00:00 2001
From: Nick Pentreath <nickp@za.ibm.com>
Date: Fri, 24 Mar 2017 08:01:15 -0700
Subject: [PATCH] [SPARK-15040][ML][PYSPARK] Add Imputer to PySpark

Add Python wrapper for `Imputer` feature transformer.

## How was this patch tested?

New doc tests and tweak to PySpark ML `tests.py`

Author: Nick Pentreath <nickp@za.ibm.com>

Closes #17316 from MLnick/SPARK-15040-pyspark-imputer.
---
 .../org/apache/spark/ml/feature/Imputer.scala |  10 +-
 python/pyspark/ml/feature.py                  | 160 ++++++++++++++++++
 python/pyspark/ml/tests.py                    |  10 ++
 3 files changed, 175 insertions(+), 5 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index b1a802ee13..ec4c6ad75e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -93,12 +93,12 @@ private[feature] trait ImputerParams extends Params with HasInputCols {
 /**
  * :: Experimental ::
  * Imputation estimator for completing missing values, either using the mean or the median
- * of the column in which the missing values are located. The input column should be of
- * DoubleType or FloatType. Currently Imputer does not support categorical features yet
+ * of the columns in which the missing values are located. The input columns should be of
+ * DoubleType or FloatType. Currently Imputer does not support categorical features
  * (SPARK-15041) and possibly creates incorrect values for a categorical feature.
  *
  * Note that the mean/median value is computed after filtering out missing values.
- * All Null values in the input column are treated as missing, and so are also imputed. For
+ * All Null values in the input columns are treated as missing, and so are also imputed. For
  * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
  */
 @Experimental
@@ -176,8 +176,8 @@ object Imputer extends DefaultParamsReadable[Imputer] {
  * :: Experimental ::
  * Model fitted by [[Imputer]].
  *
- * @param surrogateDF a DataFrame contains inputCols and their corresponding surrogates, which are
- *                    used to replace the missing values in the input DataFrame.
+ * @param surrogateDF a DataFrame containing inputCols and their corresponding surrogates,
+ *                    which are used to replace the missing values in the input DataFrame.
  */
 @Experimental
 class ImputerModel private[ml](
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 92f8549e9c..8d25f5b3a7 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -36,6 +36,7 @@ __all__ = ['Binarizer',
            'ElementwiseProduct',
            'HashingTF',
            'IDF', 'IDFModel',
+           'Imputer', 'ImputerModel',
            'IndexToString',
            'MaxAbsScaler', 'MaxAbsScalerModel',
            'MinHashLSH', 'MinHashLSHModel',
@@ -870,6 +871,165 @@ class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable):
         return self._call_java("idf")
 
 
+@inherit_doc
+class Imputer(JavaEstimator, HasInputCols, JavaMLReadable, JavaMLWritable):
+    """
+    .. note:: Experimental
+
+    Imputation estimator for completing missing values, either using the mean or the median
+    of the columns in which the missing values are located. The input columns should be of
+    DoubleType or FloatType. Currently Imputer does not support categorical features and
+    possibly creates incorrect values for a categorical feature.
+
+    Note that the mean/median value is computed after filtering out missing values.
+    All Null values in the input columns are treated as missing, and so are also imputed. For
+    computing median, :py:meth:`pyspark.sql.DataFrame.approxQuantile` is used with a
+    relative error of `0.001`.
+
+    >>> df = spark.createDataFrame([(1.0, float("nan")), (2.0, float("nan")), (float("nan"), 3.0),
+    ...                             (4.0, 4.0), (5.0, 5.0)], ["a", "b"])
+    >>> imputer = Imputer(inputCols=["a", "b"], outputCols=["out_a", "out_b"])
+    >>> model = imputer.fit(df)
+    >>> model.surrogateDF.show()
+    +---+---+
+    |  a|  b|
+    +---+---+
+    |3.0|4.0|
+    +---+---+
+    ...
+    >>> model.transform(df).show()
+    +---+---+-----+-----+
+    |  a|  b|out_a|out_b|
+    +---+---+-----+-----+
+    |1.0|NaN|  1.0|  4.0|
+    |2.0|NaN|  2.0|  4.0|
+    |NaN|3.0|  3.0|  3.0|
+    ...
+    >>> imputer.setStrategy("median").setMissingValue(1.0).fit(df).transform(df).show()
+    +---+---+-----+-----+
+    |  a|  b|out_a|out_b|
+    +---+---+-----+-----+
+    |1.0|NaN|  4.0|  NaN|
+    ...
+    >>> imputerPath = temp_path + "/imputer"
+    >>> imputer.save(imputerPath)
+    >>> loadedImputer = Imputer.load(imputerPath)
+    >>> loadedImputer.getStrategy() == imputer.getStrategy()
+    True
+    >>> loadedImputer.getMissingValue()
+    1.0
+    >>> modelPath = temp_path + "/imputer-model"
+    >>> model.save(modelPath)
+    >>> loadedModel = ImputerModel.load(modelPath)
+    >>> loadedModel.transform(df).head().out_a == model.transform(df).head().out_a
+    True
+
+    .. versionadded:: 2.2.0
+    """
+
+    outputCols = Param(Params._dummy(), "outputCols",
+                       "output column names.", typeConverter=TypeConverters.toListString)
+
+    strategy = Param(Params._dummy(), "strategy",
+                     "strategy for imputation. If mean, then replace missing values using the mean "
+                     "value of the feature. If median, then replace missing values using the "
+                     "median value of the feature.",
+                     typeConverter=TypeConverters.toString)
+
+    missingValue = Param(Params._dummy(), "missingValue",
+                         "The placeholder for the missing values. All occurrences of missingValue "
+                         "will be imputed.", typeConverter=TypeConverters.toFloat)
+
+    @keyword_only
+    def __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None,
+                 outputCols=None):
+        """
+        __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, \
+                 outputCols=None):
+        """
+        super(Imputer, self).__init__()
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Imputer", self.uid)
+        self._setDefault(strategy="mean", missingValue=float("nan"))
+        kwargs = self._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    @since("2.2.0")
+    def setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None,
+                  outputCols=None):
+        """
+        setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, \
+                  outputCols=None)
+        Sets params for this Imputer.
+        """
+        kwargs = self._input_kwargs
+        return self._set(**kwargs)
+
+    @since("2.2.0")
+    def setOutputCols(self, value):
+        """
+        Sets the value of :py:attr:`outputCols`.
+        """
+        return self._set(outputCols=value)
+
+    @since("2.2.0")
+    def getOutputCols(self):
+        """
+        Gets the value of :py:attr:`outputCols` or its default value.
+        """
+        return self.getOrDefault(self.outputCols)
+
+    @since("2.2.0")
+    def setStrategy(self, value):
+        """
+        Sets the value of :py:attr:`strategy`.
+        """
+        return self._set(strategy=value)
+
+    @since("2.2.0")
+    def getStrategy(self):
+        """
+        Gets the value of :py:attr:`strategy` or its default value.
+        """
+        return self.getOrDefault(self.strategy)
+
+    @since("2.2.0")
+    def setMissingValue(self, value):
+        """
+        Sets the value of :py:attr:`missingValue`.
+        """
+        return self._set(missingValue=value)
+
+    @since("2.2.0")
+    def getMissingValue(self):
+        """
+        Gets the value of :py:attr:`missingValue` or its default value.
+        """
+        return self.getOrDefault(self.missingValue)
+
+    def _create_model(self, java_model):
+        return ImputerModel(java_model)
+
+
+class ImputerModel(JavaModel, JavaMLReadable, JavaMLWritable):
+    """
+    .. note:: Experimental
+
+    Model fitted by :py:class:`Imputer`.
+
+    .. versionadded:: 2.2.0
+    """
+
+    @property
+    @since("2.2.0")
+    def surrogateDF(self):
+        """
+        Returns a DataFrame containing inputCols and their corresponding surrogates,
+        which are used to replace the missing values in the input DataFrame.
+        """
+        return self._call_java("surrogateDF")
+
+
 @inherit_doc
 class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
     """
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index f052f5bb77..cc559db587 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1273,6 +1273,7 @@ class DefaultValuesTests(PySparkTestCase):
     """
 
     def check_params(self, py_stage):
+        import pyspark.ml.feature
         if not hasattr(py_stage, "_to_java"):
             return
         java_stage = py_stage._to_java()
@@ -1292,6 +1293,15 @@ class DefaultValuesTests(PySparkTestCase):
                     _java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param))
                 py_stage._clear(p)
                 py_default = py_stage.getOrDefault(p)
+                if isinstance(py_stage, pyspark.ml.feature.Imputer) and p.name == "missingValue":
+                    # SPARK-15040 - default value for Imputer param 'missingValue' is NaN,
+                    # and NaN != NaN, so handle it specially here
+                    import math
+                    self.assertTrue(math.isnan(java_default) and math.isnan(py_default),
+                                    "Java default %s and python default %s are not both NaN for "
+                                    "param %s for Params %s"
+                                    % (str(java_default), str(py_default), p.name, str(py_stage)))
+                    return
                 self.assertEqual(java_default, py_default,
                                  "Java default %s != python default %s of param %s for Params %s"
                                  % (str(java_default), str(py_default), p.name, str(py_stage)))
-- 
GitLab