diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index bbbcc8436b7c29e14095a6556e703b44c1c86be0..ab475af264dd33311441c6046679044713226bfc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -305,6 +305,8 @@ private[spark] object BLAS extends Serializable with Logging {
       "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
     if (alpha == 0.0 && beta == 1.0) {
       logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
+    } else if (alpha == 0.0) {
+      f2jBLAS.dscal(C.values.length, beta, C.values, 1)
     } else {
       A match {
         case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)
@@ -408,8 +410,8 @@ private[spark] object BLAS extends Serializable with Logging {
         }
       }
     } else {
-      // Scale matrix first if `beta` is not equal to 0.0
-      if (beta != 0.0) {
+      // Scale matrix first if `beta` is not equal to 1.0
+      if (beta != 1.0) {
         f2jBLAS.dscal(C.values.length, beta, C.values, 1)
       }
       // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
@@ -470,8 +472,10 @@ private[spark] object BLAS extends Serializable with Logging {
       s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}")
     require(A.numRows == y.size,
       s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}")
-    if (alpha == 0.0) {
-      logDebug("gemv: alpha is equal to 0. Returning y.")
+    if (alpha == 0.0 && beta == 1.0) {
+      logDebug("gemv: alpha is equal to 0 and beta is equal to 1. Returning y.")
+    } else if (alpha == 0.0) {
+      scal(beta, y)
     } else {
       (A, x) match {
         case (smA: SparseMatrix, dvx: DenseVector) =>
@@ -526,11 +530,6 @@ private[spark] object BLAS extends Serializable with Logging {
     val xValues = x.values
     val yValues = y.values
 
-    if (alpha == 0.0) {
-      scal(beta, y)
-      return
-    }
-
     if (A.isTransposed) {
       var rowCounterForA = 0
       while (rowCounterForA < mA) {
@@ -581,11 +580,6 @@ private[spark] object BLAS extends Serializable with Logging {
     val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
     val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices
 
-    if (alpha == 0.0) {
-      scal(beta, y)
-      return
-    }
-
     if (A.isTransposed) {
       var rowCounter = 0
       while (rowCounter < mA) {
@@ -604,7 +598,7 @@ private[spark] object BLAS extends Serializable with Logging {
         rowCounter += 1
       }
     } else {
-      scal(beta, y)
+      if (beta != 1.0) scal(beta, y)
 
       var colCounterForA = 0
       var k = 0
@@ -659,7 +653,7 @@ private[spark] object BLAS extends Serializable with Logging {
         rowCounter += 1
       }
     } else {
-      scal(beta, y)
+      if (beta != 1.0) scal(beta, y)
       // Perform matrix-vector multiplication and add to y
       var colCounterForA = 0
       while (colCounterForA < nA) {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index d119e0b50a3936c41df3a5e4cbd7a51297a2ef33..8db5c8424abe95f10e859052c1a46f0c44ea217e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -204,6 +204,7 @@ class BLASSuite extends SparkFunSuite {
     val C14 = C1.copy
     val C15 = C1.copy
     val C16 = C1.copy
+    val C17 = C1.copy
     val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
     val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
     val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
@@ -217,6 +218,10 @@ class BLASSuite extends SparkFunSuite {
     assert(C2 ~== expected2 absTol 1e-15)
     assert(C3 ~== expected3 absTol 1e-15)
     assert(C4 ~== expected3 absTol 1e-15)
+    gemm(1.0, dA, B, 0.0, C17)
+    assert(C17 ~== expected absTol 1e-15)
+    gemm(1.0, sA, B, 0.0, C17)
+    assert(C17 ~== expected absTol 1e-15)
 
     withClue("columns of A don't match the rows of B") {
       intercept[Exception] {