[SYSTEMML-1979] Improved codegen cost model (sparsity, minor fixes) This patch improves the codegen cost model to correctly account for the compute workload of sparse matrix multiplications as well as sparse and dense input sizes. Furthermore, this patch also includes minor fixes of eviction- and broadcast-aware cost corrections. Overall, these changes address special cases of sparse large-scale (i.e., distributed) scenarios, where the codegen optimizer picked suboptimal plans.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/cb1d7928 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/cb1d7928 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/cb1d7928 Branch: refs/heads/master Commit: cb1d792826411b144909b2168929c4f33620b02a Parents: 381d1d6 Author: Matthias Boehm <mboe...@gmail.com> Authored: Tue Oct 31 22:25:38 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Wed Nov 1 02:25:34 2017 -0700 ---------------------------------------------------------------------- .../opt/PlanSelectionFuseCostBasedV2.java | 29 ++++++++++++++------ 1 file changed, 21 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/cb1d7928/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java index 1f670b3..9302573 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java +++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java @@ -89,7 +89,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection private static final double WRITE_BANDWIDTH_IO = 512*1024*1024; //512MB/s private static final double WRITE_BANDWIDTH_MEM = 2d*1024*1024*1024; //2GB/s private static final double READ_BANDWIDTH_MEM = 32d*1024*1024*1024; //32GB/s - private static final double READ_BANDWIDTH_BROADCAST = WRITE_BANDWIDTH_MEM/4; + private static final double READ_BANDWIDTH_BROADCAST = WRITE_BANDWIDTH_IO/4; private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //1GFLOPs/core * InfrastructureAnalyzer.getLocalParallelism(); @@ -329,7 +329,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //get partition input reads (at least read once) for( Long hopID : part.getInputs() ) { Hop hop = memo.getHopRefs().get(hopID); - costs += getSize(hop) * 8 / READ_BANDWIDTH_MEM; + costs += getSafeMemEst(hop) / READ_BANDWIDTH_MEM; } return costs; } @@ -355,6 +355,16 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection .mapToDouble(e -> e.getValue()).sum(); } + private static double sumInputMemoryEstimates(CPlanMemoTable memo, CostVector vect) { + return vect.inSizes.keySet().stream() + .mapToDouble(e -> getSafeMemEst(memo.getHopRefs().get(e))).sum(); + } + + private static double getSafeMemEst(Hop hop) { + return !hop.dimsKnown() ? getSize(hop) * 8 + : hop.getMemEstimate(); + } + private static long getSize(Hop hop) { return Math.max(hop.getDim1(),1) * Math.max(hop.getDim2(),1); @@ -603,7 +613,6 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection private static boolean isRowAggOp(Hop hop){ return (hop instanceof AggUnaryOp || hop instanceof AggBinaryOp - || (hop instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)hop)) || HopRewriteUtils.isBinary(hop, OpOp2.CBIND)); } @@ -840,19 +849,20 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //add costs for opened fused operator if( opened ) { + double memInputs = sumInputMemoryEstimates(memo, costVect); double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH_MEM - + Math.max(costVect.getInputSize() * 8 / READ_BANDWIDTH_MEM, + + Math.max(memInputs / READ_BANDWIDTH_MEM, costVect.computeCosts/ COMPUTE_BANDWIDTH); //read correction for distributed computation - Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID()); - if( driver.getMemEstimate() > OptimizerUtils.getLocalMemBudget() ) + if( memInputs > OptimizerUtils.getLocalMemBudget() ) tmpCosts += costVect.getSideInputSize() * 8 / READ_BANDWIDTH_BROADCAST; //sparsity correction for outer-product template (and sparse-safe cell) + Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID()); if( best != null && best.type == TemplateType.OUTER ) tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST; //write correction for known evictions in CP - else if( driver.getMemEstimate() < OptimizerUtils.getLocalMemBudget() - && sumTmpInputOutputSize(memo, costVect) > LazyWriteBuffer.getWriteBufferSize() ) + else if( memInputs <= OptimizerUtils.getLocalMemBudget() + && sumTmpInputOutputSize(memo, costVect)*8 > LazyWriteBuffer.getWriteBufferSize() ) tmpCosts += costVect.outSize * 8 / WRITE_BANDWIDTH_IO; costs += tmpCosts; if( LOG.isTraceEnabled() ) { @@ -997,6 +1007,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //outer product template w/ matrix-matrix //or row template w/ matrix-vector or matrix-matrix costs = 2 * current.getInput().get(0).getDim2(); + if( current.getInput().get(0).dimsKnown(true) ) + costs *= current.getInput().get(0).getSparsity(); } else if( current instanceof AggUnaryOp) { switch(((AggUnaryOp)current).getOp()) { @@ -1048,6 +1060,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection //ensures that input sizes are not double counted inSizes.put(hopID, inputSize); } + @SuppressWarnings("unused") public double getInputSize() { return inSizes.values().stream() .mapToDouble(d -> d.doubleValue()).sum();