This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 8af2559  [SYSTEMDS-2866] Extended sparse-sparse matrix multiplication
8af2559 is described below

commit 8af25591620265e582e85befc8167a040d282004
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Feb 20 17:05:48 2021 +0100

    [SYSTEMDS-2866] Extended sparse-sparse matrix multiplication
    
    We already had four kernels for sparse-sparse matrix multiplication, and
    several kernels for ultra-sparse matrix multiplication, but only
    ultra-sparse kernels directly worked with sparse outputs while
    sparse-sparse worked over dense outputs which created unnecessary
    allocation and GC overhead. This patch adds an additional sparse-sparse
    kernel that estimates the number of non-zeros in the output and directly
    works with the sparse output, by using temporary dense rows buffers and
    immediate compaction.
---
 .../runtime/controlprogram/caching/CacheBlock.java |  3 +
 .../org/apache/sysds/runtime/data/TensorBlock.java |  7 ++
 .../sysds/runtime/matrix/data/FrameBlock.java      |  7 ++
 .../sysds/runtime/matrix/data/LibMatrixMult.java   | 88 ++++++++++++++++++----
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  1 +
 5 files changed, 93 insertions(+), 13 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheBlock.java
index f95dbba..17741f8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheBlock.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.controlprogram.caching;
 
 import org.apache.hadoop.io.Writable;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
 
 
 /**
@@ -34,6 +35,8 @@ public interface CacheBlock extends Writable
 
        public int getNumColumns();
        
+       public DataCharacteristics getDataCharacteristics();
+       
        /**
         * Get the in-memory size in bytes of the cache block.
         * 
diff --git a/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java 
b/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
index 030209d..d2167c8 100644
--- a/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/TensorBlock.java
@@ -27,6 +27,8 @@ import 
org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.TensorCharacteristics;
 import org.apache.sysds.runtime.util.UtilFunctions;
 
 import java.io.DataInput;
@@ -249,6 +251,11 @@ public class TensorBlock implements CacheBlock, 
Externalizable {
        }
 
        @Override
+       public DataCharacteristics getDataCharacteristics() {
+               return new TensorCharacteristics(getLongDims(), -1);
+       }
+       
+       @Override
        public long getInMemorySize() {
                // TODO Auto-generated method stub
                return 0;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 227fa0c..d82d40b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -54,6 +54,8 @@ import 
org.apache.sysds.runtime.functionobjects.ValueComparisonFunction;
 import org.apache.sysds.runtime.instructions.cp.*;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.transform.encode.EncoderRecode;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.DMVUtils;
@@ -164,6 +166,11 @@ public class FrameBlock implements CacheBlock, 
Externalizable  {
                return (_schema != null) ? _schema.length : 0;
        }
 
+       @Override
+       public DataCharacteristics getDataCharacteristics() {
+               return new MatrixCharacteristics(getNumRows(), getNumColumns(), 
-1);
+       }
+       
        /**
         * Returns the schema of the frame block.
         *
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 02156f6..bd3310d 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -43,7 +43,9 @@ import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyze
 import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.DenseBlockFactory;
 import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlock.Type;
 import org.apache.sysds.runtime.data.SparseBlockCSR;
+import org.apache.sysds.runtime.data.SparseBlockFactory;
 import org.apache.sysds.runtime.data.SparseBlockMCSR;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -128,9 +130,11 @@ public class LibMatrixMult
                boolean m1Perm = m1.isSparsePermutationMatrix();
                boolean ultraSparse = (fixedRet && ret.sparse)
                        || (!fixedRet && isUltraSparseMatrixMult(m1, m2, 
m1Perm));
+               boolean sparse = !m1Perm && !ultraSparse && !fixedRet 
+                       && isSparseOutputMatrixMult(m1, m2);
                boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
                m2 = prepMatrixMultRightInput(m1, m2);
-               ret.sparse = ultraSparse;
+               ret.sparse = ultraSparse | sparse;
                ret.allocateBlock();
                
                //prepare row-upper for special cases of vector-matrix
@@ -145,7 +149,7 @@ public class LibMatrixMult
                else if(!m1.sparse && !m2.sparse)
                        matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, 
cu);
                else if(m1.sparse && m2.sparse)
-                       matrixMultSparseSparse(m1, m2, ret, pm2, 0, ru2);
+                       matrixMultSparseSparse(m1, m2, ret, pm2, sparse, 0, 
ru2);
                else if(m1.sparse)
                        matrixMultSparseDense(m1, m2, ret, pm2, 0, ru2);
                else
@@ -190,9 +194,10 @@ public class LibMatrixMult
                //we need to allocate sparse as well in order to prevent 
synchronization)
                boolean m1Perm = m1.isSparsePermutationMatrix();
                boolean ultraSparse = isUltraSparseMatrixMult(m1, m2, m1Perm);
+               boolean sparse = !ultraSparse && !m1Perm && 
isSparseOutputMatrixMult(m1, m2);
                boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
                m2 = prepMatrixMultRightInput(m1, m2);
-               ret.sparse = ultraSparse;
+               ret.sparse = ultraSparse | sparse;
                ret.allocateBlock();
                
                if (!ret.isThreadSafe()) {
@@ -201,7 +206,7 @@ public class LibMatrixMult
                }
                
                //prepare row-upper for special cases of vector-matrix / 
matrix-matrix
-               boolean pm2r = !ultraSparse && 
checkParMatrixMultRightInputRows(m1, m2, k);
+               boolean pm2r = !ultraSparse && !sparse && 
checkParMatrixMultRightInputRows(m1, m2, k);
                boolean pm2c = !ultraSparse && 
checkParMatrixMultRightInputCols(m1, m2, k, pm2r);
                int num = pm2r ? m2.rlen : pm2c ? m2.clen : m1.rlen; 
                
@@ -212,7 +217,7 @@ public class LibMatrixMult
                        ArrayList<MatrixMultTask> tasks = new ArrayList<>();
                        ArrayList<Integer> blklens = 
UtilFunctions.getBalancedBlockSizesDefault(num, k, (pm2r||pm2c));
                        for( int i=0, lb=0; i<blklens.size(); 
lb+=blklens.get(i), i++ )
-                               tasks.add(new MatrixMultTask(m1, m2, ret, tm2, 
pm2r, pm2c, m1Perm, lb, lb+blklens.get(i)));
+                               tasks.add(new MatrixMultTask(m1, m2, ret, tm2, 
pm2r, pm2c, m1Perm, sparse, lb, lb+blklens.get(i)));
                        //execute tasks
                        List<Future<Object>> taskret = pool.invokeAll(tasks);
                        pool.shutdown();
@@ -1409,24 +1414,26 @@ public class LibMatrixMult
                }
        }
 
-       private static void matrixMultSparseSparse(MatrixBlock m1, MatrixBlock 
m2, MatrixBlock ret, boolean pm2, int rl, int ru) {
+       private static void matrixMultSparseSparse(MatrixBlock m1, MatrixBlock 
m2, MatrixBlock ret, boolean pm2, boolean sparse, int rl, int ru) {
                SparseBlock a = m1.sparseBlock;
                SparseBlock b = m2.sparseBlock;
-               DenseBlock c = ret.getDenseBlock();
                int m = m1.rlen;
                int cd = m1.clen;
+               int n = m2.clen;
                
                // MATRIX-MATRIX (VV, MV not applicable here because V always 
dense)
                if(LOW_LEVEL_OPTIMIZATION) {
                        if( pm2 && m==1 )               //VECTOR-MATRIX
-                               matrixMultSparseSparseVM(a, b, c, rl, ru);
+                               matrixMultSparseSparseVM(a, b, 
ret.getDenseBlock(), rl, ru);
+                       else if( sparse )               //SPARSE OUPUT
+                               
ret.setNonZeros(matrixMultSparseSparseSparseMM(a, b, ret.getSparseBlock(), n, 
rl, ru));
                        else if( m2.nonZeros < 2048 )   //MATRIX-SMALL MATRIX
-                               matrixMultSparseSparseMMSmallRHS(a, b, c, rl, 
ru);
+                               matrixMultSparseSparseMMSmallRHS(a, b, 
ret.getDenseBlock(), rl, ru);
                        else                            //MATRIX-MATRIX
-                               matrixMultSparseSparseMM(a, b, c, m, cd, 
m1.nonZeros, rl, ru);
+                               matrixMultSparseSparseMM(a, b, 
ret.getDenseBlock(), m, cd, m1.nonZeros, rl, ru);
                }
                else {
-                       matrixMultSparseSparseMMGeneric(a, b, c, rl, ru);
+                       matrixMultSparseSparseMMGeneric(a, b, 
ret.getDenseBlock(), rl, ru);
                }
        }
        
@@ -1452,6 +1459,39 @@ public class LibMatrixMult
                        }
        }
        
+       private static long matrixMultSparseSparseSparseMM(SparseBlock a, 
SparseBlock b, SparseBlock c, int n, int rl, int ru) {
+               double[] tmp = new double[n];
+               long nnz = 0;
+               for( int i=rl; i<Math.min(ru, a.numRows()); i++ ) {
+                       if( a.isEmpty(i) ) continue;
+                       final int apos = a.pos(i);
+                       final int alen = a.size(i);
+                       int[] aix = a.indexes(i);
+                       double[] avals = a.values(i);
+                       //compute row output in dense buffer
+                       boolean hitNonEmpty = false;
+                       for(int k = apos; k < apos+alen; k++) {
+                               int aixk = aix[k];
+                               if( b.isEmpty(aixk) ) continue;
+                               vectMultiplyAdd(avals[k], b.values(aixk), tmp,
+                                       b.indexes(aixk), b.pos(aixk), 0, 
b.size(aixk));
+                               hitNonEmpty = true;
+                       }
+                       //copy dense buffer into sparse output (CSR or MCSR)
+                       if( hitNonEmpty ) {
+                               int rnnz = UtilFunctions.computeNnz(tmp, 0, n);
+                               nnz += rnnz;
+                               c.allocate(i, rnnz);
+                               for(int j=0; j<n; j++)
+                                       if( tmp[j] != 0 ) {
+                                               c.append(i, j, tmp[j]);
+                                               tmp[j] = 0;
+                                       }
+                       }
+               }
+               return nnz;
+       }
+       
        private static void matrixMultSparseSparseMMSmallRHS(SparseBlock a, 
SparseBlock b, DenseBlock c, int rl, int ru) {
                for( int i=rl; i<Math.min(ru, a.numRows()); i++ ) {
                        if( a.isEmpty(i) ) continue;
@@ -3887,6 +3927,17 @@ public class LibMatrixMult
                                && outSp < MatrixBlock.SPARSITY_TURN_POINT);
        }
        
+       public static boolean isSparseOutputMatrixMult(MatrixBlock m1, 
MatrixBlock m2) {
+               //output is a matrix (not vector), very likely sparse, and 
output rows fit into L1 cache
+               if( !(m1.sparse && m2.sparse && m1.rlen > 1 && m2.clen > 1) )
+                       return false;
+               double estSp = OptimizerUtils.getMatMultSparsity(
+                       m1.getSparsity(), m2.getSparsity(), m1.rlen, m1.clen, 
m2.clen, false);
+               long estNnz = (long)(estSp * m1.rlen * m2.clen);
+               boolean sparseOut = 
MatrixBlock.evalSparseFormatInMemory(m1.rlen, m2.clen, estNnz);
+               return m2.clen < 4*1024 && sparseOut;
+       }
+       
        public static boolean isOuterProductTSMM(int rlen, int clen, boolean 
left) {
                return left ? rlen == 1 & clen > 1 : rlen > 1 & clen == 1;
        }
@@ -3929,6 +3980,15 @@ public class LibMatrixMult
        }
        
        @SuppressWarnings("unused")
+       private static void compactSparseOutput(MatrixBlock ret) {
+               if( !ret.sparse || ret.nonZeros > ret.rlen || ret.isEmpty() 
+                       || ret.getSparseBlock() instanceof SparseBlockCSR )
+                       return; //early abort
+               ret.sparseBlock = SparseBlockFactory
+                       .copySparseBlock(Type.CSR, ret.sparseBlock, false);
+       }
+       
+       @SuppressWarnings("unused")
        private static void resetPosVect(int[] curk, SparseBlock sblock, int 
rl, int ru) {
                if( sblock instanceof SparseBlockMCSR ) {
                        //all rows start at position 0 (individual arrays)
@@ -3989,11 +4049,12 @@ public class LibMatrixMult
                private final boolean _pm2r; //par over m2 rows
                private final boolean _pm2c; //par over m2 rows
                private final boolean _m1Perm; //sparse permutation
+               private final boolean _sparse; //sparse output
                private final int _rl;
                private final int _ru;
 
                protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret,
-                       boolean tm2, boolean pm2r, boolean pm2c, boolean 
m1Perm, int rl, int ru )
+                       boolean tm2, boolean pm2r, boolean pm2c, boolean 
m1Perm, boolean sparse, int rl, int ru )
                {
                        _m1 = m1;
                        _m2 = m2;
@@ -4001,6 +4062,7 @@ public class LibMatrixMult
                        _pm2r = pm2r;
                        _pm2c = pm2c;
                        _m1Perm = m1Perm;
+                       _sparse = sparse;
                        _rl = rl;
                        _ru = ru;
                        
@@ -4031,7 +4093,7 @@ public class LibMatrixMult
                        else if(!_m1.sparse && !_m2.sparse)
                                matrixMultDenseDense(_m1, _m2, _ret, _tm2, 
_pm2r, rl, ru, cl, cu);
                        else if(_m1.sparse && _m2.sparse)
-                               matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, 
rl, ru);
+                               matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, 
_sparse, rl, ru);
                        else if(_m1.sparse)
                                matrixMultSparseDense(_m1, _m2, _ret, _pm2r, 
rl, ru);
                        else
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 1695465..a97e196 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -490,6 +490,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                return OptimizerUtils.getSparsity(rlen, clen, nonZeros);
        }
        
+       @Override
        public DataCharacteristics getDataCharacteristics() {
                return new MatrixCharacteristics(rlen, clen, -1, nonZeros);
        }

Reply via email to