[SYSTEMML-1514] Fix codegen cost estimation (two-level memoization)

This patch fixes the cost estimation of fusion plans for complex fused
operators with internal DAG structures, where we mistakenly double
counted compute costs. Similarly, we incorrectly double counted costs of
materialized intermediates. The core idea is a two-level memoization,
i.e., memoization for pairs of hops and cost vectors which allows a
proper memoization but at the same time the evaluation of costs for
overlapping fused operators with redundant computation. 

Additionally, this patch also hardens the compilation of
multi-aggregates to ensure matching input dimensions and to exclude
partial rowwise fusion plans that cover matrix multiplications.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/fb82482b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/fb82482b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/fb82482b

Branch: refs/heads/master
Commit: fb82482b09f851b428fc1c0e70994e9e6d94c007
Parents: f9f70b3
Author: Matthias Boehm <[email protected]>
Authored: Tue Apr 11 23:08:30 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Apr 11 23:08:30 2017 -0700

----------------------------------------------------------------------
 .../template/PlanSelectionFuseCostBased.java    | 47 ++++++++++++++------
 1 file changed, 33 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/fb82482b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index 3c98090..8ba2490 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -31,6 +31,7 @@ import java.util.Iterator;
 import java.util.List;
 
 import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.hops.AggBinaryOp;
@@ -48,7 +49,7 @@ import 
org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
 import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
 import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
-
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
 
 /**
  * This cost-based plan selection algorithm chooses fused operators
@@ -68,7 +69,8 @@ public class PlanSelectionFuseCostBased extends PlanSelection
        private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 
//2GFLOPs/core
                * InfrastructureAnalyzer.getLocalParallelism();
        
-       private final static TemplateRow ROW_TPL = new TemplateRow();
+       private static final IDSequence COST_ID = new IDSequence();
+       private static final TemplateRow ROW_TPL = new TemplateRow();
        
        @Override
        public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) 
@@ -315,10 +317,12 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                                LOG.trace(info);
                }
                
+               //filter aggregations w/ matmults to ensure consistent dims
                //sort aggregations by num dependencies to simplify merging
                //clusters of aggregations with parallel dependencies
-               aggInfos = aggInfos.stream().sorted(Comparator.comparing(
-                       a -> a._inputAggs.size())).collect(Collectors.toList());
+               aggInfos = aggInfos.stream().filter(a -> !a.containsMatMult)
+                       .sorted(Comparator.comparing(a -> a._inputAggs.size()))
+                       .collect(Collectors.toList());
                
                //greedy grouping of multi-agg candidates
                boolean converged = false;
@@ -409,6 +413,10 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                        aggInfo.addInputAggregate(current.getHopID());
                }
                
+               //collect included matrix multiplications
+               if( type != null && HopRewriteUtils.isMatrixMultiply(current) )
+                       aggInfo.setContainsMatMult();
+               
                //recursively process children
                MemoTableEntry me = (type!=null) ? 
memo.getBest(current.getHopID()) : null;
                for( int i=0; i< current.getInput().size(); i++ ) {
@@ -612,7 +620,7 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                //READ costs by the input sizes, and COMPUTE by operation 
specific FLOP
                //counts times number of cells of main input, disregarding 
sparsity for now.
                
-               HashSet<Long> visited = new HashSet<Long>();
+               HashSet<Pair<Long,Long>> visited = new 
HashSet<Pair<Long,Long>>();
                double costs = 0;
                for( Long hopID : R )
                        costs += rGetPlanCosts(memo, memo._hopRefs.get(hopID), 
@@ -620,11 +628,17 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                return costs;
        }
        
-       private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, 
HashSet<Long> visited, HashSet<Long> partition, 
-                       ArrayList<Long> M, boolean[] plan, HashMap<Long, 
Double> computeCosts, OperatorStats costsCurrent, TemplateType currentType) 
+       private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, 
HashSet<Pair<Long,Long>> visited, HashSet<Long> partition, 
+                       ArrayList<Long> M, boolean[] plan, HashMap<Long, 
Double> computeCosts, CostVector costsCurrent, TemplateType currentType) 
        {
-               if( visited.contains(current.getHopID()) )
-                       return 0; //dont double count 
+               //memoization per hop id and cost vector to account for 
redundant
+               //computation without double counting materialized results or 
compute
+               //costs of complex operation DAGs within a single fused operator
+               Pair<Long,Long> tag = Pair.of(current.getHopID(), 
+                       (costsCurrent==null)?0:costsCurrent.ID);
+               if( visited.contains(tag) )
+                       return 0; 
+               visited.add(tag);       
                
                //open template if necessary, including memoization
                //under awareness of current plan choice
@@ -637,7 +651,6 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                                        .filter(p -> 
hasNoRefToMaterialization(p, M, plan))
                                        .min(new 
BasicPlanComparator()).orElse(null);
                                opened = true;
-                               visited.add(current.getHopID());
                        }
                        else {
                                best = memo.get(current.getHopID()).stream()
@@ -649,8 +662,8 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                }
                
                //create new cost vector if opened, initialized with write costs
-               OperatorStats costVect = !opened ? costsCurrent : 
-                       new 
OperatorStats(Math.max(current.getDim1(),1)*Math.max(current.getDim2(),1));
+               CostVector costVect = !opened ? costsCurrent : 
+                       new 
CostVector(Math.max(current.getDim1(),1)*Math.max(current.getDim2(),1));
                
                //add compute costs of current operator to costs vector 
                if( partition.contains(current.getHopID()) )
@@ -821,12 +834,14 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                return ret;
        }
        
-       private static class OperatorStats {
+       private static class CostVector {
+               public final long ID;
                public final double outSize; 
                public double computeCosts = 0;
                public final HashMap<Long, Double> inSizes = new HashMap<Long, 
Double>();
                
-               public OperatorStats(double outputSize) {
+               public CostVector(double outputSize) {
+                       ID = COST_ID.getNextID();
                        outSize = outputSize;
                }
                public void addInputSize(long hopID, double inputSize) {
@@ -853,6 +868,7 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                public final HashMap<Long,Hop> _aggregates;
                public final HashSet<Long> _inputAggs = new HashSet<Long>();
                public final HashSet<Long> _fusedInputs = new HashSet<Long>();
+               public boolean containsMatMult = false;
                public AggregateInfo(Hop aggregate) {
                        _aggregates = new HashMap<Long, Hop>();
                        _aggregates.put(aggregate.getHopID(), aggregate);
@@ -863,6 +879,9 @@ public class PlanSelectionFuseCostBased extends 
PlanSelection
                public void addFusedInput(long hopID) {
                        _fusedInputs.add(hopID);
                }
+               public void setContainsMatMult() {
+                       containsMatMult = true;
+               }
                public boolean isMergable(AggregateInfo that) {
                        //check independence
                        boolean ret = _aggregates.size()<3 

Reply via email to