Repository: systemml Updated Branches: refs/heads/master f1b9d1c08 -> ca24ec564
[SYSTEMML-2329] Extended sampling-based sparsity estimator This patch fixes the existing sampling-based estimator by optionally removing its bias via an approach similar to element-wise addition used in other estimators. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ca24ec56 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ca24ec56 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ca24ec56 Branch: refs/heads/master Commit: ca24ec5647dedbf6eb50bbc630ccee673b1b3320 Parents: f1b9d1c Author: Matthias Boehm <mboe...@gmail.com> Authored: Wed Oct 17 18:29:56 2018 +0200 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Wed Oct 17 18:29:56 2018 +0200 ---------------------------------------------------------------------- .../sysml/hops/estim/EstimatorSample.java | 36 ++++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ca24ec56/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java index faf7d0e..ec624f0 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java @@ -37,22 +37,29 @@ import org.apache.sysml.runtime.util.UtilFunctions; * The basic idea is to draw random samples of aligned columns SA and rows SB, * and compute the output nnz as max(nnz(SA_i)*nnz(SB_i)). However, this estimator is * biased toward underestimation as the maximum is unlikely sampled and collisions are - * not accounted for. + * not accounted for. Accordingly, we also support an extended estimator that relies + * on similar ideas for element-wise addition as the other estimators. */ public class EstimatorSample extends SparsityEstimator { private static final double SAMPLE_FRACTION = 0.1; //10% private final double _frac; + private final boolean _extended; public EstimatorSample() { - this(SAMPLE_FRACTION); + this(SAMPLE_FRACTION, false); } public EstimatorSample(double sampleFrac) { + this(sampleFrac, false); + } + + public EstimatorSample(double sampleFrac, boolean extended) { if( sampleFrac < 0 || sampleFrac > 1.0 ) throw new DMLRuntimeException("Invalid sample fraction: "+sampleFrac); _frac = sampleFrac; + _extended = extended; } @Override @@ -73,12 +80,27 @@ public class EstimatorSample extends SparsityEstimator int k = m1.getNumColumns(); int[] ix = UtilFunctions.getSortedSampleIndexes( k, (int)Math.max(k*_frac, 1)); + int p = ix.length; int[] cnnz = computeColumnNnz(m1, ix); - long nnzOut = 0; - for(int i=0; i<ix.length; i++) - nnzOut = Math.max(nnzOut, cnnz[i] * m2.recomputeNonZeros(ix[i], ix[i])); - return OptimizerUtils.getSparsity( - m1.getNumRows(), m2.getNumColumns(), nnzOut); + if( _extended ) { + double ml = (long)m1.getNumRows()*m2.getNumColumns(); + double sumS = 0, prodS = 1; + for(int i=0; i<ix.length; i++) { + long rnnz = m2.recomputeNonZeros(ix[i], ix[i]); + double v = (double)cnnz[i] * rnnz /ml; + sumS += v; + prodS *= 1-v; + } + return 1-Math.pow(1-1d/p * sumS, k - p) * prodS; + } + else { + //biased sampling-based estimator + long nnzOut = 0; + for(int i=0; i<p; i++) + nnzOut = Math.max(nnzOut, cnnz[i] * m2.recomputeNonZeros(ix[i], ix[i])); + return OptimizerUtils.getSparsity( + m1.getNumRows(), m2.getNumColumns(), nnzOut); + } } case MULT: { int k = Math.max(m1.getNumColumns(), m1.getNumRows());