Repository: systemml Updated Branches: refs/heads/master 8a1f98e1b -> ec024661a
[SYSTEMML-1990] New rewrites for rand, outer products, and cbind/rbind This patch generalizes two existing rewrites and introduces a new rewrite for nary cbind/rbind: (1) Generalized rand-binary fusion: So far we only fused binary operations with literals into rand operations. For special cases of multiply and add, we now also fuse binary operations with variable scalar inputs. (2) Generalized outer-product rewrites: Outer products for replication with subsequent comparison operations are rewritten to binary outer operations. This patch generalizes the detection of such patterns to partial unknowns. (3) New rbind/cbind folding: Exploiting the recently added nary cbind/rbind operations, we now recursively fold subsequent cbind/rbind operations with single consumers into nary cbind/rbind. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/578a9869 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/578a9869 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/578a9869 Branch: refs/heads/master Commit: 578a98697e6219947b748337f9a121acf54afe53 Parents: 8a1f98e Author: Matthias Boehm <[email protected]> Authored: Mon Nov 6 13:14:14 2017 -0800 Committer: Matthias Boehm <[email protected]> Committed: Wed Nov 8 13:22:13 2017 -0800 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/DataGenOp.java | 15 +- .../sysml/hops/rewrite/HopRewriteUtils.java | 21 +++ .../RewriteAlgebraicSimplificationDynamic.java | 5 +- .../RewriteAlgebraicSimplificationStatic.java | 140 +++++++++++++++---- .../java/org/apache/sysml/utils/Statistics.java | 3 +- .../test/integration/AutomatedTestBase.java | 15 ++ .../functions/misc/RewriteFoldRCBindTest.java | 101 +++++++++++++ .../functions/misc/RewriteFusedRandTest.java | 62 ++++---- .../scripts/functions/misc/RewriteFoldCBind.dml | 28 ++++ .../scripts/functions/misc/RewriteFoldRBind.dml | 28 ++++ .../scripts/functions/misc/RewriteFusedRand.dml | 29 ---- .../functions/misc/RewriteFusedRandLit.dml | 29 ++++ .../functions/misc/RewriteFusedRandVar1.dml | 28 ++++ .../functions/misc/RewriteFusedRandVar2.dml | 28 ++++ .../functions/misc/ZPackageSuite.java | 1 + 15 files changed, 446 insertions(+), 87 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/hops/DataGenOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DataGenOp.java b/src/main/java/org/apache/sysml/hops/DataGenOp.java index 69f35ed..500bbd9 100644 --- a/src/main/java/org/apache/sysml/hops/DataGenOp.java +++ b/src/main/java/org/apache/sysml/hops/DataGenOp.java @@ -45,7 +45,6 @@ import org.apache.sysml.runtime.util.UtilFunctions; */ public class DataGenOp extends Hop implements MultiThreadedHop { - public static final long UNSPECIFIED_SEED = -1; // defines the specific data generation method @@ -366,15 +365,21 @@ public class DataGenOp extends Hop implements MultiThreadedHop } - public HashMap<String, Integer> getParamIndexMap() - { + public HashMap<String, Integer> getParamIndexMap() { return _paramIndexMap; } - public int getParamIndex(String key) - { + public int getParamIndex(String key) { return _paramIndexMap.get(key); } + + public Hop getInput(String key) { + return getInput().get(getParamIndex(key)); + } + + public void setInput(String key, Hop hop) { + getInput().set(getParamIndex(key), hop); + } public boolean hasConstantValue() { http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index b6db466..15cc2cb 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -40,6 +40,7 @@ import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.FileFormatTypes; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.Hop.OpOp3; +import org.apache.sysml.hops.Hop.OpOpN; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.HopsException; @@ -47,6 +48,7 @@ import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.MemoTable; +import org.apache.sysml.hops.NaryOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; @@ -603,6 +605,16 @@ public class HopRewriteUtils return ix; } + public static NaryOp createNary(OpOpN op, Hop... inputs) throws HopsException { + Hop mainInput = inputs[0]; + NaryOp nop = new NaryOp(mainInput.getName(), mainInput.getDataType(), + mainInput.getValueType(), op, inputs); + nop.setOutputBlocksizes(mainInput.getRowsInBlock(), mainInput.getColsInBlock()); + copyLineNumbers(mainInput, nop); + nop.refreshSizeInformation(); + return nop; + } + public static Hop createValueHop( Hop hop, boolean row ) throws HopsException { @@ -957,6 +969,15 @@ public class HopRewriteUtils return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp()==AggOp.SUM_SQ); } + public static boolean isNary(Hop hop, OpOpN type) { + return hop instanceof NaryOp && ((NaryOp)hop).getOp()==type; + } + + public static boolean isNary(Hop hop, OpOpN... types) { + return ( hop instanceof NaryOp + && ArrayUtils.contains(types, ((NaryOp) hop).getOp())); + } + public static boolean isNonZeroIndicator(Hop pred, Hop hop ) { if( pred instanceof BinaryOp && ((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index eba06fc..0fa1aed 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -432,8 +432,9 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule else if(HopRewriteUtils.isValidOuterBinaryOp(bop) && HopRewriteUtils.isMatrixMultiply(left) && HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1) - && left.getInput().get(0).getDim2() == 1 //column vector - && left.getDim1() != 1 && right.getDim1() == 1 ) //outer vector product + && (left.getInput().get(0).getDim2() == 1 //outer product + || left.getInput().get(1).getDim1() == 1) + && left.getDim1() != 1 && right.getDim1() == 1 ) //outer vector binary { Hop hnew = HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true); HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos); http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 64a37d7..2d5d881 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -38,10 +38,12 @@ import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.DataGenMethod; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp3; +import org.apache.sysml.hops.Hop.OpOpN; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.ReOrgOp; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.NaryOp; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.ParameterizedBuiltinOp; @@ -147,23 +149,24 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = removeUnnecessaryBinaryOperation(hop, hi, i); //e.g., X*1 -> X (dep: should come after rm unnecessary vectorize) hi = fuseDatagenAndBinaryOperation(hi); //e.g., rand(min=-1,max=1)*7 -> rand(min=-7,max=7) hi = fuseDatagenAndMinusOperation(hi); //e.g., -(rand(min=-2,max=1)) -> rand(min=-1,max=2) - hi = simplifyBinaryToUnaryOperation(hop, hi, i); //e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X) - hi = canonicalizeMatrixMultScalarAdd(hi); //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) - hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X) - if(OptimizerUtils.ALLOW_OPERATOR_FUSION) - hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y - hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X - hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) - hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X) - hi = removeUnnecessaryAggregates(hi); //e.g., sum(rowSums(X)) -> sum(X) - hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s; - hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X)) - hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) - hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X) - hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor - hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); - if(OptimizerUtils.ALLOW_OPERATOR_FUSION) - hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X) + hi = foldMultipleAppendOperations(hi); //e.g., cbind(X,cbind(Y,Z)) -> cbind(X,Y,Z) + hi = simplifyBinaryToUnaryOperation(hop, hi, i); //e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X) + hi = canonicalizeMatrixMultScalarAdd(hi); //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) + hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X) + if(OptimizerUtils.ALLOW_OPERATOR_FUSION) + hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y + hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X + hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) + hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X) + hi = removeUnnecessaryAggregates(hi); //e.g., sum(rowSums(X)) -> sum(X) + hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s; + hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X)) + hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) + hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X) + hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor + hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); + if(OptimizerUtils.ALLOW_OPERATOR_FUSION) + hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X) hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y)); hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]; hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq; @@ -358,14 +361,13 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //the creation of multiple datagen ops and thus potentially different results if seed not specified) //left input rand and hence output matrix double, right scalar literal - if( left instanceof DataGenOp && ((DataGenOp)left).getOp()==DataGenMethod.RAND && + if( HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) && right instanceof LiteralOp && left.getParent().size()==1 ) { DataGenOp inputGen = (DataGenOp)left; - HashMap<String,Integer> params = inputGen.getParamIndexMap(); - Hop pdf = left.getInput().get(params.get(DataExpression.RAND_PDF)); - Hop min = left.getInput().get(params.get(DataExpression.RAND_MIN)); - Hop max = left.getInput().get(params.get(DataExpression.RAND_MAX)); + Hop pdf = inputGen.getInput(DataExpression.RAND_PDF); + Hop min = inputGen.getInput(DataExpression.RAND_MIN); + Hop max = inputGen.getInput(DataExpression.RAND_MAX); double sval = ((LiteralOp)right).getDoubleValue(); if( HopRewriteUtils.isBinary(bop, OpOp2.MULT, OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV) @@ -396,10 +398,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule left instanceof LiteralOp && right.getParent().size()==1 ) { DataGenOp inputGen = (DataGenOp)right; - HashMap<String,Integer> params = inputGen.getParamIndexMap(); - Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF)); - Hop min = right.getInput().get(params.get(DataExpression.RAND_MIN)); - Hop max = right.getInput().get(params.get(DataExpression.RAND_MAX)); + Hop pdf = inputGen.getInput(DataExpression.RAND_PDF); + Hop min = inputGen.getInput(DataExpression.RAND_MIN); + Hop max = inputGen.getInput(DataExpression.RAND_MAX); double sval = ((LiteralOp)left).getDoubleValue(); if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS) @@ -423,6 +424,44 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule LOG.debug("Applied fuseDatagenAndBinaryOperation2 (line "+bop.getBeginLine()+")."); } } + //left input rand and hence output matrix double, right scalar variable + else if( HopRewriteUtils.isDataGenOp(left, DataGenMethod.RAND) + && right.getDataType().isScalar() && left.getParent().size()==1 ) + { + DataGenOp gen = (DataGenOp)left; + Hop min = gen.getInput(DataExpression.RAND_MIN); + Hop max = gen.getInput(DataExpression.RAND_MAX); + + if( HopRewriteUtils.isBinary(bop, OpOp2.PLUS) + && HopRewriteUtils.isLiteralOfValue(min, 0) + && HopRewriteUtils.isLiteralOfValue(max, 0) ) + { + gen.setInput(DataExpression.RAND_MIN, right); + gen.setInput(DataExpression.RAND_MAX, right); + //rewire all parents (avoid anomalies with replicated datagen) + List<Hop> parents = new ArrayList<>(bop.getParent()); + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, bop, gen); + hi = gen; + LOG.debug("Applied fuseDatagenAndBinaryOperation3a (line "+bop.getBeginLine()+")."); + } + else if( HopRewriteUtils.isBinary(bop, OpOp2.MULT) + && (HopRewriteUtils.isLiteralOfValue(min, 0) + || HopRewriteUtils.isLiteralOfValue(min, 1)) + && HopRewriteUtils.isLiteralOfValue(max, 1) ) + { + if( HopRewriteUtils.isLiteralOfValue(min, 1) ) + gen.setInput(DataExpression.RAND_MIN, right); + gen.setInput(DataExpression.RAND_MAX, right); + //rewire all parents (avoid anomalies with replicated datagen) + List<Hop> parents = new ArrayList<>(bop.getParent()); + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, bop, gen); + hi = gen; + LOG.debug("Applied fuseDatagenAndBinaryOperation3b (line "+bop.getBeginLine()+")."); + } + } + } return hi; @@ -478,6 +517,55 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } + private static Hop foldMultipleAppendOperations(Hop hi) + throws HopsException + { + if( hi.getDataType().isMatrix() //no string appends + && (HopRewriteUtils.isBinary(hi, OpOp2.CBIND, OpOp2.RBIND) + || HopRewriteUtils.isNary(hi, OpOpN.CBIND, OpOpN.RBIND)) + && !OptimizerUtils.isHadoopExecutionMode() ) + { + OpOp2 bop = (hi instanceof BinaryOp) ? ((BinaryOp)hi).getOp() : + OpOp2.valueOf(((NaryOp)hi).getOp().name()); + OpOpN nop = (hi instanceof NaryOp) ? ((NaryOp)hi).getOp() : + OpOpN.valueOf(((BinaryOp)hi).getOp().name()); + + boolean converged = false; + while( !converged ) { + //get first matching cbind or rbind + Hop first = hi.getInput().stream() + .filter(h -> HopRewriteUtils.isBinary(h, bop) || HopRewriteUtils.isNary(h, nop)) + .findFirst().orElse(null); + + //replace current op with new nary cbind/rbind + if( first != null && first.getParent().size()==1 ) { + //construct new list of inputs (in original order) + ArrayList<Hop> linputs = new ArrayList<>(); + for(Hop in : hi.getInput()) + if( in == first ) + linputs.addAll(first.getInput()); + else + linputs.add(in); + Hop hnew = HopRewriteUtils.createNary(nop, linputs.toArray(new Hop[0])); + //clear dangling references + HopRewriteUtils.removeAllChildReferences(hi); + HopRewriteUtils.removeAllChildReferences(first); + //rewire all parents (avoid anomalies with refs to hi) + List<Hop> parents = new ArrayList<>(hi.getParent()); + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, hi, hnew); + hi = hnew; + LOG.debug("Applied foldMultipleAppendOperations (line "+hi.getBeginLine()+")."); + } + else { + converged = true; + } + } + } + + return hi; + } + /** * Handle simplification of binary operations (relies on previous common subexpression elimination). * At the same time this servers as a canonicalization for more complex rewrites. http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index 555493c..5cc0650 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -545,7 +545,8 @@ public class Statistics } public static long getCPHeavyHitterCount(String opcode) { - return _cpInstCounts.get(opcode); + Long tmp = _cpInstCounts.get(opcode); + return (tmp != null) ? tmp : 0; } /** http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java index d39e513..3bd78b0 100644 --- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java @@ -1808,6 +1808,21 @@ public abstract class AutomatedTestBase return writeInputFrame(name, data, false, schema, oi); } + protected boolean heavyHittersContainsString(String... str) { + for( String opcode : Statistics.getCPHeavyHitterOpCodes()) + for( String s : str ) + if(opcode.equals(s)) + return true; + return false; + } + + protected boolean heavyHittersContainsString(String str, int minCount) { + int count = 0; + for( String opcode : Statistics.getCPHeavyHitterOpCodes()) + count += opcode.equals(str) ? 1 : 0; + return (count >= minCount); + } + protected boolean heavyHittersContainsSubString(String... str) { for( String opcode : Statistics.getCPHeavyHitterOpCodes()) for( String s : str ) http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldRCBindTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldRCBindTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldRCBindTest.java new file mode 100644 index 0000000..83050bd --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFoldRCBindTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; + +public class RewriteFoldRCBindTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteFoldCBind"; + private static final String TEST_NAME2 = "RewriteFoldRBind"; + + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFoldRCBindTest.class.getSimpleName() + "/"; + + private static final int rows = 1932; + private static final int cols = 14; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + } + + @Test + public void testRewriteFoldCBindNoRewrite() { + testRewriteFoldRCBind( TEST_NAME1, false ); + } + + @Test + public void testRewriteFoldCBindRewrite() { + testRewriteFoldRCBind( TEST_NAME1, true ); + } + + @Test + public void testRewriteFoldRBindNoRewrite() { + testRewriteFoldRCBind( TEST_NAME2, false ); + } + + @Test + public void testRewriteFoldRBindRewrite() { + testRewriteFoldRCBind( TEST_NAME2, true ); + } + + private void testRewriteFoldRCBind( String testname, boolean rewrites ) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-stats", "-args", String.valueOf(rows), + String.valueOf(cols), output("R") }; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + //run performance tests + runTest(true, false, null, -1); + + //compare matrices + Double ret = readDMLMatrixFromHDFS("R").get(new CellIndex(1,1)); + Assert.assertEquals("Wrong result", new Double(5*rows*cols), ret); + + //check for applied rewrites + if( rewrites ) { + Assert.assertTrue(!heavyHittersContainsString("append") + && Statistics.getCPHeavyHitterCount("cbind") <= 1 + && Statistics.getCPHeavyHitterCount("rbind") <= 1); + } + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java index 1491aa6..d7fe902 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java @@ -19,8 +19,6 @@ package org.apache.sysml.test.integration.functions.misc; -import java.util.HashMap; - import org.junit.Assert; import org.junit.Test; import org.apache.sysml.hops.OptimizerUtils; @@ -29,13 +27,12 @@ import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; -/** - * - * - */ public class RewriteFusedRandTest extends AutomatedTestBase -{ - private static final String TEST_NAME1 = "RewriteFusedRand"; +{ + private static final String TEST_NAME1 = "RewriteFusedRandLit"; + private static final String TEST_NAME2 = "RewriteFusedRandVar1"; + private static final String TEST_NAME3 = "RewriteFusedRandVar2"; + private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFusedRandTest.class.getSimpleName() + "/"; @@ -47,44 +44,50 @@ public class RewriteFusedRandTest extends AutomatedTestBase public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); } @Test - public void testRewriteFusedRandUniformNoRewrite() { + public void testRewriteFusedRandUniformNoRewrite() { testRewriteFusedRand( TEST_NAME1, "uniform", false ); } @Test - public void testRewriteFusedRandNormalNoRewrite() { + public void testRewriteFusedRandNormalNoRewrite() { testRewriteFusedRand( TEST_NAME1, "normal", false ); } @Test - public void testRewriteFusedRandPoissonNoRewrite() { + public void testRewriteFusedRandPoissonNoRewrite() { testRewriteFusedRand( TEST_NAME1, "poisson", false ); } @Test - public void testRewriteFusedRandUniform() { + public void testRewriteFusedRandUniform() { testRewriteFusedRand( TEST_NAME1, "uniform", true ); } @Test - public void testRewriteFusedRandNormal() { + public void testRewriteFusedRandNormal() { testRewriteFusedRand( TEST_NAME1, "normal", true ); } @Test - public void testRewriteFusedRandPoisson() { + public void testRewriteFusedRandPoisson() { testRewriteFusedRand( TEST_NAME1, "poisson", true ); } - /** - * - * @param condition - * @param branchRemoval - * @param IPA - */ + @Test + public void testRewriteFusedZerosPlusVar() { + testRewriteFusedRand( TEST_NAME2, "uniform", true ); + } + + @Test + public void testRewriteFusedOnesMultVar() { + testRewriteFusedRand( TEST_NAME3, "uniform", true ); + } + private void testRewriteFusedRand( String testname, String pdf, boolean rewrites ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; @@ -95,7 +98,7 @@ public class RewriteFusedRandTest extends AutomatedTestBase String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[]{ "-args", String.valueOf(rows), + programArgs = new String[]{ "-stats", "-args", String.valueOf(rows), String.valueOf(cols), pdf, String.valueOf(seed), output("R") }; OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; @@ -103,11 +106,22 @@ public class RewriteFusedRandTest extends AutomatedTestBase runTest(true, false, null, -1); //compare matrices - HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); - Assert.assertEquals("Wrong result, expected: "+rows, new Double(rows), dmlfile.get(new CellIndex(1,1))); + Double ret = readDMLMatrixFromHDFS("R").get(new CellIndex(1,1)); + if( testname.equals(TEST_NAME1) ) + Assert.assertEquals("Wrong result", new Double(rows), ret); + else if( testname.equals(TEST_NAME2) ) + Assert.assertEquals("Wrong result", new Double(Math.pow(rows*cols, 2)), ret); + else if( testname.equals(TEST_NAME3) ) + Assert.assertEquals("Wrong result", new Double(Math.pow(rows*cols, 2)), ret); + + //check for applied rewrites + if( rewrites && pdf.equals("uniform") ) { + Assert.assertTrue(!heavyHittersContainsString("+") + && !heavyHittersContainsString("*")); + } } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; } } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFoldCBind.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFoldCBind.dml b/src/test/scripts/functions/misc/RewriteFoldCBind.dml new file mode 100644 index 0000000..47733fe --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFoldCBind.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, $1, $2) +while(FALSE){} +Y = cbind(cbind(X,X),cbind(X,X,X)) +while(FALSE){} +R = as.matrix(sum(Y)) + +write(R, $3); http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFoldRBind.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFoldRBind.dml b/src/test/scripts/functions/misc/RewriteFoldRBind.dml new file mode 100644 index 0000000..c489f7c --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFoldRBind.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, $1, $2) +while(FALSE){} +Y = rbind(rbind(X,X),rbind(X,X,X)) +while(FALSE){} +R = as.matrix(sum(Y)) + +write(R, $3); http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFusedRand.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFusedRand.dml b/src/test/scripts/functions/misc/RewriteFusedRand.dml deleted file mode 100644 index ab00f04..0000000 --- a/src/test/scripts/functions/misc/RewriteFusedRand.dml +++ /dev/null @@ -1,29 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7; - -while(FALSE){} #prevent cse - -X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7; - -R = as.matrix(sum(rowSums(X1)==rowSums(X2))); -write(R, $5); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFusedRandLit.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFusedRandLit.dml b/src/test/scripts/functions/misc/RewriteFusedRandLit.dml new file mode 100644 index 0000000..ab00f04 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFusedRandLit.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7; + +while(FALSE){} #prevent cse + +X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7; + +R = as.matrix(sum(rowSums(X1)==rowSums(X2))); +write(R, $5); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFusedRandVar1.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFusedRandVar1.dml b/src/test/scripts/functions/misc/RewriteFusedRandVar1.dml new file mode 100644 index 0000000..f37557f --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFusedRandVar1.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, $1, $2) +while(FALSE){} +Y = matrix(0, $1, $2) + sum(X); +while(FALSE){} +R = as.matrix(sum(Y)) + +write(R, $5); http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test/scripts/functions/misc/RewriteFusedRandVar2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFusedRandVar2.dml b/src/test/scripts/functions/misc/RewriteFusedRandVar2.dml new file mode 100644 index 0000000..d98ef08 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFusedRandVar2.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, $1, $2) +while(FALSE){} +Y = matrix(1, $1, $2) * sum(X); +while(FALSE){} +R = as.matrix(sum(Y)) + +write(R, $5); http://git-wip-us.apache.org/repos/asf/systemml/blob/578a9869/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index a453cbd..ae4f820 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -56,6 +56,7 @@ import org.junit.runners.Suite; RewriteCTableToRExpandTest.class, RewriteElementwiseMultChainOptimizationTest.class, RewriteEliminateAggregatesTest.class, + RewriteFoldRCBindTest.class, RewriteFuseBinaryOpChainTest.class, RewriteFusedRandTest.class, RewriteLoopVectorization.class,
