[SYSTEMML-833] Additional cleanup rewrites (unnecess. cast, reorg, agg) This patch adds various additional cleanup rewrites in order to simplify debugging. In detail this includes:
(1) Unnecessary data type casts (e.g., as.scalar(as.matrix)) (2) Unnecessary reorg operations (e.g., t(X), iff X 1x1 dims) (3) Unnecessary aggregation (e.g., sum(X) iff X 1x1 dims) (4) Pushdown of scalar casts (e.g., as.scalar(X*s)->as.scalar(X)*s) Note that these rewrites enable each other; e.g., once (2), (3), and (4) are performed, unnecessary casts (1) can be removed avoiding long chains of unnecessary operations like sum(t(as.matrix(t(X))*7)) -> as.scalar(X)*7. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/11a85775 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/11a85775 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/11a85775 Branch: refs/heads/master Commit: 11a85775f11e4490d957fe4f9fab4bfd8ea7a138 Parents: 461184a Author: Matthias Boehm <[email protected]> Authored: Sat Jul 30 23:48:51 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Jul 31 19:16:59 2016 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopRewriteUtils.java | 3 +- .../sysml/hops/rewrite/ProgramRewriter.java | 11 ++-- .../RewriteAlgebraicSimplificationDynamic.java | 67 ++++++++++++++++---- .../RewriteAlgebraicSimplificationStatic.java | 47 ++++++++++++++ .../rewrite/RewriteRemoveUnnecessaryCasts.java | 21 +++++- .../cp/ArithmeticBinaryCPInstruction.java | 14 ++-- 6 files changed, 134 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index f7e4656..3bfdcb5 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -512,7 +512,8 @@ public class HopRewriteUtils public static UnaryOp createUnary(Hop input, OpOp1 type) throws HopsException { - UnaryOp unary = new UnaryOp(input.getName(), input.getDataType(), input.getValueType(), type, input); + DataType dt = (type==OpOp1.CAST_AS_SCALAR) ? DataType.SCALAR : input.getDataType(); + UnaryOp unary = new UnaryOp(input.getName(), dt, input.getValueType(), type, input); HopRewriteUtils.setOutputBlocksizes(unary, input.getRowsInBlock(), input.getColsInBlock()); HopRewriteUtils.copyLineNumbers(input, unary); unary.refreshSizeInformation(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 8e645dc..e7b03c4 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -130,12 +130,13 @@ public class ProgramRewriter _dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse } - - //reapply cse after rewrites because (1) applied rewrites on operators w/ multiple parents, and - //(2) newly introduced operators potentially created redundancy (incl leaf merge to allow for cse) - if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) - _dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); //dependency: simplifications } + + // cleanup after all rewrites applied + // (newly introduced operators, introduced redundancy after rewrites w/ multiple parents) + _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); + if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) + _dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); } /** http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 10953f5..793bc25 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -69,6 +69,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule //valid aggregation operation types for empty (sparse-safe) operations (not all operations apply) //AggOp.MEAN currently not due to missing count/corrections private static AggOp[] LOOKUP_VALID_EMPTY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE}; + private static AggOp[] LOOKUP_VALID_UNNECESSARY_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.PROD, AggOp.TRACE}; //valid unary operation types for empty (sparse-safe) operations (not all operations apply) private static OpOp1[] LOOKUP_VALID_EMPTY_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.SIN, OpOp1.TAN, OpOp1.SQRT, OpOp1.ROUND, OpOp1.CUMSUM}; @@ -149,13 +150,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule hi = removeUnnecessaryLeftIndexing(hop, hi, i); //e.g., X[,1]=Y -> Y, if output == input dims hi = fuseLeftIndexingChainToAppend(hop, hi, i); //e.g., X[,1]=A; X[,2]=B -> X=cbind(A,B), iff ncol(X)==2 and col1/2 lix hi = removeUnnecessaryCumulativeOp(hop, hi, i); //e.g., cumsum(X) -> X, if nrow(X)==1; - hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., matrix(X) -> X, if output == input dims + hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., matrix(X) -> X, if dims(in)==dims(out); r(X)->X, if 1x1 dims hi = removeUnnecessaryOuterProduct(hop, hi, i); //e.g., X*(Y%*%matrix(1,...) -> X*Y, if Y col vector hi = fuseDatagenAndReorgOperation(hop, hi, i); //e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1 hi = simplifyColwiseAggregate(hop, hi, i); //e.g., colsums(X) -> sum(X) or X, if col/row vector hi = simplifyRowwiseAggregate(hop, hi, i); //e.g., rowsums(X) -> sum(X) or X, if row/col vector hi = simplifyColSumsMVMult(hop, hi, i); //e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector hi = simplifyRowSumsMVMult(hop, hi, i); //e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector + hi = simplifyUnnecessaryAggregate(hop, hi, i); //e.g., sum(X) -> as.scalar(X), if 1x1 dims hi = simplifyEmptyAggregate(hop, hi, i); //e.g., sum(X) -> 0, if nnz(X)==0 hi = simplifyEmptyUnaryOperation(hop, hi, i); //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0 hi = simplifyEmptyReorgOperation(hop, hi, i); //e.g., t(X) -> matrix(0, ncol(X), nrow(X)) @@ -428,22 +430,26 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule */ private Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) { - if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp() == ReOrgOp.RESHAPE ) //reshape operation + if( hi instanceof ReorgOp ) { + ReorgOp rop = (ReorgOp) hi; Hop input = hi.getInput().get(0); - - if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims - { - //equal dims of reshape input and output -> no need for reshape because - //byrow always refers to both input/output and hence gives the same result - - //remove unnecessary right indexing - HopRewriteUtils.removeChildReference(parent, hi); + boolean apply = false; + + //equal dims of reshape input and output -> no need for reshape because + //byrow always refers to both input/output and hence gives the same result + apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input)); + + //1x1 dimensions of transpose/reshape -> no need for reorg + apply |= ((rop.getOp()==ReOrgOp.TRANSPOSE || rop.getOp()==ReOrgOp.RESHAPE) + && rop.getDim1()==1 && rop.getDim2()==1); + + if( apply ) { + HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, input, pos); parent.refreshSizeInformation(); hi = input; - - LOG.debug("Applied removeUnnecessaryReshape"); + LOG.debug("Applied removeUnnecessaryReorg."); } } @@ -841,6 +847,43 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule * @return * @throws HopsException */ + private Hop simplifyUnnecessaryAggregate(Hop parent, Hop hi, int pos) + throws HopsException + { + //e.g., sum(X) -> as.scalar(X) if 1x1 (applies to sum, min, max, prod, trace) + if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol ) + { + AggUnaryOp uhi = (AggUnaryOp)hi; + Hop input = uhi.getInput().get(0); + + if( HopRewriteUtils.isValidOp(uhi.getOp(), LOOKUP_VALID_UNNECESSARY_AGGREGATE) ){ + + if( input.getDim1()==1 && input.getDim2()==1 ) + { + UnaryOp cast = HopRewriteUtils.createUnary(input, OpOp1.CAST_AS_SCALAR); + + //remove unnecessary aggregation + HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); + HopRewriteUtils.addChildReference(parent, cast, pos); + parent.refreshSizeInformation(); + hi = cast; + + LOG.debug("Applied simplifyUnncessaryAggregate"); + } + } + } + + return hi; + } + + /** + * + * @param parent + * @param hi + * @param pos + * @return + * @throws HopsException + */ private Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos) throws HopsException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index f23686c..784d678 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -144,6 +144,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X hi = simplifyBushyBinaryOperation(hop, hi, i); //e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v) hi = simplifyUnaryAggReorgOperation(hop, hi, i); //e.g., sum(t(X)) -> sum(X) + hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s; hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X)) hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X) hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor @@ -890,6 +891,52 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule * @param hi * @param pos * @return + * @throws HopsException + */ + private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) + throws HopsException + { + if( hi instanceof UnaryOp && ((UnaryOp)hi).getOp()==OpOp1.CAST_AS_SCALAR + && hi.getInput().get(0) instanceof BinaryOp ) + { + BinaryOp bin = (BinaryOp) hi.getInput().get(0); + BinaryOp bout = null; + + //as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y) + if( bin.getInput().get(0).getDataType()==DataType.MATRIX + && bin.getInput().get(1).getDataType()==DataType.MATRIX ) { + UnaryOp cast1 = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR); + UnaryOp cast2 = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR); + bout = HopRewriteUtils.createBinary(cast1, cast2, bin.getOp()); + } + //as.scalar(X*s) -> as.scalar(X) * s + else if( bin.getInput().get(0).getDataType()==DataType.MATRIX ) { + UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR); + bout = HopRewriteUtils.createBinary(cast, bin.getInput().get(1), bin.getOp()); + } + //as.scalar(s*X) -> s * as.scalar(X) + else if ( bin.getInput().get(1).getDataType()==DataType.MATRIX ) { + UnaryOp cast = HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR); + bout = HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp()); + } + + if( bout != null ) { + HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); + HopRewriteUtils.addChildReference(parent, bout, pos); + + LOG.debug("Applied simplifyBinaryMatrixScalarOperation."); + } + } + + return hi; + } + + /** + * + * @param parent + * @param hi + * @param pos + * @return */ private Hop pushdownUnaryAggTransposeOperation( Hop parent, Hop hi, int pos ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java index 36d8712..a8001f8 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryCasts.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.Hop.VisitStatus; import org.apache.sysml.hops.UnaryOp; @@ -73,6 +74,7 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule * * @param hop */ + @SuppressWarnings("unchecked") private void rule_RemoveUnnecessaryCasts( Hop hop ) { //check mark processed @@ -84,7 +86,7 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule for( int i=0; i<inputs.size(); i++ ) rule_RemoveUnnecessaryCasts( inputs.get(i) ); - //remove cast if unnecessary + //remove unnecessary value type cast if( hop instanceof UnaryOp && HopRewriteUtils.isValueTypeCast(((UnaryOp)hop).getOp()) ) { Hop in = hop.getInput().get(0); @@ -116,6 +118,23 @@ public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule } } + //remove unnecessary data type casts + if( hop instanceof UnaryOp && hop.getInput().get(0) instanceof UnaryOp ) { + UnaryOp uop1 = (UnaryOp) hop; + UnaryOp uop2 = (UnaryOp) hop.getInput().get(0); + if( (uop1.getOp()==OpOp1.CAST_AS_MATRIX && uop2.getOp()==OpOp1.CAST_AS_SCALAR) + || (uop1.getOp()==OpOp1.CAST_AS_SCALAR && uop2.getOp()==OpOp1.CAST_AS_MATRIX) ) { + Hop input = uop2.getInput().get(0); + //rewire parents + ArrayList<Hop> parents = (ArrayList<Hop>) hop.getParent().clone(); + for( Hop p : parents ) { + int ix = HopRewriteUtils.getChildReferencePos(p, hop); + HopRewriteUtils.removeChildReference(p, hop); + HopRewriteUtils.addChildReference(p, input, ix); + } + } + } + //mark processed hop.setVisited( VisitStatus.DONE ); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/11a85775/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java index 38ba9dd..c9545ac 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ArithmeticBinaryCPInstruction.java @@ -60,16 +60,10 @@ public abstract class ArithmeticBinaryCPInstruction extends BinaryCPInstruction //make sure these checks belong here //if either input is a matrix, then output //has to be a matrix - if((dt1 == DataType.MATRIX - || dt2 == DataType.MATRIX) - && dt3 != DataType.MATRIX) - throw new DMLRuntimeException("Element-wise matrix operations between variables " - + in1.getName() - + " and " - + in2.getName() - + " must produce a matrix, which " - + out.getName() - + "is not"); + if((dt1 == DataType.MATRIX || dt2 == DataType.MATRIX) && dt3 != DataType.MATRIX) { + throw new DMLRuntimeException("Element-wise matrix operations between variables " + in1.getName() + + " and " + in2.getName() + " must produce a matrix, which " + out.getName() + "is not"); + } Operator operator = (dt1 != dt2) ? InstructionUtils.parseScalarBinaryOperator(opcode, (dt1 == DataType.SCALAR)) :
