[SYSTEMML-2149] New simplification rewrite for replace zero w/ scalar

There are multiple scripts that emulate the replacement of zeros with a
scalar via X + (X==0) * s. We now rewrite this pattern to the builtin
function replace(X, 0, s), which avoids an unnecessary intermediate and
(partitioning-preserving) joins for distributed operations.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/62e590ce
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/62e590ce
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/62e590ce

Branch: refs/heads/master
Commit: 62e590ced04900364bdc294538e78de6af3f4988
Parents: 72830f0
Author: Matthias Boehm <[email protected]>
Authored: Thu Feb 15 18:49:00 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Thu Feb 15 18:49:00 2018 -0800

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationStatic.java   | 32 +++++++++++++++++---
 1 file changed, 28 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/62e590ce/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index ac45e77..3a3235d 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -175,6 +175,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = fuseOrderOperationChain(hi);                    
//e.g., order(order(X,2),1) -> order(X,(12))
                        hi = removeUnnecessaryReorgOperation(hop, hi, i);    
//e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
                        hi = simplifyTransposeAggBinBinaryChains(hop, hi, 
i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C)
+                       hi = simplifyReplaceZeroOperation(hop, hi, i);       
//e.g., X + (X==0) * s -> replace(X, 0, s)
                        hi = removeUnnecessaryMinus(hop, hi, i);             
//e.g., -(-X)->X; potentially introduced by simplify binary or dyn rewrites
                        hi = simplifyGroupedAggregate(hi);                   
//e.g., aggregate(target=X,groups=y,fn="count") -> 
aggregate(target=y,groups=y,fn="count")
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
@@ -1556,14 +1557,14 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
        {
                if( HopRewriteUtils.isTransposeOperation(hi)
                        && hi.getInput().get(0) instanceof BinaryOp             
          //basic binary
-                       && 
((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) 
+                       && 
((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations())
                {
                        Hop left = hi.getInput().get(0).getInput().get(0);
                        Hop C = hi.getInput().get(0).getInput().get(1);
                        
                        //check matrix mult and both inputs transposes w/ 
single consumer
                        if( left instanceof AggBinaryOp && 
C.getDataType().isMatrix()
-                               && 
HopRewriteUtils.isTransposeOperation(left.getInput().get(0))     
+                               && 
HopRewriteUtils.isTransposeOperation(left.getInput().get(0))
                                && left.getInput().get(0).getParent().size()==1 
                                && 
HopRewriteUtils.isTransposeOperation(left.getInput().get(1))
                                && left.getInput().get(1).getParent().size()==1 
)
@@ -1578,13 +1579,36 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                                HopRewriteUtils.replaceChildReference(parent, 
hi, bop, pos);
                                
                                hi = bop;
-                               LOG.debug("Applied 
simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");            
                             
-                       }  
+                               LOG.debug("Applied 
simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");
+                       }
                }
                
                return hi;
        }
        
+       // Patterns: X + (X==0) * s -> replace(X, 0, s)
+       private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int 
pos) 
+               throws HopsException
+       {
+               if( HopRewriteUtils.isBinary(hi, OpOp2.PLUS) && 
hi.getInput().get(0).isMatrix()
+                       && HopRewriteUtils.isBinary(hi.getInput().get(1), 
OpOp2.MULT)
+                       && hi.getInput().get(1).getInput().get(1).isScalar()
+                       && 
HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0), 
OpOp2.EQUAL, 0)
+                       && 
hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0))
 )
+               {
+                       HashMap<String, Hop> args = new HashMap<>();
+                       args.put("target", hi.getInput().get(0));
+                       args.put("pattern", new LiteralOp(0));
+                       args.put("replacement", 
hi.getInput().get(1).getInput().get(1));
+                       Hop replace = 
HopRewriteUtils.createParameterizedBuiltinOp(
+                               hi.getInput().get(0), args, 
ParamBuiltinOp.REPLACE);
+                       HopRewriteUtils.replaceChildReference(parent, hi, 
replace, pos);
+                       hi = replace;
+                       LOG.debug("Applied simplifyReplaceZeroOperation (line 
"+hi.getBeginLine()+").");
+               }
+               return hi;
+       }
+       
        /**
         * Pattners: t(t(X)) -> X, rev(rev(X)) -> X
         * 

Reply via email to