Repository: systemml
Updated Branches:
  refs/heads/master c3601d419 -> ae98864e5


[SYSTEMML-2120] New plan cache for cost-based codegen optimizer

So far the codegen framework used a plan cache to reused generated and
compiled operators. Since this operator plan cache determined operator
equivalence based on hashing cplans, the optimizer still had to
repeatedly enumerate and cost plans (which cover multiple cplans). This
patch introduces an additional plan cache in the cost-based optimizer to
reduce the overhead of repeated optimization with equivalent inputs.

In detail, we compute plan partition signatures and used these to match
equivalent optimization problems which allows to reuse optimization
results. These signatures are cheap to compute but inexact and hence we
only use this plan cache for more than 10 interesting points which makes
collisions practically impossible.

On a scenario of 1000 iterations over lenet (mini-batch algorithm, with
need to dynamic recompilation, and a large partition of 14 interesting
points in the inner loop), this patch improved the total runtime from
442s to 183s (239s without codegen). This improvement is due to a cache
hit rate of 990/1013 for amenable plans, reducing the number of costed
plans from 11M to 150K (out of a total plan space of 1,064,354,993,445)
and the total codegen overhead from 236s to 8.6s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d79e1c56
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d79e1c56
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d79e1c56

Branch: refs/heads/master
Commit: d79e1c565a88e03e16469d1b7a92b51835db089e
Parents: c3601d4
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sun Feb 4 20:53:05 2018 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sun Feb 4 23:52:36 2018 -0800

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |  6 +-
 .../opt/PlanSelectionFuseCostBasedV2.java       | 99 +++++++++++++++++++-
 .../java/org/apache/sysml/utils/Statistics.java | 23 ++++-
 3 files changed, 122 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d79e1c56/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 7f67b43..28d3821 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -421,14 +421,14 @@ public class SpoofCompiler
                                                
planCache.putPlan(tmp.getValue(), cla);
                                }
                                else if( DMLScript.STATISTICS ) {
-                                       
Statistics.incrementCodegenPlanCacheHits();
+                                       
Statistics.incrementCodegenOpCacheHits();
                                }
                                
                                //make class available and maintain hits
                                if(cla != null)
                                        clas.put(cplan.getKey(), new 
Pair<Hop[],Class<?>>(tmp.getKey(),cla));
                                if( DMLScript.STATISTICS )
-                                       
Statistics.incrementCodegenPlanCacheTotal();
+                                       
Statistics.incrementCodegenOpCacheTotal();
                        }
                        
                        //create modified hop dag (operator replacement and CSE)
@@ -438,7 +438,7 @@ public class SpoofCompiler
                                ret = constructModifiedHopDag(roots, cplans, 
clas);
                                
                                //run common subexpression elimination and 
other rewrites
-                               ret = rewriteCSE.rewriteHopDAG(ret, new 
ProgramRewriteStatus());        
+                               ret = rewriteCSE.rewriteHopDAG(ret, new 
ProgramRewriteStatus());
                                
                                //explain after modification
                                if( LOG.isTraceEnabled() ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d79e1c56/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 89bb1e4..aa1082e 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -27,6 +27,7 @@ import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map.Entry;
 import java.util.stream.Collectors;
@@ -91,7 +92,7 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
        private static final double READ_BANDWIDTH_MEM  = 32d*1024*1024*1024;  
//32GB/s
        private static final double READ_BANDWIDTH_BROADCAST = 
WRITE_BANDWIDTH_IO/4;
        private static final double COMPUTE_BANDWIDTH  =   2d*1024*1024*1024   
//1GFLOPs/core
-                                                               * 
InfrastructureAnalyzer.getLocalParallelism();
+               * InfrastructureAnalyzer.getLocalParallelism();
        
        //sparsity estimate for unknown sparsity to prefer sparse-safe fusion 
plans
        private static final double SPARSE_SAFE_SPARSITY_EST = 0.1;
@@ -100,11 +101,20 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
        //remaining candidate plans of large partitions (w/ >= 
COST_MIN_EPS_NUM_POINTS) are
        //only evaluated if the current costs are > (1+COST_MIN_EPS) * static 
(i.e., minimal) costs.
        public static final double COST_MIN_EPS = 0.01; //1%
-       public static final double COST_MIN_EPS_NUM_POINTS = 20; //2^20 = 1M 
plans
+       public static final int COST_MIN_EPS_NUM_POINTS = 20; //2^20 = 1M plans
+       
+       //In order to avoid unnecessary repeated reoptimization we use a plan 
cache for
+       //mapping partition signatures (including input sizes) to optimal 
plans. However,
+       //since hop ids change during dynamic recompilation, we use an 
approximate signature
+       //that is cheap to compute and therefore only use this for large 
partitions.
+       private static final int PLAN_CACHE_NUM_POINTS = 10; //2^10 = 1024
+       private static final int PLAN_CACHE_SIZE = 1024;
+       private static final LinkedHashMap<PartitionSignature, boolean[]> 
_planCache = new LinkedHashMap<>();
        
        //optimizer configuration
        public static boolean COST_PRUNING = true;
        public static boolean STRUCTURAL_PRUNING = true;
+       public static boolean PLAN_CACHING = true;
        private static final TemplateRow ROW_TPL = new TemplateRow();
        
        //cost vector id generator, whose ids are only used for memoization per 
call to getPlanCost;
@@ -229,6 +239,17 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                if( !evalRemain )
                        LOG.warn("Skip enum for |M|="+Mlen+", C="+bestC+", 
Cmin="+costs.getMinCosts());
                
+               //probe plan cache for existing optimized plan
+               PartitionSignature pKey = null;
+               if( probePlanCache(matPoints) ) {
+                       pKey = new PartitionSignature(part, matPoints.length, 
costs, C0, CN);
+                       boolean[] plan = getPlan(pKey);
+                       if( plan != null ) {
+                               
Statistics.incrementCodegenEnumAllP((rgraph!=null||!STRUCTURAL_PRUNING)?len:0);
+                               return plan;
+                       }
+               }
+               
                //evaluate remaining plans, except already evaluated heuristics
                for( long i=1; i<len-1 & evalRemain; i++ ) {
                        //construct assignment
@@ -300,6 +321,10 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                if( LOG.isTraceEnabled() )
                        LOG.trace("Enum: Optimal plan: 
"+Arrays.toString(bestPlan));
                
+               //keep large plans 
+               if( probePlanCache(matPoints) )
+                       putPlan(pKey, bestPlan);
+               
                //copy best plan w/o fixed offset plan
                return (bestPlan==null) ? new boolean[Mlen] :
                        Arrays.copyOfRange(bestPlan, off, bestPlan.length);
@@ -1081,6 +1106,38 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                        && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(index)); 
        }
        
+       private static boolean probePlanCache(InterestingPoint[] matPoints) {
+               return matPoints.length >= PLAN_CACHE_NUM_POINTS;
+       }
+       
+       private static boolean[] getPlan(PartitionSignature pKey) {
+               boolean[] plan = null;
+               synchronized( _planCache ) {
+                       plan = _planCache.get(pKey);
+               }
+               if( DMLScript.STATISTICS ) {
+                       if( plan != null )
+                               Statistics.incrementCodegenPlanCacheHits();
+                       Statistics.incrementCodegenPlanCacheTotal();
+               }
+               return plan;
+       }
+       
+       private static void putPlan(PartitionSignature pKey, boolean[] plan) {
+               synchronized( _planCache ) {
+                       //maintain size of plan cache (remove first)
+                       if( _planCache.size() >= PLAN_CACHE_SIZE ) {
+                               Iterator<Entry<PartitionSignature, boolean[]>> 
iter =
+                                       _planCache.entrySet().iterator();
+                               iter.next();
+                               iter.remove();
+                       }
+                       
+                       //add last entry 
+                       _planCache.put(pKey, plan);
+               }
+       }
+       
        private class CostVector {
                public final long ID;
                public final double outSize; 
@@ -1188,4 +1245,42 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                                +"{"+Arrays.toString(_fusedInputs.toArray(new 
Long[0]))+"}]"; 
                }
        }
+       
+       private class PartitionSignature {
+               private final int partNodes, inputNodes, rootNodes, matPoints;
+               private final double cCompute, cRead, cWrite, cPlan0, cPlanN;
+               
+               public PartitionSignature(PlanPartition part, int M, 
StaticCosts costs, double cP0, double cPN) {
+                       partNodes = part.getPartition().size();
+                       inputNodes = part.getInputs().size();
+                       rootNodes = part.getRoots().size();
+                       matPoints = M;
+                       cCompute = costs._compute;
+                       cRead = costs._read;
+                       cWrite = costs._write;
+                       cPlan0 = cP0;
+                       cPlanN = cPN;
+               }
+               @Override
+               public int hashCode() {
+                       return UtilFunctions.intHashCode(
+                               Arrays.hashCode(new int[]{partNodes, 
inputNodes, rootNodes, matPoints}),
+                               Arrays.hashCode(new double[]{cCompute, cRead, 
cWrite, cPlan0, cPlanN}));
+               }
+               @Override 
+               public boolean equals(Object o) {
+                       if( !(o instanceof PartitionSignature) )
+                               return false;
+                       PartitionSignature that = (PartitionSignature) o;
+                       return partNodes == that.partNodes
+                               && inputNodes == that.inputNodes
+                               && rootNodes == that.rootNodes
+                               && matPoints == that.matPoints
+                               && cCompute == that.cCompute
+                               && cRead == that.cRead
+                               && cWrite == that.cWrite
+                               && cPlan0 == that.cPlan0
+                               && cPlanN == that.cPlanN;
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d79e1c56/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java 
b/src/main/java/org/apache/sysml/utils/Statistics.java
index 0371966..8cf3f02 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -89,6 +89,8 @@ public class Statistics
        private static final LongAdder codegenEnumAllP = new LongAdder(); 
//count
        private static final LongAdder codegenEnumEval = new LongAdder(); 
//count
        private static final LongAdder codegenEnumEvalP = new LongAdder(); 
//count
+       private static final LongAdder codegenOpCacheHits = new LongAdder(); 
//count
+       private static final LongAdder codegenOpCacheTotal = new LongAdder(); 
//count
        private static final LongAdder codegenPlanCacheHits = new LongAdder(); 
//count
        private static final LongAdder codegenPlanCacheTotal = new LongAdder(); 
//count
        
@@ -288,6 +290,14 @@ public class Statistics
                codegenClassCompileTime.add(delta);
        }
        
+       public static void incrementCodegenOpCacheHits() {
+               codegenOpCacheHits.increment();
+       }
+       
+       public static void incrementCodegenOpCacheTotal() {
+               codegenOpCacheTotal.increment();
+       }
+       
        public static void incrementCodegenPlanCacheHits() {
                codegenPlanCacheHits.increment();
        }
@@ -329,6 +339,14 @@ public class Statistics
                return codegenClassCompileTime.longValue();
        }
        
+       public static long getCodegenOpCacheHits() {
+               return codegenOpCacheHits.longValue();
+       }
+       
+       public static long getCodegenOpCacheTotal() {
+               return codegenOpCacheTotal.longValue();
+       }
+       
        public static long getCodegenPlanCacheHits() {
                return codegenPlanCacheHits.longValue();
        }
@@ -418,6 +436,8 @@ public class Statistics
                codegenEnumEvalP.reset();
                codegenCompileTime.reset();
                codegenClassCompileTime.reset();
+               codegenOpCacheHits.reset();
+               codegenOpCacheTotal.reset();
                codegenPlanCacheHits.reset();
                codegenPlanCacheTotal.reset();
                
@@ -814,7 +834,8 @@ public class Statistics
                                                getCodegenEnumAllP() + "/" + 
getCodegenEnumEval() + "/" + getCodegenEnumEvalP() + ".\n");
                                sb.append("Codegen compile times (DAG,JC):\t" + 
String.format("%.3f", (double)getCodegenCompileTime()/1000000000) + "/" + 
                                                String.format("%.3f", 
(double)getCodegenClassCompileTime()/1000000000)  + " sec.\n");
-                               sb.append("Codegen plan cache hits:\t" + 
getCodegenPlanCacheHits() + "/" + getCodegenPlanCacheTotal() + ".\n");
+                               sb.append("Codegen enum plan cache hits:\t" + 
getCodegenPlanCacheHits() + "/" + getCodegenPlanCacheTotal() + ".\n");
+                               sb.append("Codegen op plan cache hits:\t" + 
getCodegenOpCacheHits() + "/" + getCodegenOpCacheTotal() + ".\n");
                        }
                        if( OptimizerUtils.isSparkExecutionMode() ){
                                String lazy = 
SparkExecutionContext.isLazySparkContextCreation() ? "(lazy)" : "(eager)";

Reply via email to