Repository: incubator-systemml
Updated Branches:
  refs/heads/master 20e05458b -> 2ebf885a6


[SYSTEMML-769] Improved performance of LibMatrixDNN's conv2d and
conv2d_backward_filter

- Fixed bug while iterating through sparse conv2d_backward_filter
- Also added vectorized conv2d

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2ebf885a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2ebf885a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2ebf885a

Branch: refs/heads/master
Commit: 2ebf885a6919e1cb0598e2aab4d0ffb46b8e0ab5
Parents: 20e0545
Author: Niketan Pansare <[email protected]>
Authored: Fri Jul 8 20:26:16 2016 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Fri Jul 8 20:28:25 2016 -0700

----------------------------------------------------------------------
 .../sysml/runtime/matrix/data/LibMatrixDNN.java | 558 ++++++++++++++-----
 1 file changed, 410 insertions(+), 148 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2ebf885a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index d9faf7e..26e2b8b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -20,6 +20,7 @@ package org.apache.sysml.runtime.matrix.data;
 
 import java.lang.ref.SoftReference;
 import java.util.ArrayList;
+import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ConcurrentHashMap;
@@ -29,12 +30,16 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.atomic.AtomicLong;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.util.ConvolutionUtils;
 
 
 public class LibMatrixDNN {
+       
+       protected static final Log LOG =  
LogFactory.getLog(LibMatrixDNN.class.getName());
 
        public static final boolean ALLOW_MULTI_THREADED_OPS = true;
        // Using hashmap to avoid any performance impacts of multimap
@@ -62,13 +67,14 @@ public class LibMatrixDNN {
        enum TaskType {
                ReshapeCol, Rotate180, Im2Col, Col2Im, MaxPooling_Forward, 
MaxPooling_Backward, LoopBasedConv2d
        }
-       public static final int TASK_SIZE = 64; // to take care of extremely 
small tasks
        
        public static class TemporaryConvolutionData {
                public int [] minIndexArrR;
                public int [] minIndexArrS;
                public int [] maxIndexArrR;
                public int [] maxIndexArrS;
+               int minCommonIndexS;
+               int maxCommonIndexS;
        }
        
        public static class ConvolutionParameters {
@@ -159,6 +165,9 @@ public class LibMatrixDNN {
                                dout.getNumRows() != params.N || 
dout.getNumColumns() != params.K*params.P*params.Q) {
                        throw new DMLRuntimeException("Incorrect input to 
conv2d_backward_filter");
                }
+               if(params.stride_h <= 0 || params.stride_w <= 0) {
+                       throw new DMLRuntimeException("Only positive strides 
supported");
+               }
                
                int constrainedNumThreads = 
OptimizerUtils.getConstrainedNumThreads(params.numThreads);
                if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) {
@@ -198,7 +207,7 @@ public class LibMatrixDNN {
                }
        }
        
-       public static void doConv2d_Backward_Filter(int k, int c, int r, int s, 
ConvolutionParameters params) {
+       private static void doConv2d_Backward_Filter(int k, int c, int r, int 
s, ConvolutionParameters params) throws DMLRuntimeException {
                double [] inputArray = null;
                if (!params.input1.isInSparseFormat())
                        inputArray = params.input1.getDenseBlock();
@@ -207,62 +216,125 @@ public class LibMatrixDNN {
                        doutArray = params.input2.getDenseBlock();
                double [] outputArray = params.output.getDenseBlock();
                
-               long outputVal = 0;
-               if(doutArray != null) {
-                       for (int n = 0; n < params.N; n++) {
-                               for (int p = 0; p < params.P; p++) {
-                                       for (int q = 0; q < params.Q; q++) {
-                                               int h = p*params.stride_h + r - 
params.pad_h;
-                                               int w = q*params.stride_w + s - 
params.pad_w;
-                                               if(h >= 0 && h < params.H && w 
>= 0 && w < params.W) {
-                                                       double doutVal = 
doutArray[n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q + q];
-                                                       if(doutVal != 0) {
-                                                               if(inputArray 
!= null)
-                                                                       
outputVal += doutVal*inputArray[n*params.C*params.H*params.W + 
c*params.H*params.W + h*params.W+w];
-                                                               else 
-                                                                       
outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + 
h*params.W + w);
-                                                       }
-                                               }
-                                       }
-                               }
-                       }
+               double outputVal = 0;
+               if(inputArray == null && doutArray == null) {
+                       outputVal = doConv2d_Backward_Filter_SparseSparse(k, c, 
r, s, params);
+               }
+               else if(inputArray != null && doutArray == null) {
+                       outputVal = doConv2d_Backward_Filter_DenseSparse(k, c, 
r, s, params, inputArray);
+               }
+               else if(inputArray == null && doutArray != null) {
+                       outputVal = doConv2d_Backward_Filter_SparseDense(k, c, 
r, s, params, doutArray);
                }
                else {
-                       MatrixBlock dout = params.input2;
-                       if( !dout.isEmptyBlock(false) ) {
-                               int start=0;
-                               int rlen = dout.getNumRows();
-                               int clen = dout.getNumColumns();
-                               for(int r1=0; 
r1<Math.min(dout.sparseBlock.numRows(), rlen); r1++, start+=clen)
-                               {
-                                       if(dout.sparseBlock.isEmpty(r1)) 
-                                               continue;
-                                       int pos = dout.sparseBlock.pos(r1);
-                                       int len = dout.sparseBlock.size(r1);
-                                       int[] aix = 
dout.sparseBlock.indexes(r1);
-                                       double[] avals = 
dout.sparseBlock.values(r1);
-                                       
-                                       for(int i=pos; i<pos+len; i++) {
-                                               int index = start+aix[i];
-                                               double doutVal = avals[i];
-                                               int n = index / clen; 
-                                               int p = index / params.Q;
-                                               int q = index % params.Q;
-                                               int h = p*params.stride_h + r - 
params.pad_h;
-                                               int w = q*params.stride_w + s - 
params.pad_w;
-                                               if(h >= 0 && h < params.H && w 
>= 0 && w < params.W && doutVal != 0) {
-                                                       if(inputArray != null)
-                                                               outputVal += 
doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + 
h*params.W+w];
-                                                       else 
-                                                               outputVal += 
doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
-                                               }
-                                       }
+                       outputVal = doConv2d_Backward_Filter_DenseDense(k, c, 
r, s, params, inputArray, doutArray);
+               }
+               
+               outputArray[k*params.C*params.R*params.S + c*params.R*params.S 
+ r*params.S + s] = outputVal;
+       }
+       
+       private static double doConv2d_Backward_Filter_SparseDense(int k, int 
c, int r, int s, ConvolutionParameters params, double [] doutArray) throws 
DMLRuntimeException {
+               double outputVal = 0;
+               // To ensure h >= 0 && h < params.H 
+               int pMin = (int) Math.max(0, 
Math.ceil(((double)(params.pad_h-r))/params.stride_h));
+               int qMin = (int) Math.max(0, 
Math.ceil(((double)(params.pad_w-s))/params.stride_w));
+               // To ensure w >= 0 && w < params.W 
+               int pMax = (int) Math.min(params.P, 
Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h));
+               int qMax = (int) Math.min(params.Q, 
Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w));
+               
+               // TODO: Optimize this case
+               for (int n = 0; n < params.N; n++) {
+                       int doutOffset = n*params.K*params.P*params.Q + 
k*params.P*params.Q;
+                       for (int p = pMin; p < pMax; p++) {
+                               for (int q = qMin; q < qMax; q++) {
+                                       int h = p*params.stride_h + r - 
params.pad_h;
+                                       int w = q*params.stride_w + s - 
params.pad_w;
+                                       outputVal += doutArray[doutOffset + 
p*params.Q + q]*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W 
+ w);
                                }
-                       }       
+                       }
                }
                
+               return outputVal;
+       }
+       
+       private static double doConv2d_Backward_Filter_DenseDense(int k, int c, 
int r, int s, ConvolutionParameters params, double [] inputArray, double [] 
doutArray) {
+               double outputVal = 0;
+               // To ensure h >= 0 && h < params.H 
+               int pMin = (int) Math.max(0, 
Math.ceil(((double)(params.pad_h-r))/params.stride_h));
+               int qMin = (int) Math.max(0, 
Math.ceil(((double)(params.pad_w-s))/params.stride_w));
+               // To ensure w >= 0 && w < params.W 
+               int pMax = (int) Math.min(params.P, 
Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h));
+               int qMax = (int) Math.min(params.Q, 
Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w));
                
-               outputArray[k*params.C*params.R*params.S + c*params.R*params.S 
+ r*params.S + s] = outputVal;
+               for (int n = 0; n < params.N; n++) {
+                       int inputOffset =  n*params.C*params.H*params.W + 
c*params.H*params.W + s - params.pad_w;
+                       int doutOffset = n*params.K*params.P*params.Q + 
k*params.P*params.Q;
+                       for (int p = pMin; p < pMax; p++) {
+                               int h = p*params.stride_h + r - params.pad_h;
+                               for (int q = qMin; q < qMax; q++) {
+                                       int w = q*params.stride_w;
+                                       outputVal += doutArray[doutOffset + 
p*params.Q + q]*inputArray[inputOffset + h*params.W+w];
+                               }
+                       }
+               }
+                               
+               return outputVal;
+       }
+       
+       private static void computeTensorIndexes(int i, int j, int [] ret, int 
N, int C, int H, int W) throws DMLRuntimeException {
+               ret[0] = i;
+               ret[1] = j / (H*W);
+               ret[2] = (j - ret[1]*(H*W))/W;
+               ret[3] = j % W;
+       }
+       
+       private static double doConv2d_Backward_Filter_DenseSparse(int k, int 
c, int r, int s, ConvolutionParameters params, double [] inputArray) throws 
DMLRuntimeException {
+               MatrixBlock dout = params.input2;
+               double outputVal = 0;
+               Iterator<IJV> iter = dout.sparseBlock.getIterator();
+               int [] tensorIndexes = new int[4];
+               while(iter.hasNext()) {
+                       IJV ijv = iter.next();
+                       computeTensorIndexes(ijv.getI(), ijv.getJ(), 
tensorIndexes, params.N, params.K, params.P, params.Q);
+                       if(k == tensorIndexes[1]) {
+                               int n = tensorIndexes[0];
+                               int p = tensorIndexes[2];
+                               int q = tensorIndexes[3];
+                               
+                               double doutVal = ijv.getV();
+                               int h = p*params.stride_h + r - params.pad_h;
+                               int w = q*params.stride_w + s - params.pad_w;
+                               if(h >= 0 && h < params.H && w >= 0 && w < 
params.W) {
+                                       outputVal += 
doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + 
h*params.W+w];
+                               }
+                       }
+               }
+               return outputVal;
+       }
+       
+       private static double doConv2d_Backward_Filter_SparseSparse(int k, int 
c, int r, int s, ConvolutionParameters params) throws DMLRuntimeException {
+               MatrixBlock dout = params.input2;
+               double outputVal = 0;
+               Iterator<IJV> iter = dout.sparseBlock.getIterator();
+               int [] tensorIndexes = new int[4];
+               
+               while(iter.hasNext()) {
+                       IJV ijv = iter.next();
+                       computeTensorIndexes(ijv.getI(), ijv.getJ(), 
tensorIndexes, params.N, params.K, params.P, params.Q);
+                       if(k == tensorIndexes[1]) {
+                               int n = tensorIndexes[0];
+                               int p = tensorIndexes[2];
+                               int q = tensorIndexes[3];
+                               
+                               double doutVal = ijv.getV();
+                               int h = p*params.stride_h + r - params.pad_h;
+                               int w = q*params.stride_w + s - params.pad_w;
+                               if(h >= 0 && h < params.H && w >= 0 && w < 
params.W) {
+                                       outputVal += 
doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w);
+                               }
+                       }
+               }
+               return outputVal;
        }
        
        private static class ConvBackwardFilterTask implements Callable<Object> 
{
@@ -294,25 +366,55 @@ public class LibMatrixDNN {
                        throw new DMLRuntimeException("Incorrect input to 
conv2d");
                }
                
-               params.tmpData = new TemporaryConvolutionData();                
-               params.tmpData.minIndexArrR = new int[params.R];
-               params.tmpData.maxIndexArrR = new int[params.R];
-               params.tmpData.minIndexArrS = new int[params.S];
-               params.tmpData.maxIndexArrS = new int[params.S];
-               for (int r = 0; r < params.R; r++) {
-                       params.tmpData.minIndexArrR[r] = getMinPQ(params.pad_h, 
r, params.stride_h);
-                       params.tmpData.maxIndexArrR[r] = getMaxPQ(params.pad_h, 
r, params.stride_h, params.P, params.H);
+               params.tmpData = new TemporaryConvolutionData();
+               if(input.isInSparseFormat()) {
+                       params.tmpData.minIndexArrR = new int[params.H];
+                       params.tmpData.minIndexArrS = new int[params.W];
+                       for(int h = 0; h < params.H; h++) {
+                               for (int r = 0; r < params.R; r++) {
+                                       // int h = p*params.stride_h + r - 
params.pad_h;
+                                       if((h + params.pad_h - r) % 
params.stride_h == 0) {
+                                               params.tmpData.minIndexArrR[h] 
= r;
+                                               break;
+                                       }
+                               }
+                       }
+                       for(int w = 0; w < params.W; w++) {
+                               for (int s = 0; s < params.S; s++) {
+                                       // int h = p*params.stride_h + r - 
params.pad_h;
+                                       if((w + params.pad_w - s) % 
params.stride_w == 0) {
+                                               params.tmpData.minIndexArrS[w] 
= s;
+                                               break;
+                                       }
+                               }
+                       }
                }
-               for (int s = 0; s < params.S; s++) {
-                       params.tmpData.minIndexArrS[s] = getMinPQ(params.pad_w, 
s, params.stride_w);
-                       params.tmpData.maxIndexArrS[s] = getMaxPQ(params.pad_w, 
s, params.stride_w, params.Q, params.W);
+               else {
+                       params.tmpData.minIndexArrR = new int[params.R];
+                       params.tmpData.maxIndexArrR = new int[params.R];
+                       params.tmpData.minIndexArrS = new int[params.S];
+                       params.tmpData.maxIndexArrS = new int[params.S];
+                       for (int r = 0; r < params.R; r++) {
+                               params.tmpData.minIndexArrR[r] = 
getMinPQ(params.pad_h, r, params.stride_h);
+                               params.tmpData.maxIndexArrR[r] = 
getMaxPQ(params.pad_h, r, params.stride_h, params.P, params.H);
+                       }
+                       for (int s = 0; s < params.S; s++) {
+                               params.tmpData.minIndexArrS[s] = 
getMinPQ(params.pad_w, s, params.stride_w);
+                               params.tmpData.maxIndexArrS[s] = 
getMaxPQ(params.pad_w, s, params.stride_w, params.Q, params.W);
+                       }
+                       params.tmpData.minCommonIndexS = 
params.tmpData.minIndexArrS[0];
+                       params.tmpData.maxCommonIndexS = 
params.tmpData.maxIndexArrS[0];
+                       for (int s = 1; s < params.S; s++) {
+                               params.tmpData.minCommonIndexS = 
Math.max(params.tmpData.minCommonIndexS, params.tmpData.minIndexArrS[s]);
+                               params.tmpData.maxCommonIndexS = 
Math.min(params.tmpData.maxCommonIndexS, params.tmpData.maxIndexArrS[s]);
+                       }
                }
                
                int constrainedNumThreads = 
OptimizerUtils.getConstrainedNumThreads(params.numThreads);
                if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) {
                        for (int n = 0; n < params.N; n++) {
                                for (int k = 0; k < params.K; k++) {
-                                       doLoopBasedConv2d(n, k, params);
+                                       doLoopBasedConv2d(n, n+1, k, params);
                                }
                        }
                }
@@ -345,102 +447,255 @@ public class LibMatrixDNN {
                }
        }
        
-       /**
-        * This is essentially memory-less operation and can be used when the 
memory pressure is extremely high.
-        * @param n
-        * @param k
-        * @param params
-        */
-       private static void doLoopBasedConv2d(int n, int k, 
ConvolutionParameters params) {
-               double [] inputArray = null;
-               if (!params.input1.isInSparseFormat())
-                       inputArray = params.input1.getDenseBlock();
-               double [] filterArray = null;
-               if (!params.input2.isInSparseFormat())
-                       filterArray = params.input2.getDenseBlock();
+       private static void doLoopBasedConv2dDenseDense(int n1, int n2, int k, 
ConvolutionParameters params, 
+                       double [] inputArray, double [] filterArray) {
                double [] outputArray = params.output.getDenseBlock();
-               
-               int outputOffset = n*params.K*params.P*params.Q + 
k*params.P*params.Q;
-               
                int [] minIndexArrR = params.tmpData.minIndexArrR;
                int [] maxIndexArrR = params.tmpData.maxIndexArrR;
                int [] minIndexArrS = params.tmpData.minIndexArrS;
                int [] maxIndexArrS = params.tmpData.maxIndexArrS;
                
-               if(inputArray != null && filterArray != null) {
-                       for (int c = 0; c < params.C; c++) {
-                               for (int r = 0; r < params.R; r++) {
-                                       int filterOffset = 
k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
-                                       for (int p = minIndexArrR[r]; p < 
maxIndexArrR[r]; p++) {
-                                               for (int s = 0; s < params.S; 
s++) {
-                                                       double filterVal = 
filterArray[filterOffset + s];
-                                                       if(filterVal != 0) {
-                                                               int h = 
p*params.stride_h + r - params.pad_h;
-                                                               for (int q = 
minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
-                                                                       int w = 
q*params.stride_w + s - params.pad_w;
-                                                                       
outputArray[outputOffset + p*params.Q + q] += denseConvMultiply(inputArray, 
filterVal, params, n, c, h, w);
+               int minCommonIndexS = params.tmpData.minCommonIndexS;
+               int maxCommonIndexS = params.tmpData.maxCommonIndexS;
+               
+               
+               int minS = 0;
+               if(params.S >= 4) {
+                       minS = params.S - params.S % 4;
+                       for (int n = n1; n < n2; n++) {
+                               for (int c = 0; c < params.C; c++) {
+                                       for (int r = 0; r < params.R; r++) {
+                                               final int filterOffset = 
k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
+                                               for (int p = minIndexArrR[r]; p 
< maxIndexArrR[r]; p++) {
+                                                       final int h = 
p*params.stride_h + r - params.pad_h;
+                                                       final int inputOffSet = 
n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w;
+                                                       final int outputOffset 
= n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q;
+                                                       // 
------------------------------------------------------------------------
+                                                       // Efficient striding 
with vectorization
+                                                       for (int q = 
minCommonIndexS; q < maxCommonIndexS; q++) {
+                                                               final int 
wOffset = inputOffSet + q*params.stride_w;
+                                                               final int 
outOffsetWithQ = outputOffset + q;
+                                                               for (int s = 0; 
s < minS; s += 4) {
+                                                                       final 
int inOffsetWithS = wOffset + s;
+                                                                       final 
int filterOffsetWithS = filterOffset + s;
+                                                                       
outputArray[outOffsetWithQ] += 
inputArray[inOffsetWithS]*filterArray[filterOffsetWithS]
+                                                                               
        + inputArray[inOffsetWithS+1]*filterArray[filterOffsetWithS+1]
+                                                                               
        + inputArray[inOffsetWithS+2]*filterArray[filterOffsetWithS+2]
+                                                                               
        + inputArray[inOffsetWithS+3]*filterArray[filterOffsetWithS+3];
                                                                }
                                                        }
+                                                       // 
------------------------------------------------------------------------
                                                }
                                        }
                                }
                        }
                }
-               else if(inputArray != null && filterArray == null) {
+               
+               for (int n = n1; n < n2; n++) {
                        for (int c = 0; c < params.C; c++) {
                                for (int r = 0; r < params.R; r++) {
+                                       final int filterOffset = 
k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
                                        for (int p = minIndexArrR[r]; p < 
maxIndexArrR[r]; p++) {
-                                               for (int s = 0; s < params.S; 
s++) {
-                                                       double filterVal = 
params.input2.quickGetValue(k, c*params.R*params.S + r*params.S + s);
-                                                       if(filterVal != 0) {
-                                                               int h = 
p*params.stride_h + r - params.pad_h;
-                                                               for (int q = 
minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
-                                                                       int w = 
q*params.stride_w + s - params.pad_w;
-                                                                       
outputArray[outputOffset + p*params.Q + q] += denseConvMultiply(inputArray, 
filterVal, params, n, c, h, w);
-                                                               }
+                                               final int h = p*params.stride_h 
+ r - params.pad_h;
+                                               final int inputOffSet = 
n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w;
+                                               final int outputOffset = 
n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q;
+                                               // 
------------------------------------------------------------------------
+                                               // Efficient striding
+                                               for (int q = minCommonIndexS; q 
< maxCommonIndexS; q++) {
+                                                       final int wOffset = 
inputOffSet + q*params.stride_w;
+                                                       for (int s = minS; s < 
params.S; s++) {
+                                                               
outputArray[outputOffset + q] += inputArray[wOffset + 
s]*filterArray[filterOffset + s];
                                                        }
                                                }
+                                               // 
------------------------------------------------------------------------
                                        }
                                }
                        }
-               }
-               else if(inputArray == null && filterArray != null) {
+                       
+                       
                        for (int c = 0; c < params.C; c++) {
                                for (int r = 0; r < params.R; r++) {
-                                       int filterOffset = 
k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
+                                       final int filterOffset = 
k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
                                        for (int p = minIndexArrR[r]; p < 
maxIndexArrR[r]; p++) {
+                                               final int h = p*params.stride_h 
+ r - params.pad_h;
+                                               final int inputOffSet = 
n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w;
+                                               final int outputOffset = 
n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q;
+                                               // 
------------------------------------------------------------------------
+                                               // Inefficient striding
                                                for (int s = 0; s < params.S; 
s++) {
-                                                       double filterVal = 
filterArray[filterOffset + s];
-                                                       if(filterVal != 0) {
-                                                               int h = 
p*params.stride_h + r - params.pad_h;
-                                                               for (int q = 
minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
-                                                                       int w = 
q*params.stride_w + s - params.pad_w;
-                                                                       
outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(inputArray, 
filterVal, params, n, c, h, w);
-                                                               }
+                                                       for (int q = 
minIndexArrS[s]; q < minCommonIndexS; q++) {
+                                                               final int w = 
q*params.stride_w + s;
+                                                               
outputArray[outputOffset + q] += inputArray[inputOffSet + 
w]*filterArray[filterOffset + s];
+                                                       }
+                                                       for (int q = 
maxCommonIndexS; q < maxIndexArrS[s]; q++) {
+                                                               final int w = 
q*params.stride_w + s;
+                                                               
outputArray[outputOffset + q] += inputArray[inputOffSet + 
w]*filterArray[filterOffset + s];
                                                        }
                                                }
+                                               // 
------------------------------------------------------------------------
                                        }
                                }
                        }
                }
-               else if(inputArray == null && filterArray == null) {
-                       for (int c = 0; c < params.C; c++) {
-                               for (int r = 0; r < params.R; r++) {
-                                       for (int p = minIndexArrR[r]; p < 
maxIndexArrR[r]; p++) {
-                                               for (int s = 0; s < params.S; 
s++) {
-                                                       double filterVal = 
params.input2.quickGetValue(k, c*params.R*params.S + r*params.S + s);
-                                                       if(filterVal != 0) {
-                                                               int h = 
p*params.stride_h + r - params.pad_h;
-                                                               for (int q = 
minIndexArrS[s]; q < maxIndexArrS[s]; q++) {
-                                                                       int w = 
q*params.stride_w + s - params.pad_w;
-                                                                       
outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(inputArray, 
filterVal, params, n, c, h, w);
-                                                               }
+       }
+       
+       private static void doLoopBasedConv2dDenseSparse(int n, int k, 
ConvolutionParameters params, double [] inputArray) throws DMLRuntimeException {
+               double [] outputArray = params.output.getDenseBlock();
+               int [] minIndexArrR = params.tmpData.minIndexArrR;
+               int [] maxIndexArrR = params.tmpData.maxIndexArrR;
+               int [] minIndexArrS = params.tmpData.minIndexArrS;
+               int [] maxIndexArrS = params.tmpData.maxIndexArrS;
+               final int outputOffset = n*params.K*params.P*params.Q + 
k*params.P*params.Q;
+               
+               Iterator<IJV> iter = params.input2.sparseBlock.getIterator();
+               int [] tensorIndexes = new int[4];
+               
+               while(iter.hasNext()) {
+                       IJV ijv = iter.next();
+                       computeTensorIndexes(ijv.getI(), ijv.getJ(), 
tensorIndexes, params.K, params.C, params.R, params.S);
+                       if(k == tensorIndexes[0]) {
+                               int c = tensorIndexes[1];
+                               int r = tensorIndexes[2];
+                               int s = tensorIndexes[3];
+                               double filterVal = ijv.getV();
+                               final int inputOffset = 
n*params.C*params.H*params.W + c*params.H*params.W + s - params.pad_w;
+                               for (int p = minIndexArrR[r]; p < 
maxIndexArrR[r]; p++) {
+                                       final int hOffset = inputOffset + 
(p*params.stride_h + r - params.pad_h)*params.W;
+                                       final int pOffset = outputOffset + 
p*params.Q;
+                                       for (int q = minIndexArrS[s]; q < 
maxIndexArrS[s]; q++) {
+                                               final int w = q*params.stride_w;
+                                               outputArray[pOffset + q] += 
inputArray[hOffset + w]*filterVal;
+                                       }
+                               }
+                       }
+               }
+       }
+       
+       private static void doLoopBasedConv2dSparseDense(int n, int k, 
ConvolutionParameters params, double [] filterArray) throws DMLRuntimeException 
{
+               double [] outputArray = params.output.getDenseBlock();
+               int outputOffset = n*params.K*params.P*params.Q + 
k*params.P*params.Q;
+               
+               Iterator<IJV> iter = params.input1.sparseBlock.getIterator();
+               int [] tensorIndexes = new int[4];
+               
+               int [] minIndexArrR = params.tmpData.minIndexArrR;
+               int [] minIndexArrS = params.tmpData.minIndexArrS;
+               while(iter.hasNext()) {
+                       IJV ijv = iter.next();
+                       computeTensorIndexes(ijv.getI(), ijv.getJ(), 
tensorIndexes, params.N, params.C, params.H, params.W);
+                       if(n == tensorIndexes[0]) {
+                               int c = tensorIndexes[1];
+                               int h = tensorIndexes[2];
+                               int w = tensorIndexes[3];
+                               double imgVal = ijv.getV();
+                               for (int r = minIndexArrR[h]; r < params.R; r 
+= params.stride_h) {
+                                       int filterOffset = 
k*params.C*params.R*params.S + c*params.R*params.S + r*params.S;
+                                       for (int s = minIndexArrS[w]; s < 
params.S; s += params.stride_w) {
+                                               int p = 
(int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h);
+                                               int q = 
(int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w);
+                                               if(p >= 0 && p < params.P && q 
>= 0 && q < params.Q) {
+                                                       double filterVal = 
filterArray[filterOffset + s];
+                                                       
outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal;
+                                               }
+                                       }
+                               }       
+                       }
+               }
+       }
+       
+       private static void doLoopBasedConv2dSparseSparse(int n, int k, 
ConvolutionParameters params) throws DMLRuntimeException {
+               double [] outputArray = params.output.getDenseBlock();
+               int [] minIndexArrR = params.tmpData.minIndexArrR;
+               int [] maxIndexArrR = params.tmpData.maxIndexArrR;
+               int [] minIndexArrS = params.tmpData.minIndexArrS;
+               int [] maxIndexArrS = params.tmpData.maxIndexArrS;
+               int outputOffset = n*params.K*params.P*params.Q + 
k*params.P*params.Q;
+               
+               
+               int [] tensorIndexesImage = new int[4];
+               int [] tensorIndexesFilter = new int[4];
+
+               Iterator<IJV> iter = params.input1.sparseBlock.getIterator();
+               
+               while(iter.hasNext()) {
+                       IJV ijv = iter.next();
+                       computeTensorIndexes(ijv.getI(), ijv.getJ(), 
tensorIndexesImage, params.N, params.C, params.H, params.W);
+                       if(n == tensorIndexesImage[0]) {
+                               int c = tensorIndexesImage[1];
+                               int h = tensorIndexesImage[2];
+                               int w = tensorIndexesImage[3];
+                               double imgVal = ijv.getV();
+               
+                               Iterator<IJV> iter1 = 
params.input2.sparseBlock.getIterator();
+                               while(iter1.hasNext()) {
+                                       IJV ijv1 = iter1.next();
+                                       computeTensorIndexes(ijv1.getI(), 
ijv1.getJ(), tensorIndexesFilter, params.K, params.C, params.R, params.S);
+                                       if(k == tensorIndexesFilter[0] && c == 
tensorIndexesFilter[1]) {
+                                               int r =  tensorIndexesFilter[2];
+                                               int s =  tensorIndexesFilter[3];
+                                               
if((r-minIndexArrR[h])%params.stride_h == 0 && 
(s-minIndexArrS[w])%params.stride_w == 0) {
+                                                       int p = 
(int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h);
+                                                       int q = 
(int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w);
+                                                       if(p >= 0 && p < 
params.P && q >= 0 && q < params.Q) {
+                                                               double 
filterVal =  ijv1.getV();
+                                                               
outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal;
                                                        }
                                                }
                                        }
                                }
                        }
                }
+               
+               while(iter.hasNext()) {
+                       IJV ijv = iter.next();
+                       computeTensorIndexes(ijv.getI(), ijv.getJ(), 
tensorIndexesFilter, params.K, params.C, params.R, params.S);
+                       if(k == tensorIndexesFilter[0]) {
+                               int c = tensorIndexesFilter[1];
+                               int r = tensorIndexesFilter[2];
+                               int s = tensorIndexesFilter[3];
+                               double filterVal = ijv.getV();
+                               for (int p = minIndexArrR[r]; p < 
maxIndexArrR[r]; p++) {
+                                       int h = p*params.stride_h + r - 
params.pad_h;
+                                       for (int q = minIndexArrS[s]; q < 
maxIndexArrS[s]; q++) {
+                                               int w = q*params.stride_w + s - 
params.pad_w;
+                                               // TODO: Improve the 
performance of sparse sparse 
+                                               outputArray[outputOffset + 
p*params.Q + q] += sparseConvMultiply(filterVal, params, n, c, h, w);
+                                       }
+                               }
+                       }
+               }
+       }
+       
+       /**
+        * This is essentially memory-less operation and can be used when the 
memory pressure is extremely high.
+        * @param n
+        * @param k
+        * @param params
+        * @throws DMLRuntimeException 
+        */
+       private static void doLoopBasedConv2d(int n1, int n2, int k, 
ConvolutionParameters params) throws DMLRuntimeException {
+               double [] inputArray = null;
+               if (!params.input1.isInSparseFormat())
+                       inputArray = params.input1.getDenseBlock();
+               double [] filterArray = null;
+               if (!params.input2.isInSparseFormat())
+                       filterArray = params.input2.getDenseBlock();
+               
+               if(inputArray != null && filterArray != null) {
+                       doLoopBasedConv2dDenseDense(n1, n2, k, params, 
inputArray, filterArray);
+               }
+               else if(inputArray != null && filterArray == null) {
+                       for (int n = n1; n < n2; n++) 
+                               doLoopBasedConv2dDenseSparse(n, k, params, 
inputArray);
+               }
+               else if(inputArray == null && filterArray != null) {
+                       for (int n = n1; n < n2; n++)
+                               doLoopBasedConv2dSparseDense(n, k, params, 
filterArray);
+               }
+               else if(inputArray == null && filterArray == null) {
+                       for (int n = n1; n < n2; n++)
+                               doLoopBasedConv2dSparseSparse(n, k, params);
+               }
        }
        
        private static int getMinPQ(int pad, int filterSize, int stride) {
@@ -451,12 +706,7 @@ public class LibMatrixDNN {
                return Math.min(outputSize, (int)Math.ceil(((double)(inputSize 
+ pad - filterSize)) / stride));
        }
        
-       private static double denseConvMultiply(double [] inputArray, double 
filterVal, ConvolutionParameters params,
-                       int n, int c, int h, int w) {
-               return inputArray[n*params.C*params.H*params.W + 
c*params.H*params.W + h*params.W+w]*filterVal;
-       }
-       
-       private static double sparseConvMultiply(double [] inputArray, double 
filterVal, ConvolutionParameters params,
+       private static double sparseConvMultiply(double filterVal, 
ConvolutionParameters params,
                        int n, int c, int h, int w) {
                return params.input1.quickGetValue(n, c*params.H*params.W + 
h*params.W + w)*filterVal;
        }
@@ -635,27 +885,41 @@ public class LibMatrixDNN {
                outputBlock.setNonZeros(input.getNonZeros()); // As number of 
non-zeros doesnot change for reshape_col
        }
        
-       private static void runParallelConvTask(int constrainedNumThreads, int 
Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException {
-               // Total number of compute units available: 
constrainedNumThreads
-               // Static task allocation. TODO: Do this in dynamic way
-               int taskSize = TASK_SIZE;
-               while(true) {
-                       if(params.N * Math.ceil(Z/taskSize) > 
constrainedNumThreads || taskSize == 1) {
-                               doRunParallelConvTask(constrainedNumThreads, Z, 
type, params, taskSize);
-                               return;
+       private static int [] getTaskSize(int constrainedNumThreads, int 
maxNumTaskSize1, int maxNumTaskSize2) {
+               int taskSize1 = 1; int taskSize2 = 1;
+               // Why this heuristics ? To reduce the impact of the 
thread-creation overhead in case of small tasks
+               int approxNumTasksToCreate = 3*constrainedNumThreads;
+               while((maxNumTaskSize1*maxNumTaskSize2)/(taskSize1*taskSize2) > 
approxNumTasksToCreate) {
+                       // Possibility of creating too many tasks, increase 
taskSize2
+                       taskSize2 *= 2;
+                       if(taskSize2 >= maxNumTaskSize2) {
+                               taskSize2 = maxNumTaskSize2;
+                               break;
                        }
-                       taskSize = Math.max(taskSize/2, 1);
                }
+               while((maxNumTaskSize1*maxNumTaskSize2)/(taskSize1*taskSize2) > 
approxNumTasksToCreate) {
+                       // Possibility of creating too many tasks, increase 
taskSize1
+                       taskSize1 *= 2;
+                       if(taskSize1 >= maxNumTaskSize1) {
+                               taskSize1 = maxNumTaskSize1;
+                               break;
+                       }
+               }
+               int [] ret = new int[2];
+               ret[0] = taskSize1;
+               ret[1] = taskSize2;
+               return ret;
        }
        
-       private static void doRunParallelConvTask(int constrainedNumThreads, 
int Z, TaskType type, ConvolutionParameters params, int taskSize) throws 
DMLRuntimeException {
-               ArrayList<ConvTask> tasks = new ArrayList<ConvTask>();          
-               
-               for (int n = 0; n < params.N; n++) {
-                       for (int z = 0; z < Z; z += taskSize) {
-                               tasks.add(new ConvTask(n, n+1, z, Math.min(Z, 
z+taskSize), type, params));
+       private static void runParallelConvTask(int constrainedNumThreads, int 
Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException {
+               ArrayList<ConvTask> tasks = new ArrayList<ConvTask>();
+               int [] taskSizes = getTaskSize(constrainedNumThreads, params.N, 
Z);
+               for (int n = 0; n < params.N; n += taskSizes[0]) {
+                       for (int z = 0; z < Z; z += taskSizes[1]) {
+                               tasks.add(new ConvTask(n, Math.min(params.N, 
n+taskSizes[0]), z, Math.min(Z, z+taskSizes[1]), type, params));
                        }
                }
+               LOG.debug("Reduce number of tasks from " + (params.N*Z)  + "(" 
+ params.N + "," + Z + ") to " + tasks.size());
 
                ExecutorService pool = Executors.newFixedThreadPool( 
Math.min(constrainedNumThreads, tasks.size()) );
                List<Future<Object>> taskret;
@@ -727,10 +991,8 @@ public class LibMatrixDNN {
                                        }
                                        break;
                                case LoopBasedConv2d:
-                                       for (int n = n1; n < n2; n++) {
-                                               for (int z = z1; z < z2; z++) {
-                                                       
LibMatrixDNN.doLoopBasedConv2d(n, z, params);
-                                               }
+                                       for (int z = z1; z < z2; z++) {
+                                               
LibMatrixDNN.doLoopBasedConv2d(n1, n2, z, params);
                                        }
                                        break;
                                default:

Reply via email to