Document RewriteEMult. Add smart recursion. RewriteEMult now rewrites emult chains deeper than the top-most one.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d88f867f Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d88f867f Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d88f867f Branch: refs/heads/master Commit: d88f867fd0384954dce9e6ce4d65f02f1054bc5e Parents: a5846bb Author: Dylan Hutchison <dhutc...@cs.washington.edu> Authored: Sat Jun 10 01:17:36 2017 -0700 Committer: Dylan Hutchison <dhutc...@cs.washington.edu> Committed: Sun Jun 18 17:43:33 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopRewriteUtils.java | 2 + .../apache/sysml/hops/rewrite/RewriteEMult.java | 90 +++++++++++++------- .../org/apache/sysml/parser/Expression.java | 1 - 3 files changed, 62 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 4d23cb9..17ac4ec 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -246,6 +246,8 @@ public class HopRewriteUtils * Replace an old Hop with a replacement Hop. * If the old Hop has no parents, then return the replacement. * Otherwise rewire each of the Hop's parents into the replacement and return the replacement. + * @param old To be replaced + * @param replacement The replacement * @return replacement */ public static Hop replaceHop(final Hop old, final Hop replacement) { http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java index d483a08..5cd1471 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java @@ -42,6 +42,7 @@ import com.google.common.collect.Multiset; * * Rewrite a chain of element-wise multiply hops that contain identical elements. * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power. + * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity. * * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain, * since foreign parents may rely on the individual results. @@ -74,18 +75,15 @@ public class RewriteEMult extends HopRewriteRule { return root; root.setVisited(); - final ArrayList<Hop> rootInputs = root.getInput(); - // 1. Find immediate subtree of EMults. if (isBinaryMult(root)) { - final Hop left = rootInputs.get(0), right = rootInputs.get(1); - final BinaryOp r = (BinaryOp)root; + final Hop left = root.getInput().get(0), right = root.getInput().get(1); final Set<BinaryOp> emults = new HashSet<>(); final Multiset<Hop> leaves = HashMultiset.create(); - findEMultsAndLeaves(r, emults, leaves); + findEMultsAndLeaves((BinaryOp)root, emults, leaves); // 2. Ensure it is profitable to do a rewrite. - if (isOptimizable(leaves)) { + if (isOptimizable(emults, leaves)) { // 3. Check for foreign parents. // A foreign parent is a parent of some EMult that is not in the set. // Foreign parents destroy correctness of this rewrite. @@ -94,25 +92,35 @@ public class RewriteEMult extends HopRewriteRule { if (okay) { // 4. Construct replacement EMults for the leaves final Hop replacement = constructReplacement(leaves); - // 5. Replace root with replacement if (LOG.isDebugEnabled()) LOG.debug(String.format( "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d", emults.size(), root.getHopID(), replacement.getHopID())); - replacement.setVisited(); - return HopRewriteUtils.replaceHop(root, replacement); + + // 5. Replace root with replacement + final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement); + + // 6. Recurse at leaves (no need to repeat the interior emults) + for (final Hop leaf : leaves.elementSet()) { + recurseInputs(leaf); + } + return newRoot; } } } - // This rewrite is not applicable to the current root. // Try the root's children. - for (int i = 0; i < rootInputs.size(); i++) { - final Hop input = rootInputs.get(i); + recurseInputs(root); + return root; + } + + private static void recurseInputs(final Hop parent) { + final ArrayList<Hop> inputs = parent.getInput(); + for (int i = 0; i < inputs.size(); i++) { + final Hop input = inputs.get(i); final Hop newInput = rule_RewriteEMult(input); - rootInputs.set(i, newInput); + inputs.set(i, newInput); } - return root; } private static Hop constructReplacement(final Multiset<Hop> leaves) { @@ -133,6 +141,7 @@ public class RewriteEMult extends HopRewriteRule { for (int i = 1; i < sorted.size(); i++) { final Hop second = constructPower(iterator.next()); first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT); + first.setVisited(); } return first; } @@ -141,16 +150,21 @@ public class RewriteEMult extends HopRewriteRule { final Hop hop = entry.getKey(); final int cnt = entry.getValue(); assert(cnt >= 1); - if (cnt == 1) return hop; - return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); + if (cnt == 1) + return hop; // don't set this visited... we will visit this next + Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); + pow.setVisited(); + return pow; } - - - // Order: scalars > row vectors > col vectors > - // non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) > - // other data types - // disambiguate by Hop ID + /** + * A Comparator that orders Hops by their data type, dimention, and sparsity. + * The order is as follows: + * scalars > row vectors > col vectors > + * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) > + * other data types. + * Disambiguate by Hop ID. + */ private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() { @Override public final int compare(Hop o1, Hop o2) { @@ -211,6 +225,12 @@ public class RewriteEMult extends HopRewriteRule { } }; + /** + * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults. + * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. + * @param child An interior emult hop in the emult chain dag. + * @return Whether this interior emult or any child emult has a foreign parent. + */ private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) { final ArrayList<Hop> parents = child.getParent(); if (parents.size() > 1) @@ -227,7 +247,10 @@ public class RewriteEMult extends HopRewriteRule { } /** - * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root. + * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively. + * @param root Root of sub-dag + * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root). + * @param leaves Out parameter. The multiset of multiplicands in the emult chain. */ private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) { // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality. @@ -243,12 +266,19 @@ public class RewriteEMult extends HopRewriteRule { else leaves.add(right); } - /** Only optimize a subtree of EMults if at least one leaf occurs more than once. */ - private static boolean isOptimizable(final Multiset<Hop> set) { - for (Multiset.Entry<Hop> hopEntry : set.entrySet()) { - if (hopEntry.getCount() > 1) - return true; - } - return false; + /** + * Only optimize a subtree of emults if there are at least two emults. + * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. + * @param leaves The multiset of multiplicands in the emult chain. + * @return If the multiset is worth optimizing. + */ + private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) { + // Old criterion: there should be at least one repeated leaf +// for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) { +// if (hopEntry.getCount() > 1) +// return true; +// } +// return false; + return emults.size() >= 2; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/parser/Expression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java index b944e29..9ee3fba 100644 --- a/src/main/java/org/apache/sysml/parser/Expression.java +++ b/src/main/java/org/apache/sysml/parser/Expression.java @@ -162,7 +162,6 @@ public abstract class Expression * Data types (matrix, scalar, frame, object, unknown). */ public enum DataType { - // Careful: the order of these enums is significant! See RewriteEMult.comparatorByDataType MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN; public boolean isMatrix() {