Repository: incubator-systemml
Updated Branches:
  refs/heads/master 7065b7f39 -> 55c8ee7d6


[SYSTEMML-762] Bug fix for memory estimates of ConvolutionOp


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/55c8ee7d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/55c8ee7d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/55c8ee7d

Branch: refs/heads/master
Commit: 55c8ee7d6e3c1fcdf5c2583eee3f0a287d4baac9
Parents: 7065b7f
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Wed Jun 15 16:17:05 2016 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Wed Jun 15 16:20:26 2016 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    | 22 +++++++++++++++++++-
 1 file changed, 21 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/55c8ee7d/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java 
b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index 7f53e2e..07a45b6 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -315,7 +315,26 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
        // input_shape1, input_shape2, input_shape3, input_shape4, 
        // filter_shape1, filter_shape2, filter_shape3, filter_shape4
        ConvolutionParameters parseInput() throws DMLRuntimeException {
-               ConvolutionParameters params = new ConvolutionParameters(
+               ConvolutionParameters params = null;
+               if(op == ConvOp.MAX_POOLING_BACKWARD 
+                               || op == ConvOp.DIRECT_CONV2D 
+                               || op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER
+                               || op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
+                       params = new ConvolutionParameters(
+                                       extractValue(getInput().get(6)),
+                                       extractValue(getInput().get(7)), 
+                                       extractValue(getInput().get(8)), 
+                                       extractValue(getInput().get(9)), 
+                                       extractValue(getInput().get(10)), 
+                                       extractValue(getInput().get(12)), 
+                                       extractValue(getInput().get(13)), 
+                                       extractValue(getInput().get(2)), 
+                                       extractValue(getInput().get(3)), 
+                                       extractValue(getInput().get(4)), 
+                                       extractValue(getInput().get(5)), 
_maxNumThreads);
+               }
+               else {
+                       params = new ConvolutionParameters(
                                extractValue(getInput().get(5)),
                                extractValue(getInput().get(6)), 
                                extractValue(getInput().get(7)), 
@@ -327,6 +346,7 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                                extractValue(getInput().get(2)), 
                                extractValue(getInput().get(3)), 
                                extractValue(getInput().get(4)), 
_maxNumThreads);
+               }
                return params;
        }
        

Reply via email to