Skip to content
Snippets Groups Projects
Commit 0e821ec6 authored by sethah's avatar sethah Committed by Yanbo Liang
Browse files

[SPARK-19313][ML][MLLIB] GaussianMixture should limit the number of features

## What changes were proposed in this pull request?

The following test will fail on current master

````scala
test("gmm fails on high dimensional data") {
    val ctx = spark.sqlContext
    import ctx.implicits._
    val df = Seq(
      Vectors.sparse(GaussianMixture.MAX_NUM_FEATURES + 1, Array(0, 4), Array(3.0, 8.0)),
      Vectors.sparse(GaussianMixture.MAX_NUM_FEATURES + 1, Array(1, 5), Array(4.0, 9.0)))
      .map(Tuple1.apply).toDF("features")
    val gm = new GaussianMixture()
    intercept[IllegalArgumentException] {
      gm.fit(df)
    }
  }
````

Instead, you'll get an `ArrayIndexOutOfBoundsException` or something similar for MLlib. That's because the covariance matrix allocates an array of `numFeatures * numFeatures`, and in this case we get integer overflow. While there is currently a warning that the algorithm does not perform well for high number of features, we should perform an appropriate check to communicate this limitation to users.

This patch adds a `require(numFeatures < GaussianMixture.MAX_NUM_FEATURES)` check to ML and MLlib algorithms. For the feature limitation, we can limit it such that we do not get numerical overflow to something like `math.sqrt(Integer.MaxValue).toInt` (about 46k) which eliminates the cryptic error. However in, for example WLS, we need to collect an array on the order of `numFeatures * numFeatures` to the driver and we therefore limit to 4096 features. We may want to keep that convention here for consistency.

## How was this patch tested?
Unit tests in ML and MLlib.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #16661 from sethah/gmm_high_dim.
parent 76db394f
No related branches found
No related tags found
No related merge requests found
......@@ -278,7 +278,9 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
* While this process is generally guaranteed to converge, it is not guaranteed
* to find a global optimum.
*
* @note For high-dimensional data (with many features), this algorithm may perform poorly.
* @note This algorithm is limited in its number of features since it requires storing a covariance
* matrix which has size quadratic in the number of features. Even when the number of features does
* not exceed this limit, this algorithm may perform poorly on high-dimensional data.
* This is due to high-dimensional data (a) making it difficult to cluster at all (based
* on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions.
*/
......@@ -344,6 +346,9 @@ class GaussianMixture @Since("2.0.0") (
// Extract the number of features.
val numFeatures = instances.first().size
require(numFeatures < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " +
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")
val instr = Instrumentation.create(this, instances)
instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol)
......@@ -391,8 +396,8 @@ class GaussianMixture @Since("2.0.0") (
val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, cov, weight) =>
GaussianMixture.updateWeightsAndGaussians(mean, cov, weight, sumWeights)
}.collect().unzip
Array.copy(ws.toArray, 0, weights, 0, ws.length)
Array.copy(gs.toArray, 0, gaussians, 0, gs.length)
Array.copy(ws, 0, weights, 0, ws.length)
Array.copy(gs, 0, gaussians, 0, gs.length)
} else {
var i = 0
while (i < numClusters) {
......@@ -486,6 +491,9 @@ class GaussianMixture @Since("2.0.0") (
@Since("2.0.0")
object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
/** Limit number of features such that numFeatures^2^ < Int.MaxValue */
private[clustering] val MAX_NUM_FEATURES = math.sqrt(Int.MaxValue).toInt
@Since("2.0.0")
override def load(path: String): GaussianMixture = super.load(path)
......
......@@ -46,7 +46,9 @@ import org.apache.spark.util.Utils
* is considered to have occurred.
* @param maxIterations Maximum number of iterations allowed.
*
* @note For high-dimensional data (with many features), this algorithm may perform poorly.
* @note This algorithm is limited in its number of features since it requires storing a covariance
* matrix which has size quadratic in the number of features. Even when the number of features does
* not exceed this limit, this algorithm may perform poorly on high-dimensional data.
* This is due to high-dimensional data (a) making it difficult to cluster at all (based
* on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions.
*/
......@@ -170,6 +172,9 @@ class GaussianMixture private (
// Get length of the input vectors
val d = breezeData.first().length
require(d < GaussianMixture.MAX_NUM_FEATURES, s"GaussianMixture cannot handle more " +
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")
val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(k, d)
......@@ -211,8 +216,8 @@ class GaussianMixture private (
val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) =>
updateWeightsAndGaussians(mean, sigma, weight, sumWeights)
}.collect().unzip
Array.copy(ws.toArray, 0, weights, 0, ws.length)
Array.copy(gs.toArray, 0, gaussians, 0, gs.length)
Array.copy(ws, 0, weights, 0, ws.length)
Array.copy(gs, 0, gaussians, 0, gs.length)
} else {
var i = 0
while (i < k) {
......@@ -272,6 +277,10 @@ class GaussianMixture private (
}
private[clustering] object GaussianMixture {
/** Limit number of features such that numFeatures^2^ < Int.MaxValue */
private[clustering] val MAX_NUM_FEATURES = math.sqrt(Int.MaxValue).toInt
/**
* Heuristic to distribute the computation of the `MultivariateGaussian`s, approximately when
* d is greater than 25 except for when k is very small.
......
......@@ -53,6 +53,20 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
rDataset = rData.map(FeatureData).toDF()
}
test("gmm fails on high dimensional data") {
val df = Seq(
Vectors.sparse(GaussianMixture.MAX_NUM_FEATURES + 1, Array(0, 4), Array(3.0, 8.0)),
Vectors.sparse(GaussianMixture.MAX_NUM_FEATURES + 1, Array(1, 5), Array(4.0, 9.0)))
.map(Tuple1.apply).toDF("features")
val gm = new GaussianMixture()
withClue(s"GMM should restrict the maximum number of features to be < " +
s"${GaussianMixture.MAX_NUM_FEATURES}") {
intercept[IllegalArgumentException] {
gm.fit(df)
}
}
}
test("default parameters") {
val gm = new GaussianMixture()
......
......@@ -25,6 +25,20 @@ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext {
test("gmm fails on high dimensional data") {
val rdd = sc.parallelize(Seq(
Vectors.sparse(GaussianMixture.MAX_NUM_FEATURES + 1, Array(0, 4), Array(3.0, 8.0)),
Vectors.sparse(GaussianMixture.MAX_NUM_FEATURES + 1, Array(1, 5), Array(4.0, 9.0))))
val gm = new GaussianMixture()
withClue(s"GMM should restrict the maximum number of features to be < " +
s"${GaussianMixture.MAX_NUM_FEATURES}") {
intercept[IllegalArgumentException] {
gm.run(rdd)
}
}
}
test("single cluster") {
val data = sc.parallelize(Array(
Vectors.dense(6.0, 9.0),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment