diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index bbbcc8436b7c29e14095a6556e703b44c1c86be0..ab475af264dd33311441c6046679044713226bfc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -305,6 +305,8 @@ private[spark] object BLAS extends Serializable with Logging { "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0 && beta == 1.0) { logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.") + } else if (alpha == 0.0) { + f2jBLAS.dscal(C.values.length, beta, C.values, 1) } else { A match { case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) @@ -408,8 +410,8 @@ private[spark] object BLAS extends Serializable with Logging { } } } else { - // Scale matrix first if `beta` is not equal to 0.0 - if (beta != 0.0) { + // Scale matrix first if `beta` is not equal to 1.0 + if (beta != 1.0) { f2jBLAS.dscal(C.values.length, beta, C.values, 1) } // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of @@ -470,8 +472,10 @@ private[spark] object BLAS extends Serializable with Logging { s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") require(A.numRows == y.size, s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}") - if (alpha == 0.0) { - logDebug("gemv: alpha is equal to 0. Returning y.") + if (alpha == 0.0 && beta == 1.0) { + logDebug("gemv: alpha is equal to 0 and beta is equal to 1. Returning y.") + } else if (alpha == 0.0) { + scal(beta, y) } else { (A, x) match { case (smA: SparseMatrix, dvx: DenseVector) => @@ -526,11 +530,6 @@ private[spark] object BLAS extends Serializable with Logging { val xValues = x.values val yValues = y.values - if (alpha == 0.0) { - scal(beta, y) - return - } - if (A.isTransposed) { var rowCounterForA = 0 while (rowCounterForA < mA) { @@ -581,11 +580,6 @@ private[spark] object BLAS extends Serializable with Logging { val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices - if (alpha == 0.0) { - scal(beta, y) - return - } - if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { @@ -604,7 +598,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - scal(beta, y) + if (beta != 1.0) scal(beta, y) var colCounterForA = 0 var k = 0 @@ -659,7 +653,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - scal(beta, y) + if (beta != 1.0) scal(beta, y) // Perform matrix-vector multiplication and add to y var colCounterForA = 0 while (colCounterForA < nA) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index d119e0b50a3936c41df3a5e4cbd7a51297a2ef33..8db5c8424abe95f10e859052c1a46f0c44ea217e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -204,6 +204,7 @@ class BLASSuite extends SparkFunSuite { val C14 = C1.copy val C15 = C1.copy val C16 = C1.copy + val C17 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0)) @@ -217,6 +218,10 @@ class BLASSuite extends SparkFunSuite { assert(C2 ~== expected2 absTol 1e-15) assert(C3 ~== expected3 absTol 1e-15) assert(C4 ~== expected3 absTol 1e-15) + gemm(1.0, dA, B, 0.0, C17) + assert(C17 ~== expected absTol 1e-15) + gemm(1.0, sA, B, 0.0, C17) + assert(C17 ~== expected absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] {