diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 0f7c6e8bc04bb88891fe37f2b93cf69cfad1c56b..706ce78f260a6c44382dbb5a16211ee0e668ff56 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -35,12 +35,24 @@ import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). + * @param selectedFeatures list of indices to select (filter). Must be ordered asc */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { + require(isSorted(selectedFeatures), "Array has to be sorted asc") + + protected def isSorted(array: Array[Int]): Boolean = { + var i = 1 + val len = array.length + while (i < len) { + if (array(i) < array(i-1)) return false + i += 1 + } + true + } + /** * Applies transformation on a vector. * @@ -57,22 +69,21 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter + * @param filterIndices indices of features to filter, must be ordered asc */ private def compress(features: Vector, filterIndices: Array[Int]): Vector = { - val orderedIndices = filterIndices.sorted features match { case SparseVector(size, indices, values) => - val newSize = orderedIndices.length + val newSize = filterIndices.length val newValues = new ArrayBuilder.ofDouble val newIndices = new ArrayBuilder.ofInt var i = 0 var j = 0 var indicesIdx = 0 var filterIndicesIdx = 0 - while (i < indices.length && j < orderedIndices.length) { + while (i < indices.length && j < filterIndices.length) { indicesIdx = indices(i) - filterIndicesIdx = orderedIndices(j) + filterIndicesIdx = filterIndices(j) if (indicesIdx == filterIndicesIdx) { newIndices += j newValues += values(i) @@ -90,7 +101,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( Vectors.sparse(newSize, newIndices.result(), newValues.result()) case DenseVector(values) => val values = features.toArray - Vectors.dense(orderedIndices.map(i => values(i))) + Vectors.dense(filterIndices.map(i => values(i))) case other => throw new UnsupportedOperationException( s"Only sparse and dense vectors are supported but got ${other.getClass}.") @@ -220,18 +231,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val chiSqTestResult = Statistics.chiSqTest(data) - .zipWithIndex.sortBy { case (res, _) => -res.statistic } val features = selectorType match { - case ChiSqSelector.KBest => chiSqTestResult - .take(numTopFeatures) - case ChiSqSelector.Percentile => chiSqTestResult - .take((chiSqTestResult.length * percentile).toInt) - case ChiSqSelector.FPR => chiSqTestResult - .filter{ case (res, _) => res.pValue < alpha } + case ChiSqSelector.KBest => + chiSqTestResult.zipWithIndex + .sortBy { case (res, _) => -res.statistic } + .take(numTopFeatures) + case ChiSqSelector.Percentile => + chiSqTestResult.zipWithIndex + .sortBy { case (res, _) => -res.statistic } + .take((chiSqTestResult.length * percentile).toInt) + case ChiSqSelector.FPR => + chiSqTestResult.zipWithIndex + .filter{ case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } - val indices = features.map { case (_, indices) => indices } + val indices = features.map { case (_, indices) => indices }.sorted new ChiSqSelectorModel(indices) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8024fbd21bbfc8cc736118c89c417b649928c51e..4db3edb733a568d0bb1709e8a02fe1e48b21a29f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -817,9 +817,6 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") - ) ++ Seq( - // [SPARK-17017] Add chiSquare selector based on False Positive Rate (FPR) test - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")