Repository: systemml Updated Branches: refs/heads/master 9861d7a3c -> 86f0e3f70
[SYSTEMML-445] Added CP implementation for batch_norm2d and batch_norm2d_backward implementation. - This feature is required for NN tests. - The current version of batch_norm2d_backward only supports dense image and dense dout. This will be fixed in future. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/86f0e3f7 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/86f0e3f7 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/86f0e3f7 Branch: refs/heads/master Commit: 86f0e3f705874877c261a4325a11fa45dfe851cc Parents: 9861d7a Author: Niketan Pansare <[email protected]> Authored: Fri Jun 1 21:25:40 2018 -0700 Committer: Niketan Pansare <[email protected]> Committed: Fri Jun 1 21:26:20 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/FunctionOp.java | 43 ++- .../instructions/CPInstructionParser.java | 2 + .../cp/ConvolutionCPInstruction.java | 176 +++++++--- .../sysml/runtime/matrix/data/LibMatrixDNN.java | 322 +++++++++++++++++++ 4 files changed, 490 insertions(+), 53 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/86f0e3f7/src/main/java/org/apache/sysml/hops/FunctionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java index ec2fda8..da13cd1 100644 --- a/src/main/java/org/apache/sysml/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops; import java.util.ArrayList; +import org.apache.sysml.api.DMLScript; import org.apache.sysml.lops.FunctionCallCP; import org.apache.sysml.lops.FunctionCallCPSingle; import org.apache.sysml.lops.Lop; @@ -168,10 +169,22 @@ public class FunctionOp extends Hop long outputValues = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0); return outputVectors+outputValues; } - else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) { + else if ( getFunctionName().equalsIgnoreCase("lstm") ) { // TODO: To allow for initial version to always run on the GPU return 0; } + else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") ) { + return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) + + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) + + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0) + + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), getOutputs().get(3).getDim2(), 1.0) + + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), getOutputs().get(4).getDim2(), 1.0); + } + else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) { + return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) + + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) + + OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0); + } else if ( getFunctionName().equalsIgnoreCase("svd") ) { long outputU = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0); long outputSigma = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0); @@ -202,7 +215,10 @@ public class FunctionOp extends Hop return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0) + 3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 1.0); } - else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { + else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { + return 0; + } + else if ( getFunctionName().equalsIgnoreCase("lstm") ) { // TODO: To allow for initial version to always run on the GPU return 0; } @@ -274,15 +290,20 @@ public class FunctionOp extends Hop || (getMemEstimate() >= OptimizerUtils.getLocalMemBudget() && OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP); } - else if( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { -// if ( OptimizerUtils.isMemoryBasedOptLevel() ) { -// _etype = findExecTypeByMemEstimate(); -// } -// else { -// _etype = ExecType.CP; -// } -// _etype = _etype == REMOTE ? ExecType.CP : _etype; // lstm not supported on Spark - _etype = ExecType.GPU; + else if( getFunctionName().equalsIgnoreCase("lstm")) { + if(DMLScript.USE_ACCELERATOR) + _etype = ExecType.GPU; + else + throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU."); + } + else if( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) { + if ( OptimizerUtils.isMemoryBasedOptLevel() ) { + _etype = findExecTypeByMemEstimate(); + } + else { + _etype = ExecType.CP; + } + _etype = _etype == REMOTE ? ExecType.CP : _etype; // batch_norm2d and batch_norm2d_backward are not supported on Spark } else { // Since the memory estimate is only conservative, do not throw http://git-wip-us.apache.org/repos/asf/systemml/blob/86f0e3f7/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index 82d4418..ece9d65 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -247,6 +247,8 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "bias_add" , CPType.Convolution); String2CPInstructionType.put( "bias_multiply" , CPType.Convolution); String2CPInstructionType.put( "channel_sums" , CPType.Convolution); + String2CPInstructionType.put( "batch_norm2d", CPType.Convolution); + String2CPInstructionType.put( "batch_norm2d_backward", CPType.Convolution); // Quaternary instruction opcodes String2CPInstructionType.put( "wsloss" , CPType.Quaternary); http://git-wip-us.apache.org/repos/asf/systemml/blob/86f0e3f7/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java index f1c313b..6addfe4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java @@ -20,21 +20,18 @@ package org.apache.sysml.runtime.instructions.cp; import java.util.ArrayList; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.functionobjects.KahanPlus; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.ConvolutionParameters; import org.apache.sysml.runtime.matrix.data.LibMatrixDNN; import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType; import org.apache.sysml.runtime.matrix.data.LibMatrixNative; import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.data.SparseBlock; import org.apache.sysml.runtime.util.ConvolutionUtils; import org.apache.sysml.utils.NativeHelper; @@ -44,6 +41,15 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { private final CPOperand _in2; private final CPOperand _in3; + private final CPOperand _in4; + private final CPOperand _in5; + private final CPOperand _in6; + private final CPOperand _in7; + private final CPOperand _in8; + private final CPOperand _out2; + private final CPOperand _out3; + private final CPOperand _out4; + private final CPOperand _out5; private final ArrayList<CPOperand> _input_shape; private final ArrayList<CPOperand> _filter_shape; private final ArrayList<CPOperand> _stride; @@ -57,6 +63,8 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { super(CPType.Convolution, null, in, out, opcode, istr); _in2 = in2; _in3 = in3; + _in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = null; + _out2 = null; _out3 = null; _out4 = null; _out5 = null; _stride = stride; _padding = padding; _input_shape = input_shape; @@ -98,6 +106,30 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { ArrayList<CPOperand> filter_shape, int numThreads, double intermediateMemoryBudget) { this(in, in2, in3, out, stride, padding, input_shape, filter_shape, numThreads, intermediateMemoryBudget, opcode, istr); } + + public ConvolutionCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, + CPOperand in6, CPOperand in7, CPOperand in8, + CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, + double intermediateMemoryBudget) throws DMLRuntimeException { + super(CPType.Convolution, null, in1, out, opcode, istr); + _in2 = in2; + _in3 = in3; + _in4 = in4; + _in5 = in5; + _in6 = in6; + _in7 = in7; + _in8 = in8; + _out2 = out2; + _out3 = out3; + _out4 = out4; + _out5 = out5; + _stride = null; + _padding = null; + _input_shape = null; + _filter_shape = null; + _numThreads = 0; + _intermediateMemoryBudget = intermediateMemoryBudget; + } public static ConvolutionCPInstruction parseInstruction(String str) { @@ -214,6 +246,36 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { CPOperand out = new CPOperand(parts[4]); return new ConvolutionCPInstruction(in, in2, in3, out, opcode, str, -1, 0); } + else if (opcode.equalsIgnoreCase("batch_norm2d")) { + InstructionUtils.checkNumFields(parts, 13); + CPOperand in1 = new CPOperand(parts[1]); // image + CPOperand in2 = new CPOperand(parts[2]); // scale + CPOperand in3 = new CPOperand(parts[3]); // bias + CPOperand in4 = new CPOperand(parts[4]); // runningMean + CPOperand in5 = new CPOperand(parts[5]); // runningVar + CPOperand in6 = new CPOperand(parts[6]); // mode + CPOperand in7 = new CPOperand(parts[7]); // epsilon + CPOperand in8 = new CPOperand(parts[8]); // exponentialAverageFactor + CPOperand out = new CPOperand(parts[9]); // ret + CPOperand out2 = new CPOperand(parts[10]); // retRunningMean + CPOperand out3 = new CPOperand(parts[11]); // retRunningVar + CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean + CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance + return new ConvolutionCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0); + } + else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) { + InstructionUtils.checkNumFields(parts, 9); + CPOperand in1 = new CPOperand(parts[1]); // image + CPOperand in2 = new CPOperand(parts[2]); // dout + CPOperand in3 = new CPOperand(parts[3]); // scale + CPOperand in4 = new CPOperand(parts[4]); // epsilon + CPOperand in5 = new CPOperand(parts[5]); // resultSaveMean + CPOperand in6 = new CPOperand(parts[6]); // resultSaveInvVariance + CPOperand out = new CPOperand(parts[7]); // dX + CPOperand out2 = new CPOperand(parts[8]); // dScale + CPOperand out3 = new CPOperand(parts[9]); // dBias + return new ConvolutionCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0); + } else { throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str); } @@ -309,45 +371,7 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { } else { outputBlock = new MatrixBlock(C, 1, false).allocateBlock(); - double [] output = outputBlock.getDenseBlockValues(); - if(input.isInSparseFormat()) { - SparseBlock sblock = input.getSparseBlock(); - for(int n = 0; n < input.getNumRows(); n++) { - if( sblock.isEmpty(n) ) - continue; - int apos = sblock.pos(n); - int alen = sblock.size(n); - int[] aix = sblock.indexes(n); - double[] avals = sblock.values(n); - - // Iterate over the sparse block - for(int j=apos; j<apos+alen; j++) { - // Note: the input is of shape [N, CHW] - int chw = aix[j]; - - // Get individual zero-based c,h,w indexes from zero-based 'chw' - int c = chw / HW; - output[c] += avals[j]; - } - } - } - else { - double [] inArr = input.getDenseBlockValues(); - if(inArr != null) { - KahanPlus kplus = KahanPlus.getKahanPlusFnObject(); - for(int c = 0; c < C; c++) { - KahanObject sum = new KahanObject(0.0, 0.0); - for(int n = 0; n < input.getNumRows(); n++) { - int index = n*C*HW + c*HW; - for(int hw = 0; hw < HW; hw++, index++) { - kplus.execute2(sum, inArr[index]); - } - } - output[c] = sum._sum; - } - } - } - outputBlock.recomputeNonZeros(); + LibMatrixDNN.channelSums(input, outputBlock, C, HW); } // release inputs/outputs @@ -355,6 +379,66 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { ec.setMatrixOutput(getOutputVariableName(), outputBlock, getExtendedOpcode()); } + + + public void processBatchNorm2dInstruction(ExecutionContext ec) { + MatrixBlock image = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock scale = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + MatrixBlock bias = ec.getMatrixInput(_in3.getName(), getExtendedOpcode()); + MatrixBlock runningMean = ec.getMatrixInput(_in4.getName(), getExtendedOpcode()); + MatrixBlock runningVar = ec.getMatrixInput(_in5.getName(), getExtendedOpcode()); + String phase = ec.getScalarInput(_in6.getName(), _in6.getValueType(), _in6.isLiteral()).getStringValue(); + double epsilon = ec.getScalarInput(_in7.getName(), _in7.getValueType(), _in7.isLiteral()).getDoubleValue(); + double mu = ec.getScalarInput(_in8.getName(), _in8.getValueType(), _in8.isLiteral()).getDoubleValue(); + + MatrixBlock ret = new MatrixBlock(image.getNumRows(), image.getNumColumns(), false).allocateBlock(); + MatrixBlock retRunningMean = new MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), false).allocateBlock(); + MatrixBlock retRunningVar = new MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), false).allocateBlock(); + MatrixBlock resultSaveMean = new MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), false).allocateBlock(); + MatrixBlock resultSaveInvVariance = new MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), false).allocateBlock(); + + LibMatrixDNN.batchNorm2D(image, scale, bias, runningMean, runningVar, phase, epsilon, mu, ret, + retRunningMean, retRunningVar, resultSaveMean, resultSaveInvVariance); + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode()); + ec.setMatrixOutput(output.getName(), ret, getExtendedOpcode()); + ec.setMatrixOutput(_out2.getName(), retRunningMean, getExtendedOpcode()); + ec.setMatrixOutput(_out3.getName(), retRunningVar, getExtendedOpcode()); + ec.setMatrixOutput(_out4.getName(), resultSaveMean, getExtendedOpcode()); + ec.setMatrixOutput(_out5.getName(), resultSaveInvVariance, getExtendedOpcode()); + } + + public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) { + MatrixBlock image = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock dout = ec.getMatrixInput(_in2.getName(), getExtendedOpcode()); + MatrixBlock scale = ec.getMatrixInput(_in3.getName(), getExtendedOpcode()); + double epsilon = ec.getScalarInput(_in4.getName(), _in4.getValueType(), _in4.isLiteral()).getDoubleValue(); + MatrixBlock resultSaveMean = ec.getMatrixInput(_in5.getName(), getExtendedOpcode()); + MatrixBlock resultSaveInvVariance = ec.getMatrixInput(_in6.getName(), getExtendedOpcode()); + + MatrixBlock dX = new MatrixBlock(image.getNumRows(), image.getNumColumns(), false).allocateBlock(); + MatrixBlock dScale = new MatrixBlock(scale.getNumRows(), scale.getNumColumns(), false).allocateBlock(); + MatrixBlock dBias = new MatrixBlock(scale.getNumRows(), scale.getNumColumns(), false).allocateBlock(); + + LibMatrixDNN.batchNorm2DBackward(image, dout, scale, epsilon, resultSaveMean, resultSaveInvVariance, dX, dScale, dBias); + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode()); + ec.releaseMatrixInput(_in6.getName(), getExtendedOpcode()); + ec.setMatrixOutput(output.getName(), dX, getExtendedOpcode()); + ec.setMatrixOutput(_out2.getName(), dScale, getExtendedOpcode()); + ec.setMatrixOutput(_out3.getName(), dBias, getExtendedOpcode()); + } + + // Assumption: enableNative && NativeHelper.isNativeLibraryLoaded() is true // This increases the number of native calls. For example:the cases where filter is sparse but input is dense private static boolean isFilterSparse(MatrixBlock filter) { @@ -385,6 +469,14 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { processChannelSumsInstruction(ec); return; } + else if (instOpcode.equalsIgnoreCase("batch_norm2d")) { + processBatchNorm2dInstruction(ec); + return; + } + else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) { + processBatchNorm2dBackwardInstruction(ec); + return; + } // acquire inputs MatrixBlock outputBlock = null; http://git-wip-us.apache.org/repos/asf/systemml/blob/86f0e3f7/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index 2d2cf63..f8318a5 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -19,6 +19,7 @@ package org.apache.sysml.runtime.matrix.data; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; @@ -30,6 +31,8 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.functionobjects.KahanPlus; +import org.apache.sysml.runtime.instructions.cp.KahanObject; import org.apache.sysml.runtime.util.CommonThreadPool; import org.apache.sysml.runtime.util.ConvolutionUtils; @@ -332,6 +335,325 @@ public class LibMatrixDNN { outputBlock.examSparsity(); } + /** + * Perform channel sum operation + * + * @param input input matrix block + * @param outputBlock output matrix block + * @param C number of channels + * @param HW height X width + */ + public static void channelSums(MatrixBlock input, MatrixBlock outputBlock, int C, int HW) { + double [] output = outputBlock.getDenseBlockValues(); + if(input.isInSparseFormat()) { + SparseBlock sblock = input.getSparseBlock(); + for(int n = 0; n < input.getNumRows(); n++) { + if( sblock.isEmpty(n) ) + continue; + int apos = sblock.pos(n); + int alen = sblock.size(n); + int[] aix = sblock.indexes(n); + double[] avals = sblock.values(n); + + // Iterate over the sparse block + for(int j=apos; j<apos+alen; j++) { + // Note: the input is of shape [N, CHW] + int chw = aix[j]; + + // Get individual zero-based c,h,w indexes from zero-based 'chw' + int c = chw / HW; + output[c] += avals[j]; + } + } + } + else { + double [] inArr = input.getDenseBlockValues(); + if(inArr != null) { + KahanPlus kplus = KahanPlus.getKahanPlusFnObject(); + for(int c = 0; c < C; c++) { + KahanObject sum = new KahanObject(0.0, 0.0); + for(int n = 0; n < input.getNumRows(); n++) { + int index = n*C*HW + c*HW; + for(int hw = 0; hw < HW; hw++, index++) { + kplus.execute2(sum, inArr[index]); + } + } + output[c] = sum._sum; + } + } + } + outputBlock.recomputeNonZeros(); + } + + public static void batchNorm2DBackward(MatrixBlock image, MatrixBlock dout, MatrixBlock scale, double epsilon, + MatrixBlock resultSaveMean, MatrixBlock resultSaveInvVariance, + MatrixBlock dX, MatrixBlock dScale, MatrixBlock dBias) { + int N = image.getNumRows(); + int K = scale.getNumRows(); + int PQ = image.getNumColumns() / K; + channelSums(image, dBias, K, PQ); + // Since output + if(dBias.isInSparseFormat()) + dBias.sparseToDense(); + if(dScale.isInSparseFormat()) + dScale.sparseToDense(); + if(dX.isInSparseFormat()) + dX.sparseToDense(); + // Very small matrices + if(resultSaveMean.isInSparseFormat()) + resultSaveMean.sparseToDense(); + if(resultSaveInvVariance.isInSparseFormat()) + resultSaveInvVariance.sparseToDense(); + if(scale.isInSparseFormat()) + scale.sparseToDense(); + double [] dBiasArr = dBias.getDenseBlockValues(); + double [] dScaleArr = dScale.getDenseBlockValues(); + double [] dXArr = dX.getDenseBlockValues(); + double [] mean = resultSaveMean.getDenseBlockValues(); + double [] invVar = resultSaveInvVariance.getDenseBlockValues(); + double [] scaleArr = scale.getDenseBlockValues(); + // since K is relatively small, it reduces code complexity. We can avoid this in subsequent commits. + mean = (mean==null) ? new double[K] : mean; + invVar = (invVar==null) ? new double[K] : invVar; + scaleArr = (scaleArr == null) ? new double[K] : scaleArr; + + // TODO: Handle sparse image and dout cases: + if(image.isInSparseFormat()) + image.sparseToDense(); + if(dout.isInSparseFormat()) + dout.sparseToDense(); + + if(!image.isInSparseFormat() && !dout.isInSparseFormat()) { + double [] imageArr = image.getDenseBlockValues(); + double [] doutArr = dout.getDenseBlockValues(); + double constant1 = Math.pow(N*PQ, -1); + int KPQ = K*PQ; + for(int k = 0; k < K; k++) { + double dvar = 0; + double dmean_norm_branch = 0; double dmean_var_branch = 0; + double sumDout = 0; double sum = 0; + for(int n = 0; n < N; n++) { + int index = n*KPQ + k*PQ; + for(int pq = 0; pq < PQ; pq++, index++) { + double doutVal = doutArr != null ? doutArr[index] : 0; + double centered = imageArr != null ? imageArr[index] : 0; + centered -= mean[k]; + double dnorm = doutVal*scaleArr[k]; + dvar -= 0.5*centered*Math.pow(invVar[k], 3)*dnorm; + dmean_norm_branch -= dnorm*invVar[k]; + sum += centered * invVar[k] * doutVal; + sumDout += doutVal; + dmean_var_branch -= 2*constant1*centered; + } + } + dBiasArr[k] = sumDout; + dScaleArr[k] = sum; + dmean_var_branch *= dvar; + double dmean = dmean_norm_branch + dmean_var_branch; + double dX_mean_branch = constant1*dmean; + + for(int n = 0; n < N; n++) { + int index = n*KPQ + k*PQ; + for(int pq = 0; pq < PQ; pq++, index++) { + double doutVal = doutArr != null ? doutArr[index] : 0; + double centered = imageArr != null ? imageArr[index] : 0; + centered -= mean[k]; + double dnorm = doutVal*scaleArr[k]; + double dX_norm_branch = dnorm*invVar[k]; + double dX_var_branch = 2*constant1*centered*dvar; + dXArr[index] = dX_norm_branch + dX_mean_branch + dX_var_branch; + } + } + } + } + else { + throw new DMLRuntimeException("Sparse format is not yet supported for batch norm backward"); + } + dBias.recomputeNonZeros(); + dScale.recomputeNonZeros(); + dX.recomputeNonZeros(); + } + + public static void batchNorm2D(MatrixBlock image, MatrixBlock scale, MatrixBlock bias, MatrixBlock runningMean, + MatrixBlock runningVar, String phase, double epsilon, double mu, + MatrixBlock ret, MatrixBlock retRunningMean, MatrixBlock retRunningVar, + MatrixBlock resultSaveMean, MatrixBlock resultSaveInvVariance) { + // Since bias, scale, runningMean, runningVar are extremely small array + if(bias.isInSparseFormat()) + bias.sparseToDense(); + double [] biasArr = bias.getDenseBlockValues(); + if(scale.isInSparseFormat()) + scale.sparseToDense(); + double [] scaleArr = scale.getDenseBlockValues(); + if(runningMean.isInSparseFormat()) + runningMean.sparseToDense(); + double [] runningMeanArr = runningMean.getDenseBlockValues(); // ema_mean + if(runningVar.isInSparseFormat()) + runningVar.sparseToDense(); + double [] runningVarArr = runningVar.getDenseBlockValues(); // ema_var + + double [] retRunningMeanArr = retRunningMean.getDenseBlockValues(); // ema_mean_upd + double [] retRunningVarArr = retRunningVar.getDenseBlockValues(); // ema_var_upd + double [] resultSaveMeanArr = resultSaveMean.getDenseBlockValues(); // cache_mean + double [] resultSaveInvVarianceArr = resultSaveInvVariance.getDenseBlockValues(); // cache_inv_var + + int N = image.getNumRows(); + int K = bias.getNumRows(); // number of output channels + int PQ = image.getNumColumns() / K; // output height X output width + + if(phase.equalsIgnoreCase("train")) { + computeBiasSumAndSumSquares(image, resultSaveMeanArr, resultSaveInvVarianceArr, K, PQ); + int NPQ = N*PQ; + for(int k = 0; k < K; k++) { + double mean = resultSaveMeanArr[k] / NPQ; + double var = resultSaveInvVarianceArr[k]/NPQ - Math.pow(mean, 2.0); + resultSaveMeanArr[k] = mean; + resultSaveInvVarianceArr[k] = Math.pow(Math.sqrt(var + epsilon), -1.0); + retRunningMeanArr[k] = mu*runningMeanArr[k] + (1-mu)*mean; + retRunningVarArr[k] = mu*runningVarArr[k] + (1-mu)*mean; + } + } + else if(phase.equalsIgnoreCase("test")) { + copy(runningMean, retRunningMeanArr); // ema_mean_upd = ema_mean + copy(runningVar, retRunningVarArr); // ema_var_upd = ema_var + copy(runningMean, resultSaveMeanArr); // cache_mean = ema_mean + double invSqrtEps = Math.pow(Math.sqrt(epsilon), -1.0); + double [] inArr = runningVar.getDenseBlockValues(); + if(inArr != null) { + for(int i = 0; i < inArr.length; i++) { + resultSaveInvVarianceArr[i] = Math.pow(Math.sqrt(inArr[i] + epsilon), -1.0); + } + } + else { + Arrays.fill(resultSaveInvVarianceArr, invSqrtEps); + } + } + else { + throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + phase); + } + + // Normalize, shift, and scale + double [] retArr = ret.getDenseBlockValues(); + copy(image, retArr); + if(resultSaveMean != null && resultSaveInvVariance != null && biasArr != null && scaleArr != null) { + // Common scenario: + int index = 0; + for(int n = 0; n < N; n++) { + for(int k = 0; k < K; k++) { + for(int pq = 0; pq < PQ; pq++, index++) { + retArr[index] = (retArr[index]-resultSaveMeanArr[k])*resultSaveInvVarianceArr[k]*scaleArr[k] + biasArr[k]; + } + } + } + } + else { + addBias(retArr, resultSaveMeanArr, -1, N, K, PQ); + multiplyBias(retArr, resultSaveInvVarianceArr, N, K, PQ); + multiplyBias(retArr, scaleArr, N, K, PQ); + addBias(retArr, biasArr, 1, N, K, PQ); + } + ret.recomputeNonZeros(); + retRunningMean.recomputeNonZeros(); + retRunningVar.recomputeNonZeros(); + resultSaveMean.recomputeNonZeros(); + resultSaveInvVariance.recomputeNonZeros(); + } + + private static void copy(MatrixBlock input, double [] output) { + if(input.isInSparseFormat()) { + SparseBlock sblock = input.getSparseBlock(); + int numCols = input.getNumColumns(); + for(int n = 0; n < input.getNumRows(); n++) { + if( sblock.isEmpty(n) ) + continue; + int apos = sblock.pos(n); + int alen = sblock.size(n); + int[] aix = sblock.indexes(n); + double[] avals = sblock.values(n); + + // Iterate over the sparse block + for(int j=apos; j<apos+alen; j++) { + output[n*numCols + aix[j]] = avals[j]; + } + } + } + else { + double [] inputArr = input.getDenseBlockValues(); + if(inputArr != null) { + System.arraycopy(inputArr, 0, output, 0, inputArr.length); + } + } + } + + private static void addBias(double [] arr, double [] bias, double biasMultiplier, int N, int K, int PQ) { + int index = 0; + if(bias != null) { + for(int n = 0; n < N; n++) { + for(int k = 0; k < K; k++) { + for(int pq = 0; pq < PQ; pq++, index++) { + arr[index] += biasMultiplier*bias[k]; + } + } + } + } + } + + private static void multiplyBias(double [] arr, double [] bias, int N, int K, int PQ) { + int index = 0; + if(bias != null) { + for(int n = 0; n < N; n++) { + for(int k = 0; k < K; k++) { + for(int pq = 0; pq < PQ; pq++, index++) { + arr[index] *= bias[k]; + } + } + } + } + else { + Arrays.fill(arr, 0); + } + } + + private static void computeBiasSumAndSumSquares(MatrixBlock image, double [] sumArr, double [] sumSquaresArr, int K, int PQ) { + if(sumArr.length != K) { + throw new DMLRuntimeException("Expected the length of array to be " + K + ", but instead is " + sumArr.length); + } + if(sumSquaresArr.length != K) { + throw new DMLRuntimeException("Expected the length of array to be " + K + ", but instead is " + sumSquaresArr.length); + } + if(image.isInSparseFormat()) { + SparseBlock sblock = image.getSparseBlock(); + for(int r = 0; r < image.getNumRows(); r++) { + if( sblock.isEmpty(r) ) + continue; + int apos = sblock.pos(r); + int alen = sblock.size(r); + int[] aix = sblock.indexes(r); + double[] avals = sblock.values(r); + for(int j=apos; j<apos+alen; j++) { + int k = aix[j] / PQ; + sumArr[k] += avals[j]; + sumSquaresArr[k] += Math.pow(avals[j], 2.0); + } + } + } + else { + double [] X = image.getDenseBlockValues(); + int N = image.getNumRows(); + if(X != null) { + int index = 0; + for(int n = 0; n < N; n++) { + for(int k = 0; k < K; k++) { + for(int pq = 0; pq < PQ; pq++, index++) { + sumArr[k] += X[index]; + sumSquaresArr[k] += Math.pow(X[index], 2.0); + } + } + } + } + } + } + /** * Performs the operation corresponding to the DML script:
