Repository: systemml
Updated Branches:
  refs/heads/master f1b9d1c08 -> ca24ec564


[SYSTEMML-2329] Extended sampling-based sparsity estimator

This patch fixes the existing sampling-based estimator by optionally
removing its bias via an approach similar to element-wise addition used
in other estimators.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ca24ec56
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ca24ec56
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ca24ec56

Branch: refs/heads/master
Commit: ca24ec5647dedbf6eb50bbc630ccee673b1b3320
Parents: f1b9d1c
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Wed Oct 17 18:29:56 2018 +0200
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Wed Oct 17 18:29:56 2018 +0200

----------------------------------------------------------------------
 .../sysml/hops/estim/EstimatorSample.java       | 36 ++++++++++++++++----
 1 file changed, 29 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ca24ec56/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
index faf7d0e..ec624f0 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
@@ -37,22 +37,29 @@ import org.apache.sysml.runtime.util.UtilFunctions;
  * The basic idea is to draw random samples of aligned columns SA and rows SB,
  * and compute the output nnz as max(nnz(SA_i)*nnz(SB_i)). However, this 
estimator is
  * biased toward underestimation as the maximum is unlikely sampled and 
collisions are
- * not accounted for.
+ * not accounted for. Accordingly, we also support an extended estimator that 
relies
+ * on similar ideas for element-wise addition as the other estimators.
  */
 public class EstimatorSample extends SparsityEstimator
 {
        private static final double SAMPLE_FRACTION = 0.1; //10%
        
        private final double _frac;
+       private final boolean _extended;
        
        public EstimatorSample() {
-               this(SAMPLE_FRACTION);
+               this(SAMPLE_FRACTION, false);
        }
        
        public EstimatorSample(double sampleFrac) {
+               this(sampleFrac, false);
+       }
+       
+       public EstimatorSample(double sampleFrac, boolean extended) {
                if( sampleFrac < 0 || sampleFrac > 1.0 )
                        throw new DMLRuntimeException("Invalid sample fraction: 
"+sampleFrac);
                _frac = sampleFrac;
+               _extended = extended;
        }
        
        @Override
@@ -73,12 +80,27 @@ public class EstimatorSample extends SparsityEstimator
                                int k =  m1.getNumColumns();
                                int[] ix = UtilFunctions.getSortedSampleIndexes(
                                        k, (int)Math.max(k*_frac, 1));
+                               int p = ix.length;
                                int[] cnnz = computeColumnNnz(m1, ix);
-                               long nnzOut = 0;
-                               for(int i=0; i<ix.length; i++)
-                                       nnzOut = Math.max(nnzOut, cnnz[i] * 
m2.recomputeNonZeros(ix[i], ix[i]));
-                               return OptimizerUtils.getSparsity( 
-                                       m1.getNumRows(), m2.getNumColumns(), 
nnzOut);
+                               if( _extended ) {
+                                       double ml = 
(long)m1.getNumRows()*m2.getNumColumns();
+                                       double sumS = 0, prodS = 1;
+                                       for(int i=0; i<ix.length; i++) {
+                                               long rnnz = 
m2.recomputeNonZeros(ix[i], ix[i]);
+                                               double v = (double)cnnz[i] * 
rnnz /ml;
+                                               sumS += v;
+                                               prodS *= 1-v;
+                                       }
+                                       return 1-Math.pow(1-1d/p * sumS, k - p) 
* prodS;
+                               }
+                               else {
+                                       //biased sampling-based estimator
+                                       long nnzOut = 0;
+                                       for(int i=0; i<p; i++)
+                                               nnzOut = Math.max(nnzOut, 
cnnz[i] * m2.recomputeNonZeros(ix[i], ix[i]));
+                                       return OptimizerUtils.getSparsity( 
+                                               m1.getNumRows(), 
m2.getNumColumns(), nnzOut);
+                               }
                        }
                        case MULT: {
                                int k = Math.max(m1.getNumColumns(), 
m1.getNumRows());

Reply via email to