Skip to content
Snippets Groups Projects
Commit c77f4066 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-3087][MLLIB] fix col indexing bug in chi-square and add a check for...

[SPARK-3087][MLLIB] fix col indexing bug in chi-square and add a check for number of distinct values

There is a bug determining the column index. dorx

Author: Xiangrui Meng <meng@databricks.com>

Closes #1997 from mengxr/chisq-index and squashes the following commits:

8fc2ab2 [Xiangrui Meng] fix col indexing bug and add a check for number of distinct values
parent 95470a03
No related branches found
No related tags found
No related merge requests found
......@@ -155,7 +155,7 @@ object Statistics {
* :: Experimental ::
* Conduct Pearson's independence test for every feature against the label across the input RDD.
* For each feature, the (feature, label) pairs are converted into a contingency matrix for which
* the chi-squared statistic is computed.
* the chi-squared statistic is computed. All label and feature values must be categorical.
*
* @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features.
* Real-valued features will be treated as categorical for each distinct value.
......
......@@ -20,11 +20,13 @@ package org.apache.spark.mllib.stat.test
import breeze.linalg.{DenseMatrix => BDM}
import cern.jet.stat.Probability.chiSquareComplemented
import org.apache.spark.Logging
import org.apache.spark.{SparkException, Logging}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import scala.collection.mutable
/**
* Conduct the chi-squared test for the input RDDs using the specified method.
* Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted
......@@ -75,21 +77,42 @@ private[stat] object ChiSqTest extends Logging {
*/
def chiSquaredFeatures(data: RDD[LabeledPoint],
methodName: String = PEARSON.name): Array[ChiSqTestResult] = {
val maxCategories = 10000
val numCols = data.first().features.size
val results = new Array[ChiSqTestResult](numCols)
var labels: Map[Double, Int] = null
// At most 100 columns at a time
val batchSize = 100
// at most 1000 columns at a time
val batchSize = 1000
var batch = 0
while (batch * batchSize < numCols) {
// The following block of code can be cleaned up and made public as
// chiSquared(data: RDD[(V1, V2)])
val startCol = batch * batchSize
val endCol = startCol + math.min(batchSize, numCols - startCol)
val pairCounts = data.flatMap { p =>
// assume dense vectors
p.features.toArray.slice(startCol, endCol).zipWithIndex.map { case (feature, col) =>
(col, feature, p.label)
val pairCounts = data.mapPartitions { iter =>
val distinctLabels = mutable.HashSet.empty[Double]
val allDistinctFeatures: Map[Int, mutable.HashSet[Double]] =
Map((startCol until endCol).map(col => (col, mutable.HashSet.empty[Double])): _*)
var i = 1
iter.flatMap { case LabeledPoint(label, features) =>
if (i % 1000 == 0) {
if (distinctLabels.size > maxCategories) {
throw new SparkException(s"Chi-square test expect factors (categorical values) but "
+ s"found more than $maxCategories distinct label values.")
}
allDistinctFeatures.foreach { case (col, distinctFeatures) =>
if (distinctFeatures.size > maxCategories) {
throw new SparkException(s"Chi-square test expect factors (categorical values) but "
+ s"found more than $maxCategories distinct values in column $col.")
}
}
}
i += 1
distinctLabels += label
features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) =>
allDistinctFeatures(col) += feature
(col, feature, label)
}
}
}.countByValue()
......
......@@ -17,8 +17,11 @@
package org.apache.spark.mllib.stat
import java.util.Random
import org.scalatest.FunSuite
import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.test.ChiSqTest
......@@ -107,12 +110,13 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
// labels: 1.0 (2 / 6), 0.0 (4 / 6)
// feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6)
// feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6)
val data = Array(new LabeledPoint(0.0, Vectors.dense(0.5, 10.0)),
new LabeledPoint(0.0, Vectors.dense(1.5, 20.0)),
new LabeledPoint(1.0, Vectors.dense(1.5, 30.0)),
new LabeledPoint(0.0, Vectors.dense(3.5, 30.0)),
new LabeledPoint(0.0, Vectors.dense(3.5, 40.0)),
new LabeledPoint(1.0, Vectors.dense(3.5, 40.0)))
val data = Seq(
LabeledPoint(0.0, Vectors.dense(0.5, 10.0)),
LabeledPoint(0.0, Vectors.dense(1.5, 20.0)),
LabeledPoint(1.0, Vectors.dense(1.5, 30.0)),
LabeledPoint(0.0, Vectors.dense(3.5, 30.0)),
LabeledPoint(0.0, Vectors.dense(3.5, 40.0)),
LabeledPoint(1.0, Vectors.dense(3.5, 40.0)))
for (numParts <- List(2, 4, 6, 8)) {
val chi = Statistics.chiSqTest(sc.parallelize(data, numParts))
val feature1 = chi(0)
......@@ -130,10 +134,25 @@ class HypothesisTestSuite extends FunSuite with LocalSparkContext {
}
// Test that the right number of results is returned
val numCols = 321
val sparseData = Array(new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))),
new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((200, 1.0)))))
val numCols = 1001
val sparseData = Array(
new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))),
new LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0)))))
val chi = Statistics.chiSqTest(sc.parallelize(sparseData))
assert(chi.size === numCols)
assert(chi(1000) != null) // SPARK-3087
// Detect continous features or labels
val random = new Random(11L)
val continuousLabel =
Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2))))
intercept[SparkException] {
Statistics.chiSqTest(sc.parallelize(continuousLabel, 2))
}
val continuousFeature =
Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble())))
intercept[SparkException] {
Statistics.chiSqTest(sc.parallelize(continuousFeature, 2))
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment