[SYSTEMML-2283] Fix performance issue CSE on DAGs w/ many root nodes This patch fixes a severe performance issue of the compiler rewrite for common subexpression elimination (CSE) for the case of DAGs with many root nodes. The issues showed up on resnet 200 because this script contains DAGs with 2000+ root nodes. In detail, the issue was due to incorrect reset-merge of root nodes which led to (predicated on the graph structure) in the worst-case to processing the entire DAG times the number of root nodes. This patch fixes the issue and generally improves performance by using better keys for literal ops (value type and name) to avoid unnecessary string concatenation. Together, these changes improved the compilation time of resnet200 from 22min to <10s.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/686e3831 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/686e3831 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/686e3831 Branch: refs/heads/master Commit: 686e3831dee5bb87833d0192a098c97b13fe1c20 Parents: b3fef52 Author: Matthias Boehm <[email protected]> Authored: Thu Apr 26 23:46:39 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Apr 27 00:03:21 2018 -0700 ---------------------------------------------------------------------- .../RewriteCommonSubexpressionElimination.java | 120 ++++++++++--------- 1 file changed, 66 insertions(+), 54 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/686e3831/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java index f0cc46e..f359422 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java @@ -25,7 +25,9 @@ import java.util.HashMap; import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.hops.Hop.DataOpTypes; +import org.apache.sysml.runtime.util.UtilFunctions; /** * Rule: CommonSubexpressionElimination. For all statement blocks, @@ -36,16 +38,13 @@ import org.apache.sysml.hops.Hop.DataOpTypes; */ public class RewriteCommonSubexpressionElimination extends HopRewriteRule { + private final boolean _mergeLeafs; - private boolean _mergeLeafs = true; - - public RewriteCommonSubexpressionElimination() - { + public RewriteCommonSubexpressionElimination() { this( true ); //default full CSE } - public RewriteCommonSubexpressionElimination( boolean mergeLeafs ) - { + public RewriteCommonSubexpressionElimination( boolean mergeLeafs ) { _mergeLeafs = mergeLeafs; } @@ -55,21 +54,23 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule if( roots == null ) return null; - HashMap<String, Hop> dataops = new HashMap<>(); - HashMap<String, Hop> literalops = new HashMap<>(); //key: <VALUETYPE>_<LITERAL> - for (Hop h : roots) - { - int cseMerged = 0; - if( _mergeLeafs ) { + //CSE pass 1: merge leaf nodes by name + int cseMerged = 0; + if( _mergeLeafs ) { + HashMap<String, Hop> dataops = new HashMap<>(); + HashMap<LiteralKey, Hop> literalops = new HashMap<>(); + for (Hop h : roots) cseMerged += rule_CommonSubexpressionElimination_MergeLeafs(h, dataops, literalops); - h.resetVisitStatus(); - } - cseMerged += rule_CommonSubexpressionElimination(h); - - if( cseMerged > 0 ) - LOG.debug("Common Subexpression Elimination - removed "+cseMerged+" operators."); + Hop.resetVisitStatus(roots); } + //CSE pass 2: bottom-up merge of inner nodes + for (Hop h : roots) + cseMerged += rule_CommonSubexpressionElimination(h); + + if( cseMerged > 0 ) + LOG.debug("Common Subexpression Elimination - removed "+cseMerged+" operators."); + return roots; } @@ -79,13 +80,16 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule if( root == null ) return null; - HashMap<String, Hop> dataops = new HashMap<>(); - HashMap<String, Hop> literalops = new HashMap<>(); //key: <VALUETYPE>_<LITERAL> + //CSE pass 1: merge leaf nodes by name int cseMerged = 0; if( _mergeLeafs ) { + HashMap<String, Hop> dataops = new HashMap<>(); + HashMap<LiteralKey, Hop> literalops = new HashMap<>(); cseMerged += rule_CommonSubexpressionElimination_MergeLeafs(root, dataops, literalops); root.resetVisitStatus(); } + + //CSE pass 2: bottom-up merge of inner nodes cseMerged += rule_CommonSubexpressionElimination(root); if( cseMerged > 0 ) @@ -94,25 +98,24 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule return root; } - private int rule_CommonSubexpressionElimination_MergeLeafs( Hop hop, HashMap<String, Hop> dataops, HashMap<String, Hop> literalops ) + private int rule_CommonSubexpressionElimination_MergeLeafs( Hop hop, + HashMap<String, Hop> dataops, HashMap<LiteralKey, Hop> literalops ) { - int ret = 0; if( hop.isVisited() ) - return ret; - + return 0; + + int ret = 0; if( hop.getInput().isEmpty() //LEAF NODE || HopRewriteUtils.isData(hop, DataOpTypes.TRANSIENTREAD) ) { - if( hop instanceof LiteralOp ) - { - String key = hop.getValueType()+"_"+hop.getName(); + if( hop instanceof LiteralOp ) { + LiteralKey key = new LiteralKey(hop.getValueType(), hop.getName()); if( !literalops.containsKey(key) ) literalops.put(key, hop); } - else if( hop instanceof DataOp && ((DataOp)hop).isRead()) - { - if(!dataops.containsKey(hop.getName()) ) - dataops.put(hop.getName(), hop); + else if( hop instanceof DataOp && ((DataOp)hop).isRead() + && !dataops.containsKey(hop.getName())) { + dataops.put(hop.getName(), hop); } } else //INNER NODE @@ -121,10 +124,8 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule for( int i=0; i<hop.getInput().size(); i++ ) { Hop hi = hop.getInput().get(i); - String litKey = hi.getValueType()+"_"+hi.getName(); - if( hi instanceof DataOp && ((DataOp)hi).isRead() && dataops.containsKey(hi.getName()) ) - { - + LiteralKey litKey = new LiteralKey(hi.getValueType(), hi.getName()); + if( hi instanceof DataOp && ((DataOp)hi).isRead() && dataops.containsKey(hi.getName()) ) { //replace child node ref Hop tmp = dataops.get(hi.getName()); if( tmp != hi ) { //if required @@ -134,10 +135,8 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule ret++; } } - else if( hi instanceof LiteralOp && literalops.containsKey(litKey) ) - { + else if( hi instanceof LiteralOp && literalops.containsKey(litKey) ) { Hop tmp = literalops.get(litKey); - //replace child node ref if( tmp != hi ){ //if required tmp.getParent().add(hop); @@ -148,37 +147,33 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule } //recursive invocation (direct return on merged nodes) - ret += rule_CommonSubexpressionElimination_MergeLeafs(hi, dataops, literalops); - } + ret += rule_CommonSubexpressionElimination_MergeLeafs(hi, dataops, literalops); + } } - hop.setVisited(); return ret; } private int rule_CommonSubexpressionElimination( Hop hop ) { - int ret = 0; if( hop.isVisited() ) - return ret; - + return 0; + //step 1: merge childs recursively first + int ret = 0; for(Hop hi : hop.getInput()) ret += rule_CommonSubexpressionElimination(hi); - //step 2: merge parent nodes if( hop.getParent().size()>1 ) //multiple consumers { //for all pairs for( int i=0; i<hop.getParent().size()-1; i++ ) - for( int j=i+1; j<hop.getParent().size(); j++ ) - { + for( int j=i+1; j<hop.getParent().size(); j++ ) { Hop h1 = hop.getParent().get(i); Hop h2 = hop.getParent().get(j); - if( h1==h2 ) - { + if( h1==h2 ) { //do nothing, note: we should not remove redundant parent links //(otherwise rewrites would need to take this property into account) @@ -186,8 +181,7 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule //hop.getParent().remove(j); //j--; } - else if( h1.compare(h2) ) //merge h2 into h1 - { + else if( h1.compare(h2) ) { //merge h2 into h1 //remove h2 from parent list hop.getParent().remove(j); @@ -195,8 +189,7 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule ArrayList<Hop> parent = h2.getParent(); for( Hop p : parent ) for( int k=0; k<p.getInput().size(); k++ ) - if( p.getInput().get(k)==h2 ) - { + if( p.getInput().get(k)==h2 ) { p.getInput().set(k, h1); h1.getParent().add(p); h1.setVisited(); @@ -213,7 +206,26 @@ public class RewriteCommonSubexpressionElimination extends HopRewriteRule } hop.setVisited(); - return ret; } + + protected static class LiteralKey { + private final int _vtType; + private final String _name; + + public LiteralKey(ValueType vt, String name) { + _vtType = vt.ordinal(); + _name = name; + } + @Override + public int hashCode() { + return UtilFunctions.longHashCode(_vtType, _name.hashCode()); + } + @Override + public boolean equals(Object o) { + return (o instanceof LiteralKey + && _vtType == ((LiteralKey)o)._vtType + && _name.equals(((LiteralKey)o)._name)); + } + } }
