Repository: systemml Updated Branches: refs/heads/master cca6356f8 -> 0a957e4c9
[SYSTEMML-2468] Improved MNC estimator (avoid final sketch propagation) Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/569806dc Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/569806dc Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/569806dc Branch: refs/heads/master Commit: 569806dcdf3c37bff59ad052884cf5f9af9bd598 Parents: cca6356 Author: Matthias Boehm <[email protected]> Authored: Sun Oct 21 18:43:32 2018 +0200 Committer: Matthias Boehm <[email protected]> Committed: Sun Oct 21 18:43:32 2018 +0200 ---------------------------------------------------------------------- .../hops/estim/EstimatorMatrixHistogram.java | 61 +++++++++++++++++--- 1 file changed, 52 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/569806dc/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java index 83d918a..34f9cac 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java @@ -55,11 +55,15 @@ public class EstimatorMatrixHistogram extends SparsityEstimator @Override public MatrixCharacteristics estim(MMNode root) { + return estim(root, true); + } + + private MatrixCharacteristics estim(MMNode root, boolean topLevel) { //recursive histogram computation of non-leaf nodes if( !root.getLeft().isLeaf() ) - estim(root.getLeft()); //obtain synopsis + estim(root.getLeft(), false); //obtain synopsis if( root.getRight()!=null && !root.getRight().isLeaf() ) - estim(root.getRight()); //obtain synopsis + estim(root.getRight(), false); //obtain synopsis MatrixHistogram h1 = !root.getLeft().isLeaf() ? (MatrixHistogram)root.getLeft().getSynopsis() : new MatrixHistogram(root.getLeft().getData(), _useExcepts); @@ -69,6 +73,12 @@ public class EstimatorMatrixHistogram extends SparsityEstimator //estimate output sparsity based on input histograms double ret = estimIntern(h1, h2, root.getOp(), root.getMisc()); + if( topLevel ) { //fast-path final result + return MatrixHistogram.deriveOutputCharacteristics( + h1, h2, ret, root.getOp(), root.getMisc()); + } + + //sketch propagation for intermediates other than final result MatrixHistogram outMap = MatrixHistogram .deriveOutputHistogram(h1, h2, ret, root.getOp(), root.getMisc()); root.setSynopsis(outMap); @@ -183,13 +193,15 @@ public class EstimatorMatrixHistogram extends SparsityEstimator nnz = (long)(spOut * mnOut); } - //exploit upper bound on nnz based on non-empty rows/cols - nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ? - Math.min((long)h1.rNonEmpty * h2.cNonEmpty, nnz) : nnz; - - //exploit lower bound on nnz based on half-full rows/cols - nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ? - Math.max((long)h1.rNdiv2 * h2.cNdiv2, nnz) : nnz; + if( _useExcepts ) { + //exploit upper bound on nnz based on non-empty rows/cols + nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ? + Math.min((long)h1.rNonEmpty * h2.cNonEmpty, nnz) : nnz; + + //exploit lower bound on nnz based on half-full rows/cols + nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ? + Math.max((long)h1.rNdiv2 * h2.cNdiv2, nnz) : nnz; + } //compute final sparsity return OptimizerUtils.getSparsity( @@ -343,6 +355,37 @@ public class EstimatorMatrixHistogram extends SparsityEstimator } } + public static MatrixCharacteristics deriveOutputCharacteristics(MatrixHistogram h1, MatrixHistogram h2, double spOut, OpCode op, long[] misc) { + switch(op) { + case MM: + return new MatrixCharacteristics(h1.getRows(), h2.getCols(), + OptimizerUtils.getNnz(h1.getRows(), h2.getCols(), spOut)); + case MULT: + case PLUS: + case NEQZERO: + case EQZERO: + return new MatrixCharacteristics(h1.getRows(), h1.getCols(), + OptimizerUtils.getNnz(h1.getRows(), h1.getCols(), spOut)); + case RBIND: + return new MatrixCharacteristics(h1.getRows()+h1.getRows(), h1.getCols(), + OptimizerUtils.getNnz(h1.getRows()+h2.getRows(), h1.getCols(), spOut)); + case CBIND: + return new MatrixCharacteristics(h1.getRows(), h1.getCols()+h2.getCols(), + OptimizerUtils.getNnz(h1.getRows(), h1.getCols()+h2.getCols(), spOut)); + case DIAG: + int ncol = h1.getCols()==1 ? h1.getRows() : 1; + return new MatrixCharacteristics(h1.getRows(), ncol, + OptimizerUtils.getNnz(h1.getRows(), ncol, spOut)); + case TRANS: + return new MatrixCharacteristics(h1.getCols(), h1.getRows(), h1.getNonZeros()); + case RESHAPE: + return new MatrixCharacteristics((int)misc[0], (int)misc[1], + OptimizerUtils.getNnz((int)misc[0], (int)misc[1], spOut)); + default: + throw new NotImplementedException(); + } + } + private static MatrixHistogram deriveMMHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut) { //exact propagation if lhs or rhs full diag if( h1.fullDiag ) return h2;
