[SYSTEMML-2292,2296] Matrix histograms w/ new exception vectors This patch improves the general case of sparsity estimation with matrix histograms where so far we used average case estimates. We now maintain dedicated exception vectors which enable a hybrid estimator of exact and approximate fractions. Overall this significantly improves accuracy for scenarios where a substantial number of rows or columns has <=1 non-zeros (which allow for exact sparsity inference). For an m x n input matrix, the size of the synopsis remains O(max(n, m)), i.e., in the size of the largest dimension.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3d61fddd Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3d61fddd Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3d61fddd Branch: refs/heads/master Commit: 3d61fddd9a3e741e81067c92e41e36567786031f Parents: 53fa046 Author: Matthias Boehm <[email protected]> Authored: Sat May 5 23:07:04 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat May 5 23:07:04 2018 -0700 ---------------------------------------------------------------------- .../hops/estim/EstimatorMatrixHistogram.java | 89 ++++++++++++++++++-- .../functions/estim/OuterProductTest.java | 14 ++- .../estim/SquaredProductChainTest.java | 14 ++- .../functions/estim/SquaredProductTest.java | 14 ++- 4 files changed, 116 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/3d61fddd/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java index 1467130..7fc4bb0 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java @@ -34,6 +34,19 @@ import org.apache.sysml.runtime.matrix.data.SparseBlock; */ public class EstimatorMatrixHistogram extends SparsityEstimator { + //internal configurations + private static final boolean DEFAULT_USE_EXCEPTS = true; + + private final boolean _useExcepts; + + public EstimatorMatrixHistogram() { + this(DEFAULT_USE_EXCEPTS); + } + + public EstimatorMatrixHistogram(boolean useExcepts) { + _useExcepts = useExcepts; + } + @Override public double estim(MMNode root) { //recursive histogram computation of non-leaf nodes @@ -42,9 +55,11 @@ public class EstimatorMatrixHistogram extends SparsityEstimator if( !root.getRight().isLeaf() ) estim(root.getLeft()); //obtain synopsis MatrixHistogram h1 = !root.getLeft().isLeaf() ? - (MatrixHistogram)root.getLeft().getSynopsis() : new MatrixHistogram(root.getLeft().getData()); + (MatrixHistogram)root.getLeft().getSynopsis() : + new MatrixHistogram(root.getLeft().getData(), _useExcepts); MatrixHistogram h2 = !root.getRight().isLeaf() ? - (MatrixHistogram)root.getRight().getSynopsis() : new MatrixHistogram(root.getRight().getData()); + (MatrixHistogram)root.getRight().getSynopsis() : + new MatrixHistogram(root.getRight().getData(), _useExcepts); //estimate output sparsity based on input histograms double ret = estimIntern(h1, h2); @@ -57,8 +72,8 @@ public class EstimatorMatrixHistogram extends SparsityEstimator @Override public double estim(MatrixBlock m1, MatrixBlock m2) { - MatrixHistogram h1 = new MatrixHistogram(m1); - MatrixHistogram h2 = new MatrixHistogram(m2); + MatrixHistogram h1 = new MatrixHistogram(m1, _useExcepts); + MatrixHistogram h2 = new MatrixHistogram(m2, _useExcepts); return estimIntern(h1, h2); } @@ -77,6 +92,21 @@ public class EstimatorMatrixHistogram extends SparsityEstimator for( int j=0; j<h1.getCols(); j++ ) nnz += h1.cNnz[j] * h2.rNnz[j]; } + //special case, with hybrid exact and approximate output + else if(h1.cNnz1e!=null && h2.rNnz1e != null) { + int mnOut = h1.getRows()*h2.getCols(); + double spOutRest = 0; + for( int j=0; j<h1.getCols(); j++ ) { + //exact fractions, w/o double counting + nnz += h1.cNnz1e[j] * h2.rNnz[j]; + nnz += (h1.cNnz[j]-h1.cNnz1e[j]) * h2.rNnz1e[j]; + //approximate fraction, w/o double counting + double lsp = (double)(h1.cNnz[j]-h1.cNnz1e[j]) + * (h2.rNnz[j]-h2.rNnz1e[j]) / mnOut; + spOutRest = spOutRest + lsp - spOutRest*lsp; + } + nnz += (long)(spOutRest * mnOut); + } //general case with approximate output else { int mnOut = h1.getRows()*h2.getCols(); @@ -94,17 +124,21 @@ public class EstimatorMatrixHistogram extends SparsityEstimator } private static class MatrixHistogram { - private final int[] rNnz; - private final int[] cNnz; + private final int[] rNnz; //row nnz counts + private int[] rNnz1e = null; //row nnz counts for cols w/ <= 1 non-zeros + private final int[] cNnz; //column nnz counts + private int[] cNnz1e = null; //column nnz counts for rows w/ <= 1 non-zeros private int rMaxNnz = 0; private int cMaxNnz = 0; - public MatrixHistogram(MatrixBlock in) { + public MatrixHistogram(MatrixBlock in, boolean useExcepts) { + //allocate basic synopsis rNnz = new int[in.getNumRows()]; cNnz = new int[in.getNumColumns()]; if( in.isEmptyBlock(false) ) return; + //compute basic synopsis details if( in.isInSparseFormat() ) { SparseBlock sblock = in.getSparseBlock(); for( int i=0; i<in.getNumRows(); i++ ) { @@ -132,11 +166,48 @@ public class EstimatorMatrixHistogram extends SparsityEstimator } } cMaxNnz = max(cNnz, 0, in.getNumColumns()); + + //compute exception details if necessary (optional) + if( useExcepts && (rMaxNnz > 1 || cMaxNnz > 1) ) { + rNnz1e = new int[in.getNumRows()]; + cNnz1e = new int[in.getNumColumns()]; + + if( in.isInSparseFormat() ) { + SparseBlock sblock = in.getSparseBlock(); + for( int i=0; i<in.getNumRows(); i++ ) { + if( sblock.isEmpty(i) ) continue; + int alen = sblock.size(i); + int apos = sblock.pos(i); + int[] aix = sblock.indexes(i); + for( int k=apos; k<apos+alen; k++ ) + rNnz1e[i] += cNnz[aix[k]] <= 1 ? 1 : 0; + if( alen <= 1 ) + for( int k=apos; k<apos+alen; k++ ) + cNnz1e[aix[k]]++; + } + } + else { + DenseBlock dblock = in.getDenseBlock(); + for( int i=0; i<in.getNumRows(); i++ ) { + double[] avals = dblock.values(i); + int aix = dblock.pos(i); + boolean rNnzlte1 = rNnz[i] <= 1; + for( int j=0; j<in.getNumColumns(); j++ ) { + if( avals[aix + j] != 0 ) { + rNnz1e[i] += cNnz[j] <= 1 ? 1 : 0; + cNnz1e[j] += rNnzlte1 ? 1 : 0; + } + } + } + } + } } - public MatrixHistogram(int[] r, int[] c, int rmax, int cmax) { + public MatrixHistogram(int[] r, int[] r1e, int[] c, int[] c1e, int rmax, int cmax) { rNnz = r; + rNnz1e = r1e; cNnz = c; + cNnz1e = c1e; rMaxNnz = rmax; cMaxNnz = cmax; } @@ -170,7 +241,7 @@ public class EstimatorMatrixHistogram extends SparsityEstimator } //construct new histogram object - return new MatrixHistogram(rNnz, cNnz, rMaxNnz, cMaxNnz); + return new MatrixHistogram(rNnz, null, cNnz, null, rMaxNnz, cMaxNnz); } private static int max(int[] a, int ai, int alen) { http://git-wip-us.apache.org/repos/asf/systemml/blob/3d61fddd/src/test/java/org/apache/sysml/test/integration/functions/estim/OuterProductTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/estim/OuterProductTest.java b/src/test/java/org/apache/sysml/test/integration/functions/estim/OuterProductTest.java index f58a0c4..f45ca70 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/estim/OuterProductTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/estim/OuterProductTest.java @@ -100,12 +100,22 @@ public class OuterProductTest extends AutomatedTestBase @Test public void testMatrixHistogramCase1() { - runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k, n, case1); + runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, k, n, case1); } @Test public void testMatrixHistogramCase2() { - runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k, n, case2); + runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, k, n, case2); + } + + @Test + public void testMatrixHistogramExceptCase1() { + runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, k, n, case1); + } + + @Test + public void testMatrixHistogramExceptCase2() { + runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, k, n, case2); } private void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, double[] sp) { http://git-wip-us.apache.org/repos/asf/systemml/blob/3d61fddd/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java index 82e45ed..97e7fd3 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java @@ -107,12 +107,22 @@ public class SquaredProductChainTest extends AutomatedTestBase @Test public void testMatrixHistogramCase1() { - runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k, n, n2, case1); + runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, k, n, n2, case1); } @Test public void testMatrixHistogramCase2() { - runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k, n, n2, case2); + runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, k, n, n2, case2); + } + + @Test + public void testMatrixHistogramExceptCase1() { + runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, k, n, n2, case1); + } + + @Test + public void testMatrixHistogramExceptCase2() { + runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, k, n, n2, case2); } private void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, int n2, double[] sp) { http://git-wip-us.apache.org/repos/asf/systemml/blob/3d61fddd/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductTest.java b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductTest.java index 4cbd06c..204842c 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductTest.java @@ -105,12 +105,22 @@ public class SquaredProductTest extends AutomatedTestBase @Test public void testMatrixHistogramCase1() { - runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k, n, case1); + runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, k, n, case1); } @Test public void testMatrixHistogramCase2() { - runSparsityEstimateTest(new EstimatorMatrixHistogram(), m, k, n, case2); + runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, k, n, case2); + } + + @Test + public void testMatrixHistogramExceptCase1() { + runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, k, n, case1); + } + + @Test + public void testMatrixHistogramExceptCase2() { + runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, k, n, case2); } private void runSparsityEstimateTest(SparsityEstimator estim, int m, int k, int n, double[] sp) {
