[SYSTEMML-694] Improved wdivmm rewrites (outer-product-like mm only)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/382df653
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/382df653
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/382df653

Branch: refs/heads/master
Commit: 382df653a5438124b1cdff8c676a55b9c0d50976
Parents: a528f5e
Author: Matthias Boehm <[email protected]>
Authored: Fri Jul 29 23:25:01 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jul 30 16:23:21 2016 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 11 +++++++++
 .../RewriteAlgebraicSimplificationDynamic.java  | 24 +++++++++++---------
 2 files changed, 24 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/382df653/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 385a888..f7e4656 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -785,6 +785,17 @@ public class HopRewriteUtils
                                    : (hop.getDim1()>0 && 
hop.getDim1()<=hop.getRowsInBlock());
        }
        
+       /**
+        * 
+        * @param hop
+        * @return
+        */
+       public static boolean isOuterProductLikeMM( Hop hop ) {
+               return hop instanceof AggBinaryOp
+                       && hop.getInput().get(0).getDim1() > 
hop.getInput().get(0).getDim2()
+                       && hop.getInput().get(1).getDim1() < 
hop.getInput().get(1).getDim2();
+       }
+       
        public static boolean isEqualValue( LiteralOp hop1, LiteralOp hop2 ) 
                throws HopsException
        {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/382df653/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index dbde506..10953f5 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -1852,7 +1852,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        //alternative pattern: t(U) %*% (W*(U%*%t(V)))
                        if( right instanceof BinaryOp && 
HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY) 
       
                                && 
HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) 
//prevent mv
-                               && right.getInput().get(1) instanceof 
AggBinaryOp
+                               && 
HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1))
                                && 
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) 
//BLOCKSIZE CONSTRAINT
                        {
                                Hop W = right.getInput().get(0); 
@@ -1886,6 +1886,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                && right.getInput().get(1) instanceof BinaryOp
                                && ((BinaryOp) right.getInput().get(1)).getOp() 
== Hop.OpOp2.PLUS
                                && 
right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
+                               && 
HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
                                && 
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true)
 ) //BLOCKSIZE CONSTRAINT
                        {
                                Hop W = right.getInput().get(0); 
@@ -1917,7 +1918,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        if( !appliedPattern
                                && left instanceof BinaryOp && 
HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY) 
 
                                && 
HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) 
//prevent mv
-                               && left.getInput().get(1) instanceof AggBinaryOp
+                               && 
HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1))
                                && 
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) 
//BLOCKSIZE CONSTRAINT
                        {
                                Hop W = left.getInput().get(0); 
@@ -1948,6 +1949,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                && left.getInput().get(1) instanceof BinaryOp
                                && ((BinaryOp) left.getInput().get(1)).getOp() 
== Hop.OpOp2.PLUS
                                && 
left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
+                               && 
HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
                                && 
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true)
 ) //BLOCKSIZE CONSTRAINT
                        {
                                Hop W = left.getInput().get(0); 
@@ -1975,8 +1977,8 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        if( !appliedPattern
                                && right instanceof BinaryOp && 
((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
                                && right.getInput().get(1) instanceof BinaryOp 
&& ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS      
-                               && right.getInput().get(1).getInput().get(0) 
instanceof AggBinaryOp
-                && right.getInput().get(1).getInput().get(1).getDataType() == 
DataType.MATRIX
+                               && 
HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
+                               && 
right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
                                && 
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true)
 ) //BLOCKSIZE CONSTRAINT
                        {
                                Hop W = right.getInput().get(0); 
@@ -2008,8 +2010,8 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        if( !appliedPattern
                                && left instanceof BinaryOp && 
((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT   
                                && left.getInput().get(1) instanceof BinaryOp 
&& ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS        
-                               && left.getInput().get(1).getInput().get(0) 
instanceof AggBinaryOp
-                && left.getInput().get(1).getInput().get(1).getDataType() == 
DataType.MATRIX
+                               && 
HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
+                               && 
left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
                                && 
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true)
 ) //BLOCKSIZE CONSTRAINT
                        {
                                Hop W = left.getInput().get(0); 
@@ -2038,8 +2040,8 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        if( !appliedPattern
                                && right instanceof BinaryOp && 
((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
                                && right.getInput().get(1) instanceof BinaryOp 
&& ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS      
-                               && right.getInput().get(1).getInput().get(0) 
instanceof AggBinaryOp
-                && right.getInput().get(1).getInput().get(1).getDataType() == 
DataType.MATRIX
+                               && 
HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
+                               && 
right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
                                && 
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true)
 ) //BLOCKSIZE CONSTRAINT
                        {
                                Hop W = right.getInput().get(0); 
@@ -2071,8 +2073,8 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        if( !appliedPattern
                                && left instanceof BinaryOp && 
((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT   
                                && left.getInput().get(1) instanceof BinaryOp 
&& ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS        
-                               && left.getInput().get(1).getInput().get(0) 
instanceof AggBinaryOp
-                && left.getInput().get(1).getInput().get(1).getDataType() == 
DataType.MATRIX
+                               && 
HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
+                               && 
left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
                                && 
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true)
 ) //BLOCKSIZE CONSTRAINT
                        {
                                Hop W = left.getInput().get(0); 
@@ -2105,7 +2107,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        && hi.getDim2() > 1 //not applied for vector-vector mult
                        && hi.getInput().get(0).getDataType() == 
DataType.MATRIX 
                        && hi.getInput().get(0).getDim2() > 
hi.getInput().get(0).getColsInBlock()
-                       && hi.getInput().get(1) instanceof AggBinaryOp
+                       && 
HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1))
                        && (((AggBinaryOp) 
hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || 
hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
                        && 
HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) 
//BLOCKSIZE CONSTRAINT
                {

Reply via email to