[SYSTEMML-2243] Improved spark mapmm broadcast selection (worstcase nnz) This patch improves the operator selection logic for spark mapmm (broadcast-based matrix multiply) operations. For equal input dimensions and unknown nnz, we now also take the worstcase nnz estimates, via the input memory estimates into account. This is important for outer-product-like matrix multiplications with sparse inputs that are extracted with indexing operations in the same DAG and hence have unknown nnz meta data.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/61925ab4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/61925ab4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/61925ab4 Branch: refs/heads/master Commit: 61925ab4912af57cdcbf7f293e91340f3127bf35 Parents: 5d149a0 Author: Matthias Boehm <[email protected]> Authored: Sat Apr 14 01:58:53 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Apr 14 01:58:53 2018 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/AggBinaryOp.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/61925ab4/src/main/java/org/apache/sysml/hops/AggBinaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java index 03a1bb6..9a286d2 100644 --- a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java @@ -223,7 +223,7 @@ public class AggBinaryOp extends Hop implements MultiThreadedHop //matrix mult operation selection part 3 (SPARK type) boolean tmmRewrite = HopRewriteUtils.isTransposeOperation(input1); _method = optFindMMultMethodSpark ( - input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), + input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(), mmtsj, chain, _hasLeftPMInput, tmmRewrite ); @@ -1687,7 +1687,10 @@ public class AggBinaryOp extends Hop implements MultiThreadedHop //apply map mult if one side fits in remote task memory //(if so pick smaller input for distributed cache) //TODO relax requirement of valid CP dimensions once we support broadcast creation from files/RDDs - if( m1SizeP < m2SizeP && m1_rows>=0 && m1_cols>=0 + double em1Size = getInput().get(0).getOutputMemEstimate(); //w/ worst-case estimate + double em2Size = getInput().get(1).getOutputMemEstimate(); //w/ worst-case estimate + if( (m1SizeP < m2SizeP || (m1SizeP==m2SizeP && em1Size<em2Size) ) + && m1_rows>=0 && m1_cols>=0 && OptimizerUtils.isValidCPDimensions(m1_rows, m1_cols) ) { _spBroadcastMemEstimate = m1Size+m1SizeP; return MMultMethod.MAPMM_L;
