diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index a33b6137cf9cc0d1ad2836bf4e7f33b1efeb71cd..81a6c0550bda79fde53acd56cf29c86534495a58 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -54,12 +54,14 @@ private[mllib] class GridPartitioner( /** * Returns the index of the partition the input coordinate belongs to. * - * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in - * multiplication. k is ignored in computing partitions. + * @param key The partition id i (calculated through this method for coordinate (i, j) in + * `simulateMultiply`, the coordinate (i, j) or a tuple (i, j, k), where k is + * the inner index used in multiplication. k is ignored in computing partitions. * @return The index of the partition, which the coordinate belongs to. */ override def getPartition(key: Any): Int = { key match { + case i: Int => i case (i: Int, j: Int) => getPartitionId(i, j) case (i: Int, j: Int, _: Int) => @@ -352,12 +354,49 @@ class BlockMatrix @Since("1.3.0") ( } } + /** Block (i,j) --> Set of destination partitions */ + private type BlockDestinations = Map[(Int, Int), Set[Int]] + + /** + * Simulate the multiplication with just block indices in order to cut costs on communication, + * when we are actually shuffling the matrices. + * The `colsPerBlock` of this matrix must equal the `rowsPerBlock` of `other`. + * Exposed for tests. + * + * @param other The BlockMatrix to multiply + * @param partitioner The partitioner that will be used for the resulting matrix `C = A * B` + * @return A tuple of [[BlockDestinations]]. The first element is the Map of the set of partitions + * that we need to shuffle each blocks of `this`, and the second element is the Map for + * `other`. + */ + private[distributed] def simulateMultiply( + other: BlockMatrix, + partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = { + val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached + val rightMatrix = other.blocks.keys.collect() + val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) => + val rightCounterparts = rightMatrix.filter(_._1 == colIndex) + val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2))) + ((rowIndex, colIndex), partitions.toSet) + }.toMap + val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) => + val leftCounterparts = leftMatrix.filter(_._2 == rowIndex) + val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex))) + ((rowIndex, colIndex), partitions.toSet) + }.toMap + (leftDestinations, rightDestinations) + } + /** * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause * some performance issues until support for multiplying two sparse matrices is added. + * + * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when + * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added + * with each other. */ @Since("1.3.0") def multiply(other: BlockMatrix): BlockMatrix = { @@ -368,33 +407,30 @@ class BlockMatrix @Since("1.3.0") ( if (colsPerBlock == other.rowsPerBlock) { val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks, math.max(blocks.partitions.length, other.blocks.partitions.length)) - // Each block of A must be multiplied with the corresponding blocks in each column of B. - // TODO: Optimize to send block to a partition once, similar to ALS + val (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner) + // Each block of A must be multiplied with the corresponding blocks in the columns of B. val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => - Iterator.tabulate(other.numColBlocks)(j => ((blockRowIndex, j, blockColIndex), block)) + val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) + destinations.map(j => (j, (blockRowIndex, blockColIndex, block))) } // Each block of B must be multiplied with the corresponding blocks in each row of A. val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => - Iterator.tabulate(numRowBlocks)(i => ((i, blockColIndex, blockRowIndex), block)) + val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) + destinations.map(j => (j, (blockRowIndex, blockColIndex, block))) } - val newBlocks: RDD[MatrixBlock] = flatA.cogroup(flatB, resultPartitioner) - .flatMap { case ((blockRowIndex, blockColIndex, _), (a, b)) => - if (a.size > 1 || b.size > 1) { - throw new SparkException("There are multiple MatrixBlocks with indices: " + - s"($blockRowIndex, $blockColIndex). Please remove them.") - } - if (a.nonEmpty && b.nonEmpty) { - val C = b.head match { - case dense: DenseMatrix => a.head.multiply(dense) - case sparse: SparseMatrix => a.head.multiply(sparse.toDense) - case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.") + val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) => + a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) => + b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) => + val C = rightBlock match { + case dense: DenseMatrix => leftBlock.multiply(dense) + case sparse: SparseMatrix => leftBlock.multiply(sparse.toDense) + case _ => + throw new SparkException(s"Unrecognized matrix type ${rightBlock.getClass}.") } - Iterator(((blockRowIndex, blockColIndex), C.toBreeze)) - } else { - Iterator() + ((leftRowIndex, rightColIndex), C.toBreeze) } - }.reduceByKey(resultPartitioner, (a, b) => a + b) - .mapValues(Matrices.fromBreeze) + } + }.reduceByKey(resultPartitioner, (a, b) => a + b).mapValues(Matrices.fromBreeze) // TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matrices new BlockMatrix(newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols()) } else { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 93fe04c139b9a05ad1391fb6cd1c04169f8c97d9..b8eb10305801c5a1efcc4bcc61f2e42b55f09a85 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -235,6 +235,24 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(localC ~== result absTol 1e-8) } + test("simulate multiply") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0)))) + val rdd = sc.parallelize(blocks, 2) + val B = new BlockMatrix(rdd, colPerPart, rowPerPart) + val resultPartitioner = GridPartitioner(gridBasedMat.numRowBlocks, B.numColBlocks, + math.max(numPartitions, 2)) + val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner) + assert(destinationsA((0, 0)) === Set(0)) + assert(destinationsA((0, 1)) === Set(2)) + assert(destinationsA((1, 0)) === Set(0)) + assert(destinationsA((1, 1)) === Set(2)) + assert(destinationsA((2, 1)) === Set(3)) + assert(destinationsB((0, 0)) === Set(0)) + assert(destinationsB((1, 1)) === Set(2, 3)) + } + test("validate") { // No error gridBasedMat.validate()