[SYSTEMML-2283] Fix performance issue CSE on DAGs w/ many root nodes

This patch fixes a severe performance issue of the compiler rewrite for
common subexpression elimination (CSE) for the case of DAGs with many
root nodes. The issues showed up on resnet 200 because this script
contains DAGs with 2000+ root nodes. In detail, the issue was due to
incorrect reset-merge of root nodes which led to (predicated on the
graph structure) in the worst-case to processing the entire DAG times
the number of root nodes. This patch fixes the issue and generally
improves performance by using better keys for literal ops (value type
and name) to avoid unnecessary string concatenation. Together, these
changes improved the compilation time of resnet200 from 22min to <10s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/686e3831
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/686e3831
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/686e3831

Branch: refs/heads/master
Commit: 686e3831dee5bb87833d0192a098c97b13fe1c20
Parents: b3fef52
Author: Matthias Boehm <[email protected]>
Authored: Thu Apr 26 23:46:39 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Apr 27 00:03:21 2018 -0700

----------------------------------------------------------------------
 .../RewriteCommonSubexpressionElimination.java  | 120 ++++++++++---------
 1 file changed, 66 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/686e3831/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
index f0cc46e..f359422 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCommonSubexpressionElimination.java
@@ -25,7 +25,9 @@ import java.util.HashMap;
 import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.runtime.util.UtilFunctions;
 
 /**
  * Rule: CommonSubexpressionElimination. For all statement blocks, 
@@ -36,16 +38,13 @@ import org.apache.sysml.hops.Hop.DataOpTypes;
  */
 public class RewriteCommonSubexpressionElimination extends HopRewriteRule
 {
+       private final boolean _mergeLeafs;
        
-       private boolean _mergeLeafs = true;
-       
-       public RewriteCommonSubexpressionElimination()
-       {
+       public RewriteCommonSubexpressionElimination() {
                this( true ); //default full CSE
        }
        
-       public RewriteCommonSubexpressionElimination( boolean mergeLeafs )
-       {
+       public RewriteCommonSubexpressionElimination( boolean mergeLeafs ) {
                _mergeLeafs = mergeLeafs;
        }
        
@@ -55,21 +54,23 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                if( roots == null )
                        return null;
                
-               HashMap<String, Hop> dataops = new HashMap<>();
-               HashMap<String, Hop> literalops = new HashMap<>(); //key: 
<VALUETYPE>_<LITERAL>
-               for (Hop h : roots) 
-               {
-                       int cseMerged = 0;
-                       if( _mergeLeafs ) {
+               //CSE pass 1: merge leaf nodes by name
+               int cseMerged = 0;
+               if( _mergeLeafs ) {
+                       HashMap<String, Hop> dataops = new HashMap<>();
+                       HashMap<LiteralKey, Hop> literalops = new HashMap<>();
+                       for (Hop h : roots)
                                cseMerged += 
rule_CommonSubexpressionElimination_MergeLeafs(h, dataops, literalops);
-                               h.resetVisitStatus();
-                       }
-                       cseMerged += rule_CommonSubexpressionElimination(h);
-                               
-                       if( cseMerged > 0 )
-                               LOG.debug("Common Subexpression Elimination - 
removed "+cseMerged+" operators.");
+                       Hop.resetVisitStatus(roots);
                }
                
+               //CSE pass 2: bottom-up merge of inner nodes
+               for (Hop h : roots) 
+                       cseMerged += rule_CommonSubexpressionElimination(h);
+               
+               if( cseMerged > 0 )
+                       LOG.debug("Common Subexpression Elimination - removed 
"+cseMerged+" operators.");
+               
                return roots;
        }
 
@@ -79,13 +80,16 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                if( root == null )
                        return null;
                
-               HashMap<String, Hop> dataops = new HashMap<>();
-               HashMap<String, Hop> literalops = new HashMap<>(); //key: 
<VALUETYPE>_<LITERAL>
+               //CSE pass 1: merge leaf nodes by name
                int cseMerged = 0;
                if( _mergeLeafs ) {
+                       HashMap<String, Hop> dataops = new HashMap<>();
+                       HashMap<LiteralKey, Hop> literalops = new HashMap<>();
                        cseMerged += 
rule_CommonSubexpressionElimination_MergeLeafs(root, dataops, literalops);
                        root.resetVisitStatus();
                }
+               
+               //CSE pass 2: bottom-up merge of inner nodes
                cseMerged += rule_CommonSubexpressionElimination(root);
                
                if( cseMerged > 0 )
@@ -94,25 +98,24 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                return root;
        }
        
-       private int rule_CommonSubexpressionElimination_MergeLeafs( Hop hop, 
HashMap<String, Hop> dataops, HashMap<String, Hop> literalops ) 
+       private int rule_CommonSubexpressionElimination_MergeLeafs( Hop hop,
+               HashMap<String, Hop> dataops, HashMap<LiteralKey, Hop> 
literalops ) 
        {
-               int ret = 0;
                if( hop.isVisited() )
-                       return ret;
-
+                       return 0;
+               
+               int ret = 0;
                if( hop.getInput().isEmpty() //LEAF NODE
                        || HopRewriteUtils.isData(hop, 
DataOpTypes.TRANSIENTREAD) )
                {
-                       if( hop instanceof LiteralOp )
-                       {
-                               String key = 
hop.getValueType()+"_"+hop.getName();
+                       if( hop instanceof LiteralOp ) {
+                               LiteralKey key = new 
LiteralKey(hop.getValueType(), hop.getName());
                                if( !literalops.containsKey(key) )
                                        literalops.put(key, hop);
                        }
-                       else if( hop instanceof DataOp && 
((DataOp)hop).isRead())
-                       {
-                               if(!dataops.containsKey(hop.getName()) )
-                                       dataops.put(hop.getName(), hop);
+                       else if( hop instanceof DataOp && ((DataOp)hop).isRead()
+                               && !dataops.containsKey(hop.getName())) {
+                               dataops.put(hop.getName(), hop);
                        } 
                }
                else //INNER NODE
@@ -121,10 +124,8 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                        for( int i=0; i<hop.getInput().size(); i++ )
                        {
                                Hop hi = hop.getInput().get(i);
-                               String litKey = 
hi.getValueType()+"_"+hi.getName();
-                               if( hi instanceof DataOp && 
((DataOp)hi).isRead() && dataops.containsKey(hi.getName()) )
-                               {
-                                       
+                               LiteralKey litKey = new 
LiteralKey(hi.getValueType(), hi.getName());
+                               if( hi instanceof DataOp && 
((DataOp)hi).isRead() && dataops.containsKey(hi.getName()) ) {
                                        //replace child node ref
                                        Hop tmp = dataops.get(hi.getName());
                                        if( tmp != hi ) { //if required
@@ -134,10 +135,8 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                                                ret++;
                                        }
                                }
-                               else if( hi instanceof LiteralOp && 
literalops.containsKey(litKey) )
-                               {
+                               else if( hi instanceof LiteralOp && 
literalops.containsKey(litKey) ) {
                                        Hop tmp = literalops.get(litKey);
-                                       
                                        //replace child node ref
                                        if( tmp != hi ){ //if required
                                                tmp.getParent().add(hop);
@@ -148,37 +147,33 @@ public class RewriteCommonSubexpressionElimination 
extends HopRewriteRule
                                }
                                
                                //recursive invocation (direct return on merged 
nodes)
-                               ret += 
rule_CommonSubexpressionElimination_MergeLeafs(hi, dataops, literalops);        
 
-                       }       
+                               ret += 
rule_CommonSubexpressionElimination_MergeLeafs(hi, dataops, literalops);
+                       }
                }
-               
                hop.setVisited();
                return ret;
        }
 
        private int rule_CommonSubexpressionElimination( Hop hop ) 
        {
-               int ret = 0;
                if( hop.isVisited() )
-                       return ret;
-
+                       return 0;
+               
                //step 1: merge childs recursively first
+               int ret = 0;
                for(Hop hi : hop.getInput())
                        ret += rule_CommonSubexpressionElimination(hi); 
                
-               
                //step 2: merge parent nodes
                if( hop.getParent().size()>1 ) //multiple consumers
                {
                        //for all pairs 
                        for( int i=0; i<hop.getParent().size()-1; i++ )
-                               for( int j=i+1; j<hop.getParent().size(); j++ )
-                               {
+                               for( int j=i+1; j<hop.getParent().size(); j++ ) 
{
                                        Hop h1 = hop.getParent().get(i);
                                        Hop h2 = hop.getParent().get(j);
                                        
-                                       if( h1==h2 )
-                                       {
+                                       if( h1==h2 ) {
                                                //do nothing, note: we should 
not remove redundant parent links
                                                //(otherwise rewrites would 
need to take this property into account) 
                                                
@@ -186,8 +181,7 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                                                //hop.getParent().remove(j);
                                                //j--;
                                        }
-                                       else if( h1.compare(h2) ) //merge h2 
into h1
-                                       {
+                                       else if( h1.compare(h2) ) { //merge h2 
into h1
                                                //remove h2 from parent list
                                                hop.getParent().remove(j);
                                                
@@ -195,8 +189,7 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                                                ArrayList<Hop> parent = 
h2.getParent();
                                                for( Hop p : parent )
                                                        for( int k=0; 
k<p.getInput().size(); k++ )
-                                                               if( 
p.getInput().get(k)==h2 )
-                                                               {
+                                                               if( 
p.getInput().get(k)==h2 ) {
                                                                        
p.getInput().set(k, h1);
                                                                        
h1.getParent().add(p);
                                                                        
h1.setVisited();
@@ -213,7 +206,26 @@ public class RewriteCommonSubexpressionElimination extends 
HopRewriteRule
                }
                
                hop.setVisited();
-
                return ret;
        }
+       
+       protected static class LiteralKey {
+               private final int _vtType;
+               private final String _name;
+               
+               public LiteralKey(ValueType vt, String name) {
+                       _vtType = vt.ordinal();
+                       _name = name;
+               }
+               @Override
+               public int hashCode() {
+                       return UtilFunctions.longHashCode(_vtType, 
_name.hashCode());
+               }
+               @Override 
+               public boolean equals(Object o) {
+                       return (o instanceof LiteralKey
+                               && _vtType == ((LiteralKey)o)._vtType
+                               && _name.equals(((LiteralKey)o)._name));
+               }
+       }
 }

Reply via email to