Fix visit status bug
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0a8936cd Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0a8936cd Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0a8936cd Branch: refs/heads/master Commit: 0a8936cd849d74baced732f45f1c53812abce537 Parents: d6d3795 Author: Dylan Hutchison <[email protected]> Authored: Sun Jun 11 03:55:25 2017 -0700 Committer: Dylan Hutchison <[email protected]> Committed: Sun Jun 18 17:43:48 2017 -0700 ---------------------------------------------------------------------- .../apache/sysml/hops/rewrite/HopDagValidator.java | 5 ++++- .../RewriteElementwiseMultChainOptimization.java | 17 +++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0a8936cd/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java index 8cb5e1e..9ac21fc 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java @@ -35,6 +35,8 @@ import org.apache.sysml.parser.Expression; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.utils.Explain; +import com.google.common.collect.Lists; + import static org.apache.sysml.hops.HopsException.check; /** @@ -89,7 +91,8 @@ public class HopDagValidator { //check visit status final boolean seen = !state.seen.add(id); check(seen == hop.isVisited(), hop, - "seen previously is %b but does not match hop visit status", seen); + "(parents: %s) seen previously is %b but does not match hop visit status", + Lists.transform(hop.getParent(), Hop::getHopID), seen); if (seen) return; // we saw the Hop previously, no need to re-validate //check parent linking http://git-wip-us.apache.org/repos/asf/systemml/blob/0a8936cd/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java index 91b7306..9ca0932 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java @@ -91,7 +91,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right)); if (okay) { // 4. Construct replacement EMults for the leaves - final Hop replacement = constructReplacement(leaves); + final Hop replacement = constructReplacement(emults, leaves); if (LOG.isDebugEnabled()) LOG.debug(String.format( "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d", @@ -123,13 +123,14 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { } } - private static Hop constructReplacement(final Multiset<Hop> leaves) { + private static Hop constructReplacement(final Set<BinaryOp> emults, final Multiset<Hop> leaves) { // Sort by data type final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType); for (final Multiset.Entry<Hop> entry : leaves.entrySet()) { final Hop h = entry.getElement(); - // unlink parents (the EMults, which we are throwing away) - h.getParent().clear(); + // unlink parents that are in the emult set(we are throwing them away) + // keep other parents + h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent)); sorted.put(h, entry.getCount()); } // sorted contains all leaves, sorted by data type, stripped from their parents @@ -146,12 +147,13 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { return first; } - private static Hop constructPower(Map.Entry<Hop, Integer> entry) { + private static Hop constructPower(final Map.Entry<Hop, Integer> entry) { final Hop hop = entry.getKey(); final int cnt = entry.getValue(); assert(cnt >= 1); + hop.setVisited(); // we will visit the leaves' children next if (cnt == 1) - return hop; // don't set this visited... we will visit this next + return hop; Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); pow.setVisited(); return pow; @@ -222,8 +224,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule { final ArrayList<Hop> parents = child.getParent(); if (parents.size() > 1) for (final Hop parent : parents) - //noinspection SuspiciousMethodCalls (for Intellij, which checks when - if (!emults.contains(parent)) + if (parent instanceof BinaryOp && !emults.contains(parent)) return false; // child does not have foreign parents
