Repository: systemml Updated Branches: refs/heads/master 0871f260e -> cc349dc88
[SYSTEMML-2295] Fix robustness bitset sparsity estimator for empty blocks Closes #784. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/cc349dc8 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/cc349dc8 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/cc349dc8 Branch: refs/heads/master Commit: cc349dc88a8e72adc16a6048d6ee48a35834e9aa Parents: 0871f26 Author: Johanna Sommer <[email protected]> Authored: Wed Jun 13 16:55:16 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jun 13 16:55:17 2018 -0700 ---------------------------------------------------------------------- .../sysml/hops/estim/EstimatorBitsetMM.java | 86 ++++++++++---------- 1 file changed, 43 insertions(+), 43 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/cc349dc8/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java index 8bb5805..652a4a1 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java @@ -29,29 +29,27 @@ import org.apache.sysml.runtime.matrix.data.SparseBlock; /** * This estimator implements naive but rather common approach of boolean matrix - * multiplies which allows to infer the exact non-zero structure and thus is also - * useful for sparse result preallocation. + * multiplies which allows to infer the exact non-zero structure and thus is + * also useful for sparse result preallocation. * */ -public class EstimatorBitsetMM extends SparsityEstimator -{ +public class EstimatorBitsetMM extends SparsityEstimator { @Override public double estim(MMNode root) { - //recursive density map computation of non-leaf nodes - if( !root.getLeft().isLeaf() ) - estim(root.getLeft()); //obtain synopsis - if( !root.getRight().isLeaf() ) - estim(root.getLeft()); //obtain synopsis - BitsetMatrix m1Map = !root.getLeft().isLeaf() ? - (BitsetMatrix)root.getLeft().getSynopsis() : new BitsetMatrix(root.getLeft().getData()); - BitsetMatrix m2Map = !root.getRight().isLeaf() ? - (BitsetMatrix)root.getRight().getSynopsis() : new BitsetMatrix(root.getRight().getData()); - - //estimate output density map and sparsity via boolean matrix mult + // recursive density map computation of non-leaf nodes + if (!root.getLeft().isLeaf()) + estim(root.getLeft()); // obtain synopsis + if (!root.getRight().isLeaf()) + estim(root.getLeft()); // obtain synopsis + BitsetMatrix m1Map = !root.getLeft().isLeaf() ? (BitsetMatrix) root.getLeft().getSynopsis() + : new BitsetMatrix(root.getLeft().getData()); + BitsetMatrix m2Map = !root.getRight().isLeaf() ? (BitsetMatrix) root.getRight().getSynopsis() + : new BitsetMatrix(root.getRight().getData()); + + // estimate output density map and sparsity via boolean matrix mult BitsetMatrix outMap = m1Map.matMult(m2Map); - root.setSynopsis(outMap); //memoize boolean matrix - return OptimizerUtils.getSparsity( - outMap.getNumRows(), outMap.getNumColumns(), outMap.getNonZeros()); + root.setSynopsis(outMap); // memoize boolean matrix + return OptimizerUtils.getSparsity(outMap.getNumRows(), outMap.getNumColumns(), outMap.getNonZeros()); } @Override @@ -59,8 +57,8 @@ public class EstimatorBitsetMM extends SparsityEstimator BitsetMatrix m1Map = new BitsetMatrix(m1); BitsetMatrix m2Map = new BitsetMatrix(m2); BitsetMatrix outMap = m1Map.matMult(m2Map); - return OptimizerUtils.getSparsity( //aggregate output histogram - outMap.getNumRows(), outMap.getNumColumns(), outMap.getNonZeros()); + return OptimizerUtils.getSparsity( // aggregate output histogram + outMap.getNumRows(), outMap.getNumColumns(), outMap.getNonZeros()); } @Override @@ -68,79 +66,81 @@ public class EstimatorBitsetMM extends SparsityEstimator LOG.warn("Meta-data-only estimates not supported in EstimatorBitsetMM, falling back to EstimatorBasicAvg."); return new EstimatorBasicAvg().estim(mc1, mc2); } - + private static class BitsetMatrix { private final int _rlen; private final int _clen; private long _nonZeros; private BitSet[] _data; - + public BitsetMatrix(int rlen, int clen) { _rlen = rlen; _clen = clen; _data = new BitSet[_rlen]; - for(int i=0; i<_rlen; i++) + for (int i = 0; i < _rlen; i++) _data[i] = new BitSet(_clen); _nonZeros = 0; } - + public BitsetMatrix(MatrixBlock in) { this(in.getNumRows(), in.getNumColumns()); init(in); } - + public int getNumRows() { return _rlen; } - + public int getNumColumns() { return _clen; } - + public long getNonZeros() { return _nonZeros; } - + private void init(MatrixBlock in) { - if( in.isInSparseFormat() ) { + if (in.isEmptyBlock(false)) + return; + if (in.isInSparseFormat()) { SparseBlock sblock = in.getSparseBlock(); - for(int i=0; i<in.getNumRows(); i++) { - if(sblock.isEmpty(i)) continue; + for (int i = 0; i < in.getNumRows(); i++) { + if (sblock.isEmpty(i)) + continue; BitSet lbs = _data[i]; int alen = sblock.size(i); int apos = sblock.pos(i); int[] aix = sblock.indexes(i); - for(int k=apos; k<apos+alen; k++) + for (int k = apos; k < apos + alen; k++) lbs.set(aix[k]); } - } - else { + } else { DenseBlock dblock = in.getDenseBlock(); - for(int i=0; i<in.getNumRows(); i++) { + for (int i = 0; i < in.getNumRows(); i++) { BitSet lbs = _data[i]; double[] avals = dblock.values(i); int aix = dblock.pos(i); - for(int j=0; j<in.getNumColumns(); j++) - if( avals[aix+j] != 0 ) + for (int j = 0; j < in.getNumColumns(); j++) + if (avals[aix + j] != 0) lbs.set(j); } } _nonZeros = in.getNonZeros(); } - + public BitsetMatrix matMult(BitsetMatrix m2) { final int m = this._rlen; final int cd = this._clen; final int n = m2._clen; - //matrix multiply with IKJ schedule and pure OR ops in inner loop + // matrix multiply with IKJ schedule and pure OR ops in inner loop BitsetMatrix out = new BitsetMatrix(m, n); - for(int i=0; i<m; i++) { + for (int i = 0; i < m; i++) { BitSet a = this._data[i], c = out._data[i]; - for(int k=0; k<cd; k++) { - if( a.get(k) ) + for (int k = 0; k < cd; k++) { + if (a.get(k)) c.or(m2._data[k]); } - //maintain nnz + // maintain nnz out._nonZeros += c.cardinality(); } return out;
