From b88cb63da39786c07cb4bfa70afed32ec5eb3286 Mon Sep 17 00:00:00 2001
From: Sean Owen <sowen@cloudera.com>
Date: Sat, 1 Oct 2016 16:10:39 -0400
Subject: [PATCH] [SPARK-17704][ML][MLLIB] ChiSqSelector performance
 improvement.

## What changes were proposed in this pull request?

Partial revert of #15277 to instead sort and store input to model rather than require sorted input

## How was this patch tested?

Existing tests.

Author: Sean Owen <sowen@cloudera.com>

Closes #15299 from srowen/SPARK-17704.2.
---
 .../spark/ml/feature/ChiSqSelector.scala      |  2 +-
 .../spark/mllib/feature/ChiSqSelector.scala   | 22 +++++++++----------
 python/pyspark/ml/feature.py                  |  2 +-
 3 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 9c131a4185..d0385e220e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] (
 
   import ChiSqSelectorModel._
 
-  /** list of indices to select (filter). Must be ordered asc */
+  /** list of indices to select (filter). */
   @Since("1.6.0")
   val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 706ce78f26..c305b36278 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession}
 /**
  * Chi Squared selector model.
  *
- * @param selectedFeatures list of indices to select (filter). Must be ordered asc
+ * @param selectedFeatures list of indices to select (filter).
  */
 @Since("1.3.0")
 class ChiSqSelectorModel @Since("1.3.0") (
   @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
 
-  require(isSorted(selectedFeatures), "Array has to be sorted asc")
+  private val filterIndices = selectedFeatures.sorted
 
+  @deprecated("not intended for subclasses to use", "2.1.0")
   protected def isSorted(array: Array[Int]): Boolean = {
     var i = 1
     val len = array.length
@@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
    */
   @Since("1.3.0")
   override def transform(vector: Vector): Vector = {
-    compress(vector, selectedFeatures)
+    compress(vector)
   }
 
   /**
@@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") (
    * Preserves the order of filtered features the same as their indices are stored.
    * Might be moved to Vector as .slice
    * @param features vector
-   * @param filterIndices indices of features to filter, must be ordered asc
    */
-  private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
+  private def compress(features: Vector): Vector = {
     features match {
       case SparseVector(size, indices, values) =>
         val newSize = filterIndices.length
@@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
    */
   @Since("1.3.0")
   def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
-    val chiSqTestResult = Statistics.chiSqTest(data)
+    val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex
     val features = selectorType match {
       case ChiSqSelector.KBest =>
-        chiSqTestResult.zipWithIndex
+        chiSqTestResult
           .sortBy { case (res, _) => -res.statistic }
           .take(numTopFeatures)
       case ChiSqSelector.Percentile =>
-        chiSqTestResult.zipWithIndex
+        chiSqTestResult
           .sortBy { case (res, _) => -res.statistic }
           .take((chiSqTestResult.length * percentile).toInt)
       case ChiSqSelector.FPR =>
-        chiSqTestResult.zipWithIndex
-          .filter{ case (res, _) => res.pValue < alpha }
+        chiSqTestResult
+          .filter { case (res, _) => res.pValue < alpha }
       case errorType =>
         throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
     }
-    val indices = features.map { case (_, indices) => indices }.sorted
+    val indices = features.map { case (_, index) => index }
     new ChiSqSelectorModel(indices)
   }
 }
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 12a13849dc..64b21caa61 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
     @since("2.0.0")
     def selectedFeatures(self):
         """
-        List of indices to select (filter). Must be ordered asc.
+        List of indices to select (filter).
         """
         return self._call_java("selectedFeatures")
 
-- 
GitLab