[SYSTEMML-2385] New simplification rewrites for comparison chains This patch introduces new rewrites for binary comparison chains such as outer(v1,v2,">") == 0 --> outer(v1,v2,"<="), which is especially useful together with other rewrites such as uaggouterchain for fusing rowSums or rowIndexMax with the outer operation which provides better asymptotic behavior but only applies to row aggregates directly over outer comparison operations.
Furthermore, this also includes a fix for the recompilation tests, which after the recent cleanup of constant folding produce fewer distributed jobs. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1e5984cc Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1e5984cc Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1e5984cc Branch: refs/heads/master Commit: 1e5984cca10132603af7c638f8bd4ec6139b7061 Parents: 303a2d3 Author: Matthias Boehm <[email protected]> Authored: Thu Jun 14 19:16:52 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Jun 14 19:46:37 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/BinaryOp.java | 22 ++- .../sysml/hops/rewrite/HopRewriteUtils.java | 9 ++ .../RewriteAlgebraicSimplificationStatic.java | 35 ++++- .../misc/RewriteRemoveComparisonChainsTest.java | 106 ++++++++++++++ .../recompile/PredicateRecompileTest.java | 140 +++++++------------ .../functions/misc/RewriteComparisons.dml | 29 ++++ .../functions/misc/ZPackageSuite.java | 1 + 7 files changed, 244 insertions(+), 98 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/main/java/org/apache/sysml/hops/BinaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java index d66ac12..1a65130 100644 --- a/src/main/java/org/apache/sysml/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java @@ -1595,10 +1595,22 @@ public class BinaryOp extends Hop ||op==OpOp2.BITWSHIFTL ||op==OpOp2.BITWSHIFTR); } - public boolean isPPredOperation() - { - return ( op==OpOp2.LESS ||op==OpOp2.LESSEQUAL - ||op==OpOp2.GREATER ||op==OpOp2.GREATEREQUAL - ||op==OpOp2.EQUAL ||op==OpOp2.NOTEQUAL); + public boolean isPPredOperation() { + return (op==OpOp2.LESS ||op==OpOp2.LESSEQUAL + ||op==OpOp2.GREATER ||op==OpOp2.GREATEREQUAL + ||op==OpOp2.EQUAL ||op==OpOp2.NOTEQUAL); + } + + public OpOp2 getComplementPPredOperation() { + switch( op ) { + case LESS: return OpOp2.GREATEREQUAL; + case LESSEQUAL: return OpOp2.GREATER; + case GREATER: return OpOp2.LESSEQUAL; + case GREATEREQUAL: return OpOp2.LESS; + case EQUAL: return OpOp2.NOTEQUAL; + case NOTEQUAL: return OpOp2.EQUAL; + default: + throw new HopsException("BinaryOp is not a ppred operation."); + } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 269e9e6..9765fc8 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -20,6 +20,7 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; +import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; @@ -163,6 +164,10 @@ public class HopRewriteUtils } } + public static boolean isLiteralOfValue( Hop hop, Double... val ) { + return Arrays.stream(val).anyMatch(d -> isLiteralOfValue(hop, d)); + } + public static boolean isLiteralOfValue( Hop hop, double val ) { return (hop instanceof LiteralOp && (hop.getValueType()==ValueType.DOUBLE || hop.getValueType()==ValueType.INT) @@ -914,6 +919,10 @@ public class HopRewriteUtils return isBinary(hop, type) && hop.getParent().size() <= maxParents; } + public static boolean isBinaryPPred(Hop hop) { + return hop instanceof BinaryOp && ((BinaryOp) hop).isPPredOperation(); + } + public static boolean isBinarySparseSafe(Hop hop) { if( !(hop instanceof BinaryOp) ) return false; http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 8f9aad9..4396c7b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -175,13 +175,15 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = simplifyTransposeAggBinBinaryChains(hop, hi, i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C) hi = simplifyReplaceZeroOperation(hop, hi, i); //e.g., X + (X==0) * s -> replace(X, 0, s) hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplify binary or dyn rewrites - hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count") + hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count") if(OptimizerUtils.ALLOW_OPERATOR_FUSION) { hi = fuseMinusNzBinaryOperation(hop, hi, i); //e.g., X-mean*ppred(X,0,!=) -> X -nz mean hi = fuseLogNzUnaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X) -> log_nz(X) hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) } hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) + hi = simplifyBinaryComparisonChain(hop, hi, i); //e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> outer(v1,v2,"!="), + //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) @@ -1781,7 +1783,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { //determine variable parameters for pattern a/b boolean isPatternB = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)); - boolean isTransposeRight = HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)); + boolean isTransposeRight = HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)); Hop trgt = isPatternB ? (isTransposeRight ? hi.getInput().get(1).getInput().get(0) : //get v from t(v) HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v') @@ -1813,6 +1815,35 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } + private static Hop simplifyBinaryComparisonChain(Hop parent, Hop hi, int pos) { + if( HopRewriteUtils.isBinaryPPred(hi) + && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 0d, 1d) + && HopRewriteUtils.isBinaryPPred(hi.getInput().get(0)) ) { + BinaryOp bop = (BinaryOp) hi; + BinaryOp bop2 = (BinaryOp) hi.getInput().get(0); + + //pattern: outer(v1,v2,"!=") == 1 -> outer(v1,v2,"!=") + if( HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 1) ) { + HopRewriteUtils.replaceChildReference(parent, bop, bop2, pos); + HopRewriteUtils.cleanupUnreferenced(bop); + hi = bop2; + LOG.debug("Applied simplifyBinaryComparisonChain1 (line "+hi.getBeginLine()+")"); + } + //pattern: outer(v1,v2,"!=") == 0 -> outer(v1,v2,"==") + else { + OpOp2 optr = bop2.getComplementPPredOperation(); + BinaryOp tmp = HopRewriteUtils.createBinary(bop2.getInput().get(0), + bop2.getInput().get(1), optr, bop2.isOuterVectorOperator()); + HopRewriteUtils.replaceChildReference(parent, bop, tmp, pos); + HopRewriteUtils.cleanupUnreferenced(bop, bop2); + hi = tmp; + LOG.debug("Applied simplifyBinaryComparisonChain0 (line "+hi.getBeginLine()+")"); + } + } + + return hi; + } + /** * NOTE: currently disabled since this rewrite is INVALID in the * presence of NaNs (because (NaN!=NaN) is true). http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteRemoveComparisonChainsTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteRemoveComparisonChainsTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteRemoveComparisonChainsTest.java new file mode 100644 index 0000000..43fd4f9 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteRemoveComparisonChainsTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class RewriteRemoveComparisonChainsTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "RewriteComparisons"; + //a) >, == 0; b) <=, == 1; c) ==, == 0; d) !=, == 1 + + private final static String TEST_DIR = "functions/misc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + RewriteRemoveComparisonChainsTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + } + + @Test + public void testComparisonGt0() { + runComparisonChainTest( ">", 0, false ); + } + + @Test + public void testComparisonGt0Rewrites() { + runComparisonChainTest( ">", 0, true ); + } + + @Test + public void testComparisonLte1() { + runComparisonChainTest( "<=", 1, false ); + } + + @Test + public void testComparisonLte1Rewrites() { + runComparisonChainTest( "<=", 1, true ); + } + + @Test + public void testComparisonEq0() { + runComparisonChainTest( "==", 0, false ); + } + + @Test + public void testComparisonEq0Rewrites() { + runComparisonChainTest( "==", 0, true ); + } + + @Test + public void testComparisonNeq1() { + runComparisonChainTest( "!=", 1, false ); + } + + @Test + public void testComparisonNeq1Rewrites() { + runComparisonChainTest( "!=", 1, true ); + } + + private void runComparisonChainTest( String op, int compare, boolean rewrites ) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(TEST_NAME1); + loadTestConfiguration(config); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[]{"-stats","-args", op, String.valueOf(compare)}; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + runTest(true, false, null, -1); + + //check for applied rewrites + Assert.assertEquals(rewrites, heavyHittersContainsString("uaggouterchain")); + if( compare == 1 && rewrites ) + Assert.assertTrue(!heavyHittersContainsString("==")); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java b/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java index 29660aa..9f8c82f 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/recompile/PredicateRecompileTest.java @@ -32,7 +32,6 @@ import org.apache.sysml.utils.Statistics; public class PredicateRecompileTest extends AutomatedTestBase { - private final static String TEST_NAME1 = "while_recompile"; private final static String TEST_NAME2 = "if_recompile"; private final static String TEST_NAME3 = "for_recompile"; @@ -41,9 +40,8 @@ public class PredicateRecompileTest extends AutomatedTestBase private final static String TEST_CLASS_DIR = TEST_DIR + PredicateRecompileTest.class.getSimpleName() + "/"; private final static int rows = 10; - private final static int cols = 15; - private final static int val = 7; - + private final static int cols = 15; + private final static int val = 7; @Override public void setUp() @@ -59,225 +57,188 @@ public class PredicateRecompileTest extends AutomatedTestBase } @Test - public void testWhileRecompile() - { + public void testWhileRecompile() { runRecompileTest(TEST_NAME1, true, false, false, false); } @Test - public void testWhileNoRecompile() - { + public void testWhileNoRecompile() { runRecompileTest(TEST_NAME1, false, false, false, false); } @Test - public void testIfRecompile() - { + public void testIfRecompile() { runRecompileTest(TEST_NAME2, true, false, false, false); } @Test - public void testIfNoRecompile() - { + public void testIfNoRecompile() { runRecompileTest(TEST_NAME2, false, false, false, false); } @Test - public void testForRecompile() - { + public void testForRecompile() { runRecompileTest(TEST_NAME3, true, false, false, false); } @Test - public void testForNoRecompile() - { + public void testForNoRecompile() { runRecompileTest(TEST_NAME3, false, false, false, false); } @Test - public void testParForRecompile() - { + public void testParForRecompile() { runRecompileTest(TEST_NAME4, true, false, false, false); } @Test - public void testParForNoRecompile() - { + public void testParForNoRecompile() { runRecompileTest(TEST_NAME4, false, false, false, false); } @Test - public void testWhileRecompileExprEval() - { + public void testWhileRecompileExprEval() { runRecompileTest(TEST_NAME1, true, true, false, false); } @Test - public void testWhileNoRecompileExprEval() - { + public void testWhileNoRecompileExprEval() { runRecompileTest(TEST_NAME1, false, true, false, false); } @Test - public void testIfRecompileExprEval() - { + public void testIfRecompileExprEval() { runRecompileTest(TEST_NAME2, true, true, false, false); } @Test - public void testIfNoRecompileExprEval() - { + public void testIfNoRecompileExprEval() { runRecompileTest(TEST_NAME2, false, true, false, false); } @Test - public void testForRecompileExprEval() - { + public void testForRecompileExprEval() { runRecompileTest(TEST_NAME3, true, true, false, false); } @Test - public void testForNoRecompileExprEval() - { + public void testForNoRecompileExprEval() { runRecompileTest(TEST_NAME3, false, true, false, false); } @Test - public void testParForRecompileExprEval() - { + public void testParForRecompileExprEval() { runRecompileTest(TEST_NAME4, true, true, false, false); } @Test - public void testParForNoRecompileExprEval() - { + public void testParForNoRecompileExprEval() { runRecompileTest(TEST_NAME4, false, true, false, false); } @Test - public void testWhileRecompileConstFold() - { + public void testWhileRecompileConstFold() { runRecompileTest(TEST_NAME1, true, false, true, false); } @Test - public void testWhileNoRecompileConstFold() - { + public void testWhileNoRecompileConstFold() { runRecompileTest(TEST_NAME1, false, false, true, false); } @Test - public void testIfRecompileConstFold() - { + public void testIfRecompileConstFold() { runRecompileTest(TEST_NAME2, true, false, true, false); } @Test - public void testIfNoRecompileConstFold() - { + public void testIfNoRecompileConstFold() { runRecompileTest(TEST_NAME2, false, false, true, false); } @Test - public void testForRecompileConstFold() - { + public void testForRecompileConstFold() { runRecompileTest(TEST_NAME3, true, false, true, false); } @Test - public void testForNoRecompileConstFold() - { + public void testForNoRecompileConstFold() { runRecompileTest(TEST_NAME3, false, false, true, false); } @Test - public void testParForRecompileConstFold() - { + public void testParForRecompileConstFold() { runRecompileTest(TEST_NAME4, true, false, true, false); } @Test - public void testParForNoRecompileConstFold() - { + public void testParForNoRecompileConstFold() { runRecompileTest(TEST_NAME4, false, false, true, false); } @Test - public void testWhileNoRecompileIPA() - { + public void testWhileNoRecompileIPA() { runRecompileTest(TEST_NAME1, false, false, false, true); } @Test - public void testIfNoRecompileIPA() - { + public void testIfNoRecompileIPA() { runRecompileTest(TEST_NAME2, false, false, false, true); } @Test - public void testForNoRecompileIPA() - { + public void testForNoRecompileIPA() { runRecompileTest(TEST_NAME3, false, false, false, true); } @Test - public void testParForNoRecompileIPA() - { + public void testParForNoRecompileIPA() { runRecompileTest(TEST_NAME4, false, false, false, true); } @Test - public void testWhileNoRecompileExprEvalIPA() - { + public void testWhileNoRecompileExprEvalIPA() { runRecompileTest(TEST_NAME1, false, true, false, true); } @Test - public void testIfNoRecompileExprEvalIPA() - { + public void testIfNoRecompileExprEvalIPA() { runRecompileTest(TEST_NAME2, false, true, false, true); } @Test - public void testForNoRecompileExprEvalIPA() - { + public void testForNoRecompileExprEvalIPA() { runRecompileTest(TEST_NAME3, false, true, false, true); } @Test - public void testParForNoRecompileExprEvalIPA() - { + public void testParForNoRecompileExprEvalIPA() { runRecompileTest(TEST_NAME4, false, true, false, true); } @Test - public void testWhileNoRecompileConstFoldIPA() - { + public void testWhileNoRecompileConstFoldIPA() { runRecompileTest(TEST_NAME1, false, false, true, true); } @Test - public void testIfNoRecompileConstFoldIPA() - { + public void testIfNoRecompileConstFoldIPA() { runRecompileTest(TEST_NAME2, false, false, true, true); } @Test - public void testForNoRecompileConstFoldIPA() - { + public void testForNoRecompileConstFoldIPA() { runRecompileTest(TEST_NAME3, false, false, true, true); } @Test - public void testParForNoRecompileConstFoldIPA() - { + public void testParForNoRecompileConstFoldIPA() { runRecompileTest(TEST_NAME4, false, false, true, true); } - private void runRecompileTest( String testname, boolean recompile, boolean evalExpr, boolean constFold, boolean IPA ) - { + { boolean oldFlagRecompile = CompilerConfig.FLAG_DYN_RECOMPILE; boolean oldFlagEval = OptimizerUtils.ALLOW_SIZE_EXPRESSION_EVALUATION; boolean oldFlagFold = OptimizerUtils.ALLOW_CONSTANT_FOLDING; @@ -295,7 +256,7 @@ public class PredicateRecompileTest extends AutomatedTestBase /* This is for running the junit test the new way, i.e., construct the arguments directly */ String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[]{"-args", + programArgs = new String[]{"-explain","-args", Integer.toString(rows), Integer.toString(cols), Integer.toString(val), @@ -312,35 +273,32 @@ public class PredicateRecompileTest extends AutomatedTestBase OptimizerUtils.ALLOW_WORSTCASE_SIZE_EXPRESSION_EVALUATION = false; boolean exceptionExpected = false; - runTest(true, exceptionExpected, null, -1); + runTest(true, exceptionExpected, null, -1); //check expected number of compiled and executed MR jobs - if( recompile ) - { + if( recompile ) { Assert.assertEquals("Unexpected number of executed MR jobs.", - 1 - ((evalExpr || constFold)?1:0), Statistics.getNoOfExecutedMRJobs()); //rand + 1 - ((evalExpr || constFold)?1:0), Statistics.getNoOfExecutedMRJobs()); //rand } else { - if( IPA ) - { + if( IPA ) { //old expected numbers before IPA if( testname.equals(TEST_NAME1) ) Assert.assertEquals("Unexpected number of executed MR jobs.", - 4 - ((evalExpr||constFold)?4:0), Statistics.getNoOfExecutedMRJobs()); //rand, 2xgmr while pred, 1x gmr while body + 4 - ((evalExpr||constFold)?4:0), Statistics.getNoOfExecutedMRJobs()); //rand, 2xgmr while pred, 1x gmr while body else //if( testname.equals(TEST_NAME2) ) Assert.assertEquals("Unexpected number of executed MR jobs.", - 3 - ((evalExpr||constFold)?3:0), Statistics.getNoOfExecutedMRJobs()); //rand, 1xgmr if pred, 1x gmr if body + 3 - ((evalExpr||constFold)?3:0), Statistics.getNoOfExecutedMRJobs()); //rand, 1xgmr if pred, 1x gmr if body } - else - { + else { //old expected numbers before IPA if( testname.equals(TEST_NAME1) ) Assert.assertEquals("Unexpected number of executed MR jobs.", - 4 - ((evalExpr)?1:0), Statistics.getNoOfExecutedMRJobs()); //rand, 2xgmr while pred, 1x gmr while body + 4 - ((evalExpr||constFold)?1:0), Statistics.getNoOfExecutedMRJobs()); //rand, 2xgmr while pred, 1x gmr while body else //if( testname.equals(TEST_NAME2) ) Assert.assertEquals("Unexpected number of executed MR jobs.", - 3 - ((evalExpr)?1:0), Statistics.getNoOfExecutedMRJobs()); //rand, 1xgmr if pred, 1x gmr if body + 3 - ((evalExpr||constFold)?1:0), Statistics.getNoOfExecutedMRJobs()); //rand, 1xgmr if pred, 1x gmr if body } } http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/test/scripts/functions/misc/RewriteComparisons.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteComparisons.dml b/src/test/scripts/functions/misc/RewriteComparisons.dml new file mode 100644 index 0000000..d84149a --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteComparisons.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = seq(1,100); +B = t(seq(5,15)); +while(FALSE){} + +X = rowIndexMax(outer(A, B, $1) == $2) + +while(FALSE){} +print(sum(X)) http://git-wip-us.apache.org/repos/asf/systemml/blob/1e5984cc/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index 8b08155..0b73edb 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -75,6 +75,7 @@ import org.junit.runners.Suite; RewritePushdownSumBinaryMult.class, RewritePushdownSumOnBinaryTest.class, RewritePushdownUaggTest.class, + RewriteRemoveComparisonChainsTest.class, RewriteSimplifyRowColSumMVMultTest.class, RewriteSlicedMatrixMultTest.class, ScalarAssignmentTest.class,
