http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java new file mode 100644 index 0000000..4532240 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java @@ -0,0 +1,645 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.cp; + +import java.util.ArrayList; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.DnnParameters; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType; +import org.apache.sysml.runtime.matrix.data.LibMatrixNative; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.util.DnnUtils; +import org.apache.sysml.utils.NativeHelper; + +public class DnnCPInstruction extends UnaryCPInstruction { + private static final Log LOG = LogFactory.getLog(DnnCPInstruction.class.getName()); + private static boolean warnedUnderUtilitization = false; + + private final CPOperand _in2; + private final CPOperand _in3; + private final CPOperand _in4; + private final CPOperand _in5; + private final CPOperand _in6; + private final CPOperand _in7; + private final CPOperand _in8; + private final CPOperand _out2; + private final CPOperand _out3; + private final CPOperand _out4; + private final CPOperand _out5; + private final ArrayList<CPOperand> _input_shape; + private final ArrayList<CPOperand> _filter_shape; + private final ArrayList<CPOperand> _stride; + private final ArrayList<CPOperand> _padding; + private final int _numThreads; + private final double _intermediateMemoryBudget; + + public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, + ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget, String opcode, String istr) { + super(CPType.Dnn, null, in, out, opcode, istr); + _in2 = in2; + _in3 = in3; + _in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = null; + _out2 = null; _out3 = null; _out4 = null; _out5 = null; + _stride = stride; + _padding = padding; + _input_shape = input_shape; + _filter_shape = filter_shape; + _numThreads = numThreads; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + + public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, int numThreads, double intermediateMemoryBudget) { + this(in, in2, null, out, null, null, null, null, numThreads, intermediateMemoryBudget, opcode, istr); + if( !(opcode.equals("bias_add") || opcode.equals("relu_backward") || opcode.equals("bias_multiply") ) ) { + throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + opcode); + } + } + + public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, int numThreads, double intermediateMemoryBudget) { + this(in, in2, in3, out, null, null, null, null, numThreads, intermediateMemoryBudget, opcode, istr); + if( !opcode.equals("channel_sums") ) { + throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode); + } + } + + private DnnCPInstruction(CPOperand in, CPOperand out, String opcode, String istr, + ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) { + this(in, null, null, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr); + } + + public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, + String istr, ArrayList<CPOperand> stride, + ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) { + this(in, in2, null, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr); + } + + public DnnCPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, + String istr, ArrayList<CPOperand> stride, + ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) { + this(in, in2, in3, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr); + } + + public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, + CPOperand in6, CPOperand in7, CPOperand in8, + CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, + double intermediateMemoryBudget) throws DMLRuntimeException { + super(CPType.Dnn, null, in1, out, opcode, istr); + _in2 = in2; + _in3 = in3; + _in4 = in4; + _in5 = in5; + _in6 = in6; + _in7 = in7; + _in8 = in8; + _out2 = out2; + _out3 = out3; + _out4 = out4; + _out5 = out5; + _stride = null; + _padding = null; + _input_shape = null; + _filter_shape = null; + _numThreads = 0; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + + public static DnnCPInstruction parseInstruction(String str) { + + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling") || + opcode.equalsIgnoreCase("avgpooling")) { + InstructionUtils.checkNumFields(parts, 16); + // stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k + CPOperand in = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[14]); + + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[2])); + stride.add(new CPOperand(parts[3])); + padding.add(new CPOperand(parts[4])); + padding.add(new CPOperand(parts[5])); + input_shape.add(new CPOperand(parts[6])); + input_shape.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + filter_shape.add(new CPOperand(parts[10])); + filter_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + int k = Integer.parseInt(parts[15]); + + return new DnnCPInstruction(in, out, opcode, str, stride, + padding, input_shape, filter_shape, k, Double.parseDouble(parts[16])); + } + else if (opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("relu_maxpooling_backward") + || opcode.equalsIgnoreCase("avgpooling_backward") + || opcode.equalsIgnoreCase("conv2d") + || opcode.equalsIgnoreCase("conv2d_backward_filter") + || opcode.equalsIgnoreCase("conv2d_backward_data")) { + InstructionUtils.checkNumFields(parts, 17); + // dout, stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[15]); + + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[3])); + stride.add(new CPOperand(parts[4])); + padding.add(new CPOperand(parts[5])); + padding.add(new CPOperand(parts[6])); + input_shape.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + input_shape.add(new CPOperand(parts[10])); + filter_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + filter_shape.add(new CPOperand(parts[14])); + int k = Integer.parseInt(parts[16]); + + return new DnnCPInstruction(in, in2, out, opcode, str, stride, + padding, input_shape, filter_shape, k, Double.parseDouble(parts[17])); + } + else if (opcode.equalsIgnoreCase("conv2d_bias_add")) { + InstructionUtils.checkNumFields(parts, 18); + // dout, stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[16]); + + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[4])); + stride.add(new CPOperand(parts[5])); + padding.add(new CPOperand(parts[6])); + padding.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + input_shape.add(new CPOperand(parts[10])); + input_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + filter_shape.add(new CPOperand(parts[14])); + filter_shape.add(new CPOperand(parts[15])); + int k = Integer.parseInt(parts[17]); + + return new DnnCPInstruction(in, in2, in3, out, opcode, str, stride, + padding, input_shape, filter_shape, k, Double.parseDouble(parts[18])); + } + else if (opcode.equalsIgnoreCase("bias_add") || opcode.equals("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) { + InstructionUtils.checkNumFields(parts, 5); + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[3]); + int k = Integer.parseInt(parts[4]); + return new DnnCPInstruction(in, in2, out, opcode, str, k, Double.parseDouble(parts[5])); + } + else if (opcode.equalsIgnoreCase("channel_sums")) { + InstructionUtils.checkNumFields(parts, 4); + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[4]); + return new DnnCPInstruction(in, in2, in3, out, opcode, str, -1, 0); + } + else if (opcode.equalsIgnoreCase("batch_norm2d")) { + InstructionUtils.checkNumFields(parts, 13); + CPOperand in1 = new CPOperand(parts[1]); // image + CPOperand in2 = new CPOperand(parts[2]); // scale + CPOperand in3 = new CPOperand(parts[3]); // bias + CPOperand in4 = new CPOperand(parts[4]); // runningMean + CPOperand in5 = new CPOperand(parts[5]); // runningVar + CPOperand in6 = new CPOperand(parts[6]); // mode + CPOperand in7 = new CPOperand(parts[7]); // epsilon + CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor + CPOperand out = new CPOperand(parts[9]); // ret + CPOperand out2 = new CPOperand(parts[10]); // retRunningMean + CPOperand out3 = new CPOperand(parts[11]); // retRunningVar + CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean + CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance + return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0); + } + else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) { + InstructionUtils.checkNumFields(parts, 9); + CPOperand in1 = new CPOperand(parts[1]); // image + CPOperand in2 = new CPOperand(parts[2]); // dout + CPOperand in3 = new CPOperand(parts[3]); // scale + CPOperand in4 = new CPOperand(parts[4]); // epsilon + CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean + CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance + CPOperand out = new CPOperand(parts[7]); // dX + CPOperand out2 = new CPOperand(parts[8]); // dScale + CPOperand out3 = new CPOperand(parts[9]); // dBias + return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0); + } + else { + throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str); + } + } + + private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) { + return (int) ec.getScalarInput(aL.get(index).getName(), + aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue(); + } + + public void processReluBackwardInstruction(ExecutionContext ec) { + // (X > 0) * dout + MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock dout = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + MatrixBlock outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), + input.isInSparseFormat() || dout.isInSparseFormat() ); + + if( !input.isEmpty() && !dout.isEmpty() ) { //sparse-safe + outputBlock.allocateBlock(); + LibMatrixDNN.reluBackward(input, dout, outputBlock, _numThreads); + } + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode()); + } + + public void processBiasAddInstruction(ExecutionContext ec) { + MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock bias = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + MatrixBlock outputBlock = null; + + if(bias.getNumColumns() != 1) { + throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns()); + } + + if(input.isEmpty() && bias.isEmpty()) { + outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true); + } + else if(bias.isEmpty()) { + outputBlock = new MatrixBlock(input); + } + else { + // As we always fill the output first with bias + outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), false); + outputBlock.allocateDenseBlock(); + LibMatrixDNN.biasAdd(input, bias, outputBlock, _numThreads); + } + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode()); + } + + public void processBiasMultiplyInstruction(ExecutionContext ec) { + MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock bias = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + MatrixBlock outputBlock = null; + + if(bias.getNumColumns() != 1) { + throw new DMLRuntimeException("Expected the number of columns of bias matrix to be 1, but found " + bias.getNumColumns()); + } + + if(bias.isEmpty()) { + // Anything multiplied by zero is zero + outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), true); + } + else { + // As we always fill the output first with bias + outputBlock = new MatrixBlock(input.getNumRows(), input.getNumColumns(), + input.isInSparseFormat()).allocateBlock(); + LibMatrixDNN.biasMultiply(input, bias, outputBlock, _numThreads); + } + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode()); + } + + public void processChannelSumsInstruction(ExecutionContext ec) { + MatrixBlock input = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + int C = (int) ec.getScalarInput(_in2.getName(), _in2.getValueType(), _in2.isLiteral()).getLongValue(); + int HW = (int) ec.getScalarInput(_in3.getName(), _in3.getValueType(), _in3.isLiteral()).getLongValue(); + if(C*HW != input.getNumColumns()) { + throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns()); + } + MatrixBlock outputBlock = null; + if(input.isEmpty()) { + outputBlock = new MatrixBlock(C, 1, true); + } + else { + outputBlock = new MatrixBlock(C, 1, false).allocateBlock(); + LibMatrixDNN.channelSums(input, outputBlock, C, HW); + } + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode()); + } + + + + public void processBatchNorm2dInstruction(ExecutionContext ec) { + MatrixBlock image = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock scale = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + MatrixBlock bias = ec.getMatrixInput(_in3.getName(), getExtendedOpcode()); + MatrixBlock runningMean = ec.getMatrixInput(_in4.getName(), getExtendedOpcode()); + MatrixBlock runningVar = ec.getMatrixInput(_in5.getName(), getExtendedOpcode()); + String phase = ec.getScalarInput(_in6.getName(), _in6.getValueType(), _in6.isLiteral()).getStringValue(); + double epsilon = ec.getScalarInput(_in7.getName(), _in7.getValueType(), _in7.isLiteral()).getDoubleValue(); + double mu = ec.getScalarInput(_in8.getName(), _in8.getValueType(), _in8.isLiteral()).getDoubleValue(); + + MatrixBlock ret = new MatrixBlock(image.getNumRows(), image.getNumColumns(), false).allocateBlock(); + MatrixBlock retRunningMean = new MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), false).allocateBlock(); + MatrixBlock retRunningVar = new MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), false).allocateBlock(); + MatrixBlock resultSaveMean = new MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), false).allocateBlock(); + MatrixBlock resultSaveInvVariance = new MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), false).allocateBlock(); + + LibMatrixDNN.batchNorm2D(image, scale, bias, runningMean, runningVar, phase, epsilon, mu, ret, + retRunningMean, retRunningVar, resultSaveMean, resultSaveInvVariance); + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode()); + ec.setMatrixOutput(output.getName(), ret, getExtendedOpcode()); + ec.setMatrixOutput(_out2.getName(), retRunningMean, getExtendedOpcode()); + ec.setMatrixOutput(_out3.getName(), retRunningVar, getExtendedOpcode()); + ec.setMatrixOutput(_out4.getName(), resultSaveMean, getExtendedOpcode()); + ec.setMatrixOutput(_out5.getName(), resultSaveInvVariance, getExtendedOpcode()); + } + + public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) { + MatrixBlock image = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock dout = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + MatrixBlock scale = ec.getMatrixInput(_in3.getName(), getExtendedOpcode()); + double epsilon = ec.getScalarInput(_in4.getName(), _in4.getValueType(), _in4.isLiteral()).getDoubleValue(); + MatrixBlock resultSaveMean = ec.getMatrixInput(_in5.getName(), getExtendedOpcode()); + MatrixBlock resultSaveInvVariance = ec.getMatrixInput(_in6.getName(), getExtendedOpcode()); + + MatrixBlock dX = new MatrixBlock(image.getNumRows(), image.getNumColumns(), false).allocateBlock(); + MatrixBlock dScale = new MatrixBlock(scale.getNumRows(), scale.getNumColumns(), false).allocateBlock(); + MatrixBlock dBias = new MatrixBlock(scale.getNumRows(), scale.getNumColumns(), false).allocateBlock(); + + LibMatrixDNN.batchNorm2DBackward(image, dout, scale, epsilon, resultSaveMean, resultSaveInvVariance, dX, dScale, dBias); + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in6.getName(), getExtendedOpcode()); + ec.setMatrixOutput(output.getName(), dX, getExtendedOpcode()); + ec.setMatrixOutput(_out2.getName(), dScale, getExtendedOpcode()); + ec.setMatrixOutput(_out3.getName(), dBias, getExtendedOpcode()); + } + + + // Assumption: enableNative && NativeHelper.isNativeLibraryLoaded() is true + // This increases the number of native calls. For example:the cases where filter is sparse but input is dense + private static boolean isFilterSparse(MatrixBlock filter) { + long numElems = filter.getNumRows()*filter.getNumColumns(); + // if filter is less than 10 MB in dense format (which handles almost all the cases). + // In fact, using threshold of 1 MB is still sufficient for common CNNs. + if(filter.isInSparseFormat() && numElems < 10e+6) + filter.sparseToDense(); + return filter.isInSparseFormat(); + } + + + @Override + public void processInstruction(ExecutionContext ec) { + + if (instOpcode.equalsIgnoreCase("bias_add")) { + processBiasAddInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("bias_multiply")) { + processBiasMultiplyInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("relu_backward")) { + processReluBackwardInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("channel_sums")) { + processChannelSumsInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("batch_norm2d")) { + processBatchNorm2dInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) { + processBatchNorm2dBackwardInstruction(ec); + return; + } + + // acquire inputs + MatrixBlock outputBlock = null; + MatrixBlock matBlock = instOpcode.equalsIgnoreCase("avgpooling_backward") ? null : ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + int pad_h = getScalarInput(ec, _padding, 0); + int pad_w = getScalarInput(ec, _padding, 1); + int stride_h = getScalarInput(ec, _stride, 0); + int stride_w = getScalarInput(ec, _stride, 1); + + int N = getScalarInput(ec, _input_shape, 0); + int C = getScalarInput(ec, _input_shape, 1); + int H = getScalarInput(ec, _input_shape, 2); + int W = getScalarInput(ec, _input_shape, 3); + + int K = getScalarInput(ec, _filter_shape, 0); + + int R = getScalarInput(ec, _filter_shape, 2); + int S = getScalarInput(ec, _filter_shape, 3); + int P = (int) DnnUtils.getP(H, R, stride_h, pad_h); + int Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w); + + DnnParameters params = new DnnParameters(N, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, _numThreads); + params.enableNative = NativeHelper.isNativeLibraryLoaded(); + if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling") || + instOpcode.equalsIgnoreCase("avgpooling")) { + if(matBlock.isEmpty()) { + outputBlock = new MatrixBlock(N, C*P*Q, true); + } + else { + outputBlock = new MatrixBlock(N, C*P*Q, false).allocateBlock(); + + PoolingType poolType = (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) ? PoolingType.MAX : PoolingType.AVG; + if(instOpcode.equalsIgnoreCase("relu_maxpooling")) + params.minValForMaxPoolOperations = 0; + LibMatrixDNN.pooling(matBlock, outputBlock, params, poolType); + } + } + else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("relu_maxpooling_backward") || + instOpcode.equalsIgnoreCase("avgpooling_backward")) { + MatrixBlock dout = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + boolean isEmpty = instOpcode.equalsIgnoreCase("avgpooling_backward") ? dout.isEmpty() : (matBlock.isEmpty() || dout.isEmpty()); + if(isEmpty) { + outputBlock = new MatrixBlock(N, C*H*W, true); + } + else { + outputBlock = new MatrixBlock(N, C*H*W, false).allocateBlock(); + PoolingType poolType = (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("relu_maxpooling_backward")) ? PoolingType.MAX : PoolingType.AVG; + boolean performReLUBackward = instOpcode.equalsIgnoreCase("relu_maxpooling_backward"); + if(performReLUBackward) + params.minValForMaxPoolOperations = 0; + LibMatrixDNN.poolingBackward(matBlock, dout, outputBlock, params, performReLUBackward, poolType); + } + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + } + else if (instOpcode.equalsIgnoreCase("conv2d")) { + resetNumThreads(params, C*R*S, P*Q, matBlock.getNonZeros() / (matBlock.getNumRows()*matBlock.getNumColumns())); + MatrixBlock filter = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + if(filter.isEmpty() || matBlock.isEmpty()) { + outputBlock = new MatrixBlock(N, K*P*Q, true); + } + else { + boolean sparse = matBlock.isUltraSparse(false) && params.bias == null + && matBlock.getInMemorySize() < MatrixBlock.estimateSizeDenseInMemory(N, K*P*Q); + outputBlock = new MatrixBlock(N, K*P*Q, sparse).allocateBlock(); + if(params.enableNative && !isFilterSparse(filter) && !matBlock.isInSparseFormat()) + LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); + else + LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); + } + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + } + else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) { + resetNumThreads(params, C*R*S, P*Q, matBlock.getNonZeros() / (matBlock.getNumRows()*matBlock.getNumColumns())); + MatrixBlock filter = ec.getMatrixInput(_in3.getName(), getExtendedOpcode()); + MatrixBlock bias = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + if(bias.getNumRows() != params.K || bias.getNumColumns() != 1) { + throw new DMLRuntimeException("Incorrect shape of bias matrix: [" + bias.getNumRows() + " " + bias.getNumColumns() + "]. " + + "Expected: [" + params.K + ", 1]"); + } + boolean isOutputConvEmpty = filter.isEmpty() || matBlock.isEmpty(); + if(isOutputConvEmpty && bias.isEmpty()) { + // bias_add(empty mb, empty mb) = empty mb + outputBlock = new MatrixBlock(N, K*P*Q, true); + } + else if(isOutputConvEmpty && !bias.isEmpty()) { + // Add bias to empty output block + // bias_add(empty mb, bias) + outputBlock = new MatrixBlock(N, K*P*Q, false).allocateBlock(); + for(int n = 0; n < params.N; n++) + DnnUtils.fillBias(bias, outputBlock.getDenseBlockValues(), + n, n+1, params.N, params.K, params.P*params.Q); + } + else { + outputBlock = new MatrixBlock(N, K*P*Q, false).allocateBlock(); + if(!bias.isEmpty()) { + // Handle situation where both input and filter are non empty, but bias is empty + params.bias = bias; + } + if(params.enableNative && !isFilterSparse(filter) && !matBlock.isInSparseFormat()) + LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); + else + LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); + } + ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + } + else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) { + MatrixBlock dout = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + if(dout.isEmpty() || matBlock.isEmpty()) { + outputBlock = new MatrixBlock(K, C*R*S, true); + } + else { + outputBlock = new MatrixBlock(K, C*R*S, false).allocateBlock(); + if(params.enableNative && !matBlock.isInSparseFormat() && !dout.isInSparseFormat()) + LibMatrixNative.conv2dBackwardFilter(matBlock, dout, outputBlock, params); + else + LibMatrixDNN.conv2dBackwardFilter(matBlock, dout, outputBlock, params); + } + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + } + else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) { + MatrixBlock dout = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + if(dout.isEmpty() || matBlock.isEmpty()) { + outputBlock = new MatrixBlock(N, C * H * W, true); + } + else { + outputBlock = new MatrixBlock(N, C * H * W, false).allocateBlock(); + if(params.enableNative && !isFilterSparse(matBlock) && !dout.isInSparseFormat()) + LibMatrixNative.conv2dBackwardData(matBlock, dout, outputBlock, params); + else + LibMatrixDNN.conv2dBackwardData(matBlock, dout, outputBlock, params); + } + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + } + else { + throw new DMLRuntimeException("Unsupported op code " + instOpcode); + } + + // release inputs/outputs + if(!instOpcode.equalsIgnoreCase("avgpooling_backward")) + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode()); + } + + /** + * Reset the number of thread to respect the intermediate CP memory budget + * + * @param params convolution parameters + * @param numRows number of rows of intermediate matrix used per thread + * @param numCols number of rows of intermediate matrix used per thread + * @param sparsity sparsity of intermediate matrix used per thread + */ + private void resetNumThreads(DnnParameters params, int numRows, int numCols, double sparsity) { + if(DMLScript.USE_ACCELERATOR) { + double memBudget1Thread = OptimizerUtils.estimateSizeExactSparsity(numRows, numCols, sparsity); + int limitedDegreeOfParallelism = (int) Math.floor(_intermediateMemoryBudget / memBudget1Thread); + if(params.numThreads > limitedDegreeOfParallelism) { + params.numThreads = limitedDegreeOfParallelism; + if(!warnedUnderUtilitization) + LOG.warn("CPU Under-utilization to respect the intermediate memory budget. To avoid this, please try reducing the mini-batch or forcing gpu execution."); + warnedUnderUtilitization = true; + } + } + } +}
http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java deleted file mode 100644 index 7d8a0fc..0000000 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java +++ /dev/null @@ -1,750 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.sysml.runtime.instructions.gpu; - -import java.util.ArrayList; - -import jcuda.Pointer; - -import org.apache.sysml.api.DMLScript; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.functionobjects.SwapIndex; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.instructions.cp.CPOperand; -import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig; -import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; -import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA; -import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType; -import org.apache.sysml.runtime.matrix.operators.ReorgOperator; -import org.apache.sysml.runtime.util.ConvolutionUtils; -import org.apache.sysml.utils.GPUStatistics; - -public class ConvolutionGPUInstruction extends GPUInstruction { - private CPOperand _input1; - private CPOperand _input2; - private CPOperand _input3; - private CPOperand _input4; - private CPOperand _input5; - private CPOperand _input6; - private CPOperand _input7; - private CPOperand _input8; - private CPOperand _output; - private CPOperand _output2; - private CPOperand _output3; - private CPOperand _output4; - private CPOperand _output5; - private ArrayList<CPOperand> _input_shape; - private ArrayList<CPOperand> _filter_shape; - private ArrayList<CPOperand> _stride = new ArrayList<>(); - private ArrayList<CPOperand> _padding = new ArrayList<>(); - private double _intermediateMemoryBudget = 0; - - public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) { - super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); - if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) { - throw new DMLRuntimeException( - "Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " - + opcode); - } - _input1 = in1; - _input2 = in2; - _gputype = GPUINSTRUCTION_TYPE.Convolution; - _output = out; - _intermediateMemoryBudget = intermediateMemoryBudget; - } - public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, - CPOperand out, CPOperand out2, String opcode, String istr, - double intermediateMemoryBudget) throws DMLRuntimeException { - super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); - _input1 = in1; - _input2 = in2; - _input3 = in3; - _input4 = in4; - _input5 = in5; - _input6 = in6; - _gputype = GPUINSTRUCTION_TYPE.Convolution; - _output = out; - _output2 = out2; - _intermediateMemoryBudget = intermediateMemoryBudget; - } - - public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, - CPOperand in6, CPOperand in7, CPOperand in8, - CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, - double intermediateMemoryBudget) throws DMLRuntimeException { - super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); - _input1 = in1; - _input2 = in2; - _input3 = in3; - _input4 = in4; - _input5 = in5; - _input6 = in6; - _input7 = in7; - _input8 = in8; - _gputype = GPUINSTRUCTION_TYPE.Convolution; - _output = out; - _output2 = out2; - _output3 = out3; - _output4 = out4; - _output5 = out5; - _intermediateMemoryBudget = intermediateMemoryBudget; - } - - public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, - double intermediateMemoryBudget) throws DMLRuntimeException { - super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); - if( !opcode.equals("channel_sums") ) { - throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode); - } - _input1 = in1; - _input2 = in2; - _input3 = in3; - _gputype = GPUINSTRUCTION_TYPE.Convolution; - _output = out; - _intermediateMemoryBudget = intermediateMemoryBudget; - } - - public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, - String istr, ArrayList<CPOperand> stride, - ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, - ArrayList<CPOperand> filter_shape, double intermediateMemoryBudget) - { - this(in1, in2, out, opcode, istr, stride, padding, input_shape, filter_shape, intermediateMemoryBudget); - _input3 = in3; - } - - public ConvolutionGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, - String istr, ArrayList<CPOperand> stride, - ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, - ArrayList<CPOperand> filter_shape, double intermediateMemoryBudget) - { - super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); - _gputype = GPUINSTRUCTION_TYPE.Convolution; - - _input1 = in1; - _input2 = in2; - _output = out; - _stride = stride; - _padding = padding; - _input_shape = input_shape; - _filter_shape = filter_shape; - _intermediateMemoryBudget = intermediateMemoryBudget; - } - - public static ConvolutionGPUInstruction parseInstruction(String str) { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - String opcode = parts[0]; - if( ( opcode.equalsIgnoreCase("conv2d") - || opcode.equalsIgnoreCase("conv2d_backward_filter") - || opcode.equalsIgnoreCase("conv2d_backward_data")) ) { - InstructionUtils.checkNumFields(parts, 16); - CPOperand in1 = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand out = new CPOperand(parts[15]); - ArrayList<CPOperand> stride = new ArrayList<>(); - ArrayList<CPOperand> padding = new ArrayList<>(); - ArrayList<CPOperand> input_shape = new ArrayList<>(); - ArrayList<CPOperand> filter_shape = new ArrayList<>(); - stride.add(new CPOperand(parts[3])); - stride.add(new CPOperand(parts[4])); - padding.add(new CPOperand(parts[5])); - padding.add(new CPOperand(parts[6])); - input_shape.add(new CPOperand(parts[7])); - input_shape.add(new CPOperand(parts[8])); - input_shape.add(new CPOperand(parts[9])); - input_shape.add(new CPOperand(parts[10])); - filter_shape.add(new CPOperand(parts[11])); - filter_shape.add(new CPOperand(parts[12])); - filter_shape.add(new CPOperand(parts[13])); - filter_shape.add(new CPOperand(parts[14])); - - return new ConvolutionGPUInstruction(in1, in2, out, opcode, str, stride, - padding, input_shape, filter_shape, Double.parseDouble(parts[16])); - } - else if( opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("avgpooling_backward") ) { - boolean withMaxPoolOut = false; - if(parts.length == 18) { - withMaxPoolOut = true; - } - else - InstructionUtils.checkNumFields(parts, 16); - CPOperand in1 = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand in3 = withMaxPoolOut ? new CPOperand(parts[15]) : null; - CPOperand out = withMaxPoolOut ? new CPOperand(parts[16]) : new CPOperand(parts[15]); - double memBudget = withMaxPoolOut ? Double.parseDouble(parts[17]) : Double.parseDouble(parts[16]); - - ArrayList<CPOperand> stride = new ArrayList<>(); - ArrayList<CPOperand> padding = new ArrayList<>(); - ArrayList<CPOperand> input_shape = new ArrayList<>(); - ArrayList<CPOperand> filter_shape = new ArrayList<>(); - stride.add(new CPOperand(parts[3])); - stride.add(new CPOperand(parts[4])); - padding.add(new CPOperand(parts[5])); - padding.add(new CPOperand(parts[6])); - input_shape.add(new CPOperand(parts[7])); - input_shape.add(new CPOperand(parts[8])); - input_shape.add(new CPOperand(parts[9])); - input_shape.add(new CPOperand(parts[10])); - filter_shape.add(new CPOperand(parts[11])); - filter_shape.add(new CPOperand(parts[12])); - filter_shape.add(new CPOperand(parts[13])); - filter_shape.add(new CPOperand(parts[14])); - - return new ConvolutionGPUInstruction(in1, in2, in3, out, opcode, str, stride, - padding, input_shape, filter_shape, memBudget); - } - else if (opcode.equalsIgnoreCase("conv2d_bias_add")) { - InstructionUtils.checkNumFields(parts, 17); - CPOperand in1 = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand in3 = new CPOperand(parts[3]); - CPOperand out = new CPOperand(parts[16]); - - ArrayList<CPOperand> stride = new ArrayList<>(); - ArrayList<CPOperand> padding = new ArrayList<>(); - ArrayList<CPOperand> input_shape = new ArrayList<>(); - ArrayList<CPOperand> filter_shape = new ArrayList<>(); - stride.add(new CPOperand(parts[4])); - stride.add(new CPOperand(parts[5])); - padding.add(new CPOperand(parts[6])); - padding.add(new CPOperand(parts[7])); - input_shape.add(new CPOperand(parts[8])); - input_shape.add(new CPOperand(parts[9])); - input_shape.add(new CPOperand(parts[10])); - input_shape.add(new CPOperand(parts[11])); - filter_shape.add(new CPOperand(parts[12])); - filter_shape.add(new CPOperand(parts[13])); - filter_shape.add(new CPOperand(parts[14])); - filter_shape.add(new CPOperand(parts[15])); - - return new ConvolutionGPUInstruction(in1, in2, in3, out, opcode, str, stride, - padding, input_shape, filter_shape, Double.parseDouble(parts[17])); - } - else if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("avgpooling")) { - InstructionUtils.checkNumFields(parts, 15); - CPOperand in1 = new CPOperand(parts[1]); - CPOperand out = new CPOperand(parts[14]); - - ArrayList<CPOperand> stride = new ArrayList<>(); - ArrayList<CPOperand> padding = new ArrayList<>(); - ArrayList<CPOperand> input_shape = new ArrayList<>(); - ArrayList<CPOperand> filter_shape = new ArrayList<>(); - stride.add(new CPOperand(parts[2])); - stride.add(new CPOperand(parts[3])); - padding.add(new CPOperand(parts[4])); - padding.add(new CPOperand(parts[5])); - input_shape.add(new CPOperand(parts[6])); - input_shape.add(new CPOperand(parts[7])); - input_shape.add(new CPOperand(parts[8])); - input_shape.add(new CPOperand(parts[9])); - filter_shape.add(new CPOperand(parts[10])); - filter_shape.add(new CPOperand(parts[11])); - filter_shape.add(new CPOperand(parts[12])); - filter_shape.add(new CPOperand(parts[13])); - - return new ConvolutionGPUInstruction(in1, null, out, opcode, str, stride, - padding, input_shape, filter_shape, Double.parseDouble(parts[15])); - } - else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) { - InstructionUtils.checkNumFields(parts, 4); - CPOperand in1 = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand out = new CPOperand(parts[3]); - return new ConvolutionGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4])); - } - else if (opcode.equalsIgnoreCase("channel_sums")) { - InstructionUtils.checkNumFields(parts, 4); - CPOperand in = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand in3 = new CPOperand(parts[3]); - CPOperand out = new CPOperand(parts[4]); - return new ConvolutionGPUInstruction(in, in2, in3, out, opcode, str, 0); - } - else if (opcode.equalsIgnoreCase("lstm")) { - InstructionUtils.checkNumFields(parts, 8); - CPOperand in1 = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand in3 = new CPOperand(parts[3]); - CPOperand in4 = new CPOperand(parts[4]); - CPOperand in5 = new CPOperand(parts[5]); - CPOperand in6 = new CPOperand(parts[6]); - CPOperand out = new CPOperand(parts[7]); - CPOperand out2 = new CPOperand(parts[8]); - return new ConvolutionGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0); - } - else if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) { - InstructionUtils.checkNumFields(parts, 13); - CPOperand in1 = new CPOperand(parts[1]); // image - CPOperand in2 = new CPOperand(parts[2]); // scale - CPOperand in3 = new CPOperand(parts[3]); // bias - CPOperand in4 = new CPOperand(parts[4]); // runningMean - CPOperand in5 = new CPOperand(parts[5]); // runningVar - CPOperand in6 = new CPOperand(parts[6]); // mode - CPOperand in7 = new CPOperand(parts[7]); // epsilon - CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor - CPOperand out = new CPOperand(parts[9]); // ret - CPOperand out2 = new CPOperand(parts[10]); // retRunningMean - CPOperand out3 = new CPOperand(parts[11]); // retRunningVar - CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean - CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance - return new ConvolutionGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0); - } - else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) { - InstructionUtils.checkNumFields(parts, 9); - CPOperand in1 = new CPOperand(parts[1]); // image - CPOperand in2 = new CPOperand(parts[2]); // dout - CPOperand in3 = new CPOperand(parts[3]); // scale - CPOperand in4 = new CPOperand(parts[4]); // epsilon - CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean - CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance - CPOperand out = new CPOperand(parts[7]); // dX - CPOperand out2 = new CPOperand(parts[8]); // dScale - CPOperand out3 = new CPOperand(parts[9]); // dBias - return new ConvolutionGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0); - } - else { - throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionGPUInstruction: " + str); - } - } - - public void processBiasInstruction(String instOpcode, ExecutionContext ec) { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns()); - - if(instOpcode.equalsIgnoreCase("bias_add")) - LibMatrixCUDA.biasAdd(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out); - else if(instOpcode.equalsIgnoreCase("bias_multiply")) - LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out); - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - } - - public void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); - MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName()); - MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName()); - - String phase = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getStringValue(); - double epsilon = ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue(); - - MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); - - if(phase.equalsIgnoreCase("train")) { - double exponentialAverageFactor = 1-ec.getScalarInput(_input8.getName(), _input8.getValueType(), _input8.isLiteral()).getDoubleValue(); - MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); - MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); - MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); - MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); - LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(), - image, scale, bias, runningMean, runningVar, ret, - retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance); - ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output3.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output4.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output5.getName()); - } - else if(phase.equalsIgnoreCase("test")) { - LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(), - image, scale, bias, runningMean, runningVar, ret, epsilon); - ec.setMatrixOutput(_output2.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode()); - ec.setMatrixOutput(_output3.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode()); - ec.setMatrixOutput(_output4.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode()); - ec.setMatrixOutput(_output5.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode()); - } - else { - throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + phase); - } - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - ec.releaseMatrixInputForGPUInstruction(_input4.getName()); - ec.releaseMatrixInputForGPUInstruction(_input5.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - } - - public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input3.getName()); - double epsilon = ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue(); - MatrixObject resultSaveMean = getMatrixInputForGPUInstruction(ec, _input5.getName()); - MatrixObject resultSaveInvVariance = getMatrixInputForGPUInstruction(ec, _input6.getName()); - - MatrixObject dX = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); - MatrixObject dScale = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), scale.getNumRows(), scale.getNumColumns()); - MatrixObject dBias = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), scale.getNumRows(), scale.getNumColumns()); - - LibMatrixCuDNN.batchNormalizationBackward(ec.getGPUContext(0), getExtendedOpcode(), image, - dout, scale, dX, dScale, dBias, - epsilon, resultSaveMean, resultSaveInvVariance); - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - ec.releaseMatrixInputForGPUInstruction(_input5.getName()); - ec.releaseMatrixInputForGPUInstruction(_input6.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output3.getName()); - } - - // (X > 0) * dout - public void processReLUBackwardInstruction(ExecutionContext ec) { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); - - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns()); - - LibMatrixCUDA.reluBackward(ec.getGPUContext(0), getExtendedOpcode(), input, dout, out); - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - } - - public void processChannelSumsInstruction(ExecutionContext ec) { - GPUStatistics.incrementNoOfExecutedGPUInst(); - MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); - int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue(); - int HW = (int) ec.getScalarInput(_input3.getName(), _input3.getValueType(), _input3.isLiteral()).getLongValue(); - if(C*HW != input.getNumColumns()) { - throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns()); - } - MatrixObject outputBlock = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1); - - LibMatrixCUDA.channelSums(ec.getGPUContext(0), getExtendedOpcode(), input, outputBlock, C, HW); - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - } - - private static int toInt(long num) throws DMLRuntimeException { - if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) { - throw new DMLRuntimeException("GPU : Exceeded supported size " + num); - } - return (int)num; - } - -// private Pointer transpose(ExecutionContext ec, MatrixObject X) throws DMLRuntimeException { -// GPUContext gCtx = ec.getGPUContext(0); -// String instructionName = getExtendedOpcode(); -// long numRowsX = X.getNumRows(); long numColsX = X.getNumColumns(); -// Pointer tX = gCtx.allocate(instructionName, numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType); -// jcuda.runtime.JCuda.cudaMemcpy(tX, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType, jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice); -// // LibMatrixCUDA.denseTranspose(ec, gCtx, instructionName, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), tX, numRowsX, numColsX); -// return tX; -// } - - private void processLstmBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException { - GPUStatistics.incrementNoOfExecutedGPUInst(); - GPUContext gCtx = ec.getGPUContext(0); - String instructionName = getExtendedOpcode(); - - MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName()); - int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M) - Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName); - - MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); - long numRowsW = W.getNumRows(); - int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures - Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M); - Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M); - Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); - LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight", - ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)), - sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - - - MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName()); - Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName); - int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D) - long numColsX = X.getNumColumns(); - int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength - Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType); - LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input", - ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D), - xPointer, cudnnInput, N, D, T*D, N*T*D); - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - - Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); - boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue(); - - // LibMatrixCuDNN.lstm(ec, gCtx, instructionName, - // cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T); - // String xName, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, // input - // String dxName, String dwName, String dbName, String dhxName, String dcxName, // output - String dxName = _output.getName(); - String dwName = _output2.getName(); - String dbName = _output3.getName(); - String dhxName = _output4.getName(); - String dcxName = _output5.getName(); - String doutName = _input7.getName(); - String dcyName = _input8.getName(); - LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName, - cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName, // input - dxName, dwName, dbName, dhxName, dcxName, // output - return_sequences, N, M, D, T); - gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE); - gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE); - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input4.getName()); - ec.releaseMatrixInputForGPUInstruction(_input5.getName()); - } - - private void processLstmInstruction(ExecutionContext ec) throws DMLRuntimeException { - // batchSize=N, seqLength=T, numFeatures=D and hiddenSize=M - // input X:(N, T*D), ==> (T, D, N) - // weight W:(D+M+2, 4M) - // previous output out0 (also represented by hx) and cell state c0 (also represented by cx): (N, M) ==> (1, M, N) - // out: (N, T*M) or (N, M) ==> (T, M, N) - GPUStatistics.incrementNoOfExecutedGPUInst(); - GPUContext gCtx = ec.getGPUContext(0); - String instructionName = getExtendedOpcode(); - - MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName()); - int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M) - Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName); - - MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); - long numRowsW = W.getNumRows(); - int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures - Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M); - Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M); - Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); - LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight", - ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)), - sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M); - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - - boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue(); - - // Beause the matrices are released immediately, the output for transpose need not be taken into account - MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName()); - Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName); - int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D) - long numColsX = X.getNumColumns(); - int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength - Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType); - LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input", - ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D), - xPointer, cudnnInput, N, D, T*D, N*T*D); - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - - Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); - - LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T); - gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE); - gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE); - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input4.getName()); - ec.releaseMatrixInputForGPUInstruction(_input5.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - } - - @Override - public void processInstruction(ExecutionContext ec) { - if (instOpcode.equalsIgnoreCase("bias_add") || instOpcode.equalsIgnoreCase("bias_multiply")) { - processBiasInstruction(instOpcode, ec); - return; - } - else if (instOpcode.equalsIgnoreCase("relu_backward")) { - processReLUBackwardInstruction(ec); - return; - } - else if (instOpcode.equalsIgnoreCase("channel_sums")) { - processChannelSumsInstruction(ec); - return; - } - else if (instOpcode.equalsIgnoreCase("lstm")) { - processLstmInstruction(ec); - return; - } - else if (instOpcode.equalsIgnoreCase("lstm_backward")) { - processLstmBackwardInstruction(ec); - return; - } - else if (instOpcode.equalsIgnoreCase("batch_norm2d")) { - processBatchNorm2dInstruction(ec); - return; - } - else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) { - processBatchNorm2dBackwardInstruction(ec); - return; - } - - GPUStatistics.incrementNoOfExecutedGPUInst(); - - int pad_h = getScalarInput(ec, _padding, 0); - int pad_w = getScalarInput(ec, _padding, 1); - int stride_h = getScalarInput(ec, _stride, 0); - int stride_w = getScalarInput(ec, _stride, 1); - - int N = getScalarInput(ec, _input_shape, 0); - int C = getScalarInput(ec, _input_shape, 1); - int H = getScalarInput(ec, _input_shape, 2); - int W = getScalarInput(ec, _input_shape, 3); - - int K = getScalarInput(ec, _filter_shape, 0); - - int R = getScalarInput(ec, _filter_shape, 2); - int S = getScalarInput(ec, _filter_shape, 3); - - int P = (int) ConvolutionUtils.getP(H, R, stride_h, pad_h); - int Q = (int) ConvolutionUtils.getQ(W, S, stride_w, pad_w); - - if (instOpcode.equalsIgnoreCase("conv2d")) { - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input2.getName()); - - if(image.getNumRows() != N || image.getNumColumns() != C*H*W) - throw new DMLRuntimeException("Incorrect dimensions for image in conv2d"); - if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S) - throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d"); - - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q); - - LibMatrixCuDNN.conv2d(ec.getGPUContext(0), getExtendedOpcode(), image, filter, out, N, C, H, W, - K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget); - } - else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) { - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input3.getName()); - - if(image.getNumRows() != N || image.getNumColumns() != C*H*W) - throw new DMLRuntimeException("Incorrect dimensions for image in conv2d"); - if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S) - throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d"); - - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q); - - LibMatrixCuDNN.conv2dBiasAdd(ec.getGPUContext(0), getExtendedOpcode(), image, bias, filter, out, N, C, H, W, - K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget); - } - else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) { - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); - - if(image.getNumRows() != N || image.getNumColumns() != C*H*W) - throw new DMLRuntimeException("Incorrect dimensions for image in conv2d_backward_filter"); - if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q) - throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_filter: " + - dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K*P*Q); - - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), K, C * R * S); - - LibMatrixCuDNN.conv2dBackwardFilter(ec.getGPUContext(0), getExtendedOpcode(), image, dout, out, N, C, H, W, - K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget); - // TODO: For now always copy the device data to host - // ec.gpuCtx.copyDeviceToHost(outputBlock); - } - else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) { - MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); - - if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S) - throw new DMLRuntimeException("Incorrect dimensions for filter in convolution_backward_data"); - if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q) - throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_data: " + - dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K*P*Q); - - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W); - - LibMatrixCuDNN.conv2dBackwardData(ec.getGPUContext(0), getExtendedOpcode(), filter, dout, out, N, C, H, W, - K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget); - } - else if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling")) { - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - - if(image.getNumRows() != N || image.getNumColumns() != C*H*W) - throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling: " + - image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + C*H*W); - - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * P * Q); - PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling") ? PoolingType.MAX : PoolingType.AVG; - LibMatrixCuDNN.pooling(ec.getGPUContext(0), getExtendedOpcode(), image, out, N, C, H, W, - K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget); - } - else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward")) { - MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); - MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); - MatrixObject maxPoolOutput = _input3 != null ? getMatrixInputForGPUInstruction(ec, _input3.getName()) : null; - if(dout.getNumRows() != N || dout.getNumColumns() != C*P*Q) - throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward"); - if(image.getNumRows() != N || image.getNumColumns() != C*H*W) - throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling_backward: " + - image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + K*P*Q); - - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W); - PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling_backward") ? PoolingType.MAX : PoolingType.AVG; - LibMatrixCuDNN.poolingBackward(ec.getGPUContext(0), getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W, - K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget); - } - else { - throw new DMLRuntimeException("Unsupported GPU context for " + instOpcode); - } - - // release inputs/outputs - ec.releaseMatrixInputForGPUInstruction(_input1.getName()); - - boolean isPool = instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling"); - boolean isPoolBackward = instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward"); - - if ( !isPool ) - ec.releaseMatrixInputForGPUInstruction(_input2.getName()); - - if (instOpcode.equalsIgnoreCase("conv2d_bias_add") || - (isPoolBackward && _input3 != null)) - ec.releaseMatrixInputForGPUInstruction(_input3.getName()); - - ec.releaseMatrixOutputForGPUInstruction(_output.getName()); - } - - - private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) { - return (int) ec.getScalarInput(aL.get(index).getName(), - aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue(); - } -}
