From b88cb63da39786c07cb4bfa70afed32ec5eb3286 Mon Sep 17 00:00:00 2001 From: Sean Owen <sowen@cloudera.com> Date: Sat, 1 Oct 2016 16:10:39 -0400 Subject: [PATCH] [SPARK-17704][ML][MLLIB] ChiSqSelector performance improvement. ## What changes were proposed in this pull request? Partial revert of #15277 to instead sort and store input to model rather than require sorted input ## How was this patch tested? Existing tests. Author: Sean Owen <sowen@cloudera.com> Closes #15299 from srowen/SPARK-17704.2. --- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/mllib/feature/ChiSqSelector.scala | 22 +++++++++---------- python/pyspark/ml/feature.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 9c131a4185..d0385e220e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] ( import ChiSqSelectorModel._ - /** list of indices to select (filter). Must be ordered asc */ + /** list of indices to select (filter). */ @Since("1.6.0") val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 706ce78f26..c305b36278 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). Must be ordered asc + * @param selectedFeatures list of indices to select (filter). */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { - require(isSorted(selectedFeatures), "Array has to be sorted asc") + private val filterIndices = selectedFeatures.sorted + @deprecated("not intended for subclasses to use", "2.1.0") protected def isSorted(array: Array[Int]): Boolean = { var i = 1 val len = array.length @@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( */ @Since("1.3.0") override def transform(vector: Vector): Vector = { - compress(vector, selectedFeatures) + compress(vector) } /** @@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter, must be ordered asc */ - private def compress(features: Vector, filterIndices: Array[Int]): Vector = { + private def compress(features: Vector): Vector = { features match { case SparseVector(size, indices, values) => val newSize = filterIndices.length @@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { */ @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { - val chiSqTestResult = Statistics.chiSqTest(data) + val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { case ChiSqSelector.KBest => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take(numTopFeatures) case ChiSqSelector.Percentile => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => - chiSqTestResult.zipWithIndex - .filter{ case (res, _) => res.pValue < alpha } + chiSqTestResult + .filter { case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } - val indices = features.map { case (_, indices) => indices }.sorted + val indices = features.map { case (_, index) => index } new ChiSqSelectorModel(indices) } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 12a13849dc..64b21caa61 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): @since("2.0.0") def selectedFeatures(self): """ - List of indices to select (filter). Must be ordered asc. + List of indices to select (filter). """ return self._call_java("selectedFeatures") -- GitLab