diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index c8a97b8c53d9ba05c0568f830d3d019ff9ad8437..89b38679b74947b9858ae33de04044f8c6b1bc36 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -256,8 +256,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros */ - def zeros(numRows: Int, numCols: Int): DenseMatrix = + def zeros(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, new Array[Double](numRows * numCols)) + } /** * Generate a `DenseMatrix` consisting of ones. @@ -265,8 +268,11 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones */ - def ones(numRows: Int, numCols: Int): DenseMatrix = + def ones(numRows: Int, numCols: Int): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(1.0)) + } /** * Generate an Identity Matrix in `DenseMatrix` format. @@ -291,6 +297,8 @@ object DenseMatrix { * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } @@ -302,6 +310,8 @@ object DenseMatrix { * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { + require(numRows.toLong * numCols <= Int.MaxValue, + s"$numRows x $numCols dense matrix is too large to allocate") new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) }