[SYSTEMML-540] Initial implementation of conv2d/pooling builtin function
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/c334c2c8 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/c334c2c8 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/c334c2c8 Branch: refs/heads/master Commit: c334c2c85bc9cbb343e63b5b28ff3a1c5098c7fa Parents: 946b663 Author: Niketan Pansare <[email protected]> Authored: Mon May 16 16:43:39 2016 -0800 Committer: Niketan Pansare <[email protected]> Committed: Mon May 16 17:45:39 2016 -0700 ---------------------------------------------------------------------- docs/devdocs/deep-learning.md | 122 ++++ .../java/org/apache/sysml/api/DMLScript.java | 2 + .../org/apache/sysml/hops/ConvolutionOp.java | 472 ++++++++++++++++ src/main/java/org/apache/sysml/hops/Hop.java | 20 + .../apache/sysml/lops/ConvolutionTransform.java | 216 +++++++ src/main/java/org/apache/sysml/lops/Lop.java | 19 + .../java/org/apache/sysml/lops/compile/Dag.java | 41 +- .../sysml/parser/BuiltinFunctionExpression.java | 146 ++++- .../org/apache/sysml/parser/DMLTranslator.java | 158 ++++++ .../org/apache/sysml/parser/DataExpression.java | 3 +- .../org/apache/sysml/parser/Expression.java | 2 + .../org/apache/sysml/parser/ExpressionList.java | 69 +++ .../java/org/apache/sysml/parser/dml/Dml.g4 | 2 +- .../sysml/parser/dml/DmlSyntacticValidator.java | 18 +- .../controlprogram/caching/CacheableData.java | 4 +- .../controlprogram/caching/MatrixObject.java | 16 + .../instructions/CPInstructionParser.java | 14 +- .../runtime/instructions/cp/CPInstruction.java | 2 +- .../cp/ConvolutionCPInstruction.java | 297 ++++++++++ .../sysml/runtime/matrix/data/LibMatrixDNN.java | 564 +++++++++++++++++++ .../sysml/runtime/matrix/data/MatrixBlock.java | 33 +- .../sysml/runtime/util/ConvolutionUtils.java | 43 ++ .../java/org/apache/sysml/utils/Statistics.java | 16 + .../tensor/Conv2DBackwardDataTest.java | 137 +++++ .../functions/tensor/Conv2DBackwardTest.java | 145 +++++ .../functions/tensor/Conv2DTest.java | 134 +++++ .../functions/tensor/PoolBackwardTest.java | 129 +++++ .../integration/functions/tensor/PoolTest.java | 134 +++++ .../functions/tensor/Conv2DBackwardDataTest.R | 104 ++++ .../functions/tensor/Conv2DBackwardDataTest.dml | 36 ++ .../functions/tensor/Conv2DBackwardTest.R | 107 ++++ .../functions/tensor/Conv2DBackwardTest.dml | 36 ++ src/test/scripts/functions/tensor/Conv2DTest.R | 98 ++++ .../scripts/functions/tensor/Conv2DTest.dml | 34 ++ .../scripts/functions/tensor/PoolBackwardTest.R | 76 +++ .../functions/tensor/PoolBackwardTest.dml | 43 ++ src/test/scripts/functions/tensor/PoolTest.R | 98 ++++ src/test/scripts/functions/tensor/PoolTest.dml | 38 ++ 38 files changed, 3610 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/docs/devdocs/deep-learning.md ---------------------------------------------------------------------- diff --git a/docs/devdocs/deep-learning.md b/docs/devdocs/deep-learning.md new file mode 100644 index 0000000..46f2502 --- /dev/null +++ b/docs/devdocs/deep-learning.md @@ -0,0 +1,122 @@ +# Initial prototype for Deep Learning + +## Representing tensor and images in SystemML + +In this prototype, we represent a tensor as a matrix stored in a row-major format, +where first dimension of tensor and matrix are exactly the same. For example, a tensor (with all zeros) +of shape [3, 2, 4, 5] can be instantiated by following DML statement: +```sh +A = matrix(0, rows=3, cols=2*4*5) +``` +### Tensor functions: + +#### Element-wise arithmetic operators: +Following operators work out-of-the box when both tensors X and Y have same shape: + +* Element-wise exponentiation: `X ^ Y` +* Element-wise unary minus: `-X` +* Element-wise integer division: `X %/% Y` +* Element-wise modulus operation: `X %% Y` +* Element-wise multiplication: `X * Y` +* Element-wise division: `X / Y` +* Element-wise addition: `X + Y` +* Element-wise subtraction: `X - Y` + +SystemML does not support implicit broadcast for above tensor operations, however one can write a DML-bodied function to do so. +For example: to perform the above operations with broadcasting on second dimensions, one can use the below `rep(Z, n)` function: +``` python +rep = function(matrix[double] Z, int C) return (matrix[double] ret) { + ret = Z + for(i in 2:C) { + ret = cbind(ret, Z) + } +} +``` +Using the above `rep(Z, n)` function, we can realize the element-wise arithmetic operation with broadcasting. Here are some examples: +* X of shape [N, C, H, W] and Y of shape [1, C, H, W]: `X + Y` (Note: SystemML does implicit broadcasting in this case because of the way +it represents the tensor) +* X of shape [1, C, H, W] and Y of shape [N, C, H, W]: `X + Y` (Note: SystemML does implicit broadcasting in this case because of the way +it represents the tensor) +* X of shape [N, C, H, W] and Y of shape [N, 1, H, W]: `X + rep(Y, C)` +* X of shape [N, C, H, W] and Y of shape [1, 1, H, W]: `X + rep(Y, C)` +* X of shape [N, 1, H, W] and Y of shape [N, C, H, W]: `rep(X, C) + Y` +* X of shape [1, 1, H, W] and Y of shape [N, C, H, W]: `rep(X, C) + Y` + +TODO: Map the NumPy tensor calls to DML expressions. + +## Representing images in SystemML + +The images are assumed to be stored NCHW format, where N = batch size, C = #channels, H = height of image and W = width of image. +Hence, the images are internally represented as a matrix with dimension (N, C * H * W). + +## Convolution and Pooling built-in functions + +This prototype also contains initial implementation of forward/backward functions for 2D convolution and pooling: +* `conv2d(x, w, ...)` +* `conv2d_backward_filter(x, dout, ...)` and `conv2d_backward_data(w, dout, ...)` +* `max_pool(x, ...)` and `max_pool_backward(x, dout, ...)` + +The required arguments for all above functions are: +* stride=[stride_h, stride_w] +* padding=[pad_h, pad_w] +* input_shape=[numImages, numChannels, height_image, width_image] + +The additional required argument for conv2d/conv2d_backward_filter/conv2d_backward_data functions is: +* filter_shape=[numFilters, numChannels, height_filter, width_filter] + +The additional required argument for max_pool/avg_pool functions is: +* pool_size=[height_pool, width_pool] + +The results of these functions are consistent with Nvidia's CuDNN library. + +### Border mode: +* To perform valid padding, use `padding = (input_shape-filter_shape)*(stride-1)/ 2`. (Hint: for stride length of 1, `padding = [0, 0]` performs valid padding). + +* To perform full padding, use `padding = ((stride-1)*input_shape + (stride+1)*filter_shape - 2*stride) / 2`. (Hint: for stride length of 1, `padding = [filter_h-1, filter_w-1]` performs full padding). + +* To perform same padding, use `padding = (input_shape*(stride-1) + filter_shape - stride)/2`. (Hint: for stride length of 1, `padding = [(filter_h-1)/2, (filter_w-1)/2]` performs same padding). + +### Explanation of backward functions for conv2d + +Consider one-channel 3 X 3 image = + +| x1 | x2 | x3 | +|----|----|----| +| x4 | x5 | x6 | +| x7 | x8 | x9 | + +and one 2 X 2 filter: + +| w1 | w2 | +|----|----| +| w3 | w4 | + +Then, `conv2d(x, w, stride=[1, 1], padding=[0, 0], input_shape=[1, 1, 3, 3], filter_shape=[1, 1, 2, 2])` produces following tensor +of shape `[1, 1, 2, 2]`, which is represented as `1 X 4` matrix in NCHW format: + +| `w1*x1 + w2*x2 + w3*x4 + w4*x5` | `w1*x2 + w2*x3 + w3*x5 + w4*x6` | `w1*x4 + w2*x5 + w3*x7 + w4*x8` | `w1*x5 + w2*x6 + w3*x8 + w4*x9` | +|---------------------------------|---------------------------------|---------------------------------|---------------------------------| + + +Let the error propagated from above layer is + +| y1 | y2 | y3 | y4 | +|----|----|----|----| + +Then `conv2d_backward_filter(x, y, stride=[1, 1], padding=[0, 0], input_shape=[1, 1, 3, 3], filter_shape=[1, 1, 2, 2])` produces following +updates for the filter: + +| `y1*x1 + y2*x2 + y3*x4 + y4*x5` | `y1*x2 + y2*x3 + y3*x5 + y4*x6` | +|---------------------------------|---------------------------------| +| `y1*x4 + y2*x5 + y3*x7 + y4*x8` | `y1*x5 + y2*x6 + y3*x8 + y4*x9` | + +Note: since the above update is a tensor of shape [1, 1, 2, 2], it will be represented as matrix of dimension [1, 4]. + +Similarly, `conv2d_backward_data(w, y, stride=[1, 1], padding=[0, 0], input_shape=[1, 1, 3, 3], filter_shape=[1, 1, 2, 2])` produces following +updates for the image: + + +| `w1*y1` | `w2*y1 + w1*y2` | `w2*y2` | +|-----------------|---------------------------------|-----------------| +| `w3*y1 + w1*y3` | `w4*y1 + w3*y2 + w2*y3 + w1*y4` | `w4*y2 + w2*y4` | +| `w3*y3` | `w4*y3 + w3*y4` | `w4*y4` | http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/api/DMLScript.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java index aa45eb9..570fb07 100644 --- a/src/main/java/org/apache/sysml/api/DMLScript.java +++ b/src/main/java/org/apache/sysml/api/DMLScript.java @@ -965,4 +965,6 @@ public class DMLScript throw new DMLException("Failed to run SystemML workspace cleanup.", ex); } } + + public static final boolean REUSE_NONZEROED_OUTPUT = false; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/hops/ConvolutionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java new file mode 100644 index 0000000..e9023d4 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java @@ -0,0 +1,472 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops; + +import java.util.ArrayList; + +import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.hops.Hop.MultiThreadedHop; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; +import org.apache.sysml.lops.ConvolutionTransform; +import org.apache.sysml.lops.Lop; +import org.apache.sysml.lops.LopsException; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.ConvolutionParameters; + +public class ConvolutionOp extends Hop implements MultiThreadedHop +{ + private Hop.ConvOp op; + + private int _maxNumThreads = -1; //-1 for unlimited + + private ConvolutionOp() { + //default constructor for clone + } + + public ConvolutionOp(String l, DataType dt, ValueType vt, ConvOp o, Hop inp) + { + super(l, dt, vt); + op = o; + getInput().add(0, inp); + inp.getParent().add(this); + + //compute unknown dims and nnz + refreshSizeInformation(); + } + + + public ConvolutionOp(String l, DataType dt, ValueType vt, ConvOp o, ArrayList<Hop> inp) + { + super(l, dt, vt); + op = o; + + for( int i=0; i<inp.size(); i++ ) { + Hop in = inp.get(i); + getInput().add(i, in); + in.getParent().add(this); + } + + //compute unknown dims and nnz + refreshSizeInformation(); + } + + public ConvOp getOp() + { + return op; + } + + @Override + public String getOpString() { + return "" + HopsConv2Lops.get(op); + } + + @Override + public Lop constructLops() + throws HopsException, LopsException + { + //return already created lops + if( getLops() != null ) + return getLops(); + + ExecType et = optFindExecType(); + ArrayList<Hop> inputs = getInput(); + switch( op ) + { + case IM2COL: + case RESHAPE_COL: + case ROTATE180: + case COL2IM: + case MAX_POOLING: + case MAX_POOLING_BACKWARD: + case DIRECT_CONV2D: + case DIRECT_CONV2D_BACKWARD_DATA: + case DIRECT_CONV2D_BACKWARD_FILTER: + { + if( et == ExecType.CP ) + { + setLops(constructConvolutionLops(et, inputs)); + break; + } + else { + // TODO: Add support for SPARK/MR backends once we are happy with the performance of + // single node Lenet script. + throw new HopsException("Unimplemented ConvolutionOp for execution type: " + et.name()); + } + // break; + } + default: + throw new HopsException("Unsupported lops construction for operation type '"+op+"'."); + } + + //add reblock/checkpoint lops if necessary + constructAndSetLopsDataFlowProperties(); + + return getLops(); + } + + public void setOp(ConvOp op) { + this.op = op; + } + + public Lop constructConvolutionLops(ExecType et, ArrayList<Hop> inputs) throws HopsException, LopsException { + int expectedNumInputs = 13; + if(op == ConvOp.MAX_POOLING_BACKWARD + || op == ConvOp.DIRECT_CONV2D + || op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER + || op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) { + expectedNumInputs = 14; + } + + if(inputs.size() != expectedNumInputs) { + throw new HopsException("Incorrect number of inputs for " + op.name()); + } + + Lop in = inputs.get(0).constructLops(); + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); + ConvolutionTransform transform1 = new ConvolutionTransform( in, + HopsConv2Lops.get(op), getDataType(), getValueType(), et, k); + setOutputDimensions(transform1); + setLineNumbers(transform1); + + // stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4 + for( int i=1; i <= (expectedNumInputs-1); i++ ) + { + Lop ltmp = inputs.get(i).constructLops(); + transform1.addInput(ltmp); + ltmp.addOutput(transform1); + } + transform1.setLevel(); //force order of added lops + return transform1; + } + + + @Override + protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) + { + double sparsity = 1.0; + switch(op) + { + case RESHAPE_COL: + case ROTATE180: + { + sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); + break; + } + case IM2COL: + case COL2IM: + case MAX_POOLING: + case MAX_POOLING_BACKWARD: + case DIRECT_CONV2D: + case DIRECT_CONV2D_BACKWARD_FILTER: + case DIRECT_CONV2D_BACKWARD_DATA: + sparsity = 1.0; // worst-case estimate + break; + } + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); + } + + @Override + protected double computeIntermediateMemEstimate( long dim1, long dim2, long nnz ) + { + //default: no intermediate memory requirements + return 0; + } + + @Override + protected long[] inferOutputCharacteristics( MemoTable memo ) + { + // [numRows, numCols, NNZ] + long[] ret = null; + + Hop input1 = getInput().get(0); + ConvolutionParameters params; + MatrixCharacteristics mc = memo.getAllInputStats(input1); + try { + params = parseInput(); + } catch (DMLRuntimeException e) { + throw new RuntimeException(e); + } + + switch(op) + { + case RESHAPE_COL: + { + ret = new long[3]; + ret[0] = params.N; + ret[1] = getExtractedVal(params.K, params.P, params.Q); + ret[2] = mc.getNonZeros(); // exact estimates + break; + } + case ROTATE180: + { + ret = new long[3]; + ret[0] = getExtractedVal(params.N, params.P, params.Q); + ret[1] = params.K; + ret[2] = mc.getNonZeros(); // exact estimates + break; + } + case IM2COL: + case COL2IM: + case MAX_POOLING: + case MAX_POOLING_BACKWARD: + case DIRECT_CONV2D: + case DIRECT_CONV2D_BACKWARD_FILTER: + case DIRECT_CONV2D_BACKWARD_DATA: + break; + } + + return ret; + } + + + @Override + public boolean allowsAllExecTypes() + { + return true; + } + + @Override + protected ExecType optFindExecType() throws HopsException { + + checkAndSetForcedPlatform(); + + ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR; + + if( _etypeForced != null ) + { + _etype = _etypeForced; + } + else + { + // TODO: After adding Spark backend, uncomment this + if ( OptimizerUtils.isMemoryBasedOptLevel() ) { + _etype = findExecTypeByMemEstimate(); + } + // Choose CP, if the input dimensions are below threshold or if the input is a vector + else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector() ) + { + _etype = ExecType.CP; + } + else + { + _etype = REMOTE; + } + + //check for valid CP dimensions and matrix size + checkAndSetInvalidCPDimsAndSize(); + } + + //mark for recompile (forever) + if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && _etype==REMOTE ) + setRequiresRecompile(); + + _etype = ExecType.CP; + + return _etype; + } + + // stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4 + ConvolutionParameters parseInput() throws DMLRuntimeException { + ConvolutionParameters params = new ConvolutionParameters( + extractValue(getInput().get(5)), + extractValue(getInput().get(6)), + extractValue(getInput().get(7)), + extractValue(getInput().get(8)), + extractValue(getInput().get(9)), + extractValue(getInput().get(11)), + extractValue(getInput().get(12)), + extractValue(getInput().get(1)), + extractValue(getInput().get(2)), + extractValue(getInput().get(3)), + extractValue(getInput().get(4)), _maxNumThreads); + return params; + } + + long getExtractedVal(long val1, long val2) { + if(val1 == -1 || val2 == -1) { + return -1; + } + return val1*val2; + } + + long getExtractedVal(long val1, long val2, long val3) { + if(val1 == -1 || val2 == -1 || val3 == -1) { + return -1; + } + return val1*val2*val3; + } + + @Override + public void refreshSizeInformation() + { + Hop input1 = getInput().get(0); + + ConvolutionParameters params; + try { + params = parseInput(); + } catch (DMLRuntimeException e) { + throw new RuntimeException(e); + } + + switch(op) + { + case IM2COL: + { + _dim1 = getExtractedVal(params.C, params.R, params.S); + _dim2 = getExtractedVal(params.N, params.P, params.Q); + _nnz = -1; + break; + } + case COL2IM: + { + // Set _dim1, _dim2 and if possible _nnz (use input1.getNnz()) + _dim1 = params.N; + _dim2 = getExtractedVal(params.C, params.H, params.W); + _nnz = -1; // cannot infer stats + break; + } + case RESHAPE_COL: + { + _dim1 = params.N; + _dim2 = getExtractedVal(params.K, params.P, params.Q); + _nnz = input1.getNnz(); // exact estimates + break; + } + case ROTATE180: + { + _dim1 = getExtractedVal(params.N, params.P, params.Q); + _dim2 = params.K; + _nnz = input1.getNnz(); // exact estimates + break; + } + case MAX_POOLING: + { + _dim1 = params.N; + _dim2 = getExtractedVal(params.C, params.P, params.Q); + _nnz = -1; // cannot infer stats + break; + } + case MAX_POOLING_BACKWARD: + { + _dim1 = params.N; + _dim2 = getExtractedVal(params.C, params.H, params.W); + _nnz = -1; + break; + } + case DIRECT_CONV2D: + { + _dim1 = params.N; + _dim2 = getExtractedVal(params.K, params.P, params.Q); + _nnz = -1; // cannot infer stats + break; + } + case DIRECT_CONV2D_BACKWARD_DATA: + { + _dim1 = params.N; + _dim2 = getExtractedVal(params.C, params.H, params.W); + _nnz = -1; // cannot infer stats + break; + } + case DIRECT_CONV2D_BACKWARD_FILTER: + { + _dim1 = params.K; + _dim2 = getExtractedVal(params.C, params.R, params.S); + _nnz = -1; // cannot infer stats + break; + } + default: + throw new RuntimeException("The sizes are not refreshed for " + op.name()); + } + } + + private long extractValue(Hop hop) { + if(hop instanceof LiteralOp) + return (long) HopRewriteUtils.getDoubleValueSafe((LiteralOp)hop); + return -1; + } + + @Override + public Object clone() throws CloneNotSupportedException + { + ConvolutionOp ret = new ConvolutionOp(); + + //copy generic attributes + ret.clone(this, false); + + //copy specific attributes + ret.op = op; + ret._maxNumThreads = _maxNumThreads; + return ret; + } + + @Override + public boolean compare( Hop that ) + { + if( !(that instanceof ConvolutionOp) ) + return false; + + ConvolutionOp that2 = (ConvolutionOp)that; + + boolean ret = (op == that2.op) + && (getInput().size()==that.getInput().size()) + && _maxNumThreads == that2._maxNumThreads; + + //compare all childs + if( ret ) //sizes matched + for( int i=0; i<_input.size(); i++ ) + ret &= getInput().get(i) == that2.getInput().get(i); + + return ret; + } + + + @Override + public void printMe() throws HopsException + { + if (LOG.isDebugEnabled()){ + if (getVisited() != VisitStatus.DONE) { + super.printMe(); + LOG.debug(" Operation: " + op); + for (Hop h : getInput()) { + h.printMe(); + } + } + setVisited(VisitStatus.DONE); + } + } + + @Override + public void setMaxNumThreads( int k ) { + _maxNumThreads = k; + } + + @Override + public int getMaxNumThreads() { + return _maxNumThreads; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 273c1a2..43193a9 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1083,6 +1083,12 @@ public abstract class Hop //DIAG_V2M, DIAG_M2V, }; + public enum ConvOp { + IM2COL, RESHAPE_COL, ROTATE180, COL2IM, + MAX_POOLING, MAX_POOLING_BACKWARD, + DIRECT_CONV2D, DIRECT_CONV2D_BACKWARD_FILTER, DIRECT_CONV2D_BACKWARD_DATA + }; + public enum DataGenMethod { RAND, SEQ, SINIT, SAMPLE, INVALID }; @@ -1147,6 +1153,20 @@ public abstract class Hop HopsTransf2Lops.put(ReOrgOp.SORT, org.apache.sysml.lops.Transform.OperationTypes.Sort); } + + protected static final HashMap<ConvOp, org.apache.sysml.lops.ConvolutionTransform.OperationTypes> HopsConv2Lops; + static { + HopsConv2Lops = new HashMap<ConvOp, org.apache.sysml.lops.ConvolutionTransform.OperationTypes>(); + HopsConv2Lops.put(ConvOp.IM2COL, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.IM2COL); + HopsConv2Lops.put(ConvOp.RESHAPE_COL, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.RESHAPE_COL); + HopsConv2Lops.put(ConvOp.ROTATE180, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.ROTATE180); + HopsConv2Lops.put(ConvOp.COL2IM, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.COL2IM); + HopsConv2Lops.put(ConvOp.MAX_POOLING, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.MAX_POOLING); + HopsConv2Lops.put(ConvOp.MAX_POOLING_BACKWARD, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.MAX_POOLING_BACKWARD); + HopsConv2Lops.put(ConvOp.DIRECT_CONV2D, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.DIRECT_CONV2D); + HopsConv2Lops.put(ConvOp.DIRECT_CONV2D_BACKWARD_FILTER, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.DIRECT_CONV2D_BACKWARD_FILTER); + HopsConv2Lops.put(ConvOp.DIRECT_CONV2D_BACKWARD_DATA, org.apache.sysml.lops.ConvolutionTransform.OperationTypes.DIRECT_CONV2D_BACKWARD_DATA); + } protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops; static { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java new file mode 100644 index 0000000..fdf280d --- /dev/null +++ b/src/main/java/org/apache/sysml/lops/ConvolutionTransform.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.lops; + +import org.apache.sysml.lops.LopProperties.ExecLocation; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.lops.compile.JobType; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; + +public class ConvolutionTransform extends Lop +{ + + + public enum OperationTypes { + IM2COL, + RESHAPE_COL, + ROTATE180, + COL2IM, + MAX_POOLING, + MAX_POOLING_BACKWARD, + DIRECT_CONV2D, DIRECT_CONV2D_BACKWARD_FILTER, DIRECT_CONV2D_BACKWARD_DATA + }; + + private OperationTypes operation = null; + private int numThreads = -1; + + /** + * Constructor when we have one input. + * @param input + * @param op + */ + + public ConvolutionTransform(Lop input, ConvolutionTransform.OperationTypes op, DataType dt, ValueType vt, ExecType et, int k) + { + super(Lop.Type.Transform, dt, vt); + init(input, op, dt, vt, et); + numThreads = k; + } + + public ConvolutionTransform(Lop input, ConvolutionTransform.OperationTypes op, DataType dt, ValueType vt) + { + super(Lop.Type.Transform, dt, vt); + init(input, op, dt, vt, ExecType.MR); + } + + private void init (Lop input, ConvolutionTransform.OperationTypes op, DataType dt, ValueType vt, ExecType et) + { + operation = op; + + this.addInput(input); + input.addOutput(this); + + boolean breaksAlignment = true; + boolean aligner = false; + boolean definesMRJob = false; + if ( et == ExecType.MR ) { + throw new RuntimeException("The execution type is not supported: " + et.name()); + } + else //CP/SPARK + { + // <code>breaksAlignment</code> is not meaningful when <code>Transform</code> executes in CP. + breaksAlignment = false; + lps.addCompatibility(JobType.INVALID); + lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); + } + } + + @Override + public String toString() { + + return " Operation: " + operation; + } + + /** + * method to get operation type + * @return + */ + + public OperationTypes getOperationType() + { + return operation; + } + + private String getOpcode() { + switch(operation) { + + case IM2COL: + return "im2col"; + + case RESHAPE_COL: + return "reshape_col"; + + case ROTATE180: + return "rotate180"; + + case COL2IM: + return "col2im"; + + case MAX_POOLING: + return "maxpooling"; + + case MAX_POOLING_BACKWARD: + return "maxpooling_backward"; + + case DIRECT_CONV2D: + return "conv2d"; + + case DIRECT_CONV2D_BACKWARD_FILTER: + return "conv2d_backward_filter"; + + case DIRECT_CONV2D_BACKWARD_DATA: + return "conv2d_backward_data"; + + default: + throw new UnsupportedOperationException(this.printErrorLocation() + "Instruction is not defined for Transform operation " + operation); + + } + } + + //CP instructions + // stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, + public String getInstructions(String input, String stride1, String stride2, String padding1, String padding2, + String input_shape1, String input_shape2, String input_shape3, String input_shape4, + String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, + String output) throws LopsException { + //only used for im2col and col2im + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + + sb.append( OPERAND_DELIMITOR ); + sb.append( getOpcode() ); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(0).prepInputOperand(input)); + + //rows, cols, byrow + String[] inputX = new String[]{stride1, stride2, padding1, padding2, + input_shape1, input_shape2, input_shape3, input_shape4, + filter_shape1, filter_shape2, filter_shape3, filter_shape4}; + for( int i=1; i<=(inputX.length); i++ ) { + Lop ltmp = getInputs().get(i); + sb.append( OPERAND_DELIMITOR ); + sb.append( ltmp.prepScalarInputOperand(getExecType())); + } + + //output + sb.append( OPERAND_DELIMITOR ); + sb.append( this.prepOutputOperand(output)); + + //append degree of parallelism + if( getExecType()==ExecType.CP ) { + sb.append( OPERAND_DELIMITOR ); + sb.append( numThreads ); + } + + return sb.toString(); + } + + public String getInstructions(String input, String dout, String stride1, String stride2, String padding1, String padding2, + String input_shape1, String input_shape2, String input_shape3, String input_shape4, + String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, + String output) throws LopsException { + //only used for im2col and col2im + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + + sb.append( OPERAND_DELIMITOR ); + sb.append( getOpcode() ); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(0).prepInputOperand(input)); + + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(1).prepInputOperand(dout)); + + String[] inputX = new String[]{input, dout, stride1, stride2, padding1, padding2, + input_shape1, input_shape2, input_shape3, input_shape4, + filter_shape1, filter_shape2, filter_shape3, filter_shape4}; + for( int i=2; i < inputX.length; i++ ) { + Lop ltmp = getInputs().get(i); + sb.append( OPERAND_DELIMITOR ); + sb.append( ltmp.prepScalarInputOperand(getExecType())); + } + + //output + sb.append( OPERAND_DELIMITOR ); + sb.append( this.prepOutputOperand(output)); + + //append degree of parallelism + if( getExecType()==ExecType.CP ) { + sb.append( OPERAND_DELIMITOR ); + sb.append( numThreads ); + } + + return sb.toString(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/lops/Lop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java index 59d11be..8424e28 100644 --- a/src/main/java/org/apache/sysml/lops/Lop.java +++ b/src/main/java/org/apache/sysml/lops/Lop.java @@ -51,6 +51,7 @@ public abstract class Lop Append, //CP/MR append (column append) CombineUnary, CombineBinary, CombineTernary, //MR combine (stitch together) CentralMoment, CoVariance, GroupedAgg, GroupedAggM, + ConvolutionTransform, Transform, DataPartition, RepMat, //CP/MR reorganization, partitioning, replication ParameterizedBuiltin, //CP/MR parameterized ops (name/value) FunctionCallCP, //CP function calls @@ -529,6 +530,24 @@ public abstract class Lop throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); } + // stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, + public String getInstructions(String input, String stride1, String stride2, String padding1, String padding2, + String input_shape1, String input_shape2, String input_shape3, String input_shape4, + String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, + String output) throws LopsException { + throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); + } + + // For pooling backward + public String getInstructions(String input, String dout, String stride1, String stride2, String padding1, String padding2, + String input_shape1, String input_shape2, String input_shape3, String input_shape4, + String filter_shape1, String filter_shape2, String filter_shape3, String filter_shape4, + String output) throws LopsException { + throw new LopsException(this.printErrorLocation() + "Should never be invoked in Baseclass"); + } + public String getInstructions(int input, int rowl, int rowu, int coll, int colu, int leftRowDim, int leftColDim, int output) throws LopsException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/lops/compile/Dag.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java b/src/main/java/org/apache/sysml/lops/compile/Dag.java index 6ec3a5a..2d24dad 100644 --- a/src/main/java/org/apache/sysml/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java @@ -1468,7 +1468,43 @@ public class Dag<N extends Lop> node.getInputs().get(6).getOutputParameters().getLabel(), node.getOutputParameters().getLabel()); } - + else if (node.getInputs().size() == 13) { + // Used for im2col and reshape_col + inst_string = node.getInstructions( + node.getInputs().get(0).getOutputParameters().getLabel(), + node.getInputs().get(1).getOutputParameters().getLabel(), + node.getInputs().get(2).getOutputParameters().getLabel(), + node.getInputs().get(3).getOutputParameters().getLabel(), + node.getInputs().get(4).getOutputParameters().getLabel(), + node.getInputs().get(5).getOutputParameters().getLabel(), + node.getInputs().get(6).getOutputParameters().getLabel(), + node.getInputs().get(7).getOutputParameters().getLabel(), + node.getInputs().get(8).getOutputParameters().getLabel(), + node.getInputs().get(9).getOutputParameters().getLabel(), + node.getInputs().get(10).getOutputParameters().getLabel(), + node.getInputs().get(11).getOutputParameters().getLabel(), + node.getInputs().get(12).getOutputParameters().getLabel(), + node.getOutputParameters().getLabel()); + } + else if (node.getInputs().size() == 14) { + // Used for pooling_backward + inst_string = node.getInstructions( + node.getInputs().get(0).getOutputParameters().getLabel(), + node.getInputs().get(1).getOutputParameters().getLabel(), + node.getInputs().get(2).getOutputParameters().getLabel(), + node.getInputs().get(3).getOutputParameters().getLabel(), + node.getInputs().get(4).getOutputParameters().getLabel(), + node.getInputs().get(5).getOutputParameters().getLabel(), + node.getInputs().get(6).getOutputParameters().getLabel(), + node.getInputs().get(7).getOutputParameters().getLabel(), + node.getInputs().get(8).getOutputParameters().getLabel(), + node.getInputs().get(9).getOutputParameters().getLabel(), + node.getInputs().get(10).getOutputParameters().getLabel(), + node.getInputs().get(11).getOutputParameters().getLabel(), + node.getInputs().get(12).getOutputParameters().getLabel(), + node.getInputs().get(13).getOutputParameters().getLabel(), + node.getOutputParameters().getLabel()); + } else { throw new LopsException(node.printErrorLocation() + "Node with " + node.getInputs().size() + " inputs is not supported in CP yet! \n"); } @@ -1478,6 +1514,9 @@ public class Dag<N extends Lop> if( LOG.isTraceEnabled() ) LOG.trace("Generating instruction - "+ inst_string); Instruction currInstr = InstructionParser.parseSingleInstruction(inst_string); + if(currInstr == null) { + throw new LopsException("Error parsing the instruction:" + inst_string); + } if (node._beginLine != 0) currInstr.setLocation(node); else if ( !node.getOutputs().isEmpty() ) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index 3a3d714..32e529d 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -21,6 +21,7 @@ package org.apache.sysml.parser; import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import org.apache.sysml.parser.LanguageException.LanguageErrorCodes; @@ -33,11 +34,12 @@ public class BuiltinFunctionExpression extends DataIdentifier public BuiltinFunctionExpression(BuiltinFunctionOp bifop, ArrayList<ParameterExpression> args, String fname, int blp, int bcp, int elp, int ecp) { _kind = Kind.BuiltinFunctionOp; _opcode = bifop; + this.setAllPositions(fname, blp, bcp, elp, ecp); + args = expandConvolutionArguments(args); _args = new Expression[args.size()]; for(int i=0; i < args.size(); i++) { _args[i] = args.get(i).getExpr(); } - this.setAllPositions(fname, blp, bcp, elp, ecp); } public BuiltinFunctionExpression(BuiltinFunctionOp bifop, Expression[] args, String fname, int blp, int bcp, int elp, int ecp) { @@ -203,6 +205,105 @@ public class BuiltinFunctionExpression extends DataIdentifier raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); } } + + private ArrayList<ParameterExpression> orderConvolutionParams(ArrayList<ParameterExpression> paramExpression, + int skip) throws LanguageException { + ArrayList<ParameterExpression> newParams = new ArrayList<ParameterExpression>(); + + for(int i = 0; i < skip; i++) + newParams.add(paramExpression.get(i)); + + String [] orderedParams = { + "stride1", "stride2", "padding1", "padding2", + "input_shape1", "input_shape2", "input_shape3", "input_shape4", + "filter_shape1", "filter_shape2", "filter_shape3", "filter_shape4" + }; + for(int i = 0; i < orderedParams.length; i++) { + boolean found = false; + for(ParameterExpression param : paramExpression) { + if(param.getName() != null && param.getName().equals(orderedParams[i])) { + found = true; + newParams.add(param); + } + } + if(!found) { + throw new LanguageException("Incorrect parameters. Expected " + orderedParams[i] + " to be expanded."); + } + } + + return newParams; + } + + private ArrayList<ParameterExpression> replaceListParams(ArrayList<ParameterExpression> paramExpression, + String inputVarName, String outputVarName, int startIndex) throws LanguageException { + ArrayList<ParameterExpression> newParamExpression = new ArrayList<ParameterExpression>(); + int i = startIndex; + int j = 1; // Assumption: sequential ordering pool_size1, pool_size2 + for (ParameterExpression expr : paramExpression) { + if(expr.getName() != null && expr.getName().equals(inputVarName + j)) { + newParamExpression.add(new ParameterExpression(outputVarName + i, expr.getExpr())); + i++; j++; + } + else { + newParamExpression.add(expr); + } + } + return newParamExpression; + } + + private ArrayList<ParameterExpression> expandListParams(ArrayList<ParameterExpression> paramExpression, + HashSet<String> paramsToExpand) throws LanguageException { + ArrayList<ParameterExpression> newParamExpressions = new ArrayList<ParameterExpression>(); + for(ParameterExpression expr : paramExpression) { + if(paramsToExpand.contains(expr.getName())) { + if(expr.getExpr() instanceof ExpressionList) { + int i = 1; + for(Expression e : ((ExpressionList)expr.getExpr()).getValue()) { + newParamExpressions.add(new ParameterExpression(expr.getName() + i, e)); + i++; + } + } + } + else if(expr.getExpr() instanceof ExpressionList) { + throw new LanguageException("The parameter " + expr.getName() + " cannot be list or is not supported for the given function"); + } + else { + newParamExpressions.add(expr); + } + } + return newParamExpressions; + } + + private ArrayList<ParameterExpression> expandConvolutionArguments(ArrayList<ParameterExpression> paramExpression) { + try { + if(_opcode == BuiltinFunctionOp.CONV2D || _opcode == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER + || _opcode == BuiltinFunctionOp.CONV2D_BACKWARD_DATA) { + HashSet<String> expand = new HashSet<String>(); + expand.add("input_shape"); expand.add("filter_shape"); expand.add("stride"); expand.add("padding"); + paramExpression = expandListParams(paramExpression, expand); + paramExpression = orderConvolutionParams(paramExpression, 2); + } + else if(_opcode == BuiltinFunctionOp.MAX_POOL || + _opcode == BuiltinFunctionOp.MAX_POOL_BACKWARD) { + HashSet<String> expand = new HashSet<String>(); + expand.add("input_shape"); expand.add("pool_size"); expand.add("stride"); expand.add("padding"); + paramExpression = expandListParams(paramExpression, expand); + paramExpression.add(new ParameterExpression("filter_shape1", + new IntIdentifier(1, getFilename(), getBeginLine(), getBeginColumn(), getEndLine(), getEndColumn()))); + paramExpression.add(new ParameterExpression("filter_shape2", + new IntIdentifier(1, getFilename(), getBeginLine(), getBeginColumn(), getEndLine(), getEndColumn()))); + paramExpression = replaceListParams(paramExpression, "pool_size", "filter_shape", 3); + if(_opcode == BuiltinFunctionOp.MAX_POOL_BACKWARD) + paramExpression = orderConvolutionParams(paramExpression, 2); + else + paramExpression = orderConvolutionParams(paramExpression, 1); + } + } + catch(LanguageException e) { + throw new RuntimeException(e); + } + return paramExpression; + } /** * Validate parse tree : Process BuiltinFunction Expression in an assignment @@ -997,6 +1098,37 @@ public class BuiltinFunctionExpression extends DataIdentifier output.setBlockDimensions(id.getRowsInBlock(), id.getColumnsInBlock()); break; + case CONV2D: + case CONV2D_BACKWARD_FILTER: + case CONV2D_BACKWARD_DATA: + case MAX_POOL: + case AVG_POOL: + case MAX_POOL_BACKWARD: + { + // At DML level: + // output = conv2d(input, filter, input_shape=[3, 2, 2], filter_shape=[3, 2, 2], + // strides=[1, 1], border_mode="valid") + // + // Converted to following in constructor (only supported NCHW): + // output = conv2d(input, filter, stride1, stride2, padding1,padding2, + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4) + // + // Similarly, + // conv2d_backward_filter and conv2d_backward_data + Expression input = _args[0]; // For conv2d_backward_filter, this is input and for conv2d_backward_data, this is filter + + if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) { + Expression filter = _args[1]; // For conv2d_backward functions, this is dout + checkMatrixParam(filter); + } + output.setDataType(DataType.MATRIX); + output.setValueType(ValueType.DOUBLE); + output.setBlockDimensions(input.getOutput().getRowsInBlock(), input.getOutput().getColumnsInBlock()); + + checkMatrixParam(input); + break; + } default: if (this.isMathFunction()) { // datatype and dimensions are same as this.getExpr() @@ -1473,6 +1605,18 @@ public class BuiltinFunctionExpression extends DataIdentifier bifop = Expression.BuiltinFunctionOp.LU; else if (functionName.equals("eigen")) bifop = Expression.BuiltinFunctionOp.EIGEN; + else if (functionName.equals("conv2d")) + bifop = Expression.BuiltinFunctionOp.CONV2D; + else if (functionName.equals("conv2d_backward_filter")) + bifop = Expression.BuiltinFunctionOp.CONV2D_BACKWARD_FILTER; + else if (functionName.equals("conv2d_backward_data")) + bifop = Expression.BuiltinFunctionOp.CONV2D_BACKWARD_DATA; + else if (functionName.equals("max_pool")) + bifop = Expression.BuiltinFunctionOp.MAX_POOL; + else if (functionName.equals("max_pool_backward")) + bifop = Expression.BuiltinFunctionOp.MAX_POOL_BACKWARD; + else if (functionName.equals("avg_pool")) + bifop = Expression.BuiltinFunctionOp.AVG_POOL; else if (functionName.equals("solve")) bifop = Expression.BuiltinFunctionOp.SOLVE; else if (functionName.equals("ceil")) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 4daeef7..c587b6f 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -64,6 +64,9 @@ import org.apache.sysml.parser.Expression.ParameterizedBuiltinFunctionOp; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.parser.PrintStatement.PRINTTYPE; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.hops.ConvolutionOp; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; +import org.apache.sysml.parser.Expression.BuiltinFunctionOp; public class DMLTranslator @@ -2798,6 +2801,76 @@ public class DMLTranslator currBuiltinOp.refreshSizeInformation(); //force size reevaluation according to 'outer' flag otherwise danger of incorrect dims break; + case CONV2D: + { + Hop filter = expr2; + // Step 1: IM2COL + Hop image = expr; + ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 2, hops); + Hop loweredMat = new ConvolutionOp(image.getName(), image.getDataType(), image.getValueType(), Hop.ConvOp.IM2COL, inHops1); + + // Step 2: Matrix multiplication + Hop temp = new AggBinaryOp("temp" + target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, AggOp.SUM, filter, loweredMat); + + // Step 3: Reshape col + ArrayList<Hop> inHops2 = getALHopsForConvOp(temp, source, 2, hops); + currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.RESHAPE_COL, inHops2); + setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp); + break; + } + case AVG_POOL: + case MAX_POOL: + { + Hop image = expr; + ArrayList<Hop> inHops1 = getALHopsForPoolingForwardIM2COL(image, source, 1, hops); + if(source.getOpCode() == BuiltinFunctionOp.MAX_POOL) + currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING, inHops1); + else + throw new HopsException("Average pooling is not implemented"); + setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp); + break; + } + case MAX_POOL_BACKWARD: + { + Hop image = expr; + ArrayList<Hop> inHops1 = getALHopsForConvOpPoolingCOL2IM(image, source, 1, hops); // process dout as well + currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.MAX_POOLING_BACKWARD, inHops1); + setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp); + break; + } + case CONV2D_BACKWARD_FILTER: + { + Hop image = expr; + Hop dout = expr2; + + ArrayList<Hop> inHops1 = getALHopsForConvOp(image, source, 2, hops); + Hop x_col = new ConvolutionOp(image.getName(), image.getDataType(), image.getValueType(), Hop.ConvOp.IM2COL, inHops1); + + ArrayList<Hop> inHops2 = getALHopsForConvOp(dout, source, 2, hops); + Hop dout_reshaped = new ConvolutionOp(dout.getName(), dout.getDataType(), dout.getValueType(), Hop.ConvOp.ROTATE180, inHops2); + + Hop dfilter1 = new AggBinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, AggOp.SUM, x_col, dout_reshaped); + currBuiltinOp = new ReorgOp("tempTranspose" + image.getName(), image.getDataType(), image.getValueType(), Hop.ReOrgOp.TRANSPOSE, dfilter1); + setBlockSizeAndRefreshSizeInfo(image, currBuiltinOp); + break; + } + case CONV2D_BACKWARD_DATA: + { + Hop filter = expr; + Hop dout = expr2; + + ArrayList<Hop> inHops1 = getALHopsForConvOp(dout, source, 2, hops); + Hop dout_reshaped = new ConvolutionOp(dout.getName(), dout.getDataType(), dout.getValueType(), Hop.ConvOp.ROTATE180, inHops1); + + Hop temp1 = new AggBinaryOp("temp" + target.getName(), target.getDataType(), target.getValueType(), OpOp2.MULT, AggOp.SUM, dout_reshaped, filter); + Hop temp2 = new ReorgOp("tempTranspose" + target.getName(), target.getDataType(), target.getValueType(), Hop.ReOrgOp.TRANSPOSE, temp1); + + ArrayList<Hop> inHops2 = getALHopsForConvOp(temp2, source, 2, hops); + currBuiltinOp = new ConvolutionOp(target.getName(), target.getDataType(), target.getValueType(), Hop.ConvOp.COL2IM, inHops2); + setBlockSizeAndRefreshSizeInfo(filter, currBuiltinOp); + break; + } + default: throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); } @@ -2806,6 +2879,91 @@ public class DMLTranslator currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); return currBuiltinOp; } + + private void setBlockSizeAndRefreshSizeInfo(Hop in, Hop out) { + HopRewriteUtils.setOutputBlocksizes(out, in.getRowsInBlock(), in.getColsInBlock()); + HopRewriteUtils.copyLineNumbers(in, out); + out.refreshSizeInformation(); + } + + private ArrayList<Hop> getALHopsForConvOpPoolingCOL2IM(Hop first, BuiltinFunctionExpression source, int skip, HashMap<String, Hop> hops) throws ParseException { + ArrayList<Hop> ret = new ArrayList<Hop>(); + ret.add(first); + Expression[] allExpr = source.getAllExpr(); + + for(int i = skip; i < allExpr.length; i++) { + if(i == 11) { + ret.add(processExpression(allExpr[7], null, hops)); // Make number of channels of images and filter the same + } + else + ret.add(processExpression(allExpr[i], null, hops)); + } + return ret; + } + + private ArrayList<Hop> getALHopsForPoolingForwardIM2COL(Hop first, BuiltinFunctionExpression source, int skip, HashMap<String, Hop> hops) throws ParseException { + ArrayList<Hop> ret = new ArrayList<Hop>(); + ret.add(first); + Expression[] allExpr = source.getAllExpr(); + if(skip != 1) { + throw new ParseException("Unsupported skip"); + } + + Expression numChannels = allExpr[6]; + + for(int i = skip; i < allExpr.length; i++) { + if(i == 10) { + ret.add(processExpression(numChannels, null, hops)); + } + else + ret.add(processExpression(allExpr[i], null, hops)); + } + return ret; + } + + private ArrayList<Hop> getALHopsForConvOpPoolingIM2COL(Hop first, BuiltinFunctionExpression source, int skip, HashMap<String, Hop> hops) throws ParseException { + ArrayList<Hop> ret = new ArrayList<Hop>(); + ret.add(first); + Expression[] allExpr = source.getAllExpr(); + int numImgIndex = -1; + if(skip == 1) { + numImgIndex = 5; + } + else if(skip == 2) { + numImgIndex = 6; + } + else { + throw new ParseException("Unsupported skip"); + } + + for(int i = skip; i < allExpr.length; i++) { + if(i == numImgIndex) { // skip=1 ==> i==5 and skip=2 => i==6 + Expression numImg = allExpr[numImgIndex]; + Expression numChannels = allExpr[numImgIndex+1]; + BinaryExpression tmp = new BinaryExpression(org.apache.sysml.parser.Expression.BinaryOp.MULT, + numImg.getFilename(), numImg.getBeginLine(), numImg.getBeginColumn(), numImg.getEndLine(), numImg.getEndColumn()); + tmp.setLeft(numImg); + tmp.setRight(numChannels); + ret.add(processTempIntExpression(tmp, hops)); + ret.add(processExpression(new IntIdentifier(1, numImg.getFilename(), numImg.getBeginLine(), numImg.getBeginColumn(), + numImg.getEndLine(), numImg.getEndColumn()), null, hops)); + i++; + } + else + ret.add(processExpression(allExpr[i], null, hops)); + } + return ret; + } + + private ArrayList<Hop> getALHopsForConvOp(Hop first, BuiltinFunctionExpression source, int skip, HashMap<String, Hop> hops) throws ParseException { + ArrayList<Hop> ret = new ArrayList<Hop>(); + ret.add(first); + Expression[] allExpr = source.getAllExpr(); + for(int i = skip; i < allExpr.length; i++) { + ret.add(processExpression(allExpr[i], null, hops)); + } + return ret; + } public void setIdentifierParams(Hop h, Identifier id) { if( id.getDim1()>= 0 ) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/parser/DataExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DataExpression.java b/src/main/java/org/apache/sysml/parser/DataExpression.java index be2f569..c33f965 100644 --- a/src/main/java/org/apache/sysml/parser/DataExpression.java +++ b/src/main/java/org/apache/sysml/parser/DataExpression.java @@ -130,7 +130,8 @@ public class DataExpression extends DataIdentifier public void setCheckMetadata(boolean checkMetadata) { _checkMetadata = checkMetadata; } - + + public static DataExpression getDataExpression(String functionName, ArrayList<ParameterExpression> passedParamExprs, String filename, int blp, int bcp, int elp, int ecp) throws LanguageException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index d6347a1..bf6b5b9 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -91,6 +91,8 @@ public abstract class Expression CUMSUM, DIAG, EIGEN, + CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA, + MAX_POOL, AVG_POOL, MAX_POOL_BACKWARD, EXP, FLOOR, INTERQUANTILE, http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/parser/ExpressionList.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ExpressionList.java b/src/main/java/org/apache/sysml/parser/ExpressionList.java new file mode 100644 index 0000000..653293c --- /dev/null +++ b/src/main/java/org/apache/sysml/parser/ExpressionList.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.parser; + +import java.util.ArrayList; + +public class ExpressionList extends Expression { + + protected String _name; + protected ArrayList<Expression> _value; + + public ExpressionList(ArrayList<Expression> value) { + this._name = "tmp"; + this._value = value; + } + + public String getName() { + return _name; + } + public void setName(String _name) { + this._name = _name; + } + public ArrayList<Expression> getValue() { + return _value; + } + public void setValue(ArrayList<Expression> _value) { + this._value = _value; + } + + @Override + public Expression rewriteExpression(String prefix) throws LanguageException { + throw new LanguageException("ExpressionList should not be exposed beyond parser layer."); + } + @Override + public VariableSet variablesRead() { + VariableSet result = new VariableSet(); + for( Expression expr : _value ) { + result.addVariables ( expr.variablesRead() ); + } + return result; + } + @Override + public VariableSet variablesUpdated() { + VariableSet result = new VariableSet(); + for( Expression expr : _value ) { + result.addVariables ( expr.variablesUpdated() ); + } + return result; + } + + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/parser/dml/Dml.g4 ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 b/src/main/java/org/apache/sysml/parser/dml/Dml.g4 index 64598ff..5fc63b8 100644 --- a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 +++ b/src/main/java/org/apache/sysml/parser/dml/Dml.g4 @@ -161,7 +161,7 @@ expression returns [ org.apache.sysml.parser.common.ExpressionInfo info ] | '(' left=expression ')' # AtomicExpression // Should you allow indexed expression here ? - // | '[' targetList+=expression (',' targetList+=expression)* ']' # MultiIdExpression + | '[' targetList+=expression (',' targetList+=expression)* ']' # MultiIdExpression // | BOOLEAN # ConstBooleanIdExpression | 'TRUE' # ConstTrueExpression http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java index 6bbaf10..321ba7a 100644 --- a/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java +++ b/src/main/java/org/apache/sysml/parser/dml/DmlSyntacticValidator.java @@ -40,6 +40,7 @@ import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.Expression; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.parser.ExpressionList; import org.apache.sysml.parser.ExternalFunctionStatement; import org.apache.sysml.parser.ForStatement; import org.apache.sysml.parser.FunctionCallIdentifier; @@ -94,6 +95,7 @@ import org.apache.sysml.parser.dml.DmlParser.MatrixMulExpressionContext; import org.apache.sysml.parser.dml.DmlParser.Ml_typeContext; import org.apache.sysml.parser.dml.DmlParser.ModIntDivExpressionContext; import org.apache.sysml.parser.dml.DmlParser.MultDivExpressionContext; +import org.apache.sysml.parser.dml.DmlParser.MultiIdExpressionContext; import org.apache.sysml.parser.dml.DmlParser.ParForStatementContext; import org.apache.sysml.parser.dml.DmlParser.ParameterizedExpressionContext; import org.apache.sysml.parser.dml.DmlParser.PathStatementContext; @@ -481,9 +483,9 @@ public class DmlSyntacticValidator extends CommonSyntacticValidator implements D String namespace = fnNames[0]; String functionName = fnNames[1]; ArrayList<ParameterExpression> paramExpression = getParameterExpressionList(ctx.paramExprs); - + castAsScalarDeprecationCheck(functionName, ctx); - + boolean hasLHS = ctx.targetList != null; functionCallAssignmentStatementHelper(ctx, printStatements, outputStatements, hasLHS ? ctx.targetList.dataInfo.expr : null, ctx.info, ctx.name, hasLHS ? ctx.targetList.start : null, namespace, functionName, paramExpression, hasLHS); @@ -1045,4 +1047,16 @@ public class DmlSyntacticValidator extends CommonSyntacticValidator implements D @Override public void enterBooleanOrExpression(BooleanOrExpressionContext ctx) {} + @Override + public void enterMultiIdExpression(MultiIdExpressionContext ctx) { } + + @Override + public void exitMultiIdExpression(MultiIdExpressionContext ctx) { + ArrayList<Expression> values = new ArrayList<Expression>(); + for(ExpressionContext elem : ctx.targetList) { + values.add(elem.info.expr); + } + ctx.info.expr = new ExpressionList(values); + } + } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java index 716e528..51c4de5 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java @@ -47,7 +47,6 @@ import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.runtime.util.LocalFileUtils; import org.apache.sysml.runtime.util.MapReduceTool; - /** * Each object of this class is a cache envelope for some large piece of data * called "cache block". For example, the body of a matrix can be the cache block. @@ -643,6 +642,8 @@ public abstract class CacheableData<T extends CacheBlock> extends Data } } + protected void clearReusableData() {} + /** * Sets the cache block reference to <code>null</code>, abandons the old block. * Makes the "envelope" empty. Run it to finalize the object (otherwise the @@ -671,6 +672,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data freeEvictedBlob(); // clear the in-memory data + clearReusableData(); _data = null; clearCache(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java index 639653a..ca6f1c7 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java @@ -40,6 +40,7 @@ import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; import org.apache.sysml.runtime.matrix.MetaData; import org.apache.sysml.runtime.matrix.data.FileFormatProperties; import org.apache.sysml.runtime.matrix.data.InputInfo; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.NumItemsByEachReducerMetaData; import org.apache.sysml.runtime.matrix.data.OutputInfo; @@ -205,6 +206,21 @@ public class MatrixObject extends CacheableData<MatrixBlock> return ((double)mc.getNonZeros())/mc.getRows()/mc.getCols(); } + @Override + protected void clearReusableData() { + if(DMLScript.REUSE_NONZEROED_OUTPUT) { + if(_data == null) { + getCache(); + } + if(_data != null && + // Not a column vector + _data.getNumRows() != 1 && _data.getNumColumns() != 1) { + double[] arr = ((MatrixBlock)_data).getDenseBlock(); + LibMatrixDNN.cacheReuseableData(arr); + } + } + } + public String toString() { StringBuilder str = new StringBuilder(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index 48ef9a8..23ca5ef 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -37,6 +37,7 @@ import org.apache.sysml.runtime.instructions.cp.BuiltinBinaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.BuiltinUnaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.CPInstruction; import org.apache.sysml.runtime.instructions.cp.CentralMomentCPInstruction; +import org.apache.sysml.runtime.instructions.cp.ConvolutionCPInstruction; import org.apache.sysml.runtime.instructions.cp.CovarianceCPInstruction; import org.apache.sysml.runtime.instructions.cp.DataGenCPInstruction; import org.apache.sysml.runtime.instructions.cp.DataPartitionCPInstruction; @@ -211,7 +212,15 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "rdiag" , CPINSTRUCTION_TYPE.Reorg); String2CPInstructionType.put( "rshape" , CPINSTRUCTION_TYPE.MatrixReshape); String2CPInstructionType.put( "rsort" , CPINSTRUCTION_TYPE.Reorg); - + + // Opcodes related to convolutions + String2CPInstructionType.put( "im2col" , CPINSTRUCTION_TYPE.Convolution); + String2CPInstructionType.put( "reshape_col" , CPINSTRUCTION_TYPE.Convolution); + String2CPInstructionType.put( "rotate180" , CPINSTRUCTION_TYPE.Convolution); + String2CPInstructionType.put( "col2im" , CPINSTRUCTION_TYPE.Convolution); + String2CPInstructionType.put( "maxpooling" , CPINSTRUCTION_TYPE.Convolution); + String2CPInstructionType.put( "maxpooling_backward" , CPINSTRUCTION_TYPE.Convolution); + // Quaternary instruction opcodes String2CPInstructionType.put( "wsloss" , CPINSTRUCTION_TYPE.Quaternary); String2CPInstructionType.put( "wsigmoid", CPINSTRUCTION_TYPE.Quaternary); @@ -318,6 +327,9 @@ public class CPInstructionParser extends InstructionParser case Reorg: return ReorgCPInstruction.parseInstruction(str); + case Convolution: + return ConvolutionCPInstruction.parseInstruction(str); + case UaggOuterChain: return UaggOuterChainCPInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java index 4dbd9a3..38e92bb 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java @@ -30,7 +30,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator; public abstract class CPInstruction extends Instruction { - public enum CPINSTRUCTION_TYPE { INVALID, AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary, Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, StringInit, CentralMoment, Covariance, UaggOuterChain }; + public enum CPINSTRUCTION_TYPE { INVALID, AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary, Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, StringInit, CentralMoment, Covariance, UaggOuterChain, Convolution }; protected CPINSTRUCTION_TYPE _cptype; protected Operator _optr; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c334c2c8/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java new file mode 100644 index 0000000..24f24dc --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.cp; + +import java.util.ArrayList; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.functionobjects.SwapIndex; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN; +import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.ConvolutionParameters; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.operators.ReorgOperator; +import org.apache.sysml.runtime.util.ConvolutionUtils; +import org.apache.sysml.utils.Statistics; + +public class ConvolutionCPInstruction extends UnaryCPInstruction { + + private CPOperand _in2; // used for pooling backward + private ArrayList<CPOperand> _input_shape; + private ArrayList<CPOperand> _filter_shape; + private ArrayList<CPOperand> _stride = new ArrayList<CPOperand>(); + private ArrayList<CPOperand> _padding = new ArrayList<CPOperand>(); + private boolean _reuseNonZeroedOutput = false; + private int _numThreads = -1; + public ConvolutionCPInstruction(CPOperand in, CPOperand out, String opcode, + String istr, ArrayList<CPOperand> stride, + ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape, int numThreads) { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, + opcode, istr); + _cptype = CPINSTRUCTION_TYPE.Convolution; + _stride = stride; + _padding = padding; + _input_shape = input_shape; + _filter_shape = filter_shape; + _numThreads = numThreads; + } + + public ConvolutionCPInstruction(CPOperand in, CPOperand in2, CPOperand out, String opcode, + String istr, ArrayList<CPOperand> stride, + ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, + ArrayList<CPOperand> filter_shape, int numThreads) { + super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, + opcode, istr); + _in2 = in2; + _cptype = CPINSTRUCTION_TYPE.Convolution; + _stride = stride; + _padding = padding; + _input_shape = input_shape; + _filter_shape = filter_shape; + _numThreads = numThreads; + } + + public static ConvolutionCPInstruction parseInstruction(String str) + throws DMLRuntimeException { + CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); + CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); + + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + if (opcode.equalsIgnoreCase("reshape_col") + || opcode.equalsIgnoreCase("rotate180") + || opcode.equalsIgnoreCase("im2col") + || opcode.equalsIgnoreCase("col2im") + || opcode.equalsIgnoreCase("pooling_pre_reshape") + || opcode.equalsIgnoreCase("pooling_post_reshape") + || opcode.equalsIgnoreCase("maxpooling")) { + InstructionUtils.checkNumFields(parts, 15); + // stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k + in.split(parts[1]); + out.split(parts[14]); + + ArrayList<CPOperand> stride = new ArrayList<CPOperand>(); + ArrayList<CPOperand> padding = new ArrayList<CPOperand>(); + ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>(); + ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>(); + stride.add(new CPOperand(parts[2])); + stride.add(new CPOperand(parts[3])); + padding.add(new CPOperand(parts[4])); + padding.add(new CPOperand(parts[5])); + input_shape.add(new CPOperand(parts[6])); + input_shape.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + filter_shape.add(new CPOperand(parts[10])); + filter_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + int k = Integer.parseInt(parts[15]); + + return new ConvolutionCPInstruction(in, out, opcode, str, stride, + padding, input_shape, filter_shape, k); + } + else if (opcode.equalsIgnoreCase("pooling_backward_reshape") + || opcode.equalsIgnoreCase("maxpooling_backward")) { + InstructionUtils.checkNumFields(parts, 16); + // dout, stride1, stride2, padding1, padding2 + // input_shape1, input_shape2, input_shape3, input_shape4, + // filter_shape1, filter_shape2, filter_shape3, filter_shape4, k + in.split(parts[1]); + CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); + in2.split(parts[2]); + out.split(parts[15]); + + ArrayList<CPOperand> stride = new ArrayList<CPOperand>(); + ArrayList<CPOperand> padding = new ArrayList<CPOperand>(); + ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>(); + ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>(); + stride.add(new CPOperand(parts[3])); + stride.add(new CPOperand(parts[4])); + padding.add(new CPOperand(parts[5])); + padding.add(new CPOperand(parts[6])); + input_shape.add(new CPOperand(parts[7])); + input_shape.add(new CPOperand(parts[8])); + input_shape.add(new CPOperand(parts[9])); + input_shape.add(new CPOperand(parts[10])); + filter_shape.add(new CPOperand(parts[11])); + filter_shape.add(new CPOperand(parts[12])); + filter_shape.add(new CPOperand(parts[13])); + filter_shape.add(new CPOperand(parts[14])); + int k = Integer.parseInt(parts[16]); + + return new ConvolutionCPInstruction(in, in2, out, opcode, str, stride, + padding, input_shape, filter_shape, k); + } + else { + throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str); + } + } + + private int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, + int index) throws DMLRuntimeException { + return (int) ec.getScalarInput(aL.get(index).getName(), + aL.get(index).getValueType(), aL.get(index).isLiteral()) + .getLongValue(); + } + + // TODO: optimize "Sparse operations" once we are happy with the performance of single node Lenet script on dense MNIST dataset + @Override + public void processInstruction(ExecutionContext ec) + throws DMLRuntimeException { + // acquire inputs + MatrixBlock outputBlock = null; + MatrixBlock matBlock = ec.getMatrixInput(input1.getName()); + int pad_h = getScalarInput(ec, _padding, 0); + int pad_w = getScalarInput(ec, _padding, 1); + int stride_h = getScalarInput(ec, _stride, 0); + int stride_w = getScalarInput(ec, _stride, 1); + + int N = getScalarInput(ec, _input_shape, 0); + int C = getScalarInput(ec, _input_shape, 1); + int H = getScalarInput(ec, _input_shape, 2); + int W = getScalarInput(ec, _input_shape, 3); + + int K = getScalarInput(ec, _filter_shape, 0); + + int R = getScalarInput(ec, _filter_shape, 2); + int S = getScalarInput(ec, _filter_shape, 3); + int P = (int) ConvolutionUtils.getP(H, R, stride_h, pad_h); + int Q = (int) ConvolutionUtils.getQ(W, S, stride_w, pad_w); + + ConvolutionParameters params = new ConvolutionParameters(N, C, H, W, K, R, S, stride_h, stride_w, pad_h, pad_w, _numThreads); + + if (instOpcode.equalsIgnoreCase("im2col")) { + checkHeightWidth(ec, params); + checkInputDimensionForIm2col(matBlock, params); + outputBlock = getDenseOutputBlock(ec, C * R * S, N * P * Q, true); + params.setReuseNonZeroedOutput(_reuseNonZeroedOutput); + LibMatrixDNN.im2col(matBlock, outputBlock, params); + } + else if (instOpcode.equalsIgnoreCase("reshape_col")) { + checkHeightWidth(ec, params); + // Is eligible for REUSE_NONZEROED_OUTPUT but cannot guarantee that previous output has been rmvar-ed + // without somewhat expensive HashMap checks + outputBlock = getDenseOutputBlock(ec, N, K * P * Q, true); + params.setReuseNonZeroedOutput(_reuseNonZeroedOutput); + LibMatrixDNN.reshape_col(matBlock, outputBlock, params); + } + else if (instOpcode.equalsIgnoreCase("rotate180")) { + checkHeightWidth(ec, params); + // Is eligible for REUSE_NONZEROED_OUTPUT and always an intermediate instruction + outputBlock = getDenseOutputBlock(ec, N * P * Q, K, true); + params.setReuseNonZeroedOutput(_reuseNonZeroedOutput); + LibMatrixDNN.rotate180(matBlock, outputBlock, params); + } + else if (instOpcode.equalsIgnoreCase("col2im")) { + checkHeightWidth(ec, params); + checkInputDimensionForCol2im(matBlock, params); + // needs to be zeroed-out + outputBlock = getDenseOutputBlock(ec, N, C * H * W, false); + params.setReuseNonZeroedOutput(_reuseNonZeroedOutput); + LibMatrixDNN.col2im(matBlock, outputBlock, params); + } + else if (instOpcode.equalsIgnoreCase("maxpooling")) { + // Is eligible for REUSE_NONZEROED_OUTPUT but cannot guarantee that previous output has been rmvar-ed + // without somewhat expensive HashMap checks + outputBlock = getDenseOutputBlock(ec, N, C*P*Q, true); + params.setReuseNonZeroedOutput(_reuseNonZeroedOutput); + LibMatrixDNN.maxpooling(matBlock, outputBlock, params); + } + else if (instOpcode.equalsIgnoreCase("maxpooling_backward")) { + MatrixBlock dout = ec.getMatrixInput(_in2.getName()); + // Is eligible for REUSE_NONZEROED_OUTPUT but cannot guarantee that previous output has been rmvar-ed + // without somewhat expensive HashMap checks + outputBlock = getDenseOutputBlock(ec, N, C*H*W, false); + params.setReuseNonZeroedOutput(_reuseNonZeroedOutput); + LibMatrixDNN.maxpooling_backward(matBlock, dout, outputBlock, params); + ec.releaseMatrixInput(_in2.getName()); + } + else { + throw new DMLRuntimeException("Unsupported op code " + instOpcode); + } + + // release inputs/outputs + ec.releaseMatrixInput(input1.getName()); + ec.setMatrixOutput(getOutputVariableName(), outputBlock); + } + + private MatrixBlock getDenseOutputBlock(ExecutionContext ec, int numRows, int numCols, boolean reuseNonZeroedOutput1) throws DMLRuntimeException { + long start = -1; + if(DMLScript.STATISTICS) + start = System.nanoTime(); + + MatrixBlock outputBlock = new MatrixBlock(numRows, numCols, numRows * numCols); + _reuseNonZeroedOutput = false; + if(reuseNonZeroedOutput1 && DMLScript.REUSE_NONZEROED_OUTPUT) { + _reuseNonZeroedOutput = true; + outputBlock.allocateDenseBlock(true, !_reuseNonZeroedOutput); + } + else { + outputBlock.allocateDenseBlock(); + } + outputBlock.setNonZeros(-1); + + if(DMLScript.STATISTICS) + Statistics.incrementAllocationTime(System.nanoTime()-start, false); + return outputBlock; + } + + private void checkHeightWidth(ExecutionContext ec, ConvolutionParameters params) throws DMLRuntimeException { + int numChannelsInFilter = getScalarInput(ec, _filter_shape, 1); + + if (numChannelsInFilter != params.C) { + throw new DMLRuntimeException("The number of channels of input and filter should match"); + } + if((params.W + 2 * params.pad_w - params.S) % params.stride_w != 0) { + throw new DMLRuntimeException("The width does not work (Hint: (W + 2 * pad_w - S) % stride_w should be 0 [ ==> (" + params.W + "+" + " 2*" + params.pad_w + "-" + params.S + ") % " + params.stride_w + "!= 0] "); + } + if((params.H + 2 * params.pad_h - params.R) % params.stride_h != 0) { + throw new DMLRuntimeException("The height does not work (Hint: (H + 2 * pad_h - R) % stride_h should be 0 [ ==> (" + params.H + "+" + " 2*" + params.pad_h + "-" + params.R + ") % " + params.stride_h + "!= 0] "); + } + if(params.H <= 0) { + throw new DMLRuntimeException("Height of output patch should be zero"); + } + if(params.Q <= 0) { + throw new DMLRuntimeException("Width of output patch should be zero"); + } + } + + + + private void checkInputDimensionForIm2col(MatrixBlock matBlock, ConvolutionParameters params) throws DMLRuntimeException { + if((params.N != matBlock.getNumRows() || params.C*params.H*params.W != matBlock.getNumColumns())) { + throw new DMLRuntimeException("Incorrect input shape in conv2d"); + } + } + + private void checkInputDimensionForCol2im(MatrixBlock matBlock, ConvolutionParameters params) throws DMLRuntimeException { + if((params.C*params.R*params.S != matBlock.getNumRows() || params.N*params.P*params.Q != matBlock.getNumColumns())) { + throw new DMLRuntimeException("Incorrect input shape in conv2d_backward_data"); + } + } +}
