Repository: systemml
Updated Branches:
  refs/heads/master 9861d7a3c -> 86f0e3f70


[SYSTEMML-445] Added CP implementation for batch_norm2d and
batch_norm2d_backward implementation.

- This feature is required for NN tests.
- The current version of batch_norm2d_backward only supports dense image
and dense dout. This will be fixed in future.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/86f0e3f7
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/86f0e3f7
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/86f0e3f7

Branch: refs/heads/master
Commit: 86f0e3f705874877c261a4325a11fa45dfe851cc
Parents: 9861d7a
Author: Niketan Pansare <[email protected]>
Authored: Fri Jun 1 21:25:40 2018 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Fri Jun 1 21:26:20 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/FunctionOp.java  |  43 ++-
 .../instructions/CPInstructionParser.java       |   2 +
 .../cp/ConvolutionCPInstruction.java            | 176 +++++++---
 .../sysml/runtime/matrix/data/LibMatrixDNN.java | 322 +++++++++++++++++++
 4 files changed, 490 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/86f0e3f7/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java 
b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index ec2fda8..da13cd1 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops;
 
 import java.util.ArrayList;
 
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.lops.FunctionCallCP;
 import org.apache.sysml.lops.FunctionCallCPSingle;
 import org.apache.sysml.lops.Lop;
@@ -168,10 +169,22 @@ public class FunctionOp extends Hop
                                long outputValues = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 1, 1.0);
                                return outputVectors+outputValues; 
                        }
-                       else if ( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
+                       else if ( getFunctionName().equalsIgnoreCase("lstm") ) {
                                // TODO: To allow for initial version to always 
run on the GPU
                                return 0; 
                        }
+                       else if ( 
getFunctionName().equalsIgnoreCase("batch_norm2d") ) {
+                               return 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0) +
+                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0) +
+                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), 
getOutputs().get(2).getDim2(), 1.0) +
+                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), 
getOutputs().get(3).getDim2(), 1.0) + 
+                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), 
getOutputs().get(4).getDim2(), 1.0);
+                       }
+                       else if ( 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
+                               return 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0) +
+                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0) +
+                                               
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), 
getOutputs().get(2).getDim2(), 1.0);
+                       }
                        else if ( getFunctionName().equalsIgnoreCase("svd") ) {
                                long outputU = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), 
getOutputs().get(0).getDim2(), 1.0);
                                long outputSigma = 
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), 
getOutputs().get(1).getDim2(), 1.0);
@@ -202,7 +215,10 @@ public class FunctionOp extends Hop
                                return 
OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 
getInput().get(0).getDim2(), 1.0) 
                                                + 
3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 
1.0); 
                        }
-                       else if ( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+                       else if 
(getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+                               return 0; 
+                       }
+                       else if ( getFunctionName().equalsIgnoreCase("lstm") ) {
                                // TODO: To allow for initial version to always 
run on the GPU
                                return 0; 
                        }
@@ -274,15 +290,20 @@ public class FunctionOp extends Hop
                                        || (getMemEstimate() >= 
OptimizerUtils.getLocalMemBudget()
                                                && 
OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
                        }
-                       else if( getFunctionName().equalsIgnoreCase("lstm") || 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
-//                             if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
-//                                     _etype = findExecTypeByMemEstimate();
-//                             }
-//                             else {
-//                                     _etype = ExecType.CP;
-//                             }
-//                             _etype = _etype == REMOTE ?  ExecType.CP : 
_etype; // lstm not supported on Spark
-                               _etype = ExecType.GPU;
+                       else if( getFunctionName().equalsIgnoreCase("lstm")) {
+                               if(DMLScript.USE_ACCELERATOR)
+                                       _etype = ExecType.GPU;
+                               else
+                                       throw new RuntimeException("The 
function " + getFunctionName() + " is only supported on GPU.");
+                       }
+                       else if( 
getFunctionName().equalsIgnoreCase("batch_norm2d") || 
getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
+                               if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
+                                       _etype = findExecTypeByMemEstimate();
+                               }
+                               else {
+                                       _etype = ExecType.CP;
+                               }
+                               _etype = _etype == REMOTE ?  ExecType.CP : 
_etype; // batch_norm2d and batch_norm2d_backward are  not supported on Spark
                        }
                        else {
                                // Since the memory estimate is only 
conservative, do not throw

http://git-wip-us.apache.org/repos/asf/systemml/blob/86f0e3f7/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index 82d4418..ece9d65 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -247,6 +247,8 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "bias_add"      , 
CPType.Convolution);
                String2CPInstructionType.put( "bias_multiply"      , 
CPType.Convolution);
                String2CPInstructionType.put( "channel_sums"      , 
CPType.Convolution);
+               String2CPInstructionType.put( "batch_norm2d",           
CPType.Convolution);
+               String2CPInstructionType.put( "batch_norm2d_backward",  
CPType.Convolution);
                
                // Quaternary instruction opcodes
                String2CPInstructionType.put( "wsloss"  , CPType.Quaternary);

http://git-wip-us.apache.org/repos/asf/systemml/blob/86f0e3f7/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index f1c313b..6addfe4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -20,21 +20,18 @@
 package org.apache.sysml.runtime.instructions.cp;
 
 import java.util.ArrayList;
-
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.functionobjects.KahanPlus;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.data.SparseBlock;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 import org.apache.sysml.utils.NativeHelper;
 
@@ -44,6 +41,15 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
        
        private final CPOperand _in2;
        private final CPOperand _in3;
+       private final CPOperand _in4;
+       private final CPOperand _in5;
+       private final CPOperand _in6;
+       private final CPOperand _in7;
+       private final CPOperand _in8;
+       private final CPOperand _out2;
+       private final CPOperand _out3;
+       private final CPOperand _out4;
+       private final CPOperand _out5;
        private final ArrayList<CPOperand> _input_shape;
        private final ArrayList<CPOperand> _filter_shape;
        private final ArrayList<CPOperand> _stride;
@@ -57,6 +63,8 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                super(CPType.Convolution, null, in, out, opcode, istr);
                _in2 = in2;
                _in3 = in3;
+               _in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = null;
+               _out2 = null; _out3 = null; _out4 = null; _out5 = null;
                _stride = stride;
                _padding = padding;
                _input_shape = input_shape;
@@ -98,6 +106,30 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        ArrayList<CPOperand> filter_shape, int numThreads, 
double intermediateMemoryBudget) {
                this(in, in2, in3, out, stride, padding, input_shape, 
filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
        }
+       
+       public ConvolutionCPInstruction(CPOperand in1, CPOperand in2, CPOperand 
in3, CPOperand in4, CPOperand in5,
+                       CPOperand in6, CPOperand in7, CPOperand in8,
+                       CPOperand out, CPOperand out2, CPOperand out3, 
CPOperand out4, CPOperand out5, String opcode, String istr, 
+                       double intermediateMemoryBudget) throws 
DMLRuntimeException {
+               super(CPType.Convolution, null, in1, out, opcode, istr);
+               _in2 = in2;
+               _in3 = in3;
+               _in4 = in4;
+               _in5 = in5;
+               _in6 = in6;
+               _in7 = in7;
+               _in8 = in8;
+               _out2 = out2;
+               _out3 = out3;
+               _out4 = out4;
+               _out5 = out5;
+               _stride = null;
+               _padding = null;
+               _input_shape = null;
+               _filter_shape = null;
+               _numThreads = 0;
+               _intermediateMemoryBudget = intermediateMemoryBudget;
+       }
 
        public static ConvolutionCPInstruction parseInstruction(String str) {
 
@@ -214,6 +246,36 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        CPOperand out = new CPOperand(parts[4]);
                        return new ConvolutionCPInstruction(in, in2, in3, out, 
opcode, str, -1, 0);
                }
+               else if (opcode.equalsIgnoreCase("batch_norm2d")) {
+                       InstructionUtils.checkNumFields(parts, 13);
+                       CPOperand in1 = new CPOperand(parts[1]); // image
+                       CPOperand in2 = new CPOperand(parts[2]); // scale
+                       CPOperand in3 = new CPOperand(parts[3]); // bias
+                       CPOperand in4 = new CPOperand(parts[4]); // runningMean
+                       CPOperand in5 = new CPOperand(parts[5]); // runningVar
+                       CPOperand in6 = new CPOperand(parts[6]); // mode
+                       CPOperand in7 = new CPOperand(parts[7]); // epsilon
+                       CPOperand in8 = new CPOperand(parts[8]); // 
exponentialAverageFactor
+                       CPOperand out = new CPOperand(parts[9]);  // ret
+                       CPOperand out2 = new CPOperand(parts[10]); // 
retRunningMean
+                       CPOperand out3 = new CPOperand(parts[11]); // 
retRunningVar
+                       CPOperand out4 = new CPOperand(parts[12]); // 
resultSaveMean
+                       CPOperand out5 = new CPOperand(parts[13]); // 
resultSaveInvVariance
+                       return new ConvolutionCPInstruction(in1, in2, in3, in4, 
in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0);
+               }
+               else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
+                       InstructionUtils.checkNumFields(parts, 9);
+                       CPOperand in1 = new CPOperand(parts[1]); // image
+                       CPOperand in2 = new CPOperand(parts[2]); // dout
+                       CPOperand in3 = new CPOperand(parts[3]); // scale
+                       CPOperand in4 = new CPOperand(parts[4]); // epsilon
+                       CPOperand in5 = new CPOperand(parts[5]); // 
resultSaveMean
+                       CPOperand in6 = new CPOperand(parts[6]); // 
resultSaveInvVariance
+                       CPOperand out = new CPOperand(parts[7]);  // dX
+                       CPOperand out2 = new CPOperand(parts[8]); // dScale
+                       CPOperand out3 = new CPOperand(parts[9]); // dBias
+                       return new ConvolutionCPInstruction(in1, in2, in3, in4, 
in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0);
+               }
                else {
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a ConvolutionCPInstruction: " + str);
                }
@@ -309,45 +371,7 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                }
                else {
                        outputBlock = new MatrixBlock(C, 1, 
false).allocateBlock();
-                       double [] output = outputBlock.getDenseBlockValues();
-                       if(input.isInSparseFormat()) {
-                               SparseBlock sblock = input.getSparseBlock();
-                               for(int n = 0; n < input.getNumRows(); n++) {
-                                       if( sblock.isEmpty(n) )
-                                               continue;
-                                       int apos = sblock.pos(n);
-                                       int alen = sblock.size(n);
-                                       int[] aix = sblock.indexes(n);
-                                       double[] avals = sblock.values(n);
-                                       
-                                       // Iterate over the sparse block
-                                       for(int j=apos; j<apos+alen; j++) {
-                                               // Note: the input is of shape 
[N, CHW]
-                                               int chw = aix[j];
-                                               
-                                               // Get individual zero-based 
c,h,w indexes from zero-based 'chw'
-                                               int c = chw / HW;
-                                               output[c] += avals[j];
-                                       }
-                               }
-                       }
-                       else {
-                               double [] inArr = input.getDenseBlockValues();
-                               if(inArr != null) {
-                                       KahanPlus kplus = 
KahanPlus.getKahanPlusFnObject();
-                                       for(int c = 0; c < C; c++) {
-                                               KahanObject sum = new 
KahanObject(0.0, 0.0);
-                                               for(int n = 0; n < 
input.getNumRows(); n++) {
-                                                       int index =  n*C*HW + 
c*HW;
-                                                       for(int hw = 0; hw < 
HW; hw++, index++) {
-                                                               
kplus.execute2(sum, inArr[index]);
-                                                       }
-                                               }
-                                               output[c] = sum._sum;
-                                       }
-                               }
-                       }
-                       outputBlock.recomputeNonZeros();
+                       LibMatrixDNN.channelSums(input, outputBlock, C, HW);
                }
                
                // release inputs/outputs
@@ -355,6 +379,66 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                ec.setMatrixOutput(getOutputVariableName(), outputBlock, 
getExtendedOpcode());
        }
        
+       
+       
+       public void processBatchNorm2dInstruction(ExecutionContext ec) {
+               MatrixBlock image = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               MatrixBlock scale = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               MatrixBlock bias = ec.getMatrixInput(_in3.getName(), 
getExtendedOpcode());
+               MatrixBlock runningMean = ec.getMatrixInput(_in4.getName(), 
getExtendedOpcode());
+               MatrixBlock runningVar = ec.getMatrixInput(_in5.getName(), 
getExtendedOpcode());
+               String phase = ec.getScalarInput(_in6.getName(), 
_in6.getValueType(), _in6.isLiteral()).getStringValue();
+               double epsilon = ec.getScalarInput(_in7.getName(), 
_in7.getValueType(), _in7.isLiteral()).getDoubleValue();
+               double mu = ec.getScalarInput(_in8.getName(), 
_in8.getValueType(), _in8.isLiteral()).getDoubleValue();
+               
+               MatrixBlock ret = new MatrixBlock(image.getNumRows(), 
image.getNumColumns(), false).allocateBlock();
+               MatrixBlock retRunningMean = new 
MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), 
false).allocateBlock();
+               MatrixBlock retRunningVar = new 
MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), 
false).allocateBlock();
+               MatrixBlock resultSaveMean = new 
MatrixBlock(runningMean.getNumRows(), runningMean.getNumColumns(), 
false).allocateBlock();
+               MatrixBlock resultSaveInvVariance = new 
MatrixBlock(runningVar.getNumRows(), runningVar.getNumColumns(), 
false).allocateBlock();
+               
+               LibMatrixDNN.batchNorm2D(image, scale, bias, runningMean, 
runningVar, phase, epsilon, mu, ret, 
+                               retRunningMean, retRunningVar, resultSaveMean, 
resultSaveInvVariance);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in4.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(output.getName(), ret, getExtendedOpcode());
+               ec.setMatrixOutput(_out2.getName(), retRunningMean, 
getExtendedOpcode());
+               ec.setMatrixOutput(_out3.getName(), retRunningVar, 
getExtendedOpcode());
+               ec.setMatrixOutput(_out4.getName(), resultSaveMean, 
getExtendedOpcode());
+               ec.setMatrixOutput(_out5.getName(), resultSaveInvVariance, 
getExtendedOpcode());
+       }
+       
+       public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) {
+               MatrixBlock image = ec.getMatrixInput(input1.getName(), 
getExtendedOpcode());
+               MatrixBlock dout = ec.getMatrixInput(_in2.getName(), 
getExtendedOpcode());
+               MatrixBlock scale = ec.getMatrixInput(_in3.getName(), 
getExtendedOpcode());
+               double epsilon = ec.getScalarInput(_in4.getName(), 
_in4.getValueType(), _in4.isLiteral()).getDoubleValue();
+               MatrixBlock resultSaveMean = ec.getMatrixInput(_in5.getName(), 
getExtendedOpcode());
+               MatrixBlock resultSaveInvVariance = 
ec.getMatrixInput(_in6.getName(), getExtendedOpcode());
+               
+               MatrixBlock dX = new MatrixBlock(image.getNumRows(), 
image.getNumColumns(), false).allocateBlock();
+               MatrixBlock dScale = new MatrixBlock(scale.getNumRows(), 
scale.getNumColumns(), false).allocateBlock();
+               MatrixBlock dBias = new MatrixBlock(scale.getNumRows(), 
scale.getNumColumns(), false).allocateBlock();
+               
+               LibMatrixDNN.batchNorm2DBackward(image, dout, scale, epsilon, 
resultSaveMean, resultSaveInvVariance, dX, dScale, dBias);
+               
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in2.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in3.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in5.getName(), getExtendedOpcode());
+               ec.releaseMatrixInput(_in6.getName(), getExtendedOpcode());
+               ec.setMatrixOutput(output.getName(), dX, getExtendedOpcode());
+               ec.setMatrixOutput(_out2.getName(), dScale, 
getExtendedOpcode());
+               ec.setMatrixOutput(_out3.getName(), dBias, getExtendedOpcode());
+       }
+       
+       
        // Assumption: enableNative && NativeHelper.isNativeLibraryLoaded() is 
true
        // This increases the number of native calls. For example:the cases 
where filter is sparse but input is dense
        private static boolean isFilterSparse(MatrixBlock filter) {
@@ -385,6 +469,14 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        processChannelSumsInstruction(ec);
                        return;
                }
+               else if (instOpcode.equalsIgnoreCase("batch_norm2d")) {
+                       processBatchNorm2dInstruction(ec);
+                       return;
+               }
+               else if (instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
+                       processBatchNorm2dBackwardInstruction(ec);
+                       return;
+               }
                
                // acquire inputs
                MatrixBlock outputBlock = null;

http://git-wip-us.apache.org/repos/asf/systemml/blob/86f0e3f7/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 2d2cf63..f8318a5 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -19,6 +19,7 @@
 package org.apache.sysml.runtime.matrix.data;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ExecutorService;
@@ -30,6 +31,8 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.functionobjects.KahanPlus;
+import org.apache.sysml.runtime.instructions.cp.KahanObject;
 import org.apache.sysml.runtime.util.CommonThreadPool;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 
@@ -332,6 +335,325 @@ public class LibMatrixDNN {
                outputBlock.examSparsity();
        }
        
+       /**
+        * Perform channel sum operation
+        * 
+        * @param input input matrix block
+        * @param outputBlock output matrix block
+        * @param C number of channels
+        * @param HW height X width
+        */
+       public static void channelSums(MatrixBlock input, MatrixBlock 
outputBlock, int C, int HW) {
+               double [] output = outputBlock.getDenseBlockValues();
+               if(input.isInSparseFormat()) {
+                       SparseBlock sblock = input.getSparseBlock();
+                       for(int n = 0; n < input.getNumRows(); n++) {
+                               if( sblock.isEmpty(n) )
+                                       continue;
+                               int apos = sblock.pos(n);
+                               int alen = sblock.size(n);
+                               int[] aix = sblock.indexes(n);
+                               double[] avals = sblock.values(n);
+                               
+                               // Iterate over the sparse block
+                               for(int j=apos; j<apos+alen; j++) {
+                                       // Note: the input is of shape [N, CHW]
+                                       int chw = aix[j];
+                                       
+                                       // Get individual zero-based c,h,w 
indexes from zero-based 'chw'
+                                       int c = chw / HW;
+                                       output[c] += avals[j];
+                               }
+                       }
+               }
+               else {
+                       double [] inArr = input.getDenseBlockValues();
+                       if(inArr != null) {
+                               KahanPlus kplus = 
KahanPlus.getKahanPlusFnObject();
+                               for(int c = 0; c < C; c++) {
+                                       KahanObject sum = new KahanObject(0.0, 
0.0);
+                                       for(int n = 0; n < input.getNumRows(); 
n++) {
+                                               int index =  n*C*HW + c*HW;
+                                               for(int hw = 0; hw < HW; hw++, 
index++) {
+                                                       kplus.execute2(sum, 
inArr[index]);
+                                               }
+                                       }
+                                       output[c] = sum._sum; 
+                               }
+                       }
+               }
+               outputBlock.recomputeNonZeros();
+       }
+       
+       public static void batchNorm2DBackward(MatrixBlock image, MatrixBlock 
dout, MatrixBlock scale, double epsilon,  
+                       MatrixBlock resultSaveMean, MatrixBlock 
resultSaveInvVariance,
+                       MatrixBlock dX, MatrixBlock dScale, MatrixBlock dBias) {
+               int N = image.getNumRows();
+               int K = scale.getNumRows();
+               int PQ = image.getNumColumns() / K;
+               channelSums(image, dBias, K, PQ);
+               // Since output
+               if(dBias.isInSparseFormat())
+                       dBias.sparseToDense();
+               if(dScale.isInSparseFormat())
+                       dScale.sparseToDense();
+               if(dX.isInSparseFormat())
+                       dX.sparseToDense();
+               // Very small matrices
+               if(resultSaveMean.isInSparseFormat())
+                       resultSaveMean.sparseToDense();
+               if(resultSaveInvVariance.isInSparseFormat())
+                       resultSaveInvVariance.sparseToDense();
+               if(scale.isInSparseFormat())
+                       scale.sparseToDense();
+               double [] dBiasArr = dBias.getDenseBlockValues();
+               double [] dScaleArr = dScale.getDenseBlockValues();
+               double [] dXArr = dX.getDenseBlockValues();
+               double [] mean = resultSaveMean.getDenseBlockValues();
+               double [] invVar = resultSaveInvVariance.getDenseBlockValues();
+               double [] scaleArr = scale.getDenseBlockValues();
+               // since K is relatively small, it reduces code complexity. We 
can avoid this in subsequent commits.
+               mean = (mean==null) ? new double[K] : mean; 
+               invVar = (invVar==null) ? new double[K] : invVar;
+               scaleArr = (scaleArr == null) ? new double[K] : scaleArr;
+               
+               // TODO: Handle sparse image and dout cases:
+               if(image.isInSparseFormat())
+                       image.sparseToDense();
+               if(dout.isInSparseFormat())
+                       dout.sparseToDense();
+               
+               if(!image.isInSparseFormat() && !dout.isInSparseFormat()) {
+                       double [] imageArr = image.getDenseBlockValues();
+                       double [] doutArr = dout.getDenseBlockValues();
+                       double constant1 = Math.pow(N*PQ, -1);
+                       int KPQ = K*PQ;
+                       for(int k = 0; k < K; k++) {
+                               double dvar = 0; 
+                               double dmean_norm_branch = 0; double 
dmean_var_branch  = 0;
+                               double sumDout = 0; double sum = 0;
+                               for(int n = 0; n < N; n++) {
+                                       int index = n*KPQ + k*PQ; 
+                                       for(int pq = 0; pq < PQ; pq++, index++) 
{
+                                               double doutVal = doutArr != 
null ? doutArr[index] : 0;
+                                               double centered = imageArr != 
null ? imageArr[index] : 0;
+                                               centered -= mean[k];
+                                               double dnorm = 
doutVal*scaleArr[k];
+                                               dvar -= 
0.5*centered*Math.pow(invVar[k], 3)*dnorm;
+                                               dmean_norm_branch -= 
dnorm*invVar[k];
+                                               sum += centered * invVar[k] * 
doutVal;
+                                               sumDout += doutVal;
+                                               dmean_var_branch -= 
2*constant1*centered;
+                                       }
+                               }
+                               dBiasArr[k] = sumDout;
+                               dScaleArr[k] = sum;
+                               dmean_var_branch *= dvar;
+                               double dmean = dmean_norm_branch + 
dmean_var_branch;
+                               double dX_mean_branch = constant1*dmean;
+                               
+                               for(int n = 0; n < N; n++) {
+                                       int index = n*KPQ + k*PQ; 
+                                       for(int pq = 0; pq < PQ; pq++, index++) 
{
+                                               double doutVal = doutArr != 
null ? doutArr[index] : 0;
+                                               double centered = imageArr != 
null ? imageArr[index] : 0;
+                                               centered -= mean[k];
+                                               double dnorm = 
doutVal*scaleArr[k];
+                                               double dX_norm_branch = 
dnorm*invVar[k];
+                                               double dX_var_branch = 
2*constant1*centered*dvar;
+                                               dXArr[index] = dX_norm_branch + 
dX_mean_branch + dX_var_branch;
+                                       }
+                               }
+                       }
+               }
+               else {
+                       throw new DMLRuntimeException("Sparse format is not yet 
supported for batch norm backward");
+               }
+               dBias.recomputeNonZeros();
+               dScale.recomputeNonZeros();
+               dX.recomputeNonZeros();
+       }
+       
+       public static void batchNorm2D(MatrixBlock image, MatrixBlock scale, 
MatrixBlock bias, MatrixBlock runningMean, 
+                       MatrixBlock runningVar, String phase, double epsilon, 
double mu,
+                       MatrixBlock ret, MatrixBlock retRunningMean, 
MatrixBlock retRunningVar, 
+                       MatrixBlock resultSaveMean, MatrixBlock 
resultSaveInvVariance) {
+               // Since bias, scale, runningMean, runningVar are extremely 
small array
+               if(bias.isInSparseFormat())
+                       bias.sparseToDense();
+               double [] biasArr = bias.getDenseBlockValues();
+               if(scale.isInSparseFormat())
+                       scale.sparseToDense();
+               double [] scaleArr = scale.getDenseBlockValues();
+               if(runningMean.isInSparseFormat())
+                       runningMean.sparseToDense();
+               double [] runningMeanArr = runningMean.getDenseBlockValues(); 
// ema_mean
+               if(runningVar.isInSparseFormat())
+                       runningVar.sparseToDense(); 
+               double [] runningVarArr = runningVar.getDenseBlockValues(); // 
ema_var
+               
+               double [] retRunningMeanArr = 
retRunningMean.getDenseBlockValues(); // ema_mean_upd
+               double [] retRunningVarArr = 
retRunningVar.getDenseBlockValues(); // ema_var_upd
+               double [] resultSaveMeanArr = 
resultSaveMean.getDenseBlockValues(); // cache_mean
+               double [] resultSaveInvVarianceArr = 
resultSaveInvVariance.getDenseBlockValues(); // cache_inv_var
+               
+               int N = image.getNumRows();
+               int K = bias.getNumRows(); // number of output channels
+               int PQ = image.getNumColumns() / K; // output height X output 
width
+               
+               if(phase.equalsIgnoreCase("train")) { 
+                       computeBiasSumAndSumSquares(image, resultSaveMeanArr, 
resultSaveInvVarianceArr, K, PQ);
+                       int NPQ = N*PQ;
+                       for(int k = 0; k < K; k++) {
+                               double mean = resultSaveMeanArr[k] / NPQ;
+                               double var = resultSaveInvVarianceArr[k]/NPQ - 
Math.pow(mean, 2.0);
+                               resultSaveMeanArr[k] = mean;
+                               resultSaveInvVarianceArr[k] = 
Math.pow(Math.sqrt(var + epsilon), -1.0);
+                               retRunningMeanArr[k] = mu*runningMeanArr[k] + 
(1-mu)*mean;
+                               retRunningVarArr[k] = mu*runningVarArr[k] + 
(1-mu)*mean;
+                       }
+               }
+               else if(phase.equalsIgnoreCase("test")) {
+                       copy(runningMean, retRunningMeanArr); // ema_mean_upd = 
ema_mean
+                       copy(runningVar, retRunningVarArr); // ema_var_upd = 
ema_var
+                       copy(runningMean, resultSaveMeanArr); // cache_mean = 
ema_mean
+                       double invSqrtEps = Math.pow(Math.sqrt(epsilon), -1.0);
+                       double [] inArr = runningVar.getDenseBlockValues();
+                       if(inArr != null) {
+                               for(int i = 0; i < inArr.length; i++) {
+                                       resultSaveInvVarianceArr[i] = 
Math.pow(Math.sqrt(inArr[i] + epsilon), -1.0);
+                               }
+                       }
+                       else {
+                               Arrays.fill(resultSaveInvVarianceArr, 
invSqrtEps);
+                       }
+               }
+               else {
+                       throw new DMLRuntimeException("Incorrect mode: Expected 
either train or test, but found " + phase);
+               }
+               
+               // Normalize, shift, and scale
+               double [] retArr = ret.getDenseBlockValues();
+               copy(image, retArr);
+               if(resultSaveMean != null && resultSaveInvVariance != null && 
biasArr != null && scaleArr != null) {
+                       // Common scenario:
+                       int index = 0;
+                       for(int n = 0; n < N; n++) {
+                               for(int k = 0; k < K; k++) {
+                                       for(int pq = 0; pq < PQ; pq++, index++) 
{
+                                               retArr[index] = 
(retArr[index]-resultSaveMeanArr[k])*resultSaveInvVarianceArr[k]*scaleArr[k] + 
biasArr[k];
+                                       }
+                               }
+                       }
+               }
+               else {
+                       addBias(retArr, resultSaveMeanArr, -1, N, K, PQ);
+                       multiplyBias(retArr, resultSaveInvVarianceArr, N, K, 
PQ);
+                       multiplyBias(retArr, scaleArr, N, K, PQ);
+                       addBias(retArr, biasArr, 1, N, K, PQ);
+               }
+               ret.recomputeNonZeros();
+               retRunningMean.recomputeNonZeros();
+               retRunningVar.recomputeNonZeros();
+               resultSaveMean.recomputeNonZeros();
+               resultSaveInvVariance.recomputeNonZeros();
+       }
+       
+       private static void copy(MatrixBlock input, double [] output) {
+               if(input.isInSparseFormat()) {
+                       SparseBlock sblock = input.getSparseBlock();
+                       int numCols = input.getNumColumns();
+                       for(int n = 0; n < input.getNumRows(); n++) {
+                               if( sblock.isEmpty(n) )
+                                       continue;
+                               int apos = sblock.pos(n);
+                               int alen = sblock.size(n);
+                               int[] aix = sblock.indexes(n);
+                               double[] avals = sblock.values(n);
+                               
+                               // Iterate over the sparse block
+                               for(int j=apos; j<apos+alen; j++) {
+                                       output[n*numCols + aix[j]] = avals[j];
+                               }
+                       }
+               }
+               else {
+                       double [] inputArr = input.getDenseBlockValues();
+                       if(inputArr != null) {
+                               System.arraycopy(inputArr, 0, output, 0, 
inputArr.length);
+                       }
+               }
+       }
+       
+       private static void addBias(double [] arr, double [] bias, double 
biasMultiplier, int N, int K, int PQ) {
+               int index = 0;
+               if(bias != null) {
+                       for(int n = 0; n < N; n++) {
+                               for(int k = 0; k < K; k++) {
+                                       for(int pq = 0; pq < PQ; pq++, index++) 
{
+                                               arr[index] += 
biasMultiplier*bias[k];
+                                       }
+                               }
+                       }
+               }
+       }
+       
+       private static void multiplyBias(double [] arr, double [] bias, int N, 
int K, int PQ) {
+               int index = 0;
+               if(bias != null) {
+                       for(int n = 0; n < N; n++) {
+                               for(int k = 0; k < K; k++) {
+                                       for(int pq = 0; pq < PQ; pq++, index++) 
{
+                                               arr[index] *= bias[k];
+                                       }
+                               }
+                       }
+               }
+               else {
+                       Arrays.fill(arr, 0);
+               }
+       }
+       
+       private static void computeBiasSumAndSumSquares(MatrixBlock image, 
double [] sumArr, double [] sumSquaresArr, int K, int PQ) {
+               if(sumArr.length != K) {
+                       throw new DMLRuntimeException("Expected the length of 
array to be " + K + ", but instead is " + sumArr.length);
+               }
+               if(sumSquaresArr.length != K) {
+                       throw new DMLRuntimeException("Expected the length of 
array to be " + K + ", but instead is " + sumSquaresArr.length);
+               }
+               if(image.isInSparseFormat()) {
+                       SparseBlock sblock = image.getSparseBlock();
+                       for(int r = 0; r < image.getNumRows(); r++) {
+                               if( sblock.isEmpty(r) )
+                                       continue;
+                               int apos = sblock.pos(r);
+                               int alen = sblock.size(r);
+                               int[] aix = sblock.indexes(r);
+                               double[] avals = sblock.values(r);
+                               for(int j=apos; j<apos+alen; j++) {
+                                       int k = aix[j] / PQ;
+                                       sumArr[k] += avals[j];
+                                       sumSquaresArr[k] += Math.pow(avals[j], 
2.0);
+                               }
+                       }
+               }
+               else {
+                       double [] X = image.getDenseBlockValues();
+                       int N = image.getNumRows();
+                       if(X != null) {
+                               int index = 0;
+                               for(int n = 0; n < N; n++) {
+                                       for(int k = 0; k < K; k++) {
+                                               for(int pq = 0; pq < PQ; pq++, 
index++) {
+                                                       sumArr[k] += X[index];
+                                                       sumSquaresArr[k] += 
Math.pow(X[index], 2.0);
+                                               }
+                                       }
+                               }
+                       }
+               }
+       }
+       
        
        /**
         * Performs the operation corresponding to the DML script:

Reply via email to