Repository: systemml Updated Branches: refs/heads/master c2124544d -> c170374e7
[MINOR][SYSTEMML-1741] Improved graph traversal on codegen plan costing This patch makes a minor improvement of the codegen graph traversal for plan costing by stopping at partition boundaries instead of traversing to the leafs without taking the costs into account. On lenet this led to a minor but consistent improvement of codegen compilation overheads from 14.4s to 13.1s (for enumerating 474,840 plans). Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0221fbcd Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0221fbcd Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0221fbcd Branch: refs/heads/master Commit: 0221fbcd0d3c09c0aef37ee7e77863ecdc0446b4 Parents: c212454 Author: Matthias Boehm <[email protected]> Authored: Mon Aug 7 13:17:32 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Aug 9 13:52:49 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/opt/PlanSelection.java | 3 +- .../opt/PlanSelectionFuseCostBasedV2.java | 50 +++++++++----------- 2 files changed, 24 insertions(+), 29 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0221fbcd/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java index 23cc128..d18d156 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java @@ -98,8 +98,7 @@ public abstract class PlanSelection return Integer.compare(o1.type.getRank(), o2.type.getRank()); //for same type, prefer plan with more refs - return Integer.compare( - 3-o1.countPlanRefs(), 3-o2.countPlanRefs()); + return Integer.compare(-o1.countPlanRefs(), -o2.countPlanRefs()); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/0221fbcd/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 2fa0de7..717a059 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -744,11 +744,9 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //memoization per hop id and cost vector to account for redundant //computation without double counting materialized results or compute //costs of complex operation DAGs within a single fused operator - VisitMarkCost tag = new VisitMarkCost(current.getHopID(), - (costsCurrent==null || currentType==TemplateType.MAGG)?0:costsCurrent.ID); - if( visited.contains(tag) ) - return 0; - visited.add(tag); + if( !visited.add(new VisitMarkCost(current.getHopID(), + (costsCurrent==null || currentType==TemplateType.MAGG)?0:costsCurrent.ID)) ) + return 0; //already existing //open template if necessary, including memoization //under awareness of current plan choice @@ -792,8 +790,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection } //add compute costs of current operator to costs vector - if( part.getPartition().contains(current.getHopID()) ) - costVect.computeCosts += computeCosts.get(current.getHopID()); + costVect.computeCosts += computeCosts.get(current.getHopID()); //process children recursively for( int i=0; i< current.getInput().size(); i++ ) { @@ -803,34 +800,33 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection else if( best!=null && isImplicitlyFused(current, i, best.type) ) costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c)); else { //include children and I/O costs - costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null); + if( part.getPartition().contains(c.getHopID()) ) + costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null); if( costVect != null && c.getDataType().isMatrix() ) costVect.addInputSize(c.getHopID(), getSize(c)); } } //add costs for opened fused operator - if( part.getPartition().contains(current.getHopID()) ) { - if( opened ) { - if( LOG.isTraceEnabled() ) { - String type = (best !=null) ? best.type.name() : "HOP"; - LOG.trace("Cost vector ("+type+" "+current.getHopID()+"): "+costVect); - } - double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH //time for output write - + Math.max(costVect.getSumInputSizes() * 8 / READ_BANDWIDTH, - costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH); - //sparsity correction for outer-product template (and sparse-safe cell) - if( best != null && best.type == TemplateType.OUTER ) { - Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID()); - tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST; - } - costs += tmpCosts; + if( opened ) { + if( LOG.isTraceEnabled() ) { + String type = (best !=null) ? best.type.name() : "HOP"; + LOG.trace("Cost vector ("+type+" "+current.getHopID()+"): "+costVect); } - //add costs for non-partition read in the middle of fused operator - else if( part.getExtConsumed().contains(current.getHopID()) ) { - costs += rGetPlanCosts(memo, current, visited, - part, matPoints, plan, computeCosts, null, null); + double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH //time for output write + + Math.max(costVect.getSumInputSizes() * 8 / READ_BANDWIDTH, + costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH); + //sparsity correction for outer-product template (and sparse-safe cell) + if( best != null && best.type == TemplateType.OUTER ) { + Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID()); + tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST; } + costs += tmpCosts; + } + //add costs for non-partition read in the middle of fused operator + else if( part.getExtConsumed().contains(current.getHopID()) ) { + costs += rGetPlanCosts(memo, current, visited, + part, matPoints, plan, computeCosts, null, null); } //sanity check non-negative costs
