Repository: systemml
Updated Branches:
  refs/heads/master c98e81581 -> f74f5ad4b


[SYSTEMML-2479] Extended AVG estimator for other operations

Closes #813.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f74f5ad4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f74f5ad4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f74f5ad4

Branch: refs/heads/master
Commit: f74f5ad4bf27606a5cb5e27a16eceb65c0bd5f62
Parents: c98e815
Author: Johanna Sommer <[email protected]>
Authored: Fri Aug 3 17:41:22 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Aug 3 17:49:21 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/estim/EstimatorBasicAvg.java     |  54 ++-
 .../hops/estim/EstimatorMatrixHistogram.java    | 351 ++++++++++---------
 2 files changed, 220 insertions(+), 185 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f74f5ad4/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
index 259ab43..baa1fc4 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
@@ -27,35 +27,69 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
  * Basic average case estimator for matrix sparsity:
  * sp = 1 - Math.pow(1-sp1*sp2, k)
  */
-public class EstimatorBasicAvg extends SparsityEstimator
-{
+public class EstimatorBasicAvg extends SparsityEstimator {
        @Override
        public double estim(MMNode root) {
-               //recursive sparsity evaluation of non-leaf nodes
+               // recursive sparsity evaluation of non-leaf nodes
                double sp1 = !root.getLeft().isLeaf() ? estim(root.getLeft()) :
                        
OptimizerUtils.getSparsity(root.getLeft().getMatrixCharacteristics());
                double sp2 = !root.getRight().isLeaf() ? estim(root.getRight()) 
:
                        
OptimizerUtils.getSparsity(root.getRight().getMatrixCharacteristics());
-               return estimIntern(sp1, sp2, root.getRows(), 
root.getLeft().getCols(), root.getCols());
+               return estimInternMM(sp1, sp2, root.getRows(), 
root.getLeft().getCols(), root.getCols());
        }
 
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2) {
-               return estimIntern(m1.getSparsity(), m2.getSparsity(),
+               return estimInternMM(m1.getSparsity(), m2.getSparsity(),
                        m1.getNumRows(), m1.getNumColumns(), 
m2.getNumColumns());
        }
 
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
-               throw new NotImplementedException();
+               return estimIntern(m1, m2, op);
        }
-       
+
        @Override
        public double estim(MatrixBlock m, OpCode op) {
-               throw new NotImplementedException();
+               return estimIntern(m, null, op);
+       }
+
+       private double estimIntern(MatrixBlock m1, MatrixBlock m2, OpCode op) {
+               switch (op) {
+                       case MM:
+                               return estimInternMM(m1.getSparsity(), 
m2.getSparsity(),
+                                       m1.getNumRows(), m1.getNumColumns(), 
m2.getNumColumns());
+                       case MULT:
+                               return m1.getSparsity() * m2.getSparsity();
+                       case PLUS:
+                               return m1.getSparsity() + m2.getSparsity() - 
m1.getSparsity() * m2.getSparsity();
+                       case EQZERO:
+                               return 
OptimizerUtils.getSparsity(m1.getNumRows(), m1.getNumColumns(),
+                                       (long) m1.getNumRows() * 
m1.getNumColumns() - m1.getNonZeros());
+                       case DIAG:
+                               return (m1.getNumColumns() == 1) ?
+                                       
OptimizerUtils.getSparsity(m1.getNumRows(), m1.getNumRows(), m1.getNonZeros()) :
+                                       
OptimizerUtils.getSparsity(m1.getNumRows(), 1, Math.min(m1.getNumRows(), 
m1.getNonZeros()));
+                       // binary operations that preserve sparsity exactly
+                       case CBIND:
+                               return 
OptimizerUtils.getSparsity(m1.getNumRows(),
+                                       m1.getNumColumns() + 
m1.getNumColumns(), m1.getNonZeros() + m2.getNonZeros());
+                       case RBIND:
+                               return 
OptimizerUtils.getSparsity(m1.getNumRows() + m2.getNumRows(),
+                                       m1.getNumColumns(), m1.getNonZeros() + 
m2.getNonZeros());
+                       // unary operation that preserve sparsity exactly
+                       case NEQZERO:
+                               return m1.getSparsity();
+                       case TRANS:
+                               return m1.getSparsity();
+                       case RESHAPE:
+                               return m1.getSparsity();
+                       default:
+                               throw new NotImplementedException();
+               }
        }
-       
-       private double estimIntern(double sp1, double sp2, long m, long k, long 
n) {
+
+       private double estimInternMM(double sp1, double sp2, long m, long k, 
long n) {
                return OptimizerUtils.getMatMultSparsity(sp1, sp2, m, k, n, 
false);
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f74f5ad4/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
index 209e2f0..60e6af4 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
@@ -31,183 +31,184 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.SparseBlock;
 
 /**
- * This estimator implements a remarkably simple yet effective
- * approach for incorporating structural properties into sparsity
- * estimation. The key idea is to maintain row and column nnz per
- * matrix, along with additional meta data.
+ * This estimator implements a remarkably simple yet effective approach for
+ * incorporating structural properties into sparsity estimation. The key idea 
is
+ * to maintain row and column nnz per matrix, along with additional meta data.
  */
-public class EstimatorMatrixHistogram extends SparsityEstimator
-{
-       //internal configurations
+public class EstimatorMatrixHistogram extends SparsityEstimator {
+       // internal configurations
        private static final boolean DEFAULT_USE_EXCEPTS = true;
-       
+
        private final boolean _useExcepts;
-       
+
        public EstimatorMatrixHistogram() {
                this(DEFAULT_USE_EXCEPTS);
        }
-       
+
        public EstimatorMatrixHistogram(boolean useExcepts) {
                _useExcepts = useExcepts;
        }
-       
+
        @Override
        public double estim(MMNode root) {
-               //recursive histogram computation of non-leaf nodes
-               if( !root.getLeft().isLeaf() )
-                       estim(root.getLeft()); //obtain synopsis
-               if( !root.getRight().isLeaf() )
-                       estim(root.getLeft()); //obtain synopsis
-               MatrixHistogram h1 = !root.getLeft().isLeaf() ?
-                       (MatrixHistogram)root.getLeft().getSynopsis() :
-                       new MatrixHistogram(root.getLeft().getData(), 
_useExcepts);
-               MatrixHistogram h2 = !root.getRight().isLeaf() ?
-                       (MatrixHistogram)root.getRight().getSynopsis() :
-                       new MatrixHistogram(root.getRight().getData(), 
_useExcepts);
-               
-               //estimate output sparsity based on input histograms
+               // recursive histogram computation of non-leaf nodes
+               if (!root.getLeft().isLeaf())
+                       estim(root.getLeft()); // obtain synopsis
+               if (!root.getRight().isLeaf())
+                       estim(root.getLeft()); // obtain synopsis
+               MatrixHistogram h1 = !root.getLeft().isLeaf() ? 
(MatrixHistogram) root.getLeft().getSynopsis()
+                               : new MatrixHistogram(root.getLeft().getData(), 
_useExcepts);
+               MatrixHistogram h2 = !root.getRight().isLeaf() ? 
(MatrixHistogram) root.getRight().getSynopsis()
+                               : new 
MatrixHistogram(root.getRight().getData(), _useExcepts);
+
+               // estimate output sparsity based on input histograms
                double ret = estimIntern(h1, h2, OpCode.MM);
-               
-               //derive and memoize output histogram
+
+               // derive and memoize output histogram
                root.setSynopsis(MatrixHistogram.deriveOutputHistogram(h1, h2, 
ret));
-               
+
                return ret;
        }
-       
-       @Override 
+
+       @Override
        public double estim(MatrixBlock m1, MatrixBlock m2) {
                return estim(m1, m2, OpCode.MM);
        }
-       
+
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
                MatrixHistogram h1 = new MatrixHistogram(m1, _useExcepts);
-               MatrixHistogram h2 = (m1 == m2) ? //self product
-                       h1 : new MatrixHistogram(m2, _useExcepts);
+               MatrixHistogram h2 = (m1 == m2) ? // self product
+                               h1 : new MatrixHistogram(m2, _useExcepts);
                return estimIntern(h1, h2, op);
        }
-       
+
        @Override
        public double estim(MatrixBlock m1, OpCode op) {
                MatrixHistogram h1 = new MatrixHistogram(m1, _useExcepts);
                return estimIntern(h1, null, op);
        }
-       
+
        private double estimIntern(MatrixHistogram h1, MatrixHistogram h2, 
OpCode op) {
-               double msize = (double)h1.getRows()*h1.getCols();
-               
+               double msize = (double) h1.getRows() * h1.getCols();
+
                switch (op) {
-                       case MM:
-                               return estimInternMM(h1, h2);
-                       case MULT:
-                               return Math.min(
-                                       IntStream.range(0, 
h1.getRows()).mapToDouble(i -> (double)h1.rNnz[i]/msize * 
(double)h2.rNnz[i]/msize).sum(),
-                                       IntStream.range(0, 
h1.getCols()).mapToDouble(i -> (double)h1.cNnz[i]/msize * 
(double)h2.cNnz[i]/msize).sum());
-                       case PLUS:
-                               return Math.min(
-                                       IntStream.range(0, 
h1.getRows()).mapToDouble(i -> (double)h1.rNnz[i]/msize 
-                                               + (double)h2.rNnz[i]/msize - 
(double)h1.rNnz[i]/msize * (double)h2.rNnz[i]/msize).sum(),
-                                       IntStream.range(0, 
h1.getCols()).mapToDouble(i -> (double)h1.cNnz[i]/msize 
-                                               + (double)h2.cNnz[i]/msize - 
(double)h1.cNnz[i]/msize * (double)h2.cNnz[i]/msize).sum());
-                       case EQZERO:
-                               return OptimizerUtils.getSparsity(h1.getRows(), 
h1.getCols(),
-                                       (long)h1.getRows() * h1.getCols() - 
h1.getNonZeros());
-                       case DIAG:
-                               return (h1.getCols()==1) ?
-                                       
OptimizerUtils.getSparsity(h1.getRows(), h1.getRows(), h1.getNonZeros()) :
-                                       
OptimizerUtils.getSparsity(h1.getRows(), 1, Math.min(h1.getRows(), 
h1.getNonZeros()));
-                       //binary operations that preserve sparsity exactly
-                       case CBIND:
-                               return OptimizerUtils.getSparsity(h1.getRows(),
-                                       h1.getCols()+h2.getCols(), 
h1.getNonZeros() + h2.getNonZeros());
-                       case RBIND:
-                               return 
OptimizerUtils.getSparsity(h1.getRows()+h2.getRows(),
-                                       h1.getCols(), h1.getNonZeros() + 
h2.getNonZeros());
-                       //unary operation that preserve sparsity exactly
-                       case NEQZERO:
-                       case TRANS:
-                       case RESHAPE:
-                               return OptimizerUtils.getSparsity(h1.getRows(), 
h1.getCols(), h1.getNonZeros());
-                       default:
-                               throw new NotImplementedException();
+               case MM:
+                       return estimInternMM(h1, h2);
+               case MULT:
+                       return Math.min(
+                                       IntStream.range(0, 
h1.getRows()).mapToDouble(i -> h1.rNnz[i] / msize * h2.rNnz[i] / msize).sum(),
+                                       IntStream.range(0, 
h1.getCols()).mapToDouble(i -> h1.cNnz[i] / msize * h2.cNnz[i] / msize).sum());
+               case PLUS:
+                       return Math.min(
+                                       IntStream.range(0, h1.getRows())
+                                                       .mapToDouble(i -> 
h1.rNnz[i] / msize + h2.rNnz[i] / msize
+                                                                       - 
h1.rNnz[i] / msize * h2.rNnz[i] / msize)
+                                                       .sum(),
+                                       IntStream.range(0, 
h1.getCols()).mapToDouble(
+                                                       i -> h1.cNnz[i] / msize 
+ h2.cNnz[i] / msize - h1.cNnz[i] / msize * h2.cNnz[i] / msize)
+                                                       .sum());
+               case EQZERO:
+                       return OptimizerUtils.getSparsity(h1.getRows(), 
h1.getCols(),
+                                       (long) h1.getRows() * h1.getCols() - 
h1.getNonZeros());
+               case DIAG:
+                       return (h1.getCols() == 1) ? 
OptimizerUtils.getSparsity(h1.getRows(), h1.getRows(), h1.getNonZeros())
+                                       : 
OptimizerUtils.getSparsity(h1.getRows(), 1, Math.min(h1.getRows(), 
h1.getNonZeros()));
+               // binary operations that preserve sparsity exactly
+               case CBIND:
+                       return OptimizerUtils.getSparsity(h1.getRows(), 
h1.getCols() + h2.getCols(),
+                                       h1.getNonZeros() + h2.getNonZeros());
+               case RBIND:
+                       return OptimizerUtils.getSparsity(h1.getRows() + 
h2.getRows(), h1.getCols(),
+                                       h1.getNonZeros() + h2.getNonZeros());
+               // unary operation that preserve sparsity exactly
+               case NEQZERO:
+                       return OptimizerUtils.getSparsity(h1.getRows(), 
h1.getCols(), h1.getNonZeros());
+               case TRANS:
+                       return OptimizerUtils.getSparsity(h1.getRows(), 
h1.getCols(), h1.getNonZeros());
+               case RESHAPE:
+                       return OptimizerUtils.getSparsity(h1.getRows(), 
h1.getCols(), h1.getNonZeros());
+               default:
+                       throw new NotImplementedException();
                }
        }
-       
+
        private double estimInternMM(MatrixHistogram h1, MatrixHistogram h2) {
                long nnz = 0;
-               //special case, with exact sparsity estimate, where the dot 
product
-               //dot(h1.cNnz,h2rNnz) gives the exact number of non-zeros in 
the output
-               if( h1.rMaxNnz <= 1 || h2.cMaxNnz <= 1 ) {
-                       for( int j=0; j<h1.getCols(); j++ )
+               // special case, with exact sparsity estimate, where the dot 
product
+               // dot(h1.cNnz,h2rNnz) gives the exact number of non-zeros in 
the output
+               if (h1.rMaxNnz <= 1 || h2.cMaxNnz <= 1) {
+                       for (int j = 0; j < h1.getCols(); j++)
                                nnz += h1.cNnz[j] * h2.rNnz[j];
                }
-               //special case, with hybrid exact and approximate output
-               else if(h1.cNnz1e!=null && h2.rNnz1e != null) {
-                       //note: normally h1.getRows()*h2.getCols() would define 
mnOut
-                       //but by leveraging the knowledge of rows/cols w/ <=1 
nnz, we account
-                       //that exact and approximate fractions touch different 
areas
-                       long mnOut = (h1.rNonEmpty-h1.rN1) * 
(h2.cNonEmpty-h2.cN1);
+               // special case, with hybrid exact and approximate output
+               else if (h1.cNnz1e != null && h2.rNnz1e != null) {
+                       // note: normally h1.getRows()*h2.getCols() would 
define mnOut
+                       // but by leveraging the knowledge of rows/cols w/ <=1 
nnz, we
+                       // account
+                       // that exact and approximate fractions touch different 
areas
+                       long mnOut = (h1.rNonEmpty - h1.rN1) * (h2.cNonEmpty - 
h2.cN1);
                        double spOutRest = 0;
-                       for( int j=0; j<h1.getCols(); j++ ) {
-                               //exact fractions, w/o double counting
+                       for (int j = 0; j < h1.getCols(); j++) {
+                               // exact fractions, w/o double counting
                                nnz += h1.cNnz1e[j] * h2.rNnz[j];
-                               nnz += (h1.cNnz[j]-h1.cNnz1e[j]) * h2.rNnz1e[j];
-                               //approximate fraction, w/o double counting
-                               double lsp = (double)(h1.cNnz[j]-h1.cNnz1e[j]) 
-                                       * (h2.rNnz[j]-h2.rNnz1e[j]) / mnOut;
-                               spOutRest = spOutRest + lsp - spOutRest*lsp;
+                               nnz += (h1.cNnz[j] - h1.cNnz1e[j]) * 
h2.rNnz1e[j];
+                               // approximate fraction, w/o double counting
+                               double lsp = (double) (h1.cNnz[j] - 
h1.cNnz1e[j]) * (h2.rNnz[j] - h2.rNnz1e[j]) / mnOut;
+                               spOutRest = spOutRest + lsp - spOutRest * lsp;
                        }
-                       nnz += (long)(spOutRest * mnOut);
+                       nnz += (long) (spOutRest * mnOut);
                }
-               //general case with approximate output
+               // general case with approximate output
                else {
-                       long mnOut = h1.getRows()*h2.getCols();
+                       long mnOut = h1.getRows() * h2.getCols();
                        double spOut = 0;
-                       for( int j=0; j<h1.getCols(); j++ ) {
+                       for (int j = 0; j < h1.getCols(); j++) {
                                double lsp = (double) h1.cNnz[j] * h2.rNnz[j] / 
mnOut;
-                               spOut = spOut + lsp - spOut*lsp;
+                               spOut = spOut + lsp - spOut * lsp;
                        }
-                       nnz = (long)(spOut * mnOut);
+                       nnz = (long) (spOut * mnOut);
                }
-               
-               //exploit upper bound on nnz based on non-empty rows/cols
-               nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ?
-                       Math.min((long)h1.rNonEmpty * h2.cNonEmpty, nnz) : nnz;
-               
-               //exploit lower bound on nnz based on half-full rows/cols
-               nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ?
-                       Math.max((long)h1.rNdiv2 * h2.cNdiv2, nnz) : nnz;
-               
-               //compute final sparsity
-               return OptimizerUtils.getSparsity(
-                       h1.getRows(), h2.getCols(), nnz);
+
+               // exploit upper bound on nnz based on non-empty rows/cols
+               nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ? 
Math.min((long) h1.rNonEmpty * h2.cNonEmpty, nnz) : nnz;
+
+               // exploit lower bound on nnz based on half-full rows/cols
+               nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ? Math.max((long) 
h1.rNdiv2 * h2.cNdiv2, nnz) : nnz;
+
+               // compute final sparsity
+               return OptimizerUtils.getSparsity(h1.getRows(), h2.getCols(), 
nnz);
        }
-       
+
        private static class MatrixHistogram {
                // count vectors (the histogram)
-               private final int[] rNnz;    //nnz per row
-               private int[] rNnz1e = null; //nnz per row for cols w/ <= 1 
non-zeros
-               private final int[] cNnz;    //nnz per col
-               private int[] cNnz1e = null; //nnz per col for rows w/ <= 1 
non-zeros
+               private final int[] rNnz; // nnz per row
+               private int[] rNnz1e = null; // nnz per row for cols w/ <= 1 
non-zeros
+               private final int[] cNnz; // nnz per col
+               private int[] cNnz1e = null; // nnz per col for rows w/ <= 1 
non-zeros
                // additional summary statistics
-               private final int rMaxNnz, cMaxNnz;     //max nnz per row/row
-               private final int rN1, cN1;             //number of rows/cols 
with nnz=1
-               private final int rNonEmpty, cNonEmpty; //number of non-empty 
rows/cols (w/ empty is nnz=0)
-               private final int rNdiv2, cNdiv2;       //number of rows/cols 
with nnz > #cols/2 and #rows/2
-               private boolean fullDiag;               //true if there exists 
a full diagonal of nonzeros
-               
+               private final int rMaxNnz, cMaxNnz; // max nnz per row/row
+               private final int rN1, cN1; // number of rows/cols with nnz=1
+               private final int rNonEmpty, cNonEmpty; // number of non-empty 
rows/cols
+                                                                               
                // (w/ empty is nnz=0)
+               private final int rNdiv2, cNdiv2; // number of rows/cols with 
nnz >
+                                                                               
        // #cols/2 and #rows/2
+               private boolean fullDiag; // true if there exists a full 
diagonal of
+                                                                       // 
nonzeros
+
                public MatrixHistogram(MatrixBlock in, boolean useExcepts) {
                        // 1) allocate basic synopsis
                        rNnz = new int[in.getNumRows()];
                        cNnz = new int[in.getNumColumns()];
                        fullDiag = in.getNumRows() == in.getNonZeros();
-                       
+
                        // 2) compute basic synopsis details
-                       if( !in.isEmpty() ) {
-                               if( in.isInSparseFormat() ) {
+                       if (!in.isEmpty()) {
+                               if (in.isInSparseFormat()) {
                                        SparseBlock sblock = 
in.getSparseBlock();
-                                       for( int i=0; i<in.getNumRows(); i++ ) {
-                                               if( sblock.isEmpty(i) ) 
continue;
+                                       for (int i = 0; i < in.getNumRows(); 
i++) {
+                                               if (sblock.isEmpty(i))
+                                                       continue;
                                                int apos = sblock.pos(i);
                                                int alen = sblock.size(i);
                                                int[] aix = sblock.indexes(i);
@@ -215,70 +216,69 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
                                                
LibMatrixAgg.countAgg(sblock.values(i), cNnz, aix, apos, alen);
                                                fullDiag &= aix[apos] == i;
                                        }
-                               }
-                               else {
+                               } else {
                                        DenseBlock dblock = in.getDenseBlock();
-                                       for( int i=0; i<in.getNumRows(); i++ ) {
+                                       for (int i = 0; i < in.getNumRows(); 
i++) {
                                                double[] avals = 
dblock.values(i);
                                                int lnnz = 0, aix = 
dblock.pos(i);
-                                               for( int j=0; 
j<in.getNumColumns(); j++ ) {
-                                                       if( avals[aix+j] != 0 ) 
{
+                                               for (int j = 0; j < 
in.getNumColumns(); j++) {
+                                                       if (avals[aix + j] != 
0) {
                                                                fullDiag &= (i 
== j);
-                                                               cNnz[j] ++;
-                                                               lnnz ++;
+                                                               cNnz[j]++;
+                                                               lnnz++;
                                                        }
                                                }
                                                rNnz[i] = lnnz;
                                        }
                                }
                        }
-                       
+
                        // 3) compute meta data synopsis
                        rMaxNnz = Arrays.stream(rNnz).max().orElse(0);
-                       cMaxNnz =  Arrays.stream(cNnz).max().orElse(0);
+                       cMaxNnz = Arrays.stream(cNnz).max().orElse(0);
                        rN1 = (int) Arrays.stream(rNnz).filter(item -> item == 
1).count();
                        cN1 = (int) Arrays.stream(cNnz).filter(item -> item == 
1).count();
-                       rNonEmpty = (int) Arrays.stream(rNnz).filter(v-> 
v!=0).count();
-                       cNonEmpty = (int) Arrays.stream(cNnz).filter(v-> 
v!=0).count();
-                       rNdiv2 = (int) Arrays.stream(rNnz).filter(item -> item 
> getCols()/2).count();
-                       cNdiv2 = (int) Arrays.stream(cNnz).filter(item -> item 
> getRows()/2).count();
-                       
+                       rNonEmpty = (int) Arrays.stream(rNnz).filter(v -> v != 
0).count();
+                       cNonEmpty = (int) Arrays.stream(cNnz).filter(v -> v != 
0).count();
+                       rNdiv2 = (int) Arrays.stream(rNnz).filter(item -> item 
> getCols() / 2).count();
+                       cNdiv2 = (int) Arrays.stream(cNnz).filter(item -> item 
> getRows() / 2).count();
+
                        // 4) compute exception details if necessary (optional)
-                       if( useExcepts & !in.isEmpty() && (rMaxNnz > 1 || 
cMaxNnz > 1) ) {
+                       if (useExcepts & !in.isEmpty() && (rMaxNnz > 1 || 
cMaxNnz > 1)) {
                                rNnz1e = new int[in.getNumRows()];
                                cNnz1e = new int[in.getNumColumns()];
-                               
-                               if( in.isInSparseFormat() ) {
+
+                               if (in.isInSparseFormat()) {
                                        SparseBlock sblock = 
in.getSparseBlock();
-                                       for( int i=0; i<in.getNumRows(); i++ ) {
-                                               if( sblock.isEmpty(i) ) 
continue;
+                                       for (int i = 0; i < in.getNumRows(); 
i++) {
+                                               if (sblock.isEmpty(i))
+                                                       continue;
                                                int alen = sblock.size(i);
                                                int apos = sblock.pos(i);
                                                int[] aix = sblock.indexes(i);
-                                               for( int k=apos; k<apos+alen; 
k++ )
+                                               for (int k = apos; k < apos + 
alen; k++)
                                                        rNnz1e[i] += 
cNnz[aix[k]] <= 1 ? 1 : 0;
-                                               if( alen <= 1 )
-                                                       for( int k=apos; 
k<apos+alen; k++ )
+                                               if (alen <= 1)
+                                                       for (int k = apos; k < 
apos + alen; k++)
                                                                
cNnz1e[aix[k]]++;
                                        }
-                               }
-                               else {
+                               } else {
                                        DenseBlock dblock = in.getDenseBlock();
-                                       for( int i=0; i<in.getNumRows(); i++ ) {
+                                       for (int i = 0; i < in.getNumRows(); 
i++) {
                                                double[] avals = 
dblock.values(i);
                                                int aix = dblock.pos(i);
                                                boolean rNnzlte1 = rNnz[i] <= 1;
-                                               for( int j=0; 
j<in.getNumColumns(); j++ ) {
-                                                       if( avals[aix + j] != 0 
) {
+                                               for (int j = 0; j < 
in.getNumColumns(); j++) {
+                                                       if (avals[aix + j] != 
0) {
                                                                rNnz1e[i] += 
cNnz[j] <= 1 ? 1 : 0;
-                                                               cNnz1e[j] +=  
rNnzlte1 ? 1 : 0;
+                                                               cNnz1e[j] += 
rNnzlte1 ? 1 : 0;
                                                        }
                                                }
                                        }
                                }
                        }
                }
-               
+
                public MatrixHistogram(int[] r, int[] r1e, int[] c, int[] c1e, 
int rmax, int cmax) {
                        rNnz = r;
                        rNnz1e = r1e;
@@ -290,55 +290,56 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
                        rNonEmpty = cNonEmpty = -1;
                        rNdiv2 = cNdiv2 = -1;
                }
-               
+
                public int getRows() {
                        return rNnz.length;
                }
-               
+
                public int getCols() {
                        return cNnz.length;
                }
-               
+
                public long getNonZeros() {
-                       return getRows() < getCols() ?
-                               IntStream.range(0, getRows()).mapToLong(i-> 
rNnz[i]).sum() :
-                               IntStream.range(0, getRows()).mapToLong(i-> 
cNnz[i]).sum();
+                       return getRows() < getCols() ? IntStream.range(0, 
getRows()).mapToLong(i -> rNnz[i]).sum()
+                                       : IntStream.range(0, 
getRows()).mapToLong(i -> cNnz[i]).sum();
                }
-               
+
                public static MatrixHistogram 
deriveOutputHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut) {
-                       //exact propagation if lhs or rhs full diag
-                       if( h1.fullDiag ) return h2;
-                       if( h2.fullDiag ) return h1;
-                       
-                       //get input/output nnz for scaling
+                       // exact propagation if lhs or rhs full diag
+                       if (h1.fullDiag)
+                               return h2;
+                       if (h2.fullDiag)
+                               return h1;
+
+                       // get input/output nnz for scaling
                        long nnz1 = Arrays.stream(h1.rNnz).sum();
                        long nnz2 = Arrays.stream(h2.cNnz).sum();
                        double nnzOut = spOut * h1.getRows() * h2.getCols();
-                       
-                       //propagate h1.r and h2.c to output via simple scaling
-                       //(this implies 0s propagate and distribution is 
preserved)
+
+                       // propagate h1.r and h2.c to output via simple scaling
+                       // (this implies 0s propagate and distribution is 
preserved)
                        int rMaxNnz = 0, cMaxNnz = 0;
                        int[] rNnz = new int[h1.getRows()];
                        Random rn = new Random();
-                       for( int i=0; i<h1.getRows(); i++ ) {
-                               rNnz[i] = probRound(nnzOut/nnz1 * h1.rNnz[i], 
rn);
+                       for (int i = 0; i < h1.getRows(); i++) {
+                               rNnz[i] = probRound(nnzOut / nnz1 * h1.rNnz[i], 
rn);
                                rMaxNnz = Math.max(rMaxNnz, rNnz[i]);
                        }
                        int[] cNnz = new int[h2.getCols()];
-                       for( int i=0; i<h2.getCols(); i++ ) {
-                               cNnz[i] = probRound(nnzOut/nnz2 * h2.cNnz[i], 
rn);
+                       for (int i = 0; i < h2.getCols(); i++) {
+                               cNnz[i] = probRound(nnzOut / nnz2 * h2.cNnz[i], 
rn);
                                cMaxNnz = Math.max(cMaxNnz, cNnz[i]);
                        }
-                       
-                       //construct new histogram object
+
+                       // construct new histogram object
                        return new MatrixHistogram(rNnz, null, cNnz, null, 
rMaxNnz, cMaxNnz);
                }
-               
+
                private static int probRound(double inNnz, Random rand) {
                        double temp = Math.floor(inNnz);
-                       double f = inNnz - temp; //non-int fraction [0,1)
-                       double randf = rand.nextDouble(); //uniform [0,1)
-                       return (int)((f > randf) ? temp+1 : temp);
+                       double f = inNnz - temp; // non-int fraction [0,1)
+                       double randf = rand.nextDouble(); // uniform [0,1)
+                       return (int) ((f > randf) ? temp + 1 : temp);
                }
        }
 }

Reply via email to