Add name to sorting of EMult rewrite. Handle Ternary A*A*B case.

AggUnaryOp now constructs the TernaryOperator (A,A,B) instead of (A^2,B,1).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ff8c836c
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ff8c836c
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ff8c836c

Branch: refs/heads/master
Commit: ff8c836c7b736dbd7b7651ac792a6d8c23989c98
Parents: eb0599d
Author: Dylan Hutchison <dhutc...@cs.washington.edu>
Authored: Fri Jun 9 13:18:19 2017 -0700
Committer: Dylan Hutchison <dhutc...@cs.washington.edu>
Committed: Sun Jun 18 17:43:18 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  | 62 ++++++++++++++------
 .../apache/sysml/hops/rewrite/RewriteEMult.java |  4 +-
 2 files changed, 48 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ff8c836c/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index ee4ded2..4573b66 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -516,7 +516,6 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                                }
                        }
                }
-               
                return ret;
        }
        
@@ -631,24 +630,53 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                Hop input11 = input1.getInput().get(0);
                Hop input12 = input1.getInput().get(1);
                
-               Lop in1 = null;
-               Lop in2 = null;
-               Lop in3 = null;
+               Lop in1 = null, in2 = null, in3 = null;
+               boolean handled = false;
                
-               if( input11 instanceof BinaryOp && 
((BinaryOp)input11).getOp()==OpOp2.MULT )
-               {
-                       in1 = input11.getInput().get(0).constructLops();
-                       in2 = input11.getInput().get(1).constructLops();
-                       in3 = input12.constructLops();
-               }
-               else if( input12 instanceof BinaryOp && 
((BinaryOp)input12).getOp()==OpOp2.MULT )
-               {
-                       in1 = input11.constructLops();
-                       in2 = input12.getInput().get(0).constructLops();
-                       in3 = input12.getInput().get(1).constructLops();
+               if( input11 instanceof BinaryOp ) {
+                       BinaryOp b11 = (BinaryOp)input11;
+                       switch (b11.getOp()) {
+                       case MULT: // A*B*C case
+                               in1 = input11.getInput().get(0).constructLops();
+                               in2 = input11.getInput().get(1).constructLops();
+                               in3 = input12.constructLops();
+                               handled = true;
+                               break;
+                       case POW: // A*A*B case
+                               Hop b112 = b11.getInput().get(1);
+                               if ( !(input12 instanceof BinaryOp && 
((BinaryOp)input12).getOp()==OpOp2.MULT)
+                                               && b112 instanceof LiteralOp
+                                               && 
((LiteralOp)b112).getLongValue() == 2) {
+                                       in1 = 
b11.getInput().get(0).constructLops();
+                                       in2 = in1;
+                                       in3 = input12.constructLops();
+                                       handled = true;
+                               }
+                               break;
+                       }
+               } else if( input12 instanceof BinaryOp ) {
+                       BinaryOp b12 = (BinaryOp)input12;
+                       switch (b12.getOp()) {
+                       case MULT: // A*B*C case
+                               in1 = input11.constructLops();
+                               in2 = input12.getInput().get(0).constructLops();
+                               in3 = input12.getInput().get(1).constructLops();
+                               handled = true;
+                               break;
+                       case POW: // A*B*B case
+                               Hop b112 = b12.getInput().get(1);
+                               if ( b112 instanceof LiteralOp
+                                               && 
((LiteralOp)b112).getLongValue() == 2) {
+                                       in1 = 
b12.getInput().get(0).constructLops();
+                                       in2 = in1;
+                                       in3 = input11.constructLops();
+                                       handled = true;
+                               }
+                               break;
+                       }
                }
-               else
-               {
+
+               if (!handled) {
                        in1 = input11.constructLops();
                        in2 = input12.constructLops();
                        in3 = new LiteralOp(1).constructLops();

http://git-wip-us.apache.org/repos/asf/systemml/blob/ff8c836c/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 2c9e5cb..66da6fa 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -145,7 +145,9 @@ public class RewriteEMult extends HopRewriteRule {
                return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), 
Hop.OpOp2.POW);
        }
 
-       private static Comparator<Hop> compareByDataType = 
Comparator.comparing(Hop::getDataType).thenComparing(Object::hashCode);
+       private static Comparator<Hop> compareByDataType = 
Comparator.comparing(Hop::getDataType)
+                       .thenComparing(Hop::getName)
+                       .thenComparingInt(Object::hashCode);
 
        private static boolean checkForeignParent(final Set<BinaryOp> emults, 
final BinaryOp child) {
                final ArrayList<Hop> parents = child.getParent();

Reply via email to