Add name to sorting of EMult rewrite. Handle Ternary A*A*B case. AggUnaryOp now constructs the TernaryOperator (A,A,B) instead of (A^2,B,1).
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ff8c836c Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ff8c836c Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ff8c836c Branch: refs/heads/master Commit: ff8c836c7b736dbd7b7651ac792a6d8c23989c98 Parents: eb0599d Author: Dylan Hutchison <dhutc...@cs.washington.edu> Authored: Fri Jun 9 13:18:19 2017 -0700 Committer: Dylan Hutchison <dhutc...@cs.washington.edu> Committed: Sun Jun 18 17:43:18 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/AggUnaryOp.java | 62 ++++++++++++++------ .../apache/sysml/hops/rewrite/RewriteEMult.java | 4 +- 2 files changed, 48 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ff8c836c/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index ee4ded2..4573b66 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -516,7 +516,6 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop } } } - return ret; } @@ -631,24 +630,53 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop Hop input11 = input1.getInput().get(0); Hop input12 = input1.getInput().get(1); - Lop in1 = null; - Lop in2 = null; - Lop in3 = null; + Lop in1 = null, in2 = null, in3 = null; + boolean handled = false; - if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) - { - in1 = input11.getInput().get(0).constructLops(); - in2 = input11.getInput().get(1).constructLops(); - in3 = input12.constructLops(); - } - else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) - { - in1 = input11.constructLops(); - in2 = input12.getInput().get(0).constructLops(); - in3 = input12.getInput().get(1).constructLops(); + if( input11 instanceof BinaryOp ) { + BinaryOp b11 = (BinaryOp)input11; + switch (b11.getOp()) { + case MULT: // A*B*C case + in1 = input11.getInput().get(0).constructLops(); + in2 = input11.getInput().get(1).constructLops(); + in3 = input12.constructLops(); + handled = true; + break; + case POW: // A*A*B case + Hop b112 = b11.getInput().get(1); + if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT) + && b112 instanceof LiteralOp + && ((LiteralOp)b112).getLongValue() == 2) { + in1 = b11.getInput().get(0).constructLops(); + in2 = in1; + in3 = input12.constructLops(); + handled = true; + } + break; + } + } else if( input12 instanceof BinaryOp ) { + BinaryOp b12 = (BinaryOp)input12; + switch (b12.getOp()) { + case MULT: // A*B*C case + in1 = input11.constructLops(); + in2 = input12.getInput().get(0).constructLops(); + in3 = input12.getInput().get(1).constructLops(); + handled = true; + break; + case POW: // A*B*B case + Hop b112 = b12.getInput().get(1); + if ( b112 instanceof LiteralOp + && ((LiteralOp)b112).getLongValue() == 2) { + in1 = b12.getInput().get(0).constructLops(); + in2 = in1; + in3 = input11.constructLops(); + handled = true; + } + break; + } } - else - { + + if (!handled) { in1 = input11.constructLops(); in2 = input12.constructLops(); in3 = new LiteralOp(1).constructLops(); http://git-wip-us.apache.org/repos/asf/systemml/blob/ff8c836c/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java index 2c9e5cb..66da6fa 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java @@ -145,7 +145,9 @@ public class RewriteEMult extends HopRewriteRule { return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); } - private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType).thenComparing(Object::hashCode); + private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType) + .thenComparing(Hop::getName) + .thenComparingInt(Object::hashCode); private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) { final ArrayList<Hop> parents = child.getParent();