Repository: systemml
Updated Branches:
  refs/heads/master 3c519e739 -> cad7c1e0f


[SYSTEMML-2353] Fix parfor optimizer block partitioning analysis

This patch fixes an issue of the parfor optimizer that caused a crash
during the block partitioning analysis for special cases of linear
indexing expression with minus and non-existing variables (e.g.,
n-s*i2). Accordingly, we also add a new size propagation test with a
script similar to the failing scenario.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d1000138
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d1000138
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d1000138

Branch: refs/heads/master
Commit: d100013813f15d0e3bcc43b0905aa2a863677073
Parents: 3c519e7
Author: Matthias Boehm <[email protected]>
Authored: Thu May 31 17:39:27 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu May 31 21:05:10 2018 -0700

----------------------------------------------------------------------
 .../sysml/parser/ParForStatementBlock.java      | 113 +++++--------------
 .../functions/misc/SizePropagationTest.java     |  17 ++-
 .../functions/misc/SizePropagationLoopIx3.dml   |  35 ++++++
 3 files changed, 82 insertions(+), 83 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d1000138/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java 
b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
index 1de4bf3..113186f 100644
--- a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
@@ -1665,101 +1665,48 @@ public class ParForStatementBlock extends 
ForStatementBlock
         * @param be binary expression
         * @return linear function
         */
-       private LinearFunction rParseBinaryExpression(BinaryExpression be) 
-       {
-               LinearFunction ret = null;
+       private LinearFunction rParseBinaryExpression(BinaryExpression be) {
                Expression l = be.getLeft();
                Expression r = be.getRight();
-               
-               if( be.getOpCode() == BinaryOp.PLUS )
-               {                       
-                       //parse binary expressions
-                       if( l instanceof BinaryExpression)
-                       {
-                               ret = rParseBinaryExpression((BinaryExpression) 
l);
-                               Long cvalR = parseLongConstant(r);
-                               if( ret != null && cvalR != null )
-                                       ret.addConstant(cvalR);
-                               else 
-                                       return null;
-                       }
-                       else if (r instanceof BinaryExpression)
-                       {
-                               ret = rParseBinaryExpression((BinaryExpression) 
r);
-                               Long cvalL = parseLongConstant(l);
-                               if( ret != null && cvalL != null )
-                                       ret.addConstant(cvalL);
-                               else
-                                       return null;
-                       }
-                       else // atomic case
-                       {
-                               Long cvalL = parseLongConstant(l);
-                               Long cvalR = parseLongConstant(r);
-                               if( cvalL != null )
-                                       ret = new 
LinearFunction(cvalL,1,((DataIdentifier)r)._name);    
-                               else if( cvalR != null )
-                                       ret = new 
LinearFunction(cvalR,1,((DataIdentifier)l)._name);
-                               else
-                                       return null; //let dependency analysis 
fail
-                       }
-               }
-               else if( be.getOpCode() == BinaryOp.MINUS ) 
-               {
+               if( be.getOpCode() == BinaryOp.PLUS || be.getOpCode() == 
BinaryOp.MINUS ) {
+                       boolean plus = be.getOpCode() == BinaryOp.PLUS;
                        //parse binary expressions
                        if( l instanceof BinaryExpression) {
-                               ret = rParseBinaryExpression((BinaryExpression) 
l);
-                               if( ret != null ) //change to plus
-                                       
ret.addConstant(parseLongConstant(r)*(-1));
+                               LinearFunction f = 
rParseBinaryExpression((BinaryExpression) l);
+                               Long cvalR = parseLongConstant(r);
+                               if( f != null && cvalR != null )
+                                       return f.addConstant(cvalR * 
(plus?1:-1));
                        }
                        else if (r instanceof BinaryExpression) {
-                               ret = rParseBinaryExpression((BinaryExpression) 
r);
-                               if( ret != null ) { //change to plus
-                                       ret._a*=(-1);
-                                       for( int i=0; i<ret._b.length; i++ )
-                                               ret._b[i]*=(-1);
-                                       Long cvalL = parseLongConstant(l);
-                                       ret.addConstant(cvalL);
-                               }
+                               LinearFunction f = 
rParseBinaryExpression((BinaryExpression) r);
+                               Long cvalL = parseLongConstant(l);
+                               if( f != null && cvalL != null )
+                                       return 
f.scale(plus?1:-1).addConstant(cvalL);
                        }
                        else { // atomic case
-                               //change everything to plus
+                               //change everything to plus if necessary
                                Long cvalL = parseLongConstant(l);
                                Long cvalR = parseLongConstant(r);
                                if( cvalL != null )
-                                       ret = new 
LinearFunction(cvalL,-1,((DataIdentifier)r)._name);
+                                       return new 
LinearFunction(cvalL,plus?1:-1,((DataIdentifier)r)._name);
                                else if( cvalR != null )
-                                       ret = new 
LinearFunction(cvalR*(-1),1,((DataIdentifier)l)._name);
-                               else
-                                       return null; //let dependency analysis 
fail
+                                       return new 
LinearFunction(cvalR*(plus?1:-1),1,((DataIdentifier)l)._name);
                        }
                }
                else if( be.getOpCode() == BinaryOp.MULT ) {
-                       //NOTE: only recursion for MULT expressions, where one 
side is a constant 
-                       
-                       //atomic case
+                       //atomic case (only recursion for MULT expressions, 
where one side is a constant)
                        Long cvalL = parseLongConstant(l);
                        Long cvalR = parseLongConstant(r);
-                       
                        if( cvalL != null && r instanceof DataIdentifier )
-                               ret = new LinearFunction(0, 
cvalL,((DataIdentifier)r)._name);
+                               return new LinearFunction(0, 
cvalL,((DataIdentifier)r)._name);
                        else if( cvalR != null && l instanceof DataIdentifier )
-                               ret = new LinearFunction(0, 
cvalR,((DataIdentifier)l)._name);
-                       else if( cvalL != null && r instanceof BinaryExpression 
) {
-                               LinearFunction ltmp = 
rParseBinaryExpression((BinaryExpression)r);
-                               return ltmp.scale(cvalL);
-                       }
-                       else if( cvalR != null && l instanceof BinaryExpression 
) {
-                               LinearFunction ltmp = 
rParseBinaryExpression((BinaryExpression)l);
-                               return ltmp.scale(cvalR);
-                       }
-                       else
-                               return null; //let dependency analysis fail
+                               return new LinearFunction(0, 
cvalR,((DataIdentifier)l)._name);
+                       else if( cvalL != null && r instanceof BinaryExpression 
)
+                               return 
rParseBinaryExpression((BinaryExpression)r).scale(cvalL);
+                       else if( cvalR != null && l instanceof BinaryExpression 
)
+                               return 
rParseBinaryExpression((BinaryExpression)l).scale(cvalR);
                }
-               else
-                       return null; //let dependency analysis fail
-                       
-               return ret;
+               return null; //let dependency analysis fail
        }
 
        private static Long parseLongConstant(Expression expr)
@@ -1837,10 +1784,9 @@ public class ParForStatementBlock extends 
ForStatementBlock
         * the applied GCD and Banerjee tests.
         *
         */
-       private class LinearFunction
-       {
-               long     _a;     // intercept
-               long[]   _b;     // slopes 
+       private class LinearFunction {
+               long _a;        // intercept
+               long[] _b;      // slopes 
                String[] _vars; // b variable names
                
                LinearFunction( long a, long b, String name ) {
@@ -1851,11 +1797,12 @@ public class ParForStatementBlock extends 
ForStatementBlock
                        _vars[0] = name;
                }
                
-               public void addConstant(long value) {
+               public LinearFunction addConstant(long value) {
                        _a += value;
+                       return this;
                }
 
-               public void addFunction( LinearFunction f2) {
+               public LinearFunction addFunction( LinearFunction f2) {
                        _a = _a + f2._a;
                        long[] tmpb = new long[_b.length+f2._b.length];
                        System.arraycopy( _b,    0, tmpb, 0,         _b.length  
  );
@@ -1865,9 +1812,10 @@ public class ParForStatementBlock extends 
ForStatementBlock
                        System.arraycopy( _vars,    0, tmpvars, 0,            
_vars.length    );
                        System.arraycopy( f2._vars, 0, tmpvars, _vars.length, 
f2._vars.length );
                        _vars = tmpvars;
+                       return this;
                }
 
-               public void removeVar( int i ) {
+               public LinearFunction removeVar( int i ) {
                        long[] tmpb = new long[_b.length-1];
                        System.arraycopy( _b, 0, tmpb, 0, i );
                        System.arraycopy( _b, i+1, tmpb, i, _b.length-i-1 );
@@ -1876,6 +1824,7 @@ public class ParForStatementBlock extends 
ForStatementBlock
                        System.arraycopy( _vars, 0, tmpvars, 0, i );
                        System.arraycopy( _vars, i+1, tmpvars, i, 
_vars.length-i-1 );
                        _vars = tmpvars;
+                       return this;
                }
                
                public LinearFunction scale( long scale ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d1000138/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java
index 4a8ed98..a2714f0 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/SizePropagationTest.java
@@ -25,6 +25,7 @@ import org.junit.Assert;
 
 import java.util.HashMap;
 
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
@@ -37,6 +38,7 @@ public class SizePropagationTest extends AutomatedTestBase
        private static final String TEST_NAME1 = "SizePropagationRBind";
        private static final String TEST_NAME2 = "SizePropagationLoopIx1";
        private static final String TEST_NAME3 = "SizePropagationLoopIx2";
+       private static final String TEST_NAME4 = "SizePropagationLoopIx3";
        
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
SizePropagationTest.class.getSimpleName() + "/";
@@ -49,6 +51,7 @@ public class SizePropagationTest extends AutomatedTestBase
                addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
        }
 
        @Test
@@ -81,6 +84,16 @@ public class SizePropagationTest extends AutomatedTestBase
                testSizePropagation( TEST_NAME3, true, N-2 );
        }
        
+       @Test
+       public void testSizePropagationLoopIx3NoRewrites() {
+               testSizePropagation( TEST_NAME4, false, N-1 );
+       }
+       
+       @Test
+       public void testSizePropagationLoopIx3Rewrites() {
+               testSizePropagation( TEST_NAME4, true, N-1 );
+       }
+       
        private void testSizePropagation( String testname, boolean rewrites, 
int expect ) {
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                RUNTIME_PLATFORM oldPlatform = rtplatform;
@@ -93,7 +106,8 @@ public class SizePropagationTest extends AutomatedTestBase
                        fullDMLScriptName = HOME + testname + ".dml";
                        programArgs = new String[]{ "-explain", "hops", 
"-stats","-args", String.valueOf(N), output("R") };
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
-                       rtplatform = RUNTIME_PLATFORM.SINGLE_NODE;
+                       rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
                        
                        runTest(true, false, null, -1); 
                        HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("R");
@@ -101,6 +115,7 @@ public class SizePropagationTest extends AutomatedTestBase
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = false;
                        rtplatform = oldPlatform;
                }
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d1000138/src/test/scripts/functions/misc/SizePropagationLoopIx3.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/SizePropagationLoopIx3.dml 
b/src/test/scripts/functions/misc/SizePropagationLoopIx3.dml
new file mode 100644
index 0000000..63ac539
--- /dev/null
+++ b/src/test/scripts/functions/misc/SizePropagationLoopIx3.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+P = 2
+s = 25
+Y = rand(rows=$1, cols=1, min=1, max=1)
+for(i in seq(1,1,1)) {
+  n1 = nrow(Y)+0.0
+  Y = Y[2:n1,] - Y[1:n1-1,]
+}
+n = nrow(Y)
+Z = rand(rows=n, cols=P, min=0, max=0)
+parfor(i2 in seq(1, P, 1)){
+  Z[s*i2+1:n,i2] = Y[1:n-s*i2,]
+}
+R = as.matrix(nrow(Z));
+write(R, $2);

Reply via email to