[SYSTEMML-1374] Fix codegen candidate exploration (distinct fuse/merge)

This patch fixes the code generator candidate exploration algorithm by
considering only distinct memo table entries for fuse and merge
considerations. Furthermore, this cleans up various configurations and
logging issues.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/a929ae6e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/a929ae6e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/a929ae6e

Branch: refs/heads/master
Commit: a929ae6e6b59504215403b5b5bc7110c5f180efb
Parents: 86a8e14
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Mar 23 22:14:35 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu Mar 23 22:14:35 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/api/DMLScript.java    |  7 +-
 .../sysml/hops/codegen/SpoofCompiler.java       | 43 +++++++----
 .../hops/codegen/template/CPlanMemoTable.java   | 76 +++++++++++++-------
 .../apache/sysml/hops/recompile/Recompiler.java |  6 +-
 4 files changed, 86 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a929ae6e/src/main/java/org/apache/sysml/api/DMLScript.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java 
b/src/main/java/org/apache/sysml/api/DMLScript.java
index 5bf5338..c04c321 100644
--- a/src/main/java/org/apache/sysml/api/DMLScript.java
+++ b/src/main/java/org/apache/sysml/api/DMLScript.java
@@ -57,6 +57,7 @@ import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.OptimizerUtils.OptimizationLevel;
 import org.apache.sysml.hops.codegen.SpoofCompiler;
+import org.apache.sysml.hops.codegen.SpoofCompiler.PlanCache;
 import org.apache.sysml.hops.globalopt.GlobalOptimizerWrapper;
 import org.apache.sysml.lops.Lop;
 import org.apache.sysml.lops.LopsException;
@@ -596,9 +597,9 @@ public class DMLScript
 
                //Step 5.1: Generate code for the rewrited Hop dags 
                if( dmlconf.getBooleanValue(DMLConfig.CODEGEN) ){
-                       SpoofCompiler.USE_PLAN_CACHE = 
dmlconf.getBooleanValue(DMLConfig.CODEGEN_PLANCACHE);
-                       SpoofCompiler.ALWAYS_COMPILE_LITERALS = 
(dmlconf.getIntValue(DMLConfig.CODEGEN_LITERALS)==2);
-                       
+                       SpoofCompiler.PLAN_CACHE_POLICY = PlanCache.getPolicy(
+                                       
dmlconf.getBooleanValue(DMLConfig.CODEGEN_PLANCACHE),
+                                       
dmlconf.getIntValue(DMLConfig.CODEGEN_LITERALS)==2);
                        dmlt.codgenHopsDAG(prog);
                }
                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a929ae6e/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index f1dfb91..8587367 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -78,15 +78,28 @@ public class SpoofCompiler
 {
        private static final Log LOG = 
LogFactory.getLog(SpoofCompiler.class.getName());
        
-       public static boolean OPTIMIZE = true;
-       
        //internal configuration flags
-       public static final boolean LDEBUG = false;
-       public static final boolean SUM_PRODUCT = false;
-       public static final boolean RECOMPILE = true;
-       public static boolean USE_PLAN_CACHE = true;
-       public static boolean ALWAYS_COMPILE_LITERALS = false;
-       public static final boolean ALLOW_SPARK_OPS = false;
+       public static boolean LDEBUG = false;
+       public static final boolean RECOMPILE_CODEGEN = true;
+       public static PlanCache PLAN_CACHE_POLICY = PlanCache.CSLH;
+       public static final PlanSelection PLAN_SEL_POLICY = 
PlanSelection.FUSE_ALL; 
+       public static final boolean PRUNE_REDUNDANT_PLANS = true;
+       
+       public enum PlanSelection {
+               FUSE_ALL,             //maximal fusion, possible w/ redundant 
compute
+               FUSE_NO_REDUNDANCY,   //fusion without redundant compute 
+               FUSE_COST_BASED,      //cost-based decision on materialization 
points
+       }
+
+       public enum PlanCache {
+               CONSTANT, //plan cache, with always compile literals
+               CSLH,     //plan cache, with context-sensitive literal 
replacement heuristic
+               NONE;     //no plan cache
+               
+               public static PlanCache getPolicy(boolean planCache, boolean 
compileLiterals) {
+                       return !planCache ? NONE : compileLiterals ? CONSTANT : 
CSLH;
+               }
+       }
        
        //plan cache for cplan->compiled source to avoid unnecessary 
codegen/source code compile
        //for equal operators from (1) different hop dags and (2) repeated 
recompilation 
@@ -189,7 +202,7 @@ public class SpoofCompiler
        }
        
        public static void cleanupCodeGenerator() {
-               if( USE_PLAN_CACHE ) {
+               if( PLAN_CACHE_POLICY != PlanCache.NONE ) {
                        CodegenUtils.clearClassCache(); //class cache
                        planCache.clear(); //plan cache
                }
@@ -203,11 +216,10 @@ public class SpoofCompiler
         * @return dag root nodes of modified dag 
         * @throws DMLRuntimeException if optimization failed
         */
-       @SuppressWarnings("unused")
        public static ArrayList<Hop> optimize(ArrayList<Hop> roots, boolean 
recompile) 
                throws DMLRuntimeException 
        {
-               if( roots == null || roots.isEmpty() || !OPTIMIZE )
+               if( roots == null || roots.isEmpty() )
                        return roots;
        
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
@@ -216,7 +228,7 @@ public class SpoofCompiler
                try
                {
                        //context-sensitive literal replacement (only integers 
during recompile)
-                       boolean compileLiterals = ALWAYS_COMPILE_LITERALS || 
!recompile;
+                       boolean compileLiterals = 
(PLAN_CACHE_POLICY==PlanCache.CONSTANT) || !recompile;
                        
                        //construct codegen plans
                        HashMap<Long, Pair<Hop[],CNodeTpl>>  cplans = 
constructCPlans(roots, compileLiterals);
@@ -235,7 +247,7 @@ public class SpoofCompiler
                        for( Entry<Long, Pair<Hop[],CNodeTpl>> cplan : 
cplans.entrySet() ) {
                                Pair<Hop[],CNodeTpl> tmp = cplan.getValue();
                                
-                               if( !USE_PLAN_CACHE || 
!planCache.containsKey(tmp.getValue()) ) {
+                               if( PLAN_CACHE_POLICY==PlanCache.NONE || 
!planCache.containsKey(tmp.getValue()) ) {
                                        //generate java source code
                                        String src = 
tmp.getValue().codegen(false);
                                        
@@ -336,7 +348,7 @@ public class SpoofCompiler
                //fuse and merge operator plans
                for( Hop c : hop.getInput() ) {
                        if( memo.contains(c.getHopID()) )
-                               for( MemoTableEntry me : memo.get(c.getHopID()) 
) {
+                               for( MemoTableEntry me : 
memo.getDistinct(c.getHopID()) ) {
                                        BaseTpl tpl = 
TemplateUtils.createTemplate(me.type, me.closed);
                                        if( tpl.fuse(hop, c) ) {
                                                int pos = 
hop.getInput().indexOf(c);
@@ -356,7 +368,8 @@ public class SpoofCompiler
                }
                
                //prune subsumed / redundant plans
-               memo.pruneRedundant(hop.getHopID());
+               if( PRUNE_REDUNDANT_PLANS )
+                       memo.pruneRedundant(hop.getHopID());
                
                //close operator plans, if required
                if( memo.contains(hop.getHopID()) ) {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a929ae6e/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index b0bf75b..03cb7e7 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -21,26 +21,26 @@ package org.apache.sysml.hops.codegen.template;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
+import java.util.List;
 import java.util.Map.Entry;
+import java.util.stream.Collectors;
 
 import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.SpoofCompiler;
 import org.apache.sysml.hops.codegen.template.BaseTpl.TemplateType;
 import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 
-import scala.tools.jline_embedded.internal.Log;
-
-
 public class CPlanMemoTable 
 {
-       public enum PlanSelection {
-               FUSE_ALL,             //maximal fusion, possible w/ redundant 
compute
-               FUSE_NO_REDUNDANCY,   //fusion without redundant compute 
-               FUSE_COST_BASED,      //cost-based decision on materialization 
points
-       }
+       private static final Log LOG = 
LogFactory.getLog(SpoofCompiler.class.getName());
        
        private HashMap<Long, ArrayList<MemoTableEntry>> _plans;
        private HashMap<Long, Hop> _hopRefs;
@@ -128,6 +128,9 @@ public class CPlanMemoTable
        }
 
        public void pruneSuboptimal() {
+               if( SpoofCompiler.LDEBUG )
+                       LOG.info("#1: Memo before plan selection ("+size()+" 
plans)\n"+this);
+               
                //build index of referenced entries
                HashSet<Long> ix = new HashSet<Long>();
                for( Entry<Long, ArrayList<MemoTableEntry>> e : 
_plans.entrySet() )
@@ -160,14 +163,24 @@ public class CPlanMemoTable
                                        if( me.isPlanRef(i) && 
_hopRefs.get(me.intput(i)).getParent().size()==1 )
                                                
_plansBlacklist.add(me.intput(i));
                        }
+               
+               if( SpoofCompiler.LDEBUG )
+                       LOG.info("#2: Memo after plan selection ("+size()+" 
plans)\n"+this);
        }
 
-       public ArrayList<MemoTableEntry> get(long hopID) {
+       public List<MemoTableEntry> get(long hopID) {
                return _plans.get(hopID);
        }
        
+       public List<MemoTableEntry> getDistinct(long hopID) {
+               //return distinct entries wrt type and closed attributes
+               return _plans.get(hopID).stream()
+                       .map(p -> new MemoTableEntry(p.type,-1,-1,-1,p.closed))
+                       .distinct().collect(Collectors.toList());
+       }
+       
        public MemoTableEntry getBest(long hopID) {
-               ArrayList<MemoTableEntry> tmp = get(hopID);
+               List<MemoTableEntry> tmp = get(hopID);
                if( tmp == null || tmp.isEmpty() )
                        return null;
                
@@ -183,38 +196,31 @@ public class CPlanMemoTable
        
        //TODO revisit requirement for preference once cost-based pruning 
(pruneSuboptimal) ready
        public MemoTableEntry getBest(long hopID, TemplateType pref) {
-               ArrayList<MemoTableEntry> tmp = get(hopID);
+               List<MemoTableEntry> tmp = get(hopID);
                if( tmp.size()==1 ) //single plan available
                        return tmp.get(0);
                
                //try to find plan with preferred type
-               Log.warn("Multiple memo table entries available, searching for 
preferred type.");
+               if( SpoofCompiler.LDEBUG )
+                       LOG.warn("Multiple memo table entries available, 
searching for preferred type.");
                ArrayList<MemoTableEntry> tmp2 = new 
ArrayList<MemoTableEntry>();
                for( MemoTableEntry me : tmp )
                        if( me.type == pref )
                                tmp2.add(me);
                if( !tmp2.isEmpty() ) {
-                       if( tmp2.size() > 1 )
-                               Log.warn("Multiple memo table entries w/ 
preferred type available, return max refs entry.");
+                       if( tmp2.size() > 1 && SpoofCompiler.LDEBUG )
+                               LOG.warn("Multiple memo table entries w/ 
preferred type available, return max refs entry.");
                        return getMaxRefsEntry(tmp2);
                }
                else {
-                       Log.warn("Multiple memo table entries available but 
none with preferred type, return max refs entry.");
+                       if( SpoofCompiler.LDEBUG )
+                               LOG.warn("Multiple memo table entries available 
but none with preferred type, return max refs entry.");
                        return getMaxRefsEntry(tmp);
                }
        }
        
-       private static MemoTableEntry getMaxRefsEntry(ArrayList<MemoTableEntry> 
tmp) {
-               int maxPos = 0;
-               int maxRefs = 0;
-               for( int i=0; i<tmp.size(); i++ ) {
-                       int cntRefs = tmp.get(i).countPlanRefs();
-                       if( cntRefs > maxRefs ) {
-                               maxRefs = cntRefs;
-                               maxPos = i;
-                       }
-               }
-               return tmp.get(maxPos);
+       private static MemoTableEntry getMaxRefsEntry(List<MemoTableEntry> tmp) 
{
+               return Collections.max(tmp, Comparator.comparing(p -> 
p.countPlanRefs()));
        }
        
        private static boolean isValid(MemoTableEntry me, Hop hop) {
@@ -224,6 +230,12 @@ public class CPlanMemoTable
                        || (me.type == TemplateType.CellTpl);
        }
        
+       public int size() {
+               return _plans.values().stream()
+                       .map(list -> list.size())
+                       .mapToInt(x -> x.intValue()).sum();
+       }
+       
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder();
@@ -235,6 +247,9 @@ public class CPlanMemoTable
                        sb.append(Arrays.toString(e.getValue().toArray(new 
MemoTableEntry[0]))+"\n");
                }
                sb.append("----------------------------------\n");
+               sb.append("Blacklisted Plans: ");
+               sb.append(Arrays.toString(_plansBlacklist.toArray(new 
Long[0]))+"\n");
+               sb.append("----------------------------------\n");
                return sb.toString();   
        }
        
@@ -246,10 +261,14 @@ public class CPlanMemoTable
                public final long input3;
                public boolean closed = false;
                public MemoTableEntry(TemplateType t, long in1, long in2, long 
in3) {
+                       this(t, in1, in2, in3, false);
+               }
+               public MemoTableEntry(TemplateType t, long in1, long in2, long 
in3, boolean close) {
                        type = t;
                        input1 = in1;
                        input2 = in2;
                        input3 = in3;
+                       closed = close;
                }
                public boolean isPlanRef(int index) {
                        return (index==0 && input1 >=0)
@@ -310,5 +329,10 @@ public class CPlanMemoTable
                                                (pos==1)?ref:me.input2, 
(pos==2)?ref:me.input3));
                        plans = tmp;
                }
+               
+               @Override
+               public String toString() {
+                       return Arrays.toString(plans.toArray(new 
MemoTableEntry[0]));
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a929ae6e/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
index ecf6aac..6b74bf7 100644
--- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
@@ -213,7 +213,8 @@ public class Recompiler
                        memo.extract(hops, status);
                        
                        // codegen if enabled
-                       if( 
ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.CODEGEN) && 
SpoofCompiler.RECOMPILE ) {
+                       if( 
ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.CODEGEN) 
+                                       && SpoofCompiler.RECOMPILE_CODEGEN ) {
                                Hop.resetVisitStatus(hops);
                                hops = SpoofCompiler.optimize(hops, true);
                        }
@@ -313,7 +314,8 @@ public class Recompiler
                        hops.refreshMemEstimates(memo);                 
                        
                        // codegen if enabled
-                       if( 
ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.CODEGEN) && 
SpoofCompiler.RECOMPILE ) {
+                       if( 
ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.CODEGEN) 
+                                       && SpoofCompiler.RECOMPILE_CODEGEN ) {
                                hops.resetVisitStatus();
                                hops = SpoofCompiler.optimize(hops, false);
                        }

Reply via email to