[SYSTEMML-2149] New simplification rewrite for replace zero w/ scalar There are multiple scripts that emulate the replacement of zeros with a scalar via X + (X==0) * s. We now rewrite this pattern to the builtin function replace(X, 0, s), which avoids an unnecessary intermediate and (partitioning-preserving) joins for distributed operations.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/62e590ce Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/62e590ce Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/62e590ce Branch: refs/heads/master Commit: 62e590ced04900364bdc294538e78de6af3f4988 Parents: 72830f0 Author: Matthias Boehm <[email protected]> Authored: Thu Feb 15 18:49:00 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Feb 15 18:49:00 2018 -0800 ---------------------------------------------------------------------- .../RewriteAlgebraicSimplificationStatic.java | 32 +++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/62e590ce/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index ac45e77..3a3235d 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -175,6 +175,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = fuseOrderOperationChain(hi); //e.g., order(order(X,2),1) -> order(X,(12)) hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites hi = simplifyTransposeAggBinBinaryChains(hop, hi, i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C) + hi = simplifyReplaceZeroOperation(hop, hi, i); //e.g., X + (X==0) * s -> replace(X, 0, s) hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplify binary or dyn rewrites hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count") if(OptimizerUtils.ALLOW_OPERATOR_FUSION) { @@ -1556,14 +1557,14 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { if( HopRewriteUtils.isTransposeOperation(hi) && hi.getInput().get(0) instanceof BinaryOp //basic binary - && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) + && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) { Hop left = hi.getInput().get(0).getInput().get(0); Hop C = hi.getInput().get(0).getInput().get(1); //check matrix mult and both inputs transposes w/ single consumer if( left instanceof AggBinaryOp && C.getDataType().isMatrix() - && HopRewriteUtils.isTransposeOperation(left.getInput().get(0)) + && HopRewriteUtils.isTransposeOperation(left.getInput().get(0)) && left.getInput().get(0).getParent().size()==1 && HopRewriteUtils.isTransposeOperation(left.getInput().get(1)) && left.getInput().get(1).getParent().size()==1 ) @@ -1578,13 +1579,36 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule HopRewriteUtils.replaceChildReference(parent, hi, bop, pos); hi = bop; - LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+")."); - } + LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+")."); + } } return hi; } + // Patterns: X + (X==0) * s -> replace(X, 0, s) + private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int pos) + throws HopsException + { + if( HopRewriteUtils.isBinary(hi, OpOp2.PLUS) && hi.getInput().get(0).isMatrix() + && HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) + && hi.getInput().get(1).getInput().get(1).isScalar() + && HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0), OpOp2.EQUAL, 0) + && hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0)) ) + { + HashMap<String, Hop> args = new HashMap<>(); + args.put("target", hi.getInput().get(0)); + args.put("pattern", new LiteralOp(0)); + args.put("replacement", hi.getInput().get(1).getInput().get(1)); + Hop replace = HopRewriteUtils.createParameterizedBuiltinOp( + hi.getInput().get(0), args, ParamBuiltinOp.REPLACE); + HopRewriteUtils.replaceChildReference(parent, hi, replace, pos); + hi = replace; + LOG.debug("Applied simplifyReplaceZeroOperation (line "+hi.getBeginLine()+")."); + } + return hi; + } + /** * Pattners: t(t(X)) -> X, rev(rev(X)) -> X *
