diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index a33b6137cf9cc0d1ad2836bf4e7f33b1efeb71cd..81a6c0550bda79fde53acd56cf29c86534495a58 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -54,12 +54,14 @@ private[mllib] class GridPartitioner(
   /**
    * Returns the index of the partition the input coordinate belongs to.
    *
-   * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in
-   *            multiplication. k is ignored in computing partitions.
+   * @param key The partition id i (calculated through this method for coordinate (i, j) in
+   *            `simulateMultiply`, the coordinate (i, j) or a tuple (i, j, k), where k is
+   *            the inner index used in multiplication. k is ignored in computing partitions.
    * @return The index of the partition, which the coordinate belongs to.
    */
   override def getPartition(key: Any): Int = {
     key match {
+      case i: Int => i
       case (i: Int, j: Int) =>
         getPartitionId(i, j)
       case (i: Int, j: Int, _: Int) =>
@@ -352,12 +354,49 @@ class BlockMatrix @Since("1.3.0") (
     }
   }
 
+  /** Block (i,j) --> Set of destination partitions */
+  private type BlockDestinations = Map[(Int, Int), Set[Int]]
+
+  /**
+   * Simulate the multiplication with just block indices in order to cut costs on communication,
+   * when we are actually shuffling the matrices.
+   * The `colsPerBlock` of this matrix must equal the `rowsPerBlock` of `other`.
+   * Exposed for tests.
+   *
+   * @param other The BlockMatrix to multiply
+   * @param partitioner The partitioner that will be used for the resulting matrix `C = A * B`
+   * @return A tuple of [[BlockDestinations]]. The first element is the Map of the set of partitions
+   *         that we need to shuffle each blocks of `this`, and the second element is the Map for
+   *         `other`.
+   */
+  private[distributed] def simulateMultiply(
+      other: BlockMatrix,
+      partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = {
+    val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached
+    val rightMatrix = other.blocks.keys.collect()
+    val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) =>
+      val rightCounterparts = rightMatrix.filter(_._1 == colIndex)
+      val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2)))
+      ((rowIndex, colIndex), partitions.toSet)
+    }.toMap
+    val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) =>
+      val leftCounterparts = leftMatrix.filter(_._2 == rowIndex)
+      val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex)))
+      ((rowIndex, colIndex), partitions.toSet)
+    }.toMap
+    (leftDestinations, rightDestinations)
+  }
+
   /**
    * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock`
    * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains
    * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output
    * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause
    * some performance issues until support for multiplying two sparse matrices is added.
+   *
+   * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when
+   * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added
+   * with each other.
    */
   @Since("1.3.0")
   def multiply(other: BlockMatrix): BlockMatrix = {
@@ -368,33 +407,30 @@ class BlockMatrix @Since("1.3.0") (
     if (colsPerBlock == other.rowsPerBlock) {
       val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks,
         math.max(blocks.partitions.length, other.blocks.partitions.length))
-      // Each block of A must be multiplied with the corresponding blocks in each column of B.
-      // TODO: Optimize to send block to a partition once, similar to ALS
+      val (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner)
+      // Each block of A must be multiplied with the corresponding blocks in the columns of B.
       val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) =>
-        Iterator.tabulate(other.numColBlocks)(j => ((blockRowIndex, j, blockColIndex), block))
+        val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
+        destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
       }
       // Each block of B must be multiplied with the corresponding blocks in each row of A.
       val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) =>
-        Iterator.tabulate(numRowBlocks)(i => ((i, blockColIndex, blockRowIndex), block))
+        val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty)
+        destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
       }
-      val newBlocks: RDD[MatrixBlock] = flatA.cogroup(flatB, resultPartitioner)
-        .flatMap { case ((blockRowIndex, blockColIndex, _), (a, b)) =>
-          if (a.size > 1 || b.size > 1) {
-            throw new SparkException("There are multiple MatrixBlocks with indices: " +
-              s"($blockRowIndex, $blockColIndex). Please remove them.")
-          }
-          if (a.nonEmpty && b.nonEmpty) {
-            val C = b.head match {
-              case dense: DenseMatrix => a.head.multiply(dense)
-              case sparse: SparseMatrix => a.head.multiply(sparse.toDense)
-              case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.")
+      val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) =>
+        a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) =>
+          b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) =>
+            val C = rightBlock match {
+              case dense: DenseMatrix => leftBlock.multiply(dense)
+              case sparse: SparseMatrix => leftBlock.multiply(sparse.toDense)
+              case _ =>
+                throw new SparkException(s"Unrecognized matrix type ${rightBlock.getClass}.")
             }
-            Iterator(((blockRowIndex, blockColIndex), C.toBreeze))
-          } else {
-            Iterator()
+            ((leftRowIndex, rightColIndex), C.toBreeze)
           }
-      }.reduceByKey(resultPartitioner, (a, b) => a + b)
-        .mapValues(Matrices.fromBreeze)
+        }
+      }.reduceByKey(resultPartitioner, (a, b) => a + b).mapValues(Matrices.fromBreeze)
       // TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matrices
       new BlockMatrix(newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols())
     } else {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
index 93fe04c139b9a05ad1391fb6cd1c04169f8c97d9..b8eb10305801c5a1efcc4bcc61f2e42b55f09a85 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala
@@ -235,6 +235,24 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(localC ~== result absTol 1e-8)
   }
 
+  test("simulate multiply") {
+    val blocks: Seq[((Int, Int), Matrix)] = Seq(
+      ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0))),
+      ((1, 1), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0))))
+    val rdd = sc.parallelize(blocks, 2)
+    val B = new BlockMatrix(rdd, colPerPart, rowPerPart)
+    val resultPartitioner = GridPartitioner(gridBasedMat.numRowBlocks, B.numColBlocks,
+      math.max(numPartitions, 2))
+    val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner)
+    assert(destinationsA((0, 0)) === Set(0))
+    assert(destinationsA((0, 1)) === Set(2))
+    assert(destinationsA((1, 0)) === Set(0))
+    assert(destinationsA((1, 1)) === Set(2))
+    assert(destinationsA((2, 1)) === Set(3))
+    assert(destinationsB((0, 0)) === Set(0))
+    assert(destinationsB((1, 1)) === Set(2, 3))
+  }
+
   test("validate") {
     // No error
     gridBasedMat.validate()