[SYSTEMML-1852] Fix rewrite outer-rexpand (generalized max handling)

This patch fixes a static rewrite issue of outer to rexpand, which
failed whenever the sequence to parameter was not a literal after
constant folding. We now extended this rewrite to arbitrary hops and
generalized the related size propagation and memory estimates
accordingly.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8c87d2a2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8c87d2a2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8c87d2a2

Branch: refs/heads/master
Commit: 8c87d2a2f3472d5516113a65817defe5ce6989a6
Parents: fc7fcb0
Author: Matthias Boehm <[email protected]>
Authored: Thu Aug 17 22:40:32 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Aug 18 14:15:45 2017 -0700

----------------------------------------------------------------------
 .../apache/sysml/hops/ParameterizedBuiltinOp.java  | 10 +++++-----
 .../apache/sysml/hops/rewrite/HopRewriteUtils.java | 17 +++++++++--------
 .../RewriteAlgebraicSimplificationStatic.java      |  2 +-
 3 files changed, 15 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/8c87d2a2/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index 7d6fe2a..ab276d7 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -1025,15 +1025,15 @@ public class ParameterizedBuiltinOp extends Hop 
implements MultiThreadedHop
                        //but very good sparsity estimate possible (number of 
non-zeros in input)
                        Hop max = getInput().get(_paramIndexMap.get("max"));
                        Hop dir = getInput().get(_paramIndexMap.get("dir"));
-                       double maxVal = 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)max);
+                       long maxVal = computeDimParameterInformation(max, memo);
                        String dirVal = ((LiteralOp)dir).getStringValue();
                        if( mc.dimsKnown() ) {
                                long lnnz = mc.nnzKnown() ? mc.getNonZeros() : 
mc.getRows();
                                if( "cols".equals(dirVal) ) { //expand 
horizontally
-                                       ret = new long[]{mc.getRows(), 
UtilFunctions.toLong(maxVal), lnnz};
+                                       ret = new long[]{mc.getRows(), maxVal, 
lnnz};
                                }
                                else if( "rows".equals(dirVal) ){ //expand 
vertically
-                                       ret = new 
long[]{UtilFunctions.toLong(maxVal), mc.getRows(), lnnz};
+                                       ret = new long[]{maxVal, mc.getRows(), 
lnnz};
                                }       
                        }
                }
@@ -1156,7 +1156,7 @@ public class ParameterizedBuiltinOp extends Hop 
implements MultiThreadedHop
                                if( isNonZeroReplaceArguments() )
                                        setNnz( target.getNnz() );
                                
-                               break;  
+                               break;
                        }
                        case REXPAND: {
                                //dimensions are exactly known from input, 
sparsity unknown but upper bounded by nrow(v)
@@ -1164,7 +1164,7 @@ public class ParameterizedBuiltinOp extends Hop 
implements MultiThreadedHop
                                Hop target = 
getInput().get(_paramIndexMap.get("target"));
                                Hop max = 
getInput().get(_paramIndexMap.get("max"));
                                Hop dir = 
getInput().get(_paramIndexMap.get("dir"));
-                               double maxVal = 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)max);
+                               double maxVal = computeSizeInformation(max);
                                String dirVal = 
((LiteralOp)dir).getStringValue();
                                
                                if( "cols".equals(dirVal) ) { //expand 
horizontally

http://git-wip-us.apache.org/repos/asf/systemml/blob/8c87d2a2/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 351173f..5a3b9aa 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -456,6 +456,11 @@ public class HopRewriteUtils
                return datagen;
        }
        
+       public static boolean isDataGenOp(Hop hop, DataGenMethod... ops) {
+               return (hop instanceof DataGenOp 
+                       && ArrayUtils.contains(ops, ((DataGenOp)hop).getOp()));
+       }
+       
        public static boolean isDataGenOpWithConstantValue(Hop hop, double 
value) {
                return hop instanceof DataGenOp
                        && ((DataGenOp)hop).getOp()==DataGenMethod.RAND
@@ -989,17 +994,13 @@ public class HopRewriteUtils
                return ret;
        }
 
-       public static LiteralOp getBasic1NSequenceMaxLiteral(Hop hop) 
+       public static Hop getBasic1NSequenceMax(Hop hop) 
                throws HopsException
        {
-               if( hop instanceof DataGenOp )
-               {
+               if( isDataGenOp(hop, DataGenMethod.SEQ) ) {
                        DataGenOp dgop = (DataGenOp) hop;
-                       if( dgop.getOp() == DataGenMethod.SEQ ){
-                               Hop to = 
dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO));
-                               if( to instanceof LiteralOp )
-                                       return (LiteralOp)to;
-                       }
+                       return dgop.getInput()
+                               .get(dgop.getParamIndex(Statement.SEQ_TO));
                }
                
                throw new HopsException("Failed to retrieve 'to' argument from 
basic 1-N sequence.");

http://git-wip-us.apache.org/repos/asf/systemml/blob/8c87d2a2/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 63c6772..c010bc2 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1641,7 +1641,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                //setup input parameter hops
                                HashMap<String,Hop> inputargs = new 
HashMap<String,Hop>();
                                inputargs.put("target", trgt);
-                               inputargs.put("max", 
HopRewriteUtils.getBasic1NSequenceMaxLiteral(seq));
+                               inputargs.put("max", 
HopRewriteUtils.getBasic1NSequenceMax(seq));
                                inputargs.put("dir", new LiteralOp(direction));
                                inputargs.put("ignore", new LiteralOp(true));
                                inputargs.put("cast", new LiteralOp(false));

Reply via email to