This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5c96d64 [SPARK-35707][ML] optimize sparse GEMM by skipping bound
checking
5c96d64 is described below
commit 5c96d643eeb4ca1ad7e4e9cc711971203fcacc6c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jun 16 08:57:27 2021 +0800
[SPARK-35707][ML] optimize sparse GEMM by skipping bound checking
### What changes were proposed in this pull request?
Sparse gemm use mothod `DenseMatrix.apply` to access the values, which can
be optimized by skipping checking the bound and `isTransposed`
```
override def apply(i: Int, j: Int): Double = values(index(i, j))
private[ml] def index(i: Int, j: Int): Int = {
require(i >= 0 && i < numRows, s"Expected 0 <= i < $numRows, got i =
$i.")
require(j >= 0 && j < numCols, s"Expected 0 <= j < $numCols, got j =
$j.")
if (!isTransposed) i + numRows * j else j + numCols * i
}
```
### Why are the changes needed?
to improve performance, about 15% faster in the designed case
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing testsuite and additional performance test
Closes #32857 from zhengruifeng/gemm_opt_index.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala | 4 ++--
mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
index 0bc8b2f..d1255de 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
@@ -480,7 +480,7 @@ private[spark] object BLAS extends Serializable {
val indEnd = AcolPtrs(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
- sum += Avals(i) * B(ArowIndices(i), colCounterForB)
+ sum += Avals(i) * Bvals(colCounterForB + nB * ArowIndices(i))
i += 1
}
val Cindex = Cstart + rowCounterForA
@@ -522,7 +522,7 @@ private[spark] object BLAS extends Serializable {
while (colCounterForA < kA) {
var i = AcolPtrs(colCounterForA)
val indEnd = AcolPtrs(colCounterForA + 1)
- val Bval = B(colCounterForA, colCounterForB) * alpha
+ val Bval = Bvals(colCounterForB + nB * colCounterForA) * alpha
while (i < indEnd) {
Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
i += 1
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index e38cfe4..5cbec53 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -462,7 +462,7 @@ private[spark] object BLAS extends Serializable with
Logging {
val indEnd = AcolPtrs(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
- sum += Avals(i) * B(ArowIndices(i), colCounterForB)
+ sum += Avals(i) * Bvals(colCounterForB + nB * ArowIndices(i))
i += 1
}
val Cindex = Cstart + rowCounterForA
@@ -504,7 +504,7 @@ private[spark] object BLAS extends Serializable with
Logging {
while (colCounterForA < kA) {
var i = AcolPtrs(colCounterForA)
val indEnd = AcolPtrs(colCounterForA + 1)
- val Bval = B(colCounterForA, colCounterForB) * alpha
+ val Bval = Bvals(colCounterForB + nB * colCounterForA) * alpha
while (i < indEnd) {
Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval
i += 1
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]