Repository: systemml
Updated Branches:
  refs/heads/master ccac6dd37 -> 6778a63b0


[HOTFIX] Fix for recently updated validation code for convolution operation

- Tested NNTest in local environment.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6778a63b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6778a63b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6778a63b

Branch: refs/heads/master
Commit: 6778a63b02fc1c644501bae67cd24e639ed3a623
Parents: ccac6dd
Author: Niketan Pansare <[email protected]>
Authored: Fri Jul 14 13:59:41 2017 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Fri Jul 14 14:59:41 2017 -0700

----------------------------------------------------------------------
 .../sysml/parser/BuiltinFunctionExpression.java | 94 ++++++++++++--------
 1 file changed, 57 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/6778a63b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index 58760bc..54281cc 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -1124,45 +1124,65 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        output.setDataType(DataType.MATRIX);
                        output.setValueType(ValueType.DOUBLE);
                        
output.setBlockDimensions(input.getOutput().getRowsInBlock(), 
input.getOutput().getColumnsInBlock());
-                       // stride1, stride2, padding1, padding2, numImg, 
numChannels, imgSize, imgSize, 
-                       // filter_shape1=1, filter_shape2=1, 
filterSize/poolSize1, filterSize/poolSize1
-                       try {
-                               int start = 2;
-                               if(!(this.getOpCode() == 
BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) {
-                                       start = 1;
+                       
+                       if(this.getOpCode() == 
BuiltinFunctionOp.MAX_POOL_BACKWARD) {
+                               
output.setDimensions(input.getOutput().getDim1(), input.getOutput().getDim2());
+                       }
+                       else {
+                               // stride1, stride2, padding1, padding2, 
numImg, numChannels, imgSize, imgSize, 
+                               // filter_shape1=1, filter_shape2=1, 
filterSize/poolSize1, filterSize/poolSize1
+                               try {
+                                       int start = 2;
+                                       if(!(this.getOpCode() == 
BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) {
+                                               start = 1;
+                                       }
+                                       long stride_h = (long) 
getDoubleValue(_args[start++]);
+                                       long stride_w = (long) 
getDoubleValue(_args[start++]);
+                                       long pad_h = (long) 
getDoubleValue(_args[start++]);
+                                       long pad_w = (long) 
getDoubleValue(_args[start++]); 
+                                       long N = (long) 
getDoubleValue(_args[start++]);
+                                       long C = (long) 
getDoubleValue(_args[start++]);
+                                       long H = (long) 
getDoubleValue(_args[start++]);
+                                       long W = (long) 
getDoubleValue(_args[start++]);
+                                       long K = -1;
+                                       if(!(this.getOpCode() == 
BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) {
+                                               K = (long) 
getDoubleValue(_args[start]);
+                                       }
+                                       start++; start++; // Increment index 
for K and C
+                                       long R = (long) 
getDoubleValue(_args[start++]);
+                                       long S = (long) 
getDoubleValue(_args[start++]);
+                                       
+                                       if(this.getOpCode() == 
BuiltinFunctionOp.CONV2D_BACKWARD_FILTER) {
+                                               output.setDimensions(K, C*R*S);
+                                       }
+                                       else if(this.getOpCode() == 
BuiltinFunctionOp.CONV2D_BACKWARD_DATA) {
+                                               output.setDimensions(N, C*H*W);
+                                       }
+                                       else if(H > 0 && W > 0 && stride_h > 0 
&& stride_w > 0 && pad_h >= 0 && pad_w >= 0 && R > 0 && S > 0) {
+                                               long P = 
ConvolutionUtils.getP(H, R, stride_h, pad_h);
+                                               long Q = 
ConvolutionUtils.getQ(W, S, stride_w, pad_w);
+                                               
+                                               // Try to set both rows and 
columns
+                                               if(this.getOpCode() == 
BuiltinFunctionOp.CONV2D) 
+                                                       output.setDimensions(N, 
K*P*Q);
+                                               else if(this.getOpCode() == 
BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)
+                                                       output.setDimensions(N, 
C*P*Q);
+                                               else
+                                                       throw new 
LanguageException("");
+                                       }
+                                       else {
+                                               // Since columns cannot be 
computed, set only rows
+                                               if(this.getOpCode() == 
BuiltinFunctionOp.CONV2D) 
+                                                       
output.setDimensions(input.getOutput().getDim1(), -1);
+                                               else if(this.getOpCode() == 
BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)
+                                                       
output.setDimensions(input.getOutput().getDim1(), -1);
+                                               else
+                                                       throw new 
LanguageException("");
+                                       }
                                }
-                               long stride_h = (long) 
getDoubleValue(_args[start++]);
-                               long stride_w = (long) 
getDoubleValue(_args[start++]);
-                               long pad_h = (long) 
getDoubleValue(_args[start++]);
-                               long pad_w = (long) 
getDoubleValue(_args[start++]); 
-                               long N = (long) getDoubleValue(_args[start++]);
-                               long C = (long) getDoubleValue(_args[start++]);
-                               long H = (long) getDoubleValue(_args[start++]);
-                               long W = (long) getDoubleValue(_args[start++]);
-                               long K = -1;
-                               if(!(this.getOpCode() == 
BuiltinFunctionOp.MAX_POOL || this.getOpCode() == BuiltinFunctionOp.AVG_POOL)) {
-                                       K = (long) getDoubleValue(_args[start]);
+                               catch(Exception e) {
+                                       output.setDimensions(-1, -1); // To 
make sure that output dimensions are not incorrect even if getDoubleValue 
doesnot return value
                                }
-                               start++; start++; // Increment index for K and C
-                               long R = (long) getDoubleValue(_args[start++]);
-                               long S = (long) getDoubleValue(_args[start++]);
-                               long P = ConvolutionUtils.getP(H, R, stride_h, 
pad_h);
-                               long Q = ConvolutionUtils.getP(W, S, stride_w, 
pad_w);
-                               if(this.getOpCode() == 
BuiltinFunctionOp.CONV2D) 
-                                       output.setDimensions(N, K*P*Q);
-                               else if(this.getOpCode() == 
BuiltinFunctionOp.CONV2D_BACKWARD_FILTER)
-                                       output.setDimensions(K, C*R*S);
-                               else if(this.getOpCode() == 
BuiltinFunctionOp.CONV2D_BACKWARD_DATA)
-                                       output.setDimensions(N, C*H*W);
-                               else if(this.getOpCode() == 
BuiltinFunctionOp.MAX_POOL)
-                                       output.setDimensions(N, C*P*Q);
-                               else if(this.getOpCode() == 
BuiltinFunctionOp.MAX_POOL_BACKWARD)
-                                       output.setDimensions(N, C*H*W);
-                               else
-                                       throw new LanguageException("");
-                       }
-                       catch(Exception e) {
-                               
output.setDimensions(input.getOutput().getDim1(), -1); // To make sure that 
output dimensions are not incorrect
                        }
                        checkMatrixParam(input);
                        if(input2 != null)

Reply via email to