Repository: systemml
Updated Branches:
  refs/heads/master 5c1ac17cc -> 401b79657


[SYSTEMML-2041] Performance/memory-efficiency ultra-sparse conv2d 

This patch improves the performance and memory-efficiency of
ultra-sparse conv2d operations by allocating the output in sparse
format, forcing ultra-sparse matrix multiplications (with sparse output)
and generalizing all related operations to handle sparse intermediates. 

Furthermore, this also includes a fix for ultra-sparse matrix
multiplications where the rhs is much larger than the lhs which triggers
parallelization over rows in the rhs instead of the lhs, which is not
supported by ultra-sparse matrix multiplications.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/08c7f00d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/08c7f00d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/08c7f00d

Branch: refs/heads/master
Commit: 08c7f00d5b3d7383ca0cf0e893736026c8ee7151
Parents: 5c1ac17
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Feb 9 21:40:22 2018 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Feb 9 21:40:22 2018 -0800

----------------------------------------------------------------------
 .../cp/ConvolutionCPInstruction.java            |  4 +-
 .../runtime/matrix/data/LibMatrixDNNConv2d.java | 58 ++++++++++++-------
 .../runtime/matrix/data/LibMatrixDNNHelper.java |  2 +-
 .../runtime/matrix/data/LibMatrixDNNIm2Col.java | 16 +++--
 .../runtime/matrix/data/LibMatrixMult.java      | 61 +++++++++-----------
 .../runtime/matrix/data/LibMatrixNative.java    |  2 +-
 6 files changed, 76 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/08c7f00d/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index e32c3cf..34daf33 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -440,7 +440,9 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                                outputBlock = new MatrixBlock(N, K*P*Q, true);
                        }
                        else {
-                               outputBlock = new MatrixBlock(N, K*P*Q, 
false).allocateBlock();
+                               boolean sparse = matBlock.isUltraSparse(false) 
&& params.bias == null
+                                       && matBlock.getInMemorySize() < 
MatrixBlock.estimateSizeDenseInMemory(N, K*P*Q);
+                               outputBlock = new MatrixBlock(N, K*P*Q, 
sparse).allocateBlock();
                                if(params.enableNative && 
!isFilterSparse(filter) && !matBlock.isInSparseFormat())
                                        LibMatrixNative.conv2d(matBlock, 
filter, outputBlock, params);
                                else

http://git-wip-us.apache.org/repos/asf/systemml/blob/08c7f00d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java
index e30de2c..22a5e1a 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2d.java
@@ -52,7 +52,7 @@ public class LibMatrixDNNConv2d
                
                MatrixBlock in1 = params.input1;
                boolean isEmptyDenseInput = !in1.isInSparseFormat() && 
in1.denseBlock == null;
-               boolean isTransPref = in1.sparse && !params.input2.sparse && 
+               boolean isTransPref = in1.sparse && !params.input2.sparse && 
!params.output.sparse &&
                        MatrixBlock.evalSparseFormatInMemory(in1.clen, 
in1.rlen, in1.nonZeros);
                boolean applyNative = isEligibleForConv2dSparse(params)
                        && !(!isEmptyDenseInput && isTransPref);
@@ -171,8 +171,8 @@ public class LibMatrixDNNConv2d
                @Override
                public Long call() throws Exception {
                        final int PQ = _params.P*_params.Q, K = _params.K, CRS 
= _params.C*_params.R*_params.S;
-                       MatrixBlock outIm2col = new MatrixBlock(CRS, PQ, false);
-                       MatrixBlock outMM = new MatrixBlock(K, PQ, false);
+                       MatrixBlock outIm2col = new MatrixBlock(CRS, PQ, 
_params.input1.sparse);
+                       MatrixBlock outMM = new MatrixBlock(K, PQ, 
_params.output.sparse);
                        Im2colWorker im2ColWorker = Im2colWorker.getWorker( 
_params.input1, outIm2col, _params, false);
                        long time1 = 0; long time2 = 0;
                        for(int n = _rl; n < _ru; n++)  {
@@ -182,7 +182,7 @@ public class LibMatrixDNNConv2d
                                long t2 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
                                
                                // filter %*% _im2ColOutBlock => matMultOutBlock
-                               outMM.reset(outMM.rlen, outMM.clen, false);
+                               outMM.reset(outMM.rlen, outMM.clen, 
_params.output.sparse);
                                
LibMatrixDNNHelper.singleThreadedMatMult(_params.input2, outIm2col, outMM, 
false, true, _params);
                                long t3 = DMLScript.FINEGRAINED_STATISTICS ? 
System.nanoTime() : 0;
                                
@@ -191,8 +191,8 @@ public class LibMatrixDNNConv2d
                                        time2 += t3 - t2;
                                }
                                
-                               // Copy the matrix matMultOutBlock of shape [K 
X PQ] to params.output.denseBlock + destPos
-                               partialCopy1(outMM, 
_params.output.getDenseBlockValues(), n*K*PQ, K, PQ);
+                               // Copy the outMM of shape [K x PQ] to a row in 
params.output 
+                               partialCopy1(outMM, _params.output, n, K, PQ);
                                
                                // Add bias to current row if necessary, always 
dense
                                if(_params.bias != null)
@@ -209,27 +209,43 @@ public class LibMatrixDNNConv2d
                        return _params.output.recomputeNonZeros(_rl, _ru-1);
                }
                
-               // Copy the matrix src of shape [K X PQ] to 
params.output.denseBlock + destPos
-               private static void partialCopy1(MatrixBlock src, double [] 
dest, int destPos, int K, int PQ) {
+               // Copy the matrix src of shape [K X PQ] to the r-th row in 
params.output 
+               private static void partialCopy1(MatrixBlock src, MatrixBlock 
dest, int r, int K, int PQ) {
                        // Copying is required as LibMatrixMult.matrixMult 
(and/or Java) is not pointer aware.
-                       // This is not required in Native implementation
                        if( src.isEmptyBlock() )
                                return;
-                       if(src.isInSparseFormat()) {
-                               SparseBlock sblock = src.sparseBlock;
+                       if( src.sparse ) { //* <- SPARSE
+                               SparseBlock srcBlock = src.sparseBlock;
+                               SparseBlock sdestBlock = dest.sparseBlock;
+                               double[] ddestBlock = 
dest.getDenseBlockValues();
+                               
                                for(int k = 0; k < src.getNumRows(); k++) {
-                                       if( sblock.isEmpty(k) ) continue;
-                                       int apos = sblock.pos(k);
-                                       int alen = sblock.size(k);
-                                       int[] aix = sblock.indexes(k);
-                                       double[] avals = sblock.values(k);
-                                       int desPosK = destPos + k*PQ;
-                                       for(int j = apos; j < apos+alen; j++)
-                                               dest[desPosK+aix[j]] = avals[j];
+                                       if( srcBlock.isEmpty(k) ) continue;
+                                       int apos = srcBlock.pos(k);
+                                       int alen = srcBlock.size(k);
+                                       int[] aix = srcBlock.indexes(k);
+                                       double[] avals = srcBlock.values(k);
+                                       if( dest.sparse ) {
+                                               sdestBlock.setIndexRange(r,
+                                                       0, K*PQ, avals, aix, 
apos, alen);
+                                       }
+                                       else {
+                                               int desPosK = r + k*PQ;
+                                               for(int j = apos; j < 
apos+alen; j++)
+                                                       
ddestBlock[desPosK+aix[j]] = avals[j];
+                                       }
+                               }
+                       }
+                       else { //* <- DENSE
+                               if( dest.sparse ) {
+                                       dest.getSparseBlock().setIndexRange(r, 
0, K*PQ,
+                                               src.getDenseBlockValues(), 0, 
K*PQ);
+                               }
+                               else {
+                                       
System.arraycopy(src.getDenseBlockValues(), 0,
+                                               dest.getDenseBlockValues(), 
r*K*PQ, K*PQ);
                                }
                        }
-                       else 
-                               System.arraycopy(src.getDenseBlockValues(), 0, 
dest, destPos, K * PQ);
                }
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/08c7f00d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
index 32a0eaa..b985c42 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
@@ -62,7 +62,7 @@ public class LibMatrixDNNHelper
                if( !params.enableNative || m1.sparse || m2.sparse ) {
                        prepNonZerosForMatrixMult(m1, recomputeNNZM1);
                        prepNonZerosForMatrixMult(m2, recomputeNNZM2);
-                       LibMatrixMult.matrixMult(m1, m2, ret, false);
+                       LibMatrixMult.matrixMult(m1, m2, ret, true);
                }
                else {
                        ret.sparse = false;

http://git-wip-us.apache.org/repos/asf/systemml/blob/08c7f00d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
index ac7a6a8..4af4933 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
@@ -54,7 +54,7 @@ public class LibMatrixDNNIm2Col {
                                //preallocate sparse-rows (larger than average 
sparsity to account for skew)
                                int estnnz = 
(int)Math.ceil(4*input.getSparsity()*out.clen);
                                for(int r = 0; r < out.rlen; r++)
-                                       out.getSparseBlock().allocate(r, 
Math.min(estnnz, out.clen));
+                                       out.getSparseBlock().allocate(r, 
Math.max(Math.min(estnnz, out.clen),16));
                                return new 
SparseSparseIm2colWorkerAllChan(input, out, params, trans);
                        }
                }
@@ -145,16 +145,15 @@ public class LibMatrixDNNIm2Col {
         */
        private static class SparseSparseIm2colWorkerAllChan implements 
Im2colWorker {
                private final MatrixBlock input, output;
-               private final int S, R, P, Q, W, HW, RS;
+               private final int S, R, P, Q, H, W, RS;
                private final int stride_h, stride_w, pad_h, pad_w;
                private final boolean trans;
                private final boolean simple;
                public SparseSparseIm2colWorkerAllChan(MatrixBlock input, 
MatrixBlock im2ColOutBlock, ConvolutionParameters params, boolean trans) {
                        this.input = input;
                        this.output = im2ColOutBlock;
-                       this.HW = params.H * params.W;
                        this.RS = params.R * params.S;
-                       this.W = params.W; this.R = params.R; this.S = 
params.S; this.P = params.P; this.Q = params.Q;
+                       this.H = params.H; this.W = params.W; this.R = 
params.R; this.S = params.S; this.P = params.P; this.Q = params.Q;
                        this.stride_h = params.stride_h; this.stride_w = 
params.stride_w;
                        this.pad_h = params.pad_h; this.pad_w = params.pad_w;
                        this.trans = trans;
@@ -175,20 +174,19 @@ public class LibMatrixDNNIm2Col {
                        double[] avals = sblock.values(n);
                        
                        // Iterate over the sparse block
+                       CellIndex3 ix = new CellIndex3();
                        for(int j=apos; j<apos+alen; j++) {
                                // Note: the input is of shape [N, CHW]
                                int chw = aix[j];
                                
                                // Get individual zero-based c,h,w indexes from 
zero-based 'chw'
-                               int cInput = chw / HW;
-                               int hInput = (chw - cInput*HW)/W;
-                               int wInput = chw % W; 
+                               ix = 
LibMatrixDNNHelper.computeTensorIndexes(chw, H, W, ix);
                                
                                if( simple )
-                                       
appendInputValueToIm2colOutputSimple(output, cInput, hInput, wInput, 
+                                       
appendInputValueToIm2colOutputSimple(output, ix.ix1, ix.ix2, ix.ix3, 
                                                avals[j], R, S, RS, P, trans);
                                else
-                                       appendInputValueToIm2colOutput(output, 
cInput, hInput, wInput, avals[j], 
+                                       appendInputValueToIm2colOutput(output, 
ix.ix1, ix.ix2, ix.ix3, avals[j], 
                                                R, S, P, Q, stride_h, stride_w, 
pad_h, pad_w, trans);
                        }
                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/08c7f00d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 6119e95..d5ed8b2 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -95,22 +95,18 @@ public class LibMatrixMult
         * @param m1 first matrix
         * @param m2 second matrix
         * @param ret result matrix
-        * @param maintainNnz if false, nnzs are not recomputed and evaluated
+        * @param fixedRet if true, output representation is fixed and nnzs not 
recomputed
         * @throws DMLRuntimeException if DMLRuntimeException occurs
         */
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, boolean maintainNnz) 
-                       throws DMLRuntimeException
-       {       
-               matrixMult(m1, m2, ret, 0, m1.rlen, maintainNnz);
+       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, boolean fixedRet) throws DMLRuntimeException {
+               matrixMult(m1, m2, ret, 0, m1.rlen, fixedRet);
        }
        
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int rl, int ru) 
-                       throws DMLRuntimeException
-       {
-               matrixMult(m1, m2, ret, rl, ru, true);
+       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int rl, int ru) throws DMLRuntimeException {
+               matrixMult(m1, m2, ret, rl, ru, false);
        }
        
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int rl, int ru, boolean maintainNnz) 
+       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int rl, int ru, boolean fixedRet) 
                throws DMLRuntimeException
        {
                //check inputs / outputs
@@ -122,14 +118,16 @@ public class LibMatrixMult
                //Timing time = new Timing(true);
                
                //pre-processing: output allocation
-               boolean ultraSparse = isUltraSparseMatrixMult(m1, m2);
+               boolean ultraSparse = (fixedRet && ret.sparse)
+                       || (!fixedRet && isUltraSparseMatrixMult(m1, m2));
                boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
                m2 = prepMatrixMultRightInput(m1, m2);
                ret.sparse = ultraSparse;
                ret.allocateBlock();
                
                //prepare row-upper for special cases of vector-matrix
-               boolean pm2 = checkParMatrixMultRightInputRows(m1, m2, 
Integer.MAX_VALUE);
+               boolean pm2 = !ultraSparse &&
+                       checkParMatrixMultRightInputRows(m1, m2, 
Integer.MAX_VALUE);
                int ru2 = (pm2 && ru==m1.rlen) ? m2.rlen : ru; 
                int cu = m2.clen;
                
@@ -146,7 +144,7 @@ public class LibMatrixMult
                        matrixMultDenseSparse(m1, m2, ret, pm2, 0, ru2);
                
                //post-processing: nnz/representation
-               if( maintainNnz ) {
+               if( !fixedRet ) {
                        if( !ret.sparse )
                                ret.recomputeNonZeros();
                        ret.examSparsity();
@@ -200,8 +198,8 @@ public class LibMatrixMult
                }
                
                //prepare row-upper for special cases of vector-matrix / 
matrix-matrix
-               boolean pm2r = checkParMatrixMultRightInputRows(m1, m2, k);
-               boolean pm2c = checkParMatrixMultRightInputCols(m1, m2, k, 
pm2r);
+               boolean pm2r = !ultraSparse && 
checkParMatrixMultRightInputRows(m1, m2, k);
+               boolean pm2c = !ultraSparse && 
checkParMatrixMultRightInputCols(m1, m2, k, pm2r);
                int num = pm2r ? m2.rlen : pm2c ? m2.clen : m1.rlen; 
                
                //core multi-threaded matrix mult computation
@@ -1642,25 +1640,20 @@ public class LibMatrixMult
                {
                        SparseBlock b = m2.sparseBlock;
                        
-                       for(int k = 0; k < cd; k++ ) 
-                       {
-                               if( !b.isEmpty(k) ) 
-                               {
-                                       int bpos = b.pos(k);
-                                       int blen = b.size(k);
-                                       int[] bixs = b.indexes(k);
-                                       double[] bvals = b.values(k);
-                                       for( int j=bpos; j<bpos+blen; j++ )
-                                       {
-                                               double bval = bvals[j];
-                                               int bix = bixs[j];
-                                               for( int i=rl; i<ru; i++ )
-                                               {
-                                                       double cvald = 
bval*m1.quickGetValue(i, k);
-                                                       if( cvald != 0 ){
-                                                               double cval = 
ret.quickGetValue(i, bix);
-                                                               
ret.quickSetValue(i, bix, cval+cvald);
-                                                       }
+                       for(int k = 0; k < cd; k++ ) {
+                               if( b.isEmpty(k) ) continue; 
+                               int bpos = b.pos(k);
+                               int blen = b.size(k);
+                               int[] bixs = b.indexes(k);
+                               double[] bvals = b.values(k);
+                               for( int j=bpos; j<bpos+blen; j++ ) {
+                                       double bval = bvals[j];
+                                       int bix = bixs[j];
+                                       for( int i=rl; i<ru; i++ ) {
+                                               double cvald = 
bval*m1.quickGetValue(i, k);
+                                               if( cvald != 0 ){
+                                                       double cval = 
ret.quickGetValue(i, bix);
+                                                       ret.quickSetValue(i, 
bix, cval+cvald);
                                                }
                                        }
                                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/08c7f00d/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
index 9e3a6ee..7e0a6d7 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixNative.java
@@ -109,7 +109,7 @@ public class LibMatrixNative
                        Statistics.incrementNativeFailuresCounter();
                }
                if (k == 1)
-                       LibMatrixMult.matrixMult(m1, m2, ret, examSparsity);
+                       LibMatrixMult.matrixMult(m1, m2, ret, !examSparsity);
                else
                        LibMatrixMult.matrixMult(m1, m2, ret, k);
        }

Reply via email to