Skip to content
Snippets Groups Projects
Commit 29396e7d authored by Bjarne Fruergaard's avatar Bjarne Fruergaard Committed by Joseph K. Bradley
Browse files

[SPARK-17721][MLLIB][ML] Fix for multiplying transposed SparseMatrix with SparseVector

## What changes were proposed in this pull request?

* changes the implementation of gemv with transposed SparseMatrix and SparseVector both in mllib-local and mllib (identical)
* adds a test that was failing before this change, but succeeds with these changes.

The problem in the previous implementation was that it only increments `i`, that is enumerating the columns of a row in the SparseMatrix, when the row-index of the vector matches the column-index of the SparseMatrix. In cases where a particular row of the SparseMatrix has non-zero values at column-indices lower than corresponding non-zero row-indices of the SparseVector, the non-zero values of the SparseVector are enumerated without ever matching the column-index at index `i` and the remaining column-indices i+1,...,indEnd-1 are never attempted. The test cases in this PR illustrate this issue.

## How was this patch tested?

I have run the specific `gemv` tests in both mllib-local and mllib. I am currently still running `./dev/run-tests`.

## ___
As per instructions, I hereby state that this is my original work and that I license the work to the project (Apache Spark) under the project's open source license.

Mentioning dbtsai, viirya and brkyvz whom I can see have worked/authored on these parts before.

Author: Bjarne Fruergaard <bwahlgreen@gmail.com>

Closes #15296 from bwahlgreen/bugfix-spark-17721.
parent 4ecc648a
No related branches found
No related tags found
No related merge requests found
......@@ -638,12 +638,16 @@ private[spark] object BLAS extends Serializable {
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
var k = 0
while (k < xNnz && i < indEnd) {
while (i < indEnd && k < xNnz) {
if (xIndices(k) == Acols(i)) {
sum += Avals(i) * xValues(k)
k += 1
i += 1
} else if (xIndices(k) < Acols(i)) {
k += 1
} else {
i += 1
}
k += 1
}
yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
rowCounter += 1
......
......@@ -392,6 +392,23 @@ class BLASSuite extends SparkMLFunSuite {
}
}
val y17 = new DenseVector(Array(0.0, 0.0))
val y18 = y17.copy
val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0))
.transpose
val sA4 =
new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0))
val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
val expected4 = new DenseVector(Array(5.0, 4.0))
gemv(1.0, sA3, sx3, 0.0, y17)
gemv(1.0, sA4, sx3, 0.0, y18)
assert(y17 ~== expected4 absTol 1e-15)
assert(y18 ~== expected4 absTol 1e-15)
val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =
......
......@@ -637,12 +637,16 @@ private[spark] object BLAS extends Serializable with Logging {
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
var k = 0
while (k < xNnz && i < indEnd) {
while (i < indEnd && k < xNnz) {
if (xIndices(k) == Acols(i)) {
sum += Avals(i) * xValues(k)
k += 1
i += 1
} else if (xIndices(k) < Acols(i)) {
k += 1
} else {
i += 1
}
k += 1
}
yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
rowCounter += 1
......
......@@ -392,6 +392,23 @@ class BLASSuite extends SparkFunSuite {
}
}
val y17 = new DenseVector(Array(0.0, 0.0))
val y18 = y17.copy
val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0))
.transpose
val sA4 =
new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0))
val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
val expected4 = new DenseVector(Array(5.0, 4.0))
gemv(1.0, sA3, sx3, 0.0, y17)
gemv(1.0, sA4, sx3, 0.0, y18)
assert(y17 ~== expected4 absTol 1e-15)
assert(y18 ~== expected4 absTol 1e-15)
val dAT =
new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
val sAT =
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment