Repository: systemml
Updated Branches:
  refs/heads/master cca6356f8 -> 0a957e4c9


[SYSTEMML-2468] Improved MNC estimator (avoid final sketch propagation)

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/569806dc
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/569806dc
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/569806dc

Branch: refs/heads/master
Commit: 569806dcdf3c37bff59ad052884cf5f9af9bd598
Parents: cca6356
Author: Matthias Boehm <[email protected]>
Authored: Sun Oct 21 18:43:32 2018 +0200
Committer: Matthias Boehm <[email protected]>
Committed: Sun Oct 21 18:43:32 2018 +0200

----------------------------------------------------------------------
 .../hops/estim/EstimatorMatrixHistogram.java    | 61 +++++++++++++++++---
 1 file changed, 52 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/569806dc/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
index 83d918a..34f9cac 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
@@ -55,11 +55,15 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
        
        @Override
        public MatrixCharacteristics estim(MMNode root) {
+               return estim(root, true);
+       }
+       
+       private MatrixCharacteristics estim(MMNode root, boolean topLevel) {
                //recursive histogram computation of non-leaf nodes
                if( !root.getLeft().isLeaf() )
-                       estim(root.getLeft()); //obtain synopsis
+                       estim(root.getLeft(), false); //obtain synopsis
                if( root.getRight()!=null && !root.getRight().isLeaf() )
-                       estim(root.getRight()); //obtain synopsis
+                       estim(root.getRight(), false); //obtain synopsis
                MatrixHistogram h1 = !root.getLeft().isLeaf() ?
                        (MatrixHistogram)root.getLeft().getSynopsis() :
                        new MatrixHistogram(root.getLeft().getData(), 
_useExcepts);
@@ -69,6 +73,12 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
                
                //estimate output sparsity based on input histograms
                double ret = estimIntern(h1, h2, root.getOp(), root.getMisc());
+               if( topLevel ) { //fast-path final result
+                       return MatrixHistogram.deriveOutputCharacteristics(
+                               h1, h2, ret, root.getOp(), root.getMisc());
+               }
+               
+               //sketch propagation for intermediates other than final result
                MatrixHistogram outMap = MatrixHistogram
                        .deriveOutputHistogram(h1, h2, ret, root.getOp(), 
root.getMisc());
                root.setSynopsis(outMap);
@@ -183,13 +193,15 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
                        nnz = (long)(spOut * mnOut);
                }
                
-               //exploit upper bound on nnz based on non-empty rows/cols
-               nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ?
-                       Math.min((long)h1.rNonEmpty * h2.cNonEmpty, nnz) : nnz;
-               
-               //exploit lower bound on nnz based on half-full rows/cols
-               nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ?
-                       Math.max((long)h1.rNdiv2 * h2.cNdiv2, nnz) : nnz;
+               if( _useExcepts ) {
+                       //exploit upper bound on nnz based on non-empty 
rows/cols
+                       nnz = (h1.rNonEmpty >= 0 && h2.cNonEmpty >= 0) ?
+                               Math.min((long)h1.rNonEmpty * h2.cNonEmpty, 
nnz) : nnz;
+                       
+                       //exploit lower bound on nnz based on half-full 
rows/cols
+                       nnz = (h1.rNdiv2 >= 0 && h2.cNdiv2 >= 0) ?
+                               Math.max((long)h1.rNdiv2 * h2.cNdiv2, nnz) : 
nnz;
+               }
                
                //compute final sparsity
                return OptimizerUtils.getSparsity(
@@ -343,6 +355,37 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
                        }
                }
                
+               public static MatrixCharacteristics 
deriveOutputCharacteristics(MatrixHistogram h1, MatrixHistogram h2, double 
spOut, OpCode op, long[] misc) {
+                       switch(op) {
+                               case MM:
+                                       return new 
MatrixCharacteristics(h1.getRows(), h2.getCols(),
+                                               
OptimizerUtils.getNnz(h1.getRows(), h2.getCols(), spOut));
+                               case MULT:
+                               case PLUS:
+                               case NEQZERO:
+                               case EQZERO:
+                                       return new 
MatrixCharacteristics(h1.getRows(), h1.getCols(),
+                                               
OptimizerUtils.getNnz(h1.getRows(), h1.getCols(), spOut));
+                               case RBIND:
+                                       return new 
MatrixCharacteristics(h1.getRows()+h1.getRows(), h1.getCols(),
+                                               
OptimizerUtils.getNnz(h1.getRows()+h2.getRows(), h1.getCols(), spOut));
+                               case CBIND:
+                                       return new 
MatrixCharacteristics(h1.getRows(), h1.getCols()+h2.getCols(),
+                                               
OptimizerUtils.getNnz(h1.getRows(), h1.getCols()+h2.getCols(), spOut));
+                               case DIAG:
+                                       int ncol = h1.getCols()==1 ? 
h1.getRows() : 1;
+                                       return new 
MatrixCharacteristics(h1.getRows(), ncol,
+                                               
OptimizerUtils.getNnz(h1.getRows(), ncol, spOut));
+                               case TRANS:
+                                       return new 
MatrixCharacteristics(h1.getCols(), h1.getRows(), h1.getNonZeros());
+                               case RESHAPE:
+                                       return new 
MatrixCharacteristics((int)misc[0], (int)misc[1],
+                                               
OptimizerUtils.getNnz((int)misc[0], (int)misc[1], spOut));
+                               default:
+                                       throw new NotImplementedException();
+                       }
+               }
+               
                private static MatrixHistogram 
deriveMMHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut) {
                        //exact propagation if lhs or rhs full diag
                        if( h1.fullDiag ) return h2;

Reply via email to