http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
new file mode 100644
index 0000000..4532240
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java
@@ -0,0 +1,645 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.cp;
+
+import java.util.ArrayList;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.matrix.data.DnnParameters;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
+import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.util.DnnUtils;
+import org.apache.sysml.utils.NativeHelper;
+
+public class DnnCPInstruction extends UnaryCPInstruction {
+       private static final Log LOG = 
LogFactory.getLog(DnnCPInstruction.class.getName());
+       private static boolean warnedUnderUtilitization = false;
+       
+       private final CPOperand _in2;
+       private final CPOperand _in3;
+       private final CPOperand _in4;
+       private final CPOperand _in5;
+       private final CPOperand _in6;
+       private final CPOperand _in7;
+       private final CPOperand _in8;
+       private final CPOperand _out2;
+       private final CPOperand _out3;
+       private final CPOperand _out4;
+       private final CPOperand _out5;
+       private final ArrayList<CPOperand> _input_shape;
+       private final ArrayList<CPOperand> _filter_shape;
+       private final ArrayList<CPOperand> _stride;
+       private final ArrayList<CPOperand> _padding;
+       private final int _numThreads;
+       private final double _intermediateMemoryBudget;
+       
+       public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, 
CPOperand out, 
+                       ArrayList<CPOperand> stride, ArrayList<CPOperand> 
padding, ArrayList<CPOperand> input_shape,
+                       ArrayList<CPOperand> filter_shape, int numThreads, 
double intermediateMemoryBudget, String opcode, String istr) {
+               super(CPType.Dnn, null, in, out, opcode, istr);
+               _in2 = in2;
+               _in3 = in3;
+               _in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = null;
+               _out2 = null; _out3 = null; _out4 = null; _out5 = null;
+               _stride = stride;
+               _padding = padding;
+               _input_shape = input_shape;
+               _filter_shape = filter_shape;
+               _numThreads = numThreads;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
+       
+       public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand out, 
String opcode, String istr, int numThreads, double intermediateMemoryBudget) {
+               this(in, in2, null, out, null, null, null, null, numThreads, 
intermediateMemoryBudget, opcode, istr);
+               if( !(opcode.equals("bias_add") || 
opcode.equals("relu_backward") || opcode.equals("bias_multiply") ) ) {
+                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be bias_add or bias_multiply or relu_backward, but found 
" + opcode);
+               }
+       }
+       
+       public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, 
CPOperand out, String opcode, String istr, int numThreads, double 
intermediateMemoryBudget) {
+               this(in, in2, in3, out, null, null, null, null, numThreads, 
intermediateMemoryBudget, opcode, istr);
+               if( !opcode.equals("channel_sums") ) {
+                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be channel_sums, but found " + opcode);
+               }
+       }
+
+       private DnnCPInstruction(CPOperand in, CPOperand out, String opcode, 
String istr,
+                       ArrayList<CPOperand> stride, ArrayList<CPOperand> 
padding, ArrayList<CPOperand> input_shape,
+                       ArrayList<CPOperand> filter_shape, int numThreads, 
double intermediateMemoryBudget) {
+               this(in, null, null, out, stride, padding, input_shape, 
filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
+       }
+       
+       public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand out, 
String opcode,
+                       String istr, ArrayList<CPOperand> stride,
+                       ArrayList<CPOperand> padding, ArrayList<CPOperand> 
input_shape,
+                       ArrayList<CPOperand> filter_shape, int numThreads, 
double intermediateMemoryBudget) {
+               this(in, in2, null, out, stride, padding, input_shape, 
filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
+       }
+       
+       public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, 
CPOperand out, String opcode,
+                       String istr, ArrayList<CPOperand> stride,
+                       ArrayList<CPOperand> padding, ArrayList<CPOperand> 
input_shape,
+                       ArrayList<CPOperand> filter_shape, int numThreads, 
double intermediateMemoryBudget) {
+               this(in, in2, in3, out, stride, padding, input_shape, 
filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
+       }
+       
+       public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
CPOperand in4, CPOperand in5,
+                       CPOperand in6, CPOperand in7, CPOperand in8,
+                       CPOperand out, CPOperand out2, CPOperand out3, 
CPOperand out4, CPOperand out5, String opcode, String istr, 
+                       double intermediateMemoryBudget) throws 
DMLRuntimeException {
+               super(CPType.Dnn, null, in1, out, opcode, istr);
+               _in2 = in2;
+               _in3 = in3;
+               _in4 = in4;
+               _in5 = in5;
+               _in6 = in6;
+               _in7 = in7;
+               _in8 = in8;
+               _out2 = out2;
+               _out3 = out3;
+               _out4 = out4;
+               _out5 = out5;
+               _stride = null;
+               _padding = null;
+               _input_shape = null;
+               _filter_shape = null;
+               _numThreads = 0;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
+
+       public static DnnCPInstruction parseInstruction(String str) {
+
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+               if (opcode.equalsIgnoreCase("maxpooling") || 
opcode.equalsIgnoreCase("relu_maxpooling") ||
+                       opcode.equalsIgnoreCase("avgpooling")) {
+                       InstructionUtils.checkNumFields(parts, 16);
+                       // stride1, stride2, padding1, padding2
+                       // input_shape1, input_shape2, input_shape3, 
input_shape4,
+                       // filter_shape1, filter_shape2, filter_shape3, 
filter_shape4, k
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand out = new CPOperand(parts[14]);
+
+                       ArrayList<CPOperand> stride = new ArrayList<>();
+                       ArrayList<CPOperand> padding = new ArrayList<>();
+                       ArrayList<CPOperand> input_shape = new ArrayList<>();
+                       ArrayList<CPOperand> filter_shape = new ArrayList<>();
+                       stride.add(new CPOperand(parts[2]));
+                       stride.add(new CPOperand(parts[3]));
+                       padding.add(new CPOperand(parts[4]));
+                       padding.add(new CPOperand(parts[5]));
+                       input_shape.add(new CPOperand(parts[6]));
+                       input_shape.add(new CPOperand(parts[7]));
+                       input_shape.add(new CPOperand(parts[8]));
+                       input_shape.add(new CPOperand(parts[9]));
+                       filter_shape.add(new CPOperand(parts[10]));
+                       filter_shape.add(new CPOperand(parts[11]));
+                       filter_shape.add(new CPOperand(parts[12]));
+                       filter_shape.add(new CPOperand(parts[13]));
+                       int k = Integer.parseInt(parts[15]);
+
+                       return new DnnCPInstruction(in, out, opcode, str, 
stride,
+                                       padding, input_shape, filter_shape, k, 
Double.parseDouble(parts[16]));
+               } 
+               else if (opcode.equalsIgnoreCase("maxpooling_backward") || 
opcode.equalsIgnoreCase("relu_maxpooling_backward")
+                               || 
opcode.equalsIgnoreCase("avgpooling_backward")
+                               || opcode.equalsIgnoreCase("conv2d")
+                               || 
opcode.equalsIgnoreCase("conv2d_backward_filter")
+                               || 
opcode.equalsIgnoreCase("conv2d_backward_data")) {
+                       InstructionUtils.checkNumFields(parts, 17);
+                       // dout, stride1, stride2, padding1, padding2
+                       // input_shape1, input_shape2, input_shape3, 
input_shape4,
+                       // filter_shape1, filter_shape2, filter_shape3, 
filter_shape4, k
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand out = new CPOperand(parts[15]);
+
+                       ArrayList<CPOperand> stride = new ArrayList<>();
+                       ArrayList<CPOperand> padding = new ArrayList<>();
+                       ArrayList<CPOperand> input_shape = new ArrayList<>();
+                       ArrayList<CPOperand> filter_shape = new ArrayList<>();
+                       stride.add(new CPOperand(parts[3]));
+                       stride.add(new CPOperand(parts[4]));
+                       padding.add(new CPOperand(parts[5]));
+                       padding.add(new CPOperand(parts[6]));
+                       input_shape.add(new CPOperand(parts[7]));
+                       input_shape.add(new CPOperand(parts[8]));
+                       input_shape.add(new CPOperand(parts[9]));
+                       input_shape.add(new CPOperand(parts[10]));
+                       filter_shape.add(new CPOperand(parts[11]));
+                       filter_shape.add(new CPOperand(parts[12]));
+                       filter_shape.add(new CPOperand(parts[13]));
+                       filter_shape.add(new CPOperand(parts[14]));
+                       int k = Integer.parseInt(parts[16]);
+
+                       return new DnnCPInstruction(in, in2, out, opcode, str, 
stride,
+                                       padding, input_shape, filter_shape, k, 
Double.parseDouble(parts[17]));
+               }
+               else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
+                       InstructionUtils.checkNumFields(parts, 18);
+                       // dout, stride1, stride2, padding1, padding2
+                       // input_shape1, input_shape2, input_shape3, 
input_shape4,
+                       // filter_shape1, filter_shape2, filter_shape3, 
filter_shape4, k
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand out = new CPOperand(parts[16]);
+
+                       ArrayList<CPOperand> stride = new ArrayList<>();
+                       ArrayList<CPOperand> padding = new ArrayList<>();
+                       ArrayList<CPOperand> input_shape = new ArrayList<>();
+                       ArrayList<CPOperand> filter_shape = new ArrayList<>();
+                       stride.add(new CPOperand(parts[4]));
+                       stride.add(new CPOperand(parts[5]));
+                       padding.add(new CPOperand(parts[6]));
+                       padding.add(new CPOperand(parts[7]));
+                       input_shape.add(new CPOperand(parts[8]));
+                       input_shape.add(new CPOperand(parts[9]));
+                       input_shape.add(new CPOperand(parts[10]));
+                       input_shape.add(new CPOperand(parts[11]));
+                       filter_shape.add(new CPOperand(parts[12]));
+                       filter_shape.add(new CPOperand(parts[13]));
+                       filter_shape.add(new CPOperand(parts[14]));
+                       filter_shape.add(new CPOperand(parts[15]));
+                       int k = Integer.parseInt(parts[17]);
+
+                       return new DnnCPInstruction(in, in2, in3, out, opcode, 
str, stride,
+                                       padding, input_shape, filter_shape, k, 
Double.parseDouble(parts[18]));
+               }
+               else if (opcode.equalsIgnoreCase("bias_add") || 
opcode.equals("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) {
+                       InstructionUtils.checkNumFields(parts, 5);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand out = new CPOperand(parts[3]);
+                       int k = Integer.parseInt(parts[4]);
+                       return new DnnCPInstruction(in, in2, out, opcode, str, 
k, Double.parseDouble(parts[5]));
+               }
+               else if (opcode.equalsIgnoreCase("channel_sums")) {
+                       InstructionUtils.checkNumFields(parts, 4);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand out = new CPOperand(parts[4]);
+                       return new DnnCPInstruction(in, in2, in3, out, opcode, 
str, -1, 0);
+               }
+               else if (opcode.equalsIgnoreCase("batch_norm2d")) {
+                       InstructionUtils.checkNumFields(parts, 13);
+                       CPOperand in1 = new CPOperand(parts[1]); // image
+                       CPOperand in2 = new CPOperand(parts[2]); // scale
+                       CPOperand in3 = new CPOperand(parts[3]); // bias
+                       CPOperand in4 = new CPOperand(parts[4]); // runningMean
+                       CPOperand in5 = new CPOperand(parts[5]); // runningVar
+                       CPOperand in6 = new CPOperand(parts[6]); // mode
+                       CPOperand in7 = new CPOperand(parts[7]); // epsilon
+                       CPOperand in8 = new CPOperand(parts[8]); // 
exponentialAverageFactor
+                       CPOperand out = new CPOperand(parts[9]);  // ret
+                       CPOperand out2 = new CPOperand(parts[10]); // 
retRunningMean
+                       CPOperand out3 = new CPOperand(parts[11]); // 
retRunningVar
+                       CPOperand out4 = new CPOperand(parts[12]); // 
resultSaveMean
+                       CPOperand out5 = new CPOperand(parts[13]); // 
resultSaveInvVariance
+                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
+               }
+               else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
+                       InstructionUtils.checkNumFields(parts, 9);
+                       CPOperand in1 = new CPOperand(parts[1]); // image
+                       CPOperand in2 = new CPOperand(parts[2]); // dout
+                       CPOperand in3 = new CPOperand(parts[3]); // scale
+                       CPOperand in4 = new CPOperand(parts[4]); // epsilon
+                       CPOperand in5 = new CPOperand(parts[5]); // 
resultSaveMean
+                       CPOperand in6 = new CPOperand(parts[6]); // 
resultSaveInvVariance
+                       CPOperand out = new CPOperand(parts[7]);  // dX
+                       CPOperand out2 = new CPOperand(parts[8]); // dScale
+                       CPOperand out3 = new CPOperand(parts[9]); // dBias
+                       return new DnnCPInstruction(in1, in2, in3, in4, in5, 
in6, null, null, out, out2, out3, null, null, opcode, str, 0);
+               }
+               else {
+                       throw new DMLRuntimeException("Unknown opcode while 
parsing a DnnCPInstruction: " + str);
+               }
+       }
+
+       private static int getScalarInput(ExecutionContext ec, 
ArrayList<CPOperand> aL, int index) {
+               return (int) ec.getScalarInput(aL.get(index).getName(),
+                       aL.get(index).getValueType(), 
aL.get(index).isLiteral()).getLongValue();
+       }
+       
+       public void processReluBackwardInstruction(ExecutionContext ec) {
+               // (X > 0) * dout
+               MatrixBlock input = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               MatrixBlock dout = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               MatrixBlock outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(),
+                       input.isInSparseFormat() || dout.isInSparseFormat() );
+               
+               if( !input.isEmpty() && !dout.isEmpty() ) { //sparse-safe
+                       outputBlock.allocateBlock();
+                       LibMatrixDNN.reluBackward(input, dout, outputBlock, 
_numThreads);
+               }
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
+       }
+       
+       public void processBiasAddInstruction(ExecutionContext ec) {
+               MatrixBlock input = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               MatrixBlock bias = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               MatrixBlock outputBlock = null;
+               
+               if(bias.getNumColumns() != 1) {
+                       throw new DMLRuntimeException("Expected the number of 
columns of bias matrix to be 1, but found " + bias.getNumColumns());
+               }
+               
+               if(input.isEmpty() && bias.isEmpty()) {
+                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), true);
+               }
+               else if(bias.isEmpty()) {
+                       outputBlock = new MatrixBlock(input);
+               }
+               else {
+                       // As we always fill the output first with bias
+                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), false);
+                       outputBlock.allocateDenseBlock();
+                       LibMatrixDNN.biasAdd(input, bias, outputBlock, 
_numThreads);
+               }
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
+       }
+       
+       public void processBiasMultiplyInstruction(ExecutionContext ec) {
+               MatrixBlock input = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               MatrixBlock bias = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               MatrixBlock outputBlock = null;
+               
+               if(bias.getNumColumns() != 1) {
+                       throw new DMLRuntimeException("Expected the number of 
columns of bias matrix to be 1, but found " + bias.getNumColumns());
+               }
+               
+               if(bias.isEmpty()) {
+                       // Anything multiplied by zero is zero
+                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), true);
+               }
+               else {
+                       // As we always fill the output first with bias
+                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), 
+                               input.isInSparseFormat()).allocateBlock();
+                       LibMatrixDNN.biasMultiply(input, bias, outputBlock, 
_numThreads);
+               }
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
+       }
+       
+       public void processChannelSumsInstruction(ExecutionContext ec) {
+               MatrixBlock input = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               int C = (int) ec.getScalarInput(_in2.getName(), 
_in2.getValueType(), _in2.isLiteral()).getLongValue();
+               int HW = (int) ec.getScalarInput(_in3.getName(), 
_in3.getValueType(), _in3.isLiteral()).getLongValue();
+               if(C*HW != input.getNumColumns()) {
+                       throw new DMLRuntimeException("Expected rows*cols" + C 
+ "*" + HW + " to be equal to number of columns of input " + 
input.getNumColumns());
+               }
+               MatrixBlock outputBlock = null;
+               if(input.isEmpty()) {
+                       outputBlock = new MatrixBlock(C, 1, true);
+               }
+               else {
+                       outputBlock = new MatrixBlock(C, 1, 
false).allocateBlock();
+                       LibMatrixDNN.channelSums(input, outputBlock, C, HW);
+               }
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
+       }
+       
+       
+       
+       public void processBatchNorm2dInstruction(ExecutionContext ec) {
+               MatrixBlock image = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               MatrixBlock scale = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               MatrixBlock bias = ec.getMatrixInput(_in3.getName(), 
getExtendedOpcode());
+               MatrixBlock runningMean = ec.getMatrixInput(_in4.getName(), 
getExtendedOpcode());
+               MatrixBlock runningVar = ec.getMatrixInput(_in5.getName(), 
getExtendedOpcode());
+               String phase = ec.getScalarInput(_in6.getName(), 
_in6.getValueType(), _in6.isLiteral()).getStringValue();
+               double epsilon = ec.getScalarInput(_in7.getName(), 
_in7.getValueType(), _in7.isLiteral()).getDoubleValue();
+               double mu = ec.getScalarInput(_in8.getName(), 
_in8.getValueType(), _in8.isLiteral()).getDoubleValue();
+               
+               MatrixBlock ret = new MatrixBlock(image.getNumRows(), 
image.getNumColumns(), false).allocateBlock();
+               MatrixBlock retRunningMean = new 
MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), 
false).allocateBlock();
+               MatrixBlock retRunningVar = new 
MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), 
false).allocateBlock();
+               MatrixBlock resultSaveMean = new 
MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), 
false).allocateBlock();
+               MatrixBlock resultSaveInvVariance = new 
MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), 
false).allocateBlock();
+               
+               LibMatrixDNN.batchNorm2D(image, scale, bias, runningMean, 
runningVar, phase, epsilon, mu, ret, 
+                               retRunningMean, retRunningVar, resultSaveMean, 
resultSaveInvVariance);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(output.getName(), ret, getExtendedOpcode());
+               ec.setMatrixOutput(_out2.getName(), retRunningMean, 
getExtendedOpcode());
+               ec.setMatrixOutput(_out3.getName(), retRunningVar, 
getExtendedOpcode());
+               ec.setMatrixOutput(_out4.getName(), resultSaveMean, 
getExtendedOpcode());
+               ec.setMatrixOutput(_out5.getName(), resultSaveInvVariance, 
getExtendedOpcode());
+       }
+       
+       public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) {
+               MatrixBlock image = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               MatrixBlock dout = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               MatrixBlock scale = ec.getMatrixInput(_in3.getName(), 
getExtendedOpcode());
+               double epsilon = ec.getScalarInput(_in4.getName(), 
_in4.getValueType(), _in4.isLiteral()).getDoubleValue();
+               MatrixBlock resultSaveMean = ec.getMatrixInput(_in5.getName(), 
getExtendedOpcode());
+               MatrixBlock resultSaveInvVariance = 
ec.getMatrixInput(_in6.getName(), getExtendedOpcode());
+               
+               MatrixBlock dX = new MatrixBlock(image.getNumRows(), 
image.getNumColumns(), false).allocateBlock();
+               MatrixBlock dScale = new MatrixBlock(scale.getNumRows(), 
scale.getNumColumns(), false).allocateBlock();
+               MatrixBlock dBias = new MatrixBlock(scale.getNumRows(), 
scale.getNumColumns(), false).allocateBlock();
+               
+               LibMatrixDNN.batchNorm2DBackward(image, dout, scale, epsilon, 
resultSaveMean, resultSaveInvVariance, dX, dScale, dBias);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in6.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(output.getName(), dX, getExtendedOpcode());
+               ec.setMatrixOutput(_out2.getName(), dScale, 
getExtendedOpcode());
+               ec.setMatrixOutput(_out3.getName(), dBias, getExtendedOpcode());
+       }
+       
+       
+       // Assumption: enableNative && NativeHelper.isNativeLibraryLoaded() is 
true
+       // This increases the number of native calls. For example:the cases 
where filter is sparse but input is dense
+       private static boolean isFilterSparse(MatrixBlock filter) {
+               long numElems = filter.getNumRows()*filter.getNumColumns();
+               // if filter is less than 10 MB in dense format (which handles 
almost all the cases).
+               // In fact, using threshold of 1 MB is still sufficient for 
common CNNs.
+               if(filter.isInSparseFormat() && numElems < 10e+6)
+                       filter.sparseToDense(); 
+               return filter.isInSparseFormat();
+       }
+       
+       
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               
+               if (instOpcode.equalsIgnoreCase("bias_add")) {
+                       processBiasAddInstruction(ec);
+                       return;
+               }
+               else if (instOpcode.equalsIgnoreCase("bias_multiply")) {
+                       processBiasMultiplyInstruction(ec);
+                       return;
+               }
+               else if (instOpcode.equalsIgnoreCase("relu_backward")) {
+                       processReluBackwardInstruction(ec);
+                       return;
+               }
+               else if (instOpcode.equalsIgnoreCase("channel_sums")) {
+                       processChannelSumsInstruction(ec);
+                       return;
+               }
+               else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
+                       processBatchNorm2dInstruction(ec);
+                       return;
+               }
+               else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
+                       processBatchNorm2dBackwardInstruction(ec);
+                       return;
+               }
+               
+               // acquire inputs
+               MatrixBlock outputBlock = null;
+               MatrixBlock matBlock = 
instOpcode.equalsIgnoreCase("avgpooling_backward") ? null : 
ec.getMatrixInput(input1.getName(), getExtendedOpcode());
+               int pad_h = getScalarInput(ec, _padding, 0);
+               int pad_w = getScalarInput(ec, _padding, 1);
+               int stride_h = getScalarInput(ec, _stride, 0);
+               int stride_w = getScalarInput(ec, _stride, 1);
+
+               int N = getScalarInput(ec, _input_shape, 0);
+               int C = getScalarInput(ec, _input_shape, 1);
+               int H = getScalarInput(ec, _input_shape, 2);
+               int W = getScalarInput(ec, _input_shape, 3);
+
+               int K = getScalarInput(ec, _filter_shape, 0);
+               
+               int R = getScalarInput(ec, _filter_shape, 2);
+               int S = getScalarInput(ec, _filter_shape, 3);
+               int P = (int) DnnUtils.getP(H, R, stride_h, pad_h);
+               int Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w);
+               
+               DnnParameters params = new DnnParameters(N, C, H, W, K, R, S, 
stride_h, stride_w, pad_h, pad_w, _numThreads);
+               params.enableNative = NativeHelper.isNativeLibraryLoaded();
+               if (instOpcode.equalsIgnoreCase("maxpooling") || 
instOpcode.equalsIgnoreCase("relu_maxpooling") ||
+                       instOpcode.equalsIgnoreCase("avgpooling")) {
+                       if(matBlock.isEmpty()) {
+                               outputBlock = new MatrixBlock(N, C*P*Q, true);
+                       }
+                       else {
+                               outputBlock = new MatrixBlock(N, C*P*Q, 
false).allocateBlock();
+                               
+                               PoolingType poolType = 
(instOpcode.equalsIgnoreCase("maxpooling") || 
instOpcode.equalsIgnoreCase("relu_maxpooling")) ? PoolingType.MAX : 
PoolingType.AVG;
+                               
if(instOpcode.equalsIgnoreCase("relu_maxpooling"))
+                                       params.minValForMaxPoolOperations = 0;
+                               LibMatrixDNN.pooling(matBlock, outputBlock, 
params, poolType);
+                       }
+               }
+               else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || 
instOpcode.equalsIgnoreCase("relu_maxpooling_backward") ||
+                               
instOpcode.equalsIgnoreCase("avgpooling_backward")) {
+                       MatrixBlock dout = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+                       boolean isEmpty = 
instOpcode.equalsIgnoreCase("avgpooling_backward") ? dout.isEmpty() : 
(matBlock.isEmpty() || dout.isEmpty());
+                       if(isEmpty) {
+                               outputBlock = new MatrixBlock(N, C*H*W, true);
+                       }
+                       else {
+                               outputBlock = new MatrixBlock(N, C*H*W, 
false).allocateBlock();
+                               PoolingType poolType = 
(instOpcode.equalsIgnoreCase("maxpooling_backward") || 
instOpcode.equalsIgnoreCase("relu_maxpooling_backward")) ? PoolingType.MAX : 
PoolingType.AVG;
+                               boolean performReLUBackward = 
instOpcode.equalsIgnoreCase("relu_maxpooling_backward");
+                               if(performReLUBackward)
+                                       params.minValForMaxPoolOperations = 0;
+                               LibMatrixDNN.poolingBackward(matBlock, dout, 
outputBlock, params, performReLUBackward, poolType);
+                       }
+                       ec.releaseMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               }
+               else if (instOpcode.equalsIgnoreCase("conv2d")) {
+                       resetNumThreads(params, C*R*S, P*Q, 
matBlock.getNonZeros() / (matBlock.getNumRows()*matBlock.getNumColumns()));
+                       MatrixBlock filter = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+                       if(filter.isEmpty() || matBlock.isEmpty()) {
+                               outputBlock = new MatrixBlock(N, K*P*Q, true);
+                       }
+                       else {
+                               boolean sparse = matBlock.isUltraSparse(false) 
&& params.bias == null
+                                       && matBlock.getInMemorySize() < 
MatrixBlock.estimateSizeDenseInMemory(N, K*P*Q);
+                               outputBlock = new MatrixBlock(N, K*P*Q, 
sparse).allocateBlock();
+                               if(params.enableNative && 
!isFilterSparse(filter) && !matBlock.isInSparseFormat())
+                                       LibMatrixNative.conv2d(matBlock, 
filter, outputBlock, params);
+                               else
+                                       LibMatrixDNN.conv2d(matBlock, filter, 
outputBlock, params);
+                       }
+                       ec.releaseMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               }
+               else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
+                       resetNumThreads(params, C*R*S, P*Q, 
matBlock.getNonZeros() / (matBlock.getNumRows()*matBlock.getNumColumns()));
+                       MatrixBlock filter = ec.getMatrixInput(_in3.getName(), 
getExtendedOpcode());
+                       MatrixBlock bias = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+                       if(bias.getNumRows() != params.K || 
bias.getNumColumns() != 1) {
+                               throw new DMLRuntimeException("Incorrect shape 
of bias matrix: [" + bias.getNumRows() + " " + bias.getNumColumns() + "]. "
+                                               + "Expected: [" + params.K + ", 
1]");
+                       }
+                       boolean isOutputConvEmpty = filter.isEmpty() || 
matBlock.isEmpty();
+                       if(isOutputConvEmpty && bias.isEmpty()) {
+                               // bias_add(empty mb, empty mb) = empty mb
+                               outputBlock = new MatrixBlock(N, K*P*Q, true);
+                       }
+                       else if(isOutputConvEmpty && !bias.isEmpty()) {
+                               // Add bias to empty output block
+                               // bias_add(empty mb, bias)
+                               outputBlock = new MatrixBlock(N, K*P*Q, 
false).allocateBlock();
+                               for(int n = 0;  n < params.N; n++) 
+                                       DnnUtils.fillBias(bias, 
outputBlock.getDenseBlockValues(),
+                                               n, n+1, params.N, params.K, 
params.P*params.Q);
+                       }
+                       else {
+                               outputBlock = new MatrixBlock(N, K*P*Q, 
false).allocateBlock();
+                               if(!bias.isEmpty()) {
+                                       // Handle situation where both input 
and filter are non empty, but bias is empty
+                                       params.bias = bias;
+                               }
+                               if(params.enableNative && 
!isFilterSparse(filter) && !matBlock.isInSparseFormat())
+                                       LibMatrixNative.conv2d(matBlock, 
filter, outputBlock, params);
+                               else
+                                       LibMatrixDNN.conv2d(matBlock, filter, 
outputBlock, params);
+                       }
+                       ec.releaseMatrixInput(_in3.getName(), 
getExtendedOpcode());
+                       ec.releaseMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               }
+               else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) 
{
+                       MatrixBlock dout = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+                       if(dout.isEmpty() || matBlock.isEmpty()) {
+                               outputBlock = new MatrixBlock(K, C*R*S, true);
+                       }
+                       else {
+                               outputBlock = new MatrixBlock(K, C*R*S, 
false).allocateBlock();
+                               if(params.enableNative && 
!matBlock.isInSparseFormat() && !dout.isInSparseFormat())
+                                       
LibMatrixNative.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
+                               else
+                                       
LibMatrixDNN.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
+                       }
+                       ec.releaseMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               }
+               else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
+                       MatrixBlock dout = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+                       if(dout.isEmpty() || matBlock.isEmpty()) {
+                               outputBlock = new MatrixBlock(N, C * H * W, 
true);
+                       }
+                       else {
+                               outputBlock = new MatrixBlock(N, C * H * W, 
false).allocateBlock();
+                               if(params.enableNative && 
!isFilterSparse(matBlock) && !dout.isInSparseFormat())
+                                       
LibMatrixNative.conv2dBackwardData(matBlock, dout, outputBlock, params);
+                               else
+                                       
LibMatrixDNN.conv2dBackwardData(matBlock, dout, outputBlock, params);
+                       }
+                       ec.releaseMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               }
+               else {
+                       throw new DMLRuntimeException("Unsupported op code " + 
instOpcode);
+               }
+               
+               // release inputs/outputs
+               if(!instOpcode.equalsIgnoreCase("avgpooling_backward"))
+                       ec.releaseMatrixInput(input1.getName(), 
getExtendedOpcode());
+               ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
+       }
+       
+       /**
+        * Reset the number of thread to respect the intermediate CP memory 
budget
+        * 
+        * @param params convolution parameters
+        * @param numRows number of rows of intermediate matrix used per thread
+        * @param numCols number of rows of intermediate matrix used per thread
+        * @param sparsity sparsity of intermediate matrix used per thread
+        */
+       private void resetNumThreads(DnnParameters params, int numRows, int 
numCols, double sparsity) {
+               if(DMLScript.USE_ACCELERATOR) {
+                       double memBudget1Thread = 
OptimizerUtils.estimateSizeExactSparsity(numRows, numCols, sparsity);
+                       int limitedDegreeOfParallelism = (int) 
Math.floor(_intermediateMemoryBudget / memBudget1Thread);
+                       if(params.numThreads > limitedDegreeOfParallelism) {
+                               params.numThreads = limitedDegreeOfParallelism;
+                               if(!warnedUnderUtilitization)
+                                       LOG.warn("CPU Under-utilization to 
respect the intermediate memory budget. To avoid this, please try reducing the 
mini-batch or forcing gpu execution.");
+                               warnedUnderUtilitization = true;
+                       }
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
deleted file mode 100644
index 7d8a0fc..0000000
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
+++ /dev/null
@@ -1,750 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sysml.runtime.instructions.gpu;
-
-import java.util.ArrayList;
-
-import jcuda.Pointer;
-
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.functionobjects.SwapIndex;
-import org.apache.sysml.runtime.instructions.InstructionUtils;
-import org.apache.sysml.runtime.instructions.cp.CPOperand;
-import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig;
-import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
-import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
-import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
-import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
-import org.apache.sysml.runtime.util.ConvolutionUtils;
-import org.apache.sysml.utils.GPUStatistics;
-
-public class ConvolutionGPUInstruction extends GPUInstruction {
-       private CPOperand _input1;
-       private CPOperand _input2;
-       private CPOperand _input3;
-       private CPOperand _input4;
-       private CPOperand _input5;
-       private CPOperand _input6;
-       private CPOperand _input7;
-       private CPOperand _input8;
-       private CPOperand _output;
-       private CPOperand _output2;
-       private CPOperand _output3;
-       private CPOperand _output4;
-       private CPOperand _output5;
-       private ArrayList<CPOperand> _input_shape;
-       private ArrayList<CPOperand> _filter_shape;
-       private ArrayList<CPOperand> _stride = new ArrayList<>();
-       private ArrayList<CPOperand> _padding = new ArrayList<>();
-       private double _intermediateMemoryBudget = 0;
-       
-       public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, 
CPOperand out, String opcode, String istr, double intermediateMemoryBudget) {
-               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
-               if (!(opcode.equals("bias_add") || 
opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) {
-                       throw new DMLRuntimeException(
-                                       "Incorrect usage. Expected the opcode 
to be bias_add or bias_multiply or relu_backward, but found "
-                                                       + opcode);
-               }
-               _input1 = in1;
-               _input2 = in2;
-               _gputype = GPUINSTRUCTION_TYPE.Convolution;
-               _output = out;
-               _intermediateMemoryBudget = intermediateMemoryBudget;
-       }
-       public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, 
CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, 
-                       CPOperand out, CPOperand out2, String opcode, String 
istr, 
-                       double intermediateMemoryBudget) throws 
DMLRuntimeException {
-               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
-               _input1 = in1;
-               _input2 = in2;
-               _input3 = in3;
-               _input4 = in4;
-               _input5 = in5;
-               _input6 = in6;
-               _gputype = GPUINSTRUCTION_TYPE.Convolution;
-               _output = out;
-               _output2 = out2;
-               _intermediateMemoryBudget = intermediateMemoryBudget;
-       }
-       
-       public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, 
CPOperand in3, CPOperand in4, CPOperand in5,
-                       CPOperand in6, CPOperand in7, CPOperand in8,
-                       CPOperand out, CPOperand out2, CPOperand out3, 
CPOperand out4, CPOperand out5, String opcode, String istr, 
-                       double intermediateMemoryBudget) throws 
DMLRuntimeException {
-               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
-               _input1 = in1;
-               _input2 = in2;
-               _input3 = in3;
-               _input4 = in4;
-               _input5 = in5;
-               _input6 = in6;
-               _input7 = in7;
-               _input8 = in8;
-               _gputype = GPUINSTRUCTION_TYPE.Convolution;
-               _output = out;
-               _output2 = out2;
-               _output3 = out3;
-               _output4 = out4;
-               _output5 = out5;
-               _intermediateMemoryBudget = intermediateMemoryBudget;
-       }
-       
-       public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, 
CPOperand in3, CPOperand out, String opcode, String istr, 
-                       double intermediateMemoryBudget) throws 
DMLRuntimeException {
-               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
-               if( !opcode.equals("channel_sums") ) {
-                       throw new DMLRuntimeException("Incorrect usage. 
Expected the opcode to be channel_sums, but found " + opcode);
-               }
-               _input1 = in1;
-               _input2 = in2;
-               _input3 = in3;
-               _gputype = GPUINSTRUCTION_TYPE.Convolution;
-               _output = out;
-               _intermediateMemoryBudget = intermediateMemoryBudget;
-       }
-       
-       public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, 
CPOperand in3, CPOperand out, String opcode,
-                       String istr, ArrayList<CPOperand> stride,
-                       ArrayList<CPOperand> padding, ArrayList<CPOperand> 
input_shape,
-                       ArrayList<CPOperand> filter_shape, double 
intermediateMemoryBudget) 
-       {
-               this(in1, in2, out, opcode, istr, stride, padding,  
input_shape, filter_shape, intermediateMemoryBudget);
-               _input3 = in3;
-       }
-       
-       public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, 
CPOperand out, String opcode,
-                       String istr, ArrayList<CPOperand> stride,
-                       ArrayList<CPOperand> padding, ArrayList<CPOperand> 
input_shape,
-                       ArrayList<CPOperand> filter_shape, double 
intermediateMemoryBudget) 
-       {
-               super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), 
opcode, istr);
-               _gputype = GPUINSTRUCTION_TYPE.Convolution;
-
-               _input1 = in1;
-               _input2 = in2;
-               _output = out;
-               _stride = stride;
-               _padding = padding;
-               _input_shape = input_shape;
-               _filter_shape = filter_shape;
-               _intermediateMemoryBudget = intermediateMemoryBudget;
-       }
-
-       public static ConvolutionGPUInstruction parseInstruction(String str) {
-               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
-               String opcode = parts[0];
-               if( ( opcode.equalsIgnoreCase("conv2d")
-                        || opcode.equalsIgnoreCase("conv2d_backward_filter")
-                        || opcode.equalsIgnoreCase("conv2d_backward_data")) ) {
-                       InstructionUtils.checkNumFields(parts, 16);
-                       CPOperand in1 = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand out = new CPOperand(parts[15]);
-                       ArrayList<CPOperand> stride = new ArrayList<>();
-                       ArrayList<CPOperand> padding = new ArrayList<>();
-                       ArrayList<CPOperand> input_shape = new ArrayList<>();
-                       ArrayList<CPOperand> filter_shape = new ArrayList<>();
-                       stride.add(new CPOperand(parts[3]));
-                       stride.add(new CPOperand(parts[4]));
-                       padding.add(new CPOperand(parts[5]));
-                       padding.add(new CPOperand(parts[6]));
-                       input_shape.add(new CPOperand(parts[7]));
-                       input_shape.add(new CPOperand(parts[8]));
-                       input_shape.add(new CPOperand(parts[9]));
-                       input_shape.add(new CPOperand(parts[10]));
-                       filter_shape.add(new CPOperand(parts[11]));
-                       filter_shape.add(new CPOperand(parts[12]));
-                       filter_shape.add(new CPOperand(parts[13]));
-                       filter_shape.add(new CPOperand(parts[14]));
-
-                       return new ConvolutionGPUInstruction(in1, in2, out, 
opcode, str, stride,
-                                       padding, input_shape, filter_shape, 
Double.parseDouble(parts[16]));
-               }
-               else if( opcode.equalsIgnoreCase("maxpooling_backward") || 
opcode.equalsIgnoreCase("avgpooling_backward") ) {
-                       boolean withMaxPoolOut = false;
-                       if(parts.length == 18) {
-                               withMaxPoolOut = true;
-                       }
-                       else
-                               InstructionUtils.checkNumFields(parts, 16);
-                       CPOperand in1 = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand in3 = withMaxPoolOut ? new 
CPOperand(parts[15]) : null;
-                       CPOperand out = withMaxPoolOut ? new 
CPOperand(parts[16]) : new CPOperand(parts[15]);
-                       double memBudget = withMaxPoolOut ? 
Double.parseDouble(parts[17]) : Double.parseDouble(parts[16]);
-               
-                       ArrayList<CPOperand> stride = new ArrayList<>();
-                       ArrayList<CPOperand> padding = new ArrayList<>();
-                       ArrayList<CPOperand> input_shape = new ArrayList<>();
-                       ArrayList<CPOperand> filter_shape = new ArrayList<>();
-                       stride.add(new CPOperand(parts[3]));
-                       stride.add(new CPOperand(parts[4]));
-                       padding.add(new CPOperand(parts[5]));
-                       padding.add(new CPOperand(parts[6]));
-                       input_shape.add(new CPOperand(parts[7]));
-                       input_shape.add(new CPOperand(parts[8]));
-                       input_shape.add(new CPOperand(parts[9]));
-                       input_shape.add(new CPOperand(parts[10]));
-                       filter_shape.add(new CPOperand(parts[11]));
-                       filter_shape.add(new CPOperand(parts[12]));
-                       filter_shape.add(new CPOperand(parts[13]));
-                       filter_shape.add(new CPOperand(parts[14]));
-
-                       return new ConvolutionGPUInstruction(in1, in2, in3, 
out, opcode, str, stride,
-                                       padding, input_shape, filter_shape, 
memBudget);
-               }
-               else if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
-                       InstructionUtils.checkNumFields(parts, 17);
-                       CPOperand in1 = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand in3 = new CPOperand(parts[3]);
-                       CPOperand out = new CPOperand(parts[16]);
-               
-                       ArrayList<CPOperand> stride = new ArrayList<>();
-                       ArrayList<CPOperand> padding = new ArrayList<>();
-                       ArrayList<CPOperand> input_shape = new ArrayList<>();
-                       ArrayList<CPOperand> filter_shape = new ArrayList<>();
-                       stride.add(new CPOperand(parts[4]));
-                       stride.add(new CPOperand(parts[5]));
-                       padding.add(new CPOperand(parts[6]));
-                       padding.add(new CPOperand(parts[7]));
-                       input_shape.add(new CPOperand(parts[8]));
-                       input_shape.add(new CPOperand(parts[9]));
-                       input_shape.add(new CPOperand(parts[10]));
-                       input_shape.add(new CPOperand(parts[11]));
-                       filter_shape.add(new CPOperand(parts[12]));
-                       filter_shape.add(new CPOperand(parts[13]));
-                       filter_shape.add(new CPOperand(parts[14]));
-                       filter_shape.add(new CPOperand(parts[15]));
-
-                       return new ConvolutionGPUInstruction(in1, in2, in3, 
out, opcode, str, stride,
-                                       padding, input_shape, filter_shape, 
Double.parseDouble(parts[17]));
-               }
-               else if (opcode.equalsIgnoreCase("maxpooling") || 
opcode.equalsIgnoreCase("avgpooling")) {
-                       InstructionUtils.checkNumFields(parts, 15);
-                       CPOperand in1 = new CPOperand(parts[1]);
-                       CPOperand out = new CPOperand(parts[14]);
-               
-                       ArrayList<CPOperand> stride = new ArrayList<>();
-                       ArrayList<CPOperand> padding = new ArrayList<>();
-                       ArrayList<CPOperand> input_shape = new ArrayList<>();
-                       ArrayList<CPOperand> filter_shape = new ArrayList<>();
-                       stride.add(new CPOperand(parts[2]));
-                       stride.add(new CPOperand(parts[3]));
-                       padding.add(new CPOperand(parts[4]));
-                       padding.add(new CPOperand(parts[5]));
-                       input_shape.add(new CPOperand(parts[6]));
-                       input_shape.add(new CPOperand(parts[7]));
-                       input_shape.add(new CPOperand(parts[8]));
-                       input_shape.add(new CPOperand(parts[9]));
-                       filter_shape.add(new CPOperand(parts[10]));
-                       filter_shape.add(new CPOperand(parts[11]));
-                       filter_shape.add(new CPOperand(parts[12]));
-                       filter_shape.add(new CPOperand(parts[13]));
-
-                       return new ConvolutionGPUInstruction(in1, null, out, 
opcode, str, stride,
-                                       padding, input_shape, filter_shape, 
Double.parseDouble(parts[15]));
-               }
-               else if( opcode.equalsIgnoreCase("bias_add") || 
opcode.equalsIgnoreCase("relu_backward") || 
opcode.equalsIgnoreCase("bias_multiply")  ) {
-                       InstructionUtils.checkNumFields(parts, 4);
-                       CPOperand in1 = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand out = new CPOperand(parts[3]);
-                       return new ConvolutionGPUInstruction(in1, in2, out, 
opcode, str, Double.parseDouble(parts[4]));
-               }
-               else if (opcode.equalsIgnoreCase("channel_sums")) {
-                       InstructionUtils.checkNumFields(parts, 4);
-                       CPOperand in = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand in3 = new CPOperand(parts[3]);
-                       CPOperand out = new CPOperand(parts[4]);
-                       return new ConvolutionGPUInstruction(in, in2, in3, out, 
opcode, str, 0);
-               }
-               else if (opcode.equalsIgnoreCase("lstm")) {
-                       InstructionUtils.checkNumFields(parts, 8);
-                       CPOperand in1 = new CPOperand(parts[1]);
-                       CPOperand in2 = new CPOperand(parts[2]);
-                       CPOperand in3 = new CPOperand(parts[3]);
-                       CPOperand in4 = new CPOperand(parts[4]);
-                       CPOperand in5 = new CPOperand(parts[5]);
-                       CPOperand in6 = new CPOperand(parts[6]);
-                       CPOperand out = new CPOperand(parts[7]);
-                       CPOperand out2 = new CPOperand(parts[8]);
-                       return new ConvolutionGPUInstruction(in1, in2, in3, 
in4, in5, in6, out, out2, opcode, str, 0);
-               }
-               else if (opcode.equalsIgnoreCase("batch_norm2d") || 
opcode.equalsIgnoreCase("lstm_backward")) {
-                       InstructionUtils.checkNumFields(parts, 13);
-                       CPOperand in1 = new CPOperand(parts[1]); // image
-                       CPOperand in2 = new CPOperand(parts[2]); // scale
-                       CPOperand in3 = new CPOperand(parts[3]); // bias
-                       CPOperand in4 = new CPOperand(parts[4]); // runningMean
-                       CPOperand in5 = new CPOperand(parts[5]); // runningVar
-                       CPOperand in6 = new CPOperand(parts[6]); // mode
-                       CPOperand in7 = new CPOperand(parts[7]); // epsilon
-                       CPOperand in8 = new CPOperand(parts[8]); // 
exponentialAverageFactor
-                       CPOperand out = new CPOperand(parts[9]);  // ret
-                       CPOperand out2 = new CPOperand(parts[10]); // 
retRunningMean
-                       CPOperand out3 = new CPOperand(parts[11]); // 
retRunningVar
-                       CPOperand out4 = new CPOperand(parts[12]); // 
resultSaveMean
-                       CPOperand out5 = new CPOperand(parts[13]); // 
resultSaveInvVariance
-                       return new ConvolutionGPUInstruction(in1, in2, in3, 
in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
-               }
-               else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
-                       InstructionUtils.checkNumFields(parts, 9);
-                       CPOperand in1 = new CPOperand(parts[1]); // image
-                       CPOperand in2 = new CPOperand(parts[2]); // dout
-                       CPOperand in3 = new CPOperand(parts[3]); // scale
-                       CPOperand in4 = new CPOperand(parts[4]); // epsilon
-                       CPOperand in5 = new CPOperand(parts[5]); // 
resultSaveMean
-                       CPOperand in6 = new CPOperand(parts[6]); // 
resultSaveInvVariance
-                       CPOperand out = new CPOperand(parts[7]);  // dX
-                       CPOperand out2 = new CPOperand(parts[8]); // dScale
-                       CPOperand out3 = new CPOperand(parts[9]); // dBias
-                       return new ConvolutionGPUInstruction(in1, in2, in3, 
in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0);
-               }
-               else {
-                       throw new DMLRuntimeException("Unknown opcode while 
parsing a ConvolutionGPUInstruction: " + str);      
-               }
-       }
-
-       public void processBiasInstruction(String instOpcode, ExecutionContext 
ec) {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), input.getNumRows(), input.getNumColumns());
-               
-               if(instOpcode.equalsIgnoreCase("bias_add"))
-                       LibMatrixCUDA.biasAdd(ec.getGPUContext(0), 
getExtendedOpcode(), input, bias, out);
-               else if(instOpcode.equalsIgnoreCase("bias_multiply"))
-                       LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), 
getExtendedOpcode(), input, bias, out);
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-       }
-       
-       public void processBatchNorm2dInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
-               MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
-               MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, 
_input5.getName());
-               
-               String phase = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getStringValue();
-               double epsilon = ec.getScalarInput(_input7.getName(), 
_input7.getValueType(), _input7.isLiteral()).getDoubleValue();
-               
-               MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), image.getNumRows(), image.getNumColumns());
-               
-               if(phase.equalsIgnoreCase("train")) {
-                       double exponentialAverageFactor = 
1-ec.getScalarInput(_input8.getName(), _input8.getValueType(), 
_input8.isLiteral()).getDoubleValue();
-                       MatrixObject retRunningMean = 
getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), 
runningMean.getNumRows(), runningMean.getNumColumns());
-                       MatrixObject retRunningVar = 
getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), 
runningVar.getNumRows(), runningVar.getNumColumns());
-                       MatrixObject resultSaveMean = 
getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), 
runningMean.getNumRows(), runningMean.getNumColumns());
-                       MatrixObject resultSaveInvVariance = 
getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), 
runningVar.getNumRows(), runningVar.getNumColumns());
-                       
LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), 
getExtendedOpcode(), 
-                               image, scale, bias, runningMean, runningVar, 
ret, 
-                               retRunningMean, retRunningVar, epsilon, 
exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
-                       
ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
-                       
ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
-                       
ec.releaseMatrixOutputForGPUInstruction(_output4.getName());
-                       
ec.releaseMatrixOutputForGPUInstruction(_output5.getName());
-               }
-               else if(phase.equalsIgnoreCase("test")) {
-                       
LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), 
getExtendedOpcode(), 
-                                       image, scale, bias, runningMean, 
runningVar, ret, epsilon);
-                       ec.setMatrixOutput(_output2.getName(), new 
MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), 
true), getExtendedOpcode());
-                       ec.setMatrixOutput(_output3.getName(), new 
MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), 
true), getExtendedOpcode());
-                       ec.setMatrixOutput(_output4.getName(), new 
MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), 
true), getExtendedOpcode());
-                       ec.setMatrixOutput(_output5.getName(), new 
MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), 
true), getExtendedOpcode());
-               }
-               else {
-                       throw new DMLRuntimeException("Incorrect mode: Expected 
either train or test, but found " + phase);
-               }
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input4.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-       }
-       
-       public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) 
throws DMLRuntimeException {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject image = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject dout = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject scale = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
-               double epsilon = ec.getScalarInput(_input4.getName(), 
_input4.getValueType(), _input4.isLiteral()).getDoubleValue();
-               MatrixObject resultSaveMean = 
getMatrixInputForGPUInstruction(ec, _input5.getName());
-               MatrixObject resultSaveInvVariance = 
getMatrixInputForGPUInstruction(ec, _input6.getName());
-               
-               MatrixObject dX = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), image.getNumRows(), image.getNumColumns());
-               MatrixObject dScale = getDenseMatrixOutputForGPUInstruction(ec, 
_output2.getName(), scale.getNumRows(), scale.getNumColumns());
-               MatrixObject dBias = getDenseMatrixOutputForGPUInstruction(ec, 
_output3.getName(), scale.getNumRows(), scale.getNumColumns());
-               
-               LibMatrixCuDNN.batchNormalizationBackward(ec.getGPUContext(0), 
getExtendedOpcode(), image, 
-                               dout, scale, dX, dScale, dBias,
-                               epsilon, resultSaveMean, resultSaveInvVariance);
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input6.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output3.getName());
-       }
-
-       // (X > 0) * dout
-       public void processReLUBackwardInstruction(ExecutionContext ec) {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               MatrixObject dout = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               
-               MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, 
_output.getName(), input.getNumRows(), input.getNumColumns());
-               
-               LibMatrixCUDA.reluBackward(ec.getGPUContext(0), 
getExtendedOpcode(), input, dout, out);
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-       }
-       
-       public void processChannelSumsInstruction(ExecutionContext ec) {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               MatrixObject input = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               int C = (int) ec.getScalarInput(_input2.getName(), 
_input2.getValueType(), _input2.isLiteral()).getLongValue();
-               int HW = (int) ec.getScalarInput(_input3.getName(), 
_input3.getValueType(), _input3.isLiteral()).getLongValue();
-               if(C*HW != input.getNumColumns()) {
-                       throw new DMLRuntimeException("Expected rows*cols" + C 
+ "*" + HW + " to be equal to number of columns of input " + 
input.getNumColumns());
-               }
-               MatrixObject outputBlock = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1);
-               
-               LibMatrixCUDA.channelSums(ec.getGPUContext(0), 
getExtendedOpcode(), input, outputBlock, C, HW);
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-       }
-       
-       private static int toInt(long num) throws DMLRuntimeException {
-               if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
-                       throw new DMLRuntimeException("GPU : Exceeded supported 
size " + num);
-               }
-               return (int)num;
-       }
-       
-//     private Pointer transpose(ExecutionContext ec, MatrixObject X) throws 
DMLRuntimeException {
-//             GPUContext gCtx = ec.getGPUContext(0);
-//             String instructionName = getExtendedOpcode();
-//             long numRowsX = X.getNumRows(); long numColsX = 
X.getNumColumns();
-//             Pointer tX = gCtx.allocate(instructionName, 
numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType);
-//             jcuda.runtime.JCuda.cudaMemcpy(tX, 
LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), 
numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType,  
jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice);
-//             // LibMatrixCUDA.denseTranspose(ec, gCtx, instructionName, 
LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), tX, numRowsX, 
numColsX);
-//             return tX;
-//     }
-       
-       private void processLstmBackwardInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               GPUContext gCtx = ec.getGPUContext(0);
-               String instructionName = getExtendedOpcode();
-               
-               MatrixObject out0 = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
-               int M = toInt(out0.getNumColumns()); // hiddenSize .. since 
out0: (N, M)
-               Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, 
out0, instructionName);
-               
-               MatrixObject W = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
-               long numRowsW = W.getNumRows();
-               int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... 
numFeatures 
-               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
-               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
-               Pointer cudnnWPointer = gCtx.allocate(instructionName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
-                               
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
-                               sysmlWPointer, sysmlBiasPointer, cudnnWPointer, 
D, M);
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               
-               
-               MatrixObject X = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, 
instructionName); 
-               int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
-               long numColsX = X.getNumColumns();
-               int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
-               Pointer cudnnInput = gCtx.allocate(instructionName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
-               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
-                               
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
-                               xPointer, cudnnInput, N, D, T*D, N*T*D);
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               
-               Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName);
-               boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
-               
-               // LibMatrixCuDNN.lstm(ec, gCtx, instructionName, 
-                               // cudnnInput, cudnnWPointer, out0Pointer, 
c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T);
-                               // String xName, Pointer hx, Pointer cx, 
Pointer wPointer, String doutName, String dcyName,  // input
-                               // String dxName, String dwName, String dbName, 
String dhxName, String dcxName,         // output
-               String dxName = _output.getName();
-               String dwName = _output2.getName();
-               String dbName = _output3.getName();
-               String dhxName = _output4.getName();
-               String dcxName = _output5.getName();
-               String doutName = _input7.getName();
-               String dcyName = _input8.getName();
-               LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName, 
-                               cudnnInput, out0Pointer, c0Pointer, 
cudnnWPointer, doutName, dcyName,  // input
-                               dxName, dwName, dbName, dhxName, dcxName, // 
output 
-                               return_sequences, N, M, D, T);
-               gCtx.cudaFreeHelper(instructionName, cudnnWPointer, 
DMLScript.EAGER_CUDA_FREE);
-               gCtx.cudaFreeHelper(instructionName, cudnnInput, 
DMLScript.EAGER_CUDA_FREE);
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input4.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-       }
-       
-       private void processLstmInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
-               // batchSize=N, seqLength=T, numFeatures=D and hiddenSize=M
-               // input  X:(N, T*D),   ==> (T, D, N)
-               // weight W:(D+M+2, 4M) 
-               // previous output out0 (also represented by hx) and cell state 
c0 (also represented by cx): (N, M) ==> (1, M, N)
-               // out: (N, T*M) or (N, M) ==> (T, M, N)
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-               GPUContext gCtx = ec.getGPUContext(0);
-               String instructionName = getExtendedOpcode();
-               
-               MatrixObject out0 = getMatrixInputForGPUInstruction(ec, 
_input4.getName());
-               int M = toInt(out0.getNumColumns()); // hiddenSize .. since 
out0: (N, M)
-               Pointer out0Pointer =  LibMatrixCUDA.getDensePointer(gCtx, 
out0, instructionName);
-               
-               MatrixObject W = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-               MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input3.getName());
-               long numRowsW = W.getNumRows();
-               int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... 
numFeatures 
-               Pointer sysmlWPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M);
-               Pointer sysmlBiasPointer = 
LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M);
-               Pointer cudnnWPointer = gCtx.allocate(instructionName, 
(D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType);
-               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight",
-                               
ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)),
-                               sysmlWPointer, sysmlBiasPointer, cudnnWPointer, 
D, M);
-               ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-               
-               boolean return_sequences = ec.getScalarInput(_input6.getName(), 
_input6.getValueType(), _input6.isLiteral()).getBooleanValue();
-               
-               // Beause the matrices are released immediately, the output for 
transpose need not be taken into account
-               MatrixObject X = getMatrixInputForGPUInstruction(ec, 
_input1.getName());
-               Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, 
instructionName); 
-               int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D)
-               long numColsX = X.getNumColumns();
-               int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength
-               Pointer cudnnInput = gCtx.allocate(instructionName, 
(N*T*D)*LibMatrixCUDA.sizeOfDataType);
-               
LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input",
-                               
ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D),
-                               xPointer, cudnnInput, N, D, T*D, N*T*D);
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               
-               Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, 
getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); 
-               
-               LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, 
cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), 
_output2.getName(), N, M, D, T);
-               gCtx.cudaFreeHelper(instructionName, cudnnWPointer, 
DMLScript.EAGER_CUDA_FREE);
-               gCtx.cudaFreeHelper(instructionName, cudnnInput, 
DMLScript.EAGER_CUDA_FREE);
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input4.getName());
-               ec.releaseMatrixInputForGPUInstruction(_input5.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output2.getName());
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-       }
-       
-       @Override
-       public void processInstruction(ExecutionContext ec) {
-               if (instOpcode.equalsIgnoreCase("bias_add") || 
instOpcode.equalsIgnoreCase("bias_multiply")) {
-                       processBiasInstruction(instOpcode, ec);
-                       return;
-               }
-               else if (instOpcode.equalsIgnoreCase("relu_backward")) {
-                       processReLUBackwardInstruction(ec);
-                       return;
-               }
-               else if (instOpcode.equalsIgnoreCase("channel_sums")) {
-                       processChannelSumsInstruction(ec);
-                       return;
-               }
-               else if (instOpcode.equalsIgnoreCase("lstm")) {
-                       processLstmInstruction(ec);
-                       return;
-               }
-               else if (instOpcode.equalsIgnoreCase("lstm_backward")) {
-                       processLstmBackwardInstruction(ec);
-                       return;
-               }
-               else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
-                       processBatchNorm2dInstruction(ec);
-                       return;
-               }
-               else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
-                       processBatchNorm2dBackwardInstruction(ec);
-                       return;
-               }
-               
-               GPUStatistics.incrementNoOfExecutedGPUInst();
-                                       
-               int pad_h = getScalarInput(ec, _padding, 0);
-               int pad_w = getScalarInput(ec, _padding, 1);
-               int stride_h = getScalarInput(ec, _stride, 0);
-               int stride_w = getScalarInput(ec, _stride, 1);
-
-               int N = getScalarInput(ec, _input_shape, 0);
-               int C = getScalarInput(ec, _input_shape, 1);
-               int H = getScalarInput(ec, _input_shape, 2);
-               int W = getScalarInput(ec, _input_shape, 3);
-
-               int K = getScalarInput(ec, _filter_shape, 0);
-               
-               int R = getScalarInput(ec, _filter_shape, 2);
-               int S = getScalarInput(ec, _filter_shape, 3);
-               
-               int P = (int) ConvolutionUtils.getP(H, R, stride_h, pad_h);
-               int Q = (int) ConvolutionUtils.getQ(W, S, stride_w, pad_w);
-               
-               if (instOpcode.equalsIgnoreCase("conv2d")) {
-                       MatrixObject image = 
getMatrixInputForGPUInstruction(ec, _input1.getName());
-                       MatrixObject filter = 
getMatrixInputForGPUInstruction(ec, _input2.getName());
-
-                       if(image.getNumRows() != N || image.getNumColumns() != 
C*H*W) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for image in conv2d");
-                       if(filter.getNumRows() != K || filter.getNumColumns() 
!= C*R*S) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for filter in conv2d");
-                       
-                       MatrixObject out = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q);
-                       
-                       LibMatrixCuDNN.conv2d(ec.getGPUContext(0), 
getExtendedOpcode(), image, filter, out, N, C, H, W,
-                                       K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q, _intermediateMemoryBudget);
-               }
-               else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
-                       MatrixObject image = 
getMatrixInputForGPUInstruction(ec, _input1.getName());
-                       MatrixObject bias = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-                       MatrixObject filter = 
getMatrixInputForGPUInstruction(ec, _input3.getName());
-
-                       if(image.getNumRows() != N || image.getNumColumns() != 
C*H*W) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for image in conv2d");
-                       if(filter.getNumRows() != K || filter.getNumColumns() 
!= C*R*S) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for filter in conv2d");
-                       
-                       MatrixObject out = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q);
-                       
-                       LibMatrixCuDNN.conv2dBiasAdd(ec.getGPUContext(0), 
getExtendedOpcode(), image, bias, filter, out, N, C, H, W,
-                                               K, R, S, pad_h, pad_w, 
stride_h, stride_w, P, Q, _intermediateMemoryBudget);
-               }
-               else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) 
{
-                       MatrixObject image = 
getMatrixInputForGPUInstruction(ec, _input1.getName());
-                       MatrixObject dout = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-
-                       if(image.getNumRows() != N || image.getNumColumns() != 
C*H*W) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for image in conv2d_backward_filter");
-                       if(dout.getNumRows() != N || dout.getNumColumns() != 
K*P*Q) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for dout in conv2d_backward_filter: " + 
-                                               dout.getNumRows() + " != " +  N 
+ " || " + dout.getNumColumns() + " != " + K*P*Q);
-                       
-                       MatrixObject out = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), K, C * R * S);
-                       
-                       
LibMatrixCuDNN.conv2dBackwardFilter(ec.getGPUContext(0), getExtendedOpcode(), 
image, dout, out, N, C, H, W,
-                                       K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q, _intermediateMemoryBudget);
-                       // TODO: For now always copy the device data to host
-                       // ec.gpuCtx.copyDeviceToHost(outputBlock);
-               }
-               else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
-                       MatrixObject filter = 
getMatrixInputForGPUInstruction(ec, _input1.getName());
-                       MatrixObject dout = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-
-                       if(filter.getNumRows() != K || filter.getNumColumns() 
!= C*R*S) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for filter in convolution_backward_data");
-                       if(dout.getNumRows() != N || dout.getNumColumns() != 
K*P*Q) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for dout in conv2d_backward_data: " + 
-                                               dout.getNumRows() + " != " +  N 
+ " || " + dout.getNumColumns() + " != " + K*P*Q);
-                       
-                       MatrixObject out = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W);
-                       
-                       LibMatrixCuDNN.conv2dBackwardData(ec.getGPUContext(0), 
getExtendedOpcode(), filter, dout, out, N, C, H, W,
-                                       K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q, _intermediateMemoryBudget);
-               }
-               else if (instOpcode.equalsIgnoreCase("maxpooling") || 
instOpcode.equalsIgnoreCase("avgpooling")) {
-                       MatrixObject image = 
getMatrixInputForGPUInstruction(ec, _input1.getName());
-
-                       if(image.getNumRows() != N || image.getNumColumns() != 
C*H*W) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for image in maxpooling: " + 
-                                               image.getNumRows() + " != " +  
N + " || " + image.getNumColumns() + " != " + C*H*W);
-                       
-                       MatrixObject out = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * P * Q);
-                       PoolingType poolType = 
instOpcode.equalsIgnoreCase("maxpooling") ? PoolingType.MAX : PoolingType.AVG;
-                       LibMatrixCuDNN.pooling(ec.getGPUContext(0), 
getExtendedOpcode(), image, out, N, C, H, W,
-                                       K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q, poolType, _intermediateMemoryBudget);
-               }
-               else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || 
instOpcode.equalsIgnoreCase("avgpooling_backward")) {
-                       MatrixObject image = 
getMatrixInputForGPUInstruction(ec, _input1.getName());
-                       MatrixObject dout = getMatrixInputForGPUInstruction(ec, 
_input2.getName());
-                       MatrixObject maxPoolOutput = _input3 != null ? 
getMatrixInputForGPUInstruction(ec, _input3.getName()) : null;
-                       if(dout.getNumRows() != N || dout.getNumColumns() != 
C*P*Q) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for dout in maxpooling_backward");
-                       if(image.getNumRows() != N || image.getNumColumns() != 
C*H*W) 
-                               throw new DMLRuntimeException("Incorrect 
dimensions for image in maxpooling_backward: " + 
-                                               image.getNumRows() + " != " +  
N + " || " + image.getNumColumns() + " != " + K*P*Q);
-                       
-                       MatrixObject out = 
getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W);
-                       PoolingType poolType = 
instOpcode.equalsIgnoreCase("maxpooling_backward") ? PoolingType.MAX : 
PoolingType.AVG;
-                       LibMatrixCuDNN.poolingBackward(ec.getGPUContext(0), 
getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W,
-                                       K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q, poolType, _intermediateMemoryBudget);
-               }
-               else {
-                       throw new DMLRuntimeException("Unsupported GPU context 
for " + instOpcode);
-               }
-               
-               // release inputs/outputs
-               ec.releaseMatrixInputForGPUInstruction(_input1.getName());
-               
-               boolean isPool = instOpcode.equalsIgnoreCase("maxpooling") || 
instOpcode.equalsIgnoreCase("avgpooling");
-               boolean isPoolBackward = 
instOpcode.equalsIgnoreCase("maxpooling_backward") || 
instOpcode.equalsIgnoreCase("avgpooling_backward");
-
-               if ( !isPool )
-                       
ec.releaseMatrixInputForGPUInstruction(_input2.getName());
-
-               if (instOpcode.equalsIgnoreCase("conv2d_bias_add") || 
-                       (isPoolBackward && _input3 != null))
-                       
ec.releaseMatrixInputForGPUInstruction(_input3.getName());
-
-               ec.releaseMatrixOutputForGPUInstruction(_output.getName());
-       }
-
-
-       private static int getScalarInput(ExecutionContext ec, 
ArrayList<CPOperand> aL, int index) {
-               return (int) ec.getScalarInput(aL.get(index).getName(),
-                       aL.get(index).getValueType(), 
aL.get(index).isLiteral()).getLongValue();
-       }
-}

Reply via email to