Repository: systemml Updated Branches: refs/heads/master 6de8f051d -> 311e4aac9
[SYSTEMML-1968] Improved codegen optimizer (cost, mat points, pruning) This patch improves the cost-based codegen optimizer to address wrong fusion decision for large-scale computations. In detail, this includes: 1) Cost model: The cost model now accounts the broadcast cost for side inputs in distributed spark operations. Furthermore, this also includes a fix of calculating the compute costs in case of a mix of row and cell operations of different dimensions. 2) Interesting points: To enable the reasoning about side inputs, we now also consider template switches from cell to row templates as interesting points. 3) Pruning of row templates: The above changes also revealed hidden issues in the pruning of unnecessary row templates (conversion to cell templates), which mistakenly removed necessary row templates, which ultimately led to runtime errors. On a large-scale scenario of L2SVM over a 200M x 100 dense input (160GB), this patch improved the end-to-end runtime for 20 outer iterations from 942s to 273s (w/o codegen: 644s). Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/311e4aac Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/311e4aac Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/311e4aac Branch: refs/heads/master Commit: 311e4aac9833397908a083d0a48d5bd3ba086283 Parents: 6de8f05 Author: Matthias Boehm <[email protected]> Authored: Sat Oct 21 16:41:53 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Oct 21 17:15:38 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/opt/PlanAnalyzer.java | 2 +- .../opt/PlanSelectionFuseCostBasedV2.java | 131 ++++++++++--------- .../hops/codegen/template/CPlanMemoTable.java | 25 ++-- .../runtime/codegen/LibSpoofPrimitives.java | 6 +- 4 files changed, 91 insertions(+), 73 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/311e4aac/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java index 9910814..7d522b3 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java @@ -267,7 +267,7 @@ public class PlanAnalyzer for( int i=0; i<3; i++ ) { if( refs[i] < 0 ) continue; List<TemplateType> tmp = memo.getDistinctTemplateTypes(hopID, i, true); - if( memo.containsNotIn(refs[i], tmp, true, true) ) + if( memo.containsNotIn(refs[i], tmp, true) ) ret.add(new InterestingPoint(DecisionType.TEMPLATE_CHANGE, hopID, refs[i])); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/311e4aac/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index d2ed3ac..10875e8 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -86,6 +86,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //to cover result allocation, write into main memory, and potential evictions private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024; //2GB/s private static final double READ_BANDWIDTH = 32d*1024*1024*1024; //32GB/s + private static final double READ_BANDWIDTH_BROADCAST = WRITE_BANDWIDTH/4; private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //2GFLOPs/core * InfrastructureAnalyzer.getLocalParallelism(); @@ -146,7 +147,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection getComputeCosts(memo.getHopRefs().get(hopID), computeCosts); //prepare pruning helpers and prune memo table w/ determined mat points - StaticCosts costs = new StaticCosts(computeCosts, getComputeCost(computeCosts, memo), + StaticCosts costs = new StaticCosts(computeCosts, sumComputeCost(computeCosts), getReadCost(part, memo), getWriteCost(part.getRoots(), memo)); ReachabilityGraph rgraph = STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null; if( STRUCTURAL_PRUNING ) { @@ -339,14 +340,9 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection return costs; } - private static double getComputeCost(HashMap<Long, Double> computeCosts, CPlanMemoTable memo) { - double costs = 0; - for( Entry<Long,Double> e : computeCosts.entrySet() ) { - Hop mainInput = memo.getHopRefs() - .get(e.getKey()).getInput().get(0); - costs += getSize(mainInput) * e.getValue() / COMPUTE_BANDWIDTH; - } - return costs; + private static double sumComputeCost(HashMap<Long, Double> computeCosts) { + return computeCosts.values().stream() + .mapToDouble(d -> d/COMPUTE_BANDWIDTH).sum(); } private static long getSize(Hop hop) { @@ -567,33 +563,39 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } } - private static boolean isRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { - //consider all aggregations other than root operation - MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW); - boolean ret = true; - for(int i=0; i<3; i++) - if( me.isPlanRef(i) ) - ret &= rIsRowTemplateWithoutAggOrVects(memo, - current.getInput().get(i), visited); - return ret; + private static HashSet<Long> getRowAggOpsWithRowRef(CPlanMemoTable memo, PlanPartition part) { + HashSet<Long> refAggs = new HashSet<>(); + for( Long hopID : part.getPartition() ) { + if( !memo.contains(hopID, TemplateType.ROW) ) continue; + MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); + for(int i=0; i<3; i++) + if( me.isPlanRef(i) && memo.contains(me.input(i), TemplateType.ROW) + && isRowAggOp(memo.getHopRefs().get(me.input(i)))) + refAggs.add(me.input(i)); + } + return refAggs; } - private static boolean rIsRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited) { + private static boolean rIsRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited, boolean inclRoot) { if( visited.contains(current.getHopID()) ) return true; - boolean ret = true; MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW); - for(int i=0; i<3; i++) + boolean ret = !inclRoot || !isRowAggOp(current); + for(int i=0; i<3 && ret; i++) if( me!=null && me.isPlanRef(i) ) - ret &= rIsRowTemplateWithoutAggOrVects(memo, current.getInput().get(i), visited); - ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp - || HopRewriteUtils.isBinary(current, OpOp2.CBIND)); + ret &= rIsRowTemplateWithoutAggOrVects(memo, + current.getInput().get(i), visited, true); visited.add(current.getHopID()); return ret; } + private static boolean isRowAggOp(Hop hop){ + return (hop instanceof AggUnaryOp || hop instanceof AggBinaryOp + || HopRewriteUtils.isBinary(hop, OpOp2.CBIND)); + } + private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable memo, PlanPartition part) { //prune invalid row entries w/ violated blocksize constraint @@ -613,9 +615,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection && HopRewriteUtils.isTransposeOperation(in)); if( isSpark && !validNcol ) { List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); - memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(blacklist)); - if( !memo.contains(hopID) ) - memo.removeAllRefTo(hopID); + memo.remove(memo.getHopRefs().get(hopID), TemplateType.ROW); + memo.removeAllRefTo(hopID, TemplateType.ROW); if( LOG.isTraceEnabled() ) { LOG.trace("Removed row memo table entries w/ violated blocksize constraint ("+hopID+"): " + Arrays.toString(blacklist.toArray(new MemoTableEntry[0]))); @@ -625,10 +626,11 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } //prune row aggregates with pure cellwise operations + HashSet<Long> refAggs = getRowAggOpsWithRowRef(memo, part); for( Long hopID : part.getPartition() ) { MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW); if( me != null && me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL) - && isRowTemplateWithoutAggOrVects(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) { + && rIsRowTemplateWithoutAggOrVects(memo, memo.getHopRefs().get(hopID), new HashSet<Long>(), refAggs.contains(hopID)) ) { List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(blacklist)); if( LOG.isTraceEnabled() ) { @@ -698,28 +700,25 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //i.e., plans that become invalid after the previous pruning step long hopID = current.getHopID(); if( part.getPartition().contains(hopID) && memo.contains(hopID, TemplateType.ROW) ) { - for( MemoTableEntry me : memo.get(hopID) ) { - if( me.type==TemplateType.ROW ) { - //convert leaf node with pure vector inputs - if( !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current) ) { + for( MemoTableEntry me : memo.get(hopID, TemplateType.ROW) ) { + //convert leaf node with pure vector inputs + if( !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current) ) { + me.type = TemplateType.CELL; + if( LOG.isTraceEnabled() ) + LOG.trace("Converted leaf memo table entry from row to cell: "+me); + } + + //convert inner node without row template input + if( me.hasPlanRef() && !ROW_TPL.open(current) ) { + boolean hasRowInput = false; + for( int i=0; i<3; i++ ) + if( me.isPlanRef(i) ) + hasRowInput |= memo.contains(me.input(i), TemplateType.ROW); + if( !hasRowInput ) { me.type = TemplateType.CELL; if( LOG.isTraceEnabled() ) - LOG.trace("Converted leaf memo table entry from row to cell: "+me); - } - - //convert inner node without row template input - if( me.hasPlanRef() && !ROW_TPL.open(current) ) { - boolean hasRowInput = false; - for( int i=0; i<3; i++ ) - if( me.isPlanRef(i) ) - hasRowInput |= memo.contains(me.input(i), TemplateType.ROW); - if( !hasRowInput ) { - me.type = TemplateType.CELL; - if( LOG.isTraceEnabled() ) - LOG.trace("Converted inner memo table entry from row to cell: "+me); - } + LOG.trace("Converted inner memo table entry from row to cell: "+me); } - } } } @@ -834,14 +833,16 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection String type = (best !=null) ? best.type.name() : "HOP"; LOG.trace("Cost vector ("+type+" "+currentHopId+"): "+costVect); } - double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH //time for output write - + Math.max(costVect.getSumInputSizes() * 8 / READ_BANDWIDTH, - costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH); + double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH + + Math.max(costVect.getInputSize() * 8 / READ_BANDWIDTH, + costVect.computeCosts/ COMPUTE_BANDWIDTH); + //read correction for distributed computation + Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID()); + if( driver.getMemEstimate() > OptimizerUtils.getLocalMemBudget() ) + tmpCosts += costVect.getSideInputSize() * 8 / READ_BANDWIDTH_BROADCAST; //sparsity correction for outer-product template (and sparse-safe cell) - if( best != null && best.type == TemplateType.OUTER ) { - Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID()); + if( best != null && best.type == TemplateType.OUTER ) tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST; - } costs += tmpCosts; } //add costs for non-partition read in the middle of fused operator @@ -978,12 +979,9 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection costs = 1; } else if( current instanceof AggBinaryOp ) { - //outer product template - if( HopRewriteUtils.isOuterProductLikeMM(current) ) - costs = 2 * current.getInput().get(0).getDim2(); - //row template w/ matrix-vector or matrix-matrix - else - costs = 2 * current .getDim2(); + //outer product template w/ matrix-matrix + //or row template w/ matrix-vector or matrix-matrix + costs = 2 * current.getInput().get(0).getDim2(); } else if( current instanceof AggUnaryOp) { switch(((AggUnaryOp)current).getOp()) { @@ -993,10 +991,15 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection case MAX: costs = 1; break; default: LOG.warn("Cost model not " - + "implemented yet for: "+((AggUnaryOp)current).getOp()); + + "implemented yet for: "+((AggUnaryOp)current).getOp()); } } + //scale by current output size in order to correctly reflect + //a mix of row and cell operations in the same fused operator + //(e.g., row template with fused column vector operations) + costs *= getSize(current); + computeCosts.put(current.getHopID(), costs); } @@ -1025,8 +1028,14 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //ensures that input sizes are not double counted inSizes.put(hopID, inputSize); } - public double getSumInputSizes() { + public double getInputSize() { + return inSizes.values().stream() + .mapToDouble(d -> d.doubleValue()).sum(); + } + public double getSideInputSize() { + double max = getMaxInputSize(); return inSizes.values().stream() + .filter(d -> d < max) .mapToDouble(d -> d.doubleValue()).sum(); } public double getMaxInputSize() { http://git-wip-us.apache.org/repos/asf/systemml/blob/311e4aac/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 99ffc8d..5eedc7b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -95,11 +95,10 @@ public class CPlanMemoTable .anyMatch(p -> (!checkClose||!p.isClosed()) && probe.contains(p.type)); } - public boolean containsNotIn(long hopID, Collection<TemplateType> types, - boolean checkChildRefs, boolean excludeCell) { + public boolean containsNotIn(long hopID, + Collection<TemplateType> types, boolean checkChildRefs) { return contains(hopID) && get(hopID).stream() - .anyMatch(p -> (!checkChildRefs || p.hasPlanRef()) - && (!excludeCell || p.type!=TemplateType.CELL) + .anyMatch(p -> (!checkChildRefs || p.hasPlanRef()) && p.isValid() && !types.contains(p.type)); } @@ -153,14 +152,22 @@ public class CPlanMemoTable .removeIf(p -> blackList.contains(p)); } + public void remove(Hop hop, TemplateType type) { + _plans.get(hop.getHopID()) + .removeIf(p -> p.type == type); + } + public void removeAllRefTo(long hopID) { + removeAllRefTo(hopID, null); + } + + public void removeAllRefTo(long hopID, TemplateType type) { //recursive removal of references for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() ) { - if( !e.getValue().isEmpty() ) { - e.getValue().removeIf(p -> p.hasPlanRefTo(hopID)); - if( e.getValue().isEmpty() ) - removeAllRefTo(e.getKey()); - } + if( e.getValue().isEmpty() || e.getKey()==hopID ) + continue; + e.getValue().removeIf(p -> p.hasPlanRefTo(hopID) + && (type==null || p.type==type)); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/311e4aac/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java index 7624d96..91fde5e 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java @@ -1788,11 +1788,13 @@ public class LibSpoofPrimitives //dynamic memory management public static void setupThreadLocalMemory(int numVectors, int len) { - setupThreadLocalMemory(numVectors, len, -1); + if( numVectors > 0 ) + setupThreadLocalMemory(numVectors, len, -1); } public static void setupThreadLocalMemory(int numVectors, int len, int len2) { - memPool.set(new VectorBuffer(numVectors, len, len2)); + if( numVectors > 0 ) + memPool.set(new VectorBuffer(numVectors, len, len2)); } public static void cleanupThreadLocalMemory() {
