diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index 639295c69525534841043fb55b9ba1d93e1ac5c7..9782350587061dec434d6ac9ba0a817b2009bd8e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -426,16 +426,21 @@ class BlockMatrix @Since("1.3.0") (
       partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = {
     val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached
     val rightMatrix = other.blocks.keys.collect()
+
+    val rightCounterpartsHelper = rightMatrix.groupBy(_._1).mapValues(_.map(_._2))
     val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) =>
-      val rightCounterparts = rightMatrix.filter(_._1 == colIndex)
-      val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2)))
+      val rightCounterparts = rightCounterpartsHelper.getOrElse(colIndex, Array())
+      val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b)))
       ((rowIndex, colIndex), partitions.toSet)
     }.toMap
+
+    val leftCounterpartsHelper = leftMatrix.groupBy(_._2).mapValues(_.map(_._1))
     val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) =>
-      val leftCounterparts = leftMatrix.filter(_._2 == rowIndex)
-      val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex)))
+      val leftCounterparts = leftCounterpartsHelper.getOrElse(rowIndex, Array())
+      val partitions = leftCounterparts.map(b => partitioner.getPartition((b, colIndex)))
       ((rowIndex, colIndex), partitions.toSet)
     }.toMap
+
     (leftDestinations, rightDestinations)
   }