From f7082ac12518ae84d6d1d4b7330a9f12cf95e7c1 Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Thu, 29 Sep 2016 04:30:42 -0700
Subject: [PATCH] [SPARK-17704][ML][MLLIB] ChiSqSelector performance
 improvement.

## What changes were proposed in this pull request?
Several performance improvement for ```ChiSqSelector```:
1, Keep ```selectedFeatures``` ordered ascendent.
```ChiSqSelectorModel.transform``` need ```selectedFeatures``` ordered to make prediction. We should sort it when training model rather than making prediction, since users usually train model once and use the model to do prediction multiple times.
2, When training ```fpr``` type ```ChiSqSelectorModel```, it's not necessary to sort the ChiSq test result by statistic.

## How was this patch tested?
Existing unit tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #15277 from yanboliang/spark-17704.
---
 .../spark/mllib/feature/ChiSqSelector.scala   | 45 ++++++++++++-------
 project/MimaExcludes.scala                    |  3 --
 2 files changed, 30 insertions(+), 18 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index 0f7c6e8bc0..706ce78f26 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession}
 /**
  * Chi Squared selector model.
  *
- * @param selectedFeatures list of indices to select (filter).
+ * @param selectedFeatures list of indices to select (filter). Must be ordered asc
  */
 @Since("1.3.0")
 class ChiSqSelectorModel @Since("1.3.0") (
   @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable {
 
+  require(isSorted(selectedFeatures), "Array has to be sorted asc")
+
+  protected def isSorted(array: Array[Int]): Boolean = {
+    var i = 1
+    val len = array.length
+    while (i < len) {
+      if (array(i) < array(i-1)) return false
+      i += 1
+    }
+    true
+  }
+
   /**
    * Applies transformation on a vector.
    *
@@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") (
    * Preserves the order of filtered features the same as their indices are stored.
    * Might be moved to Vector as .slice
    * @param features vector
-   * @param filterIndices indices of features to filter
+   * @param filterIndices indices of features to filter, must be ordered asc
    */
   private def compress(features: Vector, filterIndices: Array[Int]): Vector = {
-    val orderedIndices = filterIndices.sorted
     features match {
       case SparseVector(size, indices, values) =>
-        val newSize = orderedIndices.length
+        val newSize = filterIndices.length
         val newValues = new ArrayBuilder.ofDouble
         val newIndices = new ArrayBuilder.ofInt
         var i = 0
         var j = 0
         var indicesIdx = 0
         var filterIndicesIdx = 0
-        while (i < indices.length && j < orderedIndices.length) {
+        while (i < indices.length && j < filterIndices.length) {
           indicesIdx = indices(i)
-          filterIndicesIdx = orderedIndices(j)
+          filterIndicesIdx = filterIndices(j)
           if (indicesIdx == filterIndicesIdx) {
             newIndices += j
             newValues += values(i)
@@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") (
         Vectors.sparse(newSize, newIndices.result(), newValues.result())
       case DenseVector(values) =>
         val values = features.toArray
-        Vectors.dense(orderedIndices.map(i => values(i)))
+        Vectors.dense(filterIndices.map(i => values(i)))
       case other =>
         throw new UnsupportedOperationException(
           s"Only sparse and dense vectors are supported but got ${other.getClass}.")
@@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   @Since("1.3.0")
   def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = {
     val chiSqTestResult = Statistics.chiSqTest(data)
-      .zipWithIndex.sortBy { case (res, _) => -res.statistic }
     val features = selectorType match {
-      case ChiSqSelector.KBest => chiSqTestResult
-        .take(numTopFeatures)
-      case ChiSqSelector.Percentile => chiSqTestResult
-        .take((chiSqTestResult.length * percentile).toInt)
-      case ChiSqSelector.FPR => chiSqTestResult
-        .filter{ case (res, _) => res.pValue < alpha }
+      case ChiSqSelector.KBest =>
+        chiSqTestResult.zipWithIndex
+          .sortBy { case (res, _) => -res.statistic }
+          .take(numTopFeatures)
+      case ChiSqSelector.Percentile =>
+        chiSqTestResult.zipWithIndex
+          .sortBy { case (res, _) => -res.statistic }
+          .take((chiSqTestResult.length * percentile).toInt)
+      case ChiSqSelector.FPR =>
+        chiSqTestResult.zipWithIndex
+          .filter{ case (res, _) => res.pValue < alpha }
       case errorType =>
         throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
     }
-    val indices = features.map { case (_, indices) => indices }
+    val indices = features.map { case (_, indices) => indices }.sorted
     new ChiSqSelectorModel(indices)
   }
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8024fbd21b..4db3edb733 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -817,9 +817,6 @@ object MimaExcludes {
     ) ++ Seq(
       // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
       ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
-    ) ++ Seq(
-      // [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test
-      ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted")
     ) ++ Seq(
       // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
       ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")
-- 
GitLab