Repository: systemml Updated Branches: refs/heads/master c3601d419 -> ae98864e5
[SYSTEMML-2120] New plan cache for cost-based codegen optimizer So far the codegen framework used a plan cache to reused generated and compiled operators. Since this operator plan cache determined operator equivalence based on hashing cplans, the optimizer still had to repeatedly enumerate and cost plans (which cover multiple cplans). This patch introduces an additional plan cache in the cost-based optimizer to reduce the overhead of repeated optimization with equivalent inputs. In detail, we compute plan partition signatures and used these to match equivalent optimization problems which allows to reuse optimization results. These signatures are cheap to compute but inexact and hence we only use this plan cache for more than 10 interesting points which makes collisions practically impossible. On a scenario of 1000 iterations over lenet (mini-batch algorithm, with need to dynamic recompilation, and a large partition of 14 interesting points in the inner loop), this patch improved the total runtime from 442s to 183s (239s without codegen). This improvement is due to a cache hit rate of 990/1013 for amenable plans, reducing the number of costed plans from 11M to 150K (out of a total plan space of 1,064,354,993,445) and the total codegen overhead from 236s to 8.6s. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d79e1c56 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d79e1c56 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d79e1c56 Branch: refs/heads/master Commit: d79e1c565a88e03e16469d1b7a92b51835db089e Parents: c3601d4 Author: Matthias Boehm <mboe...@gmail.com> Authored: Sun Feb 4 20:53:05 2018 -0800 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Sun Feb 4 23:52:36 2018 -0800 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 6 +- .../opt/PlanSelectionFuseCostBasedV2.java | 99 +++++++++++++++++++- .../java/org/apache/sysml/utils/Statistics.java | 23 ++++- 3 files changed, 122 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/d79e1c56/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 7f67b43..28d3821 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -421,14 +421,14 @@ public class SpoofCompiler planCache.putPlan(tmp.getValue(), cla); } else if( DMLScript.STATISTICS ) { - Statistics.incrementCodegenPlanCacheHits(); + Statistics.incrementCodegenOpCacheHits(); } //make class available and maintain hits if(cla != null) clas.put(cplan.getKey(), new Pair<Hop[],Class<?>>(tmp.getKey(),cla)); if( DMLScript.STATISTICS ) - Statistics.incrementCodegenPlanCacheTotal(); + Statistics.incrementCodegenOpCacheTotal(); } //create modified hop dag (operator replacement and CSE) @@ -438,7 +438,7 @@ public class SpoofCompiler ret = constructModifiedHopDag(roots, cplans, clas); //run common subexpression elimination and other rewrites - ret = rewriteCSE.rewriteHopDAG(ret, new ProgramRewriteStatus()); + ret = rewriteCSE.rewriteHopDAG(ret, new ProgramRewriteStatus()); //explain after modification if( LOG.isTraceEnabled() ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/d79e1c56/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 89bb1e4..aa1082e 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -27,6 +27,7 @@ import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map.Entry; import java.util.stream.Collectors; @@ -91,7 +92,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection private static final double READ_BANDWIDTH_MEM = 32d*1024*1024*1024; //32GB/s private static final double READ_BANDWIDTH_BROADCAST = WRITE_BANDWIDTH_IO/4; private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //1GFLOPs/core - * InfrastructureAnalyzer.getLocalParallelism(); + * InfrastructureAnalyzer.getLocalParallelism(); //sparsity estimate for unknown sparsity to prefer sparse-safe fusion plans private static final double SPARSE_SAFE_SPARSITY_EST = 0.1; @@ -100,11 +101,20 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //remaining candidate plans of large partitions (w/ >= COST_MIN_EPS_NUM_POINTS) are //only evaluated if the current costs are > (1+COST_MIN_EPS) * static (i.e., minimal) costs. public static final double COST_MIN_EPS = 0.01; //1% - public static final double COST_MIN_EPS_NUM_POINTS = 20; //2^20 = 1M plans + public static final int COST_MIN_EPS_NUM_POINTS = 20; //2^20 = 1M plans + + //In order to avoid unnecessary repeated reoptimization we use a plan cache for + //mapping partition signatures (including input sizes) to optimal plans. However, + //since hop ids change during dynamic recompilation, we use an approximate signature + //that is cheap to compute and therefore only use this for large partitions. + private static final int PLAN_CACHE_NUM_POINTS = 10; //2^10 = 1024 + private static final int PLAN_CACHE_SIZE = 1024; + private static final LinkedHashMap<PartitionSignature, boolean[]> _planCache = new LinkedHashMap<>(); //optimizer configuration public static boolean COST_PRUNING = true; public static boolean STRUCTURAL_PRUNING = true; + public static boolean PLAN_CACHING = true; private static final TemplateRow ROW_TPL = new TemplateRow(); //cost vector id generator, whose ids are only used for memoization per call to getPlanCost; @@ -229,6 +239,17 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection if( !evalRemain ) LOG.warn("Skip enum for |M|="+Mlen+", C="+bestC+", Cmin="+costs.getMinCosts()); + //probe plan cache for existing optimized plan + PartitionSignature pKey = null; + if( probePlanCache(matPoints) ) { + pKey = new PartitionSignature(part, matPoints.length, costs, C0, CN); + boolean[] plan = getPlan(pKey); + if( plan != null ) { + Statistics.incrementCodegenEnumAllP((rgraph!=null||!STRUCTURAL_PRUNING)?len:0); + return plan; + } + } + //evaluate remaining plans, except already evaluated heuristics for( long i=1; i<len-1 & evalRemain; i++ ) { //construct assignment @@ -300,6 +321,10 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection if( LOG.isTraceEnabled() ) LOG.trace("Enum: Optimal plan: "+Arrays.toString(bestPlan)); + //keep large plans + if( probePlanCache(matPoints) ) + putPlan(pKey, bestPlan); + //copy best plan w/o fixed offset plan return (bestPlan==null) ? new boolean[Mlen] : Arrays.copyOfRange(bestPlan, off, bestPlan.length); @@ -1081,6 +1106,38 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection && HopRewriteUtils.isTransposeOperation(hop.getInput().get(index)); } + private static boolean probePlanCache(InterestingPoint[] matPoints) { + return matPoints.length >= PLAN_CACHE_NUM_POINTS; + } + + private static boolean[] getPlan(PartitionSignature pKey) { + boolean[] plan = null; + synchronized( _planCache ) { + plan = _planCache.get(pKey); + } + if( DMLScript.STATISTICS ) { + if( plan != null ) + Statistics.incrementCodegenPlanCacheHits(); + Statistics.incrementCodegenPlanCacheTotal(); + } + return plan; + } + + private static void putPlan(PartitionSignature pKey, boolean[] plan) { + synchronized( _planCache ) { + //maintain size of plan cache (remove first) + if( _planCache.size() >= PLAN_CACHE_SIZE ) { + Iterator<Entry<PartitionSignature, boolean[]>> iter = + _planCache.entrySet().iterator(); + iter.next(); + iter.remove(); + } + + //add last entry + _planCache.put(pKey, plan); + } + } + private class CostVector { public final long ID; public final double outSize; @@ -1188,4 +1245,42 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection +"{"+Arrays.toString(_fusedInputs.toArray(new Long[0]))+"}]"; } } + + private class PartitionSignature { + private final int partNodes, inputNodes, rootNodes, matPoints; + private final double cCompute, cRead, cWrite, cPlan0, cPlanN; + + public PartitionSignature(PlanPartition part, int M, StaticCosts costs, double cP0, double cPN) { + partNodes = part.getPartition().size(); + inputNodes = part.getInputs().size(); + rootNodes = part.getRoots().size(); + matPoints = M; + cCompute = costs._compute; + cRead = costs._read; + cWrite = costs._write; + cPlan0 = cP0; + cPlanN = cPN; + } + @Override + public int hashCode() { + return UtilFunctions.intHashCode( + Arrays.hashCode(new int[]{partNodes, inputNodes, rootNodes, matPoints}), + Arrays.hashCode(new double[]{cCompute, cRead, cWrite, cPlan0, cPlanN})); + } + @Override + public boolean equals(Object o) { + if( !(o instanceof PartitionSignature) ) + return false; + PartitionSignature that = (PartitionSignature) o; + return partNodes == that.partNodes + && inputNodes == that.inputNodes + && rootNodes == that.rootNodes + && matPoints == that.matPoints + && cCompute == that.cCompute + && cRead == that.cRead + && cWrite == that.cWrite + && cPlan0 == that.cPlan0 + && cPlanN == that.cPlanN; + } + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/d79e1c56/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index 0371966..8cf3f02 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -89,6 +89,8 @@ public class Statistics private static final LongAdder codegenEnumAllP = new LongAdder(); //count private static final LongAdder codegenEnumEval = new LongAdder(); //count private static final LongAdder codegenEnumEvalP = new LongAdder(); //count + private static final LongAdder codegenOpCacheHits = new LongAdder(); //count + private static final LongAdder codegenOpCacheTotal = new LongAdder(); //count private static final LongAdder codegenPlanCacheHits = new LongAdder(); //count private static final LongAdder codegenPlanCacheTotal = new LongAdder(); //count @@ -288,6 +290,14 @@ public class Statistics codegenClassCompileTime.add(delta); } + public static void incrementCodegenOpCacheHits() { + codegenOpCacheHits.increment(); + } + + public static void incrementCodegenOpCacheTotal() { + codegenOpCacheTotal.increment(); + } + public static void incrementCodegenPlanCacheHits() { codegenPlanCacheHits.increment(); } @@ -329,6 +339,14 @@ public class Statistics return codegenClassCompileTime.longValue(); } + public static long getCodegenOpCacheHits() { + return codegenOpCacheHits.longValue(); + } + + public static long getCodegenOpCacheTotal() { + return codegenOpCacheTotal.longValue(); + } + public static long getCodegenPlanCacheHits() { return codegenPlanCacheHits.longValue(); } @@ -418,6 +436,8 @@ public class Statistics codegenEnumEvalP.reset(); codegenCompileTime.reset(); codegenClassCompileTime.reset(); + codegenOpCacheHits.reset(); + codegenOpCacheTotal.reset(); codegenPlanCacheHits.reset(); codegenPlanCacheTotal.reset(); @@ -814,7 +834,8 @@ public class Statistics getCodegenEnumAllP() + "/" + getCodegenEnumEval() + "/" + getCodegenEnumEvalP() + ".\n"); sb.append("Codegen compile times (DAG,JC):\t" + String.format("%.3f", (double)getCodegenCompileTime()/1000000000) + "/" + String.format("%.3f", (double)getCodegenClassCompileTime()/1000000000) + " sec.\n"); - sb.append("Codegen plan cache hits:\t" + getCodegenPlanCacheHits() + "/" + getCodegenPlanCacheTotal() + ".\n"); + sb.append("Codegen enum plan cache hits:\t" + getCodegenPlanCacheHits() + "/" + getCodegenPlanCacheTotal() + ".\n"); + sb.append("Codegen op plan cache hits:\t" + getCodegenOpCacheHits() + "/" + getCodegenOpCacheTotal() + ".\n"); } if( OptimizerUtils.isSparkExecutionMode() ){ String lazy = SparkExecutionContext.isLazySparkContextCreation() ? "(lazy)" : "(eager)";