[SYSTEMML-1514] Fix codegen cost estimation (two-level memoization) This patch fixes the cost estimation of fusion plans for complex fused operators with internal DAG structures, where we mistakenly double counted compute costs. Similarly, we incorrectly double counted costs of materialized intermediates. The core idea is a two-level memoization, i.e., memoization for pairs of hops and cost vectors which allows a proper memoization but at the same time the evaluation of costs for overlapping fused operators with redundant computation.
Additionally, this patch also hardens the compilation of multi-aggregates to ensure matching input dimensions and to exclude partial rowwise fusion plans that cover matrix multiplications. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/fb82482b Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/fb82482b Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/fb82482b Branch: refs/heads/master Commit: fb82482b09f851b428fc1c0e70994e9e6d94c007 Parents: f9f70b3 Author: Matthias Boehm <[email protected]> Authored: Tue Apr 11 23:08:30 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Apr 11 23:08:30 2017 -0700 ---------------------------------------------------------------------- .../template/PlanSelectionFuseCostBased.java | 47 ++++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/fb82482b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java index 3c98090..8ba2490 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java @@ -31,6 +31,7 @@ import java.util.Iterator; import java.util.List; import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.hops.AggBinaryOp; @@ -48,7 +49,7 @@ import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; - +import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; /** * This cost-based plan selection algorithm chooses fused operators @@ -68,7 +69,8 @@ public class PlanSelectionFuseCostBased extends PlanSelection private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //2GFLOPs/core * InfrastructureAnalyzer.getLocalParallelism(); - private final static TemplateRow ROW_TPL = new TemplateRow(); + private static final IDSequence COST_ID = new IDSequence(); + private static final TemplateRow ROW_TPL = new TemplateRow(); @Override public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) @@ -315,10 +317,12 @@ public class PlanSelectionFuseCostBased extends PlanSelection LOG.trace(info); } + //filter aggregations w/ matmults to ensure consistent dims //sort aggregations by num dependencies to simplify merging //clusters of aggregations with parallel dependencies - aggInfos = aggInfos.stream().sorted(Comparator.comparing( - a -> a._inputAggs.size())).collect(Collectors.toList()); + aggInfos = aggInfos.stream().filter(a -> !a.containsMatMult) + .sorted(Comparator.comparing(a -> a._inputAggs.size())) + .collect(Collectors.toList()); //greedy grouping of multi-agg candidates boolean converged = false; @@ -409,6 +413,10 @@ public class PlanSelectionFuseCostBased extends PlanSelection aggInfo.addInputAggregate(current.getHopID()); } + //collect included matrix multiplications + if( type != null && HopRewriteUtils.isMatrixMultiply(current) ) + aggInfo.setContainsMatMult(); + //recursively process children MemoTableEntry me = (type!=null) ? memo.getBest(current.getHopID()) : null; for( int i=0; i< current.getInput().size(); i++ ) { @@ -612,7 +620,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection //READ costs by the input sizes, and COMPUTE by operation specific FLOP //counts times number of cells of main input, disregarding sparsity for now. - HashSet<Long> visited = new HashSet<Long>(); + HashSet<Pair<Long,Long>> visited = new HashSet<Pair<Long,Long>>(); double costs = 0; for( Long hopID : R ) costs += rGetPlanCosts(memo, memo._hopRefs.get(hopID), @@ -620,11 +628,17 @@ public class PlanSelectionFuseCostBased extends PlanSelection return costs; } - private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, - ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts, OperatorStats costsCurrent, TemplateType currentType) + private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<Pair<Long,Long>> visited, HashSet<Long> partition, + ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateType currentType) { - if( visited.contains(current.getHopID()) ) - return 0; //dont double count + //memoization per hop id and cost vector to account for redundant + //computation without double counting materialized results or compute + //costs of complex operation DAGs within a single fused operator + Pair<Long,Long> tag = Pair.of(current.getHopID(), + (costsCurrent==null)?0:costsCurrent.ID); + if( visited.contains(tag) ) + return 0; + visited.add(tag); //open template if necessary, including memoization //under awareness of current plan choice @@ -637,7 +651,6 @@ public class PlanSelectionFuseCostBased extends PlanSelection .filter(p -> hasNoRefToMaterialization(p, M, plan)) .min(new BasicPlanComparator()).orElse(null); opened = true; - visited.add(current.getHopID()); } else { best = memo.get(current.getHopID()).stream() @@ -649,8 +662,8 @@ public class PlanSelectionFuseCostBased extends PlanSelection } //create new cost vector if opened, initialized with write costs - OperatorStats costVect = !opened ? costsCurrent : - new OperatorStats(Math.max(current.getDim1(),1)*Math.max(current.getDim2(),1)); + CostVector costVect = !opened ? costsCurrent : + new CostVector(Math.max(current.getDim1(),1)*Math.max(current.getDim2(),1)); //add compute costs of current operator to costs vector if( partition.contains(current.getHopID()) ) @@ -821,12 +834,14 @@ public class PlanSelectionFuseCostBased extends PlanSelection return ret; } - private static class OperatorStats { + private static class CostVector { + public final long ID; public final double outSize; public double computeCosts = 0; public final HashMap<Long, Double> inSizes = new HashMap<Long, Double>(); - public OperatorStats(double outputSize) { + public CostVector(double outputSize) { + ID = COST_ID.getNextID(); outSize = outputSize; } public void addInputSize(long hopID, double inputSize) { @@ -853,6 +868,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection public final HashMap<Long,Hop> _aggregates; public final HashSet<Long> _inputAggs = new HashSet<Long>(); public final HashSet<Long> _fusedInputs = new HashSet<Long>(); + public boolean containsMatMult = false; public AggregateInfo(Hop aggregate) { _aggregates = new HashMap<Long, Hop>(); _aggregates.put(aggregate.getHopID(), aggregate); @@ -863,6 +879,9 @@ public class PlanSelectionFuseCostBased extends PlanSelection public void addFusedInput(long hopID) { _fusedInputs.add(hopID); } + public void setContainsMatMult() { + containsMatMult = true; + } public boolean isMergable(AggregateInfo that) { //check independence boolean ret = _aggregates.size()<3
