Repository: systemml Updated Branches: refs/heads/master e9fb7a028 -> 7602af94f
[SYSTEMML-1686] Fix validate conv2d_backward_data (size propagation) Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7602af94 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7602af94 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7602af94 Branch: refs/heads/master Commit: 7602af94fbd9554f097a573bee87d1b51e709ccb Parents: e9fb7a0 Author: Matthias Boehm <[email protected]> Authored: Wed Jun 14 18:17:23 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jun 14 18:17:57 2017 -0700 ---------------------------------------------------------------------- .../apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java | 1 - .../org/apache/sysml/parser/BuiltinFunctionExpression.java | 9 ++++++--- 2 files changed, 6 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/7602af94/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java index f713b7b..3424a52 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java @@ -25,7 +25,6 @@ import java.util.Set; import java.util.Map.Entry; import org.apache.sysml.hops.HopsException; -import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.FunctionStatementBlock; import org.apache.sysml.parser.LanguageException; http://git-wip-us.apache.org/repos/asf/systemml/blob/7602af94/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index e3b3e79..d133ff2 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -1113,7 +1113,7 @@ public class BuiltinFunctionExpression extends DataIdentifier // // Similarly, // conv2d_backward_filter and conv2d_backward_data - Expression input = _args[0]; // For conv2d_backward_filter, this is input and for conv2d_backward_data, this is filter + Expression input = _args[0]; // For conv2d_backward_filter, this is input and for conv2d_backward_data, this is filter Expression filter = null; if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) { @@ -1125,10 +1125,13 @@ public class BuiltinFunctionExpression extends DataIdentifier output.setBlockDimensions(input.getOutput().getRowsInBlock(), input.getOutput().getColumnsInBlock()); // stride1, stride2, padding1, padding2, numImg, numChannels, imgSize, imgSize, // filter_shape1=1, filter_shape2=1, filterSize/poolSize1, filterSize/poolSize1 - if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD || - this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA) { + if( getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD ) { output.setDimensions(input.getOutput().getDim1(), input.getOutput().getDim2()); } + else if( getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA ) { + //args[0] .. filter, args[1] .. input + output.setDimensions(_args[1].getOutput().getDim1(), -1); + } else if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER) { output.setDimensions(filter.getOutput().getDim1(), filter.getOutput().getDim2()); }
