[SYSTEMML-766] Improved 'fuse axpy' rewrite (more patterns, no overlap)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/973b8635
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/973b8635
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/973b8635

Branch: refs/heads/master
Commit: 973b863579d7bf82505933d3d67fef4517c53eb3
Parents: b233b59
Author: Matthias Boehm <[email protected]>
Authored: Wed Jul 20 22:34:46 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Jul 21 12:54:15 2016 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 29 +++++++++
 .../RewriteAlgebraicSimplificationDynamic.java  | 65 ++++++++++++++++++++
 .../RewriteAlgebraicSimplificationStatic.java   | 41 ------------
 .../misc/RewriteFuseBinaryOpChainTest.java      | 40 ++++++++++--
 .../misc/RewriteFuseBinaryOpChainTest3.R        | 28 +++++++++
 .../misc/RewriteFuseBinaryOpChainTest3.dml      | 27 ++++++++
 6 files changed, 184 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index a5432f1..385a888 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -36,6 +36,7 @@ import org.apache.sysml.hops.Hop.DataOpTypes;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.FileFormatTypes;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
 import org.apache.sysml.hops.Hop.ReOrgOp;
 import org.apache.sysml.hops.Hop.VisitStatus;
@@ -45,6 +46,7 @@ import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.MemoTable;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.parser.DataExpression;
@@ -644,6 +646,22 @@ public class HopRewriteUtils
                return datagen;
        }
        
+       /**
+        * 
+        * @param mleft
+        * @param smid
+        * @param mright
+        * @param op
+        * @return
+        */
+       public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop 
mright, OpOp3 op) {
+               TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, 
ValueType.DOUBLE, op, mleft, smid, mright);
+               ternOp.setRowsInBlock(mleft.getRowsInBlock());
+               ternOp.setColsInBlock(mleft.getColsInBlock());
+               ternOp.refreshSizeInformation();
+               return ternOp;
+       }
+       
        public static void setOutputBlocksizes( Hop hop, long brlen, long bclen 
)
        {
                hop.setRowsInBlock( brlen );
@@ -878,6 +896,17 @@ public class HopRewriteUtils
         * @param hop
         * @return
         */
+       public static boolean isScalarMatrixBinaryMult( Hop hop ) {
+               return hop instanceof BinaryOp && 
((BinaryOp)hop).getOp()==OpOp2.MULT
+                       && 
((hop.getInput().get(0).getDataType()==DataType.SCALAR && 
hop.getInput().get(1).getDataType()==DataType.MATRIX)
+                       || 
(hop.getInput().get(0).getDataType()==DataType.MATRIX && 
hop.getInput().get(1).getDataType()==DataType.SCALAR));
+       }
+       
+       /**
+        * 
+        * @param hop
+        * @return
+        */
        public static boolean isBasic1NSequence(Hop hop)
        {
                boolean ret = false;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8205e83..dbde506 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -35,6 +35,7 @@ import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.DataGenMethod;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.OpOp4;
 import org.apache.sysml.hops.Hop.ReOrgOp;
 import org.apache.sysml.hops.HopsException;
@@ -174,6 +175,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                        hi = simplifyWeightedUnaryMM(hop, hi, i);         
//e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp)
                        hi = simplifyDotProductSum(hop, hi, i);           
//e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1 
                        hi = fuseSumSquared(hop, hi, i);                  
//e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
+                       hi = fuseAxpyBinaryOperationChain(hop, hi, i);    
//e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)      
                        hi = reorderMinusMatrixMult(hop, hi, i);          
//e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
                        hi = simplifySumMatrixMult(hop, hi, i);           
//e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss
                        hi = simplifyEmptyBinaryOperation(hop, hi, i);    
//e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X
@@ -2458,6 +2460,69 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                return hi;
        }
        
+
+       /**
+        * 
+        * @param parent
+        * @param hi
+        * @param pos
+        * @return
+        * @throws HopsException
+        */
+       private Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) 
+       {
+               //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X 
- s*Y -> X -* sY                
+               if( hi instanceof BinaryOp 
+                       && (((BinaryOp)hi).getOp()==OpOp2.PLUS || 
((BinaryOp)hi).getOp()==OpOp2.MINUS) )
+               {
+                       BinaryOp bop = (BinaryOp) hi;
+                       Hop left = bop.getInput().get(0);
+                       Hop right = bop.getInput().get(1);
+                       Hop ternop = null;
+                       
+                       //pattern (a) X + s*Y -> X +* sY
+                       if( bop.getOp() == OpOp2.PLUS && 
left.getDataType()==DataType.MATRIX 
+                               && 
HopRewriteUtils.isScalarMatrixBinaryMult(right)
+                               && right.getParent().size() == 1 )           
//single consumer s*Y
+                       {
+                               Hop smid = right.getInput().get( 
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); 
+                               Hop mright = right.getInput().get( 
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+                               ternop = HopRewriteUtils.createTernaryOp(left, 
smid, mright, OpOp3.PLUS_MULT);
+                               LOG.debug("Applied 
fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")");
+                       }
+                       //pattern (b) s*Y + X -> X +* sY
+                       else if( bop.getOp() == OpOp2.PLUS && 
right.getDataType()==DataType.MATRIX 
+                               && 
HopRewriteUtils.isScalarMatrixBinaryMult(left)
+                               && left.getParent().size() == 1              
//single consumer s*Y
+                               && HopRewriteUtils.isEqualSize(left, right)) 
//correctness matrix-vector
+                       {
+                               Hop smid = left.getInput().get( 
(left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); 
+                               Hop mright = left.getInput().get( 
(left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+                               ternop = HopRewriteUtils.createTernaryOp(right, 
smid, mright, OpOp3.PLUS_MULT);
+                               LOG.debug("Applied 
fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")");      
+                       }
+                       //pattern (c) X - s*Y -> X -* sY
+                       else if( bop.getOp() == OpOp2.MINUS && 
left.getDataType()==DataType.MATRIX 
+                               && 
HopRewriteUtils.isScalarMatrixBinaryMult(right)
+                               && right.getParent().size() == 1 )           
//single consumer s*Y
+                       {
+                               Hop smid = right.getInput().get( 
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); 
+                               Hop mright = right.getInput().get( 
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+                               ternop = HopRewriteUtils.createTernaryOp(left, 
smid, mright, OpOp3.MINUS_MULT);
+                               LOG.debug("Applied 
fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")");
+                       }
+                       
+                       //rewire parent-child operators if rewrite applied
+                       if( ternop != null ) { 
+                               
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+                               HopRewriteUtils.addChildReference(parent, 
ternop, pos);
+                               hi = ternop;
+                       }
+               }
+               
+               return hi;
+       }
+       
        /**
         * 
         * @param parent

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 9ef2c05..ae9c073 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -162,7 +162,6 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = fuseLogNzBinaryOperation(hop, hi, i);           
//e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
                        hi = simplifyOuterSeqExpand(hop, hi, i);             
//e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, 
cast=false)
                        hi = simplifyTableSeqExpand(hop, hi, i);             
//e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, 
ignore=false, cast=true)
-                       hi = fuseBinaryOperationChain(hop, hi, i);              
         //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)       
                        //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
                        //process childs recursively after rewrites (to 
investigate pattern newly created by rewrites)
@@ -1906,44 +1905,4 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                
                return hi;
        }
-
-       /**
-        * 
-        * @param parent
-        * @param hi
-        * @param pos
-        * @return
-        * @throws HopsException
-        */
-       private Hop fuseBinaryOperationChain(Hop parent, Hop hi, int pos) {
-               //pattern: X + lamda*Y -> X +* lambda Y         
-               if( hi instanceof BinaryOp 
-                       && (((BinaryOp)hi).getOp()==OpOp2.PLUS || 
((BinaryOp)hi).getOp()==OpOp2.MINUS) 
-                       && hi.getInput().get(0).getDataType()==DataType.MATRIX 
-                       && hi.getInput().get(1) instanceof BinaryOp 
-                       && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT 
)
-               {
-                       //Check that the inner binary Op is a product of Scalar 
times Matrix or viceversa
-                       Hop innerBinaryOp =  hi.getInput().get(1);
-                       if ( 
(innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && 
innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX) 
-                                       || 
(innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && 
innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR))
-                       {
-                               //check which operand is the Scalar and which 
is the matrix
-                               Hop lamda = 
(innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? 
innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); 
-                               Hop matrix = 
(innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? 
innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
-
-                               OpOp3 op = (((BinaryOp)hi).getOp()==OpOp2.PLUS) 
? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
-                               TernaryOp ternOp = new TernaryOp("tmp", 
DataType.MATRIX, ValueType.DOUBLE, op, hi.getInput().get(0), lamda, matrix);
-                               HopRewriteUtils.refreshOutputParameters(ternOp, 
hi.getInput().get(0));
-                               
-                               
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
-                               HopRewriteUtils.addChildReference(parent, 
ternOp, pos);
-                               
-                               LOG.debug("Applied fuseBinaryOperationChain. 
(line " +hi.getBeginLine()+")");
-                               return ternOp;
-                       }
-               }
-               
-               return hi;
-       }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
index 890a3b2..ff85ebc 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
@@ -40,8 +40,9 @@ import org.apache.sysml.utils.Statistics;
  */
 public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase 
 {
-       private static final String TEST_NAME1 = 
"RewriteFuseBinaryOpChainTest1";
-       private static final String TEST_NAME2 = 
"RewriteFuseBinaryOpChainTest2";
+       private static final String TEST_NAME1 = 
"RewriteFuseBinaryOpChainTest1"; //+* (X+s*Y)
+       private static final String TEST_NAME2 = 
"RewriteFuseBinaryOpChainTest2"; //-* (X-s*Y) 
+       private static final String TEST_NAME3 = 
"RewriteFuseBinaryOpChainTest3"; //+* (s*Y+X)
 
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/";
@@ -53,6 +54,7 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                TestUtils.clearAssertionInformation();
                addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
        }
        
        @Test
@@ -60,7 +62,6 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                testFuseBinaryChain( TEST_NAME1, false, ExecType.CP );
        }
        
-       
        @Test
        public void testFuseBinaryPlusRewriteCP() {
                testFuseBinaryChain( TEST_NAME1, true, ExecType.CP);
@@ -77,6 +78,16 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testFuseBinaryPlus2NoRewriteCP() {
+               testFuseBinaryChain( TEST_NAME3, false, ExecType.CP );
+       }
+       
+       @Test
+       public void testFuseBinaryPlus2RewriteCP() {
+               testFuseBinaryChain( TEST_NAME3, true, ExecType.CP );
+       }
+       
+       @Test
        public void testFuseBinaryPlusNoRewriteSP() {
                testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK );
        }
@@ -97,6 +108,16 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
        }
        
        @Test
+       public void testFuseBinaryPlus2NoRewriteSP() {
+               testFuseBinaryChain( TEST_NAME3, false, ExecType.SPARK );
+       }
+       
+       @Test
+       public void testFuseBinaryPlus2RewriteSP() {
+               testFuseBinaryChain( TEST_NAME3, true, ExecType.SPARK );
+       }
+       
+       @Test
        public void testFuseBinaryPlusNoRewriteMR() {
                testFuseBinaryChain( TEST_NAME1, false, ExecType.MR );
        }
@@ -116,6 +137,15 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                testFuseBinaryChain( TEST_NAME2, true, ExecType.MR );
        }
        
+       @Test
+       public void testFuseBinaryPlus2NoRewriteMR() {
+               testFuseBinaryChain( TEST_NAME3, false, ExecType.MR );
+       }
+       
+       @Test
+       public void testFuseBinaryPlus2RewriteMR() {
+               testFuseBinaryChain( TEST_NAME3, true, ExecType.MR );
+       }
        
        /**
         * 
@@ -162,8 +192,8 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                        //check for applies rewrites
                        if( rewrites && instType!=ExecType.MR  ) {
                                String prefix = (instType==ExecType.SPARK) ? 
Instruction.SP_INST_PREFIX  : "";
-                               Assert.assertTrue("Rewrite not 
applied.",Statistics.getCPHeavyHitterOpCodes()
-                                               
.contains(testname.equals(TEST_NAME1) ? prefix+"+*" : prefix+"-*" ));
+                               String opcode = 
(testname.equals(TEST_NAME1)||testname.equals(TEST_NAME3)) ? prefix+"+*" : 
prefix+"-*";
+                               Assert.assertTrue("Rewrite not 
applied.",Statistics.getCPHeavyHitterOpCodes().contains(opcode));
                        }
                }
                finally

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R 
b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R
new file mode 100644
index 0000000..5ae1642
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X=matrix(1,10,10)
+Y=matrix(1,10,10)
+lamda=7
+S=lamda*Y+X
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml 
b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml
new file mode 100644
index 0000000..af84884
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X=matrix(1,rows=10,cols=10)
+Y=matrix(1,rows=10,cols=10)
+if(1==1){}
+lamda=7
+S=lamda*Y+X
+write(S,$1)
\ No newline at end of file

Reply via email to