This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new fd1ba7c520 [MINOR] MM Specializations
fd1ba7c520 is described below

commit fd1ba7c520cb565a46554d383d83403b9fcb196f
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Tue Feb 4 12:47:10 2025 +0100

    [MINOR] MM Specializations
    
    This commit adds specializations for matrix multiplication with the 
following:
    
    1. dense-sparse with sparse output
    2. ultra sparse out dense dense in.
    3. sparse out on sparse vector right side in.
    
    Furthermore, I modified the call stack to branch to the native mm
    inside LibMatrixMult, to allow easy native support for CLA by calling
    LibMatrixMult, instead of having to go through a MatrixBlock.
    
    Closes #2212
---
 .../sysds/runtime/matrix/data/LibMatrixMult.java   | 143 ++++++++++++++++++++-
 .../sysds/runtime/matrix/data/LibMatrixNative.java |   2 +-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |   5 +-
 .../test/component/matrix/MatrixMultiplyTest.java  |  37 ++++--
 4 files changed, 163 insertions(+), 24 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index c4eddd90fa..66f7c3c944 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -159,6 +159,13 @@ public class LibMatrixMult
         * @return ret Matrix Block
         */
        public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k) {
+               if(NativeHelper.isNativeLibraryLoaded())
+                       return LibMatrixNative.matrixMult(m1, m2, ret, k);
+               else
+                       return matrixMult(m1, m2, ret, false, k);
+       }
+
+       public static MatrixBlock matrixMultNonNative(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, int k) {
                return matrixMult(m1, m2, ret, false, k);
        }
        
@@ -256,7 +263,7 @@ public class LibMatrixMult
                // core matrix mult computation
                if(ultraSparse && !fixedRet)
                        matrixMultUltraSparse(m1, m2, ret, m1Perm, 0, ru2);
-               else if(!m1.sparse && !m2.sparse)
+               else if(!m1.sparse && !m2.sparse && !ret.sparse)
                        matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, 
m2.clen);
                else if(m1.sparse && m2.sparse)
                        matrixMultSparseSparse(m1, m2, ret, pm2, sparse, 0, 
ru2);
@@ -1257,6 +1264,100 @@ public class LibMatrixMult
        }
 
        private static void matrixMultDenseSparse(MatrixBlock m1, MatrixBlock 
m2, MatrixBlock ret, boolean pm2, int rl, int ru) {
+               if(ret.isInSparseFormat()){
+                       if(!m1.sparse && !m2.sparse)
+                               matrixMultDenseDenseOutSparse(m1,m2,ret, pm2, 
rl, ru);
+                       else 
+                               matrixMultDenseSparseOutSparse(m1, m2, ret, 
pm2, rl, ru);
+               }
+               else
+                       matrixMultDenseSparseOutDense(m1, m2, ret, pm2, rl, ru);
+       }
+
+
+       private static void matrixMultDenseDenseOutSparse(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, boolean pm2,
+               int rl, int ru) {
+               final DenseBlock a = m1.getDenseBlock();
+               final DenseBlock b = m2.getDenseBlock();
+               final SparseBlock c = ret.getSparseBlock();
+               final int m = m1.rlen;  // rows left
+               final int cd = m1.clen; // common dim
+               final int n = m2.clen;
+
+               final int rl1 = pm2 ? 0 : rl;
+               final int ru1 = pm2 ? m : ru;
+               final int rl2 = pm2 ? rl : 0;
+               final int ru2 = pm2 ? ru : cd;
+
+               final int blocksizeK = 32;
+               final int blocksizeI = 32;
+
+               for(int bi = rl1; bi < ru1; bi += blocksizeI) {
+                       for(int bk = rl2, bimin = Math.min(ru1, bi + 
blocksizeI); bk < ru2; bk += blocksizeK) {
+                               final int bkmin = Math.min(ru2, bk + 
blocksizeK);
+                               // core sub block matrix multiplication
+                               for(int i = bi; i < bimin; i++) { // rows left
+                                       final double[] avals = a.values(i);
+                                       final int aix = a.pos(i);
+                                       for(int k = bk; k < bkmin; k++) { // 
common dimension
+                                               final double aval = avals[aix + 
k];
+                                               if(aval != 0) {
+                                                       final double[] bvals = 
b.values(k);
+                                                       final int bpos = 
b.pos(k);
+                                                       for(int j = 0; j < n; 
j++) {
+                                                               final double bv 
= bvals[bpos + j];
+                                                               c.add(i, j, 
aval * bv);
+                                                       }
+                                               }
+                                       }
+                               }
+                       }
+               }
+       }
+
+
+       private static void matrixMultDenseSparseOutSparse(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, boolean pm2,
+               int rl, int ru) {
+               final DenseBlock a = m1.getDenseBlock();
+               final SparseBlock b = m2.getSparseBlock();
+               final SparseBlock c = ret.getSparseBlock();
+               final int m = m1.rlen;  // rows left
+               final int cd = m1.clen; // common dim
+
+               final int rl1 = pm2 ? 0 : rl;
+               final int ru1 = pm2 ? m : ru;
+               final int rl2 = pm2 ? rl : 0;
+               final int ru2 = pm2 ? ru : cd;
+
+               final int blocksizeK = 32;
+               final int blocksizeI = 32;
+
+               for(int bi = rl1; bi < ru1; bi += blocksizeI) {
+                       for(int bk = rl2, bimin = Math.min(ru1, bi + 
blocksizeI); bk < ru2; bk += blocksizeK) {
+                               final int bkmin = Math.min(ru2, bk + 
blocksizeK);
+                               // core sub block matrix multiplication
+                               for(int i = bi; i < bimin; i++) { // rows left
+                                       final double[] avals = a.values(i);
+                                       final int aix = a.pos(i);
+                                       for(int k = bk; k < bkmin; k++) { // 
common dimension
+                                               final double aval = avals[aix + 
k];
+                                               if(aval == 0 || b.isEmpty(k))
+                                                       continue;
+                                               final int[] bIdx = b.indexes(k);
+                                               final double[] bVals = 
b.values(k);
+                                               final int bPos = b.pos(k);
+                                               final int bEnd = bPos + 
b.size(k);
+                                               for(int j = bPos; j < bEnd ; 
j++){
+                                                       c.add(i, bIdx[j], aval 
* bVals[j]);
+                                               }
+                                       }
+                               }
+                       }
+               }
+       }
+
+       private static void matrixMultDenseSparseOutDense(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl,
+               int ru) {
                DenseBlock a = m1.getDenseBlock();
                DenseBlock c = ret.getDenseBlock();
                int m = m1.rlen;
@@ -1907,8 +2008,10 @@ public class LibMatrixMult
                if(ret.isInSparseFormat()){
                        if(m1.isInSparseFormat())
                                
matrixMultUltraSparseRightSparseMCSRLeftSparseOut(m1, m2, ret, rl, ru);
-                       else
+                       else if (m2.isInSparseFormat())
                                
matrixMultUltraSparseRightDenseLeftSparseOut(m1, m2, ret, rl, ru);
+                       else 
+                               matrixMultUltraSparseDenseInput(m1, m2, ret, 
rl, ru);
                }
                else if(ret.getDenseBlock().isContiguous())
                        matrixMultUltraSparseRightDenseOut(m1, m2, ret, rl, ru);
@@ -1990,6 +2093,30 @@ public class LibMatrixMult
                }
        }
 
+       private static void matrixMultUltraSparseDenseInput(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, int rl, int ru){
+               final int cd = m1.clen;
+               final int rc = m2.clen;
+               final DenseBlock a = m1.denseBlock;
+               final DenseBlock b = m2.denseBlock;
+               final SparseBlockMCSR c = (SparseBlockMCSR) ret.sparseBlock;
+
+               for(int i = rl; i < ru; i++) {
+                       // it is known that the left matrix is most likely 
containing many zeros.
+                       final double[] av = a.values(i);
+                       final int pos = a.pos(i);
+                       for(int k = 0; k < cd; k++) {
+                               final double v = av[pos + k];
+                               if(v != 0) {
+                                       final double[] bv = b.values(k);
+                                       final int posb = b.pos(k);
+                                       for(int j = 0; j < rc; j++) {
+                                               c.add(i,j, bv[posb + j] * v);
+                                       }
+                               }
+                       }
+               }
+       }
+
        private static void mmDenseMatrixSparseRow(int bpos, int blen, int[] 
bixs, double[] bvals, int k, int i,
                DenseBlock a, SparseBlockMCSR c) {
                final double[] aval = a.values(i);
@@ -4419,6 +4546,8 @@ public class LibMatrixMult
        }
        
        public static boolean isSparseOutputMatrixMult(MatrixBlock m1, 
MatrixBlock m2) {
+               if(m2.rlen == 1 && m2.nonZeros < m2.clen / 4) // vector right 
... that is sparse.
+                       return true;
                //output is a matrix (not vector), very likely sparse, and 
output rows fit into L1 cache
                if( !(m1.sparse && m2.sparse && m1.rlen > 1 && m2.clen > 1) )
                        return false;
@@ -4551,7 +4680,7 @@ public class LibMatrixMult
                private final boolean _pm2r; //par over m2 rows
                private final boolean _pm2c; //par over m2 rows
                private final boolean _m1Perm; //sparse permutation
-               private final boolean _sparse; //sparse output
+               // private final boolean _sparse; //sparse output
                private final int _rl;
                private final int _ru;
                private final ConcurrentHashMap<double[], double[]> _cache;
@@ -4565,7 +4694,7 @@ public class LibMatrixMult
                        _pm2r = pm2r;
                        _pm2c = pm2c;
                        _m1Perm = m1Perm;
-                       _sparse = sparse;
+                       // _sparse = sparse;
                        _rl = rl;
                        _ru = ru;
                        _cache = cache;
@@ -4594,14 +4723,14 @@ public class LibMatrixMult
                        //compute block matrix multiplication
                        if( _ret.sparse ) //ultra-sparse
                                matrixMultUltraSparse(_m1, _m2, _ret, _m1Perm, 
rl, ru);
-                       else if(!_m1.sparse && !_m2.sparse)
+                       else if(!_m1.sparse && !_m2.sparse && !_ret.sparse){
                                if(_m1.denseBlock instanceof 
DenseBlockFP64DEDUP && _m2.denseBlock.isContiguous(0,_m1.clen) && cl == 0 && cu 
== _m2.clen)
                                        
matrixMultDenseDenseMMDedup((DenseBlockFP64DEDUP) _m1.denseBlock, 
_m2.denseBlock, (DenseBlockFP64DEDUP) _ret.denseBlock, _m2.clen, _m1.clen, rl, 
ru, _cache);
                                else
                                        matrixMultDenseDense(_m1, _m2, _ret, 
_tm2, _pm2r, rl, ru, cl, cu);
-
+                       }
                        else if(_m1.sparse && _m2.sparse)
-                               matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, 
_sparse, rl, ru);
+                               matrixMultSparseSparse(_m1, _m2, _ret, _pm2r,  
_ret.sparse, rl, ru);
                        else if(_m1.sparse)
                                matrixMultSparseDense(_m1, _m2, _ret, _pm2r, 
rl, ru);
                        else
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
index 4c8dc98ced..f8edf8af81 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
@@ -140,7 +140,7 @@ public class LibMatrixNative
                else
                        LOG.warn("Was valid for native MM but native lib was 
not loaded");
                
-               return LibMatrixMult.matrixMult(m1, m2, ret, k);
+               return LibMatrixMult.matrixMultNonNative(m1, m2, ret, k);
        }
        
        public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean 
leftTrans, int k) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 057811d2db..af068d2523 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4994,10 +4994,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
                checkAggregateBinaryOperations(m1, m2, op);
                final int k = op.getNumThreads();
-               if(NativeHelper.isNativeLibraryLoaded())
-                       return LibMatrixNative.matrixMult(m1, m2, ret, k);
-               else 
-                       return LibMatrixMult.matrixMult(m1, m2, ret, k);
+               return LibMatrixMult.matrixMult(m1, m2, ret, k);
        }
 
        protected void checkAggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, AggregateBinaryOperator op) {
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java 
b/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
index a463d49b50..0934898bcc 100644
--- 
a/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
@@ -52,10 +52,13 @@ public class MatrixMultiplyTest {
        // parallelization degree
        private final int k;
 
-       public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int 
p) {
+       public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int 
p, boolean self) {
                try {
                        this.left = 
TestUtils.ceil(TestUtils.generateTestMatrixBlock(i, j, -10, 10, i == 1 && j == 
1 ? 1 : s, 13));
-                       this.right = 
TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k == 
1 ? 1 : s2, 14));
+                       if(self)
+                               this.right = left;
+                       else 
+                               this.right = 
TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k == 
1 ? 1 : s2, 14));
 
                        this.exp = multiply(left, right, 1);
                        this.k = p;
@@ -83,7 +86,7 @@ public class MatrixMultiplyTest {
                                                for(int i = 0; i < is.length; 
i++) {
                                                        for(int j = 0; j < 
js.length; j++) {
                                                                for(int k = 0; 
k < ks.length; k++) {
-                                                                       
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2], 
par[p]});
+                                                                       
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2], 
par[p], false});
                                                                }
                                                        }
                                                }
@@ -91,15 +94,25 @@ public class MatrixMultiplyTest {
                                }
                        }
 
-                       tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0001, 
6});
-                       tests.add(new Object[]{1000, 100, 1000, 0.01, 0.3, 6});
-                       tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0005, 
6});
-                       tests.add(new Object[]{1000, 100, 1000, 0.005, 0.3, 6});
-
-                       tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0001, 
6});
-                       tests.add(new Object[]{1000, 100, 1000, 0.01, 0.6, 6});
-                       tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0005, 
6});
-                       tests.add(new Object[]{1000, 100, 1000, 0.005, 0.6, 6});
+                       tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0001, 6, 
false});
+                       tests.add(new Object[]{1000, 100, 1000, 0.01, 0.3, 6, 
false});
+                       tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0005, 6, 
false});
+                       tests.add(new Object[]{1000, 100, 1000, 0.005, 0.3, 6, 
false});
+
+                       tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0001, 6, 
false});
+                       tests.add(new Object[]{1000, 100, 1000, 0.01, 0.6, 6, 
false});
+                       tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0005, 6, 
false});
+                       tests.add(new Object[]{1000, 100, 1000, 0.005, 0.6, 6, 
false});
+                       
+                       // 0.00004 ultra sparse turn point
+                       tests.add(new Object[]{100, 100, 10000, 0.5, 0.00003, 
6, false});
+                       tests.add(new Object[]{10000, 100, 100, 0.00003, 0.6, 
6, false});
+
+
+                       tests.add(new Object[]{3, 10, 100000, 1.0, 0.00003, 6, 
false});
+                       tests.add(new Object[]{100000, 10, 3, 0.00003, 1.0, 6, 
false});
+                       
+                       tests.add(new Object[]{1000, 1000, 1000, 0.005, 0.6, 6, 
true});
 
                }
                catch(Exception e) {

Reply via email to