diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index ff81a2f03e2a85b1c49dfa9a5150181480cd0982..20d68a34bf3eab6258b869bc517421315bdefc51 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -425,22 +425,27 @@ class BlockMatrix @Since("1.3.0") (
    */
   private[distributed] def simulateMultiply(
       other: BlockMatrix,
-      partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = {
-    val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached
-    val rightMatrix = other.blocks.keys.collect()
+      partitioner: GridPartitioner,
+      midDimSplitNum: Int): (BlockDestinations, BlockDestinations) = {
+    val leftMatrix = blockInfo.keys.collect()
+    val rightMatrix = other.blockInfo.keys.collect()
 
     val rightCounterpartsHelper = rightMatrix.groupBy(_._1).mapValues(_.map(_._2))
     val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) =>
       val rightCounterparts = rightCounterpartsHelper.getOrElse(colIndex, Array.empty[Int])
       val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b)))
-      ((rowIndex, colIndex), partitions.toSet)
+      val midDimSplitIndex = colIndex % midDimSplitNum
+      ((rowIndex, colIndex),
+        partitions.toSet.map((pid: Int) => pid * midDimSplitNum + midDimSplitIndex))
     }.toMap
 
     val leftCounterpartsHelper = leftMatrix.groupBy(_._2).mapValues(_.map(_._1))
     val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) =>
       val leftCounterparts = leftCounterpartsHelper.getOrElse(rowIndex, Array.empty[Int])
       val partitions = leftCounterparts.map(b => partitioner.getPartition((b, colIndex)))
-      ((rowIndex, colIndex), partitions.toSet)
+      val midDimSplitIndex = rowIndex % midDimSplitNum
+      ((rowIndex, colIndex),
+        partitions.toSet.map((pid: Int) => pid * midDimSplitNum + midDimSplitIndex))
     }.toMap
 
     (leftDestinations, rightDestinations)
@@ -459,14 +464,39 @@ class BlockMatrix @Since("1.3.0") (
    */
   @Since("1.3.0")
   def multiply(other: BlockMatrix): BlockMatrix = {
+    multiply(other, 1)
+  }
+
+  /**
+   * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock`
+   * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains
+   * `SparseMatrix`, they will have to be converted to a `DenseMatrix`. The output
+   * [[BlockMatrix]] will only consist of blocks of `DenseMatrix`. This may cause
+   * some performance issues until support for multiplying two sparse matrices is added.
+   * Blocks with duplicate indices will be added with each other.
+   *
+   * @param other Matrix `B` in `A * B = C`
+   * @param numMidDimSplits Number of splits to cut on the middle dimension when doing
+   *                        multiplication. For example, when multiplying a Matrix `A` of
+   *                        size `m x n` with Matrix `B` of size `n x k`, this parameter
+   *                        configures the parallelism to use when grouping the matrices. The
+   *                        parallelism will increase from `m x k` to `m x k x numMidDimSplits`,
+   *                        which in some cases also reduces total shuffled data.
+   */
+  @Since("2.2.0")
+  def multiply(
+      other: BlockMatrix,
+      numMidDimSplits: Int): BlockMatrix = {
     require(numCols() == other.numRows(), "The number of columns of A and the number of rows " +
       s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " +
       "think they should be equal, try setting the dimensions of A and B explicitly while " +
       "initializing them.")
+    require(numMidDimSplits > 0, "numMidDimSplits should be a positive integer.")
     if (colsPerBlock == other.rowsPerBlock) {
       val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks,
         math.max(blocks.partitions.length, other.blocks.partitions.length))
-      val (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner)
+      val (leftDestinations, rightDestinations)
+        = simulateMultiply(other, resultPartitioner, numMidDimSplits)
       // Each block of A must be multiplied with the corresponding blocks in the columns of B.
       val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) =>
         val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
@@ -477,7 +507,11 @@ class BlockMatrix @Since("1.3.0") (
         val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
         destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
       }
-      val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) =>
+      val intermediatePartitioner = new Partitioner {
+        override def numPartitions: Int = resultPartitioner.numPartitions * numMidDimSplits
+        override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+      }
+      val newBlocks = flatA.cogroup(flatB, intermediatePartitioner).flatMap { case (pId, (a, b)) =>
         a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) =>
           b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) =>
             val C = rightBlock match {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
index 61266f3c78dbc3e8aa3e702eece156b9a85b0f5a..f6a996940291cd221ad77609795ebb0be91dd582 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
@@ -267,6 +267,15 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(sparseBM.subtract(sparseBM).toBreeze() === sparseBM.subtract(denseBM).toBreeze())
   }
 
+  def testMultiply(A: BlockMatrix, B: BlockMatrix, expectedResult: Matrix,
+      numMidDimSplits: Int): Unit = {
+    val C = A.multiply(B, numMidDimSplits)
+    val localC = C.toLocalMatrix()
+    assert(C.numRows() === A.numRows())
+    assert(C.numCols() === B.numCols())
+    assert(localC ~== expectedResult absTol 1e-8)
+  }
+
   test("multiply") {
     // identity matrix
     val blocks: Seq[((Int, Int), Matrix)] = Seq(
@@ -302,12 +311,13 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     // Try it with increased number of partitions
     val largeA = new BlockMatrix(sc.parallelize(largerAblocks, 10), 6, 4)
     val largeB = new BlockMatrix(sc.parallelize(largerBblocks, 8), 4, 4)
-    val largeC = largeA.multiply(largeB)
-    val localC = largeC.toLocalMatrix()
+
     val result = largeA.toLocalMatrix().multiply(largeB.toLocalMatrix().asInstanceOf[DenseMatrix])
-    assert(largeC.numRows() === largeA.numRows())
-    assert(largeC.numCols() === largeB.numCols())
-    assert(localC ~== result absTol 1e-8)
+
+    testMultiply(largeA, largeB, result, 1)
+    testMultiply(largeA, largeB, result, 2)
+    testMultiply(largeA, largeB, result, 3)
+    testMultiply(largeA, largeB, result, 4)
   }
 
   test("simulate multiply") {
@@ -318,7 +328,7 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     val B = new BlockMatrix(rdd, colPerPart, rowPerPart)
     val resultPartitioner = GridPartitioner(gridBasedMat.numRowBlocks, B.numColBlocks,
       math.max(numPartitions, 2))
-    val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner)
+    val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner, 1)
     assert(destinationsA((0, 0)) === Set(0))
     assert(destinationsA((0, 1)) === Set(2))
     assert(destinationsA((1, 0)) === Set(0))