This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5c96d64  [SPARK-35707][ML] optimize sparse GEMM by skipping bound 
checking
5c96d64 is described below

commit 5c96d643eeb4ca1ad7e4e9cc711971203fcacc6c
Author: Ruifeng Zheng <ruife...@foxmail.com>
AuthorDate: Wed Jun 16 08:57:27 2021 +0800

    [SPARK-35707][ML] optimize sparse GEMM by skipping bound checking
    
    ### What changes were proposed in this pull request?
    Sparse gemm use mothod `DenseMatrix.apply` to access the values, which can 
be optimized by skipping checking the bound and `isTransposed`
    
    ```
      override def apply(i: Int, j: Int): Double = values(index(i, j))
    
      private[ml] def index(i: Int, j: Int): Int = {
        require(i >= 0 && i < numRows, s"Expected 0 <= i < $numRows, got i = 
$i.")
        require(j >= 0 && j < numCols, s"Expected 0 <= j < $numCols, got j = 
$j.")
        if (!isTransposed) i + numRows * j else j + numCols * i
      }
    
    ```
    
    ### Why are the changes needed?
    to improve performance, about 15% faster in the designed case
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing testsuite and additional performance test
    
    Closes #32857 from zhengruifeng/gemm_opt_index.
    
    Authored-by: Ruifeng Zheng <ruife...@foxmail.com>
    Signed-off-by: Ruifeng Zheng <ruife...@foxmail.com>
---
 mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala | 4 ++--
 mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala    | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala 
b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
index 0bc8b2f..d1255de 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
@@ -480,7 +480,7 @@ private[spark] object BLAS extends Serializable {
             val indEnd = AcolPtrs(rowCounterForA + 1)
             var sum = 0.0
             while (i < indEnd) {
-              sum += Avals(i) * B(ArowIndices(i), colCounterForB)
+              sum += Avals(i) * Bvals(colCounterForB + nB * ArowIndices(i))
               i += 1
             }
             val Cindex = Cstart + rowCounterForA
@@ -522,7 +522,7 @@ private[spark] object BLAS extends Serializable {
           while (colCounterForA < kA) {
             var i = AcolPtrs(colCounterForA)
             val indEnd = AcolPtrs(colCounterForA + 1)
-            val Bval = B(colCounterForA, colCounterForB) * alpha
+            val Bval = Bvals(colCounterForB + nB * colCounterForA) * alpha
             while (i < indEnd) {
               Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
               i += 1
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index e38cfe4..5cbec53 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -462,7 +462,7 @@ private[spark] object BLAS extends Serializable with 
Logging {
             val indEnd = AcolPtrs(rowCounterForA + 1)
             var sum = 0.0
             while (i < indEnd) {
-              sum += Avals(i) * B(ArowIndices(i), colCounterForB)
+              sum += Avals(i) * Bvals(colCounterForB + nB * ArowIndices(i))
               i += 1
             }
             val Cindex = Cstart + rowCounterForA
@@ -504,7 +504,7 @@ private[spark] object BLAS extends Serializable with 
Logging {
           while (colCounterForA < kA) {
             var i = AcolPtrs(colCounterForA)
             val indEnd = AcolPtrs(colCounterForA + 1)
-            val Bval = B(colCounterForA, colCounterForB) * alpha
+            val Bval = Bvals(colCounterForB + nB * colCounterForA) * alpha
             while (i < indEnd) {
               Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
               i += 1

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to