From 8beab68152348c44cf2f89850f792f164b06470d Mon Sep 17 00:00:00 2001
From: Xusen Yin <yinxusen@gmail.com>
Date: Tue, 26 Jan 2016 11:56:46 -0800
Subject: [PATCH] [SPARK-11923][ML] Python API for ml.feature.ChiSqSelector

https://issues.apache.org/jira/browse/SPARK-11923

Author: Xusen Yin <yinxusen@gmail.com>

Closes #10186 from yinxusen/SPARK-11923.
---
 python/pyspark/ml/feature.py | 98 +++++++++++++++++++++++++++++++++++-
 1 file changed, 97 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index f139d81bc4..32f324685a 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -33,7 +33,8 @@ __all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel',
            'PolynomialExpansion', 'QuantileDiscretizer', 'RegexTokenizer', 'RFormula',
            'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel',
            'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer',
-           'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel']
+           'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel',
+           'ChiSqSelector', 'ChiSqSelectorModel']
 
 
 @inherit_doc
@@ -2237,6 +2238,101 @@ class RFormulaModel(JavaModel):
     """
 
 
+@inherit_doc
+class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol):
+    """
+    .. note:: Experimental
+
+    Chi-Squared feature selection, which selects categorical features to use for predicting a
+    categorical label.
+
+    >>> from pyspark.mllib.linalg import Vectors
+    >>> df = sqlContext.createDataFrame(
+    ...    [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0),
+    ...     (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0),
+    ...     (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)],
+    ...    ["features", "label"])
+    >>> selector = ChiSqSelector(numTopFeatures=1, outputCol="selectedFeatures")
+    >>> model = selector.fit(df)
+    >>> model.transform(df).head().selectedFeatures
+    DenseVector([1.0])
+    >>> model.selectedFeatures
+    [3]
+
+    .. versionadded:: 2.0.0
+    """
+
+    # a placeholder to make it appear in the generated doc
+    numTopFeatures = \
+        Param(Params._dummy(), "numTopFeatures",
+              "Number of features that selector will select, ordered by statistics value " +
+              "descending. If the number of features is < numTopFeatures, then this will select " +
+              "all features.")
+
+    @keyword_only
+    def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"):
+        """
+        __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label")
+        """
+        super(ChiSqSelector, self).__init__()
+        self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
+        self.numTopFeatures = \
+            Param(self, "numTopFeatures",
+                  "Number of features that selector will select, ordered by statistics value " +
+                  "descending. If the number of features is < numTopFeatures, then this will " +
+                  "select all features.")
+        kwargs = self.__init__._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    @since("2.0.0")
+    def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,
+                  labelCol="labels"):
+        """
+        setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\
+                  labelCol="labels")
+        Sets params for this ChiSqSelector.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
+    @since("2.0.0")
+    def setNumTopFeatures(self, value):
+        """
+        Sets the value of :py:attr:`numTopFeatures`.
+        """
+        self._paramMap[self.numTopFeatures] = value
+        return self
+
+    @since("2.0.0")
+    def getNumTopFeatures(self):
+        """
+        Gets the value of numTopFeatures or its default value.
+        """
+        return self.getOrDefault(self.numTopFeatures)
+
+    def _create_model(self, java_model):
+        return ChiSqSelectorModel(java_model)
+
+
+class ChiSqSelectorModel(JavaModel):
+    """
+    .. note:: Experimental
+
+    Model fitted by ChiSqSelector.
+
+    .. versionadded:: 2.0.0
+    """
+
+    @property
+    @since("2.0.0")
+    def selectedFeatures(self):
+        """
+        List of indices to select (filter). Must be ordered asc.
+        """
+        return self._call_java("selectedFeatures")
+
+
 if __name__ == "__main__":
     import doctest
     from pyspark.context import SparkContext
-- 
GitLab