Repository: systemml
Updated Branches:
  refs/heads/master 04bc667f3 -> ff4dbb3ee


[SYSTEMML-2479] Extended bitset estimator for rbind, various cleanups

Closes #824.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ff4dbb3e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ff4dbb3e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ff4dbb3e

Branch: refs/heads/master
Commit: ff4dbb3ee893b2609fa8111717d71f1bbfd46fa2
Parents: 04bc667
Author: Johanna Sommer <[email protected]>
Authored: Thu Aug 9 19:12:28 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Aug 9 23:29:03 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/estim/EstimatorBasicAvg.java     | 13 +---
 .../sysml/hops/estim/EstimatorBasicWorst.java   | 13 +---
 .../sysml/hops/estim/EstimatorBitsetMM.java     | 73 ++++++++++++++++----
 .../sysml/hops/estim/EstimatorDensityMap.java   |  3 +
 .../sysml/hops/estim/EstimatorLayeredGraph.java |  2 +-
 .../hops/estim/EstimatorMatrixHistogram.java    |  6 +-
 .../sysml/hops/estim/SparsityEstimator.java     | 37 ++++++++++
 7 files changed, 108 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ff4dbb3e/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
index 6448f70..10ff0f7 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java
@@ -70,24 +70,13 @@ public class EstimatorBasicAvg extends SparsityEstimator
                                        OptimizerUtils.getNnz(mc1.getRows(), 
mc1.getCols(), 
                                                mc1.getSparsity() + 
mc2.getSparsity() - mc1.getSparsity() * mc2.getSparsity()));
                        case EQZERO:
-                               return new MatrixCharacteristics(mc1.getRows(), 
mc1.getCols(),
-                                       (long) mc1.getRows() * mc1.getCols() - 
mc1.getNonZeros());
                        case DIAG:
-                               return (mc1.getCols() == 1) ?
-                                       new 
MatrixCharacteristics(mc1.getRows(), mc1.getRows(), mc1.getNonZeros()) :
-                                       new 
MatrixCharacteristics(mc1.getRows(), 1, Math.min(mc1.getRows(), 
mc1.getNonZeros()));
-                       // binary operations that preserve sparsity exactly
                        case CBIND:
-                               return new MatrixCharacteristics(mc1.getRows(), 
-                                       mc1.getCols() + mc2.getCols(), 
mc1.getNonZeros() + mc2.getNonZeros());
                        case RBIND:
-                               return new MatrixCharacteristics(mc1.getRows() 
+ mc2.getRows(), 
-                                       mc1.getCols(), mc1.getNonZeros() + 
mc2.getNonZeros());
-                       // unary operation that preserve sparsity exactly
                        case NEQZERO:
                        case TRANS:
                        case RESHAPE:
-                               return mc1;
+                               return estimExactMetaData(mc1, mc2, op);
                        default:
                                throw new NotImplementedException();
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ff4dbb3e/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java
index e99b55d..81877dd 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java
@@ -74,24 +74,13 @@ public class EstimatorBasicWorst extends SparsityEstimator
                                        OptimizerUtils.getNnz(mc1.getRows(), 
mc1.getCols(), 
                                                Math.min(mc1.getSparsity() + 
mc2.getSparsity(), 1)));
                        case EQZERO:
-                               return new MatrixCharacteristics(mc1.getRows(), 
mc1.getCols(),
-                                       (long) mc1.getRows() * mc1.getCols() - 
mc1.getNonZeros());
                        case DIAG:
-                               return (mc1.getCols() == 1) ?
-                                       new 
MatrixCharacteristics(mc1.getRows(), mc1.getRows(), mc1.getNonZeros()) :
-                                       new 
MatrixCharacteristics(mc1.getRows(), 1, Math.min(mc1.getRows(), 
mc1.getNonZeros()));
-                       // binary operations that preserve sparsity exactly
                        case CBIND:
-                               return new MatrixCharacteristics(mc1.getRows(), 
-                                       mc1.getCols() + mc2.getCols(), 
mc1.getNonZeros() + mc2.getNonZeros());
                        case RBIND:
-                               return new MatrixCharacteristics(mc1.getRows() 
+ mc2.getRows(), 
-                                       mc1.getCols(), mc1.getNonZeros() + 
mc2.getNonZeros());
-                       // unary operation that preserve sparsity exactly
                        case NEQZERO:
                        case TRANS:
                        case RESHAPE:
-                               return mc1;
+                               return estimExactMetaData(mc1, mc2, op);
                        default:
                                throw new NotImplementedException();
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ff4dbb3e/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
index a898147..0bb4e5d 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java
@@ -23,6 +23,7 @@ import java.util.BitSet;
 import java.util.stream.IntStream;
 
 import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.OptimizerUtils;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -41,7 +42,8 @@ import org.apache.sysml.runtime.matrix.data.SparseBlock;
  * Multiplication for Irregular Data. In IPDPS, pages 370–381, 2014.
  * 
  */
-public class EstimatorBitsetMM extends SparsityEstimator {
+public class EstimatorBitsetMM extends SparsityEstimator
+{
        @Override
        public MatrixCharacteristics estim(MMNode root) {
                // recursive density map computation of non-leaf nodes
@@ -53,32 +55,51 @@ public class EstimatorBitsetMM extends SparsityEstimator {
                        new BitsetMatrix1(root.getLeft().getData());
                BitsetMatrix m2Map = !root.getRight().isLeaf() ? (BitsetMatrix) 
root.getRight().getSynopsis() :
                        new BitsetMatrix1(root.getRight().getData());
-
-               // estimate output density map and sparsity via boolean matrix 
mult
-               BitsetMatrix outMap = m1Map.matMult(m2Map);
-               root.setSynopsis(outMap); // memoize boolean matrix
+               BitsetMatrix outMap = estimInternal(m1Map, m2Map, root.getOp());
+               root.setSynopsis(outMap); // memorize boolean matrix
                return root.setMatrixCharacteristics(new MatrixCharacteristics(
                        outMap.getNumRows(), outMap.getNumColumns(), 
outMap.getNonZeros()));
        }
 
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2) {
-               BitsetMatrix m1Map = new BitsetMatrix1(m1);
-               BitsetMatrix m2Map = (m1 == m2) ? //self product
-                       m1Map : new BitsetMatrix1(m2);
-               BitsetMatrix outMap = m1Map.matMult(m2Map);
-               return OptimizerUtils.getSparsity( // aggregate output histogram
-                               outMap.getNumRows(), outMap.getNumColumns(), 
outMap.getNonZeros());
+               return estim(m1, m2, OpCode.MM);
        }
        
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
-               throw new NotImplementedException();
+               if( isExactMetadataOp(op) )
+                       return estimExactMetaData(m1.getMatrixCharacteristics(),
+                               m2.getMatrixCharacteristics(), 
op).getSparsity();
+               BitsetMatrix m1Map = new BitsetMatrix1(m1);
+               BitsetMatrix m2Map = (m1 == m2) ? //self product
+                       m1Map : new BitsetMatrix1(m2);
+               BitsetMatrix outMap = estimInternal(m1Map, m2Map, op);
+               return OptimizerUtils.getSparsity(outMap.getNumRows(),
+                       outMap.getNumColumns(), outMap.getNonZeros());
        }
        
        @Override
        public double estim(MatrixBlock m, OpCode op) {
-               throw new NotImplementedException();
+               return estim(m, null, op);
+       }
+       
+       private BitsetMatrix estimInternal(BitsetMatrix m1Map, BitsetMatrix 
m2Map, OpCode op) {
+               switch(op) {
+                       case MM:      return m1Map.matMult(m2Map);
+                       case RBIND:   return m1Map.rbind(m2Map);
+                       //TODO implement all as bitset operations in both 
BitsetMatrix1 and BitsetMatrix2
+                       case MULT:
+                       case PLUS:
+                       case CBIND:
+                       case TRANS:
+                       case NEQZERO:
+                       case EQZERO:
+                       case DIAG:
+                       case RESHAPE:
+                       default:
+                               throw new NotImplementedException();
+               }
        }
 
        private abstract static class BitsetMatrix {
@@ -143,6 +164,8 @@ public class EstimatorBitsetMM extends SparsityEstimator {
                protected abstract void buildIntern(MatrixBlock in, int rl, int 
ru);
                
                protected abstract long matMultIntern(BitsetMatrix bsb, 
BitsetMatrix bsc, int rl, int ru);
+               
+               protected abstract BitsetMatrix rbind(BitsetMatrix bsb);
        }
        
        /**
@@ -240,6 +263,18 @@ public class EstimatorBitsetMM extends SparsityEstimator {
                        return lnnz;
                }
                
+               @Override 
+               public BitsetMatrix rbind(BitsetMatrix bsb) {
+                       if( !(bsb instanceof BitsetMatrix1) )
+                               throw new HopsException("Incompatible bitset 
types: "
+                                       + getClass().getSimpleName()+" and 
"+bsb.getClass().getSimpleName());
+                       BitsetMatrix1 b = (BitsetMatrix1) bsb;
+                       BitsetMatrix1 ret = new 
BitsetMatrix1(getNumRows()+bsb.getNumRows(), getNumColumns());
+                       System.arraycopy(_data, 0, ret._data, 0, _rlen*_rowLen);
+                       System.arraycopy(b._data, 0, ret._data, _rlen*_rowLen, 
b._rlen*_rowLen);
+                       return ret;
+               }
+               
                private void set(int r, int c) {
                        int off = r * _rowLen;
                        int wordIndex = wordIndex(c); //see BitSet.java
@@ -354,5 +389,17 @@ public class EstimatorBitsetMM extends SparsityEstimator {
                        }
                        return lnnz;
                }
+               
+               @Override 
+               public BitsetMatrix rbind(BitsetMatrix bsb) {
+                       if( !(bsb instanceof BitsetMatrix2) )
+                               throw new HopsException("Incompatible bitset 
types: "
+                                       + getClass().getSimpleName()+" and 
"+bsb.getClass().getSimpleName());
+                       BitsetMatrix2 b = (BitsetMatrix2) bsb;
+                       BitsetMatrix2 ret = new 
BitsetMatrix2(getNumRows()+bsb.getNumRows(), getNumColumns());
+                       System.arraycopy(_data, 0, ret._data, 0, _rlen); 
//shallow copy
+                       System.arraycopy(b._data, 0, ret._data, _rlen, 
b._rlen); //shallow copy
+                       return ret;
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ff4dbb3e/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
index 66c5826..b6fca0f 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java
@@ -80,6 +80,9 @@ public class EstimatorDensityMap extends SparsityEstimator
        
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
+               if( isExactMetadataOp(op) )
+                       return estimExactMetaData(m1.getMatrixCharacteristics(),
+                               m2.getMatrixCharacteristics(), 
op).getSparsity();
                DensityMap m1Map = new DensityMap(m1, _b);
                DensityMap m2Map = (m1 == m2) ? //self product
                        m1Map : new DensityMap(m2, _b);

http://git-wip-us.apache.org/repos/asf/systemml/blob/ff4dbb3e/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java
index e24886a..d103359 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java
@@ -68,7 +68,7 @@ public class EstimatorLayeredGraph extends SparsityEstimator {
        }
        
        @Override
-       public double estim(MatrixBlock m1, MatrixBlock m2){
+       public double estim(MatrixBlock m1, MatrixBlock m2) {
                LayeredGraph graph = new LayeredGraph(m1, m2, _rounds);
                return OptimizerUtils.getSparsity(m1.getNumRows(),
                        m2.getNumColumns(), graph.estimateNnz());

http://git-wip-us.apache.org/repos/asf/systemml/blob/ff4dbb3e/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
index a299c45..26b3df0 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
@@ -68,7 +68,6 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
                
                //estimate output sparsity based on input histograms
                double ret = estimIntern(h1, h2, root.getOp());
-
                MatrixHistogram outMap = 
MatrixHistogram.deriveOutputHistogram(h1, h2, ret, root.getOp());
                root.setSynopsis(outMap);
                return root.setMatrixCharacteristics(new MatrixCharacteristics(
@@ -83,6 +82,9 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
        
        @Override
        public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
+               if( isExactMetadataOp(op) )
+                       return estimExactMetaData(m1.getMatrixCharacteristics(),
+                               m2.getMatrixCharacteristics(), 
op).getSparsity();
                MatrixHistogram h1 = new MatrixHistogram(m1, _useExcepts);
                MatrixHistogram h2 = (m1 == m2) ? //self product
                        h1 : new MatrixHistogram(m2, _useExcepts);
@@ -91,6 +93,8 @@ public class EstimatorMatrixHistogram extends 
SparsityEstimator
        
        @Override
        public double estim(MatrixBlock m1, OpCode op) {
+               if( isExactMetadataOp(op) )
+                       return 
estimExactMetaData(m1.getMatrixCharacteristics(), null, op).getSparsity();
                MatrixHistogram h1 = new MatrixHistogram(m1, _useExcepts);
                return estimIntern(h1, null, op);
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ff4dbb3e/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java 
b/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java
index a37372f..2941959 100644
--- a/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java
+++ b/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java
@@ -19,8 +19,10 @@
 
 package org.apache.sysml.hops.estim;
 
+import org.apache.commons.lang.ArrayUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 
@@ -33,6 +35,11 @@ public abstract class SparsityEstimator
        public static boolean MULTI_THREADED_ESTIM = false;
        public static final int MIN_PAR_THRESHOLD = 10 * 1024;
        
+       private static OpCode[] EXACT_META_DATA_OPS = new OpCode[] {
+               OpCode.EQZERO, OpCode.NEQZERO, OpCode.CBIND,
+               OpCode.RBIND, OpCode.TRANS, OpCode.DIAG, OpCode.RESHAPE
+       };
+       
        public static enum OpCode {
                MM, 
                MULT, PLUS, EQZERO, NEQZERO,
@@ -77,4 +84,34 @@ public abstract class SparsityEstimator
         * @return sparsity
         */
        public abstract double estim(MatrixBlock m, OpCode op);
+       
+       protected boolean isExactMetadataOp(OpCode op) {
+               return ArrayUtils.contains(EXACT_META_DATA_OPS, op);
+       }
+       
+       protected MatrixCharacteristics 
estimExactMetaData(MatrixCharacteristics mc1, MatrixCharacteristics mc2, OpCode 
op) {
+               switch( op ) {
+                       case EQZERO:
+                               return new MatrixCharacteristics(mc1.getRows(), 
mc1.getCols(),
+                                       (long) mc1.getRows() * mc1.getCols() - 
mc1.getNonZeros());
+                       case DIAG:
+                               return (mc1.getCols() == 1) ?
+                                       new 
MatrixCharacteristics(mc1.getRows(), mc1.getRows(), mc1.getNonZeros()) :
+                                       new 
MatrixCharacteristics(mc1.getRows(), 1, Math.min(mc1.getRows(), 
mc1.getNonZeros()));
+                       // binary operations that preserve sparsity exactly
+                       case CBIND:
+                               return new MatrixCharacteristics(mc1.getRows(), 
+                                       mc1.getCols() + mc2.getCols(), 
mc1.getNonZeros() + mc2.getNonZeros());
+                       case RBIND:
+                               return new MatrixCharacteristics(mc1.getRows() 
+ mc2.getRows(), 
+                                       mc1.getCols(), mc1.getNonZeros() + 
mc2.getNonZeros());
+                       // unary operation that preserve sparsity exactly
+                       case NEQZERO:
+                       case TRANS:
+                       case RESHAPE:
+                               return mc1;
+                       default:
+                               throw new HopsException("Opcode is not an exact 
meta data operation: "+op.name());
+               }
+       }
 }

Reply via email to