Repository: systemml
Updated Branches:
  refs/heads/master 2ca2d8aa7 -> 5adb330de


[SYSTEMML-540] Reduce the number of unknowns in ConvolutionOp

- This commit reduces the unknowns during dynamic recompilation by inferring the
input's height/width of ConvolutionOp based on its parent's output's
height/width.
- Additionally, for developer debugging, I have guarded the functionality
with the flag INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP and have added
sufficient documentation to explain how these dimensions are inferred.

Closes #685.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5adb330d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5adb330d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5adb330d

Branch: refs/heads/master
Commit: 5adb330deffa5479475338316bf47193d0c31da4
Parents: 2ca2d8a
Author: Niketan Pansare <[email protected]>
Authored: Mon Oct 16 15:44:37 2017 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Mon Oct 16 15:45:39 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    | 170 ++++++++++++++++---
 1 file changed, 144 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/5adb330d/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java 
b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index e732fb8..e4ed32b 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -32,11 +32,21 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
-
 import java.util.ArrayList;
 
 public class ConvolutionOp extends Hop  implements MultiThreadedHop
 {      
+       // 
-------------------------------------------------------------------------
+       // This flag allows us to compile plans with less unknowns and also 
serves as future tensorblock integration.
+       // By default, these flags are turned on.
+       
+       // When this flag is turned on, we attempt to check the parent 
convolution hop for unknown dimensions.
+       // For example: in case of conv -> maxpool, the input 
channel/height/width of maxpool will match output channel/height/width of conv.
+       private static final boolean INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP = 
true;
+       // This guards us from cases where the user provides incorrect C,H,W 
parameters.
+       private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = 
true;
+       // 
-------------------------------------------------------------------------
+       
        private Hop.ConvOp op;
 
        private int _maxNumThreads = -1; //-1 for unlimited
@@ -475,17 +485,21 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
        // input_shape1, input_shape2, input_shape3, input_shape4, 
        // filter_shape1, filter_shape2, filter_shape3, filter_shape4
        ConvolutionParameters parseInput() throws DMLRuntimeException {
+               
+               Hop imageHeightHop = null; Hop filterHeightHop = null;
                if(op == ConvOp.MAX_POOLING_BACKWARD 
                                || op == ConvOp.DIRECT_CONV2D 
                                || op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER
                                || op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
+                       imageHeightHop = getInput().get(8);
+                       filterHeightHop = getInput().get(12);
                        _cachedParams.setIfUnknown(
                                        getInput().get(6),
                                        getInput().get(7), 
-                                       getInput().get(8), 
+                                       imageHeightHop, 
                                        getInput().get(9), 
                                        getInput().get(10), 
-                                       getInput().get(12), 
+                                       filterHeightHop, 
                                        getInput().get(13), 
                                        getInput().get(2), 
                                        getInput().get(3), 
@@ -493,22 +507,127 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                                        getInput().get(5), _maxNumThreads);
                }
                else {
+                       imageHeightHop = getInput().get(7);
+                       filterHeightHop = getInput().get(11);
                        _cachedParams.setIfUnknown(
                                        getInput().get(5),
                                        getInput().get(6), 
-                                       getInput().get(7), 
+                                       imageHeightHop, 
                                        getInput().get(8), 
                                        getInput().get(9), 
-                                       getInput().get(11), 
+                                       filterHeightHop, 
                                        getInput().get(12), 
                                        getInput().get(1), 
                                        getInput().get(2), 
                                        getInput().get(3), 
                                        getInput().get(4), _maxNumThreads);
                }
+               
+               if(INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP) {
+                       boolean isMaxPool = getOp() == ConvOp.MAX_POOLING;
+                       boolean isConv = getOp() == ConvOp.DIRECT_CONV2D;
+                       boolean unknownCHWPQ = _cachedParams.C < 0 || 
_cachedParams.H < 0 || _cachedParams.W < 0 || _cachedParams.P < 0 || 
_cachedParams.Q < 0;
+                       if((isMaxPool || isConv) && unknownCHWPQ) {
+                               // Only infer input shape for convolution and 
maxpool
+                               inferCHWPQFromParentOp();
+                       }
+               }
+               
+               if(imageHeightHop == filterHeightHop && _cachedParams.R < 0 && 
_cachedParams.H > 0) {
+                       // Unknown R, but known H and both are equal
+                       // This happens for one-dimensional conv2d where H=R 
and H can be inferred from the parent hop
+                       _cachedParams.R = _cachedParams.H;
+               }
+               
+               // Compute P and Q if unknown. At script level, they are 
computed using following script:
+               // P = as.integer(floor((H + 2*pad_h - R)/stride_h + 1))
+               // Q = as.integer(floor((W + 2*pad_w - S)/stride_w + 1))
+               if(_cachedParams.P < 0 && _cachedParams.H >= 0 && 
_cachedParams.R >= 0 && _cachedParams.stride_h >= 0 && _cachedParams.pad_h >= 
0) {
+                       _cachedParams.P = (int) 
org.apache.sysml.runtime.util.ConvolutionUtils.getP(_cachedParams.H, 
_cachedParams.R, _cachedParams.stride_h, _cachedParams.pad_h);
+               }
+               if(_cachedParams.Q < 0 && _cachedParams.W >= 0 && 
_cachedParams.S >= 0 && _cachedParams.stride_w >= 0 && _cachedParams.pad_w >= 
0) {
+                       _cachedParams.Q = (int) 
org.apache.sysml.runtime.util.ConvolutionUtils.getQ(_cachedParams.W, 
_cachedParams.S, _cachedParams.stride_w, _cachedParams.pad_w);
+               }
+               
                return _cachedParams;
        }
        
+       /**
+        * Utility method to check if the given hop is a BIAS_ADD hop
+        * 
+        * @param hop the given hop
+        * @return true if the given hop is BIAS_ADD
+        */
+       private static boolean isInputBiasAdd(Hop hop) {
+               if(hop instanceof ConvolutionOp && ((ConvolutionOp) 
hop).getOp() == ConvOp.BIAS_ADD) {
+                       return true;
+               }
+               return false;
+       }
+       
+       /**
+        * Utility method to check if the inferred shapes are equal to the 
given shape with a guard for unknown
+        * 
+        * @param dim1 inferred shape
+        * @param dim2 given shape
+        * @param paramType string denoting the parameter for pretty printing 
of the error message
+        * @throws DMLRuntimeException if dim1 != dim2
+        */
+       private void throwExceptionIfNotEqual(int dim1, int dim2, String 
paramType) throws DMLRuntimeException {
+               if(dim1 >= 0 && dim2 >= 0 && dim1 != dim2) {
+                       throw new DMLRuntimeException("Inferred " + paramType + 
" from parent doesn't match with given " + paramType + ":" + dim1 + " != " + 
dim2);
+               }
+       }
+       
+       /**
+        * Gets the values for the parameters C, H, W, P, Q from parent hops
+        * 
+        * @throws DMLRuntimeException if error occurs
+        */
+       private void inferCHWPQFromParentOp() throws DMLRuntimeException {Hop 
tmp = getInput().get(0);
+               while(isInputReLU(tmp) || isInputBiasAdd(tmp)) {
+                       // Skip ReLU and bias_add and go to its parent
+                       tmp = tmp.getInput().get(0);
+               }
+               // Cast tmp as parent
+               ConvolutionOp parentOp = (tmp instanceof ConvolutionOp) ? 
((ConvolutionOp) tmp) : null; 
+               
+               if(parentOp == null)
+                       return;
+               else if(parentOp.getOp() == ConvOp.MAX_POOLING) {
+                       ConvolutionParameters parentParam = 
parentOp.parseInput();
+                       int prevC = _cachedParams.C; int prevH = 
_cachedParams.H; int prevW = _cachedParams.W;
+                       // [C, P, Q] from maxpool becomes [C, H, W] of next op
+                       _cachedParams.C = (_cachedParams.C < 0) ? parentParam.C 
: _cachedParams.C;
+                       _cachedParams.H = (_cachedParams.H < 0) ? parentParam.P 
: _cachedParams.H;
+                       _cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q 
: _cachedParams.W;
+                       if(LOG.isDebugEnabled()) {
+                               LOG.debug("Inferring [C,H,W] from maxpool 
parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + 
"," + _cachedParams.H + "," + _cachedParams.W + "]");
+                       }
+                       if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) {
+                               throwExceptionIfNotEqual(prevC, 
_cachedParams.C, "C");
+                               throwExceptionIfNotEqual(prevH, 
_cachedParams.H, "H");
+                               throwExceptionIfNotEqual(prevW, 
_cachedParams.W, "W");
+                       }
+               }
+               else if(parentOp.getOp() == ConvOp.DIRECT_CONV2D) {
+                       ConvolutionParameters parentParam = 
parentOp.parseInput();
+                       int prevC = _cachedParams.C; int prevH = 
_cachedParams.H; int prevW = _cachedParams.W;
+                       // [K, P, Q] from convolution becomes [C, H, W] of next 
op
+                       _cachedParams.C = (_cachedParams.C < 0) ? parentParam.K 
: _cachedParams.C;
+                       _cachedParams.H = (_cachedParams.H < 0) ? parentParam.P 
: _cachedParams.H;
+                       _cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q 
: _cachedParams.W;
+                       if(LOG.isDebugEnabled()) {
+                               LOG.debug("Inferring [C,H,W] from maxpool 
parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + 
"," + _cachedParams.H + "," + _cachedParams.W + "]");
+                       }
+                       if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) {
+                               throwExceptionIfNotEqual(prevC, 
_cachedParams.C, "C");
+                               throwExceptionIfNotEqual(prevH, 
_cachedParams.H, "H");
+                               throwExceptionIfNotEqual(prevW, 
_cachedParams.W, "W");
+                       }
+               }
+       }
+       
        @Override
        public void refreshSizeInformation()
        {
@@ -620,9 +739,8 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                if(op == ConvOp.BIAS_ADD || op == ConvOp.BIAS_MULTIPLY) {
                        throw new RuntimeException("getDim method should not be 
invoked for bias_add and bias_multiply");
                }
-               ConvolutionParameters params;
                try {
-                       params = parseInput();
+                       parseInput();
                } catch (DMLRuntimeException e) {
                        throw new RuntimeException(e);
                }
@@ -653,49 +771,49 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                
                long ret = -1;
                if(dimString.equals("K") && filter != null) {
-                       ret = getNonNegative(ret, getNonNegative(params.K, 
filter._dim1));
+                       ret = getNonNegative(ret, 
getNonNegative(_cachedParams.K, filter._dim1));
                }
                else if(dimString.equals("CRS") && filter != null) {
-                       ret = getNonNegative(ret, 
getNonNegative(nonNegativeMultiply(params.C, params.R, params.S), 
filter._dim2));
+                       ret = getNonNegative(ret, 
getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.R, 
_cachedParams.S), filter._dim2));
                }
                else if(dimString.equals("N") && input != null) {
-                       ret = getNonNegative(ret, getNonNegative(params.N, 
input._dim1));
+                       ret = getNonNegative(ret, 
getNonNegative(_cachedParams.N, input._dim1));
                }
                else if(dimString.equals("CHW") && input != null) {
-                       ret = getNonNegative(ret, 
getNonNegative(nonNegativeMultiply(params.C, params.H, params.W), input._dim2));
+                       ret = getNonNegative(ret, 
getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.H, 
_cachedParams.W), input._dim2));
                }
                else if(dimString.equals("N") && dout != null) {
-                       ret = getNonNegative(ret, getNonNegative(params.N, 
dout._dim1));
+                       ret = getNonNegative(ret, 
getNonNegative(_cachedParams.N, dout._dim1));
                }
                else if(dimString.equals("KPQ") && dout != null) {
-                       ret = getNonNegative(ret, 
getNonNegative(nonNegativeMultiply(params.K, params.P, params.Q), dout._dim2));
+                       ret = getNonNegative(ret, 
getNonNegative(nonNegativeMultiply(_cachedParams.K, _cachedParams.P, 
_cachedParams.Q), dout._dim2));
                }
                else if(dimString.equals("N") && dout1 != null) {
-                       ret = getNonNegative(ret, getNonNegative(params.N, 
dout1._dim1));
+                       ret = getNonNegative(ret, 
getNonNegative(_cachedParams.N, dout1._dim1));
                }
                else if(dimString.equals("CPQ") && dout1 != null) {
-                       ret = getNonNegative(ret, 
getNonNegative(nonNegativeMultiply(params.C, params.P, params.Q), dout1._dim2));
+                       ret = getNonNegative(ret, 
getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.P, 
_cachedParams.Q), dout1._dim2));
                }
                else if(dimString.equals("K")) {
-                       ret = getNonNegative(ret, params.K >= 0 ? params.K : 
-1);
+                       ret = getNonNegative(ret, _cachedParams.K >= 0 ? 
_cachedParams.K : -1);
                }
                else if(dimString.equals("CRS")) {
-                       ret = getNonNegative(ret, nonNegativeMultiply(params.C, 
params.R, params.S));
+                       ret = getNonNegative(ret, 
nonNegativeMultiply(_cachedParams.C, _cachedParams.R, _cachedParams.S));
                }
                else if(dimString.equals("N")) {
-                       ret = getNonNegative(ret, params.N >= 0 ? params.N : 
-1);
+                       ret = getNonNegative(ret, _cachedParams.N >= 0 ? 
_cachedParams.N : -1);
                }
                else if(dimString.equals("CHW")) {
-                       ret = getNonNegative(ret, nonNegativeMultiply(params.C, 
params.H, params.W));
+                       ret = getNonNegative(ret, 
nonNegativeMultiply(_cachedParams.C, _cachedParams.H, _cachedParams.W));
                }
                else if(dimString.equals("KPQ")) {
-                       ret = getNonNegative(ret, nonNegativeMultiply(params.K, 
params.P, params.Q));
+                       ret = getNonNegative(ret, 
nonNegativeMultiply(_cachedParams.K, _cachedParams.P, _cachedParams.Q));
                }
                else if(dimString.equals("PQ")) {
-                       ret = getNonNegative(ret, nonNegativeMultiply(params.P, 
params.Q));
+                       ret = getNonNegative(ret, 
nonNegativeMultiply(_cachedParams.P, _cachedParams.Q));
                }
                else if(dimString.equals("CPQ")) {
-                       ret = getNonNegative(ret, nonNegativeMultiply(params.C, 
params.P, params.Q));
+                       ret = getNonNegative(ret, 
nonNegativeMultiply(_cachedParams.C, _cachedParams.P, _cachedParams.Q));
                }
                else {
                        throw new RuntimeException("Unsupported dimension:" + 
dimString + " for operator " + getOp().name());
@@ -703,10 +821,10 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                
                if(LOG.isDebugEnabled() && ret < 0) {
                        LOG.debug("Unknown dimension " + dimString + " for 
ConvolutionOp:" + op.name() + 
-                                       " img_dim=[" + params.N + " " + 
params.C + " " + params.H + " " + params.W + "]" +
-                                       " filter_dim=[" + params.K + " " + 
params.C + " " + params.H + " " + params.W + "]" + 
-                                       " output_feature_map=[" + params.P + " 
" + params.Q + "] stride=[" + params.stride_h + " " + params.stride_w + "]" +
-                                       " pad=[" + params.pad_h + " " + 
params.pad_w + "]");
+                                       " img_dim=[" + _cachedParams.N + " " + 
_cachedParams.C + " " + _cachedParams.H + " " + _cachedParams.W + "]" +
+                                       " filter_dim=[" + _cachedParams.K + " " 
+ _cachedParams.C + " " + _cachedParams.R + " " + _cachedParams.S + "]" + 
+                                       " output_feature_map=[" + 
_cachedParams.P + " " + _cachedParams.Q + "] stride=[" + _cachedParams.stride_h 
+ " " + _cachedParams.stride_w + "]" +
+                                       " pad=[" + _cachedParams.pad_h + " " + 
_cachedParams.pad_w + "]");
                }
                return ret;
        }

Reply via email to