[SYSTEMML-1078] Fix missing nnz maintenance conv2d ops, incl cleanups

This patch extends all conv2d operations by (so far unoptimized) nnz
maintenance in order to prevent side effects with update-in-place and
other operations that incrementally maintain the number of non-zeros. 


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d0b23d60
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d0b23d60
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d0b23d60

Branch: refs/heads/master
Commit: d0b23d607998e0bebd3d9d051faf748dd5530ce8
Parents: 827cdba
Author: Matthias Boehm <[email protected]>
Authored: Fri Feb 10 05:50:52 2017 +0100
Committer: Matthias Boehm <[email protected]>
Committed: Fri Feb 10 07:55:52 2017 +0100

----------------------------------------------------------------------
 .../runtime/controlprogram/ProgramBlock.java    |   4 +-
 .../cp/ConvolutionCPInstruction.java            |  90 ++++---
 .../sysml/runtime/matrix/data/LibMatrixDNN.java | 233 ++++++++++---------
 3 files changed, 171 insertions(+), 156 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d0b23d60/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java
index eb504ca..739b1cf 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java
@@ -400,10 +400,10 @@ public class ProgramBlock
                                        
                                        if( nnz1 != nnz2 )
                                                throw new 
DMLRuntimeException("Matrix nnz meta data was incorrect: ("+varname+", 
actual="+nnz1+", expected="+nnz2+", inst="+lastInst+")");
-                                                       
                                        
                                        if( sparse1 != sparse2 )
-                                               throw new 
DMLRuntimeException("Matrix was in wrong data representation: ("+varname+", 
actual="+sparse1+", expected="+sparse2+", nnz="+nnz1+", inst="+lastInst+")");
+                                               throw new 
DMLRuntimeException("Matrix was in wrong data representation: ("+varname+", 
actual="+sparse1+", expected="+sparse2 + 
+                                                               ", 
nrow="+mb.getNumRows()+", ncol="+mb.getNumColumns()+", nnz="+nnz1+", 
inst="+lastInst+")");
                                }
                        }
                }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d0b23d60/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index ed0b548..3513201 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -21,8 +21,6 @@ package org.apache.sysml.runtime.instructions.cp;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import org.apache.sysml.parser.Expression.DataType;
-import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
@@ -33,8 +31,8 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 
-public class ConvolutionCPInstruction extends UnaryCPInstruction {
-       
+public class ConvolutionCPInstruction extends UnaryCPInstruction 
+{      
        private CPOperand _in2;
        private CPOperand _in3; 
        private ArrayList<CPOperand> _input_shape;
@@ -101,8 +99,6 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
 
        public static ConvolutionCPInstruction parseInstruction(String str)
                        throws DMLRuntimeException {
-               CPOperand in = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-               CPOperand out = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
 
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
@@ -111,8 +107,8 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        // stride1, stride2, padding1, padding2
                        // input_shape1, input_shape2, input_shape3, 
input_shape4,
                        // filter_shape1, filter_shape2, filter_shape3, 
filter_shape4, k
-                       in.split(parts[1]);
-                       out.split(parts[14]);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand out = new CPOperand(parts[14]);
 
                        ArrayList<CPOperand> stride = new 
ArrayList<CPOperand>();
                        ArrayList<CPOperand> padding = new 
ArrayList<CPOperand>();
@@ -143,10 +139,9 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        // dout, stride1, stride2, padding1, padding2
                        // input_shape1, input_shape2, input_shape3, 
input_shape4,
                        // filter_shape1, filter_shape2, filter_shape3, 
filter_shape4, k
-                       in.split(parts[1]);
-                       CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-                       in2.split(parts[2]);
-                       out.split(parts[15]);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand out = new CPOperand(parts[15]);
 
                        ArrayList<CPOperand> stride = new 
ArrayList<CPOperand>();
                        ArrayList<CPOperand> padding = new 
ArrayList<CPOperand>();
@@ -174,12 +169,10 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        // dout, stride1, stride2, padding1, padding2
                        // input_shape1, input_shape2, input_shape3, 
input_shape4,
                        // filter_shape1, filter_shape2, filter_shape3, 
filter_shape4, k
-                       in.split(parts[1]);
-                       CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-                       in2.split(parts[2]);
-                       CPOperand in3 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-                       in3.split(parts[3]);
-                       out.split(parts[16]);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand in3 = new CPOperand(parts[3]);
+                       CPOperand out = new CPOperand(parts[16]);
 
                        ArrayList<CPOperand> stride = new 
ArrayList<CPOperand>();
                        ArrayList<CPOperand> padding = new 
ArrayList<CPOperand>();
@@ -204,10 +197,9 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                }
                else if (opcode.equalsIgnoreCase("bias_add") || 
opcode.equals("relu_backward")) {
                        InstructionUtils.checkNumFields(parts, 4);
-                       in.split(parts[1]);
-                       CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
-                       in2.split(parts[2]);
-                       out.split(parts[3]);
+                       CPOperand in = new CPOperand(parts[1]);
+                       CPOperand in2 = new CPOperand(parts[2]);
+                       CPOperand out = new CPOperand(parts[3]);
                        int k = Integer.parseInt(parts[4]);
                        return new ConvolutionCPInstruction(in, in2, out, 
opcode, str, k);
                }
@@ -216,24 +208,23 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                }
        }
 
-       private int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL,
-                       int index) throws DMLRuntimeException {
+       private int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> 
aL, int index) 
+                       throws DMLRuntimeException {
                return (int) ec.getScalarInput(aL.get(index).getName(),
                                aL.get(index).getValueType(), 
aL.get(index).isLiteral())
                                .getLongValue();
        }
        
+       @SuppressWarnings("unused")
        public void processReluBackwardInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
                // (X > 0) * dout
-               MatrixBlock outputBlock = null;
                MatrixBlock input = ec.getMatrixInput(input1.getName());
                MatrixBlock dout = ec.getMatrixInput(_in2.getName());
+               MatrixBlock outputBlock =  new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), 
+                       LibMatrixDNN.SUPPORTS_SPARSE_OUTPUTS && 
(input.isInSparseFormat() || dout.isInSparseFormat()));
                
-               if(input.isEmptyBlock() || dout.isEmptyBlock()) {
-                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), true, 0);
-               }
-               else {
-                       outputBlock = getDenseOutputBlock(ec, 
input.getNumRows(), input.getNumColumns());
+               if( !input.isEmptyBlock() && !dout.isEmptyBlock() ) {
+                       outputBlock.allocateDenseOrSparseBlock();
                        LibMatrixDNN.reluBackward(input, dout, outputBlock, 
_numThreads);
                }
                
@@ -244,24 +235,24 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
        }
        
        public void processBiasAddInstruction(ExecutionContext ec) throws 
DMLRuntimeException {
-               MatrixBlock outputBlock = null;
                MatrixBlock input = ec.getMatrixInput(input1.getName());
                MatrixBlock bias = ec.getMatrixInput(_in2.getName());
+               MatrixBlock outputBlock = null;
                
                if(bias.getNumColumns() != 1) {
                        throw new DMLRuntimeException("Expected the number of 
columns of bias matrix to be 1, but found " + bias.getNumColumns());
                }
                
                if(input.isEmptyBlock() && bias.isEmptyBlock()) {
-                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), true, 0);
+                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), true);
                }
                else if(bias.isEmptyBlock()) {
-                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), input.isInSparseFormat());
-                       outputBlock.copy(input);
+                       outputBlock = new MatrixBlock(input);
                }
                else {
                        // As we always fill the output first with bias
-                       outputBlock = getDenseOutputBlock(ec, 
input.getNumRows(), input.getNumColumns());
+                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), false);
+                       outputBlock.allocateDenseBlock();
                        LibMatrixDNN.biasAdd(input, bias, outputBlock, 
_numThreads);
                }
                
@@ -307,10 +298,10 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                ConvolutionParameters params = new ConvolutionParameters(N, C, 
H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, _numThreads);
                if (instOpcode.equalsIgnoreCase("maxpooling") || 
instOpcode.equalsIgnoreCase("relu_maxpooling")) {
                        if(matBlock.isEmptyBlock()) {
-                               outputBlock = new MatrixBlock(N, C*P*Q, true, 
0);
+                               outputBlock = new MatrixBlock(N, C*P*Q, true);
                        }
                        else {
-                               outputBlock = getDenseOutputBlock(ec, N, C*P*Q);
+                               outputBlock = getDenseOutputBlock(N, C*P*Q);
                                if(instOpcode.equalsIgnoreCase("maxpooling"))
                                        
Arrays.fill(outputBlock.getDenseBlock(), -Double.MAX_VALUE);
                                LibMatrixDNN.maxpooling(matBlock, outputBlock, 
params);
@@ -319,10 +310,10 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                else if (instOpcode.equalsIgnoreCase("maxpooling_backward")) {
                        MatrixBlock dout = ec.getMatrixInput(_in2.getName());
                        if(matBlock.isEmptyBlock() || dout.isEmptyBlock()) {
-                               outputBlock = new MatrixBlock(N, C*H*W, true, 
0);
+                               outputBlock = new MatrixBlock(N, C*H*W, true);
                        }
                        else {
-                               outputBlock = getDenseOutputBlock(ec, N, C*H*W);
+                               outputBlock = getDenseOutputBlock(N, C*H*W);
                                LibMatrixDNN.maxpoolingBackward(matBlock, dout, 
outputBlock, params);
                        }
                        ec.releaseMatrixInput(_in2.getName());
@@ -330,10 +321,10 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                else if (instOpcode.equalsIgnoreCase("conv2d")) {
                        MatrixBlock filter = ec.getMatrixInput(_in2.getName());
                        if(filter.isEmptyBlock() || matBlock.isEmptyBlock()) {
-                               outputBlock = new MatrixBlock(N, K*P*Q, true, 
0);
+                               outputBlock = new MatrixBlock(N, K*P*Q, true);
                        }
                        else {
-                               outputBlock = getDenseOutputBlock(ec, N, K*P*Q);
+                               outputBlock = getDenseOutputBlock(N, K*P*Q);
                                LibMatrixDNN.conv2d(matBlock, filter, 
outputBlock, params);
                        }
                        ec.releaseMatrixInput(_in2.getName());
@@ -342,10 +333,10 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        MatrixBlock filter = ec.getMatrixInput(_in3.getName());
                        MatrixBlock bias = ec.getMatrixInput(_in2.getName());
                        if((filter.isEmptyBlock() || matBlock.isEmptyBlock()) 
&& bias.isEmptyBlock()) {
-                               outputBlock = new MatrixBlock(N, K*P*Q, true, 
0);
+                               outputBlock = new MatrixBlock(N, K*P*Q, true);
                        }
                        else {
-                               outputBlock = getDenseOutputBlock(ec, N, K*P*Q);
+                               outputBlock = getDenseOutputBlock(N, K*P*Q);
                                if(!bias.isEmptyBlock())
                                        params.bias = bias;
                                LibMatrixDNN.conv2d(matBlock, filter, 
outputBlock, params);
@@ -356,10 +347,10 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) 
{
                        MatrixBlock dout = ec.getMatrixInput(_in2.getName());
                        if(dout.isEmptyBlock() || matBlock.isEmptyBlock()) {
-                               outputBlock = new MatrixBlock(K, C*R*S, true, 
0);
+                               outputBlock = new MatrixBlock(K, C*R*S, true);
                        }
                        else {
-                               outputBlock = getDenseOutputBlock(ec, K, C*R*S);
+                               outputBlock = getDenseOutputBlock(K, C*R*S);
                                LibMatrixDNN.conv2dBackwardFilter(matBlock, 
dout, outputBlock, params);
                        }
                        ec.releaseMatrixInput(_in2.getName());
@@ -367,10 +358,10 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
                        MatrixBlock dout = ec.getMatrixInput(_in2.getName());
                        if(dout.isEmptyBlock() || matBlock.isEmptyBlock()) {
-                               outputBlock = new MatrixBlock(N, C * H * W, 
true, 0);
+                               outputBlock = new MatrixBlock(N, C * H * W, 
true);
                        }
                        else {
-                               outputBlock = getDenseOutputBlock(ec, N, C * H 
* W);
+                               outputBlock = getDenseOutputBlock(N, C * H * W);
                                LibMatrixDNN.conv2dBackwardData(matBlock, dout, 
outputBlock, params);
                        }
                        ec.releaseMatrixInput(_in2.getName());
@@ -384,10 +375,9 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                ec.setMatrixOutput(getOutputVariableName(), outputBlock);
        }
        
-       private MatrixBlock getDenseOutputBlock(ExecutionContext ec, int 
numRows, int numCols) throws DMLRuntimeException {
-               MatrixBlock outputBlock = new MatrixBlock(numRows, numCols, 
false, numRows * numCols);
+       private MatrixBlock getDenseOutputBlock(int numRows, int numCols) 
throws DMLRuntimeException {
+               MatrixBlock outputBlock = new MatrixBlock(numRows, numCols, 
false);
                outputBlock.allocateDenseBlock();
-               outputBlock.setNonZeros(-1);
                return outputBlock;
        }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d0b23d60/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 9207171..29e59bd 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -24,7 +24,6 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -48,13 +47,14 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 public class LibMatrixDNN {
        
        protected static final Log LOG =  
LogFactory.getLog(LibMatrixDNN.class.getName());
-       // 
------------------------------------------------------------------------------------------------
-       // Useful flags for performance testing:
-       private static boolean DISPLAY_STATISTICS = false;
-       private static final boolean ALLOW_MULTI_THREADED_OPS = true;
-       // 
------------------------------------------------------------------------------------------------
        
-       enum TaskType {
+       //library configurations and external contracts
+       public static final boolean SUPPORTS_SPARSE_OUTPUTS = false; 
//operations able to handle sparse outputs 
+       private static final boolean DISPLAY_STATISTICS = false; //conv2d 
summaries in stats output
+       private static final boolean ALLOW_MULTI_THREADED_OPS = true; //enable 
multi-threading in cp
+       private static final int NUM_TASK_FACTOR = 2; //number of tasks is 
vcores scaled by this factor
+       
+       private enum TaskType {
                MaxPooling_Forward, MaxPooling_Backward, 
                // Alternate approaches that we tried but the performance was 
unsatisfactory be included: direct, non-looped im2col
                LoopedIm2ColConv2d, LoopedIm2ColConv2dBwdFilter, 
LoopedIm2ColConv2dBwdData,
@@ -79,6 +79,7 @@ public class LibMatrixDNN {
        private static AtomicLong loopedConvBwdDataMatMultTime = new 
AtomicLong(0);
        private static AtomicLong loopedConvBwdDataCol2ImTime = new 
AtomicLong(0);
        
+       @SuppressWarnings("unused")
        public static void appendStatistics(StringBuilder sb) {
                if(DMLScript.STATISTICS && DISPLAY_STATISTICS && 
(conv2dDenseCount.get() != 0 || conv2dSparseCount.get() != 0)) {
                        sb.append("LibMatrixDNN dense count 
(conv/bwdF/bwdD/im2col/maxBwd):\t" 
@@ -135,6 +136,7 @@ public class LibMatrixDNN {
         * @param params convolution parameters
         * @throws DMLRuntimeException if DMLRuntimeException occurs
         */
+       @SuppressWarnings("unused")
        public static void conv2dBackwardData(MatrixBlock filter, MatrixBlock 
dout, MatrixBlock outputBlock, ConvolutionParameters params) throws 
DMLRuntimeException {
                params.input1 = filter;
                params.input2 = dout;
@@ -157,6 +159,9 @@ public class LibMatrixDNN {
                }
                
                runConvTask(TaskType.LoopedIm2ColConv2dBwdData, params);
+               
+               //post-processing: maintain nnz
+               outputBlock.recomputeNonZeros();
        }
        
        /**
@@ -168,6 +173,7 @@ public class LibMatrixDNN {
         * @param params convolution parameters
         * @throws DMLRuntimeException if DMLRuntimeException occurs
         */
+       @SuppressWarnings("unused")
        public static void conv2dBackwardFilter(MatrixBlock input, MatrixBlock 
dout, MatrixBlock outputBlock, ConvolutionParameters params) throws 
DMLRuntimeException {
                params.input1 = input;
                params.input2 = dout;
@@ -190,6 +196,9 @@ public class LibMatrixDNN {
                }
                
                runConvTask(TaskType.LoopedIm2ColConv2dBwdFilter, params);
+               
+               //post-processing: maintain nnz
+               outputBlock.recomputeNonZeros();
        }
        
        /**
@@ -259,6 +268,7 @@ public class LibMatrixDNN {
                }
        }
        
+       @SuppressWarnings("unused")
        private static void doLoopedIm2ColConv2dBwdData(int n, MatrixBlock 
dout_reshaped, ConvolutionParameters params) throws DMLRuntimeException {
                MatrixBlock filter = params.input1;
                MatrixBlock dout = params.input2;
@@ -277,6 +287,7 @@ public class LibMatrixDNN {
                }
        }
        
+       @SuppressWarnings("unused")
        private static MatrixBlock doLoopedIm2ColConv2dBwdFilter(int n, 
                        MatrixBlock im2ColOutBlock, MatrixBlock dout_reshaped, 
MatrixBlock partialRetBlock, ConvolutionParameters params) throws 
DMLRuntimeException {
                long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? 
System.nanoTime() : 0;
@@ -306,6 +317,7 @@ public class LibMatrixDNN {
                ret[2] = j % W;
        }
        
+       @SuppressWarnings("unused")
        public static void conv2d(MatrixBlock input, MatrixBlock filter, 
MatrixBlock outputBlock, ConvolutionParameters params) throws 
DMLRuntimeException {
                params.input1 = input;
                params.input2 = filter;
@@ -333,8 +345,12 @@ public class LibMatrixDNN {
                }
                
                runConvTask(TaskType.LoopedIm2ColConv2d, params);
+               
+               //post-processing: maintain nnz
+               outputBlock.recomputeNonZeros();
        }
        
+       @SuppressWarnings("unused")
        private static void doLoopedIm2ColConv2d(int n, MatrixBlock 
im2ColOutBlock, ConvolutionParameters params) throws DMLRuntimeException {
                long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? 
System.nanoTime() : 0;
                doIm2col(n, im2ColOutBlock, params);
@@ -372,6 +388,9 @@ public class LibMatrixDNN {
                                System.arraycopy(matMultOutBlock.denseBlock, 0, 
params.output.denseBlock, destPos, length);
                }
                // 
-----------------------------------------------------------------------------
+               
+               //post-processing: maintain nnz
+               params.output.recomputeNonZeros(); 
        }
        
        /**
@@ -383,6 +402,7 @@ public class LibMatrixDNN {
         * @param params convolution parameters
         * @throws DMLRuntimeException if DMLRuntimeException occurs
         */
+       @SuppressWarnings("unused")
        public static void maxpoolingBackward(MatrixBlock input, MatrixBlock 
dout, MatrixBlock outputBlock, ConvolutionParameters params) throws 
DMLRuntimeException {
                params.input1 = input;
                params.input2 = dout;
@@ -409,6 +429,9 @@ public class LibMatrixDNN {
 
                fillIndexesArray(params);
                runConvTask(TaskType.MaxPooling_Backward, params);
+               
+               //post-processing: maintain nnz 
+               outputBlock.recomputeNonZeros();
        }
        
        private static void fillIndexesArray(ConvolutionParameters params) {
@@ -611,10 +634,13 @@ public class LibMatrixDNN {
                        throw new DMLRuntimeException("Incorrect dimensions for 
relu_backward:" + 
                                input.getNumRows() + " != " + dout.getNumRows() 
+ " || " + input.getNumColumns() + " != " + dout.getNumColumns());
                }
+               
                runConvTask(TaskType.ReluBackward, params);
+               
+               //note: no post-processing as nnz maintained per task
        }
        
-       private static void doReluBackward(int n, ConvolutionParameters params) 
throws DMLRuntimeException {
+       private static long doReluBackward(ConvolutionParameters params, int 
rl, int ru) throws DMLRuntimeException {
                // (X > 0) * dout
                double [] outputArray = params.output.getDenseBlock();
                int numOutCols = params.input1.getNumColumns();
@@ -622,14 +648,14 @@ public class LibMatrixDNN {
                if(!params.input1.isInSparseFormat() && 
!params.input2.isInSparseFormat()) {
                        double [] inputArr = params.input1.getDenseBlock();
                        double [] doutArr = params.input2.getDenseBlock();
-                       for(int i = n*numOutCols; i < (n+1)*numOutCols; i++) {
+                       for(int i = rl*numOutCols; i < ru*numOutCols; i++) {
                                outputArray[i] = inputArr[i] > 0 ? doutArr[i] : 
0;
                        }
                }
                else {
                        // Perform (X > 0)
                        if(params.input1.isInSparseFormat()) {
-                               Iterator<IJV> iter = 
params.input1.sparseBlock.getIterator(n, n+1);
+                               Iterator<IJV> iter = 
params.input1.sparseBlock.getIterator(rl, ru);
                                while(iter.hasNext()) {
                                        IJV ijv = iter.next();
                                        int i = ijv.getI();
@@ -639,13 +665,13 @@ public class LibMatrixDNN {
                        }
                        else {
                                double [] inputArr = 
params.input1.getDenseBlock();
-                               for(int i = n*numOutCols; i < (n+1)*numOutCols; 
i++) {
+                               for(int i = rl*numOutCols; i < ru*numOutCols; 
i++) {
                                        outputArray[i] = inputArr[i] > 0 ? 1 : 
0;
                                }
                        }
                        // Then perform (X > 0) * dout
                        if(params.input2.isInSparseFormat()) {
-                               Iterator<IJV> iter = 
params.input2.sparseBlock.getIterator(n, n+1);
+                               Iterator<IJV> iter = 
params.input2.sparseBlock.getIterator(rl, ru);
                                while(iter.hasNext()) {
                                        IJV ijv = iter.next();
                                        int i = ijv.getI();
@@ -655,11 +681,14 @@ public class LibMatrixDNN {
                        }
                        else {
                                double [] doutArr = 
params.input2.getDenseBlock();
-                               for(int i = n*numOutCols; i < (n+1)*numOutCols; 
i++) {
+                               for(int i = rl*numOutCols; i < ru*numOutCols; 
i++) {
                                        outputArray[i] *= doutArr[i];
                                }
                        }
                }
+               
+               //post-processing: maintain nnz
+               return params.output.recomputeNonZeros(rl, ru-1, 0, 
numOutCols-1);
        }
        
        
@@ -704,9 +733,12 @@ public class LibMatrixDNN {
                else {
                        runConvTask(TaskType.BiasAdd, params);
                }
+               
+               //post-processing: maintain nnz
+               params.output.recomputeNonZeros();
        }
        
-       private static void doBiasAdd(int n1, int n2, ConvolutionParameters 
params) throws DMLRuntimeException {
+       private static void doBiasAdd(ConvolutionParameters params, int rl, int 
ru) throws DMLRuntimeException {
                double [] outputArray = params.output.getDenseBlock();
                int PQ = params.C;
                int numOutCols = params.input1.getNumColumns();
@@ -715,8 +747,8 @@ public class LibMatrixDNN {
                        double [] inputArr = params.input1.getDenseBlock();
                        double [] biasArr = params.input2.getDenseBlock();
                        int K = params.K;
-                       int index = n1*K*PQ;
-                       for(int n = n1; n < n2; n++) {
+                       int index = rl*K*PQ;
+                       for(int n = rl; n < ru; n++) {
                                for(int k = 0; k < K; k++) {
                                        for(int pq = 0; pq < PQ; pq++, index++) 
{
                                                outputArray[index] = 
inputArr[index] + biasArr[k];
@@ -725,9 +757,9 @@ public class LibMatrixDNN {
                        }
                }
                else {
-                       fillBias(params.input2, outputArray, n1, n2, params.N, 
params.K, PQ);
+                       fillBias(params.input2, outputArray, rl, ru, params.N, 
params.K, PQ);
                        if(params.input1.isInSparseFormat()) {
-                               Iterator<IJV> iter = 
params.input1.sparseBlock.getIterator(n1, n2);
+                               Iterator<IJV> iter = 
params.input1.sparseBlock.getIterator(rl, ru);
                                while(iter.hasNext()) {
                                        IJV ijv = iter.next();
                                        int i = ijv.getI();
@@ -737,7 +769,7 @@ public class LibMatrixDNN {
                        }
                        else {
                                double [] inputArr = 
params.input1.getDenseBlock();
-                               for(int i = n1*numOutCols; i < n2*numOutCols; 
i++) {
+                               for(int i = rl*numOutCols; i < ru*numOutCols; 
i++) {
                                        outputArray[i] += inputArr[i];
                                }
                        }
@@ -780,6 +812,9 @@ public class LibMatrixDNN {
                
                fillIndexesArray(params);
                runConvTask(TaskType.MaxPooling_Forward, params);
+               
+               //post-processing: maintain nnz
+               outputBlock.recomputeNonZeros();
        }
 
        private static void doPooling(int n, ConvolutionParameters params) 
throws DMLRuntimeException {
@@ -872,75 +907,63 @@ public class LibMatrixDNN {
                for(int i = 0; i < poolSize; i++) {
                        if(type == TaskType.LoopedIm2ColConv2d || type == 
TaskType.LoopedIm2ColConv2dBwdFilter) {
                                MatrixBlock im2ColOutBlock = new 
MatrixBlock(params.C*params.R*params.S, params.P*params.Q, false);
-                               im2ColOutBlock.allocateDenseBlock(true);
+                               im2ColOutBlock.allocateDenseBlock();
                                im2ColOutBlocks.add(im2ColOutBlock);
                        }
                        
                        if(type == TaskType.LoopedIm2ColConv2dBwdFilter) {
                                MatrixBlock partialRetBlock = new 
MatrixBlock(params.C*params.R*params.S, params.K, false);
-                               partialRetBlock.allocateDenseBlock(true);
+                               partialRetBlock.allocateDenseBlock();
                                partialRetBlocks.add(partialRetBlock);
                        }
                        
                        if(type == TaskType.LoopedIm2ColConv2dBwdData || type 
== TaskType.LoopedIm2ColConv2dBwdFilter) {
                                MatrixBlock doutReshapedBlock = new 
MatrixBlock(params.P*params.Q, params.K, false);
-                               doutReshapedBlock.allocateDenseBlock(true);
+                               doutReshapedBlock.allocateDenseBlock();
                                doutReshapedBlocks.add(doutReshapedBlock);
                        }
                }
        }
        // Methods to execute convolution-related tasks using multiple threads.
        private static void runConvTask(TaskType type, ConvolutionParameters 
params) throws DMLRuntimeException {
-               int constrainedNumThreads = 
OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+               int k = 
OptimizerUtils.getConstrainedNumThreads(params.numThreads);
                ConcurrentLinkedQueue<MatrixBlock> im2ColOutBlocks = new 
ConcurrentLinkedQueue<MatrixBlock>();
                ConcurrentLinkedQueue<MatrixBlock> doutReshapedBlocks = new 
ConcurrentLinkedQueue<MatrixBlock>();
                ConcurrentLinkedQueue<MatrixBlock> partialRetBlocks = new 
ConcurrentLinkedQueue<MatrixBlock>();
-               if (ALLOW_MULTI_THREADED_OPS && params.isOutputThreadSafe() && 
constrainedNumThreads > 1) {
-                       int poolSize = Math.min(constrainedNumThreads, 
params.N);
+               
+               if (ALLOW_MULTI_THREADED_OPS && params.isOutputThreadSafe() && 
k > 1) {
+                       int poolSize = Math.min(k, params.N);
                        addMatrixBlocks(poolSize, type, params, 
im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks);
+                       
                        ArrayList<ConvTask> tasks = new ArrayList<ConvTask>();
-                       int NSize = params.N - poolSize;
-                       if(NSize >= constrainedNumThreads) {
-                               for(int n = 0; n < params.N; n++) 
-                                       tasks.add(new ConvTask(n, n+1, type, 
params, im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks));
-                       }
-                       else {
-                               int numNTasks = (int) Math.ceil(((double) 
NSize) / constrainedNumThreads);
-                               for (int n = 0; n < NSize; n += numNTasks) {
-                                       tasks.add(new ConvTask(n, 
Math.min(NSize, n+numNTasks), type, params, im2ColOutBlocks, 
doutReshapedBlocks, partialRetBlocks));
-                               }
-                               for (int n = NSize; n < params.N; n++)
-                                       tasks.add(new ConvTask(n, n+1, type, 
params, im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks));
-                       }
+                       int blklen = 
(int)(Math.ceil((double)params.N/poolSize/NUM_TASK_FACTOR));
+                       for( int i=0; i<poolSize*NUM_TASK_FACTOR && 
i*blklen<params.N; i++ )
+                               tasks.add(new ConvTask(i*blklen, 
Math.min((i+1)*blklen, params.N), 
+                                               type, params, im2ColOutBlocks, 
doutReshapedBlocks, partialRetBlocks));
                        
-                       ExecutorService pool = Executors.newFixedThreadPool( 
poolSize );
-                       List<Future<Object>> taskret;
                        try {
-                               taskret = pool.invokeAll(tasks);
+                               ExecutorService pool = 
Executors.newFixedThreadPool( poolSize );
+                               List<Future<Long>> taskret = 
pool.invokeAll(tasks);
                                pool.shutdown();
-                               for( Future<Object> task : taskret ) {
-                                       task.get();
-                               }
+                               for( Future<Long> task : taskret )
+                                       params.output.nonZeros += task.get();
                                if(type == 
TaskType.LoopedIm2ColConv2dBwdFilter) {
                                        for(MatrixBlock partialRetBlock : 
partialRetBlocks) {
                                                
elementWiseInPlaceTransposedAddition(params.output, partialRetBlock);
                                        }
                                }
-                       } catch (InterruptedException e) {
-                               throw new DMLRuntimeException("Error while 
executing multi-threaded " + type.name(), e);
-                       } catch (ExecutionException e) {
+                       } 
+                       catch (Exception e) {
                                throw new DMLRuntimeException("Error while 
executing multi-threaded " + type.name(), e);
                        }
                }
                else {
                        addMatrixBlocks(1, type, params, im2ColOutBlocks, 
doutReshapedBlocks, partialRetBlocks);
-                       ConvTask task = new ConvTask(0, 0, type, params, 
im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks);
                        try {
-                               for(int n = 0; n < params.N; n++) {
-                                       task.n1 = n;
-                                       task.n2 = n+1;
-                                       task.call();
-                               }
+                               //execute single task and maintain nnz if 
supported
+                               params.output.setNonZeros(new ConvTask(0, 
params.N, type, params, im2ColOutBlocks, 
+                                               doutReshapedBlocks, 
partialRetBlocks).call());
+                               
                                if(type == 
TaskType.LoopedIm2ColConv2dBwdFilter) {
                                        for(MatrixBlock partialRetBlock : 
partialRetBlocks) {
                                                
elementWiseInPlaceTransposedAddition(params.output, partialRetBlock);
@@ -958,92 +981,94 @@ public class LibMatrixDNN {
         * to be executed in multi-thread manner.
         * 
         */
-       private static class ConvTask implements Callable<Object> {
-               public int n1; public int n2; 
-               ConvolutionParameters params;
-               TaskType type;
-               ConcurrentLinkedQueue<MatrixBlock> im2ColOutBlocks;
-               ConcurrentLinkedQueue<MatrixBlock> partialRetBlocks;
-               ConcurrentLinkedQueue<MatrixBlock> doutReshapedBlocks;
-               public ConvTask(int n1, int n2, TaskType type, 
ConvolutionParameters params, 
+       private static class ConvTask implements Callable<Long> 
+       {
+               public int _rl; 
+               public int _ru; 
+               private final ConvolutionParameters _params;
+               private final TaskType _type;
+               private final ConcurrentLinkedQueue<MatrixBlock> 
_im2ColOutBlocks;
+               private final ConcurrentLinkedQueue<MatrixBlock> 
_partialRetBlocks;
+               private final ConcurrentLinkedQueue<MatrixBlock> 
_doutReshapedBlocks;
+               
+               public ConvTask(int rl, int ru, TaskType type, 
ConvolutionParameters params, 
                                ConcurrentLinkedQueue<MatrixBlock> 
im2ColOutBlocks,
                                ConcurrentLinkedQueue<MatrixBlock> 
doutReshapedBlocks,
                                ConcurrentLinkedQueue<MatrixBlock> 
partialRetBlocks) {
-                       this.n1 = n1;
-                       this.n2 = n2;
-                       this.type = type;
-                       this.params = params;
-                       this.im2ColOutBlocks = im2ColOutBlocks;
-                       this.partialRetBlocks = partialRetBlocks;
-                       this.doutReshapedBlocks = doutReshapedBlocks;
+                       _rl = rl;
+                       _ru = ru;
+                       _type = type;
+                       _params = params;
+                       _im2ColOutBlocks = im2ColOutBlocks;
+                       _partialRetBlocks = partialRetBlocks;
+                       _doutReshapedBlocks = doutReshapedBlocks;
                }
                
                @Override
-               public Object call() throws DMLRuntimeException {
-                       switch(type) {
+               public Long call() throws DMLRuntimeException {
+                       long lnnz = 0; //nnz per partition
+                       
+                       switch(_type) {
                                case MaxPooling_Forward:
-                               {
-                                       for(int n = n1; n < n2; n++) {
-                                               doPooling(n, params);
-                                       }
+                                       for(int n = _rl; n < _ru; n++)
+                                               doPooling(n, _params);
                                        break;
-                               }
                                case MaxPooling_Backward:
-                                       for(int n = n1; n < n2; n++) 
-                                               doPoolingBackward(n, params);
+                                       for(int n = _rl; n < _ru; n++) 
+                                               doPoolingBackward(n, _params);
                                        break;
                                case BiasAdd:
-                                       doBiasAdd(n1, n2, params);
+                                       doBiasAdd(_params, _rl, _ru);
                                        break;
                                case ReluBackward:
-                                       for(int n = n1; n < n2; n++) 
-                                               doReluBackward(n, params);
+                                       lnnz = doReluBackward(_params, _rl, 
_ru);
                                        break;
                                case LoopedIm2ColConv2d:
                                {       
-                                       MatrixBlock im2ColOutBlock = 
im2ColOutBlocks.remove();
-                                       for(int n = n1; n < n2; n++) 
-                                               doLoopedIm2ColConv2d(n, 
im2ColOutBlock, params);
-                                       im2ColOutBlocks.add(im2ColOutBlock);
-                                       if(params.bias != null)
-                                               addBias(n1, n2, params);
+                                       MatrixBlock im2ColOutBlock = 
_im2ColOutBlocks.remove();
+                                       for(int n = _rl; n < _ru; n++) 
+                                               doLoopedIm2ColConv2d(n, 
im2ColOutBlock, _params);
+                                       _im2ColOutBlocks.add(im2ColOutBlock);
+                                       if(_params.bias != null)
+                                               addBias(_params, _rl, _ru);
                                        break;
                                }
                                case LoopedIm2ColConv2dBwdFilter:
                                {
-                                       MatrixBlock im2ColOutBlock = 
im2ColOutBlocks.remove();
-                                       MatrixBlock partialRetBlock = 
partialRetBlocks.remove();
-                                       MatrixBlock doutReshapedBlock = 
doutReshapedBlocks.remove();
-                                       for(int n = n1; n < n2; n++) 
-                                               partialRetBlock = 
doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock, doutReshapedBlock, 
partialRetBlock, params);
-                                       im2ColOutBlocks.add(im2ColOutBlock);
-                                       partialRetBlocks.add(partialRetBlock);
-                                       
doutReshapedBlocks.add(doutReshapedBlock);
+                                       MatrixBlock im2ColOutBlock = 
_im2ColOutBlocks.remove();
+                                       MatrixBlock partialRetBlock = 
_partialRetBlocks.remove();
+                                       MatrixBlock doutReshapedBlock = 
_doutReshapedBlocks.remove();
+                                       for(int n = _rl; n < _ru; n++) 
+                                               partialRetBlock = 
doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock, doutReshapedBlock, 
partialRetBlock, _params);
+                                       _im2ColOutBlocks.add(im2ColOutBlock);
+                                       _partialRetBlocks.add(partialRetBlock);
+                                       
_doutReshapedBlocks.add(doutReshapedBlock);
                                        break;
                                }
                                case LoopedIm2ColConv2dBwdData:
                                {
-                                       MatrixBlock doutReshapedBlock = 
doutReshapedBlocks.remove();
-                                       for(int n = n1; n < n2; n++) 
-                                               doLoopedIm2ColConv2dBwdData(n, 
doutReshapedBlock, params);
-                                       
doutReshapedBlocks.add(doutReshapedBlock);
+                                       MatrixBlock doutReshapedBlock = 
_doutReshapedBlocks.remove();
+                                       for(int n = _rl; n < _ru; n++) 
+                                               doLoopedIm2ColConv2dBwdData(n, 
doutReshapedBlock, _params);
+                                       
_doutReshapedBlocks.add(doutReshapedBlock);
                                        break;
                                }
                                default:
-                                       throw new 
DMLRuntimeException("Unsupported ConvTask:" + type.name());
+                                       throw new 
DMLRuntimeException("Unsupported ConvTask:" + _type.name());
                        }
-                       return null;
+                       
+                       return lnnz;
                }
        }
        
-       private static void addBias(int n1, int n2, ConvolutionParameters 
params) {
+       private static void addBias(ConvolutionParameters params, int rl, int 
ru) {
                int PQ = params.P*params.Q;
                int K = params.K;
                double [] outputArr = params.output.getDenseBlock();
                if(!params.bias.isInSparseFormat()) {
                        double [] biasArr = params.bias.getDenseBlock();
-                       int index = n1*K*PQ;
-                       for(int n = n1; n < n2; n++) {
+                       int index = rl*K*PQ;
+                       for(int n = rl; n < ru; n++) {
                                for(int k = 0; k < K; k++) {
                                        for(int pq = 0; pq < PQ; pq++, index++) 
{
                                                outputArr[index] += biasArr[k];
@@ -1057,7 +1082,7 @@ public class LibMatrixDNN {
                                IJV ijv = iter.next();
                                int k = ijv.getI();
                                double val = ijv.getV();
-                               for(int n = n1; n < n2; n++) {
+                               for(int n = rl; n < ru; n++) {
                                        int index = n*K*PQ + k*PQ;
                                        for(int pq = 0; pq < PQ; pq++, index++) 
{
                                                outputArr[index] += val;

Reply via email to