Repository: systemml
Updated Branches:
  refs/heads/master e3a51f72a -> bc7b4961a


[SYSTEMML-2479] Finalized AVG and WC sparsity estimators, API cleanups

Closes #818.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/bc7b4961
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/bc7b4961
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/bc7b4961

Branch: refs/heads/master
Commit: bc7b4961a20c47c8064bc9ce7d30ba071b44a748
Parents: e3a51f7
Author: Johanna Sommer <[email protected]>
Authored: Mon Aug 6 16:05:26 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Aug 6 21:34:06 2018 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/OptimizerUtils.java   | 13 +++-
 .../sysml/hops/estim/EstimatorBasicAvg.java     | 66 ++++++++++----------
 .../sysml/hops/estim/EstimatorBasicWorst.java   | 64 ++++++++++++++-----
 .../sysml/hops/estim/EstimatorBitsetMM.java     |  8 ++-
 .../sysml/hops/estim/EstimatorDensityMap.java   |  7 ++-
 .../sysml/hops/estim/EstimatorLayeredGraph.java |  3 +-
 .../hops/estim/EstimatorMatrixHistogram.java    | 10 +--
 .../sysml/hops/estim/EstimatorSample.java       |  3 +-
 .../org/apache/sysml/hops/estim/MMNode.java     | 17 ++++-
 .../sysml/hops/estim/SparsityEstimator.java     |  4 +-
 .../runtime/matrix/MatrixCharacteristics.java   | 17 ++++-
 .../estim/SquaredProductChainTest.java          |  7 ++-
 12 files changed, 146 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index fb83df0..e78d809 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -85,7 +85,7 @@ public class OptimizerUtils
         * e.g., when input/output dimensions are unknown. The default is set 
to a large 
         * value so that operations are scheduled on MR while avoiding 
overflows as well.  
         */
-       public static double DEFAULT_SIZE;      
+       public static double DEFAULT_SIZE;
        
        
        public static final long DOUBLE_SIZE = 8;
@@ -970,6 +970,10 @@ public class OptimizerUtils
        // Sparsity Estimates //
        ////////////////////////
        
+       public static long getMatMultNnz(double sp1, double sp2, long m, long 
k, long n, boolean worstcase) {
+               return getNnz( m, n, getMatMultSparsity(sp1, sp2, m, k, n, 
worstcase));
+       }
+       
        /**
         * Estimates the result sparsity for Matrix Multiplication A %*% B. 
         *  
@@ -981,8 +985,7 @@ public class OptimizerUtils
         * @param worstcase true if worst case
         * @return the sparsity
         */
-       public static double getMatMultSparsity(double sp1, double sp2, long m, 
long k, long n, boolean worstcase) 
-       {
+       public static double getMatMultSparsity(double sp1, double sp2, long m, 
long k, long n, boolean worstcase) {
                if( worstcase ){
                        double nnz1 = sp1 * m * k;
                        double nnz2 = sp2 * k * n;
@@ -1159,6 +1162,10 @@ public class OptimizerUtils
                }
        }
        
+       public static long getNnz(long dim1, long dim2, double sp) {
+               return (long) Math.round(sp * dim1 * dim2);
+       }
+       
        public static double getSparsity( MatrixCharacteristics mc ) {
                return getSparsity(mc.getRows(), mc.getCols(), 
mc.getNonZeros());
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
index baa1fc4..6448f70 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
@@ -21,75 +21,75 @@ package org.apache.sysml.hops.estim;
 
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 
 /**
  * Basic average case estimator for matrix sparsity:
  * sp = 1 - Math.pow(1-sp1*sp2, k)
  */
-public class EstimatorBasicAvg extends SparsityEstimator {
+public class EstimatorBasicAvg extends SparsityEstimator
+{
        @Override
-       public double estim(MMNode root) {
-               // recursive sparsity evaluation of non-leaf nodes
-               double sp1 = !root.getLeft().isLeaf() ? estim(root.getLeft()) :
-                       
OptimizerUtils.getSparsity(root.getLeft().getMatrixCharacteristics());
-               double sp2 = !root.getRight().isLeaf() ? estim(root.getRight()) 
:
-                       
OptimizerUtils.getSparsity(root.getRight().getMatrixCharacteristics());
-               return estimInternMM(sp1, sp2, root.getRows(), 
root.getLeft().getCols(), root.getCols());
+       public MatrixCharacteristics estim(MMNode root) {
+               MatrixCharacteristics mc1 = !root.getLeft().isLeaf() ?
+                       estim(root.getLeft()) : 
root.getLeft().getMatrixCharacteristics();
+               MatrixCharacteristics mc2 = !root.getRight().isLeaf() ?
+                       estim(root.getRight()) : 
root.getRight().getMatrixCharacteristics();
+               return root.setMatrixCharacteristics(
+                       estimIntern(mc1, mc2, root.getOp()));
        }
 
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2) {
-               return estimInternMM(m1.getSparsity(), m2.getSparsity(),
-                       m1.getNumRows(), m1.getNumColumns(), 
m2.getNumColumns());
+               return estim(m1, m2, OpCode.MM);
        }
 
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
-               return estimIntern(m1, m2, op);
+               return estimIntern(m1.getMatrixCharacteristics(), 
m2.getMatrixCharacteristics(), op).getSparsity();
        }
 
        @Override
        public double estim(MatrixBlock m, OpCode op) {
-               return estimIntern(m, null, op);
+               return estimIntern(m.getMatrixCharacteristics(), null, 
op).getSparsity();
        }
 
-       private double estimIntern(MatrixBlock m1, MatrixBlock m2, OpCode op) {
+       private MatrixCharacteristics estimIntern(MatrixCharacteristics mc1, 
MatrixCharacteristics mc2, OpCode op) {
                switch (op) {
                        case MM:
-                               return estimInternMM(m1.getSparsity(), 
m2.getSparsity(),
-                                       m1.getNumRows(), m1.getNumColumns(), 
m2.getNumColumns());
+                               return new MatrixCharacteristics(mc1.getRows(), 
mc2.getCols(),
+                                       
OptimizerUtils.getMatMultNnz(mc1.getSparsity(), mc2.getSparsity(),
+                                       mc1.getRows(), mc1.getCols(), 
mc2.getCols(), false));
                        case MULT:
-                               return m1.getSparsity() * m2.getSparsity();
+                               return new MatrixCharacteristics(mc1.getRows(), 
mc1.getCols(), 
+                                       OptimizerUtils.getNnz(mc1.getRows(), 
mc1.getCols(),
+                                               mc1.getSparsity() * 
mc2.getSparsity()));
                        case PLUS:
-                               return m1.getSparsity() + m2.getSparsity() - 
m1.getSparsity() * m2.getSparsity();
+                               return new MatrixCharacteristics(mc1.getRows(), 
mc1.getCols(), 
+                                       OptimizerUtils.getNnz(mc1.getRows(), 
mc1.getCols(), 
+                                               mc1.getSparsity() + 
mc2.getSparsity() - mc1.getSparsity() * mc2.getSparsity()));
                        case EQZERO:
-                               return 
OptimizerUtils.getSparsity(m1.getNumRows(), m1.getNumColumns(),
-                                       (long) m1.getNumRows() * 
m1.getNumColumns() - m1.getNonZeros());
+                               return new MatrixCharacteristics(mc1.getRows(), 
mc1.getCols(),
+                                       (long) mc1.getRows() * mc1.getCols() - 
mc1.getNonZeros());
                        case DIAG:
-                               return (m1.getNumColumns() == 1) ?
-                                       
OptimizerUtils.getSparsity(m1.getNumRows(), m1.getNumRows(), m1.getNonZeros()) :
-                                       
OptimizerUtils.getSparsity(m1.getNumRows(), 1, Math.min(m1.getNumRows(), 
m1.getNonZeros()));
+                               return (mc1.getCols() == 1) ?
+                                       new 
MatrixCharacteristics(mc1.getRows(), mc1.getRows(), mc1.getNonZeros()) :
+                                       new 
MatrixCharacteristics(mc1.getRows(), 1, Math.min(mc1.getRows(), 
mc1.getNonZeros()));
                        // binary operations that preserve sparsity exactly
                        case CBIND:
-                               return 
OptimizerUtils.getSparsity(m1.getNumRows(),
-                                       m1.getNumColumns() + 
m1.getNumColumns(), m1.getNonZeros() + m2.getNonZeros());
+                               return new MatrixCharacteristics(mc1.getRows(), 
+                                       mc1.getCols() + mc2.getCols(), 
mc1.getNonZeros() + mc2.getNonZeros());
                        case RBIND:
-                               return 
OptimizerUtils.getSparsity(m1.getNumRows() + m2.getNumRows(),
-                                       m1.getNumColumns(), m1.getNonZeros() + 
m2.getNonZeros());
+                               return new MatrixCharacteristics(mc1.getRows() 
+ mc2.getRows(), 
+                                       mc1.getCols(), mc1.getNonZeros() + 
mc2.getNonZeros());
                        // unary operation that preserve sparsity exactly
                        case NEQZERO:
-                               return m1.getSparsity();
                        case TRANS:
-                               return m1.getSparsity();
                        case RESHAPE:
-                               return m1.getSparsity();
+                               return mc1;
                        default:
                                throw new NotImplementedException();
                }
        }
-
-       private double estimInternMM(double sp1, double sp2, long m, long k, 
long n) {
-               return OptimizerUtils.getMatMultSparsity(sp1, sp2, m, k, n, 
false);
-       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java
index 736affe..e99b55d 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.estim;
 
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 
 /**
@@ -34,32 +35,65 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 public class EstimatorBasicWorst extends SparsityEstimator
 {
        @Override
-       public double estim(MMNode root) {
-               //recursive sparsity evaluation of non-leaf nodes
-               double sp1 = !root.getLeft().isLeaf() ? estim(root.getLeft()) :
-                       
OptimizerUtils.getSparsity(root.getLeft().getMatrixCharacteristics());
-               double sp2 = !root.getRight().isLeaf() ? estim(root.getRight()) 
:
-                       
OptimizerUtils.getSparsity(root.getRight().getMatrixCharacteristics());
-               return estimIntern(sp1, sp2, root.getRows(), 
root.getLeft().getCols(), root.getCols());
+       public MatrixCharacteristics estim(MMNode root) {
+               MatrixCharacteristics mc1 = !root.getLeft().isLeaf() ?
+                       estim(root.getLeft()) : 
root.getLeft().getMatrixCharacteristics();
+               MatrixCharacteristics mc2 = !root.getRight().isLeaf() ?
+                       estim(root.getRight()) : 
root.getRight().getMatrixCharacteristics();
+               return root.setMatrixCharacteristics(
+                       estimIntern(mc1, mc2, root.getOp()));
        }
 
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2) {
-               return estimIntern(m1.getSparsity(), m2.getSparsity(),
-                       m1.getNumRows(), m1.getNumColumns(), 
m2.getNumColumns());
+               return estim(m1, m2, OpCode.MM);
        }
-       
+
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
-               throw new NotImplementedException();
+               return estimIntern(m1.getMatrixCharacteristics(), 
m2.getMatrixCharacteristics(), op).getSparsity();
        }
-       
+
        @Override
        public double estim(MatrixBlock m, OpCode op) {
-               throw new NotImplementedException();
+               return estimIntern(m.getMatrixCharacteristics(), null, 
op).getSparsity();
        }
 
-       private double estimIntern(double sp1, double sp2, long m, long k, long 
n) {
-               return OptimizerUtils.getMatMultSparsity(sp1, sp2, m, k, n, 
true);
+       private MatrixCharacteristics estimIntern(MatrixCharacteristics mc1, 
MatrixCharacteristics mc2, OpCode op) {
+               switch (op) {
+                       case MM:
+                               return new MatrixCharacteristics(mc1.getRows(), 
mc2.getCols(),
+                                       
OptimizerUtils.getMatMultNnz(mc1.getSparsity(), mc2.getSparsity(),
+                                       mc1.getRows(), mc1.getCols(), 
mc2.getCols(), true));
+                       case MULT:
+                               return new MatrixCharacteristics(mc1.getRows(), 
mc1.getCols(), 
+                                       OptimizerUtils.getNnz(mc1.getRows(), 
mc1.getCols(),
+                                               Math.min(mc1.getSparsity(), 
mc2.getSparsity())));
+                       case PLUS:
+                               return new MatrixCharacteristics(mc1.getRows(), 
mc1.getCols(), 
+                                       OptimizerUtils.getNnz(mc1.getRows(), 
mc1.getCols(), 
+                                               Math.min(mc1.getSparsity() + 
mc2.getSparsity(), 1)));
+                       case EQZERO:
+                               return new MatrixCharacteristics(mc1.getRows(), 
mc1.getCols(),
+                                       (long) mc1.getRows() * mc1.getCols() - 
mc1.getNonZeros());
+                       case DIAG:
+                               return (mc1.getCols() == 1) ?
+                                       new 
MatrixCharacteristics(mc1.getRows(), mc1.getRows(), mc1.getNonZeros()) :
+                                       new 
MatrixCharacteristics(mc1.getRows(), 1, Math.min(mc1.getRows(), 
mc1.getNonZeros()));
+                       // binary operations that preserve sparsity exactly
+                       case CBIND:
+                               return new MatrixCharacteristics(mc1.getRows(), 
+                                       mc1.getCols() + mc2.getCols(), 
mc1.getNonZeros() + mc2.getNonZeros());
+                       case RBIND:
+                               return new MatrixCharacteristics(mc1.getRows() 
+ mc2.getRows(), 
+                                       mc1.getCols(), mc1.getNonZeros() + 
mc2.getNonZeros());
+                       // unary operation that preserve sparsity exactly
+                       case NEQZERO:
+                       case TRANS:
+                       case RESHAPE:
+                               return mc1;
+                       default:
+                               throw new NotImplementedException();
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
index 38de751..a898147 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
@@ -25,6 +25,7 @@ import java.util.stream.IntStream;
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysml.hops.OptimizerUtils;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.DenseBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.SparseBlock;
@@ -42,7 +43,7 @@ import org.apache.sysml.runtime.matrix.data.SparseBlock;
  */
 public class EstimatorBitsetMM extends SparsityEstimator {
        @Override
-       public double estim(MMNode root) {
+       public MatrixCharacteristics estim(MMNode root) {
                // recursive density map computation of non-leaf nodes
                if (!root.getLeft().isLeaf())
                        estim(root.getLeft()); // obtain synopsis
@@ -56,7 +57,8 @@ public class EstimatorBitsetMM extends SparsityEstimator {
                // estimate output density map and sparsity via boolean matrix 
mult
                BitsetMatrix outMap = m1Map.matMult(m2Map);
                root.setSynopsis(outMap); // memoize boolean matrix
-               return OptimizerUtils.getSparsity(outMap.getNumRows(), 
outMap.getNumColumns(), outMap.getNonZeros());
+               return root.setMatrixCharacteristics(new MatrixCharacteristics(
+                       outMap.getNumRows(), outMap.getNumColumns(), 
outMap.getNonZeros()));
        }
 
        @Override
@@ -277,7 +279,7 @@ public class EstimatorBitsetMM extends SparsityEstimator {
                                c[ci+0] |= b[bi+0]; c[ci+1] |= b[bi+1];
                                c[ci+2] |= b[bi+2]; c[ci+3] |= b[bi+3];
                                c[ci+4] |= b[bi+4]; c[ci+5] |= b[bi+5];
-                               c[ci+6] |= b[bi+4]; c[ci+7] |= b[bi+7];
+                               c[ci+6] |= b[bi+6]; c[ci+7] |= b[bi+7];
                        }
                }
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
index c86ad21..1246978 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.estim;
 
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.DenseBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.SparseBlock;
@@ -52,7 +53,7 @@ public class EstimatorDensityMap extends SparsityEstimator
        }
        
        @Override
-       public double estim(MMNode root) {
+       public MatrixCharacteristics estim(MMNode root) {
                //recursive density map computation of non-leaf nodes
                if( !root.getLeft().isLeaf() )
                        estim(root.getLeft()); //obtain synopsis
@@ -67,8 +68,8 @@ public class EstimatorDensityMap extends SparsityEstimator
                MatrixBlock outMap = estimIntern(m1Map, m2Map,
                        false, root.getRows(), root.getLeft().getCols(), 
root.getCols());
                root.setSynopsis(outMap); //memoize density map
-               return OptimizerUtils.getSparsity( //aggregate output histogram
-                       root.getRows(), root.getCols(), (long)outMap.sum());
+               return root.setMatrixCharacteristics(new MatrixCharacteristics(
+                       root.getLeft().getRows(), root.getRight().getCols(), 
(long)outMap.sum()));
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java
index b646970..e24886a 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java
@@ -22,6 +22,7 @@ import org.apache.commons.lang.NotImplementedException;
 import org.apache.commons.math3.distribution.ExponentialDistribution;
 import org.apache.commons.math3.random.Well1024a;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.DenseBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.SparseBlock;
@@ -52,7 +53,7 @@ public class EstimatorLayeredGraph extends SparsityEstimator {
        }
        
        @Override
-       public double estim(MMNode root) {
+       public MatrixCharacteristics estim(MMNode root) {
                throw new NotImplementedException();
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
index 6b4c898..7f0047c 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
@@ -25,6 +25,7 @@ import java.util.stream.IntStream;
 
 import org.apache.directory.api.util.exception.NotImplementedException;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.DenseBlock;
 import org.apache.sysml.runtime.matrix.data.LibMatrixAgg;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -52,7 +53,7 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
        }
        
        @Override
-       public double estim(MMNode root) {
+       public MatrixCharacteristics estim(MMNode root) {
                //recursive histogram computation of non-leaf nodes
                if( !root.getLeft().isLeaf() )
                        estim(root.getLeft()); //obtain synopsis
@@ -69,9 +70,10 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
                double ret = estimIntern(h1, h2, OpCode.MM);
                
                //derive and memoize output histogram
-               root.setSynopsis(MatrixHistogram.deriveOutputHistogram(h1, h2, 
ret));
-               
-               return ret;
+               MatrixHistogram outMap = 
MatrixHistogram.deriveOutputHistogram(h1, h2, ret);
+               root.setSynopsis(outMap);
+               return root.setMatrixCharacteristics(new MatrixCharacteristics(
+                       outMap.getRows(), outMap.getCols(), 
outMap.getNonZeros()));
        }
        
        @Override 

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
index ad2f6b7..1df23dd 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
@@ -22,6 +22,7 @@ package org.apache.sysml.hops.estim;
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.DenseBlock;
 import org.apache.sysml.runtime.matrix.data.LibMatrixAgg;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -55,7 +56,7 @@ public class EstimatorSample extends SparsityEstimator
        }
        
        @Override
-       public double estim(MMNode root) {
+       public MatrixCharacteristics estim(MMNode root) {
                LOG.warn("Recursive estimates not supported by EstimatorSample, 
falling back to EstimatorBasicAvg.");
                return new EstimatorBasicAvg().estim(root);
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/MMNode.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/MMNode.java 
b/src/main/java/org/apache/sysml/hops/estim/MMNode.java
index 55aee3d..542449a 100644
--- a/src/main/java/org/apache/sysml/hops/estim/MMNode.java
+++ b/src/main/java/org/apache/sysml/hops/estim/MMNode.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysml.hops.estim;
 
+import org.apache.sysml.hops.estim.SparsityEstimator.OpCode;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 
@@ -33,20 +34,22 @@ public class MMNode
        private final MatrixBlock _data;
        private final MatrixCharacteristics _mc;
        private Object _synops = null;
+       private final OpCode _op;
        
        public MMNode(MatrixBlock in) {
                _m1 = null;
                _m2 = null;
                _data = in;
                _mc = in.getMatrixCharacteristics();
+               _op = null;
        }
        
-       public MMNode(MMNode left, MMNode right) {
+       public MMNode(MMNode left, MMNode right, OpCode op) {
                _m1 = left;
                _m2 = right;
                _data = null;
-               _mc = new MatrixCharacteristics(
-                       _m1.getRows(), _m2.getCols(), -1, -1);
+               _mc = new MatrixCharacteristics(-1, -1, -1, -1);
+               _op = op;
        }
        
        public int getRows() {
@@ -61,6 +64,10 @@ public class MMNode
                return _mc;
        }
        
+       public MatrixCharacteristics 
setMatrixCharacteristics(MatrixCharacteristics mc) {
+               return _mc.set(mc); //implicit copy
+       }
+       
        public MMNode getLeft() {
                return _m1;
        }
@@ -84,4 +91,8 @@ public class MMNode
        public Object getSynopsis() {
                return _synops;
        }
+       
+       public OpCode getOp() {
+               return _op;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java 
b/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java
index 411d9cd..a37372f 100644
--- a/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java
+++ b/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.estim;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 
 public abstract class SparsityEstimator 
@@ -46,7 +47,8 @@ public abstract class SparsityEstimator
         * @param root
         * @return
         */
-       public abstract double estim(MMNode root);
+       public abstract MatrixCharacteristics estim(MMNode root);
+       
        
        /**
         * Estimates the output sparsity for a single matrix multiplication.

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java 
b/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java
index 6d816b6..183cd6e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java
@@ -81,6 +81,10 @@ public class MatrixCharacteristics implements Serializable
        
        public MatrixCharacteristics() {}
        
+       public MatrixCharacteristics(long nr, long nc, long nnz) {
+               set(nr, nc, -1, -1, nnz);
+       }
+       
        public MatrixCharacteristics(long nr, long nc, int bnr, int bnc) {
                set(nr, nc, bnr, bnc);
        }
@@ -93,29 +97,32 @@ public class MatrixCharacteristics implements Serializable
                set(that.numRows, that.numColumns, that.numRowsPerBlock, 
that.numColumnsPerBlock, that.nonZero);
        }
 
-       public void set(long nr, long nc, int bnr, int bnc) {
+       public MatrixCharacteristics set(long nr, long nc, int bnr, int bnc) {
                numRows = nr;
                numColumns = nc;
                numRowsPerBlock = bnr;
                numColumnsPerBlock = bnc;
+               return this;
        }
        
-       public void set(long nr, long nc, int bnr, int bnc, long nnz) {
+       public MatrixCharacteristics set(long nr, long nc, int bnr, int bnc, 
long nnz) {
                numRows = nr;
                numColumns = nc;
                numRowsPerBlock = bnr;
                numColumnsPerBlock = bnc;
                nonZero = nnz;
                ubNnz = false;
+               return this;
        }
        
-       public void set(MatrixCharacteristics that) {
+       public MatrixCharacteristics set(MatrixCharacteristics that) {
                numRows = that.numRows;
                numColumns = that.numColumns;
                numRowsPerBlock = that.numRowsPerBlock;
                numColumnsPerBlock = that.numColumnsPerBlock;
                nonZero = that.nonZero;
                ubNnz = that.ubNnz;
+               return this;
        }
        
        public long getRows(){
@@ -207,6 +214,10 @@ public class MatrixCharacteristics implements Serializable
                return nonZero;
        }
        
+       public double getSparsity() {
+               return OptimizerUtils.getSparsity(this);
+       }
+       
        public boolean dimsKnown() {
                return ( numRows >= 0 && numColumns >= 0 );
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java
index 97e7fd3..0f4dad5 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java
@@ -19,18 +19,19 @@
 
 package org.apache.sysml.test.integration.functions.estim;
 
-import org.junit.Test;
 import org.apache.sysml.hops.estim.EstimatorBasicAvg;
 import org.apache.sysml.hops.estim.EstimatorBasicWorst;
 import org.apache.sysml.hops.estim.EstimatorBitsetMM;
 import org.apache.sysml.hops.estim.EstimatorDensityMap;
 import org.apache.sysml.hops.estim.EstimatorMatrixHistogram;
 import org.apache.sysml.hops.estim.MMNode;
+import org.apache.sysml.hops.estim.SparsityEstimator.OpCode;
 import org.apache.sysml.hops.estim.SparsityEstimator;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
 
 /**
  * This is a basic sanity check for all estimator, which need
@@ -135,8 +136,8 @@ public class SquaredProductChainTest extends 
AutomatedTestBase
                        new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
                
                //compare estimated and real sparsity
-               double est = estim.estim(new MMNode(
-                       new MMNode(new MMNode(m1), new MMNode(m2)), new 
MMNode(m3)));
+               double est = estim.estim(new MMNode(new MMNode(new MMNode(m1), 
new MMNode(m2),
+                       OpCode.MM), new MMNode(m3), OpCode.MM)).getSparsity();
                TestUtils.compareScalars(est, m5.getSparsity(),
                        (estim instanceof EstimatorBitsetMM) ? eps3 : //exact
                        (estim instanceof EstimatorBasicWorst) ? eps1 : eps2);

Reply via email to