[SYSTEMML-1852] Fix rewrite outer-rexpand (generalized max handling) This patch fixes a static rewrite issue of outer to rexpand, which failed whenever the sequence to parameter was not a literal after constant folding. We now extended this rewrite to arbitrary hops and generalized the related size propagation and memory estimates accordingly.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8c87d2a2 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8c87d2a2 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8c87d2a2 Branch: refs/heads/master Commit: 8c87d2a2f3472d5516113a65817defe5ce6989a6 Parents: fc7fcb0 Author: Matthias Boehm <[email protected]> Authored: Thu Aug 17 22:40:32 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Aug 18 14:15:45 2017 -0700 ---------------------------------------------------------------------- .../apache/sysml/hops/ParameterizedBuiltinOp.java | 10 +++++----- .../apache/sysml/hops/rewrite/HopRewriteUtils.java | 17 +++++++++-------- .../RewriteAlgebraicSimplificationStatic.java | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/8c87d2a2/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java index 7d6fe2a..ab276d7 100644 --- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java +++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java @@ -1025,15 +1025,15 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop //but very good sparsity estimate possible (number of non-zeros in input) Hop max = getInput().get(_paramIndexMap.get("max")); Hop dir = getInput().get(_paramIndexMap.get("dir")); - double maxVal = HopRewriteUtils.getDoubleValueSafe((LiteralOp)max); + long maxVal = computeDimParameterInformation(max, memo); String dirVal = ((LiteralOp)dir).getStringValue(); if( mc.dimsKnown() ) { long lnnz = mc.nnzKnown() ? mc.getNonZeros() : mc.getRows(); if( "cols".equals(dirVal) ) { //expand horizontally - ret = new long[]{mc.getRows(), UtilFunctions.toLong(maxVal), lnnz}; + ret = new long[]{mc.getRows(), maxVal, lnnz}; } else if( "rows".equals(dirVal) ){ //expand vertically - ret = new long[]{UtilFunctions.toLong(maxVal), mc.getRows(), lnnz}; + ret = new long[]{maxVal, mc.getRows(), lnnz}; } } } @@ -1156,7 +1156,7 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop if( isNonZeroReplaceArguments() ) setNnz( target.getNnz() ); - break; + break; } case REXPAND: { //dimensions are exactly known from input, sparsity unknown but upper bounded by nrow(v) @@ -1164,7 +1164,7 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop Hop target = getInput().get(_paramIndexMap.get("target")); Hop max = getInput().get(_paramIndexMap.get("max")); Hop dir = getInput().get(_paramIndexMap.get("dir")); - double maxVal = HopRewriteUtils.getDoubleValueSafe((LiteralOp)max); + double maxVal = computeSizeInformation(max); String dirVal = ((LiteralOp)dir).getStringValue(); if( "cols".equals(dirVal) ) { //expand horizontally http://git-wip-us.apache.org/repos/asf/systemml/blob/8c87d2a2/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 351173f..5a3b9aa 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -456,6 +456,11 @@ public class HopRewriteUtils return datagen; } + public static boolean isDataGenOp(Hop hop, DataGenMethod... ops) { + return (hop instanceof DataGenOp + && ArrayUtils.contains(ops, ((DataGenOp)hop).getOp())); + } + public static boolean isDataGenOpWithConstantValue(Hop hop, double value) { return hop instanceof DataGenOp && ((DataGenOp)hop).getOp()==DataGenMethod.RAND @@ -989,17 +994,13 @@ public class HopRewriteUtils return ret; } - public static LiteralOp getBasic1NSequenceMaxLiteral(Hop hop) + public static Hop getBasic1NSequenceMax(Hop hop) throws HopsException { - if( hop instanceof DataGenOp ) - { + if( isDataGenOp(hop, DataGenMethod.SEQ) ) { DataGenOp dgop = (DataGenOp) hop; - if( dgop.getOp() == DataGenMethod.SEQ ){ - Hop to = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO)); - if( to instanceof LiteralOp ) - return (LiteralOp)to; - } + return dgop.getInput() + .get(dgop.getParamIndex(Statement.SEQ_TO)); } throw new HopsException("Failed to retrieve 'to' argument from basic 1-N sequence."); http://git-wip-us.apache.org/repos/asf/systemml/blob/8c87d2a2/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 63c6772..c010bc2 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -1641,7 +1641,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //setup input parameter hops HashMap<String,Hop> inputargs = new HashMap<String,Hop>(); inputargs.put("target", trgt); - inputargs.put("max", HopRewriteUtils.getBasic1NSequenceMaxLiteral(seq)); + inputargs.put("max", HopRewriteUtils.getBasic1NSequenceMax(seq)); inputargs.put("dir", new LiteralOp(direction)); inputargs.put("ignore", new LiteralOp(true)); inputargs.put("cast", new LiteralOp(false));
