This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new fd1ba7c520 [MINOR] MM Specializations
fd1ba7c520 is described below
commit fd1ba7c520cb565a46554d383d83403b9fcb196f
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Tue Feb 4 12:47:10 2025 +0100
[MINOR] MM Specializations
This commit adds specializations for matrix multiplication with the
following:
1. dense-sparse with sparse output
2. ultra sparse out dense dense in.
3. sparse out on sparse vector right side in.
Furthermore, I modified the call stack to branch to the native mm
inside LibMatrixMult, to allow easy native support for CLA by calling
LibMatrixMult, instead of having to go through a MatrixBlock.
Closes #2212
---
.../sysds/runtime/matrix/data/LibMatrixMult.java | 143 ++++++++++++++++++++-
.../sysds/runtime/matrix/data/LibMatrixNative.java | 2 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 5 +-
.../test/component/matrix/MatrixMultiplyTest.java | 37 ++++--
4 files changed, 163 insertions(+), 24 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index c4eddd90fa..66f7c3c944 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -159,6 +159,13 @@ public class LibMatrixMult
* @return ret Matrix Block
*/
public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2,
MatrixBlock ret, int k) {
+ if(NativeHelper.isNativeLibraryLoaded())
+ return LibMatrixNative.matrixMult(m1, m2, ret, k);
+ else
+ return matrixMult(m1, m2, ret, false, k);
+ }
+
+ public static MatrixBlock matrixMultNonNative(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, int k) {
return matrixMult(m1, m2, ret, false, k);
}
@@ -256,7 +263,7 @@ public class LibMatrixMult
// core matrix mult computation
if(ultraSparse && !fixedRet)
matrixMultUltraSparse(m1, m2, ret, m1Perm, 0, ru2);
- else if(!m1.sparse && !m2.sparse)
+ else if(!m1.sparse && !m2.sparse && !ret.sparse)
matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0,
m2.clen);
else if(m1.sparse && m2.sparse)
matrixMultSparseSparse(m1, m2, ret, pm2, sparse, 0,
ru2);
@@ -1257,6 +1264,100 @@ public class LibMatrixMult
}
private static void matrixMultDenseSparse(MatrixBlock m1, MatrixBlock
m2, MatrixBlock ret, boolean pm2, int rl, int ru) {
+ if(ret.isInSparseFormat()){
+ if(!m1.sparse && !m2.sparse)
+ matrixMultDenseDenseOutSparse(m1,m2,ret, pm2,
rl, ru);
+ else
+ matrixMultDenseSparseOutSparse(m1, m2, ret,
pm2, rl, ru);
+ }
+ else
+ matrixMultDenseSparseOutDense(m1, m2, ret, pm2, rl, ru);
+ }
+
+
+ private static void matrixMultDenseDenseOutSparse(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, boolean pm2,
+ int rl, int ru) {
+ final DenseBlock a = m1.getDenseBlock();
+ final DenseBlock b = m2.getDenseBlock();
+ final SparseBlock c = ret.getSparseBlock();
+ final int m = m1.rlen; // rows left
+ final int cd = m1.clen; // common dim
+ final int n = m2.clen;
+
+ final int rl1 = pm2 ? 0 : rl;
+ final int ru1 = pm2 ? m : ru;
+ final int rl2 = pm2 ? rl : 0;
+ final int ru2 = pm2 ? ru : cd;
+
+ final int blocksizeK = 32;
+ final int blocksizeI = 32;
+
+ for(int bi = rl1; bi < ru1; bi += blocksizeI) {
+ for(int bk = rl2, bimin = Math.min(ru1, bi +
blocksizeI); bk < ru2; bk += blocksizeK) {
+ final int bkmin = Math.min(ru2, bk +
blocksizeK);
+ // core sub block matrix multiplication
+ for(int i = bi; i < bimin; i++) { // rows left
+ final double[] avals = a.values(i);
+ final int aix = a.pos(i);
+ for(int k = bk; k < bkmin; k++) { //
common dimension
+ final double aval = avals[aix +
k];
+ if(aval != 0) {
+ final double[] bvals =
b.values(k);
+ final int bpos =
b.pos(k);
+ for(int j = 0; j < n;
j++) {
+ final double bv
= bvals[bpos + j];
+ c.add(i, j,
aval * bv);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+
+ private static void matrixMultDenseSparseOutSparse(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, boolean pm2,
+ int rl, int ru) {
+ final DenseBlock a = m1.getDenseBlock();
+ final SparseBlock b = m2.getSparseBlock();
+ final SparseBlock c = ret.getSparseBlock();
+ final int m = m1.rlen; // rows left
+ final int cd = m1.clen; // common dim
+
+ final int rl1 = pm2 ? 0 : rl;
+ final int ru1 = pm2 ? m : ru;
+ final int rl2 = pm2 ? rl : 0;
+ final int ru2 = pm2 ? ru : cd;
+
+ final int blocksizeK = 32;
+ final int blocksizeI = 32;
+
+ for(int bi = rl1; bi < ru1; bi += blocksizeI) {
+ for(int bk = rl2, bimin = Math.min(ru1, bi +
blocksizeI); bk < ru2; bk += blocksizeK) {
+ final int bkmin = Math.min(ru2, bk +
blocksizeK);
+ // core sub block matrix multiplication
+ for(int i = bi; i < bimin; i++) { // rows left
+ final double[] avals = a.values(i);
+ final int aix = a.pos(i);
+ for(int k = bk; k < bkmin; k++) { //
common dimension
+ final double aval = avals[aix +
k];
+ if(aval == 0 || b.isEmpty(k))
+ continue;
+ final int[] bIdx = b.indexes(k);
+ final double[] bVals =
b.values(k);
+ final int bPos = b.pos(k);
+ final int bEnd = bPos +
b.size(k);
+ for(int j = bPos; j < bEnd ;
j++){
+ c.add(i, bIdx[j], aval
* bVals[j]);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ private static void matrixMultDenseSparseOutDense(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl,
+ int ru) {
DenseBlock a = m1.getDenseBlock();
DenseBlock c = ret.getDenseBlock();
int m = m1.rlen;
@@ -1907,8 +2008,10 @@ public class LibMatrixMult
if(ret.isInSparseFormat()){
if(m1.isInSparseFormat())
matrixMultUltraSparseRightSparseMCSRLeftSparseOut(m1, m2, ret, rl, ru);
- else
+ else if (m2.isInSparseFormat())
matrixMultUltraSparseRightDenseLeftSparseOut(m1, m2, ret, rl, ru);
+ else
+ matrixMultUltraSparseDenseInput(m1, m2, ret,
rl, ru);
}
else if(ret.getDenseBlock().isContiguous())
matrixMultUltraSparseRightDenseOut(m1, m2, ret, rl, ru);
@@ -1990,6 +2093,30 @@ public class LibMatrixMult
}
}
+ private static void matrixMultUltraSparseDenseInput(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, int rl, int ru){
+ final int cd = m1.clen;
+ final int rc = m2.clen;
+ final DenseBlock a = m1.denseBlock;
+ final DenseBlock b = m2.denseBlock;
+ final SparseBlockMCSR c = (SparseBlockMCSR) ret.sparseBlock;
+
+ for(int i = rl; i < ru; i++) {
+ // it is known that the left matrix is most likely
containing many zeros.
+ final double[] av = a.values(i);
+ final int pos = a.pos(i);
+ for(int k = 0; k < cd; k++) {
+ final double v = av[pos + k];
+ if(v != 0) {
+ final double[] bv = b.values(k);
+ final int posb = b.pos(k);
+ for(int j = 0; j < rc; j++) {
+ c.add(i,j, bv[posb + j] * v);
+ }
+ }
+ }
+ }
+ }
+
private static void mmDenseMatrixSparseRow(int bpos, int blen, int[]
bixs, double[] bvals, int k, int i,
DenseBlock a, SparseBlockMCSR c) {
final double[] aval = a.values(i);
@@ -4419,6 +4546,8 @@ public class LibMatrixMult
}
public static boolean isSparseOutputMatrixMult(MatrixBlock m1,
MatrixBlock m2) {
+ if(m2.rlen == 1 && m2.nonZeros < m2.clen / 4) // vector right
... that is sparse.
+ return true;
//output is a matrix (not vector), very likely sparse, and
output rows fit into L1 cache
if( !(m1.sparse && m2.sparse && m1.rlen > 1 && m2.clen > 1) )
return false;
@@ -4551,7 +4680,7 @@ public class LibMatrixMult
private final boolean _pm2r; //par over m2 rows
private final boolean _pm2c; //par over m2 rows
private final boolean _m1Perm; //sparse permutation
- private final boolean _sparse; //sparse output
+ // private final boolean _sparse; //sparse output
private final int _rl;
private final int _ru;
private final ConcurrentHashMap<double[], double[]> _cache;
@@ -4565,7 +4694,7 @@ public class LibMatrixMult
_pm2r = pm2r;
_pm2c = pm2c;
_m1Perm = m1Perm;
- _sparse = sparse;
+ // _sparse = sparse;
_rl = rl;
_ru = ru;
_cache = cache;
@@ -4594,14 +4723,14 @@ public class LibMatrixMult
//compute block matrix multiplication
if( _ret.sparse ) //ultra-sparse
matrixMultUltraSparse(_m1, _m2, _ret, _m1Perm,
rl, ru);
- else if(!_m1.sparse && !_m2.sparse)
+ else if(!_m1.sparse && !_m2.sparse && !_ret.sparse){
if(_m1.denseBlock instanceof
DenseBlockFP64DEDUP && _m2.denseBlock.isContiguous(0,_m1.clen) && cl == 0 && cu
== _m2.clen)
matrixMultDenseDenseMMDedup((DenseBlockFP64DEDUP) _m1.denseBlock,
_m2.denseBlock, (DenseBlockFP64DEDUP) _ret.denseBlock, _m2.clen, _m1.clen, rl,
ru, _cache);
else
matrixMultDenseDense(_m1, _m2, _ret,
_tm2, _pm2r, rl, ru, cl, cu);
-
+ }
else if(_m1.sparse && _m2.sparse)
- matrixMultSparseSparse(_m1, _m2, _ret, _pm2r,
_sparse, rl, ru);
+ matrixMultSparseSparse(_m1, _m2, _ret, _pm2r,
_ret.sparse, rl, ru);
else if(_m1.sparse)
matrixMultSparseDense(_m1, _m2, _ret, _pm2r,
rl, ru);
else
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
index 4c8dc98ced..f8edf8af81 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
@@ -140,7 +140,7 @@ public class LibMatrixNative
else
LOG.warn("Was valid for native MM but native lib was
not loaded");
- return LibMatrixMult.matrixMult(m1, m2, ret, k);
+ return LibMatrixMult.matrixMultNonNative(m1, m2, ret, k);
}
public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean
leftTrans, int k) {
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 057811d2db..af068d2523 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4994,10 +4994,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
public MatrixBlock aggregateBinaryOperations(MatrixBlock m1,
MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
checkAggregateBinaryOperations(m1, m2, op);
final int k = op.getNumThreads();
- if(NativeHelper.isNativeLibraryLoaded())
- return LibMatrixNative.matrixMult(m1, m2, ret, k);
- else
- return LibMatrixMult.matrixMult(m1, m2, ret, k);
+ return LibMatrixMult.matrixMult(m1, m2, ret, k);
}
protected void checkAggregateBinaryOperations(MatrixBlock m1,
MatrixBlock m2, AggregateBinaryOperator op) {
diff --git
a/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
b/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
index a463d49b50..0934898bcc 100644
---
a/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
+++
b/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
@@ -52,10 +52,13 @@ public class MatrixMultiplyTest {
// parallelization degree
private final int k;
- public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int
p) {
+ public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int
p, boolean self) {
try {
this.left =
TestUtils.ceil(TestUtils.generateTestMatrixBlock(i, j, -10, 10, i == 1 && j ==
1 ? 1 : s, 13));
- this.right =
TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k ==
1 ? 1 : s2, 14));
+ if(self)
+ this.right = left;
+ else
+ this.right =
TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k ==
1 ? 1 : s2, 14));
this.exp = multiply(left, right, 1);
this.k = p;
@@ -83,7 +86,7 @@ public class MatrixMultiplyTest {
for(int i = 0; i < is.length;
i++) {
for(int j = 0; j <
js.length; j++) {
for(int k = 0;
k < ks.length; k++) {
-
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2],
par[p]});
+
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2],
par[p], false});
}
}
}
@@ -91,15 +94,25 @@ public class MatrixMultiplyTest {
}
}
- tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0001,
6});
- tests.add(new Object[]{1000, 100, 1000, 0.01, 0.3, 6});
- tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0005,
6});
- tests.add(new Object[]{1000, 100, 1000, 0.005, 0.3, 6});
-
- tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0001,
6});
- tests.add(new Object[]{1000, 100, 1000, 0.01, 0.6, 6});
- tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0005,
6});
- tests.add(new Object[]{1000, 100, 1000, 0.005, 0.6, 6});
+ tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0001, 6,
false});
+ tests.add(new Object[]{1000, 100, 1000, 0.01, 0.3, 6,
false});
+ tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0005, 6,
false});
+ tests.add(new Object[]{1000, 100, 1000, 0.005, 0.3, 6,
false});
+
+ tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0001, 6,
false});
+ tests.add(new Object[]{1000, 100, 1000, 0.01, 0.6, 6,
false});
+ tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0005, 6,
false});
+ tests.add(new Object[]{1000, 100, 1000, 0.005, 0.6, 6,
false});
+
+ // 0.00004 ultra sparse turn point
+ tests.add(new Object[]{100, 100, 10000, 0.5, 0.00003,
6, false});
+ tests.add(new Object[]{10000, 100, 100, 0.00003, 0.6,
6, false});
+
+
+ tests.add(new Object[]{3, 10, 100000, 1.0, 0.00003, 6,
false});
+ tests.add(new Object[]{100000, 10, 3, 0.00003, 1.0, 6,
false});
+
+ tests.add(new Object[]{1000, 1000, 1000, 0.005, 0.6, 6,
true});
}
catch(Exception e) {