Document RewriteEMult. Add smart recursion.

RewriteEMult now rewrites emult chains deeper than the top-most one.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d88f867f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d88f867f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d88f867f

Branch: refs/heads/master
Commit: d88f867fd0384954dce9e6ce4d65f02f1054bc5e
Parents: a5846bb
Author: Dylan Hutchison <dhutc...@cs.washington.edu>
Authored: Sat Jun 10 01:17:36 2017 -0700
Committer: Dylan Hutchison <dhutc...@cs.washington.edu>
Committed: Sun Jun 18 17:43:33 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  2 +
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 90 +++++++++++++-------
 .../org/apache/sysml/parser/Expression.java     |  1 -
 3 files changed, 62 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 4d23cb9..17ac4ec 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -246,6 +246,8 @@ public class HopRewriteUtils
         * Replace an old Hop with a replacement Hop.
         * If the old Hop has no parents, then return the replacement.
         * Otherwise rewire each of the Hop's parents into the replacement and 
return the replacement.
+        * @param old To be replaced
+        * @param replacement The replacement
         * @return replacement
         */
        public static Hop replaceHop(final Hop old, final Hop replacement) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index d483a08..5cd1471 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -42,6 +42,7 @@ import com.google.common.collect.Multiset;
  *
  * Rewrite a chain of element-wise multiply hops that contain identical 
elements.
  * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), 
where `^` is element-wise power.
+ * The order of the multiplicands depends on their data types, dimentions 
(matrix or vector), and sparsity.
  *
  * Does not rewrite in the presence of foreign parents in the middle of the 
e-wise multiply chain,
  * since foreign parents may rely on the individual results.
@@ -74,18 +75,15 @@ public class RewriteEMult extends HopRewriteRule {
                        return root;
                root.setVisited();
 
-               final ArrayList<Hop> rootInputs = root.getInput();
-
                // 1. Find immediate subtree of EMults.
                if (isBinaryMult(root)) {
-                       final Hop left = rootInputs.get(0), right = 
rootInputs.get(1);
-                       final BinaryOp r = (BinaryOp)root;
+                       final Hop left = root.getInput().get(0), right = 
root.getInput().get(1);
                        final Set<BinaryOp> emults = new HashSet<>();
                        final Multiset<Hop> leaves = HashMultiset.create();
-                       findEMultsAndLeaves(r, emults, leaves);
+                       findEMultsAndLeaves((BinaryOp)root, emults, leaves);
 
                        // 2. Ensure it is profitable to do a rewrite.
-                       if (isOptimizable(leaves)) {
+                       if (isOptimizable(emults, leaves)) {
                                // 3. Check for foreign parents.
                                // A foreign parent is a parent of some EMult 
that is not in the set.
                                // Foreign parents destroy correctness of this 
rewrite.
@@ -94,25 +92,35 @@ public class RewriteEMult extends HopRewriteRule {
                                if (okay) {
                                        // 4. Construct replacement EMults for 
the leaves
                                        final Hop replacement = 
constructReplacement(leaves);
-                                       // 5. Replace root with replacement
                                        if (LOG.isDebugEnabled())
                                                LOG.debug(String.format(
                                                                "Element-wise 
multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
                                                                emults.size(), 
root.getHopID(), replacement.getHopID()));
-                                       replacement.setVisited();
-                                       return HopRewriteUtils.replaceHop(root, 
replacement);
+
+                                       // 5. Replace root with replacement
+                                       final Hop newRoot = 
HopRewriteUtils.replaceHop(root, replacement);
+
+                                       // 6. Recurse at leaves (no need to 
repeat the interior emults)
+                                       for (final Hop leaf : 
leaves.elementSet()) {
+                                               recurseInputs(leaf);
+                                       }
+                                       return newRoot;
                                }
                        }
                }
-
                // This rewrite is not applicable to the current root.
                // Try the root's children.
-               for (int i = 0; i < rootInputs.size(); i++) {
-                       final Hop input = rootInputs.get(i);
+               recurseInputs(root);
+               return root;
+       }
+
+       private static void recurseInputs(final Hop parent) {
+               final ArrayList<Hop> inputs = parent.getInput();
+               for (int i = 0; i < inputs.size(); i++) {
+                       final Hop input = inputs.get(i);
                        final Hop newInput = rule_RewriteEMult(input);
-                       rootInputs.set(i, newInput);
+                       inputs.set(i, newInput);
                }
-               return root;
        }
 
        private static Hop constructReplacement(final Multiset<Hop> leaves) {
@@ -133,6 +141,7 @@ public class RewriteEMult extends HopRewriteRule {
                for (int i = 1; i < sorted.size(); i++) {
                        final Hop second = constructPower(iterator.next());
                        first = HopRewriteUtils.createBinary(second, first, 
Hop.OpOp2.MULT);
+                       first.setVisited();
                }
                return first;
        }
@@ -141,16 +150,21 @@ public class RewriteEMult extends HopRewriteRule {
                final Hop hop = entry.getKey();
                final int cnt = entry.getValue();
                assert(cnt >= 1);
-               if (cnt == 1) return hop;
-               return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), 
Hop.OpOp2.POW);
+               if (cnt == 1)
+                       return hop; // don't set this visited... we will visit 
this next
+               Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), 
Hop.OpOp2.POW);
+               pow.setVisited();
+               return pow;
        }
 
-
-
-       // Order: scalars > row vectors > col vectors >
-       //        non-vector matrices ordered by sparsity (higher nnz first, 
unknown sparsity last) >
-       //        other data types
-       // disambiguate by Hop ID
+       /**
+        * A Comparator that orders Hops by their data type, dimention, and 
sparsity.
+        * The order is as follows:
+        *              scalars > row vectors > col vectors >
+        *      non-vector matrices ordered by sparsity (higher nnz first, 
unknown sparsity last) >
+        *      other data types.
+        * Disambiguate by Hop ID.
+        */
        private static final Comparator<Hop> compareByDataType = new 
Comparator<Hop>() {
                @Override
                public final int compare(Hop o1, Hop o2) {
@@ -211,6 +225,12 @@ public class RewriteEMult extends HopRewriteRule {
                }
        };
 
+       /**
+        * Check if a node has a parent that is not in the set of emults. 
Recursively check children who are also emults.
+        * @param emults The set of BinaryOp element-wise multiply hops in the 
emult chain.
+        * @param child An interior emult hop in the emult chain dag.
+        * @return Whether this interior emult or any child emult has a foreign 
parent.
+        */
        private static boolean checkForeignParent(final Set<BinaryOp> emults, 
final BinaryOp child) {
                final ArrayList<Hop> parents = child.getParent();
                if (parents.size() > 1)
@@ -227,7 +247,10 @@ public class RewriteEMult extends HopRewriteRule {
        }
 
        /**
-        * Create a set of the counts of all BinaryOp MULTs in the immediate 
subtree, starting with root.
+        * Create a set of the counts of all BinaryOp MULTs in the immediate 
subtree, starting with root, recursively.
+        * @param root Root of sub-dag
+        * @param emults Out parameter. The set of BinaryOp element-wise 
multiply hops in the emult chain (including root).
+        * @param leaves Out parameter. The multiset of multiplicands in the 
emult chain.
         */
        private static void findEMultsAndLeaves(final BinaryOp root, final 
Set<BinaryOp> emults, final Multiset<Hop> leaves) {
                // Because RewriteCommonSubexpressionElimination already ran, 
it is safe to compare by equality.
@@ -243,12 +266,19 @@ public class RewriteEMult extends HopRewriteRule {
                else leaves.add(right);
        }
 
-       /** Only optimize a subtree of EMults if at least one leaf occurs more 
than once. */
-       private static boolean isOptimizable(final Multiset<Hop> set) {
-               for (Multiset.Entry<Hop> hopEntry : set.entrySet()) {
-                       if (hopEntry.getCount() > 1)
-                               return true;
-               }
-               return false;
+       /**
+        * Only optimize a subtree of emults if there are at least two emults.
+        * @param emults The set of BinaryOp element-wise multiply hops in the 
emult chain.
+        * @param leaves The multiset of multiplicands in the emult chain.
+        * @return If the multiset is worth optimizing.
+        */
+       private static boolean isOptimizable(Set<BinaryOp> emults, final 
Multiset<Hop> leaves) {
+               // Old criterion: there should be at least one repeated leaf
+//             for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) {
+//                     if (hopEntry.getCount() > 1)
+//                             return true;
+//             }
+//             return false;
+               return emults.size() >= 2;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java 
b/src/main/java/org/apache/sysml/parser/Expression.java
index b944e29..9ee3fba 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -162,7 +162,6 @@ public abstract class Expression
         * Data types (matrix, scalar, frame, object, unknown).
         */
        public enum DataType {
-               // Careful: the order of these enums is significant! See 
RewriteEMult.comparatorByDataType
                MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN;
                
                public boolean isMatrix() {

Reply via email to