Repository: systemml Updated Branches: refs/heads/master ccac6dd37 -> 6778a63b0
[HOTFIX] Fix for recently updated validation code for convolution operation - Tested NNTest in local environment. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6778a63b Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6778a63b Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6778a63b Branch: refs/heads/master Commit: 6778a63b02fc1c644501bae67cd24e639ed3a623 Parents: ccac6dd Author: Niketan Pansare <[email protected]> Authored: Fri Jul 14 13:59:41 2017 -0800 Committer: Niketan Pansare <[email protected]> Committed: Fri Jul 14 14:59:41 2017 -0700 ---------------------------------------------------------------------- .../sysml/parser/BuiltinFunctionExpression.java | 94 ++++++++++++-------- 1 file changed, 57 insertions(+), 37 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/6778a63b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java index 58760bc..54281cc 100644 --- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java @@ -1124,45 +1124,65 @@ public class BuiltinFunctionExpression extends DataIdentifier output.setDataType(DataType.MATRIX); output.setValueType(ValueType.DOUBLE); output.setBlockDimensions(input.getOutput().getRowsInBlock(), input.getOutput().getColumnsInBlock()); - // stride1, stride2, padding1, padding2, numImg, numChannels, imgSize, imgSize, - // filter_shape1=1, filter_shape2=1, filterSize/poolSize1, filterSize/poolSize1 - try { - int start = 2; - if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) { - start = 1; + + if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD) { + output.setDimensions(input.getOutput().getDim1(), input.getOutput().getDim2()); + } + else { + // stride1, stride2, padding1, padding2, numImg, numChannels, imgSize, imgSize, + // filter_shape1=1, filter_shape2=1, filterSize/poolSize1, filterSize/poolSize1 + try { + int start = 2; + if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) { + start = 1; + } + long stride_h = (long) getDoubleValue(_args[start++]); + long stride_w = (long) getDoubleValue(_args[start++]); + long pad_h = (long) getDoubleValue(_args[start++]); + long pad_w = (long) getDoubleValue(_args[start++]); + long N = (long) getDoubleValue(_args[start++]); + long C = (long) getDoubleValue(_args[start++]); + long H = (long) getDoubleValue(_args[start++]); + long W = (long) getDoubleValue(_args[start++]); + long K = -1; + if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) { + K = (long) getDoubleValue(_args[start]); + } + start++; start++; // Increment index for K and C + long R = (long) getDoubleValue(_args[start++]); + long S = (long) getDoubleValue(_args[start++]); + + if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER) { + output.setDimensions(K, C*R*S); + } + else if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA) { + output.setDimensions(N, C*H*W); + } + else if(H > 0 && W > 0 && stride_h > 0 && stride_w > 0 && pad_h >= 0 && pad_w >= 0 && R > 0 && S > 0) { + long P = ConvolutionUtils.getP(H, R, stride_h, pad_h); + long Q = ConvolutionUtils.getQ(W, S, stride_w, pad_w); + + // Try to set both rows and columns + if(this.getOpCode() == BuiltinFunctionOp.CONV2D) + output.setDimensions(N, K*P*Q); + else if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL) + output.setDimensions(N, C*P*Q); + else + throw new LanguageException(""); + } + else { + // Since columns cannot be computed, set only rows + if(this.getOpCode() == BuiltinFunctionOp.CONV2D) + output.setDimensions(input.getOutput().getDim1(), -1); + else if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL) + output.setDimensions(input.getOutput().getDim1(), -1); + else + throw new LanguageException(""); + } } - long stride_h = (long) getDoubleValue(_args[start++]); - long stride_w = (long) getDoubleValue(_args[start++]); - long pad_h = (long) getDoubleValue(_args[start++]); - long pad_w = (long) getDoubleValue(_args[start++]); - long N = (long) getDoubleValue(_args[start++]); - long C = (long) getDoubleValue(_args[start++]); - long H = (long) getDoubleValue(_args[start++]); - long W = (long) getDoubleValue(_args[start++]); - long K = -1; - if(!(this.getOpCode() == BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) { - K = (long) getDoubleValue(_args[start]); + catch(Exception e) { + output.setDimensions(-1, -1); // To make sure that output dimensions are not incorrect even if getDoubleValue doesnot return value } - start++; start++; // Increment index for K and C - long R = (long) getDoubleValue(_args[start++]); - long S = (long) getDoubleValue(_args[start++]); - long P = ConvolutionUtils.getP(H, R, stride_h, pad_h); - long Q = ConvolutionUtils.getP(W, S, stride_w, pad_w); - if(this.getOpCode() == BuiltinFunctionOp.CONV2D) - output.setDimensions(N, K*P*Q); - else if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER) - output.setDimensions(K, C*R*S); - else if(this.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA) - output.setDimensions(N, C*H*W); - else if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL) - output.setDimensions(N, C*P*Q); - else if(this.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD) - output.setDimensions(N, C*H*W); - else - throw new LanguageException(""); - } - catch(Exception e) { - output.setDimensions(input.getOutput().getDim1(), -1); // To make sure that output dimensions are not incorrect } checkMatrixParam(input); if(input2 != null)
