Repository: systemml Updated Branches: refs/heads/master 9d5918743 -> 14b4d5487
[SYSTEMML-1687] Worst-case size propagation for fused codegen operators This patch extends the generic codegen hop by the ability to propagate worst-case size information according to its propagation types. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/14b4d548 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/14b4d548 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/14b4d548 Branch: refs/heads/master Commit: 14b4d548723a01eb58047993153fa4845ec1450b Parents: 9d59187 Author: Matthias Boehm <[email protected]> Authored: Tue Jun 13 13:00:58 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Jun 13 13:00:58 2017 -0700 ---------------------------------------------------------------------- .../apache/sysml/hops/codegen/SpoofFusedOp.java | 64 +++++++++++++++++--- 1 file changed, 57 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/14b4d548/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java index 9f426f6..06be99b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java @@ -32,6 +32,7 @@ import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.SpoofFused; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; public class SpoofFusedOp extends Hop implements MultiThreadedHop { @@ -90,12 +91,7 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) { return 0; } - - @Override - protected long[] inferOutputCharacteristics(MemoTable memo) { - return null; - } - + @Override public Lop constructLops() throws HopsException, LopsException { if( getLops() != null ) @@ -140,7 +136,61 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop public String getOpString() { return "spoof("+_class.getSimpleName()+")"; } - + + @Override + protected long[] inferOutputCharacteristics( MemoTable memo ) + { + long[] ret = null; + + //get statistics of main input + MatrixCharacteristics mc = memo.getAllInputStats(getInput().get(0)); + + if( mc.dimsKnown() ) { + switch(_dimsType) + { + case ROW_DIMS: + ret = new long[]{mc.getRows(), 1, -1}; + break; + case ROW_DIMS2: + ret = new long[]{mc.getRows(), 2, -1}; + break; + case COLUMN_DIMS_ROWS: + ret = new long[]{mc.getCols(), 1, -1}; + break; + case COLUMN_DIMS_COLS: + ret = new long[]{1, mc.getCols(), -1}; + break; + case INPUT_DIMS: + ret = new long[]{mc.getRows(), mc.getCols(), -1}; + break; + case SCALAR: + ret = new long[]{0, 0, -1}; + break; + case MULTI_SCALAR: + //dim2 statically set from outside + ret = new long[]{1, _dim2, -1}; + break; + case ROW_RANK_DIMS: { + MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1)); + if( mc2.dimsKnown() ) + ret = new long[]{mc.getRows(), mc2.getCols(), -1}; + break; + } + case COLUMN_RANK_DIMS: { + MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1)); + if( mc2.dimsKnown() ) + ret = new long[]{mc.getCols(), mc2.getCols(), -1}; + break; + } + default: + throw new RuntimeException("Failed to infer worst-case size information " + + "for type: "+_dimsType.toString()); + } + } + + return ret; + } + @Override public void refreshSizeInformation() { switch(_dimsType)
