Add new `wumm` pattern to pick up element-wise multiply rewrite.

The new pattern recognizes when there is a `*2` or `2*` outside `W*(U%*%t(V))`.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/479b9da4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/479b9da4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/479b9da4

Branch: refs/heads/master
Commit: 479b9da4e6c605871a914ccb4b06ab6da5de21ed
Parents: e93c487
Author: Dylan Hutchison <[email protected]>
Authored: Thu Jul 13 01:14:48 2017 -0700
Committer: Dylan Hutchison <[email protected]>
Committed: Thu Jul 13 01:14:48 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/ProgramRewriter.java     |  2 +-
 .../RewriteAlgebraicSimplificationDynamic.java  | 44 +++++++++++++++++++-
 2 files changed, 43 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 59565df..7c4f861 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -54,7 +54,7 @@ public class ProgramRewriter
        private static final Log LOG = 
LogFactory.getLog(ProgramRewriter.class.getName());
        
        //internal local debug level
-       private static final boolean LDEBUG = false; 
+       private static final boolean LDEBUG = false;
        private static final boolean CHECK = false;
        
        private ArrayList<HopRewriteRule> _dagRuleSet = null;

http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8cd71f4..6246270 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -29,11 +29,11 @@ import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.DataGenOp;
 import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.QuaternaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.DataGenMethod;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.OpOp4;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
@@ -44,7 +44,7 @@ import org.apache.sysml.hops.LeftIndexingOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
-import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.QuaternaryOp;
 import org.apache.sysml.hops.ReorgOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.UnaryOp;
@@ -1959,6 +1959,46 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        appliedPattern = true;
                        LOG.debug("Applied simplifyWeightedUnaryMM1 (line 
"+hi.getBeginLine()+")");     
                }
+
+               //Pattern 1.5) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
+               if( !appliedPattern
+                               && hi instanceof BinaryOp && 
HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
+                               && 
(HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
+                                       || 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2)))
+               {
+                       final Hop nl; // non-literal
+                       if( hi.getInput().get(0) instanceof LiteralOp ) {
+                               nl = hi.getInput().get(1);
+                       } else {
+                               nl = hi.getInput().get(0);
+                       }
+
+                       if (       HopRewriteUtils.isBinary(nl, OpOp2.MULT)
+                                       && 
HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) 
//prevent mv
+                                       && nl.getDim2() > 1 //not applied for 
vector-vector mult
+                                       && nl.getInput().get(0).getDataType() 
== DataType.MATRIX
+                                       && nl.getInput().get(0).getDim2() > 
nl.getInput().get(0).getColsInBlock()
+                                       && 
HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1))
+                                       && (((AggBinaryOp) 
nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || 
nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
+                                       && 
HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) )
+                       {
+                               final Hop W = nl.getInput().get(0);
+                               final Hop U = 
nl.getInput().get(1).getInput().get(0);
+                               Hop V = nl.getInput().get(1).getInput().get(1);
+                               if( !HopRewriteUtils.isTransposeOperation(V) )
+                                       V = HopRewriteUtils.createTranspose(V);
+                               else
+                                       V = V.getInput().get(0);
+
+                               hnew = new QuaternaryOp(hi.getName(), 
DataType.MATRIX, ValueType.DOUBLE,
+                                               OpOp4.WUMM, W, U, V, true, 
null, OpOp2.MULT);
+                               hnew.setOutputBlocksizes(W.getRowsInBlock(), 
W.getColsInBlock());
+                               hnew.refreshSizeInformation();
+
+                               appliedPattern = true;
+                               LOG.debug("Applied simplifyWeightedUnaryMM2.7 
(line "+hi.getBeginLine()+")");
+                       }
+               }
                
                //Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to 
unary ops
                if( !appliedPattern

Reply via email to