This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 88fe2b0eb4 [SYSTEMDS-3653] Ultra Sparse Right MM Optimization
88fe2b0eb4 is described below
commit 88fe2b0eb4eb1fd342f37c2741629056155c56a2
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Thu Nov 30 17:49:43 2023 +0100
[SYSTEMDS-3653] Ultra Sparse Right MM Optimization
Right side Ultra sparse optimizations goring from 8.525 to 4.575
on 100 repetitions of 100k by 1000 dense %*% 1000 by 1000 with 30 non zeros.
Closes #1952
---
.../sysds/runtime/matrix/data/LibMatrixMult.java | 47 +++++++++++++++++++---
1 file changed, 42 insertions(+), 5 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 41dc7f2264..e956f61906 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -49,6 +49,7 @@ import org.apache.sysds.runtime.data.SparseBlock.Type;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseBlockFactory;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
+import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.data.SparseRowScalar;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
@@ -194,7 +195,7 @@ public class LibMatrixMult
(!fixedRet && isUltraSparseMatrixMult(m1, m2, m1Perm));
boolean sparse = !fixedRet && !ultraSparse && !m1Perm
&& isSparseOutputMatrixMult(m1, m2);
-
+
// allocate output
if(ret == null)
ret = new MatrixBlock(m1.rlen, m2.clen, ultraSparse |
sparse);
@@ -1718,7 +1719,6 @@ public class LibMatrixMult
matrixMultUltraSparseLeft(m1, m2, ret, rl, ru);
else
matrixMultUltraSparseRight(m1, m2, ret, rl, ru);
- //no need to recompute nonzeros because maintained internally
}
private static void matrixMultUltraSparseSelf(MatrixBlock m1,
MatrixBlock ret, int rl, int ru) {
@@ -1926,10 +1926,14 @@ public class LibMatrixMult
private static void matrixMultUltraSparseRight(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, int rl, int ru) {
- if(!ret.isInSparseFormat() &&
ret.getDenseBlock().isContiguous())
+ if(ret.isInSparseFormat()){
+ if(m1.isInSparseFormat())
+
matrixMultUltraSparseRightSparseMCSRLeftSparseOut(m1, m2, ret, rl, ru);
+ else
+
matrixMultUltraSparseRightDenseLeftSparseOut(m1, m2, ret, rl, ru);
+ }
+ else if(ret.getDenseBlock().isContiguous())
matrixMultUltraSparseRightDenseOut(m1, m2, ret, rl, ru);
- else if(m1.isInSparseFormat() && ret.isInSparseFormat())
- matrixMultUltraSparseRightSparseMCSRLeftSparseOut(m1,
m2, ret, rl, ru);
else
matrixMultUltraSparseRightGeneric(m1, m2, ret, rl, ru);
}
@@ -1990,6 +1994,39 @@ public class LibMatrixMult
}
}
+ private static void
matrixMultUltraSparseRightDenseLeftSparseOut(MatrixBlock m1, MatrixBlock m2,
MatrixBlock ret, int rl, int ru) {
+ final int cd = m1.clen;
+ final DenseBlock a = m1.denseBlock;
+ final SparseBlock b = m2.sparseBlock;
+ final SparseBlockMCSR c = (SparseBlockMCSR) ret.sparseBlock;
+
+ for(int k = 0; k < cd; k++){
+ if(b.isEmpty(k))
+ continue; // skip emptry rows right side.
+ final int bpos = b.pos(k);
+ final int blen = b.size(k);
+ final int[] bixs = b.indexes(k);
+ final double[] bvals = b.values(k);
+ for(int i = rl; i < ru; i++)
+ mmDenseMatrixSparseRow(bpos, blen, bixs, bvals,
k, i, a, c);
+ }
+ }
+
+ private static void mmDenseMatrixSparseRow(int bpos, int blen, int[]
bixs, double[] bvals, int k, int i,
+ DenseBlock a, SparseBlockMCSR c) {
+ final double[] aval = a.values(i);
+ final int apos = a.pos(i);
+ if(!c.isAllocated(i))
+ c.allocate(i, Math.max(blen, 2));
+ final SparseRowVector srv = (SparseRowVector) c.get(i); //
guaranteed
+ for(int j = bpos; j < bpos + blen; j++) { // right side columns
+ final int bix = bixs[j];
+ final double bval = bvals[j];
+ srv.add(bix, bval * aval[apos + k]);
+ }
+
+ }
+
private static void matrixMultUltraSparseRightGeneric(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, int rl, int ru) {
final int cd = m1.clen;