[SYSTEMML-540] Avoid redundant computation of cudnnPoolingForward in 
max_pool_backward

- If the max_pool is invoked in the forward pass, then its output can be
  reused by the max_pool_backward rather than calling cudnnPoolingForward
  again. For sentence CNN with 2 epochs, this reduces the time for
  max_pool_backward from 6.361 to 2.966 seconds.

Closes #691.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/06d5bb07
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/06d5bb07
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/06d5bb07

Branch: refs/heads/master
Commit: 06d5bb073792345f7c4b7ecd0fb4454a335cc421
Parents: 118e3c0
Author: Niketan Pansare <[email protected]>
Authored: Sat Oct 28 13:44:37 2017 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Sat Oct 28 13:45:52 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    | 163 +++++++++++++------
 .../gpu/ConvolutionGPUInstruction.java          |  43 ++++-
 .../runtime/matrix/data/LibMatrixCuDNN.java     |  51 +++---
 .../sysml/test/gpu/NeuralNetworkOpTests.java    |  82 ++++++++++
 4 files changed, 260 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/06d5bb07/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java 
b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index 50a7ca3..16a8b63 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -47,14 +47,23 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
        private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = 
true;
        // 
-------------------------------------------------------------------------
        
+       // Specifies the type of this hop
        private Hop.ConvOp op;
-
        private int _maxNumThreads = -1; //-1 for unlimited
 
        private ConvolutionOp() {
                //default constructor for clone
        }
 
+       /**
+        * Create a hop from the builtin expression
+        * 
+        * @param l name of the hop
+        * @param dt datatype (only supports matrix datatype)
+        * @param vt valuetype  (only supports matrix valuetype) 
+        * @param o type of this hop
+        * @param inp input hops
+        */
        public ConvolutionOp(String l, DataType dt, ValueType vt, ConvOp o, 
ArrayList<Hop> inp) 
        {
                super(l, dt, vt);
@@ -75,8 +84,7 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                HopsException.check(_input.size() >= 1, this, "should have at 
least one input but has %d inputs", _input.size());
        }
 
-       public ConvOp getOp()
-       {
+       public ConvOp getOp() {
                return op;
        }
        
@@ -163,77 +171,129 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                return input instanceof ConvolutionOp && ((ConvolutionOp) 
input).getOp() == ConvOp.DIRECT_CONV2D;
        }
        
+       /**
+        * Compares the input parameters for max_pool/max_pool_backward 
operations
+        * 
+        * @return true if the following parameters match: stride=[stride, 
stride], padding=[pad, pad], input_shape=[numImg, numChannels, imgSize, 
imgSize], pool_size=[poolSize1, poolSize2]
+        */
+       private static boolean 
isPoolingParametersEqualAndKnown(ConvolutionParameters param1, 
ConvolutionParameters param2) {
+               return isEqualAndKnown(param1.stride_h, param2.stride_h) && 
isEqualAndKnown(param1.stride_w, param2.stride_w) && 
+                       isEqualAndKnown(param1.pad_h, param2.pad_h) && 
isEqualAndKnown(param1.pad_w, param2.pad_w) &&
+                       isEqualAndKnown(param1.R, param2.R) && 
isEqualAndKnown(param1.S, param2.S) &&
+                       isEqualAndKnown(param1.N, param2.N) && 
isEqualAndKnown(param1.C, param2.C) &&
+                       isEqualAndKnown(param1.H, param2.H) && 
isEqualAndKnown(param1.W, param2.W);
+       }
+       
+       private static boolean isEqualAndKnown(int val1, int val2) {
+               return val1 >= 0 && val2 >= 0 && val1 == val2;
+       }
+       
+       /**
+        * Returns the output lop of maxpool operation with same parameters as 
this hop.
+        * If corresponding output lop is not found or if this is not a 
max_pool_backward operation, this function returns null
+        * 
+        * @return output lop of maxpool operation with same parameters as this 
hop
+        * @throws HopsException if error 
+        * @throws LopsException if error
+        */
+       private Lop getMaxPoolOutputLop() throws HopsException, LopsException {
+               if(op != ConvOp.MAX_POOLING_BACKWARD)
+                       return null;
+               
+               Hop inputImage = getInput().get(0);
+               for(Hop tmpParent : inputImage.getParent()) {
+                       if(!(tmpParent instanceof ConvolutionOp))
+                               continue;
+                       ConvolutionOp parent = (ConvolutionOp) tmpParent;
+                       if(parent.getOp() == ConvOp.MAX_POOLING && 
isPoolingParametersEqualAndKnown(parent._cachedParams, _cachedParams)) {
+                               return parent.constructLops();
+                       }
+               }
+               return null;
+       }
+       
        public Lop constructConvolutionLops(ExecType et, ArrayList<Hop> inputs) 
throws HopsException, LopsException {
                if(inputs.size() != getNumExpectedInputs()) 
                        throw new HopsException("Incorrect number of inputs for 
" + op.name());
                
-               Lop in = null; Lop in2 = null;
-               ArrayList<Hop> inputs1 = inputs;
-               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+               // 
---------------------------------------------------------------
+               // Deal with fused operators and contruct 
lhsInputLop/optionalRhsInputLop
+               Lop lhsInputLop = null; Lop optionalRhsInputLop = null;
+               ArrayList<Hop> inputsOfPotentiallyFusedOp = inputs;
                OperationTypes lopOp = HopsConv2Lops.get(op);
-
+               
                // RELU_MAX_POOLING and RELU_MAX_POOLING_BACKWARD is extremely 
useful for CP backend 
                // by reducing unnecessary sparse-to-dense-to-sparse conversion.
                // For other backends, this operators is not necessary as it 
reduces an additional relu operator.
                if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && 
op == ConvOp.MAX_POOLING && isInputReLU(inputs.get(0))) {
-                       in = inputs.get(0).getInput().get(0).constructLops();
+                       lhsInputLop = 
inputs.get(0).getInput().get(0).constructLops();
                        lopOp = OperationTypes.RELU_MAX_POOLING;
                }
                else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == 
ExecType.CP && op == ConvOp.MAX_POOLING_BACKWARD && isInputReLU(inputs.get(0))) 
{
-                       in = inputs.get(0).getInput().get(0).constructLops();
+                       lhsInputLop = 
inputs.get(0).getInput().get(0).constructLops();
                        lopOp = OperationTypes.RELU_MAX_POOLING_BACKWARD;
                }
                else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && op == 
ConvOp.BIAS_ADD && isInputConv2d(inputs.get(0))) {
                        lopOp = OperationTypes.DIRECT_CONV2D_BIAS_ADD;
                        
                        // the first lop is image 
-                       in = inputs.get(0).getInput().get(0).constructLops();
+                       lhsInputLop = 
inputs.get(0).getInput().get(0).constructLops();
                        // the second lop is bias
-                       in2 = inputs.get(1).constructLops();
+                       optionalRhsInputLop = inputs.get(1).constructLops();
                        
                        // Use the inputs from conv2d rather than bias_add
-                       inputs1 = inputs.get(0).getInput();
+                       inputsOfPotentiallyFusedOp = inputs.get(0).getInput();
                }
                else {
-                       in = inputs.get(0).constructLops();
+                       lhsInputLop = inputs.get(0).constructLops();
                }
+               // 
---------------------------------------------------------------
                
-//             // TODO: Inserting reblock requires knowing columns apriori
-//             ConvolutionTransform transform1 = new 
ConvolutionTransform(addReblockIfNecessary(et, lopOp, in), lopOp, 
getDataType(), getValueType(), et, k);
-//             setReblockedOutputDimension(et, transform1);
-               double cpIntermediateMemEstimate = 
computeIntermediateMemEstimate(-1, -1, -1 );
+               // 
---------------------------------------------------------------
+               // Compute intermediate memory budget that can be passed to GPU 
operators 
+               // for better CuDNN operator selection at runtime
+               double intermediateMemEstimate = 
computeIntermediateMemEstimate(-1, -1, -1 );
                if(et == ExecType.GPU && _dim1 > 0 && _dim2 > 0) {
                        // This enables us to compile more efficient 
matrix-matrix CuDNN operation instead of 
                        // row-by-row invocation of multiple vector-matrix 
CuDNN operations.
                        // This is possible as the operations on GPU are 
single-threaded
                        double optimisticIntermediateMemEstimate = 
GPUContextPool.initialGPUMemBudget() - getOutputMemEstimate() - 
inputs.get(0).getOutputMemEstimate();
-                       if(in2 != null) {
+                       if(optionalRhsInputLop != null) {
                                optimisticIntermediateMemEstimate -= 
inputs.get(1).getOutputMemEstimate();
                        }
-                       cpIntermediateMemEstimate = 
Math.max(cpIntermediateMemEstimate, optimisticIntermediateMemEstimate);
+                       intermediateMemEstimate = 
Math.max(intermediateMemEstimate, optimisticIntermediateMemEstimate);
                }
-               ConvolutionTransform transform1 = new ConvolutionTransform(in, 
lopOp, getDataType(), getValueType(), et, k, cpIntermediateMemEstimate);
-               setOutputDimensions(transform1);
+               // 
---------------------------------------------------------------
                
-               setLineNumbers(transform1);
-               in.addOutput(transform1);
+               // Contruct the lop
+               ConvolutionTransform convolutionLop = new 
ConvolutionTransform(lhsInputLop, lopOp, 
+                               getDataType(), getValueType(), et, 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), 
intermediateMemEstimate);
                
-               if(in2 != null) {
-                       transform1.addInput(in2);
-                       in2.addOutput(transform1);
-               }
+               // Propagate the output dimensions and the line number of 
ConvolutionOp to ConvolutionTransform
+               setOutputDimensions(convolutionLop);
+               setLineNumbers(convolutionLop);
                
-               // stride1, stride2, padding1, padding2  
-               // input_shape1, input_shape2, input_shape3, input_shape4, 
-               // filter_shape1, filter_shape2, filter_shape3, filter_shape4
-               for( int i=1; i < inputs1.size(); i++ )
-               {
-                       Lop ltmp = inputs1.get(i).constructLops();
-                       transform1.addInput(ltmp);
-                       ltmp.addOutput(transform1);
+               // 
---------------------------------------------------------------
+               // Add input/output for parent lops of convolutionLop
+               lhsInputLop.addOutput(convolutionLop);
+               if(optionalRhsInputLop != null) {
+                       convolutionLop.addInput(optionalRhsInputLop);
+                       optionalRhsInputLop.addOutput(convolutionLop);
+               }
+               for( int i=1; i < inputsOfPotentiallyFusedOp.size(); i++ ) {
+                       Lop ltmp = 
inputsOfPotentiallyFusedOp.get(i).constructLops();
+                       convolutionLop.addInput(ltmp);
+                       ltmp.addOutput(convolutionLop);
                }
-               transform1.setLevel(); //force order of added lops
-               return transform1;
+               // Only valid for MAX_POOLING_BACKWARD on GPU
+               Lop optionalMaxPoolOutput = (et == ExecType.GPU) ? 
getMaxPoolOutputLop() : null; 
+               if(optionalMaxPoolOutput != null) {
+                       convolutionLop.addInput(optionalMaxPoolOutput);
+                       optionalMaxPoolOutput.addOutput(convolutionLop);
+               }
+               convolutionLop.setLevel(); //force order of added lops
+               // 
---------------------------------------------------------------
+               return convolutionLop;
        }
 
                        
@@ -453,12 +513,10 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                
                ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? 
ExecType.SPARK : ExecType.MR;
                
-               if( _etypeForced != null )                      
-               {
+               if( _etypeForced != null ) {
                        _etype = _etypeForced;
                }
-               else 
-               {       
+               else {  
                        if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
                                _etype = findExecTypeByMemEstimate();
                        }
@@ -479,8 +537,9 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                return _etype;
        }
        
-       // Caching parameters speed-ups dynamic recompilation time by avoiding 
unnecessary computeSizeInformation
+       // Parameters recomputed in refreshSizeInformation and passed across 
many calls of getDim
        private ConvolutionParameters _cachedParams = new 
ConvolutionParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 
_maxNumThreads);
+       
        // stride1, stride2, padding1, padding2  
        // input_shape1, input_shape2, input_shape3, input_shape4, 
        // filter_shape1, filter_shape2, filter_shape3, filter_shape4
@@ -494,16 +553,16 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                        imageHeightHop = getInput().get(8);
                        filterHeightHop = getInput().get(12);
                        _cachedParams.setIfUnknown(
-                                       getInput().get(6),
-                                       getInput().get(7), 
-                                       imageHeightHop, 
-                                       getInput().get(9), 
-                                       getInput().get(10), 
-                                       filterHeightHop, 
-                                       getInput().get(13), 
-                                       getInput().get(2), 
-                                       getInput().get(3), 
-                                       getInput().get(4), 
+                                       getInput().get(6),  // N
+                                       getInput().get(7),  // C
+                                       imageHeightHop,     // H
+                                       getInput().get(9),  // W
+                                       getInput().get(10), // K
+                                       filterHeightHop,    // R
+                                       getInput().get(13), // S
+                                       getInput().get(2),  // stride_h
+                                       getInput().get(3),  // stride_w
+                                       getInput().get(4),  // pad+h
                                        getInput().get(5), _maxNumThreads);
                }
                else {

http://git-wip-us.apache.org/repos/asf/systemml/blob/06d5bb07/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
index 354ea63..8565b5a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
@@ -92,8 +92,7 @@ public class ConvolutionGPUInstruction extends GPUInstruction 
{
                
                if( ( opcode.equalsIgnoreCase("conv2d")
                         || opcode.equalsIgnoreCase("conv2d_backward_filter")
-                        || opcode.equalsIgnoreCase("conv2d_backward_data")
-                        || opcode.equalsIgnoreCase("maxpooling_backward")) ) {
+                        || opcode.equalsIgnoreCase("conv2d_backward_data")) ) {
                        InstructionUtils.checkNumFields(parts, 16);
                        CPOperand in1 = new CPOperand(parts[1]);
                        CPOperand in2 = new CPOperand(parts[2]);
@@ -119,6 +118,39 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                        return new ConvolutionGPUInstruction(in1, in2, out, 
opcode, str, stride,
                                        padding, input_shape, filter_shape, 
Double.parseDouble(parts[16]));
                }
+               else if( opcode.equalsIgnoreCase("maxpooling_backward") ) {
+                       boolean withMaxPoolOut = false;
+                       if(parts.length == 18) {
+                               withMaxPoolOut = true;
+                       }
+                       else
+                               InstructionUtils.checkNumFields(parts, 16);
+                       CPOperand in1 = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = withMaxPoolOut ? new 
CPOperand(parts[15]) : null;
+                       CPOperand out = withMaxPoolOut ? new 
CPOperand(parts[16]) : new CPOperand(parts[15]);
+                       double memBudget = withMaxPoolOut ? 
Double.parseDouble(parts[17]) : Double.parseDouble(parts[16]);
+               
+                       ArrayList<CPOperand> stride = new ArrayList<>();
+                       ArrayList<CPOperand> padding = new ArrayList<>();
+                       ArrayList<CPOperand> input_shape = new ArrayList<>();
+                       ArrayList<CPOperand> filter_shape = new ArrayList<>();
+                       stride.add(new CPOperand(parts[3]));
+                       stride.add(new CPOperand(parts[4]));
+                       padding.add(new CPOperand(parts[5]));
+                       padding.add(new CPOperand(parts[6]));
+                       input_shape.add(new CPOperand(parts[7]));
+                       input_shape.add(new CPOperand(parts[8]));
+                       input_shape.add(new CPOperand(parts[9]));
+                       input_shape.add(new CPOperand(parts[10]));
+                       filter_shape.add(new CPOperand(parts[11]));
+                       filter_shape.add(new CPOperand(parts[12]));
+                       filter_shape.add(new CPOperand(parts[13]));
+                       filter_shape.add(new CPOperand(parts[14]));
+
+                       return new ConvolutionGPUInstruction(in1, in2, in3, 
out, opcode, str, stride,
+                                       padding, input_shape, filter_shape, 
memBudget);
+               }
                else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
                        InstructionUtils.checkNumFields(parts, 17);
                        CPOperand in1 = new CPOperand(parts[1]);
@@ -324,7 +356,7 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                else if (instOpcode.equalsIgnoreCase("maxpooling_backward")) {
                        MatrixObject image = 
getMatrixInputForGPUInstruction(ec, _input1.getName());
                        MatrixObject dout = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-                       
+                       MatrixObject maxPoolOutput = _input3 != null ? 
getMatrixInputForGPUInstruction(ec, _input3.getName()) : null;
                        if(dout.getNumRows() != N || dout.getNumColumns() != 
C*P*Q) 
                                throw new DMLRuntimeException("Incorrect 
dimensions for dout in maxpooling_backward");
                        if(image.getNumRows() != N || image.getNumColumns() != 
C*H*W) 
@@ -333,7 +365,7 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                        
                        MatrixObject out = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W);
                        
-                       LibMatrixCuDNN.maxpoolingBackward(ec.getGPUContext(0), 
getExtendedOpcode(), image, dout, out, N, C, H, W,
+                       LibMatrixCuDNN.maxpoolingBackward(ec.getGPUContext(0), 
getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W,
                                        K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q, _intermediateMemoryBudget);
                }
                else {
@@ -346,7 +378,8 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction {
                if ( !instOpcode.equalsIgnoreCase("maxpooling") )
                        
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
 
-               if (instOpcode.equalsIgnoreCase("conv2d_bias_add"))
+               if (instOpcode.equalsIgnoreCase("conv2d_bias_add") || 
+                       (instOpcode.equalsIgnoreCase("maxpooling_backward") && 
_input3 != null))
                        
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
 
                ec.releaseMatrixOutputForGPUInstruction(_output.getName());

http://git-wip-us.apache.org/repos/asf/systemml/blob/06d5bb07/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
index 7fd766c..e0a6a57 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCuDNN.java
@@ -519,6 +519,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         * @param instName the invoking instruction's name for record {@link 
Statistics}.
         * @param image image as matrix object
         * @param dout                  delta matrix, output of previous layer
+        * @param maxpoolOutput (optional and can be null) output of maxpool 
forward function
         * @param outputBlock output matrix
         * @param N                             batch size
         * @param C                             number of channels
@@ -537,12 +538,14 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
         * @throws DMLRuntimeException if DMLRuntimeException occurs
         */
        public static void maxpoolingBackward(GPUContext gCtx, String instName, 
MatrixObject image, MatrixObject dout,
-                       MatrixObject outputBlock, int N, int C, int H, int W, 
int K, int R,
+                       MatrixObject maxpoolOutput, MatrixObject outputBlock, 
int N, int C, int H, int W, int K, int R,
                        int S, int pad_h, int pad_w, int stride_h, int 
stride_w, int P,
                        int Q, double intermediateMemoryBudget) throws 
DMLRuntimeException {
                long CHW = C*H*W; long CPQ = C*P*Q;  
                long NCHW = N*CHW; long NCPQ = N*CPQ; 
 
+               final boolean isMaxPoolOutputProvided = maxpoolOutput != null;
+               
                if(NCHW < maxNumElementsOfCuDNNTensor && NCPQ < 
maxNumElementsOfCuDNNTensor) {
                        // Filter and output are accounted as dense in the 
memory estimation for conv2dBackwardData
                        long overhead = isInSparseFormat(gCtx, image) ? 
OptimizerUtils.estimateSizeExactSparsity(N, CHW, 1.0) : 0;
@@ -551,19 +554,26 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                        if(overhead <= intermediateMemoryBudget) {
                                Pointer x = getDensePointerForCuDNN(gCtx, 
image, instName);
                                Pointer dy = getDensePointerForCuDNN(gCtx, 
dout, instName);
-                               cudnnMaxpoolingBackward(gCtx, instName, x, dy, 
dx, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
+                               Pointer y = isMaxPoolOutputProvided ? 
getDensePointerForCuDNN(gCtx, maxpoolOutput, instName) : null;
+                               cudnnMaxpoolingBackward(gCtx, instName, x, dy, 
y, dx, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q);
                        }
                        else {
                                LibMatrixCuDNNInputRowFetcher imgFetcher = new 
LibMatrixCuDNNInputRowFetcher(gCtx, instName, image);
                                LibMatrixCuDNNInputRowFetcher doutFetcher = new 
LibMatrixCuDNNInputRowFetcher(gCtx, instName, dout);
+                               LibMatrixCuDNNInputRowFetcher maxPoolOutFetcher 
= isMaxPoolOutputProvided ? new LibMatrixCuDNNInputRowFetcher(gCtx, instName, 
maxpoolOutput) : null;
                                for(int n = 0; n < N; n++) {
-                                       cudnnMaxpoolingBackward(gCtx, instName, 
imgFetcher.getNthRow(n), doutFetcher.getNthRow(n), 
+                                       Pointer x = imgFetcher.getNthRow(n);
+                                       Pointer dy = doutFetcher.getNthRow(n);
+                                       Pointer y = isMaxPoolOutputProvided ? 
maxPoolOutFetcher.getNthRow(n) : null;
+                                       cudnnMaxpoolingBackward(gCtx, instName, 
x, dy, y, 
                                                        
dx.withByteOffset(n*CHW*sizeOfDataType), 
                                                        1, C, H, W, K, R, S, 
pad_h, pad_w, stride_h, stride_w, P, Q);
                                }
                                // Deallocate temporary array to hold one 
element of input
                                imgFetcher.close();
                                doutFetcher.close();
+                               if(isMaxPoolOutputProvided)
+                                       maxPoolOutFetcher.close();
                        }
                }
                else {
@@ -572,36 +582,33 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
        }
        
        private static void cudnnMaxpoolingBackward(GPUContext gCtx, String 
instName, 
-                       Pointer x, Pointer dy, Pointer dx, 
+                       Pointer x, Pointer dy, Pointer y, Pointer dx, 
                        int N, int C, int H, int W, int K, int R,
                        int S, int pad_h, int pad_w, int stride_h, int 
stride_w, int P,
                        int Q) throws DMLRuntimeException {
                if(LOG.isTraceEnabled()) {
                        LOG.trace("GPU : maxpoolingBackward" + ", GPUContext=" 
+ gCtx);
                }
-               Pointer y = null;
+               
+               boolean isMaxPoolOutputProvided = (y != null);
 
                try(LibMatrixCuDNNPoolingDescriptors desc = 
                                
LibMatrixCuDNNPoolingDescriptors.cudnnMaxpoolingBackwardDescriptors(gCtx, 
instName, N, C, H, W, K, R, S, 
                                                pad_h, pad_w, stride_h, 
stride_w, P, Q)) {
                        long t1=0, t2=0, t3=0;
-                       if (GPUStatistics.DISPLAY_STATISTICS) t1 = 
System.nanoTime();
-                       
-                       // Calling PoolForward first, y is one of the inputs 
for poolBackward
-                       // TODO: Remove calling poolForward after necessary 
changes at language level for poolBackward
-                       long numBytes = N*C*P*Q*sizeOfDataType;
-                       y = gCtx.allocate(numBytes);
-                       
-                       if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
-                       
-                       if (GPUStatistics.DISPLAY_STATISTICS) t2 = 
System.nanoTime();
-                       int status = cudnnPoolingForward(getCudnnHandle(gCtx), 
desc.poolingDesc, one(), desc.xDesc, x, zero(), desc.yDesc, y);
-                       if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
-
-                       if(status != 
jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
-                               throw new DMLRuntimeException("Could not 
executed cudnnPoolingForward before cudnnPoolingBackward: " + 
jcuda.jcudnn.cudnnStatus.stringFor(status));
+                       int status;
+                       if(!isMaxPoolOutputProvided) {
+                               if (GPUStatistics.DISPLAY_STATISTICS) t1 = 
System.nanoTime();
+                               long numBytes = N*C*P*Q*sizeOfDataType;
+                               y = gCtx.allocate(numBytes);
+                               if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_INIT, System.nanoTime() - t1);
+                               if (GPUStatistics.DISPLAY_STATISTICS) t2 = 
System.nanoTime();
+                               status = 
cudnnPoolingForward(getCudnnHandle(gCtx), desc.poolingDesc, one(), desc.xDesc, 
x, zero(), desc.yDesc, y);
+                               if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_MAXPOOLING_FORWARD_LIB, System.nanoTime() - t2);
+                               if(status != 
jcuda.jcudnn.cudnnStatus.CUDNN_STATUS_SUCCESS) {
+                                       throw new DMLRuntimeException("Could 
not executed cudnnPoolingForward before cudnnPoolingBackward: " + 
jcuda.jcudnn.cudnnStatus.stringFor(status));
+                               }
                        }
-
                        if (GPUStatistics.DISPLAY_STATISTICS) t3 = 
System.nanoTime();
                        status = cudnnPoolingBackward(getCudnnHandle(gCtx), 
desc.poolingDesc, one(), desc.yDesc, y, desc.dyDesc, dy, desc.xDesc, x, zero(), 
desc.dxDesc, dx);
                        if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_MAXPOOLING_BACKWARD_LIB, System.nanoTime() - t3);
@@ -615,7 +622,7 @@ public class LibMatrixCuDNN extends LibMatrixCUDA {
                finally {
                        long t4=0;
                        if (GPUStatistics.DISPLAY_STATISTICS) t4 = 
System.nanoTime();
-                       if(y != null)
+                       if(!isMaxPoolOutputProvided)
                                gCtx.cudaFreeHelper(instName, y);
                        if (GPUStatistics.DISPLAY_STATISTICS) 
GPUStatistics.maintainCPMiscTimes(instName, 
GPUInstruction.MISC_TIMER_CUDNN_CLEANUP, System.nanoTime() - t4);
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/06d5bb07/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java 
b/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
index aba0cae..c57e997 100644
--- a/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
+++ b/src/test/java/org/apache/sysml/test/gpu/NeuralNetworkOpTests.java
@@ -579,5 +579,87 @@ public class NeuralNetworkOpTests extends GPUTests {
                        }
                }
        }
+       
+       
+       @Test
+       @Ignore
+       public void testMaxPoolBackwardWithMaxpoolOut() {
+               String scriptStr = "tmp = max_pool(image, padding=[padH, padW], 
stride=[strideH, strideW], input_shape=[N,C,H,W], pool_size=[R,S]); 
print(sum(tmp)); O = max_pool_backward(image, dout, padding=[padH, padW], 
stride=[strideH, strideW], input_shape=[N,C,H,W], pool_size=[R,S])";
+
+               for (long N : Nlst) {
+                       for (long C : Clst) {
+                               for (long H : Hlst) {
+                                       long W = H;
+                                       for (long R : Rlst) {
+                                               long S = R;
+                                               for (long strideH : strideLst) {
+                                                       long strideW = strideH;
+                                                       for (long padH : 
padLst) {
+                                                               long padW = 
padH;
+                                                               for (double 
sparsity : sparsitylst) {
+
+                                                                       // pool 
is smaller than image + padding
+                                                                       if (R > 
(H + padH) || S > (W + padW))
+                                                                               
continue;
+
+                                                                       // Make 
sure ops fit in GPU memory and within constraints of cudnn
+                                                                       long 
imageSize = N * C * H * W * 8l;
+                                                                       if 
(imageSize > MAX_OP_SIZE)  // image size
+                                                                               
continue;
+                                                                       long 
poolSize = R * S * 8l;
+                                                                       if 
(poolSize > MAX_OP_SIZE)  // filter size
+                                                                               
continue;
+
+                                                                       int P = 
(int) ConvolutionUtils.getP(H, R, strideH, padH);
+                                                                       int Q = 
(int) ConvolutionUtils.getQ(W, S, strideW, padW);
+
+                                                                       long 
doutSize = N * C * P * Q * 8l;
+                                                                       if 
(doutSize > MAX_OP_SIZE) // dout/output size
+                                                                               
continue;
+
+                                                                       double 
imageSizeInMB = imageSize / (1024.0 * 1024.0);
+                                                                       double 
poolSizeInMB = poolSize / (1024.0 * 1024.0);
+                                                                       double 
doutSizeInMB = doutSize / (1024.0 * 1024.0);
+                                                                       
System.out
+                                                                       
.format("max_pool_backward, image[%d,%d,%d,%d](%.1fMB), pool[%d,%d](%.1f), 
dout[%d,%d,%d,%d](%.1fMB), stride[%d,%d], padding[%d,%d]",
+                                                                               
        N, C, H, W, imageSizeInMB, R, S, poolSizeInMB, N, C,
+                                                                               
        P, Q, doutSizeInMB, strideH, strideW, padH, padW);
+
+                                                                       Matrix 
image = generateInputMatrix(spark, (int) N,
+                                                                               
        (int) (C * H * W), -127.0, 127, sparsity, seed, true);
+                                                                       Matrix 
dout = generateInputMatrix(spark, (int) N, (int) (C * P * Q),
+                                                                               
        -127.0, 127, sparsity, seed, true);
+                                                                       
HashMap<String, Object> inputs = new HashMap<>();
+                                                                       
inputs.put("N", N);
+                                                                       
inputs.put("C", C);
+                                                                       
inputs.put("H", H);
+                                                                       
inputs.put("W", W);
+                                                                       
inputs.put("R", R);
+                                                                       
inputs.put("S", S);
+                                                                       
inputs.put("strideH", strideH);
+                                                                       
inputs.put("strideW", strideW);
+                                                                       
inputs.put("padH", padH);
+                                                                       
inputs.put("padW", padW);
+                                                                       
inputs.put("image", image);
+                                                                       
inputs.put("dout", dout);
+                                                                       
List<Object> outCPU = runOnCPU(spark, scriptStr, inputs,
+                                                                               
        Arrays.asList("O"));
+                                                                       
List<Object> outGPU = runOnGPU(spark, scriptStr, inputs,
+                                                                               
        Arrays.asList("O"));
+                                                                       
assertHeavyHitterPresent("gpu_maxpooling_backward");
+                                                                       
assertEqualObjects(outCPU.get(0), outGPU.get(0));
+                                                                       
clearGPUMemory();
+                                                               }
+                                                       }
+                                               }
+                                       }
+
+
+
+
+                               }
+                       }
+               }
+       }
 
 }

Reply via email to