From f5e10a34e644edf3cbce9a7714d31bc433f3ccbd Mon Sep 17 00:00:00 2001
From: WeichenXu <weichen.xu@databricks.com>
Date: Thu, 31 Aug 2017 16:25:10 -0700
Subject: [PATCH] [SPARK-21862][ML] Add overflow check in PCA

## What changes were proposed in this pull request?

add overflow check in PCA, otherwise it is possible to throw `NegativeArraySizeException` when `k` and `numFeatures` are too large.
The overflow checking formula is here:
https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/linalg/functions/svd.scala#L87

## How was this patch tested?

N/A

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19078 from WeichenXu123/SVD_overflow_check.
---
 .../org/apache/spark/mllib/feature/PCA.scala  | 19 +++++++++++++++++++
 .../apache/spark/mllib/feature/PCASuite.scala |  6 ++++++
 2 files changed, 25 insertions(+)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
index aaecfa8d45..a01503f4b8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
@@ -44,6 +44,11 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
     require(k <= numFeatures,
       s"source vector size $numFeatures must be no less than k=$k")
 
+    require(PCAUtil.memoryCost(k, numFeatures) < Int.MaxValue,
+      "The param k and numFeatures is too large for SVD computation. " +
+      "Try reducing the parameter k for PCA, or reduce the input feature " +
+      "vector dimension to make this tractable.")
+
     val mat = new RowMatrix(sources)
     val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
     val densePC = pc match {
@@ -110,3 +115,17 @@ class PCAModel private[spark] (
     }
   }
 }
+
+private[feature] object PCAUtil {
+
+  // This memory cost formula is from breeze code:
+  // https://github.com/scalanlp/breeze/blob/
+  // 6e541be066d547a097f5089165cd7c38c3ca276d/math/src/main/scala/breeze/linalg/
+  // functions/svd.scala#L87
+  def memoryCost(k: Int, numFeatures: Int): Long = {
+    3L * math.min(k, numFeatures) * math.min(k, numFeatures)
+    + math.max(math.max(k, numFeatures), 4L * math.min(k, numFeatures)
+    * math.min(k, numFeatures) + 4L * math.min(k, numFeatures))
+  }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
index 2f90afdcee..8eab12416a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
@@ -48,4 +48,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
     }
     assert(pca.explainedVariance ~== explainedVariance relTol 1e-8)
   }
+
+  test("memory cost computation") {
+    assert(PCAUtil.memoryCost(10, 100) < Int.MaxValue)
+    // check overflowing
+    assert(PCAUtil.memoryCost(40000, 60000) > Int.MaxValue)
+  }
 }
-- 
GitLab