Repository: incubator-systemml
Updated Branches:
  refs/heads/master 0a61fe084 -> e3a75d141


[SYSTEMML-1381] Fix worst-case size propagation convolution hop

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/442b9a5b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/442b9a5b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/442b9a5b

Branch: refs/heads/master
Commit: 442b9a5b4aaca57f0d459da267551c885a35e6e1
Parents: 0a61fe0
Author: Matthias Boehm <[email protected]>
Authored: Thu Mar 9 17:25:51 2017 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Thu Mar 9 17:25:51 2017 -0800

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    | 33 ++++++--------------
 1 file changed, 9 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/442b9a5b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java 
b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index a13de52..c32f227 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -244,15 +244,14 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
        protected long[] inferOutputCharacteristics( MemoTable memo )
        {
                // [numRows, numCols, NNZ] 
-               long[] ret = null;
+               long[] ret = new long[3];
                
                if(op == ConvOp.BIAS_ADD) {
                        MatrixCharacteristics[] mc = 
memo.getAllInputStats(getInput());
-                       ret = new long[3];
                        ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
                        ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1;
                        ret[2] = -1;
-                       return ret;
+                       return (ret[0]>0 && ret[1]>0) ? ret : null;
                }
        
                ConvolutionParameters params;
@@ -264,41 +263,26 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                
                switch(op) 
                {
-                       case MAX_POOLING:
-                       {
-                               ret = new long[3];
+                       case MAX_POOLING: {
                                ret[0] = getInput().get(0)._dim1;
                                ret[1] = getExtractedVal(params.C, params.P, 
params.Q);
                                ret[2] = -1;
                                break;
                        }
-                       case MAX_POOLING_BACKWARD:
-                       {
-                               ret = new long[3];
-                               ret[0] = getInput().get(0)._dim1;
-                               ret[1] = getInput().get(0)._dim2;
-                               ret[2] = -1;
-                               break;
-                       }
-                       case DIRECT_CONV2D:
-                       {
-                               ret = new long[3];
+                       case DIRECT_CONV2D: {
                                ret[0] = getInput().get(0)._dim1;
                                ret[1] = 
getExtractedVal(getInput().get(1)._dim1, params.P, params.Q);
                                ret[2] = -1;
                                break;
                        }
-                       case DIRECT_CONV2D_BACKWARD_FILTER:
-                       {
-                               ret = new long[3];
+                       case DIRECT_CONV2D_BACKWARD_FILTER: {
                                ret[0] = getInput().get(1)._dim1;
                                ret[1] = getInput().get(1)._dim2;
                                ret[2] = -1;
                                break;
                        }
-                       case DIRECT_CONV2D_BACKWARD_DATA:
-                       {
-                               ret = new long[3];
+                       case MAX_POOLING_BACKWARD:
+                       case DIRECT_CONV2D_BACKWARD_DATA: {
                                ret[0] = getInput().get(0)._dim1;
                                ret[1] = getInput().get(0)._dim2;
                                ret[2] = -1;
@@ -316,7 +300,8 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                                        " pad=[" + params.pad_h + " " + 
params.pad_w + "]");
                }
                
-               return ret;
+               //safe return (create entry only if at least dims known)
+               return (ret[0]>0 && ret[1]>0) ? ret : null;
        }
        
 

Reply via email to