Repository: systemml
Updated Branches:
  refs/heads/master 9481bef4e -> f2fbd99e7


[SYTEMML-2160] Fix codegen optimizer sparsity-aware cost-based pruning

The codegen optimizer computes lower bound costs for cost-based pruning
based on static and plan-dependent costs. So far these costs for compute
and I/O did not take potential sparsity exploitation across operations
into account, leading to suboptimal plans because of too eager pruning.
This patch makes the lower bound computation sparsity aware which
improved the runtime of ALS-CG over the amazon dataset by more than 2x.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/03c050d4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/03c050d4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/03c050d4

Branch: refs/heads/master
Commit: 03c050d4b50abc2808b1f844491b3d1e811b1a55
Parents: 9481bef
Author: Matthias Boehm <[email protected]>
Authored: Thu Feb 22 18:29:36 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Thu Feb 22 18:29:36 2018 -0800

----------------------------------------------------------------------
 .../sysml/hops/codegen/opt/PlanAnalyzer.java    |  4 ++-
 .../sysml/hops/codegen/opt/PlanPartition.java   |  9 +++++-
 .../opt/PlanSelectionFuseCostBasedV2.java       | 30 ++++++++++++++++----
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 10 ++-----
 4 files changed, 39 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/03c050d4/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
index 7d522b3..1822217 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
@@ -59,7 +59,9 @@ public class PlanAnalyzer
                        HashSet<Long> Pnpc = 
getNodesWithNonPartitionConsumers(R, partition, memo);
                        InterestingPoint[] Mext = !ext ? null : 
                                getMaterializationPointsExt(R, partition, M, 
memo);
-                       ret.add(new PlanPartition(partition, R, I, Pnpc, M, 
Mext));
+                       boolean hasOuter = partition.stream()
+                               .anyMatch(k -> memo.contains(k, 
TemplateType.OUTER));
+                       ret.add(new PlanPartition(partition, R, I, Pnpc, M, 
Mext, hasOuter));
                }
                
                return ret;

http://git-wip-us.apache.org/repos/asf/systemml/blob/03c050d4/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java
index fbf7d9f..58620cb 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java
@@ -42,14 +42,17 @@ public class PlanPartition
        //interesting operator dependencies
        private InterestingPoint[] _matPointsExt;
        
+       //indicator if the partitions contains outer templates
+       private final boolean _hasOuter;
        
-       public PlanPartition(HashSet<Long> P, HashSet<Long> R, HashSet<Long> I, 
HashSet<Long> Pnpc, ArrayList<Long> M, InterestingPoint[] Mext) {
+       public PlanPartition(HashSet<Long> P, HashSet<Long> R, HashSet<Long> I, 
HashSet<Long> Pnpc, ArrayList<Long> M, InterestingPoint[] Mext, boolean 
hasOuter) {
                _nodes = P;
                _roots = R;
                _inputs = I;
                _nodesNpc = Pnpc;
                _matPoints = M;
                _matPointsExt = Mext;
+               _hasOuter = hasOuter;
        }
        
        public HashSet<Long> getPartition() {
@@ -79,4 +82,8 @@ public class PlanPartition
        public void setMatPointsExt(InterestingPoint[] points) {
                _matPointsExt = points;
        }
+       
+       public boolean hasOuter() {
+               return _hasOuter;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/03c050d4/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 6ed562a..6f34497 100644
--- 
a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ 
b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -176,8 +176,8 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                                getComputeCosts(memo.getHopRefs().get(hopID), 
computeCosts);
                        
                        //prepare pruning helpers and prune memo table w/ 
determined mat points
-                       StaticCosts costs = new StaticCosts(computeCosts, 
sumComputeCost(computeCosts), 
-                               getReadCost(part, memo), 
getWriteCost(part.getRoots(), memo));
+                       StaticCosts costs = new StaticCosts(computeCosts, 
sumComputeCost(computeCosts),
+                               getReadCost(part, memo), 
getWriteCost(part.getRoots(), memo), minOuterSparsity(part, memo));
                        ReachabilityGraph rgraph = STRUCTURAL_PRUNING ? new 
ReachabilityGraph(part, memo) : null;
                        if( STRUCTURAL_PRUNING ) {
                                
part.setMatPointsExt(rgraph.getSortedSearchSpace());
@@ -287,8 +287,7 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                        }
                        //skip plans with branch and bound pruning (cost)
                        else if( COST_PRUNING ) {
-                               double lbC = Math.max(costs._read, 
costs._compute) + costs._write
-                                       + getMaterializationCost(part, 
matPoints, memo, plan);
+                               double lbC = getLowerBoundCosts(part, 
matPoints, memo, costs, plan);
                                if( lbC >= bestC ) {
                                        long skip = getNumSkipPlans(plan);
                                        if( LOG.isTraceEnabled() )
@@ -354,6 +353,18 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                return UtilFunctions.pow(2, plan.length-pos-1);
        }
        
+       private static double getLowerBoundCosts(PlanPartition part, 
InterestingPoint[] M, CPlanMemoTable memo, StaticCosts costs, boolean[] plan) {
+               //compute the lower bound from static and plan-dependent costs
+               double lb = Math.max(costs._read, costs._compute) + costs._write
+                       + getMaterializationCost(part, M, memo, plan);
+               
+               //if the partition contains outer templates, we need to correct 
the lower bound
+               if( part.hasOuter() )
+                       lb *= costs._minSparsity;
+               
+               return lb;
+       }
+       
        private static double getMaterializationCost(PlanPartition part, 
InterestingPoint[] M, CPlanMemoTable memo, boolean[] plan) {
                double costs = 0;
                //currently active materialization points
@@ -403,6 +414,13 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                        .mapToDouble(d -> d/COMPUTE_BANDWIDTH).sum();
        }
        
+       private static double minOuterSparsity(PlanPartition part, 
CPlanMemoTable memo) {
+               return !part.hasOuter() ? 1.0 : part.getPartition().stream()
+                       .map(k -> 
HopRewriteUtils.getLargestInput(memo.getHopRefs().get(k)))
+                       .mapToDouble(h -> h.dimsKnown(true) ? h.getSparsity() : 
SPARSE_SAFE_SPARSITY_EST)
+                       .min().orElse(SPARSE_SAFE_SPARSITY_EST);
+       }
+       
        private static double sumTmpInputOutputSize(CPlanMemoTable memo, 
CostVector vect) {
                //size of intermediate inputs and outputs, i.e., output and 
inputs other than treads
                return vect.outSize + vect.inSizes.entrySet().stream()
@@ -1234,11 +1252,13 @@ public class PlanSelectionFuseCostBasedV2 extends 
PlanSelection
                public final double _compute;
                public final double _read;
                public final double _write;
-               public StaticCosts(HashMap<Long,Double> allComputeCosts, double 
computeCost, double readCost, double writeCost) {
+               public final double _minSparsity;
+               public StaticCosts(HashMap<Long,Double> allComputeCosts, double 
computeCost, double readCost, double writeCost, double minSparsity) {
                        _computeCosts = allComputeCosts;
                        _compute = computeCost;
                        _read = readCost;
                        _write = writeCost;
+                       _minSparsity = minSparsity;
                }
                public double getMinCosts() {
                        return Math.max(_read, _compute) + _write;

http://git-wip-us.apache.org/repos/asf/systemml/blob/03c050d4/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 0484bb3..dce21ab 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -20,6 +20,7 @@
 package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 
@@ -309,13 +310,8 @@ public class HopRewriteUtils
        }
        
        public static Hop getLargestInput(Hop hop) {
-               Hop max = null; long maxSize = -1;
-               for(Hop in : hop.getInput())
-                       if(in.getLength() > maxSize) {
-                               max = in;
-                               maxSize = in.getLength();
-                       }
-               return max;
+               return hop.getInput().stream()
+                       .max(Comparator.comparing(h -> 
h.getLength())).orElse(null);
        }
        
        public static Hop createDataGenOp( Hop input, double value ) 

Reply via email to