[SYSTEMML-766] Fix rewrite 'fuse binary axpy' (missing blocksize info)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2b7fdb2b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2b7fdb2b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2b7fdb2b

Branch: refs/heads/master
Commit: 2b7fdb2b36df0aedc6b92bf138e7f0074eed7762
Parents: cbc4509
Author: Matthias Boehm <[email protected]>
Authored: Sat Jul 16 20:32:28 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jul 16 20:32:28 2016 -0700

----------------------------------------------------------------------
 .../rewrite/RewriteAlgebraicSimplificationStatic.java  | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2b7fdb2b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 43d5791..816b55a 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -165,7 +165,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = fuseLogNzBinaryOperation(hop, hi, i);           
//e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
                        hi = simplifyOuterSeqExpand(hop, hi, i);             
//e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, 
cast=false)
                        hi = simplifyTableSeqExpand(hop, hi, i);             
//e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, 
ignore=false, cast=true)
-                       hi = fuseBinaryOperationChain(hop, hi, i);              
         //e.g., X + lamda*Y -> X +* lambda Y   
+                       hi = fuseBinaryOperationChain(hop, hi, i);              
         //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)       
                        //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
                        //process childs recursively after rewrites (to 
investigate pattern newly created by rewrites)
@@ -1922,11 +1922,11 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                //pattern: X + lamda*Y -> X +* lambda Y         
                if( hi instanceof BinaryOp
                                && (((BinaryOp)hi).getOp()==OpOp2.PLUS || 
((BinaryOp)hi).getOp()==OpOp2.MINUS) 
-                               && 
((BinaryOp)hi).getInput().get(0).getDataType()==DataType.MATRIX && 
((BinaryOp)hi).getInput().get(1) instanceof BinaryOp 
+                               && 
hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1) 
instanceof BinaryOp 
                                && (DMLScript.rtplatform == 
RUNTIME_PLATFORM.SINGLE_NODE || OptimizerUtils.isSparkExecutionMode()) )
                {
                        //Check that the inner binary Op is a product of Scalar 
times Matrix or viceversa
-                       Hop innerBinaryOp =  ((BinaryOp)hi).getInput().get(1);
+                       Hop innerBinaryOp =  hi.getInput().get(1);
                        if ( 
(innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && 
innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX) 
                                        || 
(innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && 
innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR))
                        {
@@ -1934,8 +1934,9 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                Hop lamda = 
(innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? 
innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); 
                                Hop matrix = 
(innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? 
innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
 
-                               OpOp3 operator = 
(((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
-                               TernaryOp ternOp=new TernaryOp("tmp", 
DataType.MATRIX, ValueType.DOUBLE, operator, ((BinaryOp)hi).getInput().get(0), 
lamda, matrix);
+                               OpOp3 op = (((BinaryOp)hi).getOp()==OpOp2.PLUS) 
? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
+                               TernaryOp ternOp = new TernaryOp("tmp", 
DataType.MATRIX, ValueType.DOUBLE, op, hi.getInput().get(0), lamda, matrix);
+                               HopRewriteUtils.refreshOutputParameters(ternOp, 
hi.getInput().get(0));
                                
                                
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
                                HopRewriteUtils.addChildReference(parent, 
ternOp, pos);
@@ -1944,7 +1945,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                return ternOp;
                        }
                }
+               
                return hi;
-       
        }
 }

Reply via email to