diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index ff81a2f03e2a85b1c49dfa9a5150181480cd0982..20d68a34bf3eab6258b869bc517421315bdefc51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -425,22 +425,27 @@ class BlockMatrix @Since("1.3.0") ( */ private[distributed] def simulateMultiply( other: BlockMatrix, - partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = { - val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached - val rightMatrix = other.blocks.keys.collect() + partitioner: GridPartitioner, + midDimSplitNum: Int): (BlockDestinations, BlockDestinations) = { + val leftMatrix = blockInfo.keys.collect() + val rightMatrix = other.blockInfo.keys.collect() val rightCounterpartsHelper = rightMatrix.groupBy(_._1).mapValues(_.map(_._2)) val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) => val rightCounterparts = rightCounterpartsHelper.getOrElse(colIndex, Array.empty[Int]) val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b))) - ((rowIndex, colIndex), partitions.toSet) + val midDimSplitIndex = colIndex % midDimSplitNum + ((rowIndex, colIndex), + partitions.toSet.map((pid: Int) => pid * midDimSplitNum + midDimSplitIndex)) }.toMap val leftCounterpartsHelper = leftMatrix.groupBy(_._2).mapValues(_.map(_._1)) val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) => val leftCounterparts = leftCounterpartsHelper.getOrElse(rowIndex, Array.empty[Int]) val partitions = leftCounterparts.map(b => partitioner.getPartition((b, colIndex))) - ((rowIndex, colIndex), partitions.toSet) + val midDimSplitIndex = rowIndex % midDimSplitNum + ((rowIndex, colIndex), + partitions.toSet.map((pid: Int) => pid * midDimSplitNum + midDimSplitIndex)) }.toMap (leftDestinations, rightDestinations) @@ -459,14 +464,39 @@ class BlockMatrix @Since("1.3.0") ( */ @Since("1.3.0") def multiply(other: BlockMatrix): BlockMatrix = { + multiply(other, 1) + } + + /** + * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` + * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains + * `SparseMatrix`, they will have to be converted to a `DenseMatrix`. The output + * [[BlockMatrix]] will only consist of blocks of `DenseMatrix`. This may cause + * some performance issues until support for multiplying two sparse matrices is added. + * Blocks with duplicate indices will be added with each other. + * + * @param other Matrix `B` in `A * B = C` + * @param numMidDimSplits Number of splits to cut on the middle dimension when doing + * multiplication. For example, when multiplying a Matrix `A` of + * size `m x n` with Matrix `B` of size `n x k`, this parameter + * configures the parallelism to use when grouping the matrices. The + * parallelism will increase from `m x k` to `m x k x numMidDimSplits`, + * which in some cases also reduces total shuffled data. + */ + @Since("2.2.0") + def multiply( + other: BlockMatrix, + numMidDimSplits: Int): BlockMatrix = { require(numCols() == other.numRows(), "The number of columns of A and the number of rows " + s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " + "think they should be equal, try setting the dimensions of A and B explicitly while " + "initializing them.") + require(numMidDimSplits > 0, "numMidDimSplits should be a positive integer.") if (colsPerBlock == other.rowsPerBlock) { val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks, math.max(blocks.partitions.length, other.blocks.partitions.length)) - val (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner) + val (leftDestinations, rightDestinations) + = simulateMultiply(other, resultPartitioner, numMidDimSplits) // Each block of A must be multiplied with the corresponding blocks in the columns of B. val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) @@ -477,7 +507,11 @@ class BlockMatrix @Since("1.3.0") ( val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) destinations.map(j => (j, (blockRowIndex, blockColIndex, block))) } - val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) => + val intermediatePartitioner = new Partitioner { + override def numPartitions: Int = resultPartitioner.numPartitions * numMidDimSplits + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } + val newBlocks = flatA.cogroup(flatB, intermediatePartitioner).flatMap { case (pId, (a, b)) => a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) => b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) => val C = rightBlock match { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 61266f3c78dbc3e8aa3e702eece156b9a85b0f5a..f6a996940291cd221ad77609795ebb0be91dd582 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -267,6 +267,15 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(sparseBM.subtract(sparseBM).toBreeze() === sparseBM.subtract(denseBM).toBreeze()) } + def testMultiply(A: BlockMatrix, B: BlockMatrix, expectedResult: Matrix, + numMidDimSplits: Int): Unit = { + val C = A.multiply(B, numMidDimSplits) + val localC = C.toLocalMatrix() + assert(C.numRows() === A.numRows()) + assert(C.numCols() === B.numCols()) + assert(localC ~== expectedResult absTol 1e-8) + } + test("multiply") { // identity matrix val blocks: Seq[((Int, Int), Matrix)] = Seq( @@ -302,12 +311,13 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { // Try it with increased number of partitions val largeA = new BlockMatrix(sc.parallelize(largerAblocks, 10), 6, 4) val largeB = new BlockMatrix(sc.parallelize(largerBblocks, 8), 4, 4) - val largeC = largeA.multiply(largeB) - val localC = largeC.toLocalMatrix() + val result = largeA.toLocalMatrix().multiply(largeB.toLocalMatrix().asInstanceOf[DenseMatrix]) - assert(largeC.numRows() === largeA.numRows()) - assert(largeC.numCols() === largeB.numCols()) - assert(localC ~== result absTol 1e-8) + + testMultiply(largeA, largeB, result, 1) + testMultiply(largeA, largeB, result, 2) + testMultiply(largeA, largeB, result, 3) + testMultiply(largeA, largeB, result, 4) } test("simulate multiply") { @@ -318,7 +328,7 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val B = new BlockMatrix(rdd, colPerPart, rowPerPart) val resultPartitioner = GridPartitioner(gridBasedMat.numRowBlocks, B.numColBlocks, math.max(numPartitions, 2)) - val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner) + val (destinationsA, destinationsB) = gridBasedMat.simulateMultiply(B, resultPartitioner, 1) assert(destinationsA((0, 0)) === Set(0)) assert(destinationsA((0, 1)) === Set(2)) assert(destinationsA((1, 0)) === Set(0))