Repository: systemml Updated Branches: refs/heads/master 9481bef4e -> f2fbd99e7
[SYTEMML-2160] Fix codegen optimizer sparsity-aware cost-based pruning The codegen optimizer computes lower bound costs for cost-based pruning based on static and plan-dependent costs. So far these costs for compute and I/O did not take potential sparsity exploitation across operations into account, leading to suboptimal plans because of too eager pruning. This patch makes the lower bound computation sparsity aware which improved the runtime of ALS-CG over the amazon dataset by more than 2x. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/03c050d4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/03c050d4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/03c050d4 Branch: refs/heads/master Commit: 03c050d4b50abc2808b1f844491b3d1e811b1a55 Parents: 9481bef Author: Matthias Boehm <[email protected]> Authored: Thu Feb 22 18:29:36 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Thu Feb 22 18:29:36 2018 -0800 ---------------------------------------------------------------------- .../sysml/hops/codegen/opt/PlanAnalyzer.java | 4 ++- .../sysml/hops/codegen/opt/PlanPartition.java | 9 +++++- .../opt/PlanSelectionFuseCostBasedV2.java | 30 ++++++++++++++++---- .../sysml/hops/rewrite/HopRewriteUtils.java | 10 ++----- 4 files changed, 39 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/03c050d4/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java index 7d522b3..1822217 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java @@ -59,7 +59,9 @@ public class PlanAnalyzer HashSet<Long> Pnpc = getNodesWithNonPartitionConsumers(R, partition, memo); InterestingPoint[] Mext = !ext ? null : getMaterializationPointsExt(R, partition, M, memo); - ret.add(new PlanPartition(partition, R, I, Pnpc, M, Mext)); + boolean hasOuter = partition.stream() + .anyMatch(k -> memo.contains(k, TemplateType.OUTER)); + ret.add(new PlanPartition(partition, R, I, Pnpc, M, Mext, hasOuter)); } return ret; http://git-wip-us.apache.org/repos/asf/systemml/blob/03c050d4/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java index fbf7d9f..58620cb 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanPartition.java @@ -42,14 +42,17 @@ public class PlanPartition //interesting operator dependencies private InterestingPoint[] _matPointsExt; + //indicator if the partitions contains outer templates + private final boolean _hasOuter; - public PlanPartition(HashSet<Long> P, HashSet<Long> R, HashSet<Long> I, HashSet<Long> Pnpc, ArrayList<Long> M, InterestingPoint[] Mext) { + public PlanPartition(HashSet<Long> P, HashSet<Long> R, HashSet<Long> I, HashSet<Long> Pnpc, ArrayList<Long> M, InterestingPoint[] Mext, boolean hasOuter) { _nodes = P; _roots = R; _inputs = I; _nodesNpc = Pnpc; _matPoints = M; _matPointsExt = Mext; + _hasOuter = hasOuter; } public HashSet<Long> getPartition() { @@ -79,4 +82,8 @@ public class PlanPartition public void setMatPointsExt(InterestingPoint[] points) { _matPointsExt = points; } + + public boolean hasOuter() { + return _hasOuter; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/03c050d4/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 6ed562a..6f34497 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -176,8 +176,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection getComputeCosts(memo.getHopRefs().get(hopID), computeCosts); //prepare pruning helpers and prune memo table w/ determined mat points - StaticCosts costs = new StaticCosts(computeCosts, sumComputeCost(computeCosts), - getReadCost(part, memo), getWriteCost(part.getRoots(), memo)); + StaticCosts costs = new StaticCosts(computeCosts, sumComputeCost(computeCosts), + getReadCost(part, memo), getWriteCost(part.getRoots(), memo), minOuterSparsity(part, memo)); ReachabilityGraph rgraph = STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null; if( STRUCTURAL_PRUNING ) { part.setMatPointsExt(rgraph.getSortedSearchSpace()); @@ -287,8 +287,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } //skip plans with branch and bound pruning (cost) else if( COST_PRUNING ) { - double lbC = Math.max(costs._read, costs._compute) + costs._write - + getMaterializationCost(part, matPoints, memo, plan); + double lbC = getLowerBoundCosts(part, matPoints, memo, costs, plan); if( lbC >= bestC ) { long skip = getNumSkipPlans(plan); if( LOG.isTraceEnabled() ) @@ -354,6 +353,18 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection return UtilFunctions.pow(2, plan.length-pos-1); } + private static double getLowerBoundCosts(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo, StaticCosts costs, boolean[] plan) { + //compute the lower bound from static and plan-dependent costs + double lb = Math.max(costs._read, costs._compute) + costs._write + + getMaterializationCost(part, M, memo, plan); + + //if the partition contains outer templates, we need to correct the lower bound + if( part.hasOuter() ) + lb *= costs._minSparsity; + + return lb; + } + private static double getMaterializationCost(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo, boolean[] plan) { double costs = 0; //currently active materialization points @@ -403,6 +414,13 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection .mapToDouble(d -> d/COMPUTE_BANDWIDTH).sum(); } + private static double minOuterSparsity(PlanPartition part, CPlanMemoTable memo) { + return !part.hasOuter() ? 1.0 : part.getPartition().stream() + .map(k -> HopRewriteUtils.getLargestInput(memo.getHopRefs().get(k))) + .mapToDouble(h -> h.dimsKnown(true) ? h.getSparsity() : SPARSE_SAFE_SPARSITY_EST) + .min().orElse(SPARSE_SAFE_SPARSITY_EST); + } + private static double sumTmpInputOutputSize(CPlanMemoTable memo, CostVector vect) { //size of intermediate inputs and outputs, i.e., output and inputs other than treads return vect.outSize + vect.inSizes.entrySet().stream() @@ -1234,11 +1252,13 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection public final double _compute; public final double _read; public final double _write; - public StaticCosts(HashMap<Long,Double> allComputeCosts, double computeCost, double readCost, double writeCost) { + public final double _minSparsity; + public StaticCosts(HashMap<Long,Double> allComputeCosts, double computeCost, double readCost, double writeCost, double minSparsity) { _computeCosts = allComputeCosts; _compute = computeCost; _read = readCost; _write = writeCost; + _minSparsity = minSparsity; } public double getMinCosts() { return Math.max(_read, _compute) + _write; http://git-wip-us.apache.org/repos/asf/systemml/blob/03c050d4/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 0484bb3..dce21ab 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -20,6 +20,7 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; @@ -309,13 +310,8 @@ public class HopRewriteUtils } public static Hop getLargestInput(Hop hop) { - Hop max = null; long maxSize = -1; - for(Hop in : hop.getInput()) - if(in.getLength() > maxSize) { - max = in; - maxSize = in.getLength(); - } - return max; + return hop.getInput().stream() + .max(Comparator.comparing(h -> h.getLength())).orElse(null); } public static Hop createDataGenOp( Hop input, double value )
