Repository: systemml Updated Branches: refs/heads/master e3a51f72a -> bc7b4961a
[SYSTEMML-2479] Finalized AVG and WC sparsity estimators, API cleanups Closes #818. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/bc7b4961 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/bc7b4961 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/bc7b4961 Branch: refs/heads/master Commit: bc7b4961a20c47c8064bc9ce7d30ba071b44a748 Parents: e3a51f7 Author: Johanna Sommer <[email protected]> Authored: Mon Aug 6 16:05:26 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Aug 6 21:34:06 2018 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/OptimizerUtils.java | 13 +++- .../sysml/hops/estim/EstimatorBasicAvg.java | 66 ++++++++++---------- .../sysml/hops/estim/EstimatorBasicWorst.java | 64 ++++++++++++++----- .../sysml/hops/estim/EstimatorBitsetMM.java | 8 ++- .../sysml/hops/estim/EstimatorDensityMap.java | 7 ++- .../sysml/hops/estim/EstimatorLayeredGraph.java | 3 +- .../hops/estim/EstimatorMatrixHistogram.java | 10 +-- .../sysml/hops/estim/EstimatorSample.java | 3 +- .../org/apache/sysml/hops/estim/MMNode.java | 17 ++++- .../sysml/hops/estim/SparsityEstimator.java | 4 +- .../runtime/matrix/MatrixCharacteristics.java | 17 ++++- .../estim/SquaredProductChainTest.java | 7 ++- 12 files changed, 146 insertions(+), 73 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/OptimizerUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java index fb83df0..e78d809 100644 --- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java @@ -85,7 +85,7 @@ public class OptimizerUtils * e.g., when input/output dimensions are unknown. The default is set to a large * value so that operations are scheduled on MR while avoiding overflows as well. */ - public static double DEFAULT_SIZE; + public static double DEFAULT_SIZE; public static final long DOUBLE_SIZE = 8; @@ -970,6 +970,10 @@ public class OptimizerUtils // Sparsity Estimates // //////////////////////// + public static long getMatMultNnz(double sp1, double sp2, long m, long k, long n, boolean worstcase) { + return getNnz( m, n, getMatMultSparsity(sp1, sp2, m, k, n, worstcase)); + } + /** * Estimates the result sparsity for Matrix Multiplication A %*% B. * @@ -981,8 +985,7 @@ public class OptimizerUtils * @param worstcase true if worst case * @return the sparsity */ - public static double getMatMultSparsity(double sp1, double sp2, long m, long k, long n, boolean worstcase) - { + public static double getMatMultSparsity(double sp1, double sp2, long m, long k, long n, boolean worstcase) { if( worstcase ){ double nnz1 = sp1 * m * k; double nnz2 = sp2 * k * n; @@ -1159,6 +1162,10 @@ public class OptimizerUtils } } + public static long getNnz(long dim1, long dim2, double sp) { + return (long) Math.round(sp * dim1 * dim2); + } + public static double getSparsity( MatrixCharacteristics mc ) { return getSparsity(mc.getRows(), mc.getCols(), mc.getNonZeros()); } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java index baa1fc4..6448f70 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicAvg.java @@ -21,75 +21,75 @@ package org.apache.sysml.hops.estim; import org.apache.commons.lang.NotImplementedException; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; /** * Basic average case estimator for matrix sparsity: * sp = 1 - Math.pow(1-sp1*sp2, k) */ -public class EstimatorBasicAvg extends SparsityEstimator { +public class EstimatorBasicAvg extends SparsityEstimator +{ @Override - public double estim(MMNode root) { - // recursive sparsity evaluation of non-leaf nodes - double sp1 = !root.getLeft().isLeaf() ? estim(root.getLeft()) : - OptimizerUtils.getSparsity(root.getLeft().getMatrixCharacteristics()); - double sp2 = !root.getRight().isLeaf() ? estim(root.getRight()) : - OptimizerUtils.getSparsity(root.getRight().getMatrixCharacteristics()); - return estimInternMM(sp1, sp2, root.getRows(), root.getLeft().getCols(), root.getCols()); + public MatrixCharacteristics estim(MMNode root) { + MatrixCharacteristics mc1 = !root.getLeft().isLeaf() ? + estim(root.getLeft()) : root.getLeft().getMatrixCharacteristics(); + MatrixCharacteristics mc2 = !root.getRight().isLeaf() ? + estim(root.getRight()) : root.getRight().getMatrixCharacteristics(); + return root.setMatrixCharacteristics( + estimIntern(mc1, mc2, root.getOp())); } @Override public double estim(MatrixBlock m1, MatrixBlock m2) { - return estimInternMM(m1.getSparsity(), m2.getSparsity(), - m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns()); + return estim(m1, m2, OpCode.MM); } @Override public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) { - return estimIntern(m1, m2, op); + return estimIntern(m1.getMatrixCharacteristics(), m2.getMatrixCharacteristics(), op).getSparsity(); } @Override public double estim(MatrixBlock m, OpCode op) { - return estimIntern(m, null, op); + return estimIntern(m.getMatrixCharacteristics(), null, op).getSparsity(); } - private double estimIntern(MatrixBlock m1, MatrixBlock m2, OpCode op) { + private MatrixCharacteristics estimIntern(MatrixCharacteristics mc1, MatrixCharacteristics mc2, OpCode op) { switch (op) { case MM: - return estimInternMM(m1.getSparsity(), m2.getSparsity(), - m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns()); + return new MatrixCharacteristics(mc1.getRows(), mc2.getCols(), + OptimizerUtils.getMatMultNnz(mc1.getSparsity(), mc2.getSparsity(), + mc1.getRows(), mc1.getCols(), mc2.getCols(), false)); case MULT: - return m1.getSparsity() * m2.getSparsity(); + return new MatrixCharacteristics(mc1.getRows(), mc1.getCols(), + OptimizerUtils.getNnz(mc1.getRows(), mc1.getCols(), + mc1.getSparsity() * mc2.getSparsity())); case PLUS: - return m1.getSparsity() + m2.getSparsity() - m1.getSparsity() * m2.getSparsity(); + return new MatrixCharacteristics(mc1.getRows(), mc1.getCols(), + OptimizerUtils.getNnz(mc1.getRows(), mc1.getCols(), + mc1.getSparsity() + mc2.getSparsity() - mc1.getSparsity() * mc2.getSparsity())); case EQZERO: - return OptimizerUtils.getSparsity(m1.getNumRows(), m1.getNumColumns(), - (long) m1.getNumRows() * m1.getNumColumns() - m1.getNonZeros()); + return new MatrixCharacteristics(mc1.getRows(), mc1.getCols(), + (long) mc1.getRows() * mc1.getCols() - mc1.getNonZeros()); case DIAG: - return (m1.getNumColumns() == 1) ? - OptimizerUtils.getSparsity(m1.getNumRows(), m1.getNumRows(), m1.getNonZeros()) : - OptimizerUtils.getSparsity(m1.getNumRows(), 1, Math.min(m1.getNumRows(), m1.getNonZeros())); + return (mc1.getCols() == 1) ? + new MatrixCharacteristics(mc1.getRows(), mc1.getRows(), mc1.getNonZeros()) : + new MatrixCharacteristics(mc1.getRows(), 1, Math.min(mc1.getRows(), mc1.getNonZeros())); // binary operations that preserve sparsity exactly case CBIND: - return OptimizerUtils.getSparsity(m1.getNumRows(), - m1.getNumColumns() + m1.getNumColumns(), m1.getNonZeros() + m2.getNonZeros()); + return new MatrixCharacteristics(mc1.getRows(), + mc1.getCols() + mc2.getCols(), mc1.getNonZeros() + mc2.getNonZeros()); case RBIND: - return OptimizerUtils.getSparsity(m1.getNumRows() + m2.getNumRows(), - m1.getNumColumns(), m1.getNonZeros() + m2.getNonZeros()); + return new MatrixCharacteristics(mc1.getRows() + mc2.getRows(), + mc1.getCols(), mc1.getNonZeros() + mc2.getNonZeros()); // unary operation that preserve sparsity exactly case NEQZERO: - return m1.getSparsity(); case TRANS: - return m1.getSparsity(); case RESHAPE: - return m1.getSparsity(); + return mc1; default: throw new NotImplementedException(); } } - - private double estimInternMM(double sp1, double sp2, long m, long k, long n) { - return OptimizerUtils.getMatMultSparsity(sp1, sp2, m, k, n, false); - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java index 736affe..e99b55d 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBasicWorst.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops.estim; import org.apache.commons.lang.NotImplementedException; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; /** @@ -34,32 +35,65 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; public class EstimatorBasicWorst extends SparsityEstimator { @Override - public double estim(MMNode root) { - //recursive sparsity evaluation of non-leaf nodes - double sp1 = !root.getLeft().isLeaf() ? estim(root.getLeft()) : - OptimizerUtils.getSparsity(root.getLeft().getMatrixCharacteristics()); - double sp2 = !root.getRight().isLeaf() ? estim(root.getRight()) : - OptimizerUtils.getSparsity(root.getRight().getMatrixCharacteristics()); - return estimIntern(sp1, sp2, root.getRows(), root.getLeft().getCols(), root.getCols()); + public MatrixCharacteristics estim(MMNode root) { + MatrixCharacteristics mc1 = !root.getLeft().isLeaf() ? + estim(root.getLeft()) : root.getLeft().getMatrixCharacteristics(); + MatrixCharacteristics mc2 = !root.getRight().isLeaf() ? + estim(root.getRight()) : root.getRight().getMatrixCharacteristics(); + return root.setMatrixCharacteristics( + estimIntern(mc1, mc2, root.getOp())); } @Override public double estim(MatrixBlock m1, MatrixBlock m2) { - return estimIntern(m1.getSparsity(), m2.getSparsity(), - m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns()); + return estim(m1, m2, OpCode.MM); } - + @Override public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) { - throw new NotImplementedException(); + return estimIntern(m1.getMatrixCharacteristics(), m2.getMatrixCharacteristics(), op).getSparsity(); } - + @Override public double estim(MatrixBlock m, OpCode op) { - throw new NotImplementedException(); + return estimIntern(m.getMatrixCharacteristics(), null, op).getSparsity(); } - private double estimIntern(double sp1, double sp2, long m, long k, long n) { - return OptimizerUtils.getMatMultSparsity(sp1, sp2, m, k, n, true); + private MatrixCharacteristics estimIntern(MatrixCharacteristics mc1, MatrixCharacteristics mc2, OpCode op) { + switch (op) { + case MM: + return new MatrixCharacteristics(mc1.getRows(), mc2.getCols(), + OptimizerUtils.getMatMultNnz(mc1.getSparsity(), mc2.getSparsity(), + mc1.getRows(), mc1.getCols(), mc2.getCols(), true)); + case MULT: + return new MatrixCharacteristics(mc1.getRows(), mc1.getCols(), + OptimizerUtils.getNnz(mc1.getRows(), mc1.getCols(), + Math.min(mc1.getSparsity(), mc2.getSparsity()))); + case PLUS: + return new MatrixCharacteristics(mc1.getRows(), mc1.getCols(), + OptimizerUtils.getNnz(mc1.getRows(), mc1.getCols(), + Math.min(mc1.getSparsity() + mc2.getSparsity(), 1))); + case EQZERO: + return new MatrixCharacteristics(mc1.getRows(), mc1.getCols(), + (long) mc1.getRows() * mc1.getCols() - mc1.getNonZeros()); + case DIAG: + return (mc1.getCols() == 1) ? + new MatrixCharacteristics(mc1.getRows(), mc1.getRows(), mc1.getNonZeros()) : + new MatrixCharacteristics(mc1.getRows(), 1, Math.min(mc1.getRows(), mc1.getNonZeros())); + // binary operations that preserve sparsity exactly + case CBIND: + return new MatrixCharacteristics(mc1.getRows(), + mc1.getCols() + mc2.getCols(), mc1.getNonZeros() + mc2.getNonZeros()); + case RBIND: + return new MatrixCharacteristics(mc1.getRows() + mc2.getRows(), + mc1.getCols(), mc1.getNonZeros() + mc2.getNonZeros()); + // unary operation that preserve sparsity exactly + case NEQZERO: + case TRANS: + case RESHAPE: + return mc1; + default: + throw new NotImplementedException(); + } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java index 38de751..a898147 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorBitsetMM.java @@ -25,6 +25,7 @@ import java.util.stream.IntStream; import org.apache.commons.lang.NotImplementedException; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.DenseBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.SparseBlock; @@ -42,7 +43,7 @@ import org.apache.sysml.runtime.matrix.data.SparseBlock; */ public class EstimatorBitsetMM extends SparsityEstimator { @Override - public double estim(MMNode root) { + public MatrixCharacteristics estim(MMNode root) { // recursive density map computation of non-leaf nodes if (!root.getLeft().isLeaf()) estim(root.getLeft()); // obtain synopsis @@ -56,7 +57,8 @@ public class EstimatorBitsetMM extends SparsityEstimator { // estimate output density map and sparsity via boolean matrix mult BitsetMatrix outMap = m1Map.matMult(m2Map); root.setSynopsis(outMap); // memoize boolean matrix - return OptimizerUtils.getSparsity(outMap.getNumRows(), outMap.getNumColumns(), outMap.getNonZeros()); + return root.setMatrixCharacteristics(new MatrixCharacteristics( + outMap.getNumRows(), outMap.getNumColumns(), outMap.getNonZeros())); } @Override @@ -277,7 +279,7 @@ public class EstimatorBitsetMM extends SparsityEstimator { c[ci+0] |= b[bi+0]; c[ci+1] |= b[bi+1]; c[ci+2] |= b[bi+2]; c[ci+3] |= b[bi+3]; c[ci+4] |= b[bi+4]; c[ci+5] |= b[bi+5]; - c[ci+6] |= b[bi+4]; c[ci+7] |= b[bi+7]; + c[ci+6] |= b[bi+6]; c[ci+7] |= b[bi+7]; } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java index c86ad21..1246978 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorDensityMap.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops.estim; import org.apache.commons.lang.NotImplementedException; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.DenseBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.SparseBlock; @@ -52,7 +53,7 @@ public class EstimatorDensityMap extends SparsityEstimator } @Override - public double estim(MMNode root) { + public MatrixCharacteristics estim(MMNode root) { //recursive density map computation of non-leaf nodes if( !root.getLeft().isLeaf() ) estim(root.getLeft()); //obtain synopsis @@ -67,8 +68,8 @@ public class EstimatorDensityMap extends SparsityEstimator MatrixBlock outMap = estimIntern(m1Map, m2Map, false, root.getRows(), root.getLeft().getCols(), root.getCols()); root.setSynopsis(outMap); //memoize density map - return OptimizerUtils.getSparsity( //aggregate output histogram - root.getRows(), root.getCols(), (long)outMap.sum()); + return root.setMatrixCharacteristics(new MatrixCharacteristics( + root.getLeft().getRows(), root.getRight().getCols(), (long)outMap.sum())); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java index b646970..e24886a 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorLayeredGraph.java @@ -22,6 +22,7 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.commons.math3.distribution.ExponentialDistribution; import org.apache.commons.math3.random.Well1024a; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.DenseBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.SparseBlock; @@ -52,7 +53,7 @@ public class EstimatorLayeredGraph extends SparsityEstimator { } @Override - public double estim(MMNode root) { + public MatrixCharacteristics estim(MMNode root) { throw new NotImplementedException(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java index 6b4c898..7f0047c 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java @@ -25,6 +25,7 @@ import java.util.stream.IntStream; import org.apache.directory.api.util.exception.NotImplementedException; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.DenseBlock; import org.apache.sysml.runtime.matrix.data.LibMatrixAgg; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -52,7 +53,7 @@ public class EstimatorMatrixHistogram extends SparsityEstimator } @Override - public double estim(MMNode root) { + public MatrixCharacteristics estim(MMNode root) { //recursive histogram computation of non-leaf nodes if( !root.getLeft().isLeaf() ) estim(root.getLeft()); //obtain synopsis @@ -69,9 +70,10 @@ public class EstimatorMatrixHistogram extends SparsityEstimator double ret = estimIntern(h1, h2, OpCode.MM); //derive and memoize output histogram - root.setSynopsis(MatrixHistogram.deriveOutputHistogram(h1, h2, ret)); - - return ret; + MatrixHistogram outMap = MatrixHistogram.deriveOutputHistogram(h1, h2, ret); + root.setSynopsis(outMap); + return root.setMatrixCharacteristics(new MatrixCharacteristics( + outMap.getRows(), outMap.getCols(), outMap.getNonZeros())); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java index ad2f6b7..1df23dd 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java @@ -22,6 +22,7 @@ package org.apache.sysml.hops.estim; import org.apache.commons.lang.NotImplementedException; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.DenseBlock; import org.apache.sysml.runtime.matrix.data.LibMatrixAgg; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -55,7 +56,7 @@ public class EstimatorSample extends SparsityEstimator } @Override - public double estim(MMNode root) { + public MatrixCharacteristics estim(MMNode root) { LOG.warn("Recursive estimates not supported by EstimatorSample, falling back to EstimatorBasicAvg."); return new EstimatorBasicAvg().estim(root); } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/MMNode.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/MMNode.java b/src/main/java/org/apache/sysml/hops/estim/MMNode.java index 55aee3d..542449a 100644 --- a/src/main/java/org/apache/sysml/hops/estim/MMNode.java +++ b/src/main/java/org/apache/sysml/hops/estim/MMNode.java @@ -19,6 +19,7 @@ package org.apache.sysml.hops.estim; +import org.apache.sysml.hops.estim.SparsityEstimator.OpCode; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -33,20 +34,22 @@ public class MMNode private final MatrixBlock _data; private final MatrixCharacteristics _mc; private Object _synops = null; + private final OpCode _op; public MMNode(MatrixBlock in) { _m1 = null; _m2 = null; _data = in; _mc = in.getMatrixCharacteristics(); + _op = null; } - public MMNode(MMNode left, MMNode right) { + public MMNode(MMNode left, MMNode right, OpCode op) { _m1 = left; _m2 = right; _data = null; - _mc = new MatrixCharacteristics( - _m1.getRows(), _m2.getCols(), -1, -1); + _mc = new MatrixCharacteristics(-1, -1, -1, -1); + _op = op; } public int getRows() { @@ -61,6 +64,10 @@ public class MMNode return _mc; } + public MatrixCharacteristics setMatrixCharacteristics(MatrixCharacteristics mc) { + return _mc.set(mc); //implicit copy + } + public MMNode getLeft() { return _m1; } @@ -84,4 +91,8 @@ public class MMNode public Object getSynopsis() { return _synops; } + + public OpCode getOp() { + return _op; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java b/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java index 411d9cd..a37372f 100644 --- a/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java +++ b/src/main/java/org/apache/sysml/hops/estim/SparsityEstimator.java @@ -21,6 +21,7 @@ package org.apache.sysml.hops.estim; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; public abstract class SparsityEstimator @@ -46,7 +47,8 @@ public abstract class SparsityEstimator * @param root * @return */ - public abstract double estim(MMNode root); + public abstract MatrixCharacteristics estim(MMNode root); + /** * Estimates the output sparsity for a single matrix multiplication. http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java b/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java index 6d816b6..183cd6e 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java @@ -81,6 +81,10 @@ public class MatrixCharacteristics implements Serializable public MatrixCharacteristics() {} + public MatrixCharacteristics(long nr, long nc, long nnz) { + set(nr, nc, -1, -1, nnz); + } + public MatrixCharacteristics(long nr, long nc, int bnr, int bnc) { set(nr, nc, bnr, bnc); } @@ -93,29 +97,32 @@ public class MatrixCharacteristics implements Serializable set(that.numRows, that.numColumns, that.numRowsPerBlock, that.numColumnsPerBlock, that.nonZero); } - public void set(long nr, long nc, int bnr, int bnc) { + public MatrixCharacteristics set(long nr, long nc, int bnr, int bnc) { numRows = nr; numColumns = nc; numRowsPerBlock = bnr; numColumnsPerBlock = bnc; + return this; } - public void set(long nr, long nc, int bnr, int bnc, long nnz) { + public MatrixCharacteristics set(long nr, long nc, int bnr, int bnc, long nnz) { numRows = nr; numColumns = nc; numRowsPerBlock = bnr; numColumnsPerBlock = bnc; nonZero = nnz; ubNnz = false; + return this; } - public void set(MatrixCharacteristics that) { + public MatrixCharacteristics set(MatrixCharacteristics that) { numRows = that.numRows; numColumns = that.numColumns; numRowsPerBlock = that.numRowsPerBlock; numColumnsPerBlock = that.numColumnsPerBlock; nonZero = that.nonZero; ubNnz = that.ubNnz; + return this; } public long getRows(){ @@ -207,6 +214,10 @@ public class MatrixCharacteristics implements Serializable return nonZero; } + public double getSparsity() { + return OptimizerUtils.getSparsity(this); + } + public boolean dimsKnown() { return ( numRows >= 0 && numColumns >= 0 ); } http://git-wip-us.apache.org/repos/asf/systemml/blob/bc7b4961/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java index 97e7fd3..0f4dad5 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductChainTest.java @@ -19,18 +19,19 @@ package org.apache.sysml.test.integration.functions.estim; -import org.junit.Test; import org.apache.sysml.hops.estim.EstimatorBasicAvg; import org.apache.sysml.hops.estim.EstimatorBasicWorst; import org.apache.sysml.hops.estim.EstimatorBitsetMM; import org.apache.sysml.hops.estim.EstimatorDensityMap; import org.apache.sysml.hops.estim.EstimatorMatrixHistogram; import org.apache.sysml.hops.estim.MMNode; +import org.apache.sysml.hops.estim.SparsityEstimator.OpCode; import org.apache.sysml.hops.estim.SparsityEstimator; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.utils.TestUtils; +import org.junit.Test; /** * This is a basic sanity check for all estimator, which need @@ -135,8 +136,8 @@ public class SquaredProductChainTest extends AutomatedTestBase new MatrixBlock(), InstructionUtils.getMatMultOperator(1)); //compare estimated and real sparsity - double est = estim.estim(new MMNode( - new MMNode(new MMNode(m1), new MMNode(m2)), new MMNode(m3))); + double est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), + OpCode.MM), new MMNode(m3), OpCode.MM)).getSparsity(); TestUtils.compareScalars(est, m5.getSparsity(), (estim instanceof EstimatorBitsetMM) ? eps3 : //exact (estim instanceof EstimatorBasicWorst) ? eps1 : eps2);
