Minor optimization inside rewrite rule.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e93c487e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e93c487e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e93c487e

Branch: refs/heads/master
Commit: e93c487ef1778934c94fac291c6e76651041c961
Parents: d18a4c8
Author: Dylan Hutchison <[email protected]>
Authored: Tue Jul 11 23:04:21 2017 -0700
Committer: Dylan Hutchison <[email protected]>
Committed: Tue Jul 11 23:04:21 2017 -0700

----------------------------------------------------------------------
 ...RewriteElementwiseMultChainOptimization.java | 38 ++++++++------------
 1 file changed, 15 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e93c487e/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 1f85bbf..41fc61d 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -26,8 +26,8 @@ import java.util.HashSet;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
+import java.util.SortedSet;
+import java.util.TreeSet;
 
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
@@ -85,12 +85,15 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                // 1. Find immediate subtree of EMults. Check dimsKnown.
                if (isBinaryMult(root) && root.dimsKnown()) {
                        final Hop left = root.getInput().get(0), right = 
root.getInput().get(1);
+                       // The set of BinaryOp element-wise multiply hops in 
the emult chain.
                        final Set<BinaryOp> emults = new HashSet<>();
+                       // The multiset of multiplicands in the emult chain.
                        final Map<Hop, Integer> leaves = new HashMap<>(); // 
poor man's HashMultiset
                        findEMultsAndLeaves((BinaryOp)root, emults, leaves);
 
                        // 2. Ensure it is profitable to do a rewrite.
-                       if (isOptimizable(emults, leaves)) {
+                       // Only optimize a subtree of emults if there are at 
least two emults (meaning, at least 3 multiplicands).
+                       if (emults.size() >= 2) {
                                // 3. Check for foreign parents.
                                // A foreign parent is a parent of some EMult 
that is not in the set.
                                // Foreign parents destroy correctness of this 
rewrite.
@@ -132,20 +135,20 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
 
        private static Hop constructReplacement(final Set<BinaryOp> emults, 
final Map<Hop, Integer> leaves) {
                // Sort by data type
-               final SortedMap<Hop,Integer> sorted = new 
TreeMap<>(compareByDataType);
+               final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType);
                for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
                        final Hop h = entry.getKey();
                        // unlink parents that are in the emult set(we are 
throwing them away)
                        // keep other parents
                        h.getParent().removeIf(parent -> parent instanceof 
BinaryOp && emults.contains(parent));
-                       sorted.put(h, entry.getValue());
+                       sorted.add(constructPower(h, entry.getValue()));
                }
                // sorted contains all leaves, sorted by data type, stripped 
from their parents
 
                // Construct right-deep EMult tree
-               final Iterator<Map.Entry<Hop, Integer>> iterator = 
sorted.entrySet().iterator();
+               final Iterator<Hop> iterator = sorted.iterator();
 
-               Hop next = iterator.hasNext() ? constructPower(iterator.next()) 
: null;
+               Hop next = iterator.hasNext() ? iterator.next() : null;
                Hop colVectorsScalars = null;
                while(next != null &&
                                (next.getDataType() == 
Expression.DataType.SCALAR
@@ -157,7 +160,7 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                                colVectorsScalars = 
HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
                                colVectorsScalars.setVisited();
                        }
-                       next = iterator.hasNext() ? 
constructPower(iterator.next()) : null;
+                       next = iterator.hasNext() ? iterator.next() : null;
                }
                // next is not processed and is either null or past col vectors
 
@@ -171,7 +174,7 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                                rowVectors = 
HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
                                rowVectors.setVisited();
                        }
-                       next = iterator.hasNext() ? 
constructPower(iterator.next()) : null;
+                       next = iterator.hasNext() ? iterator.next() : null;
                }
                // next is not processed and is either null or past row vectors
 
@@ -185,7 +188,7 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                                matrices = 
HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
                                matrices.setVisited();
                        }
-                       next = iterator.hasNext() ? 
constructPower(iterator.next()) : null;
+                       next = iterator.hasNext() ? iterator.next() : null;
                }
                // next is not processed and is either null or past matrices
 
@@ -198,7 +201,7 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                                other = HopRewriteUtils.createBinary(other, 
next, Hop.OpOp2.MULT);
                                other.setVisited();
                        }
-                       next = iterator.hasNext() ? 
constructPower(iterator.next()) : null;
+                       next = iterator.hasNext() ? iterator.next() : null;
                }
                // finished
 
@@ -230,9 +233,7 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                return top;
        }
 
-       private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
-               final Hop hop = entry.getKey();
-               final int cnt = entry.getValue();
+       private static Hop constructPower(final Hop hop, final int cnt) {
                assert(cnt >= 1);
                hop.setVisited(); // we will visit the leaves' children next
                if (cnt == 1)
@@ -345,13 +346,4 @@ public class RewriteElementwiseMultChainOptimization 
extends HopRewriteRule {
                map.put(k, map.getOrDefault(k, 0) + 1);
        }
 
-       /**
-        * Only optimize a subtree of emults if there are at least two emults 
(meaning, at least 3 multiplicands).
-        * @param emults The set of BinaryOp element-wise multiply hops in the 
emult chain.
-        * @param leaves The multiset of multiplicands in the emult chain.
-        * @return If the multiset is worth optimizing.
-        */
-       private static boolean isOptimizable(Set<BinaryOp> emults, final 
Map<Hop, Integer> leaves) {
-               return emults.size() >= 2;
-       }
 }

Reply via email to