Skip to content
Snippets Groups Projects
Commit c77f4066 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-3087][MLLIB] fix col indexing bug in chi-square and add a check for...

[SPARK-3087][MLLIB] fix col indexing bug in chi-square and add a check for number of distinct values

There is a bug determining the column index. dorx

Author: Xiangrui Meng <meng@databricks.com>

Closes #1997 from mengxr/chisq-index and squashes the following commits:

8fc2ab2 [Xiangrui Meng] fix col indexing bug and add a check for number of distinct values
parent 95470a03
No related branches found
No related tags found
No related merge requests found
...@@ -155,7 +155,7 @@ object Statistics { ...@@ -155,7 +155,7 @@ object Statistics {
* :: Experimental :: * :: Experimental ::
* Conduct Pearson's independence test for every feature against the label across the input RDD. * Conduct Pearson's independence test for every feature against the label across the input RDD.
* For each feature, the (feature, label) pairs are converted into a contingency matrix for which * For each feature, the (feature, label) pairs are converted into a contingency matrix for which
* the chi-squared statistic is computed. * the chi-squared statistic is computed. All label and feature values must be categorical.
* *
* @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features.
* Real-valued features will be treated as categorical for each distinct value. * Real-valued features will be treated as categorical for each distinct value.
......
...@@ -20,11 +20,13 @@ package org.apache.spark.mllib.stat.test ...@@ -20,11 +20,13 @@ package org.apache.spark.mllib.stat.test
import breeze.linalg.{DenseMatrix => BDM} import breeze.linalg.{DenseMatrix => BDM}
import cern.jet.stat.Probability.chiSquareComplemented import cern.jet.stat.Probability.chiSquareComplemented
import org.apache.spark.Logging import org.apache.spark.{SparkException, Logging}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import scala.collection.mutable
/** /**
* Conduct the chi-squared test for the input RDDs using the specified method. * Conduct the chi-squared test for the input RDDs using the specified method.
* Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted
...@@ -75,21 +77,42 @@ private[stat] object ChiSqTest extends Logging { ...@@ -75,21 +77,42 @@ private[stat] object ChiSqTest extends Logging {
*/ */
def chiSquaredFeatures(data: RDD[LabeledPoint], def chiSquaredFeatures(data: RDD[LabeledPoint],
methodName: String = PEARSON.name): Array[ChiSqTestResult] = { methodName: String = PEARSON.name): Array[ChiSqTestResult] = {
val maxCategories = 10000
val numCols = data.first().features.size val numCols = data.first().features.size
val results = new Array[ChiSqTestResult](numCols) val results = new Array[ChiSqTestResult](numCols)
var labels: Map[Double, Int] = null var labels: Map[Double, Int] = null
// At most 100 columns at a time // at most 1000 columns at a time
val batchSize = 100 val batchSize = 1000
var batch = 0 var batch = 0
while (batch * batchSize < numCols) { while (batch * batchSize < numCols) {
// The following block of code can be cleaned up and made public as // The following block of code can be cleaned up and made public as
// chiSquared(data: RDD[(V1, V2)]) // chiSquared(data: RDD[(V1, V2)])
val startCol = batch * batchSize val startCol = batch * batchSize
val endCol = startCol + math.min(batchSize, numCols - startCol) val endCol = startCol + math.min(batchSize, numCols - startCol)
val pairCounts = data.flatMap { p => val pairCounts = data.mapPartitions { iter =>
// assume dense vectors val distinctLabels = mutable.HashSet.empty[Double]
p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) => val allDistinctFeatures: Map[Int, mutable.HashSet[Double]] =
(col, feature, p.label) Map((startCol until endCol).map(col => (col, mutable.HashSet.empty[Double])): _*)
var i = 1
iter.flatMap { case LabeledPoint(label, features) =>
if (i % 1000 == 0) {
if (distinctLabels.size > maxCategories) {
throw new SparkException(s"Chi-square test expect factors (categorical values) but "
+ s"found more than $maxCategories distinct label values.")
}
allDistinctFeatures.foreach { case (col, distinctFeatures) =>
if (distinctFeatures.size > maxCategories) {
throw new SparkException(s"Chi-square test expect factors (categorical values) but "
+ s"found more than $maxCategories distinct values in column $col.")
}
}
}
i += 1
distinctLabels += label
features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) =>
allDistinctFeatures(col) += feature
(col, feature, label)
}
} }
}.countByValue() }.countByValue()
......
...@@ -17,8 +17,11 @@ ...@@ -17,8 +17,11 @@
package org.apache.spark.mllib.stat package org.apache.spark.mllib.stat
import java.util.Random
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.stat.test.ChiSqTest
...@@ -107,12 +110,13 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext { ...@@ -107,12 +110,13 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
// labels: 1.0 (2 / 6), 0.0 (4 / 6) // labels: 1.0 (2 / 6), 0.0 (4 / 6)
// feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6)
// feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6)
val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), val data = Seq(
new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), LabeledPoint(0.0, Vectors.dense(0.5, 10.0)),
new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), LabeledPoint(0.0, Vectors.dense(1.5, 20.0)),
new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), LabeledPoint(1.0, Vectors.dense(1.5, 30.0)),
new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), LabeledPoint(0.0, Vectors.dense(3.5, 30.0)),
new LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) LabeledPoint(0.0, Vectors.dense(3.5, 40.0)),
LabeledPoint(1.0, Vectors.dense(3.5, 40.0)))
for (numParts <- List(2, 4, 6, 8)) { for (numParts <- List(2, 4, 6, 8)) {
val chi = Statistics.chiSqTest(sc.parallelize(data, numParts)) val chi = Statistics.chiSqTest(sc.parallelize(data, numParts))
val feature1 = chi(0) val feature1 = chi(0)
...@@ -130,10 +134,25 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext { ...@@ -130,10 +134,25 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
} }
// Test that the right number of results is returned // Test that the right number of results is returned
val numCols = 321 val numCols = 1001
val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), val sparseData = Array(
new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0))))) new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))),
new LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0)))))
val chi = Statistics.chiSqTest(sc.parallelize(sparseData)) val chi = Statistics.chiSqTest(sc.parallelize(sparseData))
assert(chi.size === numCols) assert(chi.size === numCols)
assert(chi(1000) != null) // SPARK-3087
// Detect continous features or labels
val random = new Random(11L)
val continuousLabel =
Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2))))
intercept[SparkException] {
Statistics.chiSqTest(sc.parallelize(continuousLabel, 2))
}
val continuousFeature =
Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble())))
intercept[SparkException] {
Statistics.chiSqTest(sc.parallelize(continuousFeature, 2))
}
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment