diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 36d8cadd2bdd7787c868f80e4c8e0413969c805c..181f507516485ae0b43e31970c0c36139d9b8318 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -102,6 +102,9 @@ class IndexedRowMatrix( k: Int, computeU: Boolean = false, rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = { + + val n = numCols().toInt + require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.") val indices = rows.map(_.index) val svd = toRowMatrix().computeSVD(k, computeU, rCond) val U = if (computeU) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index fbd35e372f9b10c22d4ebc79542488d44ae0b2c6..d5abba6a4b645588f92099be8bae52b7980609df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -212,7 +212,7 @@ class RowMatrix( tol: Double, mode: String): SingularValueDecomposition[RowMatrix, Matrix] = { val n = numCols().toInt - require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.") + require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.") object SVDMode extends Enumeration { val LocalARPACK, LocalLAPACK, DistARPACK = Value diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index e25bc02b06c9a08d8584dc2375c1f9c34e8ea880..741cd4997b853b332a10660f430e880fbc6790bf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -113,6 +113,13 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(closeToZero(U * brzDiag(s) * V.t - localA)) } + test("validate k in svd") { + val A = new IndexedRowMatrix(indexedRows) + intercept[IllegalArgumentException] { + A.computeSVD(-1) + } + } + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index dbf55ff81ca99ba7783273ac03261db79b321986..3309713e91f875dc3e494b28eac378e04f089f67 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -171,6 +171,14 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } } + test("validate k in svd") { + for (mat <- Seq(denseMat, sparseMat)) { + intercept[IllegalArgumentException] { + mat.computeSVD(-1) + } + } + } + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 }