Skip to content
Snippets Groups Projects
Commit f5e10a34 authored by WeichenXu's avatar WeichenXu Committed by Joseph K. Bradley
Browse files

[SPARK-21862][ML] Add overflow check in PCA

## What changes were proposed in this pull request?

add overflow check in PCA, otherwise it is possible to throw `NegativeArraySizeException` when `k` and `numFeatures` are too large.
The overflow checking formula is here:
https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/linalg/functions/svd.scala#L87

## How was this patch tested?

N/A

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19078 from WeichenXu123/SVD_overflow_check.
parent 96028e36
No related branches found
No related tags found
No related merge requests found
......@@ -44,6 +44,11 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
require(k <= numFeatures,
s"source vector size $numFeatures must be no less than k=$k")
require(PCAUtil.memoryCost(k, numFeatures) < Int.MaxValue,
"The param k and numFeatures is too large for SVD computation. " +
"Try reducing the parameter k for PCA, or reduce the input feature " +
"vector dimension to make this tractable.")
val mat = new RowMatrix(sources)
val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
val densePC = pc match {
......@@ -110,3 +115,17 @@ class PCAModel private[spark] (
}
}
}
private[feature] object PCAUtil {
// This memory cost formula is from breeze code:
// https://github.com/scalanlp/breeze/blob/
// 6e541be066d547a097f5089165cd7c38c3ca276d/math/src/main/scala/breeze/linalg/
// functions/svd.scala#L87
def memoryCost(k: Int, numFeatures: Int): Long = {
3L * math.min(k, numFeatures) * math.min(k, numFeatures)
+ math.max(math.max(k, numFeatures), 4L * math.min(k, numFeatures)
* math.min(k, numFeatures) + 4L * math.min(k, numFeatures))
}
}
......@@ -48,4 +48,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
}
assert(pca.explainedVariance ~== explainedVariance relTol 1e-8)
}
test("memory cost computation") {
assert(PCAUtil.memoryCost(10, 100) < Int.MaxValue)
// check overflowing
assert(PCAUtil.memoryCost(40000, 60000) > Int.MaxValue)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment