Repository: systemml Updated Branches: refs/heads/master 51db735eb -> c61b94c97
[SYSTEMML-2375] Improved size inference nary ops (cbind/rbind/min/max) This patch improves the size propagation for nary ops by (1) adding the missing worst-case size inference, and (2) computing the nnz for exact size propagation. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c61b94c9 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c61b94c9 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c61b94c9 Branch: refs/heads/master Commit: c61b94c975ac9019c9e9f0187dda5b23dbac61e7 Parents: 51db735 Author: Matthias Boehm <[email protected]> Authored: Mon Jun 18 18:50:50 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Jun 18 18:50:50 2018 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/NaryOp.java | 22 ++++++++++ .../sysml/hops/rewrite/HopRewriteUtils.java | 43 +++++++++++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c61b94c9/src/main/java/org/apache/sysml/hops/NaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/NaryOp.java b/src/main/java/org/apache/sysml/hops/NaryOp.java index 0847659..db03a23 100644 --- a/src/main/java/org/apache/sysml/hops/NaryOp.java +++ b/src/main/java/org/apache/sysml/hops/NaryOp.java @@ -25,6 +25,7 @@ import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.lops.Nary; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; /** * The NaryOp Hop allows for a variable number of operands. Functionality @@ -180,7 +181,26 @@ public class NaryOp extends Hop { } @Override + @SuppressWarnings("incomplete-switch") protected long[] inferOutputCharacteristics(MemoTable memo) { + if( !getDataType().isScalar() ) { + MatrixCharacteristics[] mc = memo.getAllInputStats(getInput()); + + switch( _op ) { + case CBIND: return new long[]{ + HopRewriteUtils.getMaxInputDim(mc, true), + HopRewriteUtils.getSumValidInputDims(mc, false), + HopRewriteUtils.getSumValidInputNnz(mc, true)}; + case RBIND: return new long[]{ + HopRewriteUtils.getSumValidInputDims(mc, true), + HopRewriteUtils.getMaxInputDim(mc, false), + HopRewriteUtils.getSumValidInputNnz(mc, true)}; + case MIN: + case MAX: return new long[]{ + HopRewriteUtils.getMaxInputDim(this, true), + HopRewriteUtils.getMaxInputDim(this, false), -1}; + } + } return null; //do nothing } @@ -190,10 +210,12 @@ public class NaryOp extends Hop { case CBIND: setDim1(HopRewriteUtils.getMaxInputDim(this, true)); setDim2(HopRewriteUtils.getSumValidInputDims(this, false)); + setNnz(HopRewriteUtils.getSumValidInputNnz(this)); break; case RBIND: setDim1(HopRewriteUtils.getSumValidInputDims(this, true)); setDim2(HopRewriteUtils.getMaxInputDim(this, false)); + setNnz(HopRewriteUtils.getSumValidInputNnz(this)); break; case MIN: case MAX: http://git-wip-us.apache.org/repos/asf/systemml/blob/c61b94c9/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 7872c91..48b95cc 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -73,6 +73,7 @@ import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory; import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.util.UtilFunctions; @@ -1357,14 +1358,14 @@ public class HopRewriteUtils public static long getMaxInputDim(Hop hop, boolean dim1) { return hop.getInput().stream().mapToLong( - h -> (dim1?h.getDim1():h.getDim2())).max().orElse(-1); + h -> (dim1 ? h.getDim1() : h.getDim2())).max().orElse(-1); } public static long getSumValidInputDims(Hop hop, boolean dim1) { if( !hasValidInputDims(hop, dim1) ) return -1; return hop.getInput().stream().mapToLong( - h -> (dim1?h.getDim1():h.getDim2())).sum(); + h -> (dim1 ? h.getDim1() : h.getDim2())).sum(); } public static boolean hasValidInputDims(Hop hop, boolean dim1) { @@ -1372,6 +1373,44 @@ public class HopRewriteUtils h -> dim1 ? h.rowsKnown() : h.colsKnown()); } + public static long getSumValidInputNnz(Hop hop) { + if( !hasValidInputNnz(hop) ) + return -1; + return hop.getInput().stream().mapToLong(h -> h.getNnz()).sum(); + } + + public static boolean hasValidInputNnz(Hop hop) { + return hop.getInput().stream().allMatch(h -> h.getNnz() >= 0); + } + + public static long getMaxInputDim(MatrixCharacteristics[] mc, boolean dim1) { + return Arrays.stream(mc).mapToLong( + h -> (dim1 ? h.getRows() : h.getRows())).max().orElse(-1); + } + + public static long getSumValidInputDims(MatrixCharacteristics[] mc, boolean dim1) { + if( !hasValidInputDims(mc, dim1) ) + return -1; + return Arrays.stream(mc).mapToLong( + h -> (dim1 ? h.getRows() : h.getCols())).sum(); + } + + public static boolean hasValidInputDims(MatrixCharacteristics[] mc, boolean dim1) { + return Arrays.stream(mc).allMatch( + h -> dim1 ? h.rowsKnown() : h.colsKnown()); + } + + public static long getSumValidInputNnz(MatrixCharacteristics[] mc, boolean worstcase) { + if( !hasValidInputNnz(mc, worstcase) ) + return -1; + return Arrays.stream(mc).mapToLong(h -> h.nnzKnown() ? + h.getNonZeros() : h.getLength()).sum(); + } + + public static boolean hasValidInputNnz(MatrixCharacteristics[] mc, boolean worstcase) { + return Arrays.stream(mc).allMatch(h -> h.nnzKnown() || (worstcase && h.dimsKnown())); + } + public static boolean containsSecondOrderBuiltin(ArrayList<Hop> roots) { Hop.resetVisitStatus(roots); return roots.stream().anyMatch(r -> containsSecondOrderBuiltin(r));
