Repository: incubator-systemml
Updated Branches:
  refs/heads/master abbce2bc7 -> 1d1a9fa40


[HOTFIX][SYSTEMML-1459] Fix rewrite 'fuse binary subdag' (multi-matches)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/1d1a9fa4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/1d1a9fa4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/1d1a9fa4

Branch: refs/heads/master
Commit: 1d1a9fa403a9227d1ef56b959132177b532884f6
Parents: abbce2b
Author: Matthias Boehm <[email protected]>
Authored: Tue Apr 4 18:33:05 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Apr 4 19:28:12 2017 -0700

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationStatic.java   | 26 ++++++++++++--------
 1 file changed, 16 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1d1a9fa4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index a0ccb0f..5c6b9c8 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1028,6 +1028,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        BinaryOp bop = (BinaryOp)hi;
                        Hop left = hi.getInput().get(0);
                        Hop right = hi.getInput().get(1);
+                       boolean applied = false;
                        
                        //sample proportion (sprop) operator
                        if( bop.getOp() == OpOp2.MULT && 
left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
@@ -1051,11 +1052,12 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                                                
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                                                
HopRewriteUtils.cleanupUnreferenced(bop, left);
                                                hi = unary;
+                                               applied = true;
                                                
                                                LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-sprop1");
                                        }
                                }                               
-                               if( right instanceof BinaryOp ) //X*(1-X)
+                               if( !applied && right instanceof BinaryOp ) 
//X*(1-X)
                                {
                                        BinaryOp bright = (BinaryOp)right;
                                        Hop right1 = bright.getInput().get(0);
@@ -1069,13 +1071,15 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                                                
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                                                
HopRewriteUtils.cleanupUnreferenced(bop, left);
                                                hi = unary;
+                                               applied = true;
                                                
                                                LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-sprop2");
                                        }
                                }
                        }
+                       
                        //sigmoid operator
-                       else if( bop.getOp() == OpOp2.DIV && 
left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX
+                       if( !applied && bop.getOp() == OpOp2.DIV && 
left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX
                                         && left instanceof LiteralOp && 
HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp)
                        {
                                //note: if there are multiple consumers on the 
intermediate,
@@ -1116,20 +1120,20 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                                                        
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                                                        
HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop);
                                                        hi = unary;
+                                                       applied = true;
                                                        
                                                        LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-sigmoid1");
                                                }                               
                                        }
                                }               
                        }
-                       //select positive (selp) operator
-                       else if( bop.getOp() == OpOp2.MULT && 
left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
+                       
+                       //select positive (selp) operator (note: same initial 
pattern as sprop)
+                       if( !applied && bop.getOp() == OpOp2.MULT && 
left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX )
                        {
                                //by definition, either left or right or none 
applies. 
                                //note: if there are multiple consumers on the 
intermediate tmp=(X>0), it's still beneficial
                                //to replace the X*tmp with selp(X) due to 
lower memory requirements and simply sparsity propagation 
-                               boolean applied = false;
-                               
                                if( left instanceof BinaryOp ) //(X>0)*X
                                {
                                        BinaryOp bleft = (BinaryOp)left;
@@ -1143,7 +1147,6 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                                UnaryOp unary = 
HopRewriteUtils.createUnary(right, OpOp1.SELP);
                                                
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                                                
HopRewriteUtils.cleanupUnreferenced(bop, left);
-                                               
                                                hi = unary;
                                                applied = true;
                                                
@@ -1163,7 +1166,6 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                                UnaryOp unary = 
HopRewriteUtils.createUnary(left, OpOp1.SELP);
                                                
HopRewriteUtils.replaceChildReference(parent, bop, unary, pos);
                                                
HopRewriteUtils.cleanupUnreferenced(bop, left);
-                                               
                                                hi = unary;
                                                applied= true;
                                                
@@ -1171,25 +1173,29 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                                        }
                                }
                        }
+                       
                        //select positive (selp) operator; pattern: max(X,0) -> 
selp+
-                       else if( bop.getOp() == OpOp2.MAX && 
left.getDataType()==DataType.MATRIX 
+                       if( !applied && bop.getOp() == OpOp2.MAX && 
left.getDataType()==DataType.MATRIX 
                                        && right instanceof LiteralOp && 
HopRewriteUtils.getDoubleValue((LiteralOp)right)==0 )
                        {
                                UnaryOp unary = 
HopRewriteUtils.createUnary(left, OpOp1.SELP);
                                HopRewriteUtils.replaceChildReference(parent, 
bop, unary, pos);
                                HopRewriteUtils.cleanupUnreferenced(bop);
                                hi = unary;
+                               applied = true;
                                
                                LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-selp3");
                        }
+                       
                        //select positive (selp) operator; pattern: max(0,X) -> 
selp+
-                       else if( bop.getOp() == OpOp2.MAX && 
right.getDataType()==DataType.MATRIX 
+                       if( !applied && bop.getOp() == OpOp2.MAX && 
right.getDataType()==DataType.MATRIX 
                                        && left instanceof LiteralOp && 
HopRewriteUtils.getDoubleValue((LiteralOp)left)==0 )
                        {
                                UnaryOp unary = 
HopRewriteUtils.createUnary(right, OpOp1.SELP);
                                HopRewriteUtils.replaceChildReference(parent, 
bop, unary, pos);
                                HopRewriteUtils.cleanupUnreferenced(bop);
                                hi = unary;
+                               applied = true;
                                
                                LOG.debug("Applied 
fuseBinarySubDAGToUnaryOperation-selp4");
                        }

Reply via email to