Repository: systemml
Updated Branches:
  refs/heads/master 4d7606574 -> 3c7a2fc96


[SYSTEMML-2034] Performance sparse-dense maxpooling_backward

The existing sparse-dense maxpooling backward implementation scanned
each sparse row for all channels, and p/q parameters, which caused
significant overhead on lenet over mnist. This patch addresses this
performance issue by introducing a "real" sparse-dense implementation
that uses small auxiliary data structures and a single, sequential scan
over the sparse input data. Additionally, the nnzs are now maintained in
a thread-local manner to increase utilization in multi-threaded
environments. There is additional potential to exploit overlapping zeros
for gaps but since this is no major bottleneck, we can postpone this
into the next release. 

On lenet over mnist60k w/ C=1, epochs=1, Hin=28, Win=28, Hf=5, Wf=5,
stride=1, pad=2, F1=32, F2=64, N3=512, this patch improved end-to-end
performance from 1,355s (1,098s maxpooling backward) to 259s (24s
maxpooling backward).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3c7a2fc9
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3c7a2fc9
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3c7a2fc9

Branch: refs/heads/master
Commit: 3c7a2fc96069e99332a21ecfa00907c99486e5a5
Parents: 4d76065
Author: Matthias Boehm <[email protected]>
Authored: Sat Dec 2 00:15:43 2017 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Sat Dec 2 01:45:57 2017 -0800

----------------------------------------------------------------------
 .../sysml/runtime/matrix/data/LibMatrixDNN.java |  13 +-
 .../runtime/matrix/data/LibMatrixDNNHelper.java |  25 ++-
 .../data/LibMatrixDNNPoolingBackwardHelper.java | 167 ++++++++++++++-----
 3 files changed, 142 insertions(+), 63 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/3c7a2fc9/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 67d4a1a..0e4a468 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -313,23 +313,22 @@ public class LibMatrixDNN {
                }
                
                if(DMLScript.FINEGRAINED_STATISTICS) {
-                       if(input.isInSparseFormat() || dout.isInSparseFormat()) 
{
+                       if(input.isInSparseFormat() || dout.isInSparseFormat())
                                maxPoolBwdSparseCount.addAndGet(1);
-                       }
-                       else {
+                       else
                                maxPoolBwdDenseCount.addAndGet(1);
-                       }
                }
                
                if (params.output.isInSparseFormat())
                        throw new DMLRuntimeException("Sparse 
maxpooling_backward is not supported");
 
-               fillIndexesArray(params);
+               if( !(params.input1.isInSparseFormat() && 
!params.input2.isInSparseFormat()) )
+                       fillIndexesArray(params); //not needed for sparse-dense
                
-               execute(LibMatrixDNNHelper.getMaxPoolingBackwardWorkers(params, 
performReluBackward), params);
+               long nnz = 
execute(LibMatrixDNNHelper.getMaxPoolingBackwardWorkers(params, 
performReluBackward), params);
                
                //post-processing: maintain nnz 
-               outputBlock.recomputeNonZeros(); 
+               outputBlock.setNonZeros(nnz); 
                outputBlock.examSparsity();
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/3c7a2fc9/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
index 7b749dd..ee21ce3 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
@@ -75,19 +75,17 @@ public class LibMatrixDNNHelper {
                ArrayList<Callable<Long>> ret = new ArrayList<>();
                int k = 
OptimizerUtils.getConstrainedNumThreads(params.numThreads);
                int taskSize = (int)(Math.ceil((double)params.N / k));
+               boolean sparse1 = params.input1.isInSparseFormat();
+               boolean sparse2 = params.input2.isInSparseFormat();
                for(int i = 0; i*taskSize < params.N; i++) {
-                       if(!params.input1.isInSparseFormat()) {
-                               if(!params.input2.isInSparseFormat()) 
-                                       ret.add(new 
PoolingBackwardDenseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), 
params, performReluBackward));
-                               else
-                                       ret.add(new 
PoolingBackwardDenseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), 
params, performReluBackward));
-                       }
-                       else {
-                               if(!params.input2.isInSparseFormat()) 
-                                       ret.add(new 
PoolingBackwardSparseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), 
params, performReluBackward));
-                               else
-                                       ret.add(new 
PoolingBackwardSparseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), 
params, performReluBackward));
-                       }
+                       if( !sparse1 && !sparse2 )
+                               ret.add(new 
PoolingBackwardDenseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), 
params, performReluBackward));
+                       else if( !sparse1 && sparse2 )
+                               ret.add(new 
PoolingBackwardDenseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), 
params, performReluBackward));
+                       else if( sparse1 && !sparse2 ) 
+                               ret.add(new 
PoolingBackwardSparseDense(i*taskSize, Math.min((i+1)*taskSize, params.N), 
params, performReluBackward));
+                       else if( sparse1 && sparse2 )
+                               ret.add(new 
PoolingBackwardSparseSparse(i*taskSize, Math.min((i+1)*taskSize, params.N), 
params, performReluBackward));
                }
                return ret;
        }
@@ -417,9 +415,6 @@ public class LibMatrixDNNHelper {
         * @throws DMLRuntimeException if error occurs
         */
        static int getMaxIndexSparse(int p, int q, int inputOffset, int n, int 
c, MatrixBlock input, ConvolutionParameters params, boolean 
performReluBackward) throws DMLRuntimeException {
-               if(!input.isInSparseFormat())
-                       throw new DMLRuntimeException("Incorrect usage: Only 
sparse format supported");
-               
                int [] tensorIndexes = new int[3];
                
                int start_h = params.start_indexes_h[p];

http://git-wip-us.apache.org/repos/asf/systemml/blob/3c7a2fc9/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
index 5b04e59..4d8319b 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
@@ -18,6 +18,7 @@
  */
 package org.apache.sysml.runtime.matrix.data;
 
+import java.util.Arrays;
 import java.util.concurrent.Callable;
 
 /**
@@ -31,8 +32,9 @@ public class LibMatrixDNNPoolingBackwardHelper {
        {
                public int _rl; public int _ru; 
                private final ConvolutionParameters _params; 
-               double [] outputArray; boolean performReluBackward;
-               double [] inputArray; double [] doutArray;
+               boolean performReluBackward;
+               double [] inputArray, doutArray;
+               MatrixBlock output;
                int C; int CHW; int P; int Q; int HW; int CPQ; int PQ;
                public PoolingBackwardDenseDense(int rl, int ru, 
ConvolutionParameters params, boolean performReluBackward) {
                        _rl = rl; _ru = ru;
@@ -40,16 +42,17 @@ public class LibMatrixDNNPoolingBackwardHelper {
                        this.performReluBackward = performReluBackward;
                        inputArray = params.input1.getDenseBlock();
                        doutArray = params.input2.getDenseBlock();
-                       outputArray = params.output.getDenseBlock();
+                       output = params.output;
                        C = params.C; CHW = params.C*params.H*params.W; HW = 
params.H*params.W;
                        P = params.P; Q = params.Q; CPQ = 
params.C*params.P*params.Q;
                        PQ = params.P*params.Q;
-                       if (inputArray == null || doutArray == null || 
outputArray == null )
+                       if (inputArray == null || doutArray == null || 
output.getDenseBlock() == null )
                                throw new RuntimeException("Incorrect usage: 
empty inputs");
                }
                
                @Override
                public Long call() throws Exception {
+                       double[] out = output.getDenseBlock();
                        for(int n = _rl; n < _ru; n++)  {
                                for (int c = 0; c < C; c++) {
                                        final int inputOffset = n*CHW + c*HW;
@@ -58,12 +61,13 @@ public class LibMatrixDNNPoolingBackwardHelper {
                                                for (int q = 0; q < Q; q++) {
                                                        int maxIndex = 
LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, inputArray, _params, 
performReluBackward);
                                                        if(maxIndex != -1)
-                                                               
outputArray[maxIndex] += doutArray[outputOffset +  p * Q + q];
+                                                               out[maxIndex] 
+= doutArray[outputOffset +  p * Q + q];
                                                }
                                        }
                                }
                        }
-                       return 0L;
+                       //thread-local nnz maintenance
+                       return output.recomputeNonZeros(_rl, _ru-1);
                }
        }
        
@@ -74,7 +78,8 @@ public class LibMatrixDNNPoolingBackwardHelper {
        {
                public int _rl; public int _ru; 
                private final ConvolutionParameters _params; 
-               double [] outputArray; boolean performReluBackward;
+               MatrixBlock output; 
+               boolean performReluBackward;
                double [] inputArray;  MatrixBlock dout;
                int C; int CHW; int P; int Q; int HW;
                public PoolingBackwardDenseSparse(int rl, int ru, 
ConvolutionParameters params, boolean performReluBackward) {
@@ -83,10 +88,10 @@ public class LibMatrixDNNPoolingBackwardHelper {
                        this.performReluBackward = performReluBackward;
                        inputArray = params.input1.getDenseBlock();
                        dout = params.input2;
-                       outputArray = params.output.getDenseBlock();
+                       output = params.output;
                        C = params.C; CHW = params.C*params.H*params.W; HW = 
params.H*params.W;
                        P = params.P; Q = params.Q; 
-                       if (inputArray == null || outputArray == null )
+                       if (inputArray == null || output.getDenseBlock() == 
null )
                                throw new RuntimeException("Incorrect usage: 
empty inputs");
                        if (!params.input2.isInSparseFormat())
                                throw new RuntimeException("Incorrect usage: 
Call optimized versions");
@@ -94,6 +99,7 @@ public class LibMatrixDNNPoolingBackwardHelper {
                
                @Override
                public Long call() throws Exception {
+                       double[] out = output.getDenseBlock();
                        for(int n = _rl; n < _ru; n++)  {
                                if( !dout.sparseBlock.isEmpty(n) ) {
                                        int [] tensorIndexes = new int[3];
@@ -109,11 +115,12 @@ public class LibMatrixDNNPoolingBackwardHelper {
                                                final int inputOffset = n*CHW + 
c*HW;
                                                int maxIndex = 
LibMatrixDNNHelper.getMaxIndex(p, q, inputOffset, inputArray, _params, 
performReluBackward);
                                                if(maxIndex != -1)
-                                                       outputArray[maxIndex] 
+= avals[j];
+                                                       out[maxIndex] += 
avals[j];
                                        }
                                }
                        }
-                       return 0L;
+                       //thread-local nnz maintenance
+                       return output.recomputeNonZeros(_rl, _ru-1);
                }
        }
        
@@ -122,45 +129,120 @@ public class LibMatrixDNNPoolingBackwardHelper {
         */
        public static class PoolingBackwardSparseDense implements 
Callable<Long> 
        {
-               public int _rl; public int _ru; 
+               private final int _rl, _ru; 
                private final ConvolutionParameters _params; 
-               double [] outputArray; boolean performReluBackward;
-               double [] doutArray;
-               int C; int CHW; int P; int Q; int HW; int CPQ; int PQ;
-               public PoolingBackwardSparseDense(int rl, int ru, 
ConvolutionParameters params, boolean performReluBackward) {
-                       _rl = rl; _ru = ru;
-                       _params = params;
-                       this.performReluBackward = performReluBackward;
-                       doutArray = params.input2.getDenseBlock();
-                       outputArray = params.output.getDenseBlock();
-                       C = params.C; CHW = params.C*params.H*params.W; HW = 
params.H*params.W;
-                       P = params.P; Q = params.Q; CPQ = 
params.C*params.P*params.Q;
-                       PQ = params.P*params.Q;
-                       if (doutArray == null || outputArray == null )
+               private final double[] dout;
+               private final MatrixBlock output;
+               private final boolean reluBack;
+               
+               public PoolingBackwardSparseDense(int rl, int ru, 
ConvolutionParameters params, boolean relu) {
+                       _rl = rl; _ru = ru; _params = params;
+                       reluBack = relu;
+                       dout = params.input2.getDenseBlock();
+                       output = params.output;
+                       if (dout == null || output.getDenseBlock() == null )
                                throw new RuntimeException("Incorrect usage: 
empty inputs");
                        if (!params.input1.isInSparseFormat())
-                               throw new RuntimeException("Incorrect usage: 
Call optimized versions");
+                               throw new RuntimeException("Incorrect usage: 
sparse input1 expected");
                }
                
                @Override
-               public Long call() throws Exception {
+               public Long call() throws Exception 
+               {
+                       SparseBlock sblock = _params.input1.getSparseBlock();
+                       double[] out = output.getDenseBlock();
+                       final int P = _params.P, Q = _params.Q, W = _params.W;
+                       final int C = _params.C, R = _params.R, S = _params.S;
+                       final int padh = _params.pad_h, padw = _params.pad_w;
+                       final int strideh = _params.stride_h, stridew = 
_params.stride_w;
+                       final int PQ = _params.P * _params.Q;
+                       final int CPQ = _params.C * _params.P * _params.Q;
+                       final int HW = _params.H * _params.W;
+                       final int CHW = _params.C * _params.H * _params.W;
+                       
+                       //allocate auxiliary data structures
+                       double[] maxVal = new double[PQ];
+                       int[] maxIx = new int[PQ];
+                       
                        for(int n = _rl; n < _ru; n++)  {
                                for (int c = 0; c < C; c++) {
+                                       //step 0: basic initializations
                                        final int doutOffset = n*CPQ + c*PQ;
-                                       final int inputOffset = n*CHW + c*HW;
-                                       for (int p = 0; p < P; p++) {
-                                               for (int q = 0; q < Q; q++) {
-                                                       double inVal = 
doutArray[doutOffset +  p * Q + q];
-                                                       if(inVal != 0) {
-                                                               int maxIndex = 
LibMatrixDNNHelper.getMaxIndexSparse(p, q, inputOffset, n, c, _params.input1, 
_params, performReluBackward);
-                                                               if(maxIndex != 
-1)
-                                                                       
outputArray[maxIndex] += inVal;
+                                       final int outOffset = n*CHW + c*HW;
+                                       
+                                       //step 1: perform maxpooling w/ index 
maintenance in a 
+                                       //single, sequential pass over the 
sparse input matrix
+                                       if( !sblock.isEmpty(n) ) {
+                                               Arrays.fill(maxVal, 
-Double.MAX_VALUE);
+                                               int apos = sblock.pos(n);
+                                               int alen = sblock.size(n);
+                                               int[] aix = sblock.indexes(n);
+                                               double[] avals = 
sblock.values(n);
+                                               //find channel start and end, 
w/ robustness for non-existing entries
+                                               int cpos = (c==0) ? 0 : 
sblock.posFIndexGTE(n, c*HW);
+                                               int cpos2 = (c+1==C) ? alen : 
sblock.posFIndexGTE(n, (c+1)*HW);
+                                               cpos = (cpos>=0) ? cpos : alen;
+                                               cpos2 = (cpos2>=0) ? cpos2 : 
alen;
+                                               int lastix = c*HW-1;
+                                               for(int j=apos+cpos; 
j<apos+cpos2; j++) {
+                                                       //handle skipped zero 
values
+                                                       update0(lastix+1, 
aix[j], maxVal, maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W);
+                                                       //handle current 
non-zero value
+                                                       int h = (aix[j] % HW) / 
W;
+                                                       int w = aix[j] % W;
+                                                       double val = reluBack 
&& avals[j] < 0 ? 0 : avals[j];
+                                                       update(val, maxVal, 
maxIx, h, w, padh, padw, strideh, stridew, P, Q, R, S, W);
+                                                       //memoize last seen 
index
+                                                       lastix = aix[j];
+                                               }
+                                               //handle skipped zero values at 
end of row
+                                               update0(lastix+1, (c+1)*HW, 
maxVal, maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W);
+                                       }
+                                       else {
+                                               //handle empty row
+                                               Arrays.fill(maxVal, 0);
+                                               for(int p = 0, ix=0; p < P; 
p++) {
+                                                       int h = 
Math.max(-padh+p*strideh, 0);
+                                                       for(int q = 0; q < Q; 
q++, ix++) {
+                                                               int w = 
Math.max(-padw+q*stridew, 0);
+                                                               maxIx[ix] = h * 
W + w;
                                                        }
                                                }
                                        }
+                                       
+                                       //step 2: perform maxpooling backward
+                                       for (int pq = 0; pq < PQ; pq++)
+                                               out[ outOffset + maxIx[pq] ] += 
dout[ doutOffset + pq ];
                                }
                        }
-                       return 0L;
+                       //thread-local nnz maintenance
+                       return output.recomputeNonZeros(_rl, _ru-1);
+               }
+               
+               private static void update0(int lix, int uix, double[] maxVal, 
int[] maxIx, int padh, int padw, int strideh, int stridew, int P, int Q, int R, 
int S, int HW, int W) {
+                       //TODO exploit constant value and overlap for potential 
early abort
+                       for(int i = lix; i<uix; i++)
+                               update(0, maxVal, maxIx, (i%HW)/W, i%W, padh, 
padw, strideh, stridew, P, Q, R, S, W);
+               }
+               
+               private static void update(double val, double[] maxVal, int[] 
maxIx, int h, int w, int padh, int padw, int strideh, int stridew, int P, int 
Q, int R, int S, int W) {
+                       //determine lower and upper bounds for p and q
+                       //(see fillIndexesArray, solved for p and q, reversed)
+                       int lp = Math.max((h+padh-R+strideh)/strideh, 0);
+                       int up = Math.min((h+padh+strideh)/strideh, P);
+                       int lq = Math.max((w+padw-S+stridew)/stridew, 0);
+                       int uq = Math.min((w+padw+stridew)/stridew, Q);
+                       
+                       //maintain max index for all relevant p and q
+                       int maxIndex = h * W + w;
+                       for(int p = lp; p < up; p++) 
+                               for(int q = lq; q < uq; q++) {
+                                       int ix = p * Q + q;
+                                       if( maxVal[ix] < val ) {
+                                               maxVal[ix] = val;
+                                               maxIx[ix] = maxIndex;
+                                       }
+                               }
                }
        }
        
@@ -171,16 +253,17 @@ public class LibMatrixDNNPoolingBackwardHelper {
        {
                public int _rl; public int _ru; 
                private final ConvolutionParameters _params; 
-               double [] outputArray; boolean performReluBackward;
+               MatrixBlock output;
+               boolean performReluBackward;
                int C; int CHW; int P; int Q; int HW; 
                public PoolingBackwardSparseSparse(int rl, int ru, 
ConvolutionParameters params, boolean performReluBackward) {
                        _rl = rl; _ru = ru;
                        _params = params;
                        this.performReluBackward = performReluBackward;
-                       outputArray = params.output.getDenseBlock();
+                       output = params.output;
                        C = params.C; CHW = params.C*params.H*params.W; HW = 
params.H*params.W;
                        P = params.P; Q = params.Q;
-                       if (outputArray == null )
+                       if (output.getDenseBlock() == null )
                                throw new RuntimeException("Incorrect usage: 
empty outputs");
                        if (!params.input1.isInSparseFormat() || 
!params.input2.isInSparseFormat())
                                throw new RuntimeException("Incorrect usage: 
Call optimized versions");
@@ -188,6 +271,7 @@ public class LibMatrixDNNPoolingBackwardHelper {
                
                @Override
                public Long call() throws Exception {
+                       double[] out = output.getDenseBlock();
                        for(int n = _rl; n < _ru; n++)  {
                                if( !_params.input2.sparseBlock.isEmpty(n) ) {
                                        int [] tensorIndexes = new int[3];
@@ -203,11 +287,12 @@ public class LibMatrixDNNPoolingBackwardHelper {
                                                final int inputOffset = n*CHW + 
c*HW;
                                                int maxIndex = 
LibMatrixDNNHelper.getMaxIndexSparse(p, q, inputOffset, n, c, _params.input1, 
_params, performReluBackward);
                                                if(maxIndex != -1)
-                                                       outputArray[maxIndex] 
+= avals[j];
+                                                       out[maxIndex] += 
avals[j];
                                        }
                                }
                        }
-                       return 0L;
+                       //thread-local nnz maintenance
+                       return output.recomputeNonZeros(_rl, _ru-1);
                }
        }
 }

Reply via email to