Minor optimization inside rewrite rule.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e93c487e Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e93c487e Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e93c487e Branch: refs/heads/master Commit: e93c487ef1778934c94fac291c6e76651041c961 Parents: d18a4c8 Author: Dylan Hutchison <[email protected]> Authored: Tue Jul 11 23:04:21 2017 -0700 Committer: Dylan Hutchison <[email protected]> Committed: Tue Jul 11 23:04:21 2017 -0700 ---------------------------------------------------------------------- ...RewriteElementwiseMultChainOptimization.java | 38 ++++++++------------ 1 file changed, 15 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e93c487e/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java index 1f85bbf..41fc61d 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java @@ -26,8 +26,8 @@ import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; -import java.util.SortedMap; -import java.util.TreeMap; +import java.util.SortedSet; +import java.util.TreeSet; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; @@ -85,12 +85,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { // 1. Find immediate subtree of EMults. Check dimsKnown. if (isBinaryMult(root) && root.dimsKnown()) { final Hop left = root.getInput().get(0), right = root.getInput().get(1); + // The set of BinaryOp element-wise multiply hops in the emult chain. final Set<BinaryOp> emults = new HashSet<>(); + // The multiset of multiplicands in the emult chain. final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset findEMultsAndLeaves((BinaryOp)root, emults, leaves); // 2. Ensure it is profitable to do a rewrite. - if (isOptimizable(emults, leaves)) { + // Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands). + if (emults.size() >= 2) { // 3. Check for foreign parents. // A foreign parent is a parent of some EMult that is not in the set. // Foreign parents destroy correctness of this rewrite. @@ -132,20 +135,20 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) { // Sort by data type - final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType); + final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType); for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) { final Hop h = entry.getKey(); // unlink parents that are in the emult set(we are throwing them away) // keep other parents h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent)); - sorted.put(h, entry.getValue()); + sorted.add(constructPower(h, entry.getValue())); } // sorted contains all leaves, sorted by data type, stripped from their parents // Construct right-deep EMult tree - final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator(); + final Iterator<Hop> iterator = sorted.iterator(); - Hop next = iterator.hasNext() ? constructPower(iterator.next()) : null; + Hop next = iterator.hasNext() ? iterator.next() : null; Hop colVectorsScalars = null; while(next != null && (next.getDataType() == Expression.DataType.SCALAR @@ -157,7 +160,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT); colVectorsScalars.setVisited(); } - next = iterator.hasNext() ? constructPower(iterator.next()) : null; + next = iterator.hasNext() ? iterator.next() : null; } // next is not processed and is either null or past col vectors @@ -171,7 +174,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT); rowVectors.setVisited(); } - next = iterator.hasNext() ? constructPower(iterator.next()) : null; + next = iterator.hasNext() ? iterator.next() : null; } // next is not processed and is either null or past row vectors @@ -185,7 +188,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT); matrices.setVisited(); } - next = iterator.hasNext() ? constructPower(iterator.next()) : null; + next = iterator.hasNext() ? iterator.next() : null; } // next is not processed and is either null or past matrices @@ -198,7 +201,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT); other.setVisited(); } - next = iterator.hasNext() ? constructPower(iterator.next()) : null; + next = iterator.hasNext() ? iterator.next() : null; } // finished @@ -230,9 +233,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { return top; } - private static Hop constructPower(final Map.Entry<Hop, Integer> entry) { - final Hop hop = entry.getKey(); - final int cnt = entry.getValue(); + private static Hop constructPower(final Hop hop, final int cnt) { assert(cnt >= 1); hop.setVisited(); // we will visit the leaves' children next if (cnt == 1) @@ -345,13 +346,4 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { map.put(k, map.getOrDefault(k, 0) + 1); } - /** - * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands). - * @param emults The set of BinaryOp element-wise multiply hops in the emult chain. - * @param leaves The multiset of multiplicands in the emult chain. - * @return If the multiset is worth optimizing. - */ - private static boolean isOptimizable(Set<BinaryOp> emults, final Map<Hop, Integer> leaves) { - return emults.size() >= 2; - } }
