[SYSTEMML-766] Fix rewrite 'fuse binary axpy' (missing blocksize info) Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2b7fdb2b Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2b7fdb2b Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2b7fdb2b
Branch: refs/heads/master Commit: 2b7fdb2b36df0aedc6b92bf138e7f0074eed7762 Parents: cbc4509 Author: Matthias Boehm <[email protected]> Authored: Sat Jul 16 20:32:28 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jul 16 20:32:28 2016 -0700 ---------------------------------------------------------------------- .../rewrite/RewriteAlgebraicSimplificationStatic.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2b7fdb2b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 43d5791..816b55a 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -165,7 +165,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) - hi = fuseBinaryOperationChain(hop, hi, i); //e.g., X + lamda*Y -> X +* lambda Y + hi = fuseBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) @@ -1922,11 +1922,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //pattern: X + lamda*Y -> X +* lambda Y if( hi instanceof BinaryOp && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) - && ((BinaryOp)hi).getInput().get(0).getDataType()==DataType.MATRIX && ((BinaryOp)hi).getInput().get(1) instanceof BinaryOp + && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1) instanceof BinaryOp && (DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE || OptimizerUtils.isSparkExecutionMode()) ) { //Check that the inner binary Op is a product of Scalar times Matrix or viceversa - Hop innerBinaryOp = ((BinaryOp)hi).getInput().get(1); + Hop innerBinaryOp = hi.getInput().get(1); if ( (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX) || (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR)) { @@ -1934,8 +1934,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule Hop lamda = (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); Hop matrix = (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); - OpOp3 operator = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT; - TernaryOp ternOp=new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, operator, ((BinaryOp)hi).getInput().get(0), lamda, matrix); + OpOp3 op = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT; + TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, hi.getInput().get(0), lamda, matrix); + HopRewriteUtils.refreshOutputParameters(ternOp, hi.getInput().get(0)); HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); HopRewriteUtils.addChildReference(parent, ternOp, pos); @@ -1944,7 +1945,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return ternOp; } } + return hi; - } }
