From 35c9599b94de759204ed33cdd46d8ee108bccd86 Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Fri, 8 May 2015 15:48:39 -0700
Subject: [PATCH] [SPARK-5913] [MLLIB] Python API for ChiSqSelector

Add a Python API for mllib.feature.ChiSqSelector
https://issues.apache.org/jira/browse/SPARK-5913

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #5939 from yanboliang/spark-5913 and squashes the following commits:

cdaac99 [Yanbo Liang] Python API for ChiSqSelector
---
 .../mllib/api/python/PythonMLLibAPI.scala     | 10 ++++
 python/pyspark/mllib/feature.py               | 59 ++++++++++++++++++-
 2 files changed, 67 insertions(+), 2 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 426306d78c..8c30ad4b39 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -494,6 +494,16 @@ private[python] class PythonMLLibAPI extends Serializable {
     new StandardScaler(withMean, withStd).fit(data.rdd)
   }
 
+  /**
+   * Java stub for ChiSqSelector.fit(). This stub returns a
+   * handle to the Java object instead of the content of the Java object.
+   * Extra care needs to be taken in the Python code to ensure it gets freed on
+   * exit; see the Py4J documentation.
+   */
+  def fitChiSqSelector(numTopFeatures: Int, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
+    new ChiSqSelector(numTopFeatures).fit(data.rdd)
+  }
+
   /**
    * Java stub for IDF.fit(). This stub returns a
    * handle to the Java object instead of the content of the Java object.
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 1140539a24..aac305db6c 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -33,10 +33,12 @@ from py4j.protocol import Py4JJavaError
 from pyspark import SparkContext
 from pyspark.rdd import RDD, ignore_unicode_prefix
 from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Vectors, _convert_to_vector
+from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector, _convert_to_vector
+from pyspark.mllib.regression import LabeledPoint
 
 __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
-           'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
+           'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
+           'ChiSqSelector', 'ChiSqSelectorModel']
 
 
 class VectorTransformer(object):
@@ -199,6 +201,59 @@ class StandardScaler(object):
         return StandardScalerModel(jmodel)
 
 
+class ChiSqSelectorModel(JavaVectorTransformer):
+    """
+    .. note:: Experimental
+
+    Represents a Chi Squared selector model.
+    """
+    def transform(self, vector):
+        """
+        Applies transformation on a vector.
+
+        :param vector: Vector or RDD of Vector to be transformed.
+        :return: transformed vector.
+        """
+        return JavaVectorTransformer.transform(self, vector)
+
+
+class ChiSqSelector(object):
+    """
+    .. note:: Experimental
+
+    Creates a ChiSquared feature selector.
+
+    >>> data = [
+    ...     LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
+    ...     LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})),
+    ...     LabeledPoint(1.0, [0.0, 9.0, 8.0]),
+    ...     LabeledPoint(2.0, [8.0, 9.0, 5.0])
+    ... ]
+    >>> model = ChiSqSelector(1).fit(sc.parallelize(data))
+    >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
+    SparseVector(1, {0: 6.0})
+    >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
+    DenseVector([5.0])
+    """
+    def __init__(self, numTopFeatures):
+        """
+        :param numTopFeatures: number of features that selector will select.
+        """
+        self.numTopFeatures = int(numTopFeatures)
+
+    def fit(self, data):
+        """
+        Returns a ChiSquared feature selector.
+
+        :param data: an `RDD[LabeledPoint]` containing the labeled dataset
+                 with categorical features. Real-valued features will be
+                 treated as categorical for each distinct value.
+                 Apply feature discretizer before using this function.
+        """
+        jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data)
+        return ChiSqSelectorModel(jmodel)
+
+
 class HashingTF(object):
     """
     .. note:: Experimental
-- 
GitLab