Repository: incubator-systemml Updated Branches: refs/heads/master 7065b7f39 -> 55c8ee7d6
[SYSTEMML-762] Bug fix for memory estimates of ConvolutionOp Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/55c8ee7d Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/55c8ee7d Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/55c8ee7d Branch: refs/heads/master Commit: 55c8ee7d6e3c1fcdf5c2583eee3f0a287d4baac9 Parents: 7065b7f Author: Niketan Pansare <npan...@us.ibm.com> Authored: Wed Jun 15 16:17:05 2016 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Wed Jun 15 16:20:26 2016 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/ConvolutionOp.java | 22 +++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/55c8ee7d/src/main/java/org/apache/sysml/hops/ConvolutionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java index 7f53e2e..07a45b6 100644 --- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java +++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java @@ -315,7 +315,26 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop // input_shape1, input_shape2, input_shape3, input_shape4, // filter_shape1, filter_shape2, filter_shape3, filter_shape4 ConvolutionParameters parseInput() throws DMLRuntimeException { - ConvolutionParameters params = new ConvolutionParameters( + ConvolutionParameters params = null; + if(op == ConvOp.MAX_POOLING_BACKWARD + || op == ConvOp.DIRECT_CONV2D + || op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER + || op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) { + params = new ConvolutionParameters( + extractValue(getInput().get(6)), + extractValue(getInput().get(7)), + extractValue(getInput().get(8)), + extractValue(getInput().get(9)), + extractValue(getInput().get(10)), + extractValue(getInput().get(12)), + extractValue(getInput().get(13)), + extractValue(getInput().get(2)), + extractValue(getInput().get(3)), + extractValue(getInput().get(4)), + extractValue(getInput().get(5)), _maxNumThreads); + } + else { + params = new ConvolutionParameters( extractValue(getInput().get(5)), extractValue(getInput().get(6)), extractValue(getInput().get(7)), @@ -327,6 +346,7 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop extractValue(getInput().get(2)), extractValue(getInput().get(3)), extractValue(getInput().get(4)), _maxNumThreads); + } return params; }