Repository: incubator-systemml Updated Branches: refs/heads/master b03950c32 -> d79dea926
[SYSTEMML-540] Cleanedup memory-less operators, fixed bufferpool bug and added direct conv2d_backward_data operator Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d79dea92 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d79dea92 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d79dea92 Branch: refs/heads/master Commit: d79dea9266bd1177a8643f2f4525ca918b0e8b97 Parents: b03950c Author: Niketan Pansare <[email protected]> Authored: Fri Aug 12 17:22:23 2016 -0700 Committer: Niketan Pansare <[email protected]> Committed: Fri Aug 12 17:24:19 2016 -0700 ---------------------------------------------------------------------- .../controlprogram/caching/CacheableData.java | 2 +- .../instructions/CPInstructionParser.java | 1 + .../cp/ConvolutionCPInstruction.java | 18 +- .../sysml/runtime/matrix/data/LibMatrixDNN.java | 684 +++++-------------- .../sysml/runtime/util/ConvolutionUtils.java | 73 +- 5 files changed, 233 insertions(+), 545 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d79dea92/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java index d043879..bf22fa3 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java @@ -790,7 +790,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data LOG.trace("Exporting " + this.getDebugName() + " to " + fName + " in format " + outputFormat); //TODO remove - if( getGPUObject() != null ) { + if( getGPUObject() != null && getGPUObject().isAllocated() ) { getGPUObject().acquireHostRead(); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d79dea92/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index b6e0c50..ae13d3d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -226,6 +226,7 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "maxpooling_backward" , CPINSTRUCTION_TYPE.Convolution); String2CPInstructionType.put( "conv2d" , CPINSTRUCTION_TYPE.Convolution); String2CPInstructionType.put( "conv2d_backward_filter" , CPINSTRUCTION_TYPE.Convolution); + String2CPInstructionType.put( "conv2d_backward_data" , CPINSTRUCTION_TYPE.Convolution); // Quaternary instruction opcodes String2CPInstructionType.put( "wsloss" , CPINSTRUCTION_TYPE.Quaternary); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d79dea92/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java index c0a1af5..8324ff2 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java @@ -118,7 +118,8 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { else if (opcode.equalsIgnoreCase("pooling_backward_reshape") || opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("conv2d") - || opcode.equalsIgnoreCase("conv2d_backward_filter")) { + || opcode.equalsIgnoreCase("conv2d_backward_filter") + || opcode.equalsIgnoreCase("conv2d_backward_data")) { InstructionUtils.checkNumFields(parts, 16); // dout, stride1, stride2, padding1, padding2 // input_shape1, input_shape2, input_shape3, input_shape4, @@ -236,16 +237,21 @@ public class ConvolutionCPInstruction extends UnaryCPInstruction { MatrixBlock filter = ec.getMatrixInput(_in2.getName()); outputBlock = getDenseOutputBlock(ec, N, K*P*Q, false); params.setReuseNonZeroedOutput(_reuseNonZeroedOutput); - boolean useMemoryLessConvolution = false; - LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params, useMemoryLessConvolution); + LibMatrixDNN.conv2d(matBlock, filter, outputBlock, params); ec.releaseMatrixInput(_in2.getName()); } else if (instOpcode.equalsIgnoreCase("conv2d_backward_filter")) { - MatrixBlock filter = ec.getMatrixInput(_in2.getName()); + MatrixBlock dout = ec.getMatrixInput(_in2.getName()); outputBlock = getDenseOutputBlock(ec, K, C*R*S, false); params.setReuseNonZeroedOutput(_reuseNonZeroedOutput); - boolean useMemoryLessConvolution = false; - LibMatrixDNN.conv2d_backward_filter(matBlock, filter, outputBlock, params, useMemoryLessConvolution); + LibMatrixDNN.conv2d_backward_filter(matBlock, dout, outputBlock, params); + ec.releaseMatrixInput(_in2.getName()); + } + else if (instOpcode.equalsIgnoreCase("conv2d_backward_data")) { + MatrixBlock dout = ec.getMatrixInput(_in2.getName()); + outputBlock = getDenseOutputBlock(ec, N, C * H * W, false); + params.setReuseNonZeroedOutput(_reuseNonZeroedOutput); + LibMatrixDNN.conv2d_backward_data(matBlock, dout, outputBlock, params); ec.releaseMatrixInput(_in2.getName()); } else { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d79dea92/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index 2374931..e657d18 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -75,7 +75,7 @@ public class LibMatrixDNN { enum TaskType { ReshapeCol, Rotate180, Im2Col, Col2Im, MaxPooling_Forward, MaxPooling_Backward, - LoopBasedConv2d, LoopedIm2ColConv2d, LoopBasedConv2dBwdFilter, LoopedIm2ColConv2dBwdFilter + LoopedIm2ColConv2d, LoopedIm2ColConv2dBwdFilter, LoopedIm2ColConv2dBwdData } public static class TemporaryConvolutionData { @@ -99,8 +99,10 @@ public class LibMatrixDNN { private static AtomicLong maxPoolBwdDenseCount = new AtomicLong(0); private static AtomicLong loopedConvMatMultTime = new AtomicLong(0); private static AtomicLong loopedConvIm2ColTime = new AtomicLong(0); - private static AtomicLong loopedConvBwdMatMultTime = new AtomicLong(0); - private static AtomicLong loopedConvBwdIm2ColTime = new AtomicLong(0); + private static AtomicLong loopedConvBwdFilterMatMultTime = new AtomicLong(0); + private static AtomicLong loopedConvBwdFilterIm2ColTime = new AtomicLong(0); + private static AtomicLong loopedConvBwdDataMatMultTime = new AtomicLong(0); + private static AtomicLong loopedConvBwdDataCol2ImTime = new AtomicLong(0); public static void appendStatistics(StringBuilder sb) { if(DMLScript.STATISTICS && (conv2dDenseCount.get() != 0 || conv2dSparseCount.get() != 0)) { @@ -117,11 +119,13 @@ public class LibMatrixDNN { + im2colSparseCount.get() + "/" + maxPoolBwdSparseCount.get() + ".\n"); if(loopedConvMatMultTime.get() != 0 || loopedConvIm2ColTime.get() != 0) { - sb.append("LibMatrixDNN conv(im2col/matmult), bwdFil (im2col/matmult) time:\t" + + sb.append("LibMatrixDNN conv(im2col/matmult), bwdF (im2col/matmult), bwdD (col2im/matmult) time:\t" + String.format("%.3f", loopedConvIm2ColTime.get()*1e-9) + "/" + String.format("%.3f", loopedConvMatMultTime.get()*1e-9) + "/" + - String.format("%.3f", loopedConvBwdIm2ColTime.get()*1e-9) + "/" + - String.format("%.3f", loopedConvBwdMatMultTime.get()*1e-9) + " sec.\n"); + String.format("%.3f", loopedConvBwdFilterIm2ColTime.get()*1e-9) + "/" + + String.format("%.3f", loopedConvBwdFilterMatMultTime.get()*1e-9) + "/" + + String.format("%.3f", loopedConvBwdDataCol2ImTime.get()*1e-9) + "/" + + String.format("%.3f", loopedConvBwdDataMatMultTime.get()*1e-9) + " sec.\n"); } } } @@ -140,8 +144,10 @@ public class LibMatrixDNN { loopedConvIm2ColTime.set(0); loopedConvMatMultTime.set(0); - loopedConvBwdMatMultTime.set(0); - loopedConvBwdIm2ColTime.set(0); + loopedConvBwdFilterMatMultTime.set(0); + loopedConvBwdFilterIm2ColTime.set(0); + loopedConvBwdDataMatMultTime.set(0); + loopedConvBwdDataCol2ImTime.set(0); } public static class ConvolutionParameters { @@ -173,6 +179,10 @@ public class LibMatrixDNN { return false; } + public String toString() { + return "(" + N + " " + C + " " + H + " " + W + " " + K + " " + R + " " + S + ")"; + } + public ConvolutionParameters(long N, long C, long H, long W, long K, long R, long S, long stride_h, long stride_w, long pad_h, long pad_w, int numThreads) throws DMLRuntimeException { this.N = convertToInt(N); @@ -228,7 +238,44 @@ public class LibMatrixDNN { } } - public static void conv2d_backward_filter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params, boolean useMemoryLessConvolution) throws DMLRuntimeException { + public static void conv2d_backward_data(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException { + params.input1 = filter; + params.input2 = dout; + params.output = outputBlock; + if(filter.getNumRows() != params.K || filter.getNumColumns() != params.C*params.R*params.S || + dout.getNumRows() != params.N || dout.getNumColumns() != params.K*params.P*params.Q) { + throw new DMLRuntimeException("Incorrect input to conv2d_backward_filter"); + } + if(params.stride_h <= 0 || params.stride_w <= 0) { + throw new DMLRuntimeException("Only positive strides supported"); + } + + if(DMLScript.STATISTICS) { + if(filter.isInSparseFormat() || dout.isInSparseFormat()) { + conv2dBwdDataSparseCount.addAndGet(1); + } + else { + conv2dBwdDataDenseCount.addAndGet(1); + } + } + + params.reuseNonZeroedOutput = true; + + int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads); + if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) { + warnSingleThreaded(); + MatrixBlock dout_reshaped = new MatrixBlock(params.P*params.Q, params.K, false); + dout_reshaped.allocateDenseBlock(true); + for (int n = 0; n < params.N; n++) { + doLoopedIm2ColConv2dBwdData(n, dout_reshaped, params); + } + } + else { + runConvTask(constrainedNumThreads, 1, TaskType.LoopedIm2ColConv2dBwdData, params); + } + } + + public static void conv2d_backward_filter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException { params.input1 = input; params.input2 = dout; params.output = outputBlock; @@ -249,42 +296,21 @@ public class LibMatrixDNN { } } - if(useMemoryLessConvolution && !useMemoryLessConvolution) { - params.reuseNonZeroedOutput = true; - } + params.reuseNonZeroedOutput = true; int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads); if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) { warnSingleThreaded(); - if(useMemoryLessConvolution) { - for (int c = 0; c < params.C; c++) { - for (int k = 0; k < params.K; k++) { - for (int r = 0; r < params.R; r++) { - for (int s = 0; s < params.S; s++) { - doConv2d_Backward_Filter(k, c, r, s, params); - } - } - } - } - } - else { - MatrixBlock im2ColOutBlock = new MatrixBlock(params.C*params.R*params.S, params.P*params.Q, false); - im2ColOutBlock.allocateDenseBlock(true); - MatrixBlock dout_reshaped = new MatrixBlock(params.P*params.Q, params.K, false); - dout_reshaped.allocateDenseBlock(true); - for (int n = 0; n < params.N; n++) { - params.output = doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock, dout_reshaped, params.output, params); - } + MatrixBlock im2ColOutBlock = new MatrixBlock(params.C*params.R*params.S, params.P*params.Q, false); + im2ColOutBlock.allocateDenseBlock(true); + MatrixBlock dout_reshaped = new MatrixBlock(params.P*params.Q, params.K, false); + dout_reshaped.allocateDenseBlock(true); + for (int n = 0; n < params.N; n++) { + params.output = doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock, dout_reshaped, params.output, params); } } else { - if(useMemoryLessConvolution) { - runConvTask(constrainedNumThreads, params.K*params.C, params.R*params.S, TaskType.LoopBasedConv2dBwdFilter, params); - } - else { - runConvTask(constrainedNumThreads, 1, TaskType.LoopedIm2ColConv2dBwdFilter, params); - } - + runConvTask(constrainedNumThreads, 1, TaskType.LoopedIm2ColConv2dBwdFilter, params); } } @@ -341,6 +367,24 @@ public class LibMatrixDNN { } } + private static void doLoopedIm2ColConv2dBwdData(int n, MatrixBlock dout_reshaped, ConvolutionParameters params) throws DMLRuntimeException { + MatrixBlock filter = params.input1; + MatrixBlock dout = params.input2; + doRotate180(n, 0, dout, dout_reshaped.denseBlock, params, true); + dout_reshaped.recomputeNonZeros(); + + MatrixBlock temp = new MatrixBlock(params.P*params.Q, params.C*params.R*params.S, false); + long t1 = DMLScript.STATISTICS ? System.nanoTime() : 0; + LibMatrixMult.matrixMult(dout_reshaped, filter, temp); + long t2 = DMLScript.STATISTICS ? System.nanoTime() : 0 ; + doCol2imOverSingleImage(n, temp, params); + long t3 = DMLScript.STATISTICS ? System.nanoTime() : 0 ; + if(DMLScript.STATISTICS) { + loopedConvBwdDataMatMultTime.addAndGet(t2-t1); + loopedConvBwdDataCol2ImTime.addAndGet(t3-t2); + } + } + private static MatrixBlock doLoopedIm2ColConv2dBwdFilter(int n, MatrixBlock im2ColOutBlock, MatrixBlock dout_reshaped, MatrixBlock partialRetBlock, ConvolutionParameters params) throws DMLRuntimeException { long nnz = 0; @@ -359,88 +403,14 @@ public class LibMatrixDNN { LibMatrixMult.matrixMult(im2ColOutBlock, dout_reshaped, temp); long t4 = DMLScript.STATISTICS ? System.nanoTime() : 0 ; if(DMLScript.STATISTICS) { - loopedConvBwdMatMultTime.addAndGet(t4-t3); - loopedConvBwdIm2ColTime.addAndGet(t2-t1); + loopedConvBwdFilterMatMultTime.addAndGet(t4-t3); + loopedConvBwdFilterIm2ColTime.addAndGet(t2-t1); } elementWiseInPlaceTransposedAddition(partialRetBlock, temp); return partialRetBlock; } - private static void doConv2d_Backward_Filter(int k, int c, int r, int s, ConvolutionParameters params) throws DMLRuntimeException { - double [] inputArray = null; - if (!params.input1.isInSparseFormat()) - inputArray = params.input1.getDenseBlock(); - double [] doutArray = null; - if (!params.input2.isInSparseFormat()) - doutArray = params.input2.getDenseBlock(); - double [] outputArray = params.output.getDenseBlock(); - - double outputVal = 0; - if(inputArray == null && doutArray == null) { - outputVal = doConv2d_Backward_Filter_SparseSparse(k, c, r, s, params); - } - else if(inputArray != null && doutArray == null) { - outputVal = doConv2d_Backward_Filter_DenseSparse(k, c, r, s, params, inputArray); - } - else if(inputArray == null && doutArray != null) { - outputVal = doConv2d_Backward_Filter_SparseDense(k, c, r, s, params, doutArray); - } - else { - outputVal = doConv2d_Backward_Filter_DenseDense(k, c, r, s, params, inputArray, doutArray); - } - - outputArray[k*params.C*params.R*params.S + c*params.R*params.S + r*params.S + s] = outputVal; - } - - private static double doConv2d_Backward_Filter_SparseDense(int k, int c, int r, int s, ConvolutionParameters params, double [] doutArray) throws DMLRuntimeException { - double outputVal = 0; - // To ensure h >= 0 && h < params.H - int pMin = (int) Math.max(0, Math.ceil(((double)(params.pad_h-r))/params.stride_h)); - int qMin = (int) Math.max(0, Math.ceil(((double)(params.pad_w-s))/params.stride_w)); - // To ensure w >= 0 && w < params.W - int pMax = (int) Math.min(params.P, Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h)); - int qMax = (int) Math.min(params.Q, Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w)); - - // TODO: Optimize this case - for (int n = 0; n < params.N; n++) { - int doutOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; - for (int p = pMin; p < pMax; p++) { - for (int q = qMin; q < qMax; q++) { - int h = p*params.stride_h + r - params.pad_h; - int w = q*params.stride_w + s - params.pad_w; - outputVal += doutArray[doutOffset + p*params.Q + q]*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w); - } - } - } - - return outputVal; - } - - private static double doConv2d_Backward_Filter_DenseDense(int k, int c, int r, int s, ConvolutionParameters params, double [] inputArray, double [] doutArray) { - double outputVal = 0; - // To ensure h >= 0 && h < params.H - int pMin = (int) Math.max(0, Math.ceil(((double)(params.pad_h-r))/params.stride_h)); - int qMin = (int) Math.max(0, Math.ceil(((double)(params.pad_w-s))/params.stride_w)); - // To ensure w >= 0 && w < params.W - int pMax = (int) Math.min(params.P, Math.ceil(((double)(params.H+params.pad_h-r))/params.stride_h)); - int qMax = (int) Math.min(params.Q, Math.ceil(((double)(params.W+params.pad_w-s))/params.stride_w)); - - for (int n = 0; n < params.N; n++) { - int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W + s - params.pad_w; - int doutOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; - for (int p = pMin; p < pMax; p++) { - int h = p*params.stride_h + r - params.pad_h; - for (int q = qMin; q < qMax; q++) { - int w = q*params.stride_w; - outputVal += doutArray[doutOffset + p*params.Q + q]*inputArray[inputOffset + h*params.W+w]; - } - } - } - - return outputVal; - } - private static void computeTensorIndexes(int i, int j, int [] ret, int N, int C, int H, int W) throws DMLRuntimeException { ret[0] = i; ret[1] = j / (H*W); @@ -448,56 +418,7 @@ public class LibMatrixDNN { ret[3] = j % W; } - private static double doConv2d_Backward_Filter_DenseSparse(int k, int c, int r, int s, ConvolutionParameters params, double [] inputArray) throws DMLRuntimeException { - MatrixBlock dout = params.input2; - double outputVal = 0; - Iterator<IJV> iter = dout.sparseBlock.getIterator(); - int [] tensorIndexes = new int[4]; - while(iter.hasNext()) { - IJV ijv = iter.next(); - computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.K, params.P, params.Q); - if(k == tensorIndexes[1]) { - int n = tensorIndexes[0]; - int p = tensorIndexes[2]; - int q = tensorIndexes[3]; - - double doutVal = ijv.getV(); - int h = p*params.stride_h + r - params.pad_h; - int w = q*params.stride_w + s - params.pad_w; - if(h >= 0 && h < params.H && w >= 0 && w < params.W) { - outputVal += doutVal*inputArray[n*params.C*params.H*params.W + c*params.H*params.W + h*params.W+w]; - } - } - } - return outputVal; - } - - private static double doConv2d_Backward_Filter_SparseSparse(int k, int c, int r, int s, ConvolutionParameters params) throws DMLRuntimeException { - MatrixBlock dout = params.input2; - double outputVal = 0; - Iterator<IJV> iter = dout.sparseBlock.getIterator(); - int [] tensorIndexes = new int[4]; - - while(iter.hasNext()) { - IJV ijv = iter.next(); - computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.K, params.P, params.Q); - if(k == tensorIndexes[1]) { - int n = tensorIndexes[0]; - int p = tensorIndexes[2]; - int q = tensorIndexes[3]; - - double doutVal = ijv.getV(); - int h = p*params.stride_h + r - params.pad_h; - int w = q*params.stride_w + s - params.pad_w; - if(h >= 0 && h < params.H && w >= 0 && w < params.W) { - outputVal += doutVal*params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w); - } - } - } - return outputVal; - } - - public static void conv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, ConvolutionParameters params, boolean useMemoryLessConvolution) throws DMLRuntimeException { + public static void conv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException { params.input1 = input; params.input2 = filter; params.output = outputBlock; @@ -516,36 +437,18 @@ public class LibMatrixDNN { } } - if(useMemoryLessConvolution) { - fillInTemporaryConvolutionData(input, params); - } - else - params.reuseNonZeroedOutput = true; - + params.reuseNonZeroedOutput = true; int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads); - if(!ALLOW_MULTI_THREADED_OPS || constrainedNumThreads <= 1) { warnSingleThreaded(); - if(useMemoryLessConvolution) { - for (int n = 0; n < params.N; n++) { - for (int k = 0; k < params.K; k++) { - doLoopBasedConv2d(n, n+1, k, params); - } - } - } - else { - MatrixBlock im2ColOutBlock = new MatrixBlock(params.C*params.R*params.S, params.P*params.Q, false); - im2ColOutBlock.allocateDenseBlock(true); - for (int n = 0; n < params.N; n++) { - doLoopedIm2ColConv2d(n, im2ColOutBlock, params); - } + MatrixBlock im2ColOutBlock = new MatrixBlock(params.C*params.R*params.S, params.P*params.Q, false); + im2ColOutBlock.allocateDenseBlock(true); + for (int n = 0; n < params.N; n++) { + doLoopedIm2ColConv2d(n, im2ColOutBlock, params); } } else { - if(useMemoryLessConvolution) - runConvTask(constrainedNumThreads, params.K, TaskType.LoopBasedConv2d, params); - else - runConvTask(constrainedNumThreads, 1, TaskType.LoopedIm2ColConv2d, params); + runConvTask(constrainedNumThreads, 1, TaskType.LoopedIm2ColConv2d, params); } } @@ -583,52 +486,6 @@ public class LibMatrixDNN { } - private static void fillInTemporaryConvolutionData(MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException { - params.tmpData = new TemporaryConvolutionData(); - if(input.isInSparseFormat()) { - params.tmpData.minIndexArrR = new int[params.H]; - params.tmpData.minIndexArrS = new int[params.W]; - for(int h = 0; h < params.H; h++) { - for (int r = 0; r < params.R; r++) { - // int h = p*params.stride_h + r - params.pad_h; - if((h + params.pad_h - r) % params.stride_h == 0) { - params.tmpData.minIndexArrR[h] = r; - break; - } - } - } - for(int w = 0; w < params.W; w++) { - for (int s = 0; s < params.S; s++) { - // int h = p*params.stride_h + r - params.pad_h; - if((w + params.pad_w - s) % params.stride_w == 0) { - params.tmpData.minIndexArrS[w] = s; - break; - } - } - } - } - else { - params.tmpData.minIndexArrR = new int[params.R]; - params.tmpData.maxIndexArrR = new int[params.R]; - params.tmpData.minIndexArrS = new int[params.S]; - params.tmpData.maxIndexArrS = new int[params.S]; - for (int r = 0; r < params.R; r++) { - params.tmpData.minIndexArrR[r] = getMinPQ(params.pad_h, r, params.stride_h); - params.tmpData.maxIndexArrR[r] = getMaxPQ(params.pad_h, r, params.stride_h, params.P, params.H); - } - for (int s = 0; s < params.S; s++) { - params.tmpData.minIndexArrS[s] = getMinPQ(params.pad_w, s, params.stride_w); - params.tmpData.maxIndexArrS[s] = getMaxPQ(params.pad_w, s, params.stride_w, params.Q, params.W); - } - params.tmpData.minCommonIndexS = params.tmpData.minIndexArrS[0]; - params.tmpData.maxCommonIndexS = params.tmpData.maxIndexArrS[0]; - for (int s = 1; s < params.S; s++) { - params.tmpData.minCommonIndexS = Math.max(params.tmpData.minCommonIndexS, params.tmpData.minIndexArrS[s]); - params.tmpData.maxCommonIndexS = Math.min(params.tmpData.maxCommonIndexS, params.tmpData.maxIndexArrS[s]); - } - } - } - public static void maxpooling_backward(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException { params.input1 = input; params.input2 = dout; @@ -665,236 +522,6 @@ public class LibMatrixDNN { } } - private static void doLoopBasedConv2dDenseDense(int n1, int n2, int k, ConvolutionParameters params, - double [] inputArray, double [] filterArray) { - double [] outputArray = params.output.getDenseBlock(); - int [] minIndexArrR = params.tmpData.minIndexArrR; - int [] maxIndexArrR = params.tmpData.maxIndexArrR; - int [] minIndexArrS = params.tmpData.minIndexArrS; - int [] maxIndexArrS = params.tmpData.maxIndexArrS; - - final int minCommonIndexS = params.tmpData.minCommonIndexS; - final int maxCommonIndexS = params.tmpData.maxCommonIndexS; - - final int minS = (params.S >= 4) ? (params.S - params.S % 4) : 0; - - for (int n = n1; n < n2; n++) { - for (int c = 0; c < params.C; c++) { - for (int r = 0; r < params.R; r++) { - final int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S; - for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { - final int h = p*params.stride_h + r - params.pad_h; - final int inputOffSet = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W - params.pad_w; - final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q + p*params.Q; - - for (int q = minCommonIndexS; q < maxCommonIndexS; q++) { - final int wOffset = inputOffSet + q*params.stride_w; - // ------------------------------------------------------------------------ - // Efficient striding with vectorization - final int outOffsetWithQ = outputOffset + q; - for (int s = 0; s < minS; s += 4) { - final int inOffsetWithS = wOffset + s; - final int filterOffsetWithS = filterOffset + s; - outputArray[outOffsetWithQ] += inputArray[inOffsetWithS]*filterArray[filterOffsetWithS] - + inputArray[inOffsetWithS+1]*filterArray[filterOffsetWithS+1] - + inputArray[inOffsetWithS+2]*filterArray[filterOffsetWithS+2] - + inputArray[inOffsetWithS+3]*filterArray[filterOffsetWithS+3]; - } - // ------------------------------------------------------------------------ - // Efficient striding without vectorization - for (int s = minS; s < params.S; s++) { - outputArray[outputOffset + q] += inputArray[wOffset + s]*filterArray[filterOffset + s]; - } - // ------------------------------------------------------------------------ - } - // ------------------------------------------------------------------------ - // Inefficient striding - for (int s = 0; s < params.S; s++) { - for (int q = minIndexArrS[s]; q < minCommonIndexS; q++) { - final int w = q*params.stride_w + s; - outputArray[outputOffset + q] += inputArray[inputOffSet + w]*filterArray[filterOffset + s]; - } - for (int q = maxCommonIndexS; q < maxIndexArrS[s]; q++) { - final int w = q*params.stride_w + s; - outputArray[outputOffset + q] += inputArray[inputOffSet + w]*filterArray[filterOffset + s]; - } - } - // ------------------------------------------------------------------------ - } - } - } - } - } - - private static void doLoopBasedConv2dDenseSparse(int n, int k, ConvolutionParameters params, double [] inputArray) throws DMLRuntimeException { - double [] outputArray = params.output.getDenseBlock(); - int [] minIndexArrR = params.tmpData.minIndexArrR; - int [] maxIndexArrR = params.tmpData.maxIndexArrR; - int [] minIndexArrS = params.tmpData.minIndexArrS; - int [] maxIndexArrS = params.tmpData.maxIndexArrS; - final int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; - - Iterator<IJV> iter = params.input2.sparseBlock.getIterator(k, k+1); - int [] tensorIndexes = new int[4]; - - while(iter.hasNext()) { - IJV ijv = iter.next(); - computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.K, params.C, params.R, params.S); - int c = tensorIndexes[1]; - int r = tensorIndexes[2]; - int s = tensorIndexes[3]; - double filterVal = ijv.getV(); - final int inputOffset = n*params.C*params.H*params.W + c*params.H*params.W + s - params.pad_w; - for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { - final int hOffset = inputOffset + (p*params.stride_h + r - params.pad_h)*params.W; - final int pOffset = outputOffset + p*params.Q; - for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) { - final int w = q*params.stride_w; - outputArray[pOffset + q] += inputArray[hOffset + w]*filterVal; - } - } - } - } - - private static void doLoopBasedConv2dSparseDense(int n, int k, ConvolutionParameters params, double [] filterArray) throws DMLRuntimeException { - double [] outputArray = params.output.getDenseBlock(); - int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; - - Iterator<IJV> iter = params.input1.sparseBlock.getIterator(n, n+1); - int [] tensorIndexes = new int[4]; - - int [] minIndexArrR = params.tmpData.minIndexArrR; - int [] minIndexArrS = params.tmpData.minIndexArrS; - while(iter.hasNext()) { - IJV ijv = iter.next(); - computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.N, params.C, params.H, params.W); - - int c = tensorIndexes[1]; - int h = tensorIndexes[2]; - int w = tensorIndexes[3]; - double imgVal = ijv.getV(); - for (int r = minIndexArrR[h]; r < params.R; r += params.stride_h) { - int filterOffset = k*params.C*params.R*params.S + c*params.R*params.S + r*params.S; - for (int s = minIndexArrS[w]; s < params.S; s += params.stride_w) { - int p = (int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h); - int q = (int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w); - if(p >= 0 && p < params.P && q >= 0 && q < params.Q) { - double filterVal = filterArray[filterOffset + s]; - outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal; - } - } - } - } - } - - private static void doLoopBasedConv2dSparseSparse(int n, int k, ConvolutionParameters params) throws DMLRuntimeException { - double [] outputArray = params.output.getDenseBlock(); - int [] minIndexArrR = params.tmpData.minIndexArrR; - int [] maxIndexArrR = params.tmpData.maxIndexArrR; - int [] minIndexArrS = params.tmpData.minIndexArrS; - int [] maxIndexArrS = params.tmpData.maxIndexArrS; - int outputOffset = n*params.K*params.P*params.Q + k*params.P*params.Q; - - - int [] tensorIndexesImage = new int[4]; - int [] tensorIndexesFilter = new int[4]; - - Iterator<IJV> iter = params.input1.sparseBlock.getIterator(n, n+1); - - while(iter.hasNext()) { - IJV ijv = iter.next(); - computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexesImage, params.N, params.C, params.H, params.W); - if(n == tensorIndexesImage[0]) { - int c = tensorIndexesImage[1]; - int h = tensorIndexesImage[2]; - int w = tensorIndexesImage[3]; - double imgVal = ijv.getV(); - - Iterator<IJV> iter1 = params.input2.sparseBlock.getIterator(k, k+1); - while(iter1.hasNext()) { - IJV ijv1 = iter1.next(); - computeTensorIndexes(ijv1.getI(), ijv1.getJ(), tensorIndexesFilter, params.K, params.C, params.R, params.S); - if(c == tensorIndexesFilter[1]) { - int r = tensorIndexesFilter[2]; - int s = tensorIndexesFilter[3]; - if((r-minIndexArrR[h])%params.stride_h == 0 && (s-minIndexArrS[w])%params.stride_w == 0) { - int p = (int)Math.ceil(((double)(h + params.pad_h - r)) / params.stride_h); - int q = (int)Math.ceil(((double)(w + params.pad_w - s)) / params.stride_w); - if(p >= 0 && p < params.P && q >= 0 && q < params.Q) { - double filterVal = ijv1.getV(); - outputArray[outputOffset + p*params.Q + q] += imgVal*filterVal; - } - } - } - } - } - } - - while(iter.hasNext()) { - IJV ijv = iter.next(); - computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexesFilter, params.K, params.C, params.R, params.S); - if(k == tensorIndexesFilter[0]) { - int c = tensorIndexesFilter[1]; - int r = tensorIndexesFilter[2]; - int s = tensorIndexesFilter[3]; - double filterVal = ijv.getV(); - for (int p = minIndexArrR[r]; p < maxIndexArrR[r]; p++) { - int h = p*params.stride_h + r - params.pad_h; - for (int q = minIndexArrS[s]; q < maxIndexArrS[s]; q++) { - int w = q*params.stride_w + s - params.pad_w; - // TODO: Improve the performance of sparse sparse - outputArray[outputOffset + p*params.Q + q] += sparseConvMultiply(filterVal, params, n, c, h, w); - } - } - } - } - } - - /** - * This is essentially memory-less operation and can be used when the memory pressure is extremely high. - * @param n - * @param k - * @param params - * @throws DMLRuntimeException - */ - private static void doLoopBasedConv2d(int n1, int n2, int k, ConvolutionParameters params) throws DMLRuntimeException { - double [] inputArray = null; - if (!params.input1.isInSparseFormat()) - inputArray = params.input1.getDenseBlock(); - double [] filterArray = null; - if (!params.input2.isInSparseFormat()) - filterArray = params.input2.getDenseBlock(); - - if(inputArray != null && filterArray != null) { - doLoopBasedConv2dDenseDense(n1, n2, k, params, inputArray, filterArray); - } - else if(inputArray != null && filterArray == null) { - for (int n = n1; n < n2; n++) - doLoopBasedConv2dDenseSparse(n, k, params, inputArray); - } - else if(inputArray == null && filterArray != null) { - for (int n = n1; n < n2; n++) - doLoopBasedConv2dSparseDense(n, k, params, filterArray); - } - else if(inputArray == null && filterArray == null) { - for (int n = n1; n < n2; n++) - doLoopBasedConv2dSparseSparse(n, k, params); - } - } - - private static int getMinPQ(int pad, int filterSize, int stride) { - return Math.max(0, (int)Math.ceil(((double)(pad - filterSize))/stride)); - } - - private static int getMaxPQ(int pad, int filterSize, int stride, int outputSize, int inputSize) { - return Math.min(outputSize, (int)Math.ceil(((double)(inputSize + pad - filterSize)) / stride)); - } - - private static double sparseConvMultiply(double filterVal, ConvolutionParameters params, - int n, int c, int h, int w) { - return params.input1.quickGetValue(n, c*params.H*params.W + h*params.W + w)*filterVal; - } - private static void doPoolingBackward(int n, ConvolutionParameters params) throws DMLRuntimeException { double [] inputArray = null; if (!params.input1.isInSparseFormat()) @@ -1251,13 +878,6 @@ public class LibMatrixDNN { } } - private static void runConvTask(int constrainedNumThreads, int NSize, int Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException { - if (params.isOutputThreadSafe() && constrainedNumThreads > 1) - runParallelConvTask(constrainedNumThreads, NSize, Z, type, params); - else - runSequentialConvTask(NSize, Z, type, params); - } - private static void runConvTask(int constrainedNumThreads, int Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException { if (params.isOutputThreadSafe() && constrainedNumThreads > 1) runParallelConvTask(constrainedNumThreads, params.N, Z, type, params); @@ -1267,13 +887,21 @@ public class LibMatrixDNN { private static void runParallelConvTask(int constrainedNumThreads, int NSize, int Z, TaskType type, ConvolutionParameters params) throws DMLRuntimeException { ArrayList<ConvTask> tasks = new ArrayList<ConvTask>(); - int [] taskSizes = getTaskSize(constrainedNumThreads, NSize, Z); - for (int n = 0; n < NSize; n += taskSizes[0]) { - for (int z = 0; z < Z; z += taskSizes[1]) { - tasks.add(new ConvTask(n, Math.min(NSize, n+taskSizes[0]), z, Math.min(Z, z+taskSizes[1]), type, params)); + if(NSize >= constrainedNumThreads || Z == 1) { + int numNTasks = (int) Math.ceil(((double) NSize) / constrainedNumThreads); + for (int n = 0; n < NSize; n += numNTasks) { + tasks.add(new ConvTask(n, Math.min(NSize, n+numNTasks), 0, Z, type, params)); + } + } + else { + int [] taskSizes = getTaskSize(constrainedNumThreads, NSize, Z); + for (int n = 0; n < NSize; n += taskSizes[0]) { + for (int z = 0; z < Z; z += taskSizes[1]) { + tasks.add(new ConvTask(n, Math.min(NSize, n+taskSizes[0]), z, Math.min(Z, z+taskSizes[1]), type, params)); + } } + LOG.debug("Reduce number of tasks from " + (NSize*Z) + "(" + NSize + "," + Z + ") to " + tasks.size()); } - LOG.debug("Reduce number of tasks from " + (NSize*Z) + "(" + NSize + "," + Z + ") to " + tasks.size()); ExecutorService pool = Executors.newFixedThreadPool( Math.min(constrainedNumThreads, tasks.size()) ); List<Future<Object>> taskret; @@ -1314,19 +942,19 @@ public class LibMatrixDNN { switch(type) { case ReshapeCol: for (int n = n1; n < n2; n++) { - LibMatrixDNN.doReshapeCol(n, params); + doReshapeCol(n, params); } break; case Rotate180: for (int n = n1; n < n2; n++) { - LibMatrixDNN.doRotate180(n, params); + doRotate180(n, params); } break; case Im2Col: long nnz = 0; for (int n = n1; n < n2; n++) { for (int z = z1; z < z2; z++) { - nnz += LibMatrixDNN.doIm2colOverInputPath_NCHW(n, z, params); + nnz += doIm2colOverInputPath_NCHW(n, z, params); } } params.outputNNZ.addAndGet(nnz); @@ -1334,46 +962,31 @@ public class LibMatrixDNN { case Col2Im: for (int n = n1; n < n2; n++) { for (int z = z1; z < z2; z++) { - LibMatrixDNN.doCol2imOverInputPath_NCHW(n, z, params); + doCol2imOverInputPath_NCHW(n, z, params); } } break; case MaxPooling_Forward: for (int n = n1; n < n2; n++) { for (int z = z1; z < z2; z++) { - LibMatrixDNN.doPooling(n, z, params); + doPooling(n, z, params); } } break; case MaxPooling_Backward: for (int n = n1; n < n2; n++) { - LibMatrixDNN.doPoolingBackward(n, params); - } - break; - case LoopBasedConv2d: - for (int z = z1; z < z2; z++) { - LibMatrixDNN.doLoopBasedConv2d(n1, n2, z, params); + doPoolingBackward(n, params); } break; case LoopedIm2ColConv2d: MatrixBlock im2ColOutBlock = new MatrixBlock(params.C*params.R*params.S, params.P*params.Q, false); im2ColOutBlock.allocateDenseBlock(true); for (int n = n1; n < n2; n++) { - LibMatrixDNN.doLoopedIm2ColConv2d(n, im2ColOutBlock, params); - } - break; - case LoopBasedConv2dBwdFilter: - for (int x = n1; x < n2; x++) { - int k = x / params.C; - int c = x % params.C; - for (int y = z1; y < z2; y++) { - int r = y / params.S; - int s = y % params.S; - doConv2d_Backward_Filter(k, c, r, s, params); - } + doLoopedIm2ColConv2d(n, im2ColOutBlock, params); } break; case LoopedIm2ColConv2dBwdFilter: + { MatrixBlock im2ColOutBlock1 = new MatrixBlock(params.C*params.R*params.S, params.P*params.Q, false); im2ColOutBlock1.allocateDenseBlock(true); MatrixBlock partialRetBlock = new MatrixBlock(params.K, params.C*params.R*params.S, false); @@ -1381,9 +994,19 @@ public class LibMatrixDNN { MatrixBlock dout_reshaped = new MatrixBlock(params.P*params.Q, params.K, false); dout_reshaped.allocateDenseBlock(true); for (int n = n1; n < n2; n++) { - partialRetBlock = LibMatrixDNN.doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock1, dout_reshaped, partialRetBlock, params); + partialRetBlock = doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock1, dout_reshaped, partialRetBlock, params); } return partialRetBlock; + } + case LoopedIm2ColConv2dBwdData: + { + MatrixBlock dout_reshaped = new MatrixBlock(params.P*params.Q, params.K, false); + dout_reshaped.allocateDenseBlock(true); + for (int n = n1; n < n2; n++) { + doLoopedIm2ColConv2dBwdData(n, dout_reshaped, params); + } + break; + } default: throw new DMLRuntimeException("Unsupported ConvTask:" + type.name()); } @@ -1470,6 +1093,67 @@ public class LibMatrixDNN { } } + + // Converts input: PQ X CRS matrix and writes to 1 X CHW + private static void doCol2imOverSingleImage(int n, MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException { + if(input.rlen != params.P*params.Q || input.clen != params.C*params.R*params.S) { + throw new DMLRuntimeException("Incorrect input dimensions"); + } + + double [] outputArray = null; + if (!params.output.isInSparseFormat()) + outputArray = params.output.getDenseBlock(); + else { + throw new DMLRuntimeException("Only dense output is implemented"); + } + + if(!input.isInSparseFormat()) { + double [] inputArray = input.getDenseBlock(); + doCol2IMDenseInput(n, inputArray, outputArray, params); + } + else { + doCol2IMSparseInput(n, input.getSparseBlockIterator(), outputArray, params); + } + } + + private static void doCol2IMSparseInput(int n, Iterator<IJV> inputIter, double [] outputArray, ConvolutionParameters params) throws DMLRuntimeException { + int [] tensorIndexes = new int[4]; + while(inputIter.hasNext()) { + IJV ijv = inputIter.next(); + computeTensorIndexes(ijv.getI(), ijv.getJ(), tensorIndexes, params.P*params.Q, params.C, params.R, params.S); + int c = tensorIndexes[1]; + int r = tensorIndexes[2]; + int s = tensorIndexes[3]; + int p = ijv.getI() / params.Q; + int q = ijv.getI() % params.Q; + int h = p*params.stride_h + r - params.pad_h; + int w = q*params.stride_w + s - params.pad_w; + if(h >= 0 && h < params.H && w >= 0 && w < params.W) { + int outIndex = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W + w; + outputArray[outIndex] += ijv.getV(); + } + } + } + + private static void doCol2IMDenseInput(int n, double [] inputArray, double [] outputArray, ConvolutionParameters params) throws DMLRuntimeException { + for (int c = 0; c < params.C; c++) { + for (int r = 0; r < params.R; r++) { // Get an input patch of size R X S + for (int s = 0; s < params.S; s++) { + for (int p = 0; p < params.P; p++) { + for (int q = 0; q < params.Q; q++) { + int inputIndex = (p*params.Q + q)*params.C*params.R*params.S + c*params.R*params.S + r*params.S + s; + int h = p*params.stride_h + r - params.pad_h; + int w = q*params.stride_w + s - params.pad_w; + if(h >= 0 && h < params.H && w >= 0 && w < params.W) { + int outIndex = n*params.C*params.H*params.W + c*params.H*params.W + h*params.W + w; + outputArray[outIndex] += inputArray[inputIndex]; + } + } + } + } + } + } + } private static void doCol2imOverInputPath_NCHW(int n, int c, ConvolutionParameters params) { double [] inputArray = null; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d79dea92/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java index db244ff..7a83278 100644 --- a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java +++ b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java @@ -190,45 +190,42 @@ public class ConvolutionUtils { } public static Lop constructConvolutionBackwardDataLops(Hop currentHop, ExecType et) throws HopsException, LopsException { - return null; // Until we add CP conv2d_backward_data + if(DMLScript.USE_ACCELERATOR) + et = ExecType.GPU; // TODO: Add memory estimate checks + else if(et == ExecType.MR || et == ExecType.SPARK) + return null; - //TODO: uncomment the following after CP conv2d_backward_data is added -// if(DMLScript.USE_ACCELERATOR) -// et = ExecType.GPU; // TODO: Add memory estimate checks -// else -// return null; -// -// if(currentHop != null && isConvolutionOp(currentHop, ConvOp.COL2IM)) { -// Hop temp = currentHop.getInput().get(0); -// if(temp != null && isTranspose(temp)) { -// Hop matMult = temp.getInput().get(0); -// if(matMult != null && isMatMult(matMult)) { -// Hop rotate180 = matMult.getInput().get(0); -// Hop filter = matMult.getInput().get(1); -// if(isConvolutionOp(rotate180, ConvOp.ROTATE180)) { -// ArrayList<Hop> inputs = new ArrayList<Hop>(); -// inputs.add(filter); -// inputs.add(rotate180.getInput().get(0)); -// for(int i = 1; i < rotate180.getInput().size(); i++) { -// inputs.add(rotate180.getInput().get(i)); -// } -// -// // N, C * H * W -// long N = currentHop.computeSizeInformation(inputs.get(6)); -// long C = currentHop.computeSizeInformation(inputs.get(7)); -// long H = currentHop.computeSizeInformation(inputs.get(8)); -// long W = currentHop.computeSizeInformation(inputs.get(9)); -// long rlen = N; -// long clen = ConvolutionOp.getExtractedVal(C, H, W); -// return ConvolutionOp.constructFusedConvolutionLops(et, inputs, ConvOp.DIRECT_CONV2D_BACKWARD_DATA, (ConvolutionOp) rotate180, rlen, clen); -// -// -// } -// } -// } -// } -// -// return null; + if(currentHop != null && isConvolutionOp(currentHop, ConvOp.COL2IM)) { + Hop temp = currentHop.getInput().get(0); + if(temp != null && isTranspose(temp)) { + Hop matMult = temp.getInput().get(0); + if(matMult != null && isMatMult(matMult)) { + Hop rotate180 = matMult.getInput().get(0); + Hop filter = matMult.getInput().get(1); + if(isConvolutionOp(rotate180, ConvOp.ROTATE180)) { + ArrayList<Hop> inputs = new ArrayList<Hop>(); + inputs.add(filter); + inputs.add(rotate180.getInput().get(0)); + for(int i = 1; i < rotate180.getInput().size(); i++) { + inputs.add(rotate180.getInput().get(i)); + } + + // N, C * H * W + long N = currentHop.computeSizeInformation(inputs.get(6)); + long C = currentHop.computeSizeInformation(inputs.get(7)); + long H = currentHop.computeSizeInformation(inputs.get(8)); + long W = currentHop.computeSizeInformation(inputs.get(9)); + long rlen = N; + long clen = ConvolutionOp.getExtractedVal(C, H, W); + return ConvolutionOp.constructFusedConvolutionLops(et, inputs, ConvOp.DIRECT_CONV2D_BACKWARD_DATA, (ConvolutionOp) rotate180, rlen, clen); + + + } + } + } + } + + return null; }
