[SYSTEMML-2407] Fix size inference reshape w/ zero rows or columns

This patch fixes the robustness of size inference and update statistics
for reshapes w/ zero rows or columns, which led to arithmetic exceptions
due to divide by zero.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/51db735e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/51db735e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/51db735e

Branch: refs/heads/master
Commit: 51db735ebb9c7d183c02446b5328f18007bfec7e
Parents: 2982c73
Author: Matthias Boehm <[email protected]>
Authored: Mon Jun 18 17:41:27 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Jun 18 17:41:27 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/ReorgOp.java     | 37 ++++++++--------
 .../functions/misc/FunctionPotpourriTest.java   |  7 ++++
 .../scripts/functions/misc/FunPotpourriEval.dml | 44 ++++++++++++++++++++
 3 files changed, 71 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/51db735e/src/main/java/org/apache/sysml/hops/ReorgOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java 
b/src/main/java/org/apache/sysml/hops/ReorgOp.java
index 22867c2..eb5d825 100644
--- a/src/main/java/org/apache/sysml/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java
@@ -453,28 +453,31 @@ public class ReorgOp extends MultiThreadedHop
                                
                                // CASE b) DIAG M2V
                                // input is [k,k] matrix and output is [k,1] 
matrix
-                               // #nnz in the output is likely to be k (a 
dense matrix)                
+                               // #nnz in the output is likely to be k (a 
dense matrix)
                                if( k > 1 )
                                        ret = new long[]{k, 1, 
((mc.getNonZeros()>=0) ? Math.min(k,mc.getNonZeros()) : k) };
                                
-                               break;          
+                               break;
                        }
                        case RESHAPE:
                        {
-                               // input is a [k1,k2] matrix and output is a 
[k3,k4] matrix with k1*k2=k3*k4
-                               // #nnz in output is exactly the same as in 
input               
+                               // input is a [k1,k2] matrix and output is a 
[k3,k4] matrix with k1*k2=k3*k4, except for
+                               // special cases where an input or output 
dimension is zero (i.e., 0x5 -> 1x0 is valid)
+                               // #nnz in output is exactly the same as in 
input
                                if( mc.dimsKnown() ) {
-                                       if( _dim1 >= 0  )
-                                               ret = new long[]{ _dim1, 
mc.getRows()*mc.getCols()/_dim1, mc.getNonZeros()};
-                                       else if( _dim2 >= 0 ) 
-                                               ret = new long[]{ 
mc.getRows()*mc.getCols()/_dim2, _dim2, mc.getNonZeros()};
+                                       if( _dim1 > 0  )
+                                               ret = new long[]{_dim1, 
mc.getRows()*mc.getCols()/_dim1, mc.getNonZeros()};
+                                       else if( _dim2 > 0 ) 
+                                               ret = new 
long[]{mc.getRows()*mc.getCols()/_dim2, _dim2, mc.getNonZeros()};
+                                       else if( _dim1 >= 0 && _dim2 >= 0 )
+                                               ret = new long[]{_dim1, _dim2, 
-1};
                                }
                                break;
                        }
                        case SORT:
                        {
                                // input is a [k1,k2] matrix and output is a 
[k1,k3] matrix, where k3=k2 if no index return;
-                               // otherwise k3=1 (for the index vector)        
+                               // otherwise k3=1 (for the index vector)
                                Hop input4 = getInput().get(3); //indexreturn
                                boolean unknownIxRet = !(input4 instanceof 
LiteralOp);
                                
@@ -577,27 +580,27 @@ public class ReorgOp extends MultiThreadedHop
                                
                                // CASE b) DIAG_M2V
                                // input is [k,k] matrix and output is [k,1] 
matrix
-                               // #nnz in the output is likely to be k (a 
dense matrix)                
+                               // #nnz in the output is likely to be k (a 
dense matrix)
                                if( input1.getDim2()>1 ){
-                                       setDim2(1);     
+                                       setDim2(1);
                                        setNnz( (input1.getNnz()>=0) ? 
Math.min(k,input1.getNnz()) : k );
                                }
                                
-                               break;          
+                               break;
                        }
                        case RESHAPE:
                        {
                                // input is a [k1,k2] matrix and output is a 
[k3,k4] matrix with k1*k2=k3*k4
-                               // #nnz in output is exactly the same as in 
input               
+                               // #nnz in output is exactly the same as in 
input
                                Hop input2 = getInput().get(1); //rows 
                                Hop input3 = getInput().get(2); //cols 
                                refreshRowsParameterInformation(input2); 
//refresh rows
                                refreshColsParameterInformation(input3); 
//refresh cols
                                setNnz(input1.getNnz());
-                               if( !dimsKnown() &&input1.dimsKnown() ) { 
//reshape allows to infer dims, if input and 1 dim known
-                                       if(_dim1 >= 0) 
+                               if( !dimsKnown() && input1.dimsKnown() ) { 
//reshape allows to infer dims, if input and 1 dim known
+                                       if(_dim1 > 0) 
                                                _dim2 = 
(input1._dim1*input1._dim2)/_dim1;
-                                       else if(_dim2 >= 0)
+                                       else if(_dim2 > 0)
                                                _dim1 = 
(input1._dim1*input1._dim2)/_dim2; 
                                }
                                break;
@@ -681,5 +684,5 @@ public class ReorgOp extends MultiThreadedHop
                }
                
                return ret;
-       }       
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/51db735e/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
index 21f06ae..bcf7c46 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
@@ -31,6 +31,7 @@ public class FunctionPotpourriTest extends AutomatedTestBase
        private final static String TEST_NAME1 = "FunPotpourriNoReturn";
        private final static String TEST_NAME2 = "FunPotpourriComments";
        private final static String TEST_NAME3 = "FunPotpourriNoReturn2";
+       private final static String TEST_NAME4 = "FunPotpourriEval";
        
        private final static String TEST_DIR = "functions/misc/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FunctionPotpourriTest.class.getSimpleName() + "/";
@@ -41,6 +42,7 @@ public class FunctionPotpourriTest extends AutomatedTestBase
                addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
        }
 
        @Test
@@ -58,6 +60,11 @@ public class FunctionPotpourriTest extends AutomatedTestBase
                runFunctionTest( TEST_NAME3, false );
        }
        
+       @Test
+       public void testFunctionEval() {
+               runFunctionTest( TEST_NAME4, false );
+       }
+       
        private void runFunctionTest(String testName, boolean error) {
                TestConfiguration config = getTestConfiguration(testName);
                loadTestConfiguration(config);

http://git-wip-us.apache.org/repos/asf/systemml/blob/51db735e/src/test/scripts/functions/misc/FunPotpourriEval.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/FunPotpourriEval.dml 
b/src/test/scripts/functions/misc/FunPotpourriEval.dml
new file mode 100644
index 0000000..9af31fd
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunPotpourriEval.dml
@@ -0,0 +1,44 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(Matrix[Double] weights, Matrix[Double] X, Integer p, Integer P, 
Integer q, Integer Q, Integer s) return(matrix[double] grad) {
+  combined_weights = rbind (weights, matrix(2, p*P, 1))
+  res_A = matrix(1, rows=p+P, cols=1)
+  grad = matrix(0, rows=p+P, cols=1)
+  if (p > 0) grad[1:p,] = res_A[1:p,]
+  if (P > 0) grad[p+1:p+P,] = res_A[p+1:p+P,]
+  if (p>0 & P>0){
+    res_A = res_A[p+P+1:nrow(res_A),]
+    for(i in seq(1, p, 1)){
+      permut = matrix(0, rows=p, cols=P)
+      permut[i,] = t(combined_weights[p+1:p+P,])
+      grad[i,1] = grad[i,1] + sum(res_A * matrix(permut, rows=p*P, cols=1))
+    }
+    for(i in seq(1, P, 1)){
+      permut = matrix(0, rows=p, cols=P)
+      permut[,i] = combined_weights[1:p,]
+      grad[p+i,1] = grad[p+i,1] + sum(res_A * matrix(permut, rows=p*P, cols=1))
+    }
+  }
+}
+
+best_point = eval ("foo", matrix(1, 2, 1),matrix(0, 998, 3), 2, 0, 0, 0, 10)
+print(toString(best_point))

Reply via email to