Repository: incubator-systemml Updated Branches: refs/heads/master abbce2bc7 -> 1d1a9fa40
[HOTFIX][SYSTEMML-1459] Fix rewrite 'fuse binary subdag' (multi-matches) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/1d1a9fa4 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/1d1a9fa4 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/1d1a9fa4 Branch: refs/heads/master Commit: 1d1a9fa403a9227d1ef56b959132177b532884f6 Parents: abbce2b Author: Matthias Boehm <[email protected]> Authored: Tue Apr 4 18:33:05 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Apr 4 19:28:12 2017 -0700 ---------------------------------------------------------------------- .../RewriteAlgebraicSimplificationStatic.java | 26 ++++++++++++-------- 1 file changed, 16 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/1d1a9fa4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index a0ccb0f..5c6b9c8 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -1028,6 +1028,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule BinaryOp bop = (BinaryOp)hi; Hop left = hi.getInput().get(0); Hop right = hi.getInput().get(1); + boolean applied = false; //sample proportion (sprop) operator if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) @@ -1051,11 +1052,12 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; + applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop1"); } } - if( right instanceof BinaryOp ) //X*(1-X) + if( !applied && right instanceof BinaryOp ) //X*(1-X) { BinaryOp bright = (BinaryOp)right; Hop right1 = bright.getInput().get(0); @@ -1069,13 +1071,15 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); hi = unary; + applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sprop2"); } } } + //sigmoid operator - else if( bop.getOp() == OpOp2.DIV && left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX + if( !applied && bop.getOp() == OpOp2.DIV && left.getDataType()==DataType.SCALAR && right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==1 && right instanceof BinaryOp) { //note: if there are multiple consumers on the intermediate, @@ -1116,20 +1120,20 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, bop2, uop); hi = unary; + applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-sigmoid1"); } } } } - //select positive (selp) operator - else if( bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) + + //select positive (selp) operator (note: same initial pattern as sprop) + if( !applied && bop.getOp() == OpOp2.MULT && left.getDataType()==DataType.MATRIX && right.getDataType()==DataType.MATRIX ) { //by definition, either left or right or none applies. //note: if there are multiple consumers on the intermediate tmp=(X>0), it's still beneficial //to replace the X*tmp with selp(X) due to lower memory requirements and simply sparsity propagation - boolean applied = false; - if( left instanceof BinaryOp ) //(X>0)*X { BinaryOp bleft = (BinaryOp)left; @@ -1143,7 +1147,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); - hi = unary; applied = true; @@ -1163,7 +1166,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop, left); - hi = unary; applied= true; @@ -1171,25 +1173,29 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule } } } + //select positive (selp) operator; pattern: max(X,0) -> selp+ - else if( bop.getOp() == OpOp2.MAX && left.getDataType()==DataType.MATRIX + if( !applied && bop.getOp() == OpOp2.MAX && left.getDataType()==DataType.MATRIX && right instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)right)==0 ) { UnaryOp unary = HopRewriteUtils.createUnary(left, OpOp1.SELP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop); hi = unary; + applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp3"); } + //select positive (selp) operator; pattern: max(0,X) -> selp+ - else if( bop.getOp() == OpOp2.MAX && right.getDataType()==DataType.MATRIX + if( !applied && bop.getOp() == OpOp2.MAX && right.getDataType()==DataType.MATRIX && left instanceof LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left)==0 ) { UnaryOp unary = HopRewriteUtils.createUnary(right, OpOp1.SELP); HopRewriteUtils.replaceChildReference(parent, bop, unary, pos); HopRewriteUtils.cleanupUnreferenced(bop); hi = unary; + applied = true; LOG.debug("Applied fuseBinarySubDAGToUnaryOperation-selp4"); }
