Repository: systemml
Updated Branches:
  refs/heads/master 426d7fa0d -> a97bc53f7


[SYSTEMML-1955] Memory-efficient ultra-sparse matrix multiplication

So far the codepath for ultra-sparse matrix multiply was only used if
any of the two inputs classified as ultra-sparse in terms of its
sparsity and absolute number of non-zeros. For special cases of
permutation matrix multiply with large ultra-sparse matrices this is not
the case leading to sparse-dense matrix multiplication which allocates
the output in dense format. 

This patch improves the code path selection for ultra-sparse matrix
multiplication to include such permutation matrix multiplies, which
allocates the output in sparse format and thus is much more memory
efficient. Additionally, it also improves performance by avoiding
repeated reallocations. 

For example, on a scenario of 58825360 x 2519371 (nnz=58825360) times
2519371 x 300 (nnz=755810998), this patch enables the successful
execution in CP (before it crashed due to allocating a dense output
larger than 16GB). By avoiding reallocations, this patch further
improved the execution time from 112s to 24s due to significantly
reduced GC overhead.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0b5480b7
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0b5480b7
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0b5480b7

Branch: refs/heads/master
Commit: 0b5480b77de3391bec0ee1e60ab7d71de80e9605
Parents: 426d7fa
Author: Matthias Boehm <[email protected]>
Authored: Wed Oct 11 14:20:00 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Oct 12 01:13:05 2017 -0700

----------------------------------------------------------------------
 .../runtime/matrix/data/LibMatrixMult.java      | 88 +++++++++++---------
 .../sysml/runtime/matrix/data/MatrixBlock.java  | 22 +++--
 2 files changed, 66 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0b5480b7/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 0885508..aedf975 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -29,6 +29,7 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 
 import org.apache.commons.math3.util.FastMath;
+import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.lops.MapMultChain.ChainType;
 import org.apache.sysml.lops.WeightedCrossEntropy.WCeMMType;
 import org.apache.sysml.lops.WeightedDivMM.WDivMMType;
@@ -111,19 +112,20 @@ public class LibMatrixMult
        
        public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int rl, int ru, boolean examSparsity) 
                throws DMLRuntimeException
-       {       
+       {
                //check inputs / outputs
                if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
                        ret.examSparsity(); //turn empty dense into sparse
                        return;
                }
-                       
+               
                //Timing time = new Timing(true);
-                       
+               
                //pre-processing: output allocation
+               boolean ultraSparse = isUltraSparseMatrixMult(m1, m2);
                boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
                m2 = prepMatrixMultRightInput(m1, m2);
-               ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse());
+               ret.sparse = ultraSparse;
                if( !ret.sparse )
                        ret.allocateDenseBlock();
                
@@ -133,7 +135,7 @@ public class LibMatrixMult
                int cu = m2.clen;
                
                //core matrix mult computation
-               if( m1.isUltraSparse() || m2.isUltraSparse() )
+               if( ultraSparse )
                        matrixMultUltraSparse(m1, m2, ret, 0, ru2);
                else if(!m1.sparse && !m2.sparse)
                        matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, 
cu);
@@ -147,13 +149,11 @@ public class LibMatrixMult
                //post-processing: nnz/representation
                if( !ret.sparse )
                        ret.recomputeNonZeros();
-               
                if(examSparsity)
                        ret.examSparsity();
                
-               
                //System.out.println("MM 
("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x"
 +
-               //                            
"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+")
 in "+time.stop());
+               //              
"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+")
 in "+time.stop());
        }
        
        /**
@@ -168,13 +168,13 @@ public class LibMatrixMult
         */
        public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k) 
                throws DMLRuntimeException
-       {       
+       {
                //check inputs / outputs
                if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
                        ret.examSparsity(); //turn empty dense into sparse
                        return;
                }
-                       
+               
                //check too high additional vector-matrix memory requirements 
(fallback to sequential)
                //check too small workload in terms of flops (fallback to 
sequential too)
                if( m1.rlen == 1 && (8L * m2.clen * k > MEM_OVERHEAD_THRESHOLD 
|| !LOW_LEVEL_OPTIMIZATION || m2.clen==1 || m1.isUltraSparse() || 
m2.isUltraSparse()) 
@@ -188,15 +188,13 @@ public class LibMatrixMult
                
                //pre-processing: output allocation (in contrast to 
single-threaded,
                //we need to allocate sparse as well in order to prevent 
synchronization)
+               boolean ultraSparse = isUltraSparseMatrixMult(m1, m2);
                boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
                m2 = prepMatrixMultRightInput(m1, m2);
-               ret.sparse = (m1.isUltraSparse() || m2.isUltraSparse());
-               if( !ret.sparse )
-                       ret.allocateDenseBlock();
-               else
-                       ret.allocateSparseRowsBlock();
+               ret.sparse = ultraSparse;
+               ret.allocateBlock();
                
-               if (!ret.isThreadSafe()){
+               if (!ret.isThreadSafe()) {
                        matrixMult(m1, m2, ret);
                        return;
                }
@@ -216,7 +214,7 @@ public class LibMatrixMult
                        for( int i=0, lb=0; i<blklens.size(); 
lb+=blklens.get(i), i++ )
                                tasks.add(new MatrixMultTask(m1, m2, ret, tm2, 
pm2r, pm2c, lb, lb+blklens.get(i)));
                        //execute tasks
-                       List<Future<Object>> taskret = pool.invokeAll(tasks);   
+                       List<Future<Object>> taskret = pool.invokeAll(tasks);
                        pool.shutdown();
                        //aggregate partial results (nnz, ret for vector/matrix)
                        ret.nonZeros = 0; //reset after execute
@@ -233,12 +231,11 @@ public class LibMatrixMult
                        throw new DMLRuntimeException(ex);
                }
                
-               
                //post-processing (nnz maintained in parallel)
                ret.examSparsity();
                
                //System.out.println("MM k="+k+" 
("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x"
 +
-               //                            
"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+")
 in "+time.stop());
+               //              
"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+")
 in "+time.stop());
        }
        
        /**
@@ -752,7 +749,7 @@ public class LibMatrixMult
                try 
                {                       
                        ExecutorService pool = Executors.newFixedThreadPool(k);
-                       ArrayList<MatrixMultWDivTask> tasks = new 
ArrayList<MatrixMultWDivTask>();                      
+                       ArrayList<MatrixMultWDivTask> tasks = new 
ArrayList<MatrixMultWDivTask>();
                        //create tasks (for wdivmm-left, parallelization over 
columns;
                        //for wdivmm-right, parallelization over rows; both 
ensure disjoint results)
                        if( wt.isLeft() ) {
@@ -827,7 +824,7 @@ public class LibMatrixMult
                ret.allocateDenseBlock();
                
                try 
-               {                       
+               {
                        ExecutorService pool = Executors.newFixedThreadPool(k);
                        ArrayList<MatrixMultWCeTask> tasks = new 
ArrayList<MatrixMultWCeTask>();
                        int blklen = (int)(Math.ceil((double)mW.rlen/k));
@@ -1476,9 +1473,8 @@ public class LibMatrixMult
        private static void matrixMultUltraSparse(MatrixBlock m1, MatrixBlock 
m2, MatrixBlock ret, int rl, int ru) 
                throws DMLRuntimeException 
        {
-               //TODO perf sparse block, consider iterators
-               
-               boolean leftUS = m1.isUltraSparse();
+               final boolean leftUS = m1.isUltraSparse()
+                       || (m1.isUltraSparse(false) && !m2.isUltraSparse());
                final int m  = m1.rlen;
                final int cd = m1.clen;
                final int n  = m2.clen;
@@ -1486,6 +1482,7 @@ public class LibMatrixMult
                if( leftUS ) //left is ultra-sparse (IKJ)
                {
                        SparseBlock a = m1.sparseBlock;
+                       SparseBlock c = ret.sparseBlock;
                        boolean rightSparse = m2.sparse;
                        
                        for( int i=rl; i<ru; i++ )
@@ -1504,13 +1501,19 @@ public class LibMatrixMult
                                                        if( 
!m2.sparseBlock.isEmpty(aix) ) {
                                                                ret.rlen=m;
                                                                
ret.allocateSparseRowsBlock(false); //allocation on demand
-                                                               
ret.sparseBlock.set(i, m2.sparseBlock.get(aix), true); 
+                                                               boolean ldeep = 
(m2.sparseBlock instanceof SparseBlockMCSR);
+                                                               
ret.sparseBlock.set(i, m2.sparseBlock.get(aix), ldeep);
                                                                ret.nonZeros += 
ret.sparseBlock.size(i);
                                                        }
                                                }
                                                else { //dense right matrix 
(append all values)
-                                                       for( int j=0; j<n; j++ )
-                                                               
ret.appendValue(i, j, m2.quickGetValue(aix, j));
+                                                       int lnnz = 
(int)m2.recomputeNonZeros(aix, aix, 0, n-1);
+                                                       if( lnnz > 0 ) {
+                                                               c.allocate(i, 
lnnz); //allocate once
+                                                               for( int j=0; 
j<n; j++ )
+                                                                       
c.append(i, j, m2.quickGetValue(aix, j));
+                                                               ret.nonZeros += 
lnnz;
+                                                       }
                                                }
                                        }
                                        else //GENERAL CASE
@@ -1536,13 +1539,13 @@ public class LibMatrixMult
                        SparseBlock b = m2.sparseBlock;
                        
                        for(int k = 0; k < cd; k++ ) 
-                       {                       
+                       {
                                if( !b.isEmpty(k) ) 
                                {
                                        int bpos = b.pos(k);
                                        int blen = b.size(k);
                                        int[] bixs = b.indexes(k);
-                                       double[] bvals = b.values(k);           
                                                
+                                       double[] bvals = b.values(k);
                                        for( int j=bpos; j<bpos+blen; j++ )
                                        {
                                                double bval = bvals[j];
@@ -3552,6 +3555,14 @@ public class LibMatrixMult
                                && m2.clen > k * 1024 && m1.rlen < k * 32 && 
!pm2r
                                && 8*m1.rlen*m1.clen < 256*1024 ); //lhs fits 
in L2 cache
        }
+       
+       public static boolean isUltraSparseMatrixMult(MatrixBlock m1, 
MatrixBlock m2) {
+               //note: ultra-sparse matrix mult implies also sparse outputs, 
hence we need
+               //to be conservative an cannot use this for all ultra-sparse 
matrices.
+               return (m1.isUltraSparse() || m2.isUltraSparse()) //base case
+                       || (m1.isUltraSparsePermutationMatrix() 
+                               && OptimizerUtils.getSparsity(m2.rlen, m2.clen, 
m2.nonZeros)<1.0);
+       }
 
        private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1, 
MatrixBlock m2 ) 
                throws DMLRuntimeException
@@ -3643,17 +3654,16 @@ public class LibMatrixMult
 
        private static class MatrixMultTask implements Callable<Object> 
        {
-               private MatrixBlock _m1  = null;
-               private MatrixBlock _m2  = null;
+               private final MatrixBlock _m1;
+               private final MatrixBlock _m2;
                private MatrixBlock _ret = null;
-               private boolean _tm2 = false; //transposed m2
-               private boolean _pm2r = false; //par over m2 rows
-               private boolean _pm2c = false; //par over m2 rows
-               
-               private int _rl = -1;
-               private int _ru = -1;
+               private final boolean _tm2; //transposed m2
+               private final boolean _pm2r; //par over m2 rows
+               private final boolean _pm2c; //par over m2 rows
+               private final int _rl;
+               private final int _ru;
 
-               protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, 
+               protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret,
                                boolean tm2, boolean pm2r, boolean pm2c, int 
rl, int ru )
                {
                        _m1 = m1;
@@ -3687,7 +3697,7 @@ public class LibMatrixMult
                                _ret.allocateDenseBlock();
                        
                        //compute block matrix multiplication
-                       if( _m1.isUltraSparse() || _m2.isUltraSparse() )
+                       if( _ret.sparse ) //ultra-sparse
                                matrixMultUltraSparse(_m1, _m2, _ret, rl, ru);
                        else if(!_m1.sparse && !_m2.sparse)
                                matrixMultDenseDense(_m1, _m2, _ret, _tm2, 
_pm2r, rl, ru, cl, cu);

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b5480b7/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index ff05fa0..8f9ae9e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -916,16 +916,28 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
         * 
         * @return true if sparse
         */
-       public boolean isInSparseFormat()
-       {
+       public boolean isInSparseFormat() {
                return sparse;
        }
+       
+       public boolean isUltraSparse() {
+               return isUltraSparse(true);
+       }
 
-       public boolean isUltraSparse()
-       {
+       public boolean isUltraSparse(boolean checkNnz) {
                double sp = ((double)nonZeros/rlen)/clen;
                //check for sparse representation in order to account for 
vectors in dense
-               return sparse && sp<ULTRA_SPARSITY_TURN_POINT && nonZeros<40;
+               return sparse && sp<ULTRA_SPARSITY_TURN_POINT && (!checkNnz || 
nonZeros<40);
+       }
+       
+       public boolean isUltraSparsePermutationMatrix() {
+               if( !isUltraSparse(false) )
+                       return false;
+               boolean isPM = true;
+               SparseBlock sblock = getSparseBlock();
+               for( int i=0; i<rlen & isPM; i++ )
+                       isPM &= sblock.isEmpty(i) || sblock.size(i) == 1;
+               return isPM;
        }
 
        /**

Reply via email to