Repository: systemml
Updated Branches:
  refs/heads/master 3c7a2fc96 -> 6992c3896


[SYSTEMML-2034] Performance sparse-sparse maxpooling_backward

This patch improves the performance of sparse-sparse maxpooling_backward
operations by leveraging the new efficient sparse-dense maxpooling
operations over the sparse input data, along with a single, sequential
scan of a partial sparse rhs row per input row and channel. Furthermore,
this patch also introduces additional tests to ensure correctness for
various scenarios.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6992c389
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6992c389
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6992c389

Branch: refs/heads/master
Commit: 6992c389628540c60b0b5e3cd60f934cdac3822a
Parents: 3c7a2fc
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sat Dec 2 14:45:38 2017 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sat Dec 2 15:00:39 2017 -0800

----------------------------------------------------------------------
 .../data/LibMatrixDNNPoolingBackwardHelper.java | 167 ++++++++++---------
 .../functions/tensor/PoolBackwardTest.java      |  12 ++
 2 files changed, 96 insertions(+), 83 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/6992c389/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
index 4d8319b..1992b79 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
@@ -131,16 +131,20 @@ public class LibMatrixDNNPoolingBackwardHelper {
        {
                private final int _rl, _ru; 
                private final ConvolutionParameters _params; 
-               private final double[] dout;
-               private final MatrixBlock output;
                private final boolean reluBack;
+               protected final MatrixBlock doutput, output;
                
-               public PoolingBackwardSparseDense(int rl, int ru, 
ConvolutionParameters params, boolean relu) {
-                       _rl = rl; _ru = ru; _params = params;
+               protected PoolingBackwardSparseDense(int rl, int ru, 
ConvolutionParameters params, boolean relu, MatrixBlock dout, MatrixBlock out) {
+                       _rl = rl; _ru = ru; 
+                       _params = params;
                        reluBack = relu;
-                       dout = params.input2.getDenseBlock();
-                       output = params.output;
-                       if (dout == null || output.getDenseBlock() == null )
+                       doutput = dout;
+                       output = out;
+               }
+               
+               public PoolingBackwardSparseDense(int rl, int ru, 
ConvolutionParameters params, boolean relu) {
+                       this(rl, ru, params, relu, params.input2, 
params.output);
+                       if (doutput.getDenseBlock() == null || 
output.getDenseBlock() == null )
                                throw new RuntimeException("Incorrect usage: 
empty inputs");
                        if (!params.input1.isInSparseFormat())
                                throw new RuntimeException("Incorrect usage: 
sparse input1 expected");
@@ -149,8 +153,6 @@ public class LibMatrixDNNPoolingBackwardHelper {
                @Override
                public Long call() throws Exception 
                {
-                       SparseBlock sblock = _params.input1.getSparseBlock();
-                       double[] out = output.getDenseBlock();
                        final int P = _params.P, Q = _params.Q, W = _params.W;
                        final int C = _params.C, R = _params.R, S = _params.S;
                        final int padh = _params.pad_h, padw = _params.pad_w;
@@ -167,58 +169,70 @@ public class LibMatrixDNNPoolingBackwardHelper {
                        for(int n = _rl; n < _ru; n++)  {
                                for (int c = 0; c < C; c++) {
                                        //step 0: basic initializations
-                                       final int doutOffset = n*CPQ + c*PQ;
                                        final int outOffset = n*CHW + c*HW;
                                        
                                        //step 1: perform maxpooling w/ index 
maintenance in a 
                                        //single, sequential pass over the 
sparse input matrix
-                                       if( !sblock.isEmpty(n) ) {
-                                               Arrays.fill(maxVal, 
-Double.MAX_VALUE);
-                                               int apos = sblock.pos(n);
-                                               int alen = sblock.size(n);
-                                               int[] aix = sblock.indexes(n);
-                                               double[] avals = 
sblock.values(n);
-                                               //find channel start and end, 
w/ robustness for non-existing entries
-                                               int cpos = (c==0) ? 0 : 
sblock.posFIndexGTE(n, c*HW);
-                                               int cpos2 = (c+1==C) ? alen : 
sblock.posFIndexGTE(n, (c+1)*HW);
-                                               cpos = (cpos>=0) ? cpos : alen;
-                                               cpos2 = (cpos2>=0) ? cpos2 : 
alen;
-                                               int lastix = c*HW-1;
-                                               for(int j=apos+cpos; 
j<apos+cpos2; j++) {
-                                                       //handle skipped zero 
values
-                                                       update0(lastix+1, 
aix[j], maxVal, maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W);
-                                                       //handle current 
non-zero value
-                                                       int h = (aix[j] % HW) / 
W;
-                                                       int w = aix[j] % W;
-                                                       double val = reluBack 
&& avals[j] < 0 ? 0 : avals[j];
-                                                       update(val, maxVal, 
maxIx, h, w, padh, padw, strideh, stridew, P, Q, R, S, W);
-                                                       //memoize last seen 
index
-                                                       lastix = aix[j];
-                                               }
-                                               //handle skipped zero values at 
end of row
-                                               update0(lastix+1, (c+1)*HW, 
maxVal, maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W);
-                                       }
-                                       else {
-                                               //handle empty row
-                                               Arrays.fill(maxVal, 0);
-                                               for(int p = 0, ix=0; p < P; 
p++) {
-                                                       int h = 
Math.max(-padh+p*strideh, 0);
-                                                       for(int q = 0; q < Q; 
q++, ix++) {
-                                                               int w = 
Math.max(-padw+q*stridew, 0);
-                                                               maxIx[ix] = h * 
W + w;
-                                                       }
-                                               }
-                                       }
+                                       maxpoolingForward(maxVal, maxIx, n, c,
+                                               padh, padw, strideh, stridew, 
C, P, Q, R, S, HW, W);
                                        
                                        //step 2: perform maxpooling backward
-                                       for (int pq = 0; pq < PQ; pq++)
-                                               out[ outOffset + maxIx[pq] ] += 
dout[ doutOffset + pq ];
+                                       maxpoolingBackward(maxIx, outOffset, n, 
c, C, Q, PQ, CPQ);
                                }
                        }
                        //thread-local nnz maintenance
                        return output.recomputeNonZeros(_rl, _ru-1);
                }
                
+               protected void maxpoolingForward(double[] maxVal, int[] maxIx, 
int n, int c, int padh, int padw, int strideh, int stridew, int C, int P, int 
Q, int R, int S, int HW, int W) {
+                       SparseBlock sblock = _params.input1.getSparseBlock();
+                       if( !sblock.isEmpty(n) ) {
+                               Arrays.fill(maxVal, -Double.MAX_VALUE);
+                               int apos = sblock.pos(n);
+                               int alen = sblock.size(n);
+                               int[] aix = sblock.indexes(n);
+                               double[] avals = sblock.values(n);
+                               //find channel start and end, w/ robustness for 
non-existing entries
+                               int cpos = (c==0) ? 0 : sblock.posFIndexGTE(n, 
c*HW);
+                               int cpos2 = (c+1==C) ? alen : 
sblock.posFIndexGTE(n, (c+1)*HW);
+                               cpos = (cpos>=0) ? cpos : alen;
+                               cpos2 = (cpos2>=0) ? cpos2 : alen;
+                               int lastix = c*HW-1;
+                               for(int j=apos+cpos; j<apos+cpos2; j++) {
+                                       //handle skipped zero values
+                                       update0(lastix+1, aix[j], maxVal, 
maxIx, padh, padw, strideh, stridew, P, Q, R, S, HW, W);
+                                       //handle current non-zero value
+                                       int h = (aix[j] % HW) / W;
+                                       int w = aix[j] % W;
+                                       double val = reluBack && avals[j] < 0 ? 
0 : avals[j];
+                                       update(val, maxVal, maxIx, h, w, padh, 
padw, strideh, stridew, P, Q, R, S, W);
+                                       //memoize last seen index
+                                       lastix = aix[j];
+                               }
+                               //handle skipped zero values at end of row
+                               update0(lastix+1, (c+1)*HW, maxVal, maxIx, 
padh, padw, strideh, stridew, P, Q, R, S, HW, W);
+                       }
+                       else {
+                               //handle empty row
+                               Arrays.fill(maxVal, 0);
+                               for(int p = 0, ix=0; p < P; p++) {
+                                       int h = Math.max(-padh+p*strideh, 0);
+                                       for(int q = 0; q < Q; q++, ix++) {
+                                               int w = 
Math.max(-padw+q*stridew, 0);
+                                               maxIx[ix] = h * W + w;
+                                       }
+                               }
+                       }
+               }
+               
+               protected void maxpoolingBackward(int[] maxIx, int outOffset, 
int n, int c, int C, int Q, int PQ, int CPQ) {
+                       double[] dout = doutput.getDenseBlock();
+                       double[] out = output.getDenseBlock();
+                       final int doutOffset = n*CPQ + c*PQ;
+                       for( int pq = 0; pq < PQ; pq++ )
+                               out[ outOffset + maxIx[pq] ] += dout[ 
doutOffset + pq ];
+               }
+               
                private static void update0(int lix, int uix, double[] maxVal, 
int[] maxIx, int padh, int padw, int strideh, int stridew, int P, int Q, int R, 
int S, int HW, int W) {
                        //TODO exploit constant value and overlap for potential 
early abort
                        for(int i = lix; i<uix; i++)
@@ -249,20 +263,10 @@ public class LibMatrixDNNPoolingBackwardHelper {
        /**
         * Performs the maxpooling backward operation for sparse input and 
sparse error (dout)
         */
-       public static class PoolingBackwardSparseSparse implements 
Callable<Long> 
+       public static class PoolingBackwardSparseSparse extends 
PoolingBackwardSparseDense
        {
-               public int _rl; public int _ru; 
-               private final ConvolutionParameters _params; 
-               MatrixBlock output;
-               boolean performReluBackward;
-               int C; int CHW; int P; int Q; int HW; 
-               public PoolingBackwardSparseSparse(int rl, int ru, 
ConvolutionParameters params, boolean performReluBackward) {
-                       _rl = rl; _ru = ru;
-                       _params = params;
-                       this.performReluBackward = performReluBackward;
-                       output = params.output;
-                       C = params.C; CHW = params.C*params.H*params.W; HW = 
params.H*params.W;
-                       P = params.P; Q = params.Q;
+               public PoolingBackwardSparseSparse(int rl, int ru, 
ConvolutionParameters params, boolean relu) {
+                       super(rl, ru, params, relu, params.input2, 
params.output);
                        if (output.getDenseBlock() == null )
                                throw new RuntimeException("Incorrect usage: 
empty outputs");
                        if (!params.input1.isInSparseFormat() || 
!params.input2.isInSparseFormat())
@@ -270,29 +274,26 @@ public class LibMatrixDNNPoolingBackwardHelper {
                }
                
                @Override
-               public Long call() throws Exception {
+               protected void maxpoolingBackward(int[] maxIx, int outOffset, 
int n, int c, int C, int Q, int PQ, int CPQ) {
+                       SparseBlock sblock = doutput.getSparseBlock();
                        double[] out = output.getDenseBlock();
-                       for(int n = _rl; n < _ru; n++)  {
-                               if( !_params.input2.sparseBlock.isEmpty(n) ) {
-                                       int [] tensorIndexes = new int[3];
-                                       int apos = 
_params.input2.sparseBlock.pos(n);
-                                       int alen = 
_params.input2.sparseBlock.size(n);
-                                       int[] aix = 
_params.input2.sparseBlock.indexes(n);
-                                       double[] avals = 
_params.input2.sparseBlock.values(n);
-                                       for(int j = apos; j < apos+alen; j++) {
-                                               
LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, P, Q);
-                                               int c = tensorIndexes[0];
-                                               int p = tensorIndexes[1];
-                                               int q = tensorIndexes[2];
-                                               final int inputOffset = n*CHW + 
c*HW;
-                                               int maxIndex = 
LibMatrixDNNHelper.getMaxIndexSparse(p, q, inputOffset, n, c, _params.input1, 
_params, performReluBackward);
-                                               if(maxIndex != -1)
-                                                       out[maxIndex] += 
avals[j];
-                                       }
-                               }
+                       if( sblock.isEmpty(n) )
+                               return;
+                       int apos = sblock.pos(n);
+                       int alen = sblock.size(n);
+                       int[] aix = sblock.indexes(n);
+                       double[] avals = sblock.values(n);
+                       //find channel start and end, w/ robustness for 
non-existing entries
+                       int cpos = (c==0) ? 0 : sblock.posFIndexGTE(n, c*PQ);
+                       int cpos2 = (c+1==C) ? alen : sblock.posFIndexGTE(n, 
(c+1)*PQ);
+                       cpos = (cpos>=0) ? cpos : alen;
+                       cpos2 = (cpos2>=0) ? cpos2 : alen;
+                       for(int j = apos+cpos; j<apos+cpos2; j++) {
+                               int p = (aix[j] % PQ) / Q;
+                               int q = aix[j] % Q;
+                               int pq = p * Q + q;
+                               out[ outOffset + maxIx[pq] ] += avals[j];
                        }
-                       //thread-local nnz maintenance
-                       return output.recomputeNonZeros(_rl, _ru-1);
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/6992c389/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java
index ca3a19a..2c661ec 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolBackwardTest.java
@@ -120,6 +120,18 @@ public class PoolBackwardTest extends AutomatedTestBase
                runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", true, false);
        }
        
+       @Test
+       public void testMaxPool2DBackwardSparse11() {
+               int numImg = 10; int imgSize = 28; int numChannels = 1;  int 
stride = 1; int pad = 0; int poolSize1 = 5; int poolSize2 = 5;
+               runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", true, true);
+       }
+       
+       @Test
+       public void testMaxPool2DBackwardSparse12() {
+               int numImg = 10; int imgSize = 28; int numChannels = 4;  int 
stride = 1; int pad = 0; int poolSize1 = 5; int poolSize2 = 5;
+               runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", true, true);
+       }
+       
        public void runPoolTest( ExecType et, int imgSize, int numImg, int 
numChannels, int stride, 
                        int pad, int poolSize1, int poolSize2, String poolMode, 
boolean sparse1, boolean sparse2) 
        {

Reply via email to