Repository: systemml Updated Branches: refs/heads/master 997eb2aa2 -> 47973a905
[SYSTEMML-540] Bugfix for fused ReLU-maxpooling and ReLU-maxpooling backward operators Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/47973a90 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/47973a90 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/47973a90 Branch: refs/heads/master Commit: 47973a9055976ae9a8a7b294e1868cadfedc50cc Parents: 997eb2a Author: Niketan Pansare <[email protected]> Authored: Mon Jan 22 15:05:16 2018 -0800 Committer: Niketan Pansare <[email protected]> Committed: Mon Jan 22 15:05:16 2018 -0800 ---------------------------------------------------------------------- .../org/apache/sysml/hops/ConvolutionOp.java | 44 ++++++++++++++------ 1 file changed, 31 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/47973a90/src/main/java/org/apache/sysml/hops/ConvolutionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java index 4c7525d..99e69b9 100644 --- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java +++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java @@ -164,10 +164,24 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop } } - private static boolean isInputReLU(Hop input) { - return HopRewriteUtils.isBinary(input, OpOp2.MAX) - && (HopRewriteUtils.isLiteralOfValue(input.getInput().get(0), 0) - || HopRewriteUtils.isLiteralOfValue(input.getInput().get(1), 0)); + /** + * Returns parent matrix X or null + * @param input input hop + * @return either null or X if input is max(X,0) or max(0,X) + */ + private static Hop isInputReLU(Hop input) { + if(HopRewriteUtils.isBinary(input, OpOp2.MAX)) { + if(HopRewriteUtils.isLiteralOfValue(input.getInput().get(0), 0)) { + return input.getInput().get(1); + } + else if(HopRewriteUtils.isLiteralOfValue(input.getInput().get(1), 0)) { + return input.getInput().get(0); + } + else + return null; + } + else + return null; } private static boolean isInputConv2d(Hop input) { @@ -228,12 +242,13 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop // RELU_MAX_POOLING and RELU_MAX_POOLING_BACKWARD is extremely useful for CP backend // by reducing unnecessary sparse-to-dense-to-sparse conversion. // For other backends, this operators is not necessary as it reduces an additional relu operator. - if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == ConvOp.MAX_POOLING && isInputReLU(inputs.get(0))) { - lhsInputLop = inputs.get(0).getInput().get(0).constructLops(); + Hop parentReLU = isInputReLU(inputs.get(0)); + if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == ConvOp.MAX_POOLING && parentReLU != null) { + lhsInputLop = parentReLU.constructLops(); lopOp = OperationTypes.RELU_MAX_POOLING; } - else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == ConvOp.MAX_POOLING_BACKWARD && isInputReLU(inputs.get(0))) { - lhsInputLop = inputs.get(0).getInput().get(0).constructLops(); + else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && et == ExecType.CP && op == ConvOp.MAX_POOLING_BACKWARD && parentReLU != null) { + lhsInputLop = parentReLU.constructLops(); lopOp = OperationTypes.RELU_MAX_POOLING_BACKWARD; } else if(OptimizerUtils.ALLOW_OPERATOR_FUSION && op == ConvOp.BIAS_ADD && isInputConv2d(inputs.get(0))) { @@ -651,11 +666,14 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop * * @throws DMLRuntimeException if error occurs */ - private void inferCHWPQFromParentOp() throws DMLRuntimeException {Hop tmp = getInput().get(0); - while(isInputReLU(tmp) || isInputBiasAdd(tmp)) { - // Skip ReLU and bias_add and go to its parent - tmp = tmp.getInput().get(0); - } + private void inferCHWPQFromParentOp() throws DMLRuntimeException { + Hop tmp = getInput().get(0); + // Skip bias_add and go to its parent + tmp = isInputBiasAdd(tmp) ? tmp.getInput().get(0) : tmp; + Hop parentReLU = isInputReLU(tmp); + // Skip ReLU and go to its parent + tmp = (parentReLU != null) ? parentReLU : tmp; + // Cast tmp as parent ConvolutionOp parentOp = (tmp instanceof ConvolutionOp) ? ((ConvolutionOp) tmp) : null;
