Repository: spark
Updated Branches:
  refs/heads/branch-1.5 576265f83 -> 0a0472197


[SPARK-17721][MLLIB][BACKPORT] Fix for multiplying transposed SparseMatrix with 
SparseVector

Backport PR of changes relevant to mllib only, but otherwise identical to #15296

jkbradley

Author: Bjarne Fruergaard <[email protected]>

Closes #15311 from bwahlgreen/bugfix-spark-17721-1.6.

(cherry picked from commit 376545e4d38cd41b4a3233819d63bb81f5c83283)
Signed-off-by: Joseph K. Bradley <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a047219
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a047219
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a047219

Branch: refs/heads/branch-1.5
Commit: 0a04721973c34a3324c41ac68b4f9c203ecedf40
Parents: 576265f
Author: Bjarne Fruergaard <[email protected]>
Authored: Sat Oct 1 19:28:51 2016 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Sat Oct 1 19:29:14 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/mllib/linalg/BLAS.scala |  8 ++++++--
 .../org/apache/spark/mllib/linalg/BLASSuite.scala  | 17 +++++++++++++++++
 2 files changed, 23 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0a047219/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index ab475af..a14809d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -587,12 +587,16 @@ private[spark] object BLAS extends Serializable with 
Logging {
         val indEnd = Arows(rowCounter + 1)
         var sum = 0.0
         var k = 0
-        while (k < xNnz && i < indEnd) {
+        while (i < indEnd && k < xNnz) {
           if (xIndices(k) == Acols(i)) {
             sum += Avals(i) * xValues(k)
+            k += 1
+            i += 1
+          } else if (xIndices(k) < Acols(i)) {
+            k += 1
+          } else {
             i += 1
           }
-          k += 1
         }
         yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
         rowCounter += 1

http://git-wip-us.apache.org/repos/asf/spark/blob/0a047219/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 8db5c84..8d02751 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -367,6 +367,23 @@ class BLASSuite extends SparkFunSuite {
       }
     }
 
+    val y17 = new DenseVector(Array(0.0, 0.0))
+    val y18 = y17.copy
+
+    val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), 
Array(2.0, 1.0, 1.0, 2.0))
+      .transpose
+    val sA4 =
+      new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 
2.0, 2.0, 1.0))
+    val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
+
+    val expected4 = new DenseVector(Array(5.0, 4.0))
+
+    gemv(1.0, sA3, sx3, 0.0, y17)
+    gemv(1.0, sA4, sx3, 0.0, y18)
+
+    assert(y17 ~== expected4 absTol 1e-15)
+    assert(y18 ~== expected4 absTol 1e-15)
+
     val dAT =
       new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 
0.0, 0.0, 3.0))
     val sAT =


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to