Repository: systemml Updated Branches: refs/heads/master f74f5ad4b -> e11ae6af3
[SYSTEMML-2479] Improved MNC estimator for element-wise multiply Closes #815. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e11ae6af Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e11ae6af Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e11ae6af Branch: refs/heads/master Commit: e11ae6af3c09678a5ab0241407e552bdfaa897c0 Parents: f74f5ad Author: Johanna Sommer <joha...@mail-sommer.com> Authored: Fri Aug 3 18:46:49 2018 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Aug 3 18:46:50 2018 -0700 ---------------------------------------------------------------------- .../hops/estim/EstimatorMatrixHistogram.java | 354 ++++++++++--------- 1 file changed, 178 insertions(+), 176 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e11ae6af/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java index 60e6af4..270d198 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java @@ -31,184 +31,186 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.SparseBlock; /** - * This estimator implements a remarkably simple yet effective approach for - * incorporating structural properties into sparsity estimation. The key idea is - * to maintain row and column nnz per matrix, along with additional meta data. + * This estimator implements a remarkably simple yet effective + * approach for incorporating structural properties into sparsity + * estimation. The key idea is to maintain row and column nnz per + * matrix, along with additional meta data. */ -public class EstimatorMatrixHistogram extends SparsityEstimator { - // internal configurations +public class EstimatorMatrixHistogram extends SparsityEstimator +{ + //internal configurations private static final boolean DEFAULT_USE_EXCEPTS = true; - + private final boolean _useExcepts; - + public EstimatorMatrixHistogram() { this(DEFAULT_USE_EXCEPTS); } - + public EstimatorMatrixHistogram(boolean useExcepts) { _useExcepts = useExcepts; } - + @Override public double estim(MMNode root) { - // recursive histogram computation of non-leaf nodes - if (!root.getLeft().isLeaf()) - estim(root.getLeft()); // obtain synopsis - if (!root.getRight().isLeaf()) - estim(root.getLeft()); // obtain synopsis - MatrixHistogram h1 = !root.getLeft().isLeaf() ? (MatrixHistogram) root.getLeft().getSynopsis() - : new MatrixHistogram(root.getLeft().getData(), _useExcepts); - MatrixHistogram h2 = !root.getRight().isLeaf() ? (MatrixHistogram) root.getRight().getSynopsis() - : new MatrixHistogram(root.getRight().getData(), _useExcepts); - - // estimate output sparsity based on input histograms + //recursive histogram computation of non-leaf nodes + if( !root.getLeft().isLeaf() ) + estim(root.getLeft()); //obtain synopsis + if( !root.getRight().isLeaf() ) + estim(root.getLeft()); //obtain synopsis + MatrixHistogram h1 = !root.getLeft().isLeaf() ? + (MatrixHistogram)root.getLeft().getSynopsis() : + new MatrixHistogram(root.getLeft().getData(), _useExcepts); + MatrixHistogram h2 = !root.getRight().isLeaf() ? + (MatrixHistogram)root.getRight().getSynopsis() : + new MatrixHistogram(root.getRight().getData(), _useExcepts); + + //estimate output sparsity based on input histograms double ret = estimIntern(h1, h2, OpCode.MM); - - // derive and memoize output histogram + + //derive and memoize output histogram root.setSynopsis(MatrixHistogram.deriveOutputHistogram(h1, h2, ret)); - + return ret; } - - @Override + + @Override public double estim(MatrixBlock m1, MatrixBlock m2) { return estim(m1, m2, OpCode.MM); } - + @Override public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) { MatrixHistogram h1 = new MatrixHistogram(m1, _useExcepts); - MatrixHistogram h2 = (m1 == m2) ? // self product - h1 : new MatrixHistogram(m2, _useExcepts); + MatrixHistogram h2 = (m1 == m2) ? //self product + h1 : new MatrixHistogram(m2, _useExcepts); return estimIntern(h1, h2, op); } - + @Override public double estim(MatrixBlock m1, OpCode op) { MatrixHistogram h1 = new MatrixHistogram(m1, _useExcepts); return estimIntern(h1, null, op); } - + private double estimIntern(MatrixHistogram h1, MatrixHistogram h2, OpCode op) { - double msize = (double) h1.getRows() * h1.getCols(); - + double msize = (double)h1.getRows()*h1.getCols(); + switch (op) { - case MM: - return estimInternMM(h1, h2); - case MULT: - return Math.min( - IntStream.range(0, h1.getRows()).mapToDouble(i -> h1.rNnz[i] / msize * h2.rNnz[i] / msize).sum(), - IntStream.range(0, h1.getCols()).mapToDouble(i -> h1.cNnz[i] / msize * h2.cNnz[i] / msize).sum()); - case PLUS: - return Math.min( - IntStream.range(0, h1.getRows()) - .mapToDouble(i -> h1.rNnz[i] / msize + h2.rNnz[i] / msize - - h1.rNnz[i] / msize * h2.rNnz[i] / msize) - .sum(), - IntStream.range(0, h1.getCols()).mapToDouble( - i -> h1.cNnz[i] / msize + h2.cNnz[i] / msize - h1.cNnz[i] / msize * h2.cNnz[i] / msize) - .sum()); - case EQZERO: - return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), - (long) h1.getRows() * h1.getCols() - h1.getNonZeros()); - case DIAG: - return (h1.getCols() == 1) ? OptimizerUtils.getSparsity(h1.getRows(), h1.getRows(), h1.getNonZeros()) - : OptimizerUtils.getSparsity(h1.getRows(), 1, Math.min(h1.getRows(), h1.getNonZeros())); - // binary operations that preserve sparsity exactly - case CBIND: - return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols() + h2.getCols(), - h1.getNonZeros() + h2.getNonZeros()); - case RBIND: - return OptimizerUtils.getSparsity(h1.getRows() + h2.getRows(), h1.getCols(), - h1.getNonZeros() + h2.getNonZeros()); - // unary operation that preserve sparsity exactly - case NEQZERO: - return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), h1.getNonZeros()); - case TRANS: - return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), h1.getNonZeros()); - case RESHAPE: - return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), h1.getNonZeros()); - default: - throw new NotImplementedException(); + case MM: + return estimInternMM(h1, h2); + case MULT: + final long N1 = h1.getNonZeros(); + final long N2 = h2.getNonZeros(); + final long scale = IntStream.range(0, h1.getCols()) + .mapToLong(j -> (long)h1.cNnz[j] * h2.cNnz[j]).sum(); + return IntStream.range(0, h1.getRows()).mapToLong( + i -> (long)h1.rNnz[i] * h2.rNnz[i] * scale / N1 / N2).sum() / msize; + case PLUS: + return Math.min( + IntStream.range(0, h1.getRows()).mapToDouble(i -> h1.rNnz[i]/msize + + h2.rNnz[i]/msize - h1.rNnz[i]/msize * h2.rNnz[i]/msize).sum(), + IntStream.range(0, h1.getCols()).mapToDouble(i -> h1.cNnz[i]/msize + + h2.cNnz[i]/msize - h1.cNnz[i]/msize * h2.cNnz[i]/msize).sum()); + case EQZERO: + return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), + (long)h1.getRows() * h1.getCols() - h1.getNonZeros()); + case DIAG: + return (h1.getCols()==1) ? + OptimizerUtils.getSparsity(h1.getRows(), h1.getRows(), h1.getNonZeros()) : + OptimizerUtils.getSparsity(h1.getRows(), 1, Math.min(h1.getRows(), h1.getNonZeros())); + //binary operations that preserve sparsity exactly + case CBIND: + return OptimizerUtils.getSparsity(h1.getRows(), + h1.getCols()+h2.getCols(), h1.getNonZeros() + h2.getNonZeros()); + case RBIND: + return OptimizerUtils.getSparsity(h1.getRows()+h2.getRows(), + h1.getCols(), h1.getNonZeros() + h2.getNonZeros()); + //unary operation that preserve sparsity exactly + case NEQZERO: + case TRANS: + case RESHAPE: + return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), h1.getNonZeros()); + default: + throw new NotImplementedException(); } } - + private double estimInternMM(MatrixHistogram h1, MatrixHistogram h2) { long nnz = 0; - // special case, with exact sparsity estimate, where the dot product - // dot(h1.cNnz,h2rNnz) gives the exact number of non-zeros in the output - if (h1.rMaxNnz <= 1 || h2.cMaxNnz <= 1) { - for (int j = 0; j < h1.getCols(); j++) + //special case, with exact sparsity estimate, where the dot product + //dot(h1.cNnz,h2rNnz) gives the exact number of non-zeros in the output + if( h1.rMaxNnz <= 1 || h2.cMaxNnz <= 1 ) { + for( int j=0; j<h1.getCols(); j++ ) nnz += h1.cNnz[j] * h2.rNnz[j]; } - // special case, with hybrid exact and approximate output - else if (h1.cNnz1e != null && h2.rNnz1e != null) { - // note: normally h1.getRows()*h2.getCols() would define mnOut - // but by leveraging the knowledge of rows/cols w/ <=1 nnz, we - // account - // that exact and approximate fractions touch different areas - long mnOut = (h1.rNonEmpty - h1.rN1) * (h2.cNonEmpty - h2.cN1); + //special case, with hybrid exact and approximate output + else if(h1.cNnz1e!=null && h2.rNnz1e != null) { + //note: normally h1.getRows()*h2.getCols() would define mnOut + //but by leveraging the knowledge of rows/cols w/ <=1 nnz, we account + //that exact and approximate fractions touch different areas + long mnOut = (h1.rNonEmpty-h1.rN1) * (h2.cNonEmpty-h2.cN1); double spOutRest = 0; - for (int j = 0; j < h1.getCols(); j++) { - // exact fractions, w/o double counting + for( int j=0; j<h1.getCols(); j++ ) { + //exact fractions, w/o double counting nnz += h1.cNnz1e[j] * h2.rNnz[j]; - nnz += (h1.cNnz[j] - h1.cNnz1e[j]) * h2.rNnz1e[j]; - // approximate fraction, w/o double counting - double lsp = (double) (h1.cNnz[j] - h1.cNnz1e[j]) * (h2.rNnz[j] - h2.rNnz1e[j]) / mnOut; - spOutRest = spOutRest + lsp - spOutRest * lsp; + nnz += (h1.cNnz[j]-h1.cNnz1e[j]) * h2.rNnz1e[j]; + //approximate fraction, w/o double counting + double lsp = (double)(h1.cNnz[j]-h1.cNnz1e[j]) + * (h2.rNnz[j]-h2.rNnz1e[j]) / mnOut; + spOutRest = spOutRest + lsp - spOutRest*lsp; } - nnz += (long) (spOutRest * mnOut); + nnz += (long)(spOutRest * mnOut); } - // general case with approximate output + //general case with approximate output else { - long mnOut = h1.getRows() * h2.getCols(); + long mnOut = h1.getRows()*h2.getCols(); double spOut = 0; - for (int j = 0; j < h1.getCols(); j++) { + for( int j=0; j<h1.getCols(); j++ ) { double lsp = (double) h1.cNnz[j] * h2.rNnz[j] / mnOut; - spOut = spOut + lsp - spOut * lsp; + spOut = spOut + lsp - spOut*lsp; } - nnz = (long) (spOut * mnOut); + nnz = (long)(spOut * mnOut); } - - // exploit upper bound on nnz based on non-empty rows/cols - nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ? Math.min((long) h1.rNonEmpty * h2.cNonEmpty, nnz) : nnz; - - // exploit lower bound on nnz based on half-full rows/cols - nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ? Math.max((long) h1.rNdiv2 * h2.cNdiv2, nnz) : nnz; - - // compute final sparsity - return OptimizerUtils.getSparsity(h1.getRows(), h2.getCols(), nnz); + + //exploit upper bound on nnz based on non-empty rows/cols + nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ? + Math.min((long)h1.rNonEmpty * h2.cNonEmpty, nnz) : nnz; + + //exploit lower bound on nnz based on half-full rows/cols + nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ? + Math.max((long)h1.rNdiv2 * h2.cNdiv2, nnz) : nnz; + + //compute final sparsity + return OptimizerUtils.getSparsity( + h1.getRows(), h2.getCols(), nnz); } - + private static class MatrixHistogram { // count vectors (the histogram) - private final int[] rNnz; // nnz per row - private int[] rNnz1e = null; // nnz per row for cols w/ <= 1 non-zeros - private final int[] cNnz; // nnz per col - private int[] cNnz1e = null; // nnz per col for rows w/ <= 1 non-zeros + private final int[] rNnz; //nnz per row + private int[] rNnz1e = null; //nnz per row for cols w/ <= 1 non-zeros + private final int[] cNnz; //nnz per col + private int[] cNnz1e = null; //nnz per col for rows w/ <= 1 non-zeros // additional summary statistics - private final int rMaxNnz, cMaxNnz; // max nnz per row/row - private final int rN1, cN1; // number of rows/cols with nnz=1 - private final int rNonEmpty, cNonEmpty; // number of non-empty rows/cols - // (w/ empty is nnz=0) - private final int rNdiv2, cNdiv2; // number of rows/cols with nnz > - // #cols/2 and #rows/2 - private boolean fullDiag; // true if there exists a full diagonal of - // nonzeros - + private final int rMaxNnz, cMaxNnz; //max nnz per row/row + private final int rN1, cN1; //number of rows/cols with nnz=1 + private final int rNonEmpty, cNonEmpty; //number of non-empty rows/cols (w/ empty is nnz=0) + private final int rNdiv2, cNdiv2; //number of rows/cols with nnz > #cols/2 and #rows/2 + private boolean fullDiag; //true if there exists a full diagonal of nonzeros + public MatrixHistogram(MatrixBlock in, boolean useExcepts) { // 1) allocate basic synopsis rNnz = new int[in.getNumRows()]; cNnz = new int[in.getNumColumns()]; fullDiag = in.getNumRows() == in.getNonZeros(); - + // 2) compute basic synopsis details - if (!in.isEmpty()) { - if (in.isInSparseFormat()) { + if( !in.isEmpty() ) { + if( in.isInSparseFormat() ) { SparseBlock sblock = in.getSparseBlock(); - for (int i = 0; i < in.getNumRows(); i++) { - if (sblock.isEmpty(i)) - continue; + for( int i=0; i<in.getNumRows(); i++ ) { + if( sblock.isEmpty(i) ) continue; int apos = sblock.pos(i); int alen = sblock.size(i); int[] aix = sblock.indexes(i); @@ -216,69 +218,70 @@ public class EstimatorMatrixHistogram extends SparsityEstimator { LibMatrixAgg.countAgg(sblock.values(i), cNnz, aix, apos, alen); fullDiag &= aix[apos] == i; } - } else { + } + else { DenseBlock dblock = in.getDenseBlock(); - for (int i = 0; i < in.getNumRows(); i++) { + for( int i=0; i<in.getNumRows(); i++ ) { double[] avals = dblock.values(i); int lnnz = 0, aix = dblock.pos(i); - for (int j = 0; j < in.getNumColumns(); j++) { - if (avals[aix + j] != 0) { + for( int j=0; j<in.getNumColumns(); j++ ) { + if( avals[aix+j] != 0 ) { fullDiag &= (i == j); - cNnz[j]++; - lnnz++; + cNnz[j] ++; + lnnz ++; } } rNnz[i] = lnnz; } } } - + // 3) compute meta data synopsis rMaxNnz = Arrays.stream(rNnz).max().orElse(0); - cMaxNnz = Arrays.stream(cNnz).max().orElse(0); + cMaxNnz = Arrays.stream(cNnz).max().orElse(0); rN1 = (int) Arrays.stream(rNnz).filter(item -> item == 1).count(); cN1 = (int) Arrays.stream(cNnz).filter(item -> item == 1).count(); - rNonEmpty = (int) Arrays.stream(rNnz).filter(v -> v != 0).count(); - cNonEmpty = (int) Arrays.stream(cNnz).filter(v -> v != 0).count(); - rNdiv2 = (int) Arrays.stream(rNnz).filter(item -> item > getCols() / 2).count(); - cNdiv2 = (int) Arrays.stream(cNnz).filter(item -> item > getRows() / 2).count(); - + rNonEmpty = (int) Arrays.stream(rNnz).filter(v-> v!=0).count(); + cNonEmpty = (int) Arrays.stream(cNnz).filter(v-> v!=0).count(); + rNdiv2 = (int) Arrays.stream(rNnz).filter(item -> item > getCols()/2).count(); + cNdiv2 = (int) Arrays.stream(cNnz).filter(item -> item > getRows()/2).count(); + // 4) compute exception details if necessary (optional) - if (useExcepts & !in.isEmpty() && (rMaxNnz > 1 || cMaxNnz > 1)) { + if( useExcepts & !in.isEmpty() && (rMaxNnz > 1 || cMaxNnz > 1) ) { rNnz1e = new int[in.getNumRows()]; cNnz1e = new int[in.getNumColumns()]; - - if (in.isInSparseFormat()) { + + if( in.isInSparseFormat() ) { SparseBlock sblock = in.getSparseBlock(); - for (int i = 0; i < in.getNumRows(); i++) { - if (sblock.isEmpty(i)) - continue; + for( int i=0; i<in.getNumRows(); i++ ) { + if( sblock.isEmpty(i) ) continue; int alen = sblock.size(i); int apos = sblock.pos(i); int[] aix = sblock.indexes(i); - for (int k = apos; k < apos + alen; k++) + for( int k=apos; k<apos+alen; k++ ) rNnz1e[i] += cNnz[aix[k]] <= 1 ? 1 : 0; - if (alen <= 1) - for (int k = apos; k < apos + alen; k++) + if( alen <= 1 ) + for( int k=apos; k<apos+alen; k++ ) cNnz1e[aix[k]]++; } - } else { + } + else { DenseBlock dblock = in.getDenseBlock(); - for (int i = 0; i < in.getNumRows(); i++) { + for( int i=0; i<in.getNumRows(); i++ ) { double[] avals = dblock.values(i); int aix = dblock.pos(i); boolean rNnzlte1 = rNnz[i] <= 1; - for (int j = 0; j < in.getNumColumns(); j++) { - if (avals[aix + j] != 0) { + for( int j=0; j<in.getNumColumns(); j++ ) { + if( avals[aix + j] != 0 ) { rNnz1e[i] += cNnz[j] <= 1 ? 1 : 0; - cNnz1e[j] += rNnzlte1 ? 1 : 0; + cNnz1e[j] += rNnzlte1 ? 1 : 0; } } } } } } - + public MatrixHistogram(int[] r, int[] r1e, int[] c, int[] c1e, int rmax, int cmax) { rNnz = r; rNnz1e = r1e; @@ -290,56 +293,55 @@ public class EstimatorMatrixHistogram extends SparsityEstimator { rNonEmpty = cNonEmpty = -1; rNdiv2 = cNdiv2 = -1; } - + public int getRows() { return rNnz.length; } - + public int getCols() { return cNnz.length; } - + public long getNonZeros() { - return getRows() < getCols() ? IntStream.range(0, getRows()).mapToLong(i -> rNnz[i]).sum() - : IntStream.range(0, getRows()).mapToLong(i -> cNnz[i]).sum(); + return getRows() < getCols() ? + IntStream.range(0, getRows()).mapToLong(i-> rNnz[i]).sum() : + IntStream.range(0, getRows()).mapToLong(i-> cNnz[i]).sum(); } - + public static MatrixHistogram deriveOutputHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut) { - // exact propagation if lhs or rhs full diag - if (h1.fullDiag) - return h2; - if (h2.fullDiag) - return h1; - - // get input/output nnz for scaling + //exact propagation if lhs or rhs full diag + if( h1.fullDiag ) return h2; + if( h2.fullDiag ) return h1; + + //get input/output nnz for scaling long nnz1 = Arrays.stream(h1.rNnz).sum(); long nnz2 = Arrays.stream(h2.cNnz).sum(); double nnzOut = spOut * h1.getRows() * h2.getCols(); - - // propagate h1.r and h2.c to output via simple scaling - // (this implies 0s propagate and distribution is preserved) + + //propagate h1.r and h2.c to output via simple scaling + //(this implies 0s propagate and distribution is preserved) int rMaxNnz = 0, cMaxNnz = 0; int[] rNnz = new int[h1.getRows()]; Random rn = new Random(); - for (int i = 0; i < h1.getRows(); i++) { - rNnz[i] = probRound(nnzOut / nnz1 * h1.rNnz[i], rn); + for( int i=0; i<h1.getRows(); i++ ) { + rNnz[i] = probRound(nnzOut/nnz1 * h1.rNnz[i], rn); rMaxNnz = Math.max(rMaxNnz, rNnz[i]); } int[] cNnz = new int[h2.getCols()]; - for (int i = 0; i < h2.getCols(); i++) { - cNnz[i] = probRound(nnzOut / nnz2 * h2.cNnz[i], rn); + for( int i=0; i<h2.getCols(); i++ ) { + cNnz[i] = probRound(nnzOut/nnz2 * h2.cNnz[i], rn); cMaxNnz = Math.max(cMaxNnz, cNnz[i]); } - - // construct new histogram object + + //construct new histogram object return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz); } - + private static int probRound(double inNnz, Random rand) { double temp = Math.floor(inNnz); - double f = inNnz - temp; // non-int fraction [0,1) - double randf = rand.nextDouble(); // uniform [0,1) - return (int) ((f > randf) ? temp + 1 : temp); + double f = inNnz - temp; //non-int fraction [0,1) + double randf = rand.nextDouble(); //uniform [0,1) + return (int)((f > randf) ? temp+1 : temp); } } }