Repository: incubator-systemml Updated Branches: refs/heads/master 0a61fe084 -> e3a75d141
[SYSTEMML-1381] Fix worst-case size propagation convolution hop Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/442b9a5b Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/442b9a5b Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/442b9a5b Branch: refs/heads/master Commit: 442b9a5b4aaca57f0d459da267551c885a35e6e1 Parents: 0a61fe0 Author: Matthias Boehm <[email protected]> Authored: Thu Mar 9 17:25:51 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Mar 9 17:25:51 2017 -0800 ---------------------------------------------------------------------- .../org/apache/sysml/hops/ConvolutionOp.java | 33 ++++++-------------- 1 file changed, 9 insertions(+), 24 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/442b9a5b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java index a13de52..c32f227 100644 --- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java +++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java @@ -244,15 +244,14 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop protected long[] inferOutputCharacteristics( MemoTable memo ) { // [numRows, numCols, NNZ] - long[] ret = null; + long[] ret = new long[3]; if(op == ConvOp.BIAS_ADD) { MatrixCharacteristics[] mc = memo.getAllInputStats(getInput()); - ret = new long[3]; ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1; ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1; ret[2] = -1; - return ret; + return (ret[0]>0 && ret[1]>0) ? ret : null; } ConvolutionParameters params; @@ -264,41 +263,26 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop switch(op) { - case MAX_POOLING: - { - ret = new long[3]; + case MAX_POOLING: { ret[0] = getInput().get(0)._dim1; ret[1] = getExtractedVal(params.C, params.P, params.Q); ret[2] = -1; break; } - case MAX_POOLING_BACKWARD: - { - ret = new long[3]; - ret[0] = getInput().get(0)._dim1; - ret[1] = getInput().get(0)._dim2; - ret[2] = -1; - break; - } - case DIRECT_CONV2D: - { - ret = new long[3]; + case DIRECT_CONV2D: { ret[0] = getInput().get(0)._dim1; ret[1] = getExtractedVal(getInput().get(1)._dim1, params.P, params.Q); ret[2] = -1; break; } - case DIRECT_CONV2D_BACKWARD_FILTER: - { - ret = new long[3]; + case DIRECT_CONV2D_BACKWARD_FILTER: { ret[0] = getInput().get(1)._dim1; ret[1] = getInput().get(1)._dim2; ret[2] = -1; break; } - case DIRECT_CONV2D_BACKWARD_DATA: - { - ret = new long[3]; + case MAX_POOLING_BACKWARD: + case DIRECT_CONV2D_BACKWARD_DATA: { ret[0] = getInput().get(0)._dim1; ret[1] = getInput().get(0)._dim2; ret[2] = -1; @@ -316,7 +300,8 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop " pad=[" + params.pad_h + " " + params.pad_w + "]"); } - return ret; + //safe return (create entry only if at least dims known) + return (ret[0]>0 && ret[1]>0) ? ret : null; }
