[SYSTEMML-2385] New simplification rewrites for comparison chains

This patch introduces new rewrites for binary comparison chains such as
outer(v1,v2,">") == 0 --> outer(v1,v2,"<="), which is especially useful
together with other rewrites such as uaggouterchain for fusing rowSums
or rowIndexMax with the outer operation which provides better asymptotic
behavior but only applies to row aggregates directly over outer
comparison operations.

Furthermore, this also includes a fix for the recompilation tests, which
after the recent cleanup of constant folding produce fewer distributed
jobs.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1e5984cc
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1e5984cc
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1e5984cc

Branch: refs/heads/master
Commit: 1e5984cca10132603af7c638f8bd4ec6139b7061
Parents: 303a2d3
Author: Matthias Boehm <[email protected]>
Authored: Thu Jun 14 19:16:52 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Jun 14 19:46:37 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/BinaryOp.java    |  22 ++-
 .../sysml/hops/rewrite/HopRewriteUtils.java     |   9 ++
 .../RewriteAlgebraicSimplificationStatic.java   |  35 ++++-
 .../misc/RewriteRemoveComparisonChainsTest.java | 106 ++++++++++++++
 .../recompile/PredicateRecompileTest.java       | 140 +++++++------------
 .../functions/misc/RewriteComparisons.dml       |  29 ++++
 .../functions/misc/ZPackageSuite.java           |   1 +
 7 files changed, 244 insertions(+), 98 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/main/java/org/apache/sysml/hops/BinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java 
b/src/main/java/org/apache/sysml/hops/BinaryOp.java
index d66ac12..1a65130 100644
--- a/src/main/java/org/apache/sysml/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java
@@ -1595,10 +1595,22 @@ public class BinaryOp extends Hop
                                ||op==OpOp2.BITWSHIFTL ||op==OpOp2.BITWSHIFTR);
        }
        
-       public boolean isPPredOperation()
-       {
-               return (   op==OpOp2.LESS    ||op==OpOp2.LESSEQUAL
-                        ||op==OpOp2.GREATER ||op==OpOp2.GREATEREQUAL
-                        ||op==OpOp2.EQUAL   ||op==OpOp2.NOTEQUAL);
+       public boolean isPPredOperation() {
+               return (op==OpOp2.LESS    ||op==OpOp2.LESSEQUAL
+                       ||op==OpOp2.GREATER ||op==OpOp2.GREATEREQUAL
+                       ||op==OpOp2.EQUAL   ||op==OpOp2.NOTEQUAL);
+       }
+       
+       public OpOp2 getComplementPPredOperation() {
+               switch( op ) {
+                       case LESS:         return OpOp2.GREATEREQUAL;
+                       case LESSEQUAL:    return OpOp2.GREATER;
+                       case GREATER:      return OpOp2.LESSEQUAL;
+                       case GREATEREQUAL: return OpOp2.LESS;
+                       case EQUAL:        return OpOp2.NOTEQUAL;
+                       case NOTEQUAL:     return OpOp2.EQUAL;
+                       default:
+                               throw new HopsException("BinaryOp is not a 
ppred operation.");
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 269e9e6..9765fc8 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -20,6 +20,7 @@
 package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -163,6 +164,10 @@ public class HopRewriteUtils
                }
        }
        
+       public static boolean isLiteralOfValue( Hop hop, Double... val ) {
+               return Arrays.stream(val).anyMatch(d -> isLiteralOfValue(hop, 
d));
+       }
+       
        public static boolean isLiteralOfValue( Hop hop, double val ) {
                return (hop instanceof LiteralOp 
                        && (hop.getValueType()==ValueType.DOUBLE || 
hop.getValueType()==ValueType.INT)
@@ -914,6 +919,10 @@ public class HopRewriteUtils
                return isBinary(hop, type) && hop.getParent().size() <= 
maxParents;
        }
        
+       public static boolean isBinaryPPred(Hop hop) {
+               return hop instanceof BinaryOp && ((BinaryOp) 
hop).isPPredOperation();
+       }
+       
        public static boolean isBinarySparseSafe(Hop hop) {
                if( !(hop instanceof BinaryOp) )
                        return false;

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 8f9aad9..4396c7b 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -175,13 +175,15 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyTransposeAggBinBinaryChains(hop, hi, 
i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C)
                        hi = simplifyReplaceZeroOperation(hop, hi, i);       
//e.g., X + (X==0) * s -> replace(X, 0, s)
                        hi = removeUnnecessaryMinus(hop, hi, i);             
//e.g., -(-X)->X; potentially introduced by simplify binary or dyn rewrites
-                       hi = simplifyGroupedAggregate(hi);                   
//e.g., aggregate(target=X,groups=y,fn="count") -> 
aggregate(target=y,groups=y,fn="count")
+                       hi = simplifyGroupedAggregate(hi);                   
//e.g., aggregate(target=X,groups=y,fn="count") -> 
aggregate(target=y,groups=y,fn="count")
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
                                hi = fuseMinusNzBinaryOperation(hop, hi, i);    
     //e.g., X-mean*ppred(X,0,!=) -> X -nz mean
                                hi = fuseLogNzUnaryOperation(hop, hi, i);       
     //e.g., ppred(X,0,"!=")*log(X) -> log_nz(X)
                                hi = fuseLogNzBinaryOperation(hop, hi, i);      
     //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
                        }
                        hi = simplifyOuterSeqExpand(hop, hi, i);             
//e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, 
cast=false)
+                       hi = simplifyBinaryComparisonChain(hop, hi, i);      
//e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> 
outer(v1,v2,"!="), 
+                       
                        //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
                        //process childs recursively after rewrites (to 
investigate pattern newly created by rewrites)
@@ -1781,7 +1783,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        {
                                //determine variable parameters for pattern a/b
                                boolean isPatternB = 
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0));
-                               boolean isTransposeRight = 
HopRewriteUtils.isTransposeOperation(hi.getInput().get(1));                     
     
+                               boolean isTransposeRight = 
HopRewriteUtils.isTransposeOperation(hi.getInput().get(1));
                                Hop trgt = isPatternB ? (isTransposeRight ? 
                                                
hi.getInput().get(1).getInput().get(0) :                  //get v from t(v)
                                                
HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v')
@@ -1813,6 +1815,35 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                return hi;
        }
        
+       private static Hop simplifyBinaryComparisonChain(Hop parent, Hop hi, 
int pos) {
+               if( HopRewriteUtils.isBinaryPPred(hi) 
+                       && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 0d, 1d)
+                       && HopRewriteUtils.isBinaryPPred(hi.getInput().get(0)) 
) {
+                       BinaryOp bop = (BinaryOp) hi;
+                       BinaryOp bop2 = (BinaryOp) hi.getInput().get(0);
+                       
+                       //pattern: outer(v1,v2,"!=") == 1 -> outer(v1,v2,"!=")
+                       if( 
HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 1) ) {
+                               HopRewriteUtils.replaceChildReference(parent, 
bop, bop2, pos);
+                               HopRewriteUtils.cleanupUnreferenced(bop);
+                               hi = bop2;
+                               LOG.debug("Applied 
simplifyBinaryComparisonChain1 (line "+hi.getBeginLine()+")");
+                       }
+                       //pattern: outer(v1,v2,"!=") == 0 -> outer(v1,v2,"==")
+                       else {
+                               OpOp2 optr = bop2.getComplementPPredOperation();
+                               BinaryOp tmp = 
HopRewriteUtils.createBinary(bop2.getInput().get(0),
+                                       bop2.getInput().get(1), optr, 
bop2.isOuterVectorOperator());
+                               HopRewriteUtils.replaceChildReference(parent, 
bop, tmp, pos);
+                               HopRewriteUtils.cleanupUnreferenced(bop, bop2);
+                               hi = tmp;
+                               LOG.debug("Applied 
simplifyBinaryComparisonChain0 (line "+hi.getBeginLine()+")");
+                       }
+               }
+               
+               return hi;
+       }
+       
        /**
         * NOTE: currently disabled since this rewrite is INVALID in the
         * presence of NaNs (because (NaN!=NaN) is true). 

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteRemoveComparisonChainsTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteRemoveComparisonChainsTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteRemoveComparisonChainsTest.java
new file mode 100644
index 0000000..43fd4f9
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteRemoveComparisonChainsTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+public class RewriteRemoveComparisonChainsTest extends AutomatedTestBase 
+{
+       private final static String TEST_NAME1 = "RewriteComparisons"; 
+       //a) >, == 0; b) <=, == 1; c) ==, == 0; d) !=, == 1
+       
+       private final static String TEST_DIR = "functions/misc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
RewriteRemoveComparisonChainsTest.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+       }
+
+       @Test
+       public void testComparisonGt0() {
+               runComparisonChainTest( ">", 0, false );
+       }
+       
+       @Test
+       public void testComparisonGt0Rewrites() {
+               runComparisonChainTest( ">", 0, true );
+       }
+       
+       @Test
+       public void testComparisonLte1() {
+               runComparisonChainTest( "<=", 1, false );
+       }
+       
+       @Test
+       public void testComparisonLte1Rewrites() {
+               runComparisonChainTest( "<=", 1, true );
+       }
+       
+       @Test
+       public void testComparisonEq0() {
+               runComparisonChainTest( "==", 0, false );
+       }
+       
+       @Test
+       public void testComparisonEq0Rewrites() {
+               runComparisonChainTest( "==", 0, true );
+       }
+       
+       @Test
+       public void testComparisonNeq1() {
+               runComparisonChainTest( "!=", 1, false );
+       }
+       
+       @Test
+       public void testComparisonNeq1Rewrites() {
+               runComparisonChainTest( "!=", 1, true );
+       }
+
+       private void runComparisonChainTest( String op, int compare, boolean 
rewrites )
+       {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               
+               try {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME1);
+                       loadTestConfiguration(config);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+                       programArgs = new String[]{"-stats","-args", op, 
String.valueOf(compare)};
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+                       runTest(true, false, null, -1);
+                       
+                       //check for applied rewrites
+                       Assert.assertEquals(rewrites, 
heavyHittersContainsString("uaggouterchain"));
+                       if( compare == 1 && rewrites )
+                               
Assert.assertTrue(!heavyHittersContainsString("=="));
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java
index 29660aa..9f8c82f 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java
@@ -32,7 +32,6 @@ import org.apache.sysml.utils.Statistics;
 
 public class PredicateRecompileTest extends AutomatedTestBase 
 {
-       
        private final static String TEST_NAME1 = "while_recompile";
        private final static String TEST_NAME2 = "if_recompile";
        private final static String TEST_NAME3 = "for_recompile";
@@ -41,9 +40,8 @@ public class PredicateRecompileTest extends AutomatedTestBase
        private final static String TEST_CLASS_DIR = TEST_DIR + 
PredicateRecompileTest.class.getSimpleName() + "/";
        
        private final static int rows = 10;
-       private final static int cols = 15;    
-       private final static int val = 7;    
-       
+       private final static int cols = 15;
+       private final static int val = 7;
        
        @Override
        public void setUp() 
@@ -59,225 +57,188 @@ public class PredicateRecompileTest extends 
AutomatedTestBase
        }
 
        @Test
-       public void testWhileRecompile() 
-       {
+       public void testWhileRecompile() {
                runRecompileTest(TEST_NAME1, true, false, false, false);
        }
        
        @Test
-       public void testWhileNoRecompile() 
-       {
+       public void testWhileNoRecompile() {
                runRecompileTest(TEST_NAME1, false, false, false, false);
        }
        
        @Test
-       public void testIfRecompile() 
-       {
+       public void testIfRecompile() {
                runRecompileTest(TEST_NAME2, true, false, false, false);
        }
        
        @Test
-       public void testIfNoRecompile() 
-       {
+       public void testIfNoRecompile() {
                runRecompileTest(TEST_NAME2, false, false, false, false);
        }
        
        @Test
-       public void testForRecompile() 
-       {
+       public void testForRecompile() {
                runRecompileTest(TEST_NAME3, true, false, false, false);
        }
        
        @Test
-       public void testForNoRecompile() 
-       {
+       public void testForNoRecompile() {
                runRecompileTest(TEST_NAME3, false, false, false, false);
        }
        
        @Test
-       public void testParForRecompile() 
-       {
+       public void testParForRecompile() {
                runRecompileTest(TEST_NAME4, true, false, false, false);
        }
        
        @Test
-       public void testParForNoRecompile() 
-       {
+       public void testParForNoRecompile() {
                runRecompileTest(TEST_NAME4, false, false, false, false);
        }
 
        @Test
-       public void testWhileRecompileExprEval() 
-       {
+       public void testWhileRecompileExprEval() {
                runRecompileTest(TEST_NAME1, true, true, false, false);
        }
        
        @Test
-       public void testWhileNoRecompileExprEval() 
-       {
+       public void testWhileNoRecompileExprEval() {
                runRecompileTest(TEST_NAME1, false, true, false, false);
        }
        
        @Test
-       public void testIfRecompileExprEval() 
-       {
+       public void testIfRecompileExprEval() {
                runRecompileTest(TEST_NAME2, true, true, false, false);
        }
        
        @Test
-       public void testIfNoRecompileExprEval() 
-       {
+       public void testIfNoRecompileExprEval() {
                runRecompileTest(TEST_NAME2, false, true, false, false);
        }
        
        @Test
-       public void testForRecompileExprEval() 
-       {
+       public void testForRecompileExprEval() {
                runRecompileTest(TEST_NAME3, true, true, false, false);
        }
        
        @Test
-       public void testForNoRecompileExprEval() 
-       {
+       public void testForNoRecompileExprEval() {
                runRecompileTest(TEST_NAME3, false, true, false, false);
        }
        
        @Test
-       public void testParForRecompileExprEval() 
-       {
+       public void testParForRecompileExprEval() {
                runRecompileTest(TEST_NAME4, true, true, false, false);
        }
        
        @Test
-       public void testParForNoRecompileExprEval() 
-       {
+       public void testParForNoRecompileExprEval() {
                runRecompileTest(TEST_NAME4, false, true, false, false);
        }
 
        @Test
-       public void testWhileRecompileConstFold() 
-       {
+       public void testWhileRecompileConstFold() {
                runRecompileTest(TEST_NAME1, true, false, true, false);
        }
        
        @Test
-       public void testWhileNoRecompileConstFold() 
-       {
+       public void testWhileNoRecompileConstFold() {
                runRecompileTest(TEST_NAME1, false, false, true, false);
        }
        
        @Test
-       public void testIfRecompileConstFold() 
-       {
+       public void testIfRecompileConstFold() {
                runRecompileTest(TEST_NAME2, true, false, true, false);
        }
        
        @Test
-       public void testIfNoRecompileConstFold() 
-       {
+       public void testIfNoRecompileConstFold() {
                runRecompileTest(TEST_NAME2, false, false, true, false);
        }
        
        @Test
-       public void testForRecompileConstFold() 
-       {
+       public void testForRecompileConstFold() {
                runRecompileTest(TEST_NAME3, true, false, true, false);
        }
        
        @Test
-       public void testForNoRecompileConstFold() 
-       {
+       public void testForNoRecompileConstFold() {
                runRecompileTest(TEST_NAME3, false, false, true, false);
        }
        
        @Test
-       public void testParForRecompileConstFold() 
-       {
+       public void testParForRecompileConstFold() {
                runRecompileTest(TEST_NAME4, true, false, true, false);
        }
        
        @Test
-       public void testParForNoRecompileConstFold() 
-       {
+       public void testParForNoRecompileConstFold() {
                runRecompileTest(TEST_NAME4, false, false, true, false);
        }
 
        @Test
-       public void testWhileNoRecompileIPA() 
-       {
+       public void testWhileNoRecompileIPA() {
                runRecompileTest(TEST_NAME1, false, false, false, true);
        }
        
        @Test
-       public void testIfNoRecompileIPA() 
-       {
+       public void testIfNoRecompileIPA() {
                runRecompileTest(TEST_NAME2, false, false, false, true);
        }
 
        @Test
-       public void testForNoRecompileIPA() 
-       {
+       public void testForNoRecompileIPA() {
                runRecompileTest(TEST_NAME3, false, false, false, true);
        }
        
        @Test
-       public void testParForNoRecompileIPA() 
-       {
+       public void testParForNoRecompileIPA() {
                runRecompileTest(TEST_NAME4, false, false, false, true);
        }
        
        @Test
-       public void testWhileNoRecompileExprEvalIPA() 
-       {
+       public void testWhileNoRecompileExprEvalIPA() {
                runRecompileTest(TEST_NAME1, false, true, false, true);
        }
 
        @Test
-       public void testIfNoRecompileExprEvalIPA() 
-       {
+       public void testIfNoRecompileExprEvalIPA() {
                runRecompileTest(TEST_NAME2, false, true, false, true);
        }
        
        @Test
-       public void testForNoRecompileExprEvalIPA() 
-       {
+       public void testForNoRecompileExprEvalIPA() {
                runRecompileTest(TEST_NAME3, false, true, false, true);
        }
        
        @Test
-       public void testParForNoRecompileExprEvalIPA() 
-       {
+       public void testParForNoRecompileExprEvalIPA() {
                runRecompileTest(TEST_NAME4, false, true, false, true);
        }
 
        @Test
-       public void testWhileNoRecompileConstFoldIPA() 
-       {
+       public void testWhileNoRecompileConstFoldIPA() {
                runRecompileTest(TEST_NAME1, false, false, true, true);
        }
 
        @Test
-       public void testIfNoRecompileConstFoldIPA() 
-       {
+       public void testIfNoRecompileConstFoldIPA() {
                runRecompileTest(TEST_NAME2, false, false, true, true);
        }
 
        
        @Test
-       public void testForNoRecompileConstFoldIPA() 
-       {
+       public void testForNoRecompileConstFoldIPA() {
                runRecompileTest(TEST_NAME3, false, false, true, true);
        }
        
        @Test
-       public void testParForNoRecompileConstFoldIPA() 
-       {
+       public void testParForNoRecompileConstFoldIPA() {
                runRecompileTest(TEST_NAME4, false, false, true, true);
        }
        
-       
        private void runRecompileTest( String testname, boolean recompile, 
boolean evalExpr, boolean constFold, boolean IPA )
-       {       
+       {
                boolean oldFlagRecompile = CompilerConfig.FLAG_DYN_RECOMPILE;
                boolean oldFlagEval = 
OptimizerUtils.ALLOW_SIZE_EXPRESSION_EVALUATION;
                boolean oldFlagFold = OptimizerUtils.ALLOW_CONSTANT_FOLDING;
@@ -295,7 +256,7 @@ public class PredicateRecompileTest extends 
AutomatedTestBase
                        /* This is for running the junit test the new way, 
i.e., construct the arguments directly */
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[]{"-args",
+                       programArgs = new String[]{"-explain","-args",
                                Integer.toString(rows),
                                Integer.toString(cols),
                                Integer.toString(val),
@@ -312,35 +273,32 @@ public class PredicateRecompileTest extends 
AutomatedTestBase
                        
OptimizerUtils.ALLOW_WORSTCASE_SIZE_EXPRESSION_EVALUATION = false;
                        
                        boolean exceptionExpected = false;
-                       runTest(true, exceptionExpected, null, -1); 
+                       runTest(true, exceptionExpected, null, -1);
                        
                        //check expected number of compiled and executed MR jobs
-                       if( recompile )
-                       {
+                       if( recompile ) {
                                Assert.assertEquals("Unexpected number of 
executed MR jobs.", 
-                                                 1 - ((evalExpr || 
constFold)?1:0), Statistics.getNoOfExecutedMRJobs()); //rand        
+                                       1 - ((evalExpr || constFold)?1:0), 
Statistics.getNoOfExecutedMRJobs()); //rand
                        }
                        else
                        {
-                               if( IPA )
-                               {
+                               if( IPA ) {
                                        //old expected numbers before IPA
                                        if( testname.equals(TEST_NAME1) )
                                                Assert.assertEquals("Unexpected 
number of executed MR jobs.", 
-                                                   4 - 
((evalExpr||constFold)?4:0), Statistics.getNoOfExecutedMRJobs()); //rand, 2xgmr 
while pred, 1x gmr while body                           
+                                                       4 - 
((evalExpr||constFold)?4:0), Statistics.getNoOfExecutedMRJobs()); //rand, 2xgmr 
while pred, 1x gmr while body
                                        else //if( testname.equals(TEST_NAME2) )
                                                Assert.assertEquals("Unexpected 
number of executed MR jobs.", 
-                                                   3 - 
((evalExpr||constFold)?3:0), Statistics.getNoOfExecutedMRJobs()); //rand, 1xgmr 
if pred, 1x gmr if body 
+                                                       3 - 
((evalExpr||constFold)?3:0), Statistics.getNoOfExecutedMRJobs()); //rand, 1xgmr 
if pred, 1x gmr if body
                                }
-                               else
-                               {
+                               else {
                                        //old expected numbers before IPA
                                        if( testname.equals(TEST_NAME1) )
                                                Assert.assertEquals("Unexpected 
number of executed MR jobs.", 
-                                                   4 - ((evalExpr)?1:0), 
Statistics.getNoOfExecutedMRJobs()); //rand, 2xgmr while pred, 1x gmr while 
body                              
+                                                       4 - 
((evalExpr||constFold)?1:0), Statistics.getNoOfExecutedMRJobs()); //rand, 2xgmr 
while pred, 1x gmr while body
                                        else //if( testname.equals(TEST_NAME2) )
                                                Assert.assertEquals("Unexpected 
number of executed MR jobs.", 
-                                                   3 - ((evalExpr)?1:0), 
Statistics.getNoOfExecutedMRJobs()); //rand, 1xgmr if pred, 1x gmr if body
+                                                       3 - 
((evalExpr||constFold)?1:0), Statistics.getNoOfExecutedMRJobs()); //rand, 1xgmr 
if pred, 1x gmr if body
                                }
                        }
                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/test/scripts/functions/misc/RewriteComparisons.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteComparisons.dml 
b/src/test/scripts/functions/misc/RewriteComparisons.dml
new file mode 100644
index 0000000..d84149a
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteComparisons.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = seq(1,100);
+B = t(seq(5,15));
+while(FALSE){}
+
+X = rowIndexMax(outer(A, B, $1) == $2)
+
+while(FALSE){}
+print(sum(X))

http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index 8b08155..0b73edb 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -75,6 +75,7 @@ import org.junit.runners.Suite;
        RewritePushdownSumBinaryMult.class,
        RewritePushdownSumOnBinaryTest.class,
        RewritePushdownUaggTest.class,
+       RewriteRemoveComparisonChainsTest.class,
        RewriteSimplifyRowColSumMVMultTest.class,
        RewriteSlicedMatrixMultTest.class,
        ScalarAssignmentTest.class,

Reply via email to