Repository: systemml Updated Branches: refs/heads/master c98e81581 -> f74f5ad4b
[SYSTEMML-2479] Extended AVG estimator for other operations Closes #813. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f74f5ad4 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f74f5ad4 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f74f5ad4 Branch: refs/heads/master Commit: f74f5ad4bf27606a5cb5e27a16eceb65c0bd5f62 Parents: c98e815 Author: Johanna Sommer <[email protected]> Authored: Fri Aug 3 17:41:22 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Aug 3 17:49:21 2018 -0700 ---------------------------------------------------------------------- .../sysml/hops/estim/EstimatorBasicAvg.java | 54 ++- .../hops/estim/EstimatorMatrixHistogram.java | 351 ++++++++++--------- 2 files changed, 220 insertions(+), 185 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/f74f5ad4/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java index 259ab43..baa1fc4 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java @@ -27,35 +27,69 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; * Basic average case estimator for matrix sparsity: * sp = 1 - Math.pow(1-sp1*sp2, k) */ -public class EstimatorBasicAvg extends SparsityEstimator -{ +public class EstimatorBasicAvg extends SparsityEstimator { @Override public double estim(MMNode root) { - //recursive sparsity evaluation of non-leaf nodes + // recursive sparsity evaluation of non-leaf nodes double sp1 = !root.getLeft().isLeaf() ? estim(root.getLeft()) : OptimizerUtils.getSparsity(root.getLeft().getMatrixCharacteristics()); double sp2 = !root.getRight().isLeaf() ? estim(root.getRight()) : OptimizerUtils.getSparsity(root.getRight().getMatrixCharacteristics()); - return estimIntern(sp1, sp2, root.getRows(), root.getLeft().getCols(), root.getCols()); + return estimInternMM(sp1, sp2, root.getRows(), root.getLeft().getCols(), root.getCols()); } @Override public double estim(MatrixBlock m1, MatrixBlock m2) { - return estimIntern(m1.getSparsity(), m2.getSparsity(), + return estimInternMM(m1.getSparsity(), m2.getSparsity(), m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns()); } @Override public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) { - throw new NotImplementedException(); + return estimIntern(m1, m2, op); } - + @Override public double estim(MatrixBlock m, OpCode op) { - throw new NotImplementedException(); + return estimIntern(m, null, op); + } + + private double estimIntern(MatrixBlock m1, MatrixBlock m2, OpCode op) { + switch (op) { + case MM: + return estimInternMM(m1.getSparsity(), m2.getSparsity(), + m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns()); + case MULT: + return m1.getSparsity() * m2.getSparsity(); + case PLUS: + return m1.getSparsity() + m2.getSparsity() - m1.getSparsity() * m2.getSparsity(); + case EQZERO: + return OptimizerUtils.getSparsity(m1.getNumRows(), m1.getNumColumns(), + (long) m1.getNumRows() * m1.getNumColumns() - m1.getNonZeros()); + case DIAG: + return (m1.getNumColumns() == 1) ? + OptimizerUtils.getSparsity(m1.getNumRows(), m1.getNumRows(), m1.getNonZeros()) : + OptimizerUtils.getSparsity(m1.getNumRows(), 1, Math.min(m1.getNumRows(), m1.getNonZeros())); + // binary operations that preserve sparsity exactly + case CBIND: + return OptimizerUtils.getSparsity(m1.getNumRows(), + m1.getNumColumns() + m1.getNumColumns(), m1.getNonZeros() + m2.getNonZeros()); + case RBIND: + return OptimizerUtils.getSparsity(m1.getNumRows() + m2.getNumRows(), + m1.getNumColumns(), m1.getNonZeros() + m2.getNonZeros()); + // unary operation that preserve sparsity exactly + case NEQZERO: + return m1.getSparsity(); + case TRANS: + return m1.getSparsity(); + case RESHAPE: + return m1.getSparsity(); + default: + throw new NotImplementedException(); + } } - - private double estimIntern(double sp1, double sp2, long m, long k, long n) { + + private double estimInternMM(double sp1, double sp2, long m, long k, long n) { return OptimizerUtils.getMatMultSparsity(sp1, sp2, m, k, n, false); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/f74f5ad4/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java index 209e2f0..60e6af4 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java @@ -31,183 +31,184 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.SparseBlock; /** - * This estimator implements a remarkably simple yet effective - * approach for incorporating structural properties into sparsity - * estimation. The key idea is to maintain row and column nnz per - * matrix, along with additional meta data. + * This estimator implements a remarkably simple yet effective approach for + * incorporating structural properties into sparsity estimation. The key idea is + * to maintain row and column nnz per matrix, along with additional meta data. */ -public class EstimatorMatrixHistogram extends SparsityEstimator -{ - //internal configurations +public class EstimatorMatrixHistogram extends SparsityEstimator { + // internal configurations private static final boolean DEFAULT_USE_EXCEPTS = true; - + private final boolean _useExcepts; - + public EstimatorMatrixHistogram() { this(DEFAULT_USE_EXCEPTS); } - + public EstimatorMatrixHistogram(boolean useExcepts) { _useExcepts = useExcepts; } - + @Override public double estim(MMNode root) { - //recursive histogram computation of non-leaf nodes - if( !root.getLeft().isLeaf() ) - estim(root.getLeft()); //obtain synopsis - if( !root.getRight().isLeaf() ) - estim(root.getLeft()); //obtain synopsis - MatrixHistogram h1 = !root.getLeft().isLeaf() ? - (MatrixHistogram)root.getLeft().getSynopsis() : - new MatrixHistogram(root.getLeft().getData(), _useExcepts); - MatrixHistogram h2 = !root.getRight().isLeaf() ? - (MatrixHistogram)root.getRight().getSynopsis() : - new MatrixHistogram(root.getRight().getData(), _useExcepts); - - //estimate output sparsity based on input histograms + // recursive histogram computation of non-leaf nodes + if (!root.getLeft().isLeaf()) + estim(root.getLeft()); // obtain synopsis + if (!root.getRight().isLeaf()) + estim(root.getLeft()); // obtain synopsis + MatrixHistogram h1 = !root.getLeft().isLeaf() ? (MatrixHistogram) root.getLeft().getSynopsis() + : new MatrixHistogram(root.getLeft().getData(), _useExcepts); + MatrixHistogram h2 = !root.getRight().isLeaf() ? (MatrixHistogram) root.getRight().getSynopsis() + : new MatrixHistogram(root.getRight().getData(), _useExcepts); + + // estimate output sparsity based on input histograms double ret = estimIntern(h1, h2, OpCode.MM); - - //derive and memoize output histogram + + // derive and memoize output histogram root.setSynopsis(MatrixHistogram.deriveOutputHistogram(h1, h2, ret)); - + return ret; } - - @Override + + @Override public double estim(MatrixBlock m1, MatrixBlock m2) { return estim(m1, m2, OpCode.MM); } - + @Override public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) { MatrixHistogram h1 = new MatrixHistogram(m1, _useExcepts); - MatrixHistogram h2 = (m1 == m2) ? //self product - h1 : new MatrixHistogram(m2, _useExcepts); + MatrixHistogram h2 = (m1 == m2) ? // self product + h1 : new MatrixHistogram(m2, _useExcepts); return estimIntern(h1, h2, op); } - + @Override public double estim(MatrixBlock m1, OpCode op) { MatrixHistogram h1 = new MatrixHistogram(m1, _useExcepts); return estimIntern(h1, null, op); } - + private double estimIntern(MatrixHistogram h1, MatrixHistogram h2, OpCode op) { - double msize = (double)h1.getRows()*h1.getCols(); - + double msize = (double) h1.getRows() * h1.getCols(); + switch (op) { - case MM: - return estimInternMM(h1, h2); - case MULT: - return Math.min( - IntStream.range(0, h1.getRows()).mapToDouble(i -> (double)h1.rNnz[i]/msize * (double)h2.rNnz[i]/msize).sum(), - IntStream.range(0, h1.getCols()).mapToDouble(i -> (double)h1.cNnz[i]/msize * (double)h2.cNnz[i]/msize).sum()); - case PLUS: - return Math.min( - IntStream.range(0, h1.getRows()).mapToDouble(i -> (double)h1.rNnz[i]/msize - + (double)h2.rNnz[i]/msize - (double)h1.rNnz[i]/msize * (double)h2.rNnz[i]/msize).sum(), - IntStream.range(0, h1.getCols()).mapToDouble(i -> (double)h1.cNnz[i]/msize - + (double)h2.cNnz[i]/msize - (double)h1.cNnz[i]/msize * (double)h2.cNnz[i]/msize).sum()); - case EQZERO: - return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), - (long)h1.getRows() * h1.getCols() - h1.getNonZeros()); - case DIAG: - return (h1.getCols()==1) ? - OptimizerUtils.getSparsity(h1.getRows(), h1.getRows(), h1.getNonZeros()) : - OptimizerUtils.getSparsity(h1.getRows(), 1, Math.min(h1.getRows(), h1.getNonZeros())); - //binary operations that preserve sparsity exactly - case CBIND: - return OptimizerUtils.getSparsity(h1.getRows(), - h1.getCols()+h2.getCols(), h1.getNonZeros() + h2.getNonZeros()); - case RBIND: - return OptimizerUtils.getSparsity(h1.getRows()+h2.getRows(), - h1.getCols(), h1.getNonZeros() + h2.getNonZeros()); - //unary operation that preserve sparsity exactly - case NEQZERO: - case TRANS: - case RESHAPE: - return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), h1.getNonZeros()); - default: - throw new NotImplementedException(); + case MM: + return estimInternMM(h1, h2); + case MULT: + return Math.min( + IntStream.range(0, h1.getRows()).mapToDouble(i -> h1.rNnz[i] / msize * h2.rNnz[i] / msize).sum(), + IntStream.range(0, h1.getCols()).mapToDouble(i -> h1.cNnz[i] / msize * h2.cNnz[i] / msize).sum()); + case PLUS: + return Math.min( + IntStream.range(0, h1.getRows()) + .mapToDouble(i -> h1.rNnz[i] / msize + h2.rNnz[i] / msize + - h1.rNnz[i] / msize * h2.rNnz[i] / msize) + .sum(), + IntStream.range(0, h1.getCols()).mapToDouble( + i -> h1.cNnz[i] / msize + h2.cNnz[i] / msize - h1.cNnz[i] / msize * h2.cNnz[i] / msize) + .sum()); + case EQZERO: + return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), + (long) h1.getRows() * h1.getCols() - h1.getNonZeros()); + case DIAG: + return (h1.getCols() == 1) ? OptimizerUtils.getSparsity(h1.getRows(), h1.getRows(), h1.getNonZeros()) + : OptimizerUtils.getSparsity(h1.getRows(), 1, Math.min(h1.getRows(), h1.getNonZeros())); + // binary operations that preserve sparsity exactly + case CBIND: + return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols() + h2.getCols(), + h1.getNonZeros() + h2.getNonZeros()); + case RBIND: + return OptimizerUtils.getSparsity(h1.getRows() + h2.getRows(), h1.getCols(), + h1.getNonZeros() + h2.getNonZeros()); + // unary operation that preserve sparsity exactly + case NEQZERO: + return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), h1.getNonZeros()); + case TRANS: + return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), h1.getNonZeros()); + case RESHAPE: + return OptimizerUtils.getSparsity(h1.getRows(), h1.getCols(), h1.getNonZeros()); + default: + throw new NotImplementedException(); } } - + private double estimInternMM(MatrixHistogram h1, MatrixHistogram h2) { long nnz = 0; - //special case, with exact sparsity estimate, where the dot product - //dot(h1.cNnz,h2rNnz) gives the exact number of non-zeros in the output - if( h1.rMaxNnz <= 1 || h2.cMaxNnz <= 1 ) { - for( int j=0; j<h1.getCols(); j++ ) + // special case, with exact sparsity estimate, where the dot product + // dot(h1.cNnz,h2rNnz) gives the exact number of non-zeros in the output + if (h1.rMaxNnz <= 1 || h2.cMaxNnz <= 1) { + for (int j = 0; j < h1.getCols(); j++) nnz += h1.cNnz[j] * h2.rNnz[j]; } - //special case, with hybrid exact and approximate output - else if(h1.cNnz1e!=null && h2.rNnz1e != null) { - //note: normally h1.getRows()*h2.getCols() would define mnOut - //but by leveraging the knowledge of rows/cols w/ <=1 nnz, we account - //that exact and approximate fractions touch different areas - long mnOut = (h1.rNonEmpty-h1.rN1) * (h2.cNonEmpty-h2.cN1); + // special case, with hybrid exact and approximate output + else if (h1.cNnz1e != null && h2.rNnz1e != null) { + // note: normally h1.getRows()*h2.getCols() would define mnOut + // but by leveraging the knowledge of rows/cols w/ <=1 nnz, we + // account + // that exact and approximate fractions touch different areas + long mnOut = (h1.rNonEmpty - h1.rN1) * (h2.cNonEmpty - h2.cN1); double spOutRest = 0; - for( int j=0; j<h1.getCols(); j++ ) { - //exact fractions, w/o double counting + for (int j = 0; j < h1.getCols(); j++) { + // exact fractions, w/o double counting nnz += h1.cNnz1e[j] * h2.rNnz[j]; - nnz += (h1.cNnz[j]-h1.cNnz1e[j]) * h2.rNnz1e[j]; - //approximate fraction, w/o double counting - double lsp = (double)(h1.cNnz[j]-h1.cNnz1e[j]) - * (h2.rNnz[j]-h2.rNnz1e[j]) / mnOut; - spOutRest = spOutRest + lsp - spOutRest*lsp; + nnz += (h1.cNnz[j] - h1.cNnz1e[j]) * h2.rNnz1e[j]; + // approximate fraction, w/o double counting + double lsp = (double) (h1.cNnz[j] - h1.cNnz1e[j]) * (h2.rNnz[j] - h2.rNnz1e[j]) / mnOut; + spOutRest = spOutRest + lsp - spOutRest * lsp; } - nnz += (long)(spOutRest * mnOut); + nnz += (long) (spOutRest * mnOut); } - //general case with approximate output + // general case with approximate output else { - long mnOut = h1.getRows()*h2.getCols(); + long mnOut = h1.getRows() * h2.getCols(); double spOut = 0; - for( int j=0; j<h1.getCols(); j++ ) { + for (int j = 0; j < h1.getCols(); j++) { double lsp = (double) h1.cNnz[j] * h2.rNnz[j] / mnOut; - spOut = spOut + lsp - spOut*lsp; + spOut = spOut + lsp - spOut * lsp; } - nnz = (long)(spOut * mnOut); + nnz = (long) (spOut * mnOut); } - - //exploit upper bound on nnz based on non-empty rows/cols - nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ? - Math.min((long)h1.rNonEmpty * h2.cNonEmpty, nnz) : nnz; - - //exploit lower bound on nnz based on half-full rows/cols - nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ? - Math.max((long)h1.rNdiv2 * h2.cNdiv2, nnz) : nnz; - - //compute final sparsity - return OptimizerUtils.getSparsity( - h1.getRows(), h2.getCols(), nnz); + + // exploit upper bound on nnz based on non-empty rows/cols + nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ? Math.min((long) h1.rNonEmpty * h2.cNonEmpty, nnz) : nnz; + + // exploit lower bound on nnz based on half-full rows/cols + nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ? Math.max((long) h1.rNdiv2 * h2.cNdiv2, nnz) : nnz; + + // compute final sparsity + return OptimizerUtils.getSparsity(h1.getRows(), h2.getCols(), nnz); } - + private static class MatrixHistogram { // count vectors (the histogram) - private final int[] rNnz; //nnz per row - private int[] rNnz1e = null; //nnz per row for cols w/ <= 1 non-zeros - private final int[] cNnz; //nnz per col - private int[] cNnz1e = null; //nnz per col for rows w/ <= 1 non-zeros + private final int[] rNnz; // nnz per row + private int[] rNnz1e = null; // nnz per row for cols w/ <= 1 non-zeros + private final int[] cNnz; // nnz per col + private int[] cNnz1e = null; // nnz per col for rows w/ <= 1 non-zeros // additional summary statistics - private final int rMaxNnz, cMaxNnz; //max nnz per row/row - private final int rN1, cN1; //number of rows/cols with nnz=1 - private final int rNonEmpty, cNonEmpty; //number of non-empty rows/cols (w/ empty is nnz=0) - private final int rNdiv2, cNdiv2; //number of rows/cols with nnz > #cols/2 and #rows/2 - private boolean fullDiag; //true if there exists a full diagonal of nonzeros - + private final int rMaxNnz, cMaxNnz; // max nnz per row/row + private final int rN1, cN1; // number of rows/cols with nnz=1 + private final int rNonEmpty, cNonEmpty; // number of non-empty rows/cols + // (w/ empty is nnz=0) + private final int rNdiv2, cNdiv2; // number of rows/cols with nnz > + // #cols/2 and #rows/2 + private boolean fullDiag; // true if there exists a full diagonal of + // nonzeros + public MatrixHistogram(MatrixBlock in, boolean useExcepts) { // 1) allocate basic synopsis rNnz = new int[in.getNumRows()]; cNnz = new int[in.getNumColumns()]; fullDiag = in.getNumRows() == in.getNonZeros(); - + // 2) compute basic synopsis details - if( !in.isEmpty() ) { - if( in.isInSparseFormat() ) { + if (!in.isEmpty()) { + if (in.isInSparseFormat()) { SparseBlock sblock = in.getSparseBlock(); - for( int i=0; i<in.getNumRows(); i++ ) { - if( sblock.isEmpty(i) ) continue; + for (int i = 0; i < in.getNumRows(); i++) { + if (sblock.isEmpty(i)) + continue; int apos = sblock.pos(i); int alen = sblock.size(i); int[] aix = sblock.indexes(i); @@ -215,70 +216,69 @@ public class EstimatorMatrixHistogram extends SparsityEstimator LibMatrixAgg.countAgg(sblock.values(i), cNnz, aix, apos, alen); fullDiag &= aix[apos] == i; } - } - else { + } else { DenseBlock dblock = in.getDenseBlock(); - for( int i=0; i<in.getNumRows(); i++ ) { + for (int i = 0; i < in.getNumRows(); i++) { double[] avals = dblock.values(i); int lnnz = 0, aix = dblock.pos(i); - for( int j=0; j<in.getNumColumns(); j++ ) { - if( avals[aix+j] != 0 ) { + for (int j = 0; j < in.getNumColumns(); j++) { + if (avals[aix + j] != 0) { fullDiag &= (i == j); - cNnz[j] ++; - lnnz ++; + cNnz[j]++; + lnnz++; } } rNnz[i] = lnnz; } } } - + // 3) compute meta data synopsis rMaxNnz = Arrays.stream(rNnz).max().orElse(0); - cMaxNnz = Arrays.stream(cNnz).max().orElse(0); + cMaxNnz = Arrays.stream(cNnz).max().orElse(0); rN1 = (int) Arrays.stream(rNnz).filter(item -> item == 1).count(); cN1 = (int) Arrays.stream(cNnz).filter(item -> item == 1).count(); - rNonEmpty = (int) Arrays.stream(rNnz).filter(v-> v!=0).count(); - cNonEmpty = (int) Arrays.stream(cNnz).filter(v-> v!=0).count(); - rNdiv2 = (int) Arrays.stream(rNnz).filter(item -> item > getCols()/2).count(); - cNdiv2 = (int) Arrays.stream(cNnz).filter(item -> item > getRows()/2).count(); - + rNonEmpty = (int) Arrays.stream(rNnz).filter(v -> v != 0).count(); + cNonEmpty = (int) Arrays.stream(cNnz).filter(v -> v != 0).count(); + rNdiv2 = (int) Arrays.stream(rNnz).filter(item -> item > getCols() / 2).count(); + cNdiv2 = (int) Arrays.stream(cNnz).filter(item -> item > getRows() / 2).count(); + // 4) compute exception details if necessary (optional) - if( useExcepts & !in.isEmpty() && (rMaxNnz > 1 || cMaxNnz > 1) ) { + if (useExcepts & !in.isEmpty() && (rMaxNnz > 1 || cMaxNnz > 1)) { rNnz1e = new int[in.getNumRows()]; cNnz1e = new int[in.getNumColumns()]; - - if( in.isInSparseFormat() ) { + + if (in.isInSparseFormat()) { SparseBlock sblock = in.getSparseBlock(); - for( int i=0; i<in.getNumRows(); i++ ) { - if( sblock.isEmpty(i) ) continue; + for (int i = 0; i < in.getNumRows(); i++) { + if (sblock.isEmpty(i)) + continue; int alen = sblock.size(i); int apos = sblock.pos(i); int[] aix = sblock.indexes(i); - for( int k=apos; k<apos+alen; k++ ) + for (int k = apos; k < apos + alen; k++) rNnz1e[i] += cNnz[aix[k]] <= 1 ? 1 : 0; - if( alen <= 1 ) - for( int k=apos; k<apos+alen; k++ ) + if (alen <= 1) + for (int k = apos; k < apos + alen; k++) cNnz1e[aix[k]]++; } - } - else { + } else { DenseBlock dblock = in.getDenseBlock(); - for( int i=0; i<in.getNumRows(); i++ ) { + for (int i = 0; i < in.getNumRows(); i++) { double[] avals = dblock.values(i); int aix = dblock.pos(i); boolean rNnzlte1 = rNnz[i] <= 1; - for( int j=0; j<in.getNumColumns(); j++ ) { - if( avals[aix + j] != 0 ) { + for (int j = 0; j < in.getNumColumns(); j++) { + if (avals[aix + j] != 0) { rNnz1e[i] += cNnz[j] <= 1 ? 1 : 0; - cNnz1e[j] += rNnzlte1 ? 1 : 0; + cNnz1e[j] += rNnzlte1 ? 1 : 0; } } } } } } - + public MatrixHistogram(int[] r, int[] r1e, int[] c, int[] c1e, int rmax, int cmax) { rNnz = r; rNnz1e = r1e; @@ -290,55 +290,56 @@ public class EstimatorMatrixHistogram extends SparsityEstimator rNonEmpty = cNonEmpty = -1; rNdiv2 = cNdiv2 = -1; } - + public int getRows() { return rNnz.length; } - + public int getCols() { return cNnz.length; } - + public long getNonZeros() { - return getRows() < getCols() ? - IntStream.range(0, getRows()).mapToLong(i-> rNnz[i]).sum() : - IntStream.range(0, getRows()).mapToLong(i-> cNnz[i]).sum(); + return getRows() < getCols() ? IntStream.range(0, getRows()).mapToLong(i -> rNnz[i]).sum() + : IntStream.range(0, getRows()).mapToLong(i -> cNnz[i]).sum(); } - + public static MatrixHistogram deriveOutputHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut) { - //exact propagation if lhs or rhs full diag - if( h1.fullDiag ) return h2; - if( h2.fullDiag ) return h1; - - //get input/output nnz for scaling + // exact propagation if lhs or rhs full diag + if (h1.fullDiag) + return h2; + if (h2.fullDiag) + return h1; + + // get input/output nnz for scaling long nnz1 = Arrays.stream(h1.rNnz).sum(); long nnz2 = Arrays.stream(h2.cNnz).sum(); double nnzOut = spOut * h1.getRows() * h2.getCols(); - - //propagate h1.r and h2.c to output via simple scaling - //(this implies 0s propagate and distribution is preserved) + + // propagate h1.r and h2.c to output via simple scaling + // (this implies 0s propagate and distribution is preserved) int rMaxNnz = 0, cMaxNnz = 0; int[] rNnz = new int[h1.getRows()]; Random rn = new Random(); - for( int i=0; i<h1.getRows(); i++ ) { - rNnz[i] = probRound(nnzOut/nnz1 * h1.rNnz[i], rn); + for (int i = 0; i < h1.getRows(); i++) { + rNnz[i] = probRound(nnzOut / nnz1 * h1.rNnz[i], rn); rMaxNnz = Math.max(rMaxNnz, rNnz[i]); } int[] cNnz = new int[h2.getCols()]; - for( int i=0; i<h2.getCols(); i++ ) { - cNnz[i] = probRound(nnzOut/nnz2 * h2.cNnz[i], rn); + for (int i = 0; i < h2.getCols(); i++) { + cNnz[i] = probRound(nnzOut / nnz2 * h2.cNnz[i], rn); cMaxNnz = Math.max(cMaxNnz, cNnz[i]); } - - //construct new histogram object + + // construct new histogram object return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz); } - + private static int probRound(double inNnz, Random rand) { double temp = Math.floor(inNnz); - double f = inNnz - temp; //non-int fraction [0,1) - double randf = rand.nextDouble(); //uniform [0,1) - return (int)((f > randf) ? temp+1 : temp); + double f = inNnz - temp; // non-int fraction [0,1) + double randf = rand.nextDouble(); // uniform [0,1) + return (int) ((f > randf) ? temp + 1 : temp); } } }
