Skip to content
Snippets Groups Projects
Commit 3bdbbc6c authored by Mike Dusenberry's avatar Mike Dusenberry Committed by Xiangrui Meng
Browse files

[SPARK-6488][MLLIB][PYTHON] Support addition/multiplication in PySpark's BlockMatrix

This PR adds addition and multiplication to PySpark's `BlockMatrix` class via `add` and `multiply` functions.

Author: Mike Dusenberry <mwdusenb@us.ibm.com>

Closes #9139 from dusenberrymw/SPARK-6488_Add_Addition_and_Multiplication_to_PySpark_BlockMatrix.
parent 9fc16a82
No related branches found
No related tags found
No related merge requests found
...@@ -775,6 +775,74 @@ class BlockMatrix(DistributedMatrix): ...@@ -775,6 +775,74 @@ class BlockMatrix(DistributedMatrix):
""" """
return self._java_matrix_wrapper.call("numCols") return self._java_matrix_wrapper.call("numCols")
def add(self, other):
"""
Adds two block matrices together. The matrices must have the
same size and matching `rowsPerBlock` and `colsPerBlock` values.
If one of the sub matrix blocks that are being added is a
SparseMatrix, the resulting sub matrix block will also be a
SparseMatrix, even if it is being added to a DenseMatrix. If
two dense sub matrix blocks are added, the output block will
also be a DenseMatrix.
>>> dm1 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])
>>> dm2 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12])
>>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12])
>>> blocks1 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)])
>>> blocks2 = sc.parallelize([((0, 0), dm1), ((1, 0), dm2)])
>>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm2)])
>>> mat1 = BlockMatrix(blocks1, 3, 2)
>>> mat2 = BlockMatrix(blocks2, 3, 2)
>>> mat3 = BlockMatrix(blocks3, 3, 2)
>>> mat1.add(mat2).toLocalMatrix()
DenseMatrix(6, 2, [2.0, 4.0, 6.0, 14.0, 16.0, 18.0, 8.0, 10.0, 12.0, 20.0, 22.0, 24.0], 0)
>>> mat1.add(mat3).toLocalMatrix()
DenseMatrix(6, 2, [8.0, 2.0, 3.0, 14.0, 16.0, 18.0, 4.0, 16.0, 18.0, 20.0, 22.0, 24.0], 0)
"""
if not isinstance(other, BlockMatrix):
raise TypeError("Other should be a BlockMatrix, got %s" % type(other))
other_java_block_matrix = other._java_matrix_wrapper._java_model
java_block_matrix = self._java_matrix_wrapper.call("add", other_java_block_matrix)
return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock)
def multiply(self, other):
"""
Left multiplies this BlockMatrix by `other`, another
BlockMatrix. The `colsPerBlock` of this matrix must equal the
`rowsPerBlock` of `other`. If `other` contains any SparseMatrix
blocks, they will have to be converted to DenseMatrix blocks.
The output BlockMatrix will only consist of DenseMatrix blocks.
This may cause some performance issues until support for
multiplying two sparse matrices is added.
>>> dm1 = Matrices.dense(2, 3, [1, 2, 3, 4, 5, 6])
>>> dm2 = Matrices.dense(2, 3, [7, 8, 9, 10, 11, 12])
>>> dm3 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6])
>>> dm4 = Matrices.dense(3, 2, [7, 8, 9, 10, 11, 12])
>>> sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 1, 2], [7, 11, 12])
>>> blocks1 = sc.parallelize([((0, 0), dm1), ((0, 1), dm2)])
>>> blocks2 = sc.parallelize([((0, 0), dm3), ((1, 0), dm4)])
>>> blocks3 = sc.parallelize([((0, 0), sm), ((1, 0), dm4)])
>>> mat1 = BlockMatrix(blocks1, 2, 3)
>>> mat2 = BlockMatrix(blocks2, 3, 2)
>>> mat3 = BlockMatrix(blocks3, 3, 2)
>>> mat1.multiply(mat2).toLocalMatrix()
DenseMatrix(2, 2, [242.0, 272.0, 350.0, 398.0], 0)
>>> mat1.multiply(mat3).toLocalMatrix()
DenseMatrix(2, 2, [227.0, 258.0, 394.0, 450.0], 0)
"""
if not isinstance(other, BlockMatrix):
raise TypeError("Other should be a BlockMatrix, got %s" % type(other))
other_java_block_matrix = other._java_matrix_wrapper._java_model
java_block_matrix = self._java_matrix_wrapper.call("multiply", other_java_block_matrix)
return BlockMatrix(java_block_matrix, self.rowsPerBlock, self.colsPerBlock)
def toLocalMatrix(self): def toLocalMatrix(self):
""" """
Collect the distributed matrix on the driver as a DenseMatrix. Collect the distributed matrix on the driver as a DenseMatrix.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment