[SYSTEMML-540] Reduce the number of unknowns in ConvolutionOp - This commit reduces the unknowns during dynamic recompilation by inferring the input's height/width of ConvolutionOp based on its parent's output's height/width. - Additionally, for developer debugging, I have guarded the functionality with the flag INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP and have added sufficient documentation to explain how these dimensions are inferred.
Closes #685. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5adb330d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5adb330d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5adb330d Branch: refs/heads/master Commit: 5adb330deffa5479475338316bf47193d0c31da4 Parents: 2ca2d8a Author: Niketan Pansare <[email protected]> Authored: Mon Oct 16 15:44:37 2017 -0700 Committer: Niketan Pansare <[email protected]> Committed: Mon Oct 16 15:45:39 2017 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/ConvolutionOp.java | 170 ++++++++++++++++--- 1 file changed, 144 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/5adb330d/src/main/java/org/apache/sysml/hops/ConvolutionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java index e732fb8..e4ed32b 100644 --- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java +++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java @@ -32,11 +32,21 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.ConvolutionParameters; - import java.util.ArrayList; public class ConvolutionOp extends Hop implements MultiThreadedHop { + // ------------------------------------------------------------------------- + // This flag allows us to compile plans with less unknowns and also serves as future tensorblock integration. + // By default, these flags are turned on. + + // When this flag is turned on, we attempt to check the parent convolution hop for unknown dimensions. + // For example: in case of conv -> maxpool, the input channel/height/width of maxpool will match output channel/height/width of conv. + private static final boolean INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP = true; + // This guards us from cases where the user provides incorrect C,H,W parameters. + private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = true; + // ------------------------------------------------------------------------- + private Hop.ConvOp op; private int _maxNumThreads = -1; //-1 for unlimited @@ -475,17 +485,21 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop // input_shape1, input_shape2, input_shape3, input_shape4, // filter_shape1, filter_shape2, filter_shape3, filter_shape4 ConvolutionParameters parseInput() throws DMLRuntimeException { + + Hop imageHeightHop = null; Hop filterHeightHop = null; if(op == ConvOp.MAX_POOLING_BACKWARD || op == ConvOp.DIRECT_CONV2D || op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER || op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) { + imageHeightHop = getInput().get(8); + filterHeightHop = getInput().get(12); _cachedParams.setIfUnknown( getInput().get(6), getInput().get(7), - getInput().get(8), + imageHeightHop, getInput().get(9), getInput().get(10), - getInput().get(12), + filterHeightHop, getInput().get(13), getInput().get(2), getInput().get(3), @@ -493,22 +507,127 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop getInput().get(5), _maxNumThreads); } else { + imageHeightHop = getInput().get(7); + filterHeightHop = getInput().get(11); _cachedParams.setIfUnknown( getInput().get(5), getInput().get(6), - getInput().get(7), + imageHeightHop, getInput().get(8), getInput().get(9), - getInput().get(11), + filterHeightHop, getInput().get(12), getInput().get(1), getInput().get(2), getInput().get(3), getInput().get(4), _maxNumThreads); } + + if(INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP) { + boolean isMaxPool = getOp() == ConvOp.MAX_POOLING; + boolean isConv = getOp() == ConvOp.DIRECT_CONV2D; + boolean unknownCHWPQ = _cachedParams.C < 0 || _cachedParams.H < 0 || _cachedParams.W < 0 || _cachedParams.P < 0 || _cachedParams.Q < 0; + if((isMaxPool || isConv) && unknownCHWPQ) { + // Only infer input shape for convolution and maxpool + inferCHWPQFromParentOp(); + } + } + + if(imageHeightHop == filterHeightHop && _cachedParams.R < 0 && _cachedParams.H > 0) { + // Unknown R, but known H and both are equal + // This happens for one-dimensional conv2d where H=R and H can be inferred from the parent hop + _cachedParams.R = _cachedParams.H; + } + + // Compute P and Q if unknown. At script level, they are computed using following script: + // P = as.integer(floor((H + 2*pad_h - R)/stride_h + 1)) + // Q = as.integer(floor((W + 2*pad_w - S)/stride_w + 1)) + if(_cachedParams.P < 0 && _cachedParams.H >= 0 && _cachedParams.R >= 0 && _cachedParams.stride_h >= 0 && _cachedParams.pad_h >= 0) { + _cachedParams.P = (int) org.apache.sysml.runtime.util.ConvolutionUtils.getP(_cachedParams.H, _cachedParams.R, _cachedParams.stride_h, _cachedParams.pad_h); + } + if(_cachedParams.Q < 0 && _cachedParams.W >= 0 && _cachedParams.S >= 0 && _cachedParams.stride_w >= 0 && _cachedParams.pad_w >= 0) { + _cachedParams.Q = (int) org.apache.sysml.runtime.util.ConvolutionUtils.getQ(_cachedParams.W, _cachedParams.S, _cachedParams.stride_w, _cachedParams.pad_w); + } + return _cachedParams; } + /** + * Utility method to check if the given hop is a BIAS_ADD hop + * + * @param hop the given hop + * @return true if the given hop is BIAS_ADD + */ + private static boolean isInputBiasAdd(Hop hop) { + if(hop instanceof ConvolutionOp && ((ConvolutionOp) hop).getOp() == ConvOp.BIAS_ADD) { + return true; + } + return false; + } + + /** + * Utility method to check if the inferred shapes are equal to the given shape with a guard for unknown + * + * @param dim1 inferred shape + * @param dim2 given shape + * @param paramType string denoting the parameter for pretty printing of the error message + * @throws DMLRuntimeException if dim1 != dim2 + */ + private void throwExceptionIfNotEqual(int dim1, int dim2, String paramType) throws DMLRuntimeException { + if(dim1 >= 0 && dim2 >= 0 && dim1 != dim2) { + throw new DMLRuntimeException("Inferred " + paramType + " from parent doesn't match with given " + paramType + ":" + dim1 + " != " + dim2); + } + } + + /** + * Gets the values for the parameters C, H, W, P, Q from parent hops + * + * @throws DMLRuntimeException if error occurs + */ + private void inferCHWPQFromParentOp() throws DMLRuntimeException {Hop tmp = getInput().get(0); + while(isInputReLU(tmp) || isInputBiasAdd(tmp)) { + // Skip ReLU and bias_add and go to its parent + tmp = tmp.getInput().get(0); + } + // Cast tmp as parent + ConvolutionOp parentOp = (tmp instanceof ConvolutionOp) ? ((ConvolutionOp) tmp) : null; + + if(parentOp == null) + return; + else if(parentOp.getOp() == ConvOp.MAX_POOLING) { + ConvolutionParameters parentParam = parentOp.parseInput(); + int prevC = _cachedParams.C; int prevH = _cachedParams.H; int prevW = _cachedParams.W; + // [C, P, Q] from maxpool becomes [C, H, W] of next op + _cachedParams.C = (_cachedParams.C < 0) ? parentParam.C : _cachedParams.C; + _cachedParams.H = (_cachedParams.H < 0) ? parentParam.P : _cachedParams.H; + _cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q : _cachedParams.W; + if(LOG.isDebugEnabled()) { + LOG.debug("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + "," + _cachedParams.H + "," + _cachedParams.W + "]"); + } + if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) { + throwExceptionIfNotEqual(prevC, _cachedParams.C, "C"); + throwExceptionIfNotEqual(prevH, _cachedParams.H, "H"); + throwExceptionIfNotEqual(prevW, _cachedParams.W, "W"); + } + } + else if(parentOp.getOp() == ConvOp.DIRECT_CONV2D) { + ConvolutionParameters parentParam = parentOp.parseInput(); + int prevC = _cachedParams.C; int prevH = _cachedParams.H; int prevW = _cachedParams.W; + // [K, P, Q] from convolution becomes [C, H, W] of next op + _cachedParams.C = (_cachedParams.C < 0) ? parentParam.K : _cachedParams.C; + _cachedParams.H = (_cachedParams.H < 0) ? parentParam.P : _cachedParams.H; + _cachedParams.W = (_cachedParams.W < 0) ? parentParam.Q : _cachedParams.W; + if(LOG.isDebugEnabled()) { + LOG.debug("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + _cachedParams.C + "," + _cachedParams.H + "," + _cachedParams.W + "]"); + } + if(THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH) { + throwExceptionIfNotEqual(prevC, _cachedParams.C, "C"); + throwExceptionIfNotEqual(prevH, _cachedParams.H, "H"); + throwExceptionIfNotEqual(prevW, _cachedParams.W, "W"); + } + } + } + @Override public void refreshSizeInformation() { @@ -620,9 +739,8 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop if(op == ConvOp.BIAS_ADD || op == ConvOp.BIAS_MULTIPLY) { throw new RuntimeException("getDim method should not be invoked for bias_add and bias_multiply"); } - ConvolutionParameters params; try { - params = parseInput(); + parseInput(); } catch (DMLRuntimeException e) { throw new RuntimeException(e); } @@ -653,49 +771,49 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop long ret = -1; if(dimString.equals("K") && filter != null) { - ret = getNonNegative(ret, getNonNegative(params.K, filter._dim1)); + ret = getNonNegative(ret, getNonNegative(_cachedParams.K, filter._dim1)); } else if(dimString.equals("CRS") && filter != null) { - ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(params.C, params.R, params.S), filter._dim2)); + ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.R, _cachedParams.S), filter._dim2)); } else if(dimString.equals("N") && input != null) { - ret = getNonNegative(ret, getNonNegative(params.N, input._dim1)); + ret = getNonNegative(ret, getNonNegative(_cachedParams.N, input._dim1)); } else if(dimString.equals("CHW") && input != null) { - ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(params.C, params.H, params.W), input._dim2)); + ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.H, _cachedParams.W), input._dim2)); } else if(dimString.equals("N") && dout != null) { - ret = getNonNegative(ret, getNonNegative(params.N, dout._dim1)); + ret = getNonNegative(ret, getNonNegative(_cachedParams.N, dout._dim1)); } else if(dimString.equals("KPQ") && dout != null) { - ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(params.K, params.P, params.Q), dout._dim2)); + ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.K, _cachedParams.P, _cachedParams.Q), dout._dim2)); } else if(dimString.equals("N") && dout1 != null) { - ret = getNonNegative(ret, getNonNegative(params.N, dout1._dim1)); + ret = getNonNegative(ret, getNonNegative(_cachedParams.N, dout1._dim1)); } else if(dimString.equals("CPQ") && dout1 != null) { - ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(params.C, params.P, params.Q), dout1._dim2)); + ret = getNonNegative(ret, getNonNegative(nonNegativeMultiply(_cachedParams.C, _cachedParams.P, _cachedParams.Q), dout1._dim2)); } else if(dimString.equals("K")) { - ret = getNonNegative(ret, params.K >= 0 ? params.K : -1); + ret = getNonNegative(ret, _cachedParams.K >= 0 ? _cachedParams.K : -1); } else if(dimString.equals("CRS")) { - ret = getNonNegative(ret, nonNegativeMultiply(params.C, params.R, params.S)); + ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.R, _cachedParams.S)); } else if(dimString.equals("N")) { - ret = getNonNegative(ret, params.N >= 0 ? params.N : -1); + ret = getNonNegative(ret, _cachedParams.N >= 0 ? _cachedParams.N : -1); } else if(dimString.equals("CHW")) { - ret = getNonNegative(ret, nonNegativeMultiply(params.C, params.H, params.W)); + ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.H, _cachedParams.W)); } else if(dimString.equals("KPQ")) { - ret = getNonNegative(ret, nonNegativeMultiply(params.K, params.P, params.Q)); + ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.K, _cachedParams.P, _cachedParams.Q)); } else if(dimString.equals("PQ")) { - ret = getNonNegative(ret, nonNegativeMultiply(params.P, params.Q)); + ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.P, _cachedParams.Q)); } else if(dimString.equals("CPQ")) { - ret = getNonNegative(ret, nonNegativeMultiply(params.C, params.P, params.Q)); + ret = getNonNegative(ret, nonNegativeMultiply(_cachedParams.C, _cachedParams.P, _cachedParams.Q)); } else { throw new RuntimeException("Unsupported dimension:" + dimString + " for operator " + getOp().name()); @@ -703,10 +821,10 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop if(LOG.isDebugEnabled() && ret < 0) { LOG.debug("Unknown dimension " + dimString + " for ConvolutionOp:" + op.name() + - " img_dim=[" + params.N + " " + params.C + " " + params.H + " " + params.W + "]" + - " filter_dim=[" + params.K + " " + params.C + " " + params.H + " " + params.W + "]" + - " output_feature_map=[" + params.P + " " + params.Q + "] stride=[" + params.stride_h + " " + params.stride_w + "]" + - " pad=[" + params.pad_h + " " + params.pad_w + "]"); + " img_dim=[" + _cachedParams.N + " " + _cachedParams.C + " " + _cachedParams.H + " " + _cachedParams.W + "]" + + " filter_dim=[" + _cachedParams.K + " " + _cachedParams.C + " " + _cachedParams.R + " " + _cachedParams.S + "]" + + " output_feature_map=[" + _cachedParams.P + " " + _cachedParams.Q + "] stride=[" + _cachedParams.stride_h + " " + _cachedParams.stride_w + "]" + + " pad=[" + _cachedParams.pad_h + " " + _cachedParams.pad_w + "]"); } return ret; }
