Repository: incubator-systemml Updated Branches: refs/heads/master 20e05458b -> 2ebf885a6
[SYSTEMML-769] Improved performance of LibMatrixDNN's conv2d and conv2d_backward_filter - Fixed bug while iterating through sparse conv2d_backward_filter - Also added vectorized conv2d Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2ebf885a Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2ebf885a Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2ebf885a Branch: refs/heads/master Commit: 2ebf885a6919e1cb0598e2aab4d0ffb46b8e0ab5 Parents: 20e0545 Author: Niketan Pansare <[email protected]> Authored: Fri Jul 8 20:26:16 2016 -0700 Committer: Niketan Pansare <[email protected]> Committed: Fri Jul 8 20:28:25 2016 -0700 ---------------------------------------------------------------------- .../sysml/runtime/matrix/data/LibMatrixDNN.java | 558 ++++++++++++++----- 1 file changed, 410 insertions(+), 148 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2ebf885a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index d9faf7e..26e2b8b 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -20,6 +20,7 @@ package org.apache.sysml.runtime.matrix.data; import java.lang.ref.SoftReference; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; @@ -29,12 +30,16 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicLong; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.util.ConvolutionUtils; public class LibMatrixDNN { + + protected static final Log LOG = LogFactory.getLog(LibMatrixDNN.class.getName()); public static final boolean ALLOW_MULTI_THREADED_OPS = true; // Using hashmap to avoid any performance impacts of multimap @@ -62,13 +67,14 @@ public class LibMatrixDNN { enum TaskType { ReshapeCol, Rotate180, Im2Col, Col2Im, MaxPooling_Forward, MaxPooling_Backward, LoopBasedConv2d } - public static final int TASK_SIZE = 64; // to take care of extremely small tasks public static class TemporaryConvolutionData { public int [] minIndexArrR; public int [] minIndexArrS; public int [] maxIndexArrR; public int [] maxIndexArrS; + int minCommonIndexS; + int maxCommonIndexS; } public static class ConvolutionParameters { @@ -159,6 +165,9 @@ public class LibMatrixDNN { dout.getNumRows() != params.N || dout.getNumColumns() != params.K*params.P*params.Q) { throw new DMLRuntimeException("Incorrect input to conv2d_backward_filter"); } + if(params.stride_h <= 0 || params.stride_w <= 0) { + throw new DMLRuntimeException("Only positive strides supported"); + } int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads); if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) { @@ -198,7 +207,7 @@ public class LibMatrixDNN { } } - public static void doConv2d_Backward_Filter(int k, int c, int r, int s, ConvolutionParameters params) { + private static void doConv2d_Backward_Filter(int k, int c, int r, int s, ConvolutionParameters params) throws DMLRuntimeException { double [] inputArray = null; if (!params.input1.isInSparseFormat()) inputArray = params.input1.getDenseBlock(); @@ -207,62 +216,125 @@ public class LibMatrixDNN { doutArray = params.input2.getDenseBlock(); double [] outputArray = params.output.getDenseBlock(); - long outputVal = 0; - if(doutArray != null) { - for (int n = 0; n < params.N; n++) { - for (int p = 0; p < params.P; p++) { - for (int q = 0; q < params.Q; q++) { - int h = p*params.stride_h + r - params.pad_h; - int w = q*params.stride_w + s - params.pad_w; - if(h >= 0 && h < params.H && w >= 0 && w < params.W) { - double doutVal = doutArray[n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q + q]; - if(doutVal != 0) { - if(inputArray != null) - outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w]; - else - outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w); - } - } - } - } - } + double outputVal = 0; + if(inputArray == null && doutArray == null) { + outputVal = doConv2d_Backward_Filter_SparseSparse(k, c, r, s, params); + } + else if(inputArray != null && doutArray == null) { + outputVal = doConv2d_Backward_Filter_DenseSparse(k, c, r, s, params, inputArray); + } + else if(inputArray == null && doutArray != null) { + outputVal = doConv2d_Backward_Filter_SparseDense(k, c, r, s, params, doutArray); } else { - MatrixBlock dout = params.input2; - if( !dout.isEmptyBlock(false) ) { - int start=0; - int rlen = dout.getNumRows(); - int clen = dout.getNumColumns(); - for(int r1=0; r1<Math.min(dout.sparseBlock.numRows(), rlen); r1++, start+=clen) - { - if(dout.sparseBlock.isEmpty(r1)) - continue; - int pos = dout.sparseBlock.pos(r1); - int len = dout.sparseBlock.size(r1); - int[] aix = dout.sparseBlock.indexes(r1); - double[] avals = dout.sparseBlock.values(r1); - - for(int i=pos; i<pos+len; i++) { - int index = start+aix[i]; - double doutVal = avals[i]; - int n = index / clen; - int p = index / params.Q; - int q = index % params.Q; - int h = p*params.stride_h + r - params.pad_h; - int w = q*params.stride_w + s - params.pad_w; - if(h >= 0 && h < params.H && w >= 0 && w < params.W && doutVal != 0) { - if(inputArray != null) - outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w]; - else - outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w); - } - } + outputVal = doConv2d_Backward_Filter_DenseDense(k, c, r, s, params, inputArray, doutArray); + } + + outputArray[k*params.C*params.R*params.S + c*params.R*params.S + r*params.S + s] = outputVal; + } + + private static double doConv2d_Backward_Filter_SparseDense(int k, int c, int r, int s, ConvolutionParameters params, double [] doutArray) throws DMLRuntimeException { + double outputVal = 0; + // To ensure h >= 0 && h < params.H + int pMin = (int) Math.max(0, Math.ceil(((double)(params.pad_h-r))/params.stride_h)); + int qMin = (int) Math.max(0, Math.ceil(((double)(params.pad_w-s))/params.stride_w)); + // To ensure w >= 0 && w < params.W + int pMax = (int) Math.min(params.P, Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h)); + int qMax = (int) Math.min(params.Q, Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w)); + + // TODO: Optimize this case + for (int n = 0; n < params.N; n++) { + int doutOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; + for (int p = pMin; p < pMax; p++) { + for (int q = qMin; q < qMax; q++) { + int h = p*params.stride_h + r - params.pad_h; + int w = q*params.stride_w + s - params.pad_w; + outputVal += doutArray[doutOffset + p*params.Q + q]*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w); } - } + } } + return outputVal; + } + + private static double doConv2d_Backward_Filter_DenseDense(int k, int c, int r, int s, ConvolutionParameters params, double [] inputArray, double [] doutArray) { + double outputVal = 0; + // To ensure h >= 0 && h < params.H + int pMin = (int) Math.max(0, Math.ceil(((double)(params.pad_h-r))/params.stride_h)); + int qMin = (int) Math.max(0, Math.ceil(((double)(params.pad_w-s))/params.stride_w)); + // To ensure w >= 0 && w < params.W + int pMax = (int) Math.min(params.P, Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h)); + int qMax = (int) Math.min(params.Q, Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w)); - outputArray[k*params.C*params.R*params.S + c*params.R*params.S + r*params.S + s] = outputVal; + for (int n = 0; n < params.N; n++) { + int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W + s - params.pad_w; + int doutOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; + for (int p = pMin; p < pMax; p++) { + int h = p*params.stride_h + r - params.pad_h; + for (int q = qMin; q < qMax; q++) { + int w = q*params.stride_w; + outputVal += doutArray[doutOffset + p*params.Q + q]*inputArray[inputOffset + h*params.W+w]; + } + } + } + + return outputVal; + } + + private static void computeTensorIndexes(int i, int j, int [] ret, int N, int C, int H, int W) throws DMLRuntimeException { + ret[0] = i; + ret[1] = j / (H*W); + ret[2] = (j - ret[1]*(H*W))/W; + ret[3] = j % W; + } + + private static double doConv2d_Backward_Filter_DenseSparse(int k, int c, int r, int s, ConvolutionParameters params, double [] inputArray) throws DMLRuntimeException { + MatrixBlock dout = params.input2; + double outputVal = 0; + Iterator<IJV> iter = dout.sparseBlock.getIterator(); + int [] tensorIndexes = new int[4]; + while(iter.hasNext()) { + IJV ijv = iter.next(); + computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.K, params.P, params.Q); + if(k == tensorIndexes[1]) { + int n = tensorIndexes[0]; + int p = tensorIndexes[2]; + int q = tensorIndexes[3]; + + double doutVal = ijv.getV(); + int h = p*params.stride_h + r - params.pad_h; + int w = q*params.stride_w + s - params.pad_w; + if(h >= 0 && h < params.H && w >= 0 && w < params.W) { + outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w]; + } + } + } + return outputVal; + } + + private static double doConv2d_Backward_Filter_SparseSparse(int k, int c, int r, int s, ConvolutionParameters params) throws DMLRuntimeException { + MatrixBlock dout = params.input2; + double outputVal = 0; + Iterator<IJV> iter = dout.sparseBlock.getIterator(); + int [] tensorIndexes = new int[4]; + + while(iter.hasNext()) { + IJV ijv = iter.next(); + computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.K, params.P, params.Q); + if(k == tensorIndexes[1]) { + int n = tensorIndexes[0]; + int p = tensorIndexes[2]; + int q = tensorIndexes[3]; + + double doutVal = ijv.getV(); + int h = p*params.stride_h + r - params.pad_h; + int w = q*params.stride_w + s - params.pad_w; + if(h >= 0 && h < params.H && w >= 0 && w < params.W) { + outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w); + } + } + } + return outputVal; } private static class ConvBackwardFilterTask implements Callable<Object> { @@ -294,25 +366,55 @@ public class LibMatrixDNN { throw new DMLRuntimeException("Incorrect input to conv2d"); } - params.tmpData = new TemporaryConvolutionData(); - params.tmpData.minIndexArrR = new int[params.R]; - params.tmpData.maxIndexArrR = new int[params.R]; - params.tmpData.minIndexArrS = new int[params.S]; - params.tmpData.maxIndexArrS = new int[params.S]; - for (int r = 0; r < params.R; r++) { - params.tmpData.minIndexArrR[r] = getMinPQ(params.pad_h, r, params.stride_h); - params.tmpData.maxIndexArrR[r] = getMaxPQ(params.pad_h, r, params.stride_h, params.P, params.H); + params.tmpData = new TemporaryConvolutionData(); + if(input.isInSparseFormat()) { + params.tmpData.minIndexArrR = new int[params.H]; + params.tmpData.minIndexArrS = new int[params.W]; + for(int h = 0; h < params.H; h++) { + for (int r = 0; r < params.R; r++) { + // int h = p*params.stride_h + r - params.pad_h; + if((h + params.pad_h - r) % params.stride_h == 0) { + params.tmpData.minIndexArrR[h] = r; + break; + } + } + } + for(int w = 0; w < params.W; w++) { + for (int s = 0; s < params.S; s++) { + // int h = p*params.stride_h + r - params.pad_h; + if((w + params.pad_w - s) % params.stride_w == 0) { + params.tmpData.minIndexArrS[w] = s; + break; + } + } + } } - for (int s = 0; s < params.S; s++) { - params.tmpData.minIndexArrS[s] = getMinPQ(params.pad_w, s, params.stride_w); - params.tmpData.maxIndexArrS[s] = getMaxPQ(params.pad_w, s, params.stride_w, params.Q, params.W); + else { + params.tmpData.minIndexArrR = new int[params.R]; + params.tmpData.maxIndexArrR = new int[params.R]; + params.tmpData.minIndexArrS = new int[params.S]; + params.tmpData.maxIndexArrS = new int[params.S]; + for (int r = 0; r < params.R; r++) { + params.tmpData.minIndexArrR[r] = getMinPQ(params.pad_h, r, params.stride_h); + params.tmpData.maxIndexArrR[r] = getMaxPQ(params.pad_h, r, params.stride_h, params.P, params.H); + } + for (int s = 0; s < params.S; s++) { + params.tmpData.minIndexArrS[s] = getMinPQ(params.pad_w, s, params.stride_w); + params.tmpData.maxIndexArrS[s] = getMaxPQ(params.pad_w, s, params.stride_w, params.Q, params.W); + } + params.tmpData.minCommonIndexS = params.tmpData.minIndexArrS[0]; + params.tmpData.maxCommonIndexS = params.tmpData.maxIndexArrS[0]; + for (int s = 1; s < params.S; s++) { + params.tmpData.minCommonIndexS = Math.max(params.tmpData.minCommonIndexS, params.tmpData.minIndexArrS[s]); + params.tmpData.maxCommonIndexS = Math.min(params.tmpData.maxCommonIndexS, params.tmpData.maxIndexArrS[s]); + } } int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads); if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) { for (int n = 0; n < params.N; n++) { for (int k = 0; k < params.K; k++) { - doLoopBasedConv2d(n, k, params); + doLoopBasedConv2d(n, n+1, k, params); } } } @@ -345,102 +447,255 @@ public class LibMatrixDNN { } } - /** - * This is essentially memory-less operation and can be used when the memory pressure is extremely high. - * @param n - * @param k - * @param params - */ - private static void doLoopBasedConv2d(int n, int k, ConvolutionParameters params) { - double [] inputArray = null; - if (!params.input1.isInSparseFormat()) - inputArray = params.input1.getDenseBlock(); - double [] filterArray = null; - if (!params.input2.isInSparseFormat()) - filterArray = params.input2.getDenseBlock(); + private static void doLoopBasedConv2dDenseDense(int n1, int n2, int k, ConvolutionParameters params, + double [] inputArray, double [] filterArray) { double [] outputArray = params.output.getDenseBlock(); - - int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; - int [] minIndexArrR = params.tmpData.minIndexArrR; int [] maxIndexArrR = params.tmpData.maxIndexArrR; int [] minIndexArrS = params.tmpData.minIndexArrS; int [] maxIndexArrS = params.tmpData.maxIndexArrS; - if(inputArray != null && filterArray != null) { - for (int c = 0; c < params.C; c++) { - for (int r = 0; r < params.R; r++) { - int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S; - for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { - for (int s = 0; s < params.S; s++) { - double filterVal = filterArray[filterOffset + s]; - if(filterVal != 0) { - int h = p*params.stride_h + r - params.pad_h; - for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) { - int w = q*params.stride_w + s - params.pad_w; - outputArray[outputOffset + p*params.Q + q] += denseConvMultiply(inputArray, filterVal, params, n, c, h, w); + int minCommonIndexS = params.tmpData.minCommonIndexS; + int maxCommonIndexS = params.tmpData.maxCommonIndexS; + + + int minS = 0; + if(params.S >= 4) { + minS = params.S - params.S % 4; + for (int n = n1; n < n2; n++) { + for (int c = 0; c < params.C; c++) { + for (int r = 0; r < params.R; r++) { + final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S; + for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { + final int h = p*params.stride_h + r - params.pad_h; + final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w; + final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q; + // ------------------------------------------------------------------------ + // Efficient striding with vectorization + for (int q = minCommonIndexS; q < maxCommonIndexS; q++) { + final int wOffset = inputOffSet + q*params.stride_w; + final int outOffsetWithQ = outputOffset + q; + for (int s = 0; s < minS; s += 4) { + final int inOffsetWithS = wOffset + s; + final int filterOffsetWithS = filterOffset + s; + outputArray[outOffsetWithQ] += inputArray[inOffsetWithS]*filterArray[filterOffsetWithS] + + inputArray[inOffsetWithS+1]*filterArray[filterOffsetWithS+1] + + inputArray[inOffsetWithS+2]*filterArray[filterOffsetWithS+2] + + inputArray[inOffsetWithS+3]*filterArray[filterOffsetWithS+3]; } } + // ------------------------------------------------------------------------ } } } } } - else if(inputArray != null && filterArray == null) { + + for (int n = n1; n < n2; n++) { for (int c = 0; c < params.C; c++) { for (int r = 0; r < params.R; r++) { + final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S; for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { - for (int s = 0; s < params.S; s++) { - double filterVal = params.input2.quickGetValue(k, c*params.R*params.S + r*params.S + s); - if(filterVal != 0) { - int h = p*params.stride_h + r - params.pad_h; - for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) { - int w = q*params.stride_w + s - params.pad_w; - outputArray[outputOffset + p*params.Q + q] += denseConvMultiply(inputArray, filterVal, params, n, c, h, w); - } + final int h = p*params.stride_h + r - params.pad_h; + final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w; + final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q; + // ------------------------------------------------------------------------ + // Efficient striding + for (int q = minCommonIndexS; q < maxCommonIndexS; q++) { + final int wOffset = inputOffSet + q*params.stride_w; + for (int s = minS; s < params.S; s++) { + outputArray[outputOffset + q] += inputArray[wOffset + s]*filterArray[filterOffset + s]; } } + // ------------------------------------------------------------------------ } } } - } - else if(inputArray == null && filterArray != null) { + + for (int c = 0; c < params.C; c++) { for (int r = 0; r < params.R; r++) { - int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S; + final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S; for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { + final int h = p*params.stride_h + r - params.pad_h; + final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w; + final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q; + // ------------------------------------------------------------------------ + // Inefficient striding for (int s = 0; s < params.S; s++) { - double filterVal = filterArray[filterOffset + s]; - if(filterVal != 0) { - int h = p*params.stride_h + r - params.pad_h; - for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) { - int w = q*params.stride_w + s - params.pad_w; - outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(inputArray, filterVal, params, n, c, h, w); - } + for (int q = minIndexArrS[s]; q < minCommonIndexS; q++) { + final int w = q*params.stride_w + s; + outputArray[outputOffset + q] += inputArray[inputOffSet + w]*filterArray[filterOffset + s]; + } + for (int q = maxCommonIndexS; q < maxIndexArrS[s]; q++) { + final int w = q*params.stride_w + s; + outputArray[outputOffset + q] += inputArray[inputOffSet + w]*filterArray[filterOffset + s]; } } + // ------------------------------------------------------------------------ } } } } - else if(inputArray == null && filterArray == null) { - for (int c = 0; c < params.C; c++) { - for (int r = 0; r < params.R; r++) { - for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { - for (int s = 0; s < params.S; s++) { - double filterVal = params.input2.quickGetValue(k, c*params.R*params.S + r*params.S + s); - if(filterVal != 0) { - int h = p*params.stride_h + r - params.pad_h; - for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) { - int w = q*params.stride_w + s - params.pad_w; - outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(inputArray, filterVal, params, n, c, h, w); - } + } + + private static void doLoopBasedConv2dDenseSparse(int n, int k, ConvolutionParameters params, double [] inputArray) throws DMLRuntimeException { + double [] outputArray = params.output.getDenseBlock(); + int [] minIndexArrR = params.tmpData.minIndexArrR; + int [] maxIndexArrR = params.tmpData.maxIndexArrR; + int [] minIndexArrS = params.tmpData.minIndexArrS; + int [] maxIndexArrS = params.tmpData.maxIndexArrS; + final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; + + Iterator<IJV> iter = params.input2.sparseBlock.getIterator(); + int [] tensorIndexes = new int[4]; + + while(iter.hasNext()) { + IJV ijv = iter.next(); + computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.K, params.C, params.R, params.S); + if(k == tensorIndexes[0]) { + int c = tensorIndexes[1]; + int r = tensorIndexes[2]; + int s = tensorIndexes[3]; + double filterVal = ijv.getV(); + final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W + s - params.pad_w; + for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { + final int hOffset = inputOffset + (p*params.stride_h + r - params.pad_h)*params.W; + final int pOffset = outputOffset + p*params.Q; + for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) { + final int w = q*params.stride_w; + outputArray[pOffset + q] += inputArray[hOffset + w]*filterVal; + } + } + } + } + } + + private static void doLoopBasedConv2dSparseDense(int n, int k, ConvolutionParameters params, double [] filterArray) throws DMLRuntimeException { + double [] outputArray = params.output.getDenseBlock(); + int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; + + Iterator<IJV> iter = params.input1.sparseBlock.getIterator(); + int [] tensorIndexes = new int[4]; + + int [] minIndexArrR = params.tmpData.minIndexArrR; + int [] minIndexArrS = params.tmpData.minIndexArrS; + while(iter.hasNext()) { + IJV ijv = iter.next(); + computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.C, params.H, params.W); + if(n == tensorIndexes[0]) { + int c = tensorIndexes[1]; + int h = tensorIndexes[2]; + int w = tensorIndexes[3]; + double imgVal = ijv.getV(); + for (int r = minIndexArrR[h]; r < params.R; r += params.stride_h) { + int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S; + for (int s = minIndexArrS[w]; s < params.S; s += params.stride_w) { + int p = (int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h); + int q = (int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w); + if(p >= 0 && p < params.P && q >= 0 && q < params.Q) { + double filterVal = filterArray[filterOffset + s]; + outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal; + } + } + } + } + } + } + + private static void doLoopBasedConv2dSparseSparse(int n, int k, ConvolutionParameters params) throws DMLRuntimeException { + double [] outputArray = params.output.getDenseBlock(); + int [] minIndexArrR = params.tmpData.minIndexArrR; + int [] maxIndexArrR = params.tmpData.maxIndexArrR; + int [] minIndexArrS = params.tmpData.minIndexArrS; + int [] maxIndexArrS = params.tmpData.maxIndexArrS; + int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; + + + int [] tensorIndexesImage = new int[4]; + int [] tensorIndexesFilter = new int[4]; + + Iterator<IJV> iter = params.input1.sparseBlock.getIterator(); + + while(iter.hasNext()) { + IJV ijv = iter.next(); + computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexesImage, params.N, params.C, params.H, params.W); + if(n == tensorIndexesImage[0]) { + int c = tensorIndexesImage[1]; + int h = tensorIndexesImage[2]; + int w = tensorIndexesImage[3]; + double imgVal = ijv.getV(); + + Iterator<IJV> iter1 = params.input2.sparseBlock.getIterator(); + while(iter1.hasNext()) { + IJV ijv1 = iter1.next(); + computeTensorIndexes(ijv1.getI(), ijv1.getJ(), tensorIndexesFilter, params.K, params.C, params.R, params.S); + if(k == tensorIndexesFilter[0] && c == tensorIndexesFilter[1]) { + int r = tensorIndexesFilter[2]; + int s = tensorIndexesFilter[3]; + if((r-minIndexArrR[h])%params.stride_h == 0 && (s-minIndexArrS[w])%params.stride_w == 0) { + int p = (int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h); + int q = (int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w); + if(p >= 0 && p < params.P && q >= 0 && q < params.Q) { + double filterVal = ijv1.getV(); + outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal; } } } } } } + + while(iter.hasNext()) { + IJV ijv = iter.next(); + computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexesFilter, params.K, params.C, params.R, params.S); + if(k == tensorIndexesFilter[0]) { + int c = tensorIndexesFilter[1]; + int r = tensorIndexesFilter[2]; + int s = tensorIndexesFilter[3]; + double filterVal = ijv.getV(); + for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { + int h = p*params.stride_h + r - params.pad_h; + for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) { + int w = q*params.stride_w + s - params.pad_w; + // TODO: Improve the performance of sparse sparse + outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(filterVal, params, n, c, h, w); + } + } + } + } + } + + /** + * This is essentially memory-less operation and can be used when the memory pressure is extremely high. + * @param n + * @param k + * @param params + * @throws DMLRuntimeException + */ + private static void doLoopBasedConv2d(int n1, int n2, int k, ConvolutionParameters params) throws DMLRuntimeException { + double [] inputArray = null; + if (!params.input1.isInSparseFormat()) + inputArray = params.input1.getDenseBlock(); + double [] filterArray = null; + if (!params.input2.isInSparseFormat()) + filterArray = params.input2.getDenseBlock(); + + if(inputArray != null && filterArray != null) { + doLoopBasedConv2dDenseDense(n1, n2, k, params, inputArray, filterArray); + } + else if(inputArray != null && filterArray == null) { + for (int n = n1; n < n2; n++) + doLoopBasedConv2dDenseSparse(n, k, params, inputArray); + } + else if(inputArray == null && filterArray != null) { + for (int n = n1; n < n2; n++) + doLoopBasedConv2dSparseDense(n, k, params, filterArray); + } + else if(inputArray == null && filterArray == null) { + for (int n = n1; n < n2; n++) + doLoopBasedConv2dSparseSparse(n, k, params); + } } private static int getMinPQ(int pad, int filterSize, int stride) { @@ -451,12 +706,7 @@ public class LibMatrixDNN { return Math.min(outputSize, (int)Math.ceil(((double)(inputSize + pad - filterSize)) / stride)); } - private static double denseConvMultiply(double [] inputArray, double filterVal, ConvolutionParameters params, - int n, int c, int h, int w) { - return inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w]*filterVal; - } - - private static double sparseConvMultiply(double [] inputArray, double filterVal, ConvolutionParameters params, + private static double sparseConvMultiply(double filterVal, ConvolutionParameters params, int n, int c, int h, int w) { return params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w)*filterVal; } @@ -635,27 +885,41 @@ public class LibMatrixDNN { outputBlock.setNonZeros(input.getNonZeros()); // As number of non-zeros doesnot change for reshape_col } - private static void runParallelConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException { - // Total number of compute units available: constrainedNumThreads - // Static task allocation. TODO: Do this in dynamic way - int taskSize = TASK_SIZE; - while(true) { - if(params.N * Math.ceil(Z/taskSize) > constrainedNumThreads || taskSize == 1) { - doRunParallelConvTask(constrainedNumThreads, Z, type, params, taskSize); - return; + private static int [] getTaskSize(int constrainedNumThreads, int maxNumTaskSize1, int maxNumTaskSize2) { + int taskSize1 = 1; int taskSize2 = 1; + // Why this heuristics ? To reduce the impact of the thread-creation overhead in case of small tasks + int approxNumTasksToCreate = 3*constrainedNumThreads; + while((maxNumTaskSize1*maxNumTaskSize2)/(taskSize1*taskSize2) > approxNumTasksToCreate) { + // Possibility of creating too many tasks, increase taskSize2 + taskSize2 *= 2; + if(taskSize2 >= maxNumTaskSize2) { + taskSize2 = maxNumTaskSize2; + break; } - taskSize = Math.max(taskSize/2, 1); } + while((maxNumTaskSize1*maxNumTaskSize2)/(taskSize1*taskSize2) > approxNumTasksToCreate) { + // Possibility of creating too many tasks, increase taskSize1 + taskSize1 *= 2; + if(taskSize1 >= maxNumTaskSize1) { + taskSize1 = maxNumTaskSize1; + break; + } + } + int [] ret = new int[2]; + ret[0] = taskSize1; + ret[1] = taskSize2; + return ret; } - private static void doRunParallelConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params, int taskSize) throws DMLRuntimeException { - ArrayList<ConvTask> tasks = new ArrayList<ConvTask>(); - - for (int n = 0; n < params.N; n++) { - for (int z = 0; z < Z; z += taskSize) { - tasks.add(new ConvTask(n, n+1, z, Math.min(Z, z+taskSize), type, params)); + private static void runParallelConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException { + ArrayList<ConvTask> tasks = new ArrayList<ConvTask>(); + int [] taskSizes = getTaskSize(constrainedNumThreads, params.N, Z); + for (int n = 0; n < params.N; n += taskSizes[0]) { + for (int z = 0; z < Z; z += taskSizes[1]) { + tasks.add(new ConvTask(n, Math.min(params.N, n+taskSizes[0]), z, Math.min(Z, z+taskSizes[1]), type, params)); } } + LOG.debug("Reduce number of tasks from " + (params.N*Z) + "(" + params.N + "," + Z + ") to " + tasks.size()); ExecutorService pool = Executors.newFixedThreadPool( Math.min(constrainedNumThreads, tasks.size()) ); List<Future<Object>> taskret; @@ -727,10 +991,8 @@ public class LibMatrixDNN { } break; case LoopBasedConv2d: - for (int n = n1; n < n2; n++) { - for (int z = z1; z < z2; z++) { - LibMatrixDNN.doLoopBasedConv2d(n, z, params); - } + for (int z = z1; z < z2; z++) { + LibMatrixDNN.doLoopBasedConv2d(n1, n2, z, params); } break; default:
