http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java new file mode 100644 index 0000000..709be6c --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/DnnGPUInstruction.java @@ -0,0 +1,750 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.runtime.instructions.gpu; + +import java.util.ArrayList; + +import jcuda.Pointer; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.functionobjects.SwapIndex; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.gpu.context.ExecutionConfig; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; +import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA; +import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType; +import org.apache.sysml.runtime.matrix.operators.ReorgOperator; +import org.apache.sysml.runtime.util.DnnUtils; +import org.apache.sysml.utils.GPUStatistics; + +public class DnnGPUInstruction extends GPUInstruction { + private CPOperand _input1; + private CPOperand _input2; + private CPOperand _input3; + private CPOperand _input4; + private CPOperand _input5; + private CPOperand _input6; + private CPOperand _input7; + private CPOperand _input8; + private CPOperand _output; + private CPOperand _output2; + private CPOperand _output3; + private CPOperand _output4; + private CPOperand _output5; + private ArrayList<CPOperand> _input_shape; + private ArrayList<CPOperand> _filter_shape; + private ArrayList<CPOperand> _stride = new ArrayList<>(); + private ArrayList<CPOperand> _padding = new ArrayList<>(); + private double _intermediateMemoryBudget = 0; + + public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); + if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) { + throw new DMLRuntimeException( + "Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + + opcode); + } + _input1 = in1; + _input2 = in2; + _gputype = GPUINSTRUCTION_TYPE.Dnn; + _output = out; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, + CPOperand out, CPOperand out2, String opcode, String istr, + double intermediateMemoryBudget) throws DMLRuntimeException { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); + _input1 = in1; + _input2 = in2; + _input3 = in3; + _input4 = in4; + _input5 = in5; + _input6 = in6; + _gputype = GPUINSTRUCTION_TYPE.Dnn; + _output = out; + _output2 = out2; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + + public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, + CPOperand in6, CPOperand in7, CPOperand in8, + CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, + double intermediateMemoryBudget) throws DMLRuntimeException { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); + _input1 = in1; + _input2 = in2; + _input3 = in3; + _input4 = in4; + _input5 = in5; + _input6 = in6; + _input7 = in7; + _input8 = in8; + _gputype = GPUINSTRUCTION_TYPE.Dnn; + _output = out; + _output2 = out2; + _output3 = out3; + _output4 = out4; + _output5 = out5; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + + public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, + double intermediateMemoryBudget) throws DMLRuntimeException { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); + if( !opcode.equals("channel_sums") ) { + throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode); + } + _input1 = in1; + _input2 = in2; + _input3 = in3; + _gputype = GPUINSTRUCTION_TYPE.Dnn; + _output = out; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + + public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, + String istr, ArrayList<CPOperand> stride, + ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape, double intermediateMemoryBudget) + { + this(in1, in2, out, opcode, istr, stride, padding, input_shape, filter_shape, intermediateMemoryBudget); + _input3 = in3; + } + + public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, + String istr, ArrayList<CPOperand> stride, + ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape, double intermediateMemoryBudget) + { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr); + _gputype = GPUINSTRUCTION_TYPE.Dnn; + + _input1 = in1; + _input2 = in2; + _output = out; + _stride = stride; + _padding = padding; + _input_shape = input_shape; + _filter_shape = filter_shape; + _intermediateMemoryBudget = intermediateMemoryBudget; + } + + public static DnnGPUInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + if( ( opcode.equalsIgnoreCase("conv2d") + || opcode.equalsIgnoreCase("conv2d_backward_filter") + || opcode.equalsIgnoreCase("conv2d_backward_data")) ) { + InstructionUtils.checkNumFields(parts, 16); + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[15]); + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[3])); + stride.add(new CPOperand(parts[4])); + padding.add(new CPOperand(parts[5])); + padding.add(new CPOperand(parts[6])); + input_shape.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + input_shape.add(new CPOperand(parts[10])); + filter_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + filter_shape.add(new CPOperand(parts[14])); + + return new DnnGPUInstruction(in1, in2, out, opcode, str, stride, + padding, input_shape, filter_shape, Double.parseDouble(parts[16])); + } + else if( opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("avgpooling_backward") ) { + boolean withMaxPoolOut = false; + if(parts.length == 18) { + withMaxPoolOut = true; + } + else + InstructionUtils.checkNumFields(parts, 16); + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = withMaxPoolOut ? new CPOperand(parts[15]) : null; + CPOperand out = withMaxPoolOut ? new CPOperand(parts[16]) : new CPOperand(parts[15]); + double memBudget = withMaxPoolOut ? Double.parseDouble(parts[17]) : Double.parseDouble(parts[16]); + + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[3])); + stride.add(new CPOperand(parts[4])); + padding.add(new CPOperand(parts[5])); + padding.add(new CPOperand(parts[6])); + input_shape.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + input_shape.add(new CPOperand(parts[10])); + filter_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + filter_shape.add(new CPOperand(parts[14])); + + return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride, + padding, input_shape, filter_shape, memBudget); + } + else if (opcode.equalsIgnoreCase("conv2d_bias_add")) { + InstructionUtils.checkNumFields(parts, 17); + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[16]); + + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[4])); + stride.add(new CPOperand(parts[5])); + padding.add(new CPOperand(parts[6])); + padding.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + input_shape.add(new CPOperand(parts[10])); + input_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + filter_shape.add(new CPOperand(parts[14])); + filter_shape.add(new CPOperand(parts[15])); + + return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride, + padding, input_shape, filter_shape, Double.parseDouble(parts[17])); + } + else if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("avgpooling")) { + InstructionUtils.checkNumFields(parts, 15); + CPOperand in1 = new CPOperand(parts[1]); + CPOperand out = new CPOperand(parts[14]); + + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[2])); + stride.add(new CPOperand(parts[3])); + padding.add(new CPOperand(parts[4])); + padding.add(new CPOperand(parts[5])); + input_shape.add(new CPOperand(parts[6])); + input_shape.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + filter_shape.add(new CPOperand(parts[10])); + filter_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + + return new DnnGPUInstruction(in1, null, out, opcode, str, stride, + padding, input_shape, filter_shape, Double.parseDouble(parts[15])); + } + else if( opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply") ) { + InstructionUtils.checkNumFields(parts, 4); + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[3]); + return new DnnGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4])); + } + else if (opcode.equalsIgnoreCase("channel_sums")) { + InstructionUtils.checkNumFields(parts, 4); + CPOperand in = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand out = new CPOperand(parts[4]); + return new DnnGPUInstruction(in, in2, in3, out, opcode, str, 0); + } + else if (opcode.equalsIgnoreCase("lstm")) { + InstructionUtils.checkNumFields(parts, 8); + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand in4 = new CPOperand(parts[4]); + CPOperand in5 = new CPOperand(parts[5]); + CPOperand in6 = new CPOperand(parts[6]); + CPOperand out = new CPOperand(parts[7]); + CPOperand out2 = new CPOperand(parts[8]); + return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0); + } + else if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) { + InstructionUtils.checkNumFields(parts, 13); + CPOperand in1 = new CPOperand(parts[1]); // image + CPOperand in2 = new CPOperand(parts[2]); // scale + CPOperand in3 = new CPOperand(parts[3]); // bias + CPOperand in4 = new CPOperand(parts[4]); // runningMean + CPOperand in5 = new CPOperand(parts[5]); // runningVar + CPOperand in6 = new CPOperand(parts[6]); // mode + CPOperand in7 = new CPOperand(parts[7]); // epsilon + CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor + CPOperand out = new CPOperand(parts[9]); // ret + CPOperand out2 = new CPOperand(parts[10]); // retRunningMean + CPOperand out3 = new CPOperand(parts[11]); // retRunningVar + CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean + CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance + return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0); + } + else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) { + InstructionUtils.checkNumFields(parts, 9); + CPOperand in1 = new CPOperand(parts[1]); // image + CPOperand in2 = new CPOperand(parts[2]); // dout + CPOperand in3 = new CPOperand(parts[3]); // scale + CPOperand in4 = new CPOperand(parts[4]); // epsilon + CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean + CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance + CPOperand out = new CPOperand(parts[7]); // dX + CPOperand out2 = new CPOperand(parts[8]); // dScale + CPOperand out3 = new CPOperand(parts[9]); // dBias + return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0); + } + else { + throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str); + } + } + + public void processBiasInstruction(String instOpcode, ExecutionContext ec) { + GPUStatistics.incrementNoOfExecutedGPUInst(); + MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns()); + + if(instOpcode.equalsIgnoreCase("bias_add")) + LibMatrixCUDA.biasAdd(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out); + else if(instOpcode.equalsIgnoreCase("bias_multiply")) + LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), getExtendedOpcode(), input, bias, out); + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + } + + public void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException { + GPUStatistics.incrementNoOfExecutedGPUInst(); + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); + MatrixObject runningMean = getMatrixInputForGPUInstruction(ec, _input4.getName()); + MatrixObject runningVar = getMatrixInputForGPUInstruction(ec, _input5.getName()); + + String phase = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getStringValue(); + double epsilon = ec.getScalarInput(_input7.getName(), _input7.getValueType(), _input7.isLiteral()).getDoubleValue(); + + MatrixObject ret = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); + + if(phase.equalsIgnoreCase("train")) { + double exponentialAverageFactor = 1-ec.getScalarInput(_input8.getName(), _input8.getValueType(), _input8.isLiteral()).getDoubleValue(); + MatrixObject retRunningMean = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); + MatrixObject retRunningVar = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); + MatrixObject resultSaveMean = getDenseMatrixOutputForGPUInstruction(ec, _output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns()); + MatrixObject resultSaveInvVariance = getDenseMatrixOutputForGPUInstruction(ec, _output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns()); + LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), getExtendedOpcode(), + image, scale, bias, runningMean, runningVar, ret, + retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance); + ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output3.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output4.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output5.getName()); + } + else if(phase.equalsIgnoreCase("test")) { + LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), getExtendedOpcode(), + image, scale, bias, runningMean, runningVar, ret, epsilon); + ec.setMatrixOutput(_output2.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode()); + ec.setMatrixOutput(_output3.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode()); + ec.setMatrixOutput(_output4.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true), getExtendedOpcode()); + ec.setMatrixOutput(_output5.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true), getExtendedOpcode()); + } + else { + throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + phase); + } + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); + ec.releaseMatrixInputForGPUInstruction(_input4.getName()); + ec.releaseMatrixInputForGPUInstruction(_input5.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + } + + public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException { + GPUStatistics.incrementNoOfExecutedGPUInst(); + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject scale = getMatrixInputForGPUInstruction(ec, _input3.getName()); + double epsilon = ec.getScalarInput(_input4.getName(), _input4.getValueType(), _input4.isLiteral()).getDoubleValue(); + MatrixObject resultSaveMean = getMatrixInputForGPUInstruction(ec, _input5.getName()); + MatrixObject resultSaveInvVariance = getMatrixInputForGPUInstruction(ec, _input6.getName()); + + MatrixObject dX = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), image.getNumRows(), image.getNumColumns()); + MatrixObject dScale = getDenseMatrixOutputForGPUInstruction(ec, _output2.getName(), scale.getNumRows(), scale.getNumColumns()); + MatrixObject dBias = getDenseMatrixOutputForGPUInstruction(ec, _output3.getName(), scale.getNumRows(), scale.getNumColumns()); + + LibMatrixCuDNN.batchNormalizationBackward(ec.getGPUContext(0), getExtendedOpcode(), image, + dout, scale, dX, dScale, dBias, + epsilon, resultSaveMean, resultSaveInvVariance); + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); + ec.releaseMatrixInputForGPUInstruction(_input5.getName()); + ec.releaseMatrixInputForGPUInstruction(_input6.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output3.getName()); + } + + // (X > 0) * dout + public void processReLUBackwardInstruction(ExecutionContext ec) { + GPUStatistics.incrementNoOfExecutedGPUInst(); + MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); + + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), input.getNumRows(), input.getNumColumns()); + + LibMatrixCUDA.reluBackward(ec.getGPUContext(0), getExtendedOpcode(), input, dout, out); + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + } + + public void processChannelSumsInstruction(ExecutionContext ec) { + GPUStatistics.incrementNoOfExecutedGPUInst(); + MatrixObject input = getMatrixInputForGPUInstruction(ec, _input1.getName()); + int C = (int) ec.getScalarInput(_input2.getName(), _input2.getValueType(), _input2.isLiteral()).getLongValue(); + int HW = (int) ec.getScalarInput(_input3.getName(), _input3.getValueType(), _input3.isLiteral()).getLongValue(); + if(C*HW != input.getNumColumns()) { + throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns()); + } + MatrixObject outputBlock = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), C, 1); + + LibMatrixCUDA.channelSums(ec.getGPUContext(0), getExtendedOpcode(), input, outputBlock, C, HW); + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + } + + private static int toInt(long num) throws DMLRuntimeException { + if(num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) { + throw new DMLRuntimeException("GPU : Exceeded supported size " + num); + } + return (int)num; + } + +// private Pointer transpose(ExecutionContext ec, MatrixObject X) throws DMLRuntimeException { +// GPUContext gCtx = ec.getGPUContext(0); +// String instructionName = getExtendedOpcode(); +// long numRowsX = X.getNumRows(); long numColsX = X.getNumColumns(); +// Pointer tX = gCtx.allocate(instructionName, numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType); +// jcuda.runtime.JCuda.cudaMemcpy(tX, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), numRowsX*numColsX*LibMatrixCUDA.sizeOfDataType, jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice); +// // LibMatrixCUDA.denseTranspose(ec, gCtx, instructionName, LibMatrixCUDA.getDensePointer(gCtx, X, instructionName), tX, numRowsX, numColsX); +// return tX; +// } + + private void processLstmBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException { + GPUStatistics.incrementNoOfExecutedGPUInst(); + GPUContext gCtx = ec.getGPUContext(0); + String instructionName = getExtendedOpcode(); + + MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName()); + int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M) + Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName); + + MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); + long numRowsW = W.getNumRows(); + int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures + Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M); + Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M); + Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight", + ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)), + sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); + + + MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName()); + Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName); + int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D) + long numColsX = X.getNumColumns(); + int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength + Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input", + ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D), + xPointer, cudnnInput, N, D, T*D, N*T*D); + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + + Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); + boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue(); + + // LibMatrixCuDNN.lstm(ec, gCtx, instructionName, + // cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T); + // String xName, Pointer hx, Pointer cx, Pointer wPointer, String doutName, String dcyName, // input + // String dxName, String dwName, String dbName, String dhxName, String dcxName, // output + String dxName = _output.getName(); + String dwName = _output2.getName(); + String dbName = _output3.getName(); + String dhxName = _output4.getName(); + String dcxName = _output5.getName(); + String doutName = _input7.getName(); + String dcyName = _input8.getName(); + LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName, + cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName, // input + dxName, dwName, dbName, dhxName, dcxName, // output + return_sequences, N, M, D, T); + gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE); + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input4.getName()); + ec.releaseMatrixInputForGPUInstruction(_input5.getName()); + } + + private void processLstmInstruction(ExecutionContext ec) throws DMLRuntimeException { + // batchSize=N, seqLength=T, numFeatures=D and hiddenSize=M + // input X:(N, T*D), ==> (T, D, N) + // weight W:(D+M+2, 4M) + // previous output out0 (also represented by hx) and cell state c0 (also represented by cx): (N, M) ==> (1, M, N) + // out: (N, T*M) or (N, M) ==> (T, M, N) + GPUStatistics.incrementNoOfExecutedGPUInst(); + GPUContext gCtx = ec.getGPUContext(0); + String instructionName = getExtendedOpcode(); + + MatrixObject out0 = getMatrixInputForGPUInstruction(ec, _input4.getName()); + int M = toInt(out0.getNumColumns()); // hiddenSize .. since out0: (N, M) + Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName); + + MatrixObject W = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input3.getName()); + long numRowsW = W.getNumRows(); + int D = toInt(numRowsW) - M; // since W:(D+M, 4M) ... numFeatures + Pointer sysmlWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D+M, 4*M); + Pointer sysmlBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4*M); + Pointer cudnnWPointer = gCtx.allocate(instructionName, (D+M+2)*(4*M)*LibMatrixCUDA.sizeOfDataType); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight", + ExecutionConfig.getConfigForSimpleVectorOperations((D+M+2)*(4*M)), + sysmlWPointer, sysmlBiasPointer, cudnnWPointer, D, M); + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); + + boolean return_sequences = ec.getScalarInput(_input6.getName(), _input6.getValueType(), _input6.isLiteral()).getBooleanValue(); + + // Beause the matrices are released immediately, the output for transpose need not be taken into account + MatrixObject X = getMatrixInputForGPUInstruction(ec, _input1.getName()); + Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName); + int N = toInt(X.getNumRows()); // batchSize .. since X:(N, T*D) + long numColsX = X.getNumColumns(); + int T = toInt(numColsX/ D); // since X:(N, T*D) ... seqLength + Pointer cudnnInput = gCtx.allocate(instructionName, (N*T*D)*LibMatrixCUDA.sizeOfDataType); + LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input", + ExecutionConfig.getConfigForSimpleVectorOperations(N*T*D), + xPointer, cudnnInput, N, D, T*D, N*T*D); + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + + Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, getMatrixInputForGPUInstruction(ec, _input5.getName()), instructionName); + + LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, _output.getName(), _output2.getName(), N, M, D, T); + gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE); + gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE); + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input4.getName()); + ec.releaseMatrixInputForGPUInstruction(_input5.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output2.getName()); + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + } + + @Override + public void processInstruction(ExecutionContext ec) { + if (instOpcode.equalsIgnoreCase("bias_add") || instOpcode.equalsIgnoreCase("bias_multiply")) { + processBiasInstruction(instOpcode, ec); + return; + } + else if (instOpcode.equalsIgnoreCase("relu_backward")) { + processReLUBackwardInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("channel_sums")) { + processChannelSumsInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("lstm")) { + processLstmInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("lstm_backward")) { + processLstmBackwardInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("batch_norm2d")) { + processBatchNorm2dInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) { + processBatchNorm2dBackwardInstruction(ec); + return; + } + + GPUStatistics.incrementNoOfExecutedGPUInst(); + + int pad_h = getScalarInput(ec, _padding, 0); + int pad_w = getScalarInput(ec, _padding, 1); + int stride_h = getScalarInput(ec, _stride, 0); + int stride_w = getScalarInput(ec, _stride, 1); + + int N = getScalarInput(ec, _input_shape, 0); + int C = getScalarInput(ec, _input_shape, 1); + int H = getScalarInput(ec, _input_shape, 2); + int W = getScalarInput(ec, _input_shape, 3); + + int K = getScalarInput(ec, _filter_shape, 0); + + int R = getScalarInput(ec, _filter_shape, 2); + int S = getScalarInput(ec, _filter_shape, 3); + + int P = (int) DnnUtils.getP(H, R, stride_h, pad_h); + int Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w); + + if (instOpcode.equalsIgnoreCase("conv2d")) { + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input2.getName()); + + if(image.getNumRows() != N || image.getNumColumns() != C*H*W) + throw new DMLRuntimeException("Incorrect dimensions for image in conv2d"); + if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S) + throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d"); + + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q); + + LibMatrixCuDNN.conv2d(ec.getGPUContext(0), getExtendedOpcode(), image, filter, out, N, C, H, W, + K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget); + } + else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) { + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject bias = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input3.getName()); + + if(image.getNumRows() != N || image.getNumColumns() != C*H*W) + throw new DMLRuntimeException("Incorrect dimensions for image in conv2d"); + if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S) + throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d"); + + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, K * P * Q); + + LibMatrixCuDNN.conv2dBiasAdd(ec.getGPUContext(0), getExtendedOpcode(), image, bias, filter, out, N, C, H, W, + K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget); + } + else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) { + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); + + if(image.getNumRows() != N || image.getNumColumns() != C*H*W) + throw new DMLRuntimeException("Incorrect dimensions for image in conv2d_backward_filter"); + if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q) + throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_filter: " + + dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K*P*Q); + + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), K, C * R * S); + + LibMatrixCuDNN.conv2dBackwardFilter(ec.getGPUContext(0), getExtendedOpcode(), image, dout, out, N, C, H, W, + K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget); + // TODO: For now always copy the device data to host + // ec.gpuCtx.copyDeviceToHost(outputBlock); + } + else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) { + MatrixObject filter = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); + + if(filter.getNumRows() != K || filter.getNumColumns() != C*R*S) + throw new DMLRuntimeException("Incorrect dimensions for filter in convolution_backward_data"); + if(dout.getNumRows() != N || dout.getNumColumns() != K*P*Q) + throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_data: " + + dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K*P*Q); + + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W); + + LibMatrixCuDNN.conv2dBackwardData(ec.getGPUContext(0), getExtendedOpcode(), filter, dout, out, N, C, H, W, + K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, _intermediateMemoryBudget); + } + else if (instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling")) { + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + + if(image.getNumRows() != N || image.getNumColumns() != C*H*W) + throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling: " + + image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + C*H*W); + + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * P * Q); + PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling") ? PoolingType.MAX : PoolingType.AVG; + LibMatrixCuDNN.pooling(ec.getGPUContext(0), getExtendedOpcode(), image, out, N, C, H, W, + K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget); + } + else if (instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward")) { + MatrixObject image = getMatrixInputForGPUInstruction(ec, _input1.getName()); + MatrixObject dout = getMatrixInputForGPUInstruction(ec, _input2.getName()); + MatrixObject maxPoolOutput = _input3 != null ? getMatrixInputForGPUInstruction(ec, _input3.getName()) : null; + if(dout.getNumRows() != N || dout.getNumColumns() != C*P*Q) + throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward"); + if(image.getNumRows() != N || image.getNumColumns() != C*H*W) + throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling_backward: " + + image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + K*P*Q); + + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, _output.getName(), N, C * H * W); + PoolingType poolType = instOpcode.equalsIgnoreCase("maxpooling_backward") ? PoolingType.MAX : PoolingType.AVG; + LibMatrixCuDNN.poolingBackward(ec.getGPUContext(0), getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W, + K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, _intermediateMemoryBudget); + } + else { + throw new DMLRuntimeException("Unsupported GPU context for " + instOpcode); + } + + // release inputs/outputs + ec.releaseMatrixInputForGPUInstruction(_input1.getName()); + + boolean isPool = instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("avgpooling"); + boolean isPoolBackward = instOpcode.equalsIgnoreCase("maxpooling_backward") || instOpcode.equalsIgnoreCase("avgpooling_backward"); + + if ( !isPool ) + ec.releaseMatrixInputForGPUInstruction(_input2.getName()); + + if (instOpcode.equalsIgnoreCase("conv2d_bias_add") || + (isPoolBackward && _input3 != null)) + ec.releaseMatrixInputForGPUInstruction(_input3.getName()); + + ec.releaseMatrixOutputForGPUInstruction(_output.getName()); + } + + + private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) { + return (int) ec.getScalarInput(aL.get(index).getName(), + aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue(); + } +}
http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java index 1a9e632..f865f9b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java @@ -38,7 +38,7 @@ public abstract class GPUInstruction extends Instruction { AggregateUnary, AggregateBinary, RelationalBinary, - Convolution, + Dnn, MMTSJ, Reorg, Append, http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java deleted file mode 100644 index da24b6d..0000000 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ConvolutionSPInstruction.java +++ /dev/null @@ -1,402 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.sysml.runtime.instructions.spark; - -import java.util.ArrayList; -import java.util.Iterator; - -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.function.PairFlatMapFunction; -import org.apache.spark.broadcast.Broadcast; -import org.apache.sysml.parser.Expression.DataType; -import org.apache.sysml.parser.Expression.ValueType; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.functionobjects.SwapIndex; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.instructions.cp.CPOperand; -import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator; -import org.apache.sysml.runtime.instructions.spark.functions.ExtractBlockForBinaryReblock; -import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; -import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.MetaDataFormat; -import org.apache.sysml.runtime.matrix.data.ConvolutionParameters; -import org.apache.sysml.runtime.matrix.data.InputInfo; -import org.apache.sysml.runtime.matrix.data.LibMatrixDNN; -import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType; -import org.apache.sysml.runtime.matrix.data.LibMatrixNative; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.data.MatrixIndexes; -import org.apache.sysml.runtime.matrix.data.OutputInfo; -import org.apache.sysml.runtime.matrix.operators.ReorgOperator; -import org.apache.sysml.runtime.util.ConvolutionUtils; -import org.apache.sysml.utils.NativeHelper; - -import scala.Tuple2; - -public class ConvolutionSPInstruction extends UnarySPInstruction { - private CPOperand _in2; - private CPOperand _in3; - private ArrayList<CPOperand> _input_shape; - private ArrayList<CPOperand> _filter_shape; - private ArrayList<CPOperand> _stride = new ArrayList<>(); - private ArrayList<CPOperand> _padding = new ArrayList<>(); - - private ConvolutionSPInstruction(CPOperand in, CPOperand out, String opcode, String istr, - ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, - ArrayList<CPOperand> filter_shape) { - super(SPType.Convolution, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); - _stride = stride; - _padding = padding; - _input_shape = input_shape; - _filter_shape = filter_shape; - } - - private ConvolutionSPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, - ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, - ArrayList<CPOperand> filter_shape) { - super(SPType.Convolution, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); - _in2 = in2; - _stride = stride; - _padding = padding; - _input_shape = input_shape; - _filter_shape = filter_shape; - } - - private ConvolutionSPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, - String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, - ArrayList<CPOperand> filter_shape) { - super(SPType.Convolution, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); - _in2 = in2; - _in3 = in3; - _stride = stride; - _padding = padding; - _input_shape = input_shape; - _filter_shape = filter_shape; - } - - private ConvolutionSPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr) { - super(SPType.Convolution, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); - _in2 = in2; - } - - public static ConvolutionSPInstruction parseInstruction( String str ) { - CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); - CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); - - String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - String opcode = parts[0]; - if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) { - InstructionUtils.checkNumFields(parts, 14); - // stride1, stride2, padding1, padding2 - // input_shape1, input_shape2, input_shape3, input_shape4, - // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k - in.split(parts[1]); - out.split(parts[14]); - - ArrayList<CPOperand> stride = new ArrayList<>(); - ArrayList<CPOperand> padding = new ArrayList<>(); - ArrayList<CPOperand> input_shape = new ArrayList<>(); - ArrayList<CPOperand> filter_shape = new ArrayList<>(); - stride.add(new CPOperand(parts[2])); - stride.add(new CPOperand(parts[3])); - padding.add(new CPOperand(parts[4])); - padding.add(new CPOperand(parts[5])); - input_shape.add(new CPOperand(parts[6])); - input_shape.add(new CPOperand(parts[7])); - input_shape.add(new CPOperand(parts[8])); - input_shape.add(new CPOperand(parts[9])); - filter_shape.add(new CPOperand(parts[10])); - filter_shape.add(new CPOperand(parts[11])); - filter_shape.add(new CPOperand(parts[12])); - filter_shape.add(new CPOperand(parts[13])); - - return new ConvolutionSPInstruction(in, out, opcode, str, stride, - padding, input_shape, filter_shape); - } - else if (opcode.equalsIgnoreCase("maxpooling_backward") - || opcode.equalsIgnoreCase("conv2d") - || opcode.equalsIgnoreCase("conv2d_backward_filter") - || opcode.equalsIgnoreCase("conv2d_backward_data")) { - InstructionUtils.checkNumFields(parts, 15); - // dout, stride1, stride2, padding1, padding2 - // input_shape1, input_shape2, input_shape3, input_shape4, - // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k - in.split(parts[1]); - CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); - in2.split(parts[2]); - out.split(parts[15]); - - ArrayList<CPOperand> stride = new ArrayList<>(); - ArrayList<CPOperand> padding = new ArrayList<>(); - ArrayList<CPOperand> input_shape = new ArrayList<>(); - ArrayList<CPOperand> filter_shape = new ArrayList<>(); - stride.add(new CPOperand(parts[3])); - stride.add(new CPOperand(parts[4])); - padding.add(new CPOperand(parts[5])); - padding.add(new CPOperand(parts[6])); - input_shape.add(new CPOperand(parts[7])); - input_shape.add(new CPOperand(parts[8])); - input_shape.add(new CPOperand(parts[9])); - input_shape.add(new CPOperand(parts[10])); - filter_shape.add(new CPOperand(parts[11])); - filter_shape.add(new CPOperand(parts[12])); - filter_shape.add(new CPOperand(parts[13])); - filter_shape.add(new CPOperand(parts[14])); - - return new ConvolutionSPInstruction(in, in2, out, opcode, str, stride, - padding, input_shape, filter_shape); - } - else if (opcode.equalsIgnoreCase("conv2d_bias_add")) { - InstructionUtils.checkNumFields(parts, 16); - // dout, stride1, stride2, padding1, padding2 - // input_shape1, input_shape2, input_shape3, input_shape4, - // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k - in.split(parts[1]); - CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); - in2.split(parts[2]); - CPOperand in3 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); - in3.split(parts[3]); - out.split(parts[16]); - - ArrayList<CPOperand> stride = new ArrayList<>(); - ArrayList<CPOperand> padding = new ArrayList<>(); - ArrayList<CPOperand> input_shape = new ArrayList<>(); - ArrayList<CPOperand> filter_shape = new ArrayList<>(); - stride.add(new CPOperand(parts[4])); - stride.add(new CPOperand(parts[5])); - padding.add(new CPOperand(parts[6])); - padding.add(new CPOperand(parts[7])); - input_shape.add(new CPOperand(parts[8])); - input_shape.add(new CPOperand(parts[9])); - input_shape.add(new CPOperand(parts[10])); - input_shape.add(new CPOperand(parts[11])); - filter_shape.add(new CPOperand(parts[12])); - filter_shape.add(new CPOperand(parts[13])); - filter_shape.add(new CPOperand(parts[14])); - filter_shape.add(new CPOperand(parts[15])); - - return new ConvolutionSPInstruction(in, in2, in3, out, opcode, str, stride, - padding, input_shape, filter_shape); - } - else if (opcode.equalsIgnoreCase("bias_add")) { - InstructionUtils.checkNumFields(parts, 3); - in.split(parts[1]); - CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); - in2.split(parts[2]); - out.split(parts[3]); - return new ConvolutionSPInstruction(in, in2, out, opcode, str); - } - else { - throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str); - } - } - - private static JavaPairRDD<MatrixIndexes,MatrixBlock> reblockAsRectangularMatrices(SparkExecutionContext sec, String name, int numRowsPerBlock) { - JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( name ); - MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(name); - if(mcRdd.getColsPerBlock() < mcRdd.getCols() || mcRdd.getRowsPerBlock() != 1) { - MatrixCharacteristics mcOut = new MatrixCharacteristics(mcRdd); - mcOut.setColsPerBlock((int)mcRdd.getCols()); - mcOut.setRowsPerBlock(numRowsPerBlock); - in1 = RDDAggregateUtils.mergeByKey(in1.flatMapToPair(new ExtractBlockForBinaryReblock(mcRdd, mcOut))); - // TODO: Inject checkpoint to avoid doing this repeated for validation set -// sec.setRDDHandleForVariable(name, in1); -// sec.setMetaData(name, new MatrixDimensionsMetaData(mcOut)); - } - return in1; - } - - private Broadcast<MatrixBlock> getBroadcast(SparkExecutionContext sec, String name) { - MatrixBlock mb = sec.getMatrixInput( name, getExtendedOpcode() ); - sec.releaseMatrixInput(name, getExtendedOpcode()); - return sec.getSparkContext().broadcast(mb); - } - - @Override - public void processInstruction(ExecutionContext ec) { - SparkExecutionContext sec = (SparkExecutionContext)ec; - if(instOpcode.equalsIgnoreCase("conv2d") || instOpcode.equalsIgnoreCase("conv2d_bias_add") - || instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { - String rddVar = input1.getName(); - int numRowsPerBlock = 1; - JavaPairRDD<MatrixIndexes,MatrixBlock> inputRDD = reblockAsRectangularMatrices(sec, rddVar, numRowsPerBlock); - MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(rddVar); - - // ------------------------------------ - // TODO: Handle large filters > 2G - Broadcast<MatrixBlock> filterBroadcast = null; - Broadcast<MatrixBlock> biasBroadcast = null; - if(instOpcode.equalsIgnoreCase("conv2d")) { - filterBroadcast = getBroadcast(sec, _in2.getName()); - } - else if(instOpcode.equalsIgnoreCase("conv2d_bias_add")) { - filterBroadcast = getBroadcast(sec, _in3.getName()); - biasBroadcast = getBroadcast(sec, _in2.getName()); - } - // ------------------------------------ - - int pad_h = getScalarInput(ec, _padding, 0); - int pad_w = getScalarInput(ec, _padding, 1); - int stride_h = getScalarInput(ec, _stride, 0); - int stride_w = getScalarInput(ec, _stride, 1); - - // int N = getScalarInput(ec, _input_shape, 0); - int C = getScalarInput(ec, _input_shape, 1); - int H = getScalarInput(ec, _input_shape, 2); - int W = getScalarInput(ec, _input_shape, 3); - - int K = getScalarInput(ec, _filter_shape, 0); - int R = getScalarInput(ec, _filter_shape, 2); - int S = getScalarInput(ec, _filter_shape, 3); - int P = (int) ConvolutionUtils.getP(H, R, stride_h, pad_h); - int Q = (int) ConvolutionUtils.getQ(W, S, stride_w, pad_w); - - ConvolutionParameters params = new ConvolutionParameters(numRowsPerBlock, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, 1); - boolean enableNativeBLAS = NativeHelper.isNativeLibraryLoaded(); - JavaPairRDD<MatrixIndexes,MatrixBlock> out = inputRDD.mapPartitionsToPair(new RDDConv2dMapMMFunction(filterBroadcast, params, instOpcode, biasBroadcast, mcRdd.getRows(), enableNativeBLAS), true); - - //put output RDD handle into symbol table - sec.setRDDHandleForVariable(output.getName(), out); - sec.addLineageRDD(output.getName(), rddVar); - - long nnz = -1; // TODO: Handle nnz - long numCols = ((long)K)*((long)P)*((long)Q); - if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { - numCols = ((long)C)*((long)P)*((long)Q); - } - if(numCols > Integer.MAX_VALUE) { - throw new DMLRuntimeException("The current operator doesnot support large outputs."); - } - sec.setMetaData(output.getName(), - new MetaDataFormat(new MatrixCharacteristics(mcRdd.getRows(), numCols, numRowsPerBlock, (int)numCols, nnz), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); - } - else { - throw new DMLRuntimeException("Not implemented: " + instOpcode); - } - } - - private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) { - return (int) ec.getScalarInput(aL.get(index).getName(), - aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue(); - } - - private static class RDDConv2dMapMMFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> { - // PairFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, MatrixBlock> { - private static final long serialVersionUID = -2106155380020232155L; - Broadcast<MatrixBlock> filterBroadcast = null; - Broadcast<MatrixBlock> biasBroadcast = null; - ConvolutionParameters params = null; - String instOpcode = null; boolean enableNative; - long numRows = 0; - public RDDConv2dMapMMFunction(Broadcast<MatrixBlock> filterBroadcast, - ConvolutionParameters params, String instOpcode, Broadcast<MatrixBlock> biasBroadcast, long numRows, boolean enableNativeBLAS) { - this.filterBroadcast = filterBroadcast; - this.params = params; - this.instOpcode = instOpcode; - this.biasBroadcast = biasBroadcast; - this.numRows = numRows; - this.enableNative = enableNativeBLAS; - } - - private MatrixBlock processRectangularBlock(MatrixBlock matBlock) throws Exception { - MatrixBlock outputBlock = null; - if(instOpcode.equalsIgnoreCase("conv2d")) { - MatrixBlock filter = filterBroadcast.getValue(); - if(filter.isEmptyBlock() || matBlock.isEmptyBlock()) { - outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true); - } - else { - outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, false).allocateDenseBlock(); - if(enableNative) - LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); - else - LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); - } - } - else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) { - MatrixBlock filter = filterBroadcast.getValue(); - MatrixBlock bias = biasBroadcast.getValue(); - if((filter.isEmptyBlock() || matBlock.isEmptyBlock()) && bias.isEmptyBlock()) { - outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true); - } - else { - outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, false).allocateDenseBlock(); - if(!bias.isEmptyBlock()) - params.bias = bias; - if(enableNative) - LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); - else - LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); - } - } - else if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { - if(matBlock.isEmptyBlock()) { - outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, true); - } - else { - outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, false).allocateBlock(); - if(instOpcode.equalsIgnoreCase("maxpooling")) - outputBlock.getDenseBlock().set(-Double.MAX_VALUE); - LibMatrixDNN.pooling(matBlock, outputBlock, params, PoolingType.MAX); - } - } - else if(instOpcode.equalsIgnoreCase("avgpooling") || instOpcode.equalsIgnoreCase("relu_avgpooling")) { - if(matBlock.isEmptyBlock()) { - outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, true); - } - else { - outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, false).allocateBlock(); - LibMatrixDNN.pooling(matBlock, outputBlock, params, PoolingType.AVG); - } - } - else { - throw new RuntimeException("Not implemented"); - } - return outputBlock; - } - - @Override - public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( - Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) - throws Exception { - return new MapsideConvolutionPartitionIterator(arg0); - } - - // Avoid materialization of partitions - private class MapsideConvolutionPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> { - public MapsideConvolutionPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) { - super(in); - } - - @Override - protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception { - if(arg._1.getRowIndex() > numRows || arg._1.getColumnIndex() != 1) { - throw new RuntimeException("Expected the inputs to be reblocked as rectangular RDD"); - } - MatrixBlock out = processRectangularBlock(arg._2); - if(out.getNumRows() != 1) - throw new RuntimeException("Expected the output to have 1 row"); - return new Tuple2<>(arg._1, out); - } - } - - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/instructions/spark/DnnSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/DnnSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/DnnSPInstruction.java new file mode 100644 index 0000000..fbb214e --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/DnnSPInstruction.java @@ -0,0 +1,402 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysml.runtime.instructions.spark; + +import java.util.ArrayList; +import java.util.Iterator; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.broadcast.Broadcast; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.functionobjects.SwapIndex; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator; +import org.apache.sysml.runtime.instructions.spark.functions.ExtractBlockForBinaryReblock; +import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.MetaDataFormat; +import org.apache.sysml.runtime.matrix.data.DnnParameters; +import org.apache.sysml.runtime.matrix.data.InputInfo; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType; +import org.apache.sysml.runtime.matrix.data.LibMatrixNative; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.data.OutputInfo; +import org.apache.sysml.runtime.matrix.operators.ReorgOperator; +import org.apache.sysml.runtime.util.DnnUtils; +import org.apache.sysml.utils.NativeHelper; + +import scala.Tuple2; + +public class DnnSPInstruction extends UnarySPInstruction { + private CPOperand _in2; + private CPOperand _in3; + private ArrayList<CPOperand> _input_shape; + private ArrayList<CPOperand> _filter_shape; + private ArrayList<CPOperand> _stride = new ArrayList<>(); + private ArrayList<CPOperand> _padding = new ArrayList<>(); + + private DnnSPInstruction(CPOperand in, CPOperand out, String opcode, String istr, + ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape) { + super(SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); + _stride = stride; + _padding = padding; + _input_shape = input_shape; + _filter_shape = filter_shape; + } + + private DnnSPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr, + ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape) { + super(SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); + _in2 = in2; + _stride = stride; + _padding = padding; + _input_shape = input_shape; + _filter_shape = filter_shape; + } + + private DnnSPInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand out, String opcode, + String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape) { + super(SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); + _in2 = in2; + _in3 = in3; + _stride = stride; + _padding = padding; + _input_shape = input_shape; + _filter_shape = filter_shape; + } + + private DnnSPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, String istr) { + super(SPType.Dnn, new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, istr); + _in2 = in2; + } + + public static DnnSPInstruction parseInstruction( String str ) { + CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); + CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); + + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("relu_maxpooling")) { + InstructionUtils.checkNumFields(parts, 14); + // stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k + in.split(parts[1]); + out.split(parts[14]); + + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[2])); + stride.add(new CPOperand(parts[3])); + padding.add(new CPOperand(parts[4])); + padding.add(new CPOperand(parts[5])); + input_shape.add(new CPOperand(parts[6])); + input_shape.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + filter_shape.add(new CPOperand(parts[10])); + filter_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + + return new DnnSPInstruction(in, out, opcode, str, stride, + padding, input_shape, filter_shape); + } + else if (opcode.equalsIgnoreCase("maxpooling_backward") + || opcode.equalsIgnoreCase("conv2d") + || opcode.equalsIgnoreCase("conv2d_backward_filter") + || opcode.equalsIgnoreCase("conv2d_backward_data")) { + InstructionUtils.checkNumFields(parts, 15); + // dout, stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k + in.split(parts[1]); + CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); + in2.split(parts[2]); + out.split(parts[15]); + + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[3])); + stride.add(new CPOperand(parts[4])); + padding.add(new CPOperand(parts[5])); + padding.add(new CPOperand(parts[6])); + input_shape.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + input_shape.add(new CPOperand(parts[10])); + filter_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + filter_shape.add(new CPOperand(parts[14])); + + return new DnnSPInstruction(in, in2, out, opcode, str, stride, + padding, input_shape, filter_shape); + } + else if (opcode.equalsIgnoreCase("conv2d_bias_add")) { + InstructionUtils.checkNumFields(parts, 16); + // dout, stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k + in.split(parts[1]); + CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); + in2.split(parts[2]); + CPOperand in3 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); + in3.split(parts[3]); + out.split(parts[16]); + + ArrayList<CPOperand> stride = new ArrayList<>(); + ArrayList<CPOperand> padding = new ArrayList<>(); + ArrayList<CPOperand> input_shape = new ArrayList<>(); + ArrayList<CPOperand> filter_shape = new ArrayList<>(); + stride.add(new CPOperand(parts[4])); + stride.add(new CPOperand(parts[5])); + padding.add(new CPOperand(parts[6])); + padding.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + input_shape.add(new CPOperand(parts[10])); + input_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + filter_shape.add(new CPOperand(parts[14])); + filter_shape.add(new CPOperand(parts[15])); + + return new DnnSPInstruction(in, in2, in3, out, opcode, str, stride, + padding, input_shape, filter_shape); + } + else if (opcode.equalsIgnoreCase("bias_add")) { + InstructionUtils.checkNumFields(parts, 3); + in.split(parts[1]); + CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); + in2.split(parts[2]); + out.split(parts[3]); + return new DnnSPInstruction(in, in2, out, opcode, str); + } + else { + throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str); + } + } + + private static JavaPairRDD<MatrixIndexes,MatrixBlock> reblockAsRectangularMatrices(SparkExecutionContext sec, String name, int numRowsPerBlock) { + JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable( name ); + MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(name); + if(mcRdd.getColsPerBlock() < mcRdd.getCols() || mcRdd.getRowsPerBlock() != 1) { + MatrixCharacteristics mcOut = new MatrixCharacteristics(mcRdd); + mcOut.setColsPerBlock((int)mcRdd.getCols()); + mcOut.setRowsPerBlock(numRowsPerBlock); + in1 = RDDAggregateUtils.mergeByKey(in1.flatMapToPair(new ExtractBlockForBinaryReblock(mcRdd, mcOut))); + // TODO: Inject checkpoint to avoid doing this repeated for validation set +// sec.setRDDHandleForVariable(name, in1); +// sec.setMetaData(name, new MatrixDimensionsMetaData(mcOut)); + } + return in1; + } + + private Broadcast<MatrixBlock> getBroadcast(SparkExecutionContext sec, String name) { + MatrixBlock mb = sec.getMatrixInput( name, getExtendedOpcode() ); + sec.releaseMatrixInput(name, getExtendedOpcode()); + return sec.getSparkContext().broadcast(mb); + } + + @Override + public void processInstruction(ExecutionContext ec) { + SparkExecutionContext sec = (SparkExecutionContext)ec; + if(instOpcode.equalsIgnoreCase("conv2d") || instOpcode.equalsIgnoreCase("conv2d_bias_add") + || instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { + String rddVar = input1.getName(); + int numRowsPerBlock = 1; + JavaPairRDD<MatrixIndexes,MatrixBlock> inputRDD = reblockAsRectangularMatrices(sec, rddVar, numRowsPerBlock); + MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(rddVar); + + // ------------------------------------ + // TODO: Handle large filters > 2G + Broadcast<MatrixBlock> filterBroadcast = null; + Broadcast<MatrixBlock> biasBroadcast = null; + if(instOpcode.equalsIgnoreCase("conv2d")) { + filterBroadcast = getBroadcast(sec, _in2.getName()); + } + else if(instOpcode.equalsIgnoreCase("conv2d_bias_add")) { + filterBroadcast = getBroadcast(sec, _in3.getName()); + biasBroadcast = getBroadcast(sec, _in2.getName()); + } + // ------------------------------------ + + int pad_h = getScalarInput(ec, _padding, 0); + int pad_w = getScalarInput(ec, _padding, 1); + int stride_h = getScalarInput(ec, _stride, 0); + int stride_w = getScalarInput(ec, _stride, 1); + + // int N = getScalarInput(ec, _input_shape, 0); + int C = getScalarInput(ec, _input_shape, 1); + int H = getScalarInput(ec, _input_shape, 2); + int W = getScalarInput(ec, _input_shape, 3); + + int K = getScalarInput(ec, _filter_shape, 0); + int R = getScalarInput(ec, _filter_shape, 2); + int S = getScalarInput(ec, _filter_shape, 3); + int P = (int) DnnUtils.getP(H, R, stride_h, pad_h); + int Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w); + + DnnParameters params = new DnnParameters(numRowsPerBlock, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, 1); + boolean enableNativeBLAS = NativeHelper.isNativeLibraryLoaded(); + JavaPairRDD<MatrixIndexes,MatrixBlock> out = inputRDD.mapPartitionsToPair(new RDDConv2dMapMMFunction(filterBroadcast, params, instOpcode, biasBroadcast, mcRdd.getRows(), enableNativeBLAS), true); + + //put output RDD handle into symbol table + sec.setRDDHandleForVariable(output.getName(), out); + sec.addLineageRDD(output.getName(), rddVar); + + long nnz = -1; // TODO: Handle nnz + long numCols = ((long)K)*((long)P)*((long)Q); + if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { + numCols = ((long)C)*((long)P)*((long)Q); + } + if(numCols > Integer.MAX_VALUE) { + throw new DMLRuntimeException("The current operator doesnot support large outputs."); + } + sec.setMetaData(output.getName(), + new MetaDataFormat(new MatrixCharacteristics(mcRdd.getRows(), numCols, numRowsPerBlock, (int)numCols, nnz), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); + } + else { + throw new DMLRuntimeException("Not implemented: " + instOpcode); + } + } + + private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) { + return (int) ec.getScalarInput(aL.get(index).getName(), + aL.get(index).getValueType(), aL.get(index).isLiteral()).getLongValue(); + } + + private static class RDDConv2dMapMMFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> { + // PairFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, MatrixBlock> { + private static final long serialVersionUID = -2106155380020232155L; + Broadcast<MatrixBlock> filterBroadcast = null; + Broadcast<MatrixBlock> biasBroadcast = null; + DnnParameters params = null; + String instOpcode = null; boolean enableNative; + long numRows = 0; + public RDDConv2dMapMMFunction(Broadcast<MatrixBlock> filterBroadcast, + DnnParameters params, String instOpcode, Broadcast<MatrixBlock> biasBroadcast, long numRows, boolean enableNativeBLAS) { + this.filterBroadcast = filterBroadcast; + this.params = params; + this.instOpcode = instOpcode; + this.biasBroadcast = biasBroadcast; + this.numRows = numRows; + this.enableNative = enableNativeBLAS; + } + + private MatrixBlock processRectangularBlock(MatrixBlock matBlock) throws Exception { + MatrixBlock outputBlock = null; + if(instOpcode.equalsIgnoreCase("conv2d")) { + MatrixBlock filter = filterBroadcast.getValue(); + if(filter.isEmptyBlock() || matBlock.isEmptyBlock()) { + outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true); + } + else { + outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, false).allocateDenseBlock(); + if(enableNative) + LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); + else + LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); + } + } + else if (instOpcode.equalsIgnoreCase("conv2d_bias_add")) { + MatrixBlock filter = filterBroadcast.getValue(); + MatrixBlock bias = biasBroadcast.getValue(); + if((filter.isEmptyBlock() || matBlock.isEmptyBlock()) && bias.isEmptyBlock()) { + outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, true); + } + else { + outputBlock = new MatrixBlock(params.N, params.K*params.P*params.Q, false).allocateDenseBlock(); + if(!bias.isEmptyBlock()) + params.bias = bias; + if(enableNative) + LibMatrixNative.conv2d(matBlock, filter, outputBlock, params); + else + LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); + } + } + else if(instOpcode.equalsIgnoreCase("maxpooling") || instOpcode.equalsIgnoreCase("relu_maxpooling")) { + if(matBlock.isEmptyBlock()) { + outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, true); + } + else { + outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, false).allocateBlock(); + if(instOpcode.equalsIgnoreCase("maxpooling")) + outputBlock.getDenseBlock().set(-Double.MAX_VALUE); + LibMatrixDNN.pooling(matBlock, outputBlock, params, PoolingType.MAX); + } + } + else if(instOpcode.equalsIgnoreCase("avgpooling") || instOpcode.equalsIgnoreCase("relu_avgpooling")) { + if(matBlock.isEmptyBlock()) { + outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, true); + } + else { + outputBlock = new MatrixBlock(params.N, params.C*params.P*params.Q, false).allocateBlock(); + LibMatrixDNN.pooling(matBlock, outputBlock, params, PoolingType.AVG); + } + } + else { + throw new RuntimeException("Not implemented"); + } + return outputBlock; + } + + @Override + public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call( + Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) + throws Exception { + return new MapsideDnnPartitionIterator(arg0); + } + + // Avoid materialization of partitions + private class MapsideDnnPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> { + public MapsideDnnPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) { + super(in); + } + + @Override + protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception { + if(arg._1.getRowIndex() > numRows || arg._1.getColumnIndex() != 1) { + throw new RuntimeException("Expected the inputs to be reblocked as rectangular RDD"); + } + MatrixBlock out = processRectangularBlock(arg._2); + if(out.getNumRows() != 1) + throw new RuntimeException("Expected the output to have 1 row"); + return new Tuple2<>(arg._1, out); + } + } + + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java index 69d6a72..234b02b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java @@ -36,7 +36,7 @@ public abstract class SPInstruction extends Instruction { CentralMoment, Covariance, QSort, QPick, ParameterizedBuiltin, MAppend, RAppend, GAppend, GAlignedAppend, Rand, MatrixReshape, Ctable, Quaternary, CumsumAggregate, CumsumOffset, BinUaggChain, UaggOuterChain, - Write, SpoofFused, Convolution + Write, SpoofFused, Dnn } protected final SPType _sptype; http://git-wip-us.apache.org/repos/asf/systemml/blob/9fa5a09b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java deleted file mode 100644 index 99d21a2..0000000 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.matrix.data; - -import java.io.Serializable; - -import org.apache.sysml.hops.Hop; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.util.ConvolutionUtils; - -/** - * This class is container that stores parameters required for executing following operations: - * conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling, maxpooling_backward - */ -public class ConvolutionParameters implements Serializable -{ - private static final long serialVersionUID = -212362627205772829L; - - public int N, C, H, W, K, R, S, P, Q; - public int stride_h, stride_w, pad_h, pad_w; - public int numThreads; - - // Optional variables used by ConvolutionCPInstruction - public boolean enableNative = false; - - public MatrixBlock input1; public MatrixBlock input2; public MatrixBlock output; - - public MatrixBlock bias; - public int [] start_indexes_h, end_indexes_h, start_indexes_w, end_indexes_w; - - public double minValForMaxPoolOperations = -Double.MAX_VALUE; - - public ConvolutionParameters(long N, long C, long H, long W, - long K, long R, long S, long stride_h, long stride_w, - long pad_h, long pad_w, int numThreads) { - this.N = convertToInt(N); - this.C = convertToInt(C); - this.H = convertToInt(H); - this.W = convertToInt(W); - this.K = convertToInt(K); - this.R = convertToInt(R); - this.S = convertToInt(S); - this.stride_h = convertToInt(stride_h); - this.stride_w = convertToInt(stride_w); - this.pad_h = convertToInt(pad_h); - this.pad_w = convertToInt(pad_w); - if(H >= 0 && pad_h >= 0 && R >= 0 && stride_h >= 0) - P = (int) ((H + 2 * pad_h - R) / stride_h + 1); - else - P = -1; - // P = convertToInt(ConvolutionUtils.getP(H, R, stride_h, pad_h)); - - if(W >= 0 && pad_w >= 0 && S >= 0 && stride_w >= 0) - Q = (int) ((W + 2 * pad_w - S) / stride_w + 1); - else - Q = -1; - // Q = convertToInt(ConvolutionUtils.getQ(W, S, stride_w, pad_w)); - - this.numThreads = numThreads; - } - - public ConvolutionParameters(int N, int C, int H, int W, - int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int numThreads) { - this.N = N; - this.C = C; - this.H = H; - this.W = W; - this.K = K; - this.R = R; - this.S = S; - this.stride_h = stride_h; - this.stride_w = stride_w; - this.pad_h = pad_h; - this.pad_w = pad_w; - if(H <= 0 || R <= 0 || stride_h < 0 || pad_h < 0) - P = -1; - else - P = (int) ConvolutionUtils.getP(H, R, stride_h, pad_h); - if(W <= 0 || S <= 0 || stride_w < 0 || pad_w < 0) - Q = -1; - else - Q = (int) ConvolutionUtils.getQ(W, S, stride_w, pad_w); - this.numThreads = numThreads; - } - - private static int convertToInt(long val) { - if( val > Integer.MAX_VALUE ) - throw new DMLRuntimeException("The value for ConvolutionParameters is too large:" + val); - return (int) val; - } - - public boolean compare(ConvolutionParameters that) { - if(this.N == that.N && this.C == that.C && this.H == that.H && this.W == that.W - && this.K == that.K && this.R == that.R && this.S == that.S && this.stride_h == that.stride_h - && this.stride_w == that.stride_w && this.pad_h == that.pad_h - && this.pad_w == that.pad_w && this.numThreads == that.numThreads) { - return true; - } - return false; - } - - @Override - public String toString() { - return "(NCHW=[" + N + " " + C + " " + H + " " + W + "], KCRS=[" + K + " " + R + " " + S + "], stride=[" + stride_h + "," + stride_w + - "], pad=[" + pad_h + "," + pad_w + "])"; - } - - public void setIfUnknown(Hop N, Hop C, Hop H, Hop W, - Hop K, Hop R, Hop S, Hop stride_h, Hop stride_w, Hop pad_h, Hop pad_w, int numThreads) { - if(this.N < 0) this.N = convertToInt(Hop.computeSizeInformation(N)); - if(this.C < 0) this.C = convertToInt(Hop.computeSizeInformation(C)); - if(this.H < 0) this.H = convertToInt(Hop.computeSizeInformation(H)); - if(this.W < 0) this.W = convertToInt(Hop.computeSizeInformation(W)); - if(this.K < 0) this.K = convertToInt(Hop.computeSizeInformation(K)); - if(this.R < 0) this.R = convertToInt(Hop.computeSizeInformation(R)); - if(this.S < 0) this.S = convertToInt(Hop.computeSizeInformation(S)); - if(this.stride_h < 0) this.stride_h = convertToInt(Hop.computeSizeInformation(stride_h)); - if(this.stride_w < 0) this.stride_w = convertToInt(Hop.computeSizeInformation(stride_w)); - if(this.pad_h < 0) this.pad_h = convertToInt(Hop.computeSizeInformation(pad_h)); - if(this.pad_w < 0) this.pad_w = convertToInt(Hop.computeSizeInformation(pad_w)); - if(this.P < 0 && this.H >= 0 && this.R >= 0 && this.stride_h >= 0 && this.pad_h >= 0) { - this.P = (int) ConvolutionUtils.getP(this.H, this.R, this.stride_h, this.pad_h); - } - if(this.Q < 0 && this.W >= 0 && this.S >= 0 && this.stride_w >= 0 && this.pad_w >= 0) { - this.Q = (int) ConvolutionUtils.getQ(this.W, this.S, this.stride_w, this.pad_w); - } - this.numThreads = numThreads; - } - - public boolean isOutputThreadSafe() { - return output.isThreadSafe(); - } - - public boolean isStride1Pad0() { - return (stride_h==1 && stride_w==1 - && pad_h==0 && pad_w==0); - } - - public boolean isAllOnes(Integer...params) { - boolean ret = true; - for(int param : params) - ret &= (param == 1); - return ret; - } -}
