Repository: systemml Updated Branches: refs/heads/master c1db484d6 -> d01d13c4b
[SYSTEMML-1934] Fix codegen optimizer (incorrect structural pruning) This patch fixes a severe codegen optimizer issue, where in special cases the positions of sub problems were incorrectly set leading to wrong mappings of optimal plans for subproblems to the global plan. We now use a much simpler and more robust creation of these mappings. With this patch we now (1) find the optimal plans on scenarios where we previously missed them (e.g., Mlogreg), and (2) structural pruning shows more pruning effectiveness. For example, on GLM binomial probit, there are 264,371 plans - with cost-based pruning this is reduced to 33,388 and with additional structural pruning further reduced to 9,574 plans. Furthermore, this patch also improves the trace information of the codegen optimizer as well as fixes an issue of statistic maintenance without structural pruning. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d01d13c4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d01d13c4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d01d13c4 Branch: refs/heads/master Commit: d01d13c4bc7b4da72ea399d76946cefda31fd4a4 Parents: c1db484 Author: Matthias Boehm <[email protected]> Authored: Sun Sep 24 22:01:27 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Sep 24 22:01:27 2017 -0700 ---------------------------------------------------------------------- .../opt/PlanSelectionFuseCostBasedV2.java | 20 +++-- .../hops/codegen/opt/ReachabilityGraph.java | 92 +++++++++----------- 2 files changed, 55 insertions(+), 57 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/d01d13c4/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 30631d0..7c27dcf 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -92,8 +92,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection private static final double SPARSE_SAFE_SPARSITY_EST = 0.1; //optimizer configuration - public static boolean USE_COST_PRUNING = true; - public static boolean USE_STRUCTURAL_PRUNING = true; + public static boolean COST_PRUNING = true; + public static boolean STRUCTURAL_PRUNING = false; private static final IDSequence COST_ID = new IDSequence(); private static final TemplateRow ROW_TPL = new TemplateRow(); @@ -149,8 +149,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //prepare pruning helpers and prune memo table w/ determined mat points StaticCosts costs = new StaticCosts(computeCosts, getComputeCost(computeCosts, memo), getReadCost(part, memo), getWriteCost(part.getRoots(), memo)); - ReachabilityGraph rgraph = USE_STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null; - if( USE_STRUCTURAL_PRUNING ) { + ReachabilityGraph rgraph = STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null; + if( STRUCTURAL_PRUNING ) { part.setMatPointsExt(rgraph.getSortedSearchSpace()); for( Long hopID : part.getPartition() ) memo.pruneRedundant(hopID, true, part.getMatPointsExt()); @@ -210,15 +210,19 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection long pskip = 0; //skip after costing //skip plans with structural pruning - if( USE_STRUCTURAL_PRUNING && (rgraph!=null) && rgraph.isCutSet(plan) ) { + if( STRUCTURAL_PRUNING && (rgraph!=null) && rgraph.isCutSet(plan) ) { //compute skip (which also acts as boundary for subproblems) pskip = rgraph.getNumSkipPlans(plan); + if( LOG.isTraceEnabled() ) + LOG.trace("Enum: Structural pruning for cut set: "+rgraph.getCutSet(plan)); //start increment rgraph get subproblems SubProblem[] prob = rgraph.getSubproblems(plan); //solve subproblems independently and combine into best plan for( int j=0; j<prob.length; j++ ) { + if( LOG.isTraceEnabled() ) + LOG.trace("Enum: Subproblem "+(j+1)+"/"+prob.length+": "+prob[j]); boolean[] bestTmp = enumPlans(memo, part, costs, null, prob[j].freeMat, prob[j].offset, bestC); LibSpoofPrimitives.vectWrite(bestTmp, plan, prob[j].freePos); @@ -228,7 +232,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //the default code path; hence we postpone the skip after costing } //skip plans with branch and bound pruning (cost) - else if( USE_COST_PRUNING ) { + else if( COST_PRUNING ) { double lbC = Math.max(costs._read, costs._compute) + costs._write + getMaterializationCost(part, matPoints, memo, plan); if( lbC >= bestC ) { @@ -241,7 +245,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } //cost assignment on hops. Stop early if exceeds bestC. - double pCBound = USE_COST_PRUNING ? bestC : Double.MAX_VALUE; + double pCBound = COST_PRUNING ? bestC : Double.MAX_VALUE; double C = getPlanCost(memo, part, matPoints, plan, costs._computeCosts, pCBound); if (LOG.isTraceEnabled()) LOG.trace("Enum: " + Arrays.toString(plan) + " -> " + C); @@ -263,7 +267,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } if( DMLScript.STATISTICS ) { - Statistics.incrementCodegenEnumAllP((rgraph!=null)?len:0); + Statistics.incrementCodegenEnumAllP((rgraph!=null||!STRUCTURAL_PRUNING)?len:0); Statistics.incrementCodegenEnumEval(numEvalPlans); Statistics.incrementCodegenEnumEvalP(numEvalPartPlans); } http://git-wip-us.apache.org/repos/asf/systemml/blob/d01d13c4/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java index 0c829e8..fb7840b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java @@ -118,28 +118,26 @@ public class ReachabilityGraph _cutSets = cutSets.stream() .sorted(Comparator.comparing(p -> p.getRight())) .map(p -> p.getLeft()).toArray(CutSet[]::new); - + + //created sorted order of materialization points + //(cut sets in predetermined order, all other points appended) HashMap<InterestingPoint, Integer> probe = new HashMap<>(); ArrayList<InterestingPoint> lsearchSpace = new ArrayList<>(); for( CutSet cs : _cutSets ) { - cs.updatePos(lsearchSpace.size()); - cs.updatePartitions(probe); CollectionUtils.addAll(lsearchSpace, cs.cut); - for( InterestingPoint p: cs.cut ) - probe.put(p, probe.size()-1); + for( InterestingPoint p : cs.cut ) + probe.put(p, probe.size()); } for( InterestingPoint p : part.getMatPointsExt() ) if( !probe.containsKey(p) ) { lsearchSpace.add(p); - probe.put(p, probe.size()-1); + probe.put(p, probe.size()); } _searchSpace = lsearchSpace.toArray(new InterestingPoint[0]); - //materialize partition indices - for( CutSet cs : _cutSets ) { - cs.updatePartitionIndexes(probe); - cs.finalizePartition(); - } + //finalize cut sets (update positions wrt search space) + for( CutSet cs : _cutSets ) + cs.updatePositions(probe); //final sanity check of interesting points if( _searchSpace.length != part.getMatPointsExt().length ) @@ -175,7 +173,7 @@ public class ReachabilityGraph public long getNumSkipPlans(boolean[] plan) { for( CutSet cs : _cutSets ) if( isCutSet(cs, plan) ) { - int pos = cs.posCut[cs.posCut.length-1]; + int pos = cs.posCut[cs.posCut.length-1]; return UtilFunctions.pow(2, plan.length-pos-1); } throw new RuntimeException("Failed to compute " @@ -271,48 +269,44 @@ public class ReachabilityGraph freePos = pos; freeMat = mat; } + + @Override + public String toString() { + return "SubProblem: "+Arrays.toString(freeMat)+"; " + +offset+"; "+Arrays.toString(freePos); + } } - public static class CutSet { - public InterestingPoint[] cut; - public InterestingPoint[] left; - public InterestingPoint[] right; - public int[] posCut; - public int[] posLeft; - public int[] posRight; + private static class CutSet { + private final InterestingPoint[] cut; + private final InterestingPoint[] left; + private final InterestingPoint[] right; + private int[] posCut; + private int[] posLeft; + private int[] posRight; - public CutSet(InterestingPoint[] cutPoints, + private CutSet(InterestingPoint[] cutPoints, InterestingPoint[] l, InterestingPoint[] r) { cut = cutPoints; - left = l; - right = r; + left = (InterestingPoint[]) ArrayUtils.addAll(cut, l); + right = (InterestingPoint[]) ArrayUtils.addAll(cut, r); } - public void updatePos(int index) { - posCut = new int[cut.length]; - for(int i=0; i<posCut.length; i++) - posCut[i] = index + i; - } - - public void updatePartitions(HashMap<InterestingPoint,Integer> blacklist) { - left = Arrays.stream(left).filter(p -> !blacklist.containsKey(p)) - .toArray(InterestingPoint[]::new); - right = Arrays.stream(right).filter(p -> !blacklist.containsKey(p)) - .toArray(InterestingPoint[]::new); - } - - public void updatePartitionIndexes(HashMap<InterestingPoint,Integer> probe) { - posLeft = new int[left.length]; - for(int i=0; i<left.length; i++) - posLeft[i] = probe.get(left[i]); - posRight = new int[right.length]; - for(int i=0; i<right.length; i++) - posRight[i] = probe.get(right[i]); - } - - public void finalizePartition() { - left = (InterestingPoint[]) ArrayUtils.addAll(cut, left); - right = (InterestingPoint[]) ArrayUtils.addAll(cut, right); + private void updatePositions(HashMap<InterestingPoint,Integer> probe) { + int lenCut = cut.length; + posCut = new int[lenCut]; + for(int i=0; i<lenCut; i++) + posCut[i] = probe.get(cut[i]); + + int lenLeft = left.length - cut.length; + posLeft = new int[lenLeft]; + for(int i=0; i<lenLeft; i++) + posLeft[i] = probe.get(left[lenCut+i]); + + int lenRight = right.length - cut.length; + posRight = new int[lenRight]; + for(int i=0; i<lenRight; i++) + posRight[i] = probe.get(right[lenCut+i]); } @Override @@ -329,12 +323,12 @@ public class ReachabilityGraph private long _ID; private InterestingPoint _p; - public NodeLink(InterestingPoint p) { + private NodeLink(InterestingPoint p) { _ID = _seqID.getNextID(); _p = p; } - public void addInput(NodeLink in) { + private void addInput(NodeLink in) { _inputs.add(in); }
