Repository: systemml Updated Branches: refs/heads/master bc7b4961a -> e0187028e
[SYSTEMML-2479] Extended MNC sparsity estimator for other operations Closes #820. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e0187028 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e0187028 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e0187028 Branch: refs/heads/master Commit: e0187028e43b5fbe884e795e2d0742280634ffa7 Parents: bc7b496 Author: Johanna Sommer <joha...@mail-sommer.com> Authored: Tue Aug 7 20:44:09 2018 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Tue Aug 7 20:44:09 2018 -0700 ---------------------------------------------------------------------- .../hops/estim/EstimatorMatrixHistogram.java | 92 ++++++++++++++++++-- 1 file changed, 84 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e0187028/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java index 7f0047c..637c7f3 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java @@ -19,10 +19,10 @@ package org.apache.sysml.hops.estim; -import java.util.Arrays; import java.util.Random; import java.util.stream.IntStream; +import org.apache.commons.lang.ArrayUtils; import org.apache.directory.api.util.exception.NotImplementedException; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -67,13 +67,13 @@ public class EstimatorMatrixHistogram extends SparsityEstimator new MatrixHistogram(root.getRight().getData(), _useExcepts); //estimate output sparsity based on input histograms - double ret = estimIntern(h1, h2, OpCode.MM); - - //derive and memoize output histogram - MatrixHistogram outMap = MatrixHistogram.deriveOutputHistogram(h1, h2, ret); + double ret = estimIntern(h1, h2, root.getOp()); + + MatrixHistogram outMap = MatrixHistogram.deriveOutputHistogram(h1, h2, ret, root.getOp()); root.setSynopsis(outMap); return root.setMatrixCharacteristics(new MatrixCharacteristics( outMap.getRows(), outMap.getCols(), outMap.getNonZeros())); + } @Override @@ -304,14 +304,27 @@ public class EstimatorMatrixHistogram extends SparsityEstimator IntStream.range(0, getRows()).mapToLong(i-> cNnz[i]).sum(); } - public static MatrixHistogram deriveOutputHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut) { + public static MatrixHistogram deriveOutputHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut, OpCode op) { + switch(op) { + case MM: return deriveMMHistogram(h1, h2, spOut); + case MULT: return deriveMultHistogram(h1, h2); + case PLUS: return derivePlusHistogram(h1, h2); + case RBIND: return deriveRbindHistogram(h1, h2); + case CBIND: return deriveCbindHistogram(h1, h2); + //TODO add missing unary operations + default: + throw new NotImplementedException(); + } + } + + private static MatrixHistogram deriveMMHistogram(MatrixHistogram h1, MatrixHistogram h2, double spOut) { //exact propagation if lhs or rhs full diag if( h1.fullDiag ) return h2; if( h2.fullDiag ) return h1; //get input/output nnz for scaling - long nnz1 = Arrays.stream(h1.rNnz).sum(); - long nnz2 = Arrays.stream(h2.cNnz).sum(); + long nnz1 = h1.getNonZeros(); + long nnz2 = h2.getNonZeros(); double nnzOut = spOut * h1.getRows() * h2.getCols(); //propagate h1.r and h2.c to output via simple scaling @@ -333,6 +346,69 @@ public class EstimatorMatrixHistogram extends SparsityEstimator return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz); } + private static MatrixHistogram deriveMultHistogram(MatrixHistogram h1, MatrixHistogram h2) { + final long N1 = h1.getNonZeros(); + final long N2 = h2.getNonZeros(); + final long scaler = IntStream.range(0, h1.getCols()) + .mapToLong(j -> (long)h1.cNnz[j] * h2.cNnz[j]).sum(); + final long scalec = IntStream.range(0, h1.getRows()) + .mapToLong(j -> (long)h1.rNnz[j] * h2.rNnz[j]).sum(); + int rMaxNnz = 0, cMaxNnz = 0; + Random rn = new Random(); + int[] rNnz = new int[h1.getRows()]; + for(int i=0; i<h1.getRows(); i++) { + rNnz[i] = probRound(h1.rNnz[i] * h2.rNnz[i] * scaler / N1 / N2, rn); + rMaxNnz = Math.max(rMaxNnz, rNnz[i]); + } + int[] cNnz = new int[h1.getCols()]; + for(int i=0; i<h1.getCols(); i++) { + cNnz[i] = probRound(h1.cNnz[i] * h2.cNnz[i] * scalec / N1 / N2, rn); + cMaxNnz = Math.max(cMaxNnz, cNnz[i]); + } + return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz); + } + + private static MatrixHistogram derivePlusHistogram(MatrixHistogram h1, MatrixHistogram h2) { + double msize = (double)h1.getRows()*h1.getCols(); + int rMaxNnz = 0, cMaxNnz = 0; + Random rn = new Random(); + int[] rNnz = new int[h1.getRows()]; + for(int i=0; i<h1.getRows(); i++) { + rNnz[i] = probRound(h1.rNnz[i]/msize + h2.rNnz[i]/msize - h1.rNnz[i]/msize * h2.rNnz[i]/msize, rn); + rMaxNnz = Math.max(rMaxNnz, rNnz[i]); + } + int[] cNnz = new int[h1.getCols()]; + for(int i=0; i<h1.getCols(); i++) { + cNnz[i] = probRound(h1.cNnz[i]/msize + h2.cNnz[i]/msize - h1.cNnz[i]/msize * h2.cNnz[i]/msize, rn); + cMaxNnz = Math.max(cMaxNnz, cNnz[i]); + } + return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz); + } + + private static MatrixHistogram deriveRbindHistogram(MatrixHistogram h1, MatrixHistogram h2) { + int[] rNnz = ArrayUtils.addAll(h1.rNnz, h2.rNnz); + int rMaxNnz = Math.max(h1.rMaxNnz, h2.rMaxNnz); + int[] cNnz = new int[h1.getCols()]; + int cMaxNnz = 0; + for(int i=0; i<h1.getCols(); i++) { + cNnz[i] = h1.cNnz[i] + h2.cNnz[i]; + cMaxNnz = Math.max(cMaxNnz, cNnz[i]); + } + return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz); + } + + private static MatrixHistogram deriveCbindHistogram(MatrixHistogram h1, MatrixHistogram h2) { + int[] rNnz = new int[h1.getRows()]; + int rMaxNnz = 0; + for(int i=0; i<h1.getRows(); i++) { + rNnz[i] = h1.rNnz[i] + h2.rNnz[i]; + rMaxNnz = Math.max(rMaxNnz, rNnz[i]); + } + int[] cNnz = ArrayUtils.addAll(h1.cNnz, h2.cNnz); + int cMaxNnz = Math.max(h1.cMaxNnz, h2.cMaxNnz); + return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz); + } + private static int probRound(double inNnz, Random rand) { double temp = Math.floor(inNnz); double f = inNnz - temp; //non-int fraction [0,1)