Repository: spark Updated Branches: refs/heads/branch-1.6 b999fa43e -> 376545e4d
[SPARK-17721][MLLIB][BACKPORT] Fix for multiplying transposed SparseMatrix with SparseVector Backport PR of changes relevant to mllib only, but otherwise identical to #15296 jkbradley Author: Bjarne Fruergaard <[email protected]> Closes #15311 from bwahlgreen/bugfix-spark-17721-1.6. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/376545e4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/376545e4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/376545e4 Branch: refs/heads/branch-1.6 Commit: 376545e4d38cd41b4a3233819d63bb81f5c83283 Parents: b999fa4 Author: Bjarne Fruergaard <[email protected]> Authored: Sat Oct 1 19:28:51 2016 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Sat Oct 1 19:28:51 2016 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/mllib/linalg/BLAS.scala | 8 ++++++-- .../org/apache/spark/mllib/linalg/BLASSuite.scala | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/376545e4/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index df9f4ae..8f91d63 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -638,12 +638,16 @@ private[spark] object BLAS extends Serializable with Logging { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } else if (xIndices(k) < Acols(i)) { + k += 1 + } else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 http://git-wip-us.apache.org/repos/asf/spark/blob/376545e4/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 96e5ffe..31c23c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -392,6 +392,23 @@ class BLASSuite extends SparkFunSuite { } } + val y17 = new DenseVector(Array(0.0, 0.0)) + val y18 = y17.copy + + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) + .transpose + val sA4 = + new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0)) + val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) + + val expected4 = new DenseVector(Array(5.0, 4.0)) + + gemv(1.0, sA3, sx3, 0.0, y17) + gemv(1.0, sA4, sx3, 0.0, y18) + + assert(y17 ~== expected4 absTol 1e-15) + assert(y18 ~== expected4 absTol 1e-15) + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
