[SYSTEMML-2357] Parfor data partitioning rewrites on hops 

For correctness of size propagation, SYSTEMML-2340 and SYSTEMML-2356
moved part of the constant propagation from the parser to
inter-procedural analysis (IPA). The latter has many advantages as it
works in conjunction with repeated rewrites. So far the parfor optimizer
used the original statements from the parser to determine partitioning
strategies and formats (for reuse). Due to the deferred constant
propagation, there were scenarios were the parfor optimizer now missed
partitioning opportunities, especially for block partitioning that
usually involve complex indexing expressions. We now completely rewrote
this partitioning analysis on hops instead of statements, which
seamlessly benefits all previous rewrites and constant propagation.

Additional minor improvements include better size propagation for right
indexing hops and more efficient data partitioning analysis (avoid
program scans for non-matrix data types).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3e54f1a2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3e54f1a2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3e54f1a2

Branch: refs/heads/master
Commit: 3e54f1a2d3599ba38776f5371c01f65476c86f54
Parents: dbc844c
Author: Matthias Boehm <[email protected]>
Authored: Sat Jun 2 20:45:35 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jun 2 21:45:43 2018 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/Hop.java    |   5 +-
 .../java/org/apache/sysml/hops/IndexingOp.java  |  41 +--
 .../sysml/parser/ParForStatementBlock.java      | 273 +++++++++++++------
 .../org/apache/sysml/parser/VariableSet.java    |   5 +
 .../controlprogram/ParForProgramBlock.java      |   2 +-
 .../parfor/opt/OptimizerRuleBased.java          |  16 +-
 .../parfor/ParForRulebasedOptimizerTest.java    |  85 +++---
 .../recompile/IPAComplexAppendTest.java         |   7 +-
 8 files changed, 264 insertions(+), 170 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/3e54f1a2/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index de84088..5d789df 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -917,12 +917,13 @@ public abstract class Hop implements ParseInfo
                }
        }
        
-       public void resetVisitStatus()  {
+       public Hop resetVisitStatus()  {
                if( !isVisited() )
-                       return;
+                       return this;
                for( Hop h : getInput() )
                        h.resetVisitStatus();
                setVisited(false);
+               return this;
        }
        
        public void resetVisitStatusForced(HashSet<Long> memo) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/3e54f1a2/src/main/java/org/apache/sysml/hops/IndexingOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/IndexingOp.java 
b/src/main/java/org/apache/sysml/hops/IndexingOp.java
index 8602dee..4bb03b4 100644
--- a/src/main/java/org/apache/sysml/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysml/hops/IndexingOp.java
@@ -411,6 +411,7 @@ public class IndexingOp extends Hop
        @Override
        public void refreshSizeInformation()
        {
+               Hop input1 = getInput().get(0); //matrix
                Hop input2 = getInput().get(1); //inpRowL
                Hop input3 = getInput().get(2); //inpRowU
                Hop input4 = getInput().get(3); //inpColL
@@ -421,12 +422,8 @@ public class IndexingOp extends Hop
                _colLowerEqualsUpper = (input4 == input5);
                
                //parse input information
-               boolean allRows = 
-                       (    input2 instanceof LiteralOp && 
HopRewriteUtils.getIntValueSafe((LiteralOp)input2)==1 
-                         && input3 instanceof UnaryOp && 
((UnaryOp)input3).getOp() == OpOp1.NROW  );
-               boolean allCols = 
-                       (    input4 instanceof LiteralOp && 
HopRewriteUtils.getIntValueSafe((LiteralOp)input4)==1 
-                         && input5 instanceof UnaryOp && 
((UnaryOp)input5).getOp() == OpOp1.NCOL );
+               boolean allRows = isAllRows();
+               boolean allCols = isAllCols();
                boolean constRowRange = (input2 instanceof LiteralOp && input3 
instanceof LiteralOp);
                boolean constColRange = (input4 instanceof LiteralOp && input5 
instanceof LiteralOp);
                
@@ -434,8 +431,7 @@ public class IndexingOp extends Hop
                if( _rowLowerEqualsUpper ) //ROWS
                        setDim1(1);
                else if( allRows ) {
-                       //input3 guaranteed to be a unaryop-nrow
-                       setDim1(input3.getInput().get(0).getDim1());
+                       setDim1(input1.getDim1());
                }
                else if( constRowRange ) {
                        setDim1( 
HopRewriteUtils.getIntValueSafe((LiteralOp)input3)
@@ -452,8 +448,7 @@ public class IndexingOp extends Hop
                if( _colLowerEqualsUpper ) //COLS
                        setDim2(1);
                else if( allCols ) {
-                       //input5 guaranteed to be a unaryop-ncol
-                       setDim2(input5.getInput().get(0).getDim2());
+                       setDim2(input1.getDim2());
                }
                else if( constColRange ) {
                        setDim2( 
HopRewriteUtils.getIntValueSafe((LiteralOp)input5)
@@ -468,16 +463,30 @@ public class IndexingOp extends Hop
                }
        }
        
+       public boolean isAllRows() {
+               Hop input1 = getInput().get(0);
+               Hop input2 = getInput().get(1);
+               Hop input3 = getInput().get(2);
+               return HopRewriteUtils.isLiteralOfValue(input2, 1)
+                       && ((HopRewriteUtils.isUnary(input3, OpOp1.NROW) && 
input3.getInput().get(0) == input1 )
+                       || HopRewriteUtils.isLiteralOfValue(input3, 
input1.getDim1()));
+       }
+       
+       public boolean isAllCols() {
+               Hop input1 = getInput().get(0);
+               Hop input4 = getInput().get(3);
+               Hop input5 = getInput().get(4);
+               return HopRewriteUtils.isLiteralOfValue(input4, 1)
+                       && ((HopRewriteUtils.isUnary(input5, OpOp1.NCOL) && 
input5.getInput().get(0) == input1 )
+                       || HopRewriteUtils.isLiteralOfValue(input5, 
input1.getDim2()));
+       }
+       
        @Override
-       public Object clone() throws CloneNotSupportedException 
-       {
-               IndexingOp ret = new IndexingOp();      
-               
+       public Object clone() throws CloneNotSupportedException  {
+               IndexingOp ret = new IndexingOp();
                //copy generic attributes
                ret.clone(this, false);
-               
                //copy specific attributes
-
                return ret;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/3e54f1a2/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java 
b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
index 113186f..9c752ae 100644
--- a/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
+++ b/src/main/java/org/apache/sysml/parser/ParForStatementBlock.java
@@ -32,10 +32,18 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
 import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.IndexingOp;
+import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.parser.Expression.BinaryOp;
 import org.apache.sysml.parser.Expression.BuiltinFunctionOp;
 import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.parser.PrintStatement.PRINTTYPE;
 import 
org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat;
 import 
org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PDataPartitioner;
@@ -368,12 +376,12 @@ public class ParForStatementBlock extends 
ForStatementBlock
                return vs;
        }
        
-       public List<String> getReadOnlyParentVars() {
+       public List<String> getReadOnlyParentMatrixVars() {
                VariableSet read = variablesRead();
                VariableSet updated = variablesUpdated();
                return liveIn().getVariableNames().stream() //read-only vars
                        .filter(var -> read.containsVariable(var) && 
!updated.containsVariable(var))
-                       .collect(Collectors.toList());
+                       .filter(var -> 
read.isMatrix(var)).collect(Collectors.toList());
        }
 
        /**
@@ -476,100 +484,100 @@ public class ParForStatementBlock extends 
ForStatementBlock
         */
        private void rDeterminePartitioningCandidates(String var, 
ArrayList<StatementBlock> asb, List<PartitionFormat> C) 
        {
-               for(StatementBlock sb : asb ) // foreach statementblock in 
parforbody
-                       for( Statement s : sb._statements ) // foreach 
statement in statement block
-                       {
-                               if( s instanceof ForStatement ) //includes for 
and parfor
-                               {
-                                       ForStatement fs = (ForStatement) s;
-                                       //predicate
-                                       List<DataIdentifier> datsFromRead = 
rGetDataIdentifiers(fs.getIterablePredicate().getFromExpr());
-                                       List<DataIdentifier> datsToRead = 
rGetDataIdentifiers(fs.getIterablePredicate().getToExpr());
-                                       List<DataIdentifier> datsIncrementRead 
= rGetDataIdentifiers(fs.getIterablePredicate().getIncrementExpr());
-                                       rDeterminePartitioningCandidates(var, 
datsFromRead, C);
-                                       rDeterminePartitioningCandidates(var, 
datsToRead, C);
-                                       rDeterminePartitioningCandidates(var, 
datsIncrementRead, C);
-                                       //for / parfor body
-                                       
rDeterminePartitioningCandidates(var,((ForStatement)s).getBody(), C);
-                               }
-                               else if( s instanceof WhileStatement ) 
-                               {
-                                       WhileStatement ws = (WhileStatement) s;
-                                       //predicate
-                                       List<DataIdentifier> datsRead = 
rGetDataIdentifiers(ws.getConditionalPredicate().getPredicate());
-                                       rDeterminePartitioningCandidates(var, 
datsRead, C);
-                                       //while body
-                                       
rDeterminePartitioningCandidates(var,((WhileStatement)s).getBody(), C);
-                               }
-                               else if( s instanceof IfStatement ) 
-                               {
-                                       IfStatement is = (IfStatement) s;
-                                       //predicate
-                                       List<DataIdentifier> datsRead = 
rGetDataIdentifiers(is.getConditionalPredicate().getPredicate());
-                                       rDeterminePartitioningCandidates(var, 
datsRead, C);
-                                       //if and else branch
-                                       
rDeterminePartitioningCandidates(var,((IfStatement)s).getIfBody(), C);
-                                       
rDeterminePartitioningCandidates(var,((IfStatement)s).getElseBody(), C);
-                               }
-                               else if( s instanceof FunctionStatement ) 
-                               {
-                                       
rDeterminePartitioningCandidates(var,((FunctionStatement)s).getBody(), C);
-                               }
-                               else
-                               {
-                                       List<DataIdentifier> datsRead = 
getDataIdentifiers(s, false);
-                                       rDeterminePartitioningCandidates(var, 
datsRead, C);
-                               }
+               for( StatementBlock sb : asb ) {
+                       if( sb instanceof FunctionStatementBlock ) {
+                               FunctionStatement fs = (FunctionStatement) 
sb.getStatement(0);
+                               rDeterminePartitioningCandidates(var, 
fs.getBody(), C);
+                       }
+                       else if( sb instanceof ForStatementBlock ) { //incl 
parfor
+                               ForStatementBlock fsb = (ForStatementBlock) sb;
+                               ForStatement fs = (ForStatement) 
fsb.getStatement(0);
+                               List<Hop> datsRead = new ArrayList<>();
+                               //predicate
+                               
rGetDataIdentifiers(resetVisitStatus(fsb.getFromHops()), datsRead);
+                               
rGetDataIdentifiers(resetVisitStatus(fsb.getToHops()), datsRead);
+                               
rGetDataIdentifiers(resetVisitStatus(fsb.getIncrementHops()), datsRead);
+                               rDeterminePartitioningCandidates(var, datsRead, 
C);
+                               //for / parfor body
+                               rDeterminePartitioningCandidates(var, 
fs.getBody(), C);
+                       }
+                       else if( sb instanceof WhileStatementBlock ) {
+                               WhileStatementBlock wsb = (WhileStatementBlock) 
sb;
+                               WhileStatement ws = (WhileStatement) 
wsb.getStatement(0);
+                               List<Hop> datsRead = new ArrayList<>();
+                               //predicate
+                               
rGetDataIdentifiers(resetVisitStatus(wsb.getPredicateHops()), datsRead);
+                               rDeterminePartitioningCandidates(var, datsRead, 
C);
+                               //while body
+                               rDeterminePartitioningCandidates(var, 
ws.getBody(), C);
+                       }
+                       else if( sb instanceof IfStatementBlock ) {
+                               IfStatementBlock isb = (IfStatementBlock) sb;
+                               IfStatement is = (IfStatement) 
isb.getStatement(0);
+                               List<Hop> datsRead = new ArrayList<>();
+                               //predicate
+                               
rGetDataIdentifiers(resetVisitStatus(isb.getPredicateHops()), datsRead);
+                               rDeterminePartitioningCandidates(var, datsRead, 
C);
+                               //if and else branch
+                               rDeterminePartitioningCandidates(var, 
is.getIfBody(), C);
+                               rDeterminePartitioningCandidates(var, 
is.getElseBody(), C);
                        }
+                       else if( sb.getHops() != null ) {
+                               Hop.resetVisitStatus(sb.getHops());
+                               List<Hop> datsRead = new ArrayList<>();
+                               for( Hop root : sb.getHops() )
+                                       rGetDataIdentifiers(root, datsRead);
+                               rDeterminePartitioningCandidates(var, datsRead, 
C);
+                       }
+               }
        }
 
-       private void rDeterminePartitioningCandidates(String var, 
List<DataIdentifier> datsRead, List<PartitionFormat> C)
-       {
+       private void rDeterminePartitioningCandidates(String var, List<Hop> 
datsRead, List<PartitionFormat> C) {
                if( datsRead == null )
                        return;
-               
-               for(DataIdentifier read : datsRead)
-                       if( var.equals( read.getName() ) ) {
-                               if( read instanceof IndexedIdentifier )
-                                       C.add( 
determineAccessPattern((IndexedIdentifier) read) );
-                               else if( read instanceof DataIdentifier )
-                                       C.add( PartitionFormat.NONE );
-                       }
+               for(Hop read : datsRead) {
+                       if( read instanceof IndexingOp && var.equals( 
read.getInput().get(0).getName() ) )
+                               C.add( determineAccessPattern((IndexingOp) 
read) );
+                       else if( HopRewriteUtils.isData(read, 
DataOpTypes.TRANSIENTREAD) && var.equals(read.getName()) )
+                               C.add( PartitionFormat.NONE );
+               }
        }
        
-       private PartitionFormat determineAccessPattern( IndexedIdentifier dat )
-       {
+       private Hop resetVisitStatus(Hop hop) {
+               return hop == null ? hop :
+                       hop.resetVisitStatus();
+       }
+       
+       private PartitionFormat determineAccessPattern( IndexingOp rix ) {
                boolean isSpark = OptimizerUtils.isSparkExecutionMode();
                int blksz = ConfigurationManager.getBlocksize();
                PartitionFormat dpf = null;
                
                //1) get all bounds expressions for index access
-               Expression rowL = dat.getRowLowerBound();
-               Expression rowU = dat.getRowUpperBound();
-               Expression colL = dat.getColLowerBound();
-               Expression colU = dat.getColUpperBound();
-               boolean allRows = (rowL == null && rowU == null);
-               boolean allCols = (colL == null && colU == null);
+               Hop rowL = rix.getInput().get(1);
+               Hop rowU = rix.getInput().get(2);
+               Hop colL = rix.getInput().get(3);
+               Hop colU = rix.getInput().get(4);
                
                try {
-                       //2) decided on access pattern          
+                       //2) decided on access pattern
                        //COLUMN_WISE if all rows and access to single column
-                       if( allRows && colL!=null && colL.equals(colU) ) {
+                       if( rix.isAllRows() && colL == colU ) {
                                dpf = PartitionFormat.COLUMN_WISE;
-                       }               
+                       }
                        //ROW_WISE if all cols and access to single row
-                       else if( allCols && rowL!=null && rowL.equals(rowU) ) {
+                       else if( rix.isAllCols() && rowL == rowU ) {
                                dpf = PartitionFormat.ROW_WISE;
                        }
                        //COLUMN_BLOCK_WISE
-                       else if( isSpark && allRows && colL != colU ) {
+                       else if( isSpark && rix.isAllRows() && colL != colU ) {
                                LinearFunction l1 = getLinearFunction(colL, 
true);
                                LinearFunction l2 = getLinearFunction(colU, 
true);
                                dpf = !isAlignedBlocking(l1, l2, blksz) ? 
PartitionFormat.NONE :
                                        new 
PartitionFormat(PDataPartitionFormat.COLUMN_BLOCK_WISE_N, (int)l1._b[0]);
                        }
                        //ROW_BLOCK_WISE
-                       else if( isSpark && allCols && rowL != rowU ) {
+                       else if( isSpark && rix.isAllCols() && rowL != rowU ) {
                                LinearFunction l1 = getLinearFunction(rowL, 
true);
                                LinearFunction l2 = getLinearFunction(rowU, 
true);
                                dpf = !isAlignedBlocking(l1, l2, blksz) ?  
PartitionFormat.NONE :
@@ -582,7 +590,6 @@ public class ParForStatementBlock extends ForStatementBlock
                catch(Exception ex) {
                        throw new RuntimeException(ex);
                }
-               
                return dpf;
        }
        
@@ -857,6 +864,29 @@ public class ParForStatementBlock extends ForStatementBlock
                return ret;
        }
        
+       private List<Hop> rGetDataIdentifiers(Hop root, List<Hop> direads) {
+               if( root == null || root.isVisited() )
+                       return direads;
+               //process children recursively (but disregard meta data ops and 
indexing)
+               if( !((HopRewriteUtils.isUnary(root, OpOp1.NROW, OpOp1.NCOL)
+                       && isDataIdentifier(root.getInput().get(0))) || 
isDataIdentifier(root)) ) {
+                       for( Hop c : root.getInput() )
+                               rGetDataIdentifiers(c, direads);
+               }
+               //handle transient read and right indexing over transient read
+               if( isDataIdentifier(root) )
+                       direads.add(root);
+               root.setVisited();
+               return direads;
+       }
+       
+       private boolean isDataIdentifier(Hop hop) {
+               return HopRewriteUtils.isData(hop, DataOpTypes.TRANSIENTREAD)
+                       || (hop instanceof IndexingOp && HopRewriteUtils.isData(
+                       hop.getInput().get(0), DataOpTypes.TRANSIENTREAD))
+                       || hop instanceof LiteralOp;
+       }
+       
        private void rDetermineBounds( ArrayList<StatementBlock> sbs, boolean 
flag ) {
                for( StatementBlock sb : sbs )
                        rDetermineBounds(sb, flag);
@@ -1305,7 +1335,7 @@ public class ParForStatementBlock extends 
ForStatementBlock
                                                
                                                _bounds._lower.put(id, 1L);
                                                _bounds._upper.put(id, 
_vsParent.getVariable(idat._name).getDim1()); //row dim
-                                               _bounds._increment.put(id, 1L); 
+                                               _bounds._increment.put(id, 1L);
                                        }
                        }
                        else //range indexing
@@ -1461,9 +1491,9 @@ public class ParForStatementBlock extends 
ForStatementBlock
                                if( sub1 instanceof IntIdentifier )
                                        out = new 
LinearFunction(((IntIdentifier)sub1).getValue(), 0, null);
                                else if( sub1 instanceof DataIdentifier )
-                                       out = new LinearFunction(0, 1, 
((DataIdentifier)sub1)._name); //never use public members
+                                       out = new LinearFunction(0, 1, 
((DataIdentifier)sub1).getName());
                                else
-                                       out = 
rParseBinaryExpression((BinaryExpression)sub1);                   
+                                       out = 
rParseBinaryExpression((BinaryExpression)sub1);
                        }
                }
                catch(Exception ex) {
@@ -1498,7 +1528,7 @@ public class ParForStatementBlock extends 
ForStatementBlock
                                if( sub1 instanceof IntIdentifier )
                                        out = new 
LinearFunction(((IntIdentifier)sub1).getValue(), 0, null);
                                else if( sub1 instanceof DataIdentifier )
-                                       out = new LinearFunction(0, 1, 
((DataIdentifier)sub1)._name); //never use public members
+                                       out = new LinearFunction(0, 1, 
((DataIdentifier)sub1).getName());
                                else
                                        out = 
rParseBinaryExpression((BinaryExpression)sub1);
                        }
@@ -1518,6 +1548,7 @@ public class ParForStatementBlock extends 
ForStatementBlock
                return out;
        }
        
+       @SuppressWarnings("unused")
        private LinearFunction getLinearFunction(Expression expr, boolean 
ignoreMinWithConstant) {
                if( expr instanceof IntIdentifier )
                        return new 
LinearFunction(((IntIdentifier)expr).getValue(), 0, null);
@@ -1534,7 +1565,25 @@ public class ParForStatementBlock extends 
ForStatementBlock
                        }
                }
                else if( expr instanceof DataIdentifier )
-                       return new LinearFunction(0, 1, 
((DataIdentifier)expr)._name); //never use public members
+                       return new LinearFunction(0, 1, 
((DataIdentifier)expr).getName());
+               
+               return null;
+       }
+       
+       private LinearFunction getLinearFunction(Hop hop, boolean 
ignoreMinWithConstant) {
+               if( hop instanceof LiteralOp && 
hop.getValueType()==ValueType.INT )
+                       return new 
LinearFunction(HopRewriteUtils.getIntValue((LiteralOp)hop), 0, null);
+               else if( HopRewriteUtils.isBinary(hop, OpOp2.PLUS, OpOp2.MINUS, 
OpOp2.MULT) )
+                       return rParseBinaryExpression(hop);
+               else if( HopRewriteUtils.isBinary(hop, OpOp2.MIN) && 
ignoreMinWithConstant ) {
+                       //note: builtin function expression is also a data 
identifier and hence order before
+                       if( hop.getInput().get(0) instanceof 
org.apache.sysml.hops.BinaryOp )
+                               return 
rParseBinaryExpression(hop.getInput().get(0));
+                       else if( hop.getInput().get(1) instanceof 
org.apache.sysml.hops.BinaryOp )
+                               return 
rParseBinaryExpression(hop.getInput().get(1));
+               }
+               else if( HopRewriteUtils.isData(hop, DataOpTypes.TRANSIENTREAD) 
)
+                       return new LinearFunction(0, 1, hop.getName());
                
                return null;
        }
@@ -1708,22 +1757,74 @@ public class ParForStatementBlock extends 
ForStatementBlock
                }
                return null; //let dependency analysis fail
        }
+       
+       private LinearFunction rParseBinaryExpression(Hop hop) {
+               org.apache.sysml.hops.BinaryOp bop = 
(org.apache.sysml.hops.BinaryOp) hop;
+               Hop l = bop.getInput().get(0);
+               Hop r = bop.getInput().get(1);
+               if( bop.getOp()==OpOp2.PLUS || bop.getOp()==OpOp2.MINUS ) {
+                       boolean plus = bop.getOp() == OpOp2.PLUS;
+                       //parse binary expressions
+                       if( l instanceof org.apache.sysml.hops.BinaryOp) {
+                               LinearFunction f = rParseBinaryExpression(l);
+                               Long cvalR = parseLongConstant(r);
+                               if( f != null && cvalR != null )
+                                       return f.addConstant(cvalR * 
(plus?1:-1));
+                       }
+                       else if (r instanceof org.apache.sysml.hops.BinaryOp) {
+                               LinearFunction f = rParseBinaryExpression(r);
+                               Long cvalL = parseLongConstant(l);
+                               if( f != null && cvalL != null )
+                                       return 
f.scale(plus?1:-1).addConstant(cvalL);
+                       }
+                       else { // atomic case
+                               //change everything to plus if necessary
+                               Long cvalL = parseLongConstant(l);
+                               Long cvalR = parseLongConstant(r);
+                               if( cvalL != null )
+                                       return new LinearFunction(cvalL, 
plus?1:-1, r.getName() );
+                               else if( cvalR != null )
+                                       return new 
LinearFunction(cvalR*(plus?1:-1), 1, l.getName());
+                       }
+               }
+               else if( bop.getOp() == OpOp2.MULT ) {
+                       //atomic case (only recursion for MULT expressions, 
where one side is a constant)
+                       Long cvalL = parseLongConstant(l);
+                       Long cvalR = parseLongConstant(r);
+                       if( cvalL != null && HopRewriteUtils.isData(r, 
DataOpTypes.TRANSIENTREAD) )
+                               return new LinearFunction(0, cvalL, 
r.getName());
+                       else if( cvalR != null && HopRewriteUtils.isData(l, 
DataOpTypes.TRANSIENTREAD) )
+                               return new LinearFunction(0, cvalR, 
l.getName());
+                       else if( cvalL != null && r instanceof 
org.apache.sysml.hops.BinaryOp )
+                               return rParseBinaryExpression(r).scale(cvalL);
+                       else if( cvalR != null && l instanceof 
org.apache.sysml.hops.BinaryOp )
+                               return rParseBinaryExpression(l).scale(cvalR);
+               }
+               return null; //let dependency analysis fail
+       }
 
-       private static Long parseLongConstant(Expression expr)
-       {
-               Long ret = null;
-               
+       private static Long parseLongConstant(Expression expr) {
                if( expr instanceof IntIdentifier ) {
-                       ret = ((IntIdentifier) expr).getValue();
+                       return ((IntIdentifier) expr).getValue();
                }
                else if( expr instanceof DoubleIdentifier ) {
                        double tmp = ((DoubleIdentifier) expr).getValue();
-                       //ensure double represent an integer number
-                       if( tmp == Math.floor(tmp) ) 
-                               ret = UtilFunctions.toLong(tmp);
+                       if( tmp == Math.floor(tmp) ) //ensure int
+                               return UtilFunctions.toLong(tmp);
                }
-               
-               return ret;
+               return null;
+       }
+       
+       private static Long parseLongConstant(Hop hop) {
+               if( hop instanceof LiteralOp && 
hop.getValueType()==ValueType.INT ) {
+                       return HopRewriteUtils.getIntValue((LiteralOp)hop);
+               }
+               else if( hop instanceof LiteralOp && 
hop.getValueType()==ValueType.DOUBLE ) {
+                       double tmp = 
HopRewriteUtils.getDoubleValue((LiteralOp)hop);
+                       if( tmp == Math.floor(tmp) ) //ensure int
+                               return UtilFunctions.toLong(tmp);
+               }
+               return null;
        }
        
        public static class ResultVar {

http://git-wip-us.apache.org/repos/asf/systemml/blob/3e54f1a2/src/main/java/org/apache/sysml/parser/VariableSet.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/VariableSet.java 
b/src/main/java/org/apache/sysml/parser/VariableSet.java
index eb5dc92..2b1bee5 100644
--- a/src/main/java/org/apache/sysml/parser/VariableSet.java
+++ b/src/main/java/org/apache/sysml/parser/VariableSet.java
@@ -73,6 +73,11 @@ public class VariableSet
                return _variables;
        }
        
+       public boolean isMatrix(String name) {
+               return _variables.containsKey(name)
+                       && _variables.get(name).getDataType().isMatrix();
+       }
+       
        @Override
        public String toString() {
                return Arrays.toString(

http://git-wip-us.apache.org/repos/asf/systemml/blob/3e54f1a2/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
index 368737b..83e9cde 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
@@ -1142,7 +1142,7 @@ public class ParForProgramBlock extends ForProgramBlock
                        if( sb == null )
                                throw new DMLRuntimeException("ParFor statement 
block required for reasoning about data partitioning.");
                        
-                       for( String var : sb.getReadOnlyParentVars() )
+                       for( String var : sb.getReadOnlyParentMatrixVars() )
                        {
                                Data dat = ec.getVariable(var);
                                //skip non-existing input matrices (which are 
due to unknown sizes marked for

http://git-wip-us.apache.org/repos/asf/systemml/blob/3e54f1a2/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
index 043e4ed..e3fb71d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
@@ -411,17 +411,13 @@ public class OptimizerRuleBased extends Optimizer
                        && (_N >= PROB_SIZE_THRESHOLD_PARTITIONING || _Nmax >= 
PROB_SIZE_THRESHOLD_PARTITIONING) ) //only if beneficial wrt problem size
                {
                        HashMap<String, PartitionFormat> cand2 = new 
HashMap<>();
-                       for( String c : pfsb.getReadOnlyParentVars() )
-                       {
+                       for( String c : pfsb.getReadOnlyParentMatrixVars() ) {
                                PartitionFormat dpf = 
pfsb.determineDataPartitionFormat( c );
-                               
                                if( dpf != PartitionFormat.NONE 
-                                       && dpf._dpf != 
PDataPartitionFormat.BLOCK_WISE_M_N ) 
-                               {
+                                       && dpf._dpf != 
PDataPartitionFormat.BLOCK_WISE_M_N ) {
                                        cand2.put( c, dpf );
                                }
                        }
-                       
                        apply = rFindDataPartitioningCandidates(n, cand2, vars, 
thetaM);
                        if( apply )
                                partitionedMatrices.putAll(cand2);
@@ -441,7 +437,7 @@ public class OptimizerRuleBased extends Optimizer
        
                _numEvaluatedPlans++;
                LOG.debug(getOptMode()+" OPT: rewrite 'set data partitioner' - 
result="+pdp.toString()+
-                                 " 
("+ProgramConverter.serializeStringCollection(partitionedMatrices.keySet())+")" 
);
+                       " 
("+ProgramConverter.serializeStringCollection(partitionedMatrices.keySet())+")" 
);
                
                return blockwise;
        }
@@ -1195,11 +1191,11 @@ public class OptimizerRuleBased extends Optimizer
                        double mem = (OptimizerUtils.isSparkExecutionMode() && 
!n.isCPOnly()) ? _lm/2 : _lm;
                        double sharedM = 0, nonSharedM = M;
                        if( computeMaxK(M, M, 0, mem) < kMax ) { //account for 
shared read if necessary
-                               sharedM = 
pfsb.getReadOnlyParentVars().stream().map(s -> vars.get(s))
+                               sharedM = 
pfsb.getReadOnlyParentMatrixVars().stream().map(s -> vars.get(s))
                                        .filter(d -> d instanceof 
MatrixObject).mapToDouble(mo -> OptimizerUtils
                                        
.estimateSize(((MatrixObject)mo).getMatrixCharacteristics())).sum();
                                nonSharedM = 
cost.getEstimate(TestMeasure.MEMORY_USAGE, n, true,
-                                       pfsb.getReadOnlyParentVars(), 
ExcludeType.SHARED_READ);
+                                       pfsb.getReadOnlyParentMatrixVars(), 
ExcludeType.SHARED_READ);
                        }
                        
                        //ensure local memory constraint (for spark more 
conservative in order to 
@@ -1909,7 +1905,7 @@ public class OptimizerRuleBased extends Optimizer
                        rCollectZipmmPartitioningCandidates(n, cand);
                        
                        //prune updated candidates
-                       HashSet<String> probe = new 
HashSet<>(pfsb.getReadOnlyParentVars());
+                       HashSet<String> probe = new 
HashSet<>(pfsb.getReadOnlyParentMatrixVars());
                        for( String var : cand )
                                if( probe.contains( var ) )
                                        ret.add( var );

http://git-wip-us.apache.org/repos/asf/systemml/blob/3e54f1a2/src/test/java/org/apache/sysml/test/integration/functions/parfor/ParForRulebasedOptimizerTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/parfor/ParForRulebasedOptimizerTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/parfor/ParForRulebasedOptimizerTest.java
index 194bcc8..9ce7da3 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/parfor/ParForRulebasedOptimizerTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/parfor/ParForRulebasedOptimizerTest.java
@@ -36,7 +36,7 @@ public class ParForRulebasedOptimizerTest extends 
AutomatedTestBase
        private final static String TEST_DIR = "functions/parfor/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
ParForRulebasedOptimizerTest.class.getSimpleName() + "/";
        private final static double eps = 1e-10;
-               
+       
        private final static int rows1 = 1000; //small CP
        private final static int rows2 = 10000; //large MR
        
@@ -47,14 +47,10 @@ public class ParForRulebasedOptimizerTest extends 
AutomatedTestBase
        private final static int cols22 = 50; //large nested parfor
        private final static int cols31 = 2;  //small nested parfor
        private final static int cols32 = 8; //large nested parfor
-       
-       
        private final static double sparsity = 0.7;
        
-       
        @Override
-       public void setUp() 
-       {
+       public void setUp() {
                addTestConfiguration(TEST_NAME1,
                        new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new 
String[] { "Rout" }) );
                addTestConfiguration(TEST_NAME2,
@@ -63,7 +59,6 @@ public class ParForRulebasedOptimizerTest extends 
AutomatedTestBase
                        new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new 
String[] { "Rout" }) );
        }
 
-       
        @Test
        public void testParForRulebasedOptimizerCorrelationSmallSmall() {
                runParForOptimizerTest(1, false, false, false);
@@ -188,31 +183,28 @@ public class ParForRulebasedOptimizerTest extends 
AutomatedTestBase
        private void runParForOptimizerTest( int scriptNum, boolean largeRows, 
boolean largeCols, boolean timebasedOpt )
        {
                //find right rows and cols configuration
-               int rows=-1, cols=-1;  
+               int rows=-1, cols=-1;
                if( largeRows )
                        rows = rows2;
                else
                        rows = rows1; 
                if( largeCols ){
-                       switch(scriptNum)
-                       {
+                       switch(scriptNum) {
                                case 1: cols=cols22; break;
                                case 2: cols=cols32; break;
-                               case 3: cols=cols12; break;                     
        
+                               case 3: cols=cols12; break;
                        }
                }
                else{
-                       switch(scriptNum)
-                       {
+                       switch(scriptNum) {
                                case 1: cols=cols21; break;
                                case 2: cols=cols31; break;
-                               case 3: cols=cols11; break;                     
        
+                               case 3: cols=cols11; break;
                        }
                }
 
                //run actual test
-               switch( scriptNum )
-               {
+               switch( scriptNum ) {
                        case 1: 
                                runUnaryTest(scriptNum, timebasedOpt, rows, 
cols);
                                break;
@@ -221,7 +213,7 @@ public class ParForRulebasedOptimizerTest extends 
AutomatedTestBase
                                break;
                        case 3: 
                                runUnaryTest(scriptNum, timebasedOpt, rows, 
cols);
-                               break;  
+                               break;
                }
        }
        
@@ -229,14 +221,12 @@ public class ParForRulebasedOptimizerTest extends 
AutomatedTestBase
        {
                TestConfiguration config = null;
                String HOME = SCRIPT_DIR + TEST_DIR;
-               if( scriptNum==1 )
-               {
+               if( scriptNum==1 ) {
                        config=getTestConfiguration(TEST_NAME1);
                        String testname = TEST_NAME1 + (timebasedOpt ? "b" : 
"");
                        fullDMLScriptName = HOME + testname + ".dml";
                }
-               else if( scriptNum==3 )
-               {
+               else if( scriptNum==3 ) {
                        config=getTestConfiguration(TEST_NAME3);
                        String testname = TEST_NAME3 + (timebasedOpt ? "b" : 
"");
                        fullDMLScriptName = HOME + testname + ".dml";
@@ -246,28 +236,24 @@ public class ParForRulebasedOptimizerTest extends 
AutomatedTestBase
                config.addVariable("cols", cols);
                loadTestConfiguration(config);
                
-               if( scriptNum==1 )
-               {
-                       programArgs = new String[]{ "-args", input("V"), 
-                               Integer.toString(rows), Integer.toString(cols), 
+               if( scriptNum==1 ) {
+                       programArgs = new String[]{ "-args", input("V"),
+                               Integer.toString(rows), Integer.toString(cols),
                                output("R") };
-                       
-                       rCmd = "Rscript" + " " + HOME + TEST_NAME1 + ".R" + " " 
+ 
+                       rCmd = "Rscript" + " " + HOME + TEST_NAME1 + ".R" + " " 
+
                                inputDir() + " " + expectedDir();
-               }       
-               else if( scriptNum==3 )
-               {
-                       programArgs = new String[]{ "-args", input("V"), 
-                               Integer.toString(rows), Integer.toString(cols), 
+               }
+               else if( scriptNum==3 ) {
+                       programArgs = new String[]{ "-args", input("V"),
+                               Integer.toString(rows), Integer.toString(cols),
                                Integer.toString(cols/2), 
                                output("R") };
-                       
-                       rCmd = "Rscript" + " " + HOME + TEST_NAME3 + ".R" + " " 
+ 
+                       rCmd = "Rscript" + " " + HOME + TEST_NAME3 + ".R" + " " 
+
                                inputDir() + " " + expectedDir();
-               }       
+               }
 
                long seed = System.nanoTime();
-        double[][] V = getRandomMatrix(rows, cols, 0, 1, sparsity, seed);
+               double[][] V = getRandomMatrix(rows, cols, 0, 1, sparsity, 
seed);
                writeInputMatrix("V", V, true);
 
                boolean exceptionExpected = false;
@@ -278,7 +264,7 @@ public class ParForRulebasedOptimizerTest extends 
AutomatedTestBase
                //compare matrices
                HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
                HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("Rout");
-               TestUtils.compareMatrices(dmlfile, rfile, eps, "DML", "R");     
        
+               TestUtils.compareMatrices(dmlfile, rfile, eps, "DML", "R");
        }
        
        private void runNaryTest(int scriptNum, boolean timebasedOpt, int rows, 
int cols)
@@ -317,35 +303,32 @@ public class ParForRulebasedOptimizerTest extends 
AutomatedTestBase
                }
                writeInputMatrix("D", D, true);
                
-               //generate attribute sets               
+               //generate attribute sets
                double[][] S1 = getRandomMatrix(1, cols, 1, cols+1-eps, 1, 
1112);
                double[][] S2 = getRandomMatrix(1, cols, 1, cols+1-eps, 1, 
1113);
                TestUtils.floor(S1);
                TestUtils.floor(S2);
                writeInputMatrix("S1", S1, true);
-               writeInputMatrix("S2", S2, true);       
+               writeInputMatrix("S2", S2, true);
 
                //generate kind for attributes (1,2,3)
-        double[][] K1 = new double[1][cols];
-        double[][] K2 = new double[1][cols];
-        for( int i=0; i<cols; i++ )
-        {
-               K1[0][i] = Dkind[(int)S1[0][i]-1];
-               K2[0][i] = Dkind[(int)S2[0][i]-1];
-        }
-        writeInputMatrix("K1", K1, true);
+               double[][] K1 = new double[1][cols];
+               double[][] K2 = new double[1][cols];
+               for( int i=0; i<cols; i++ ) {
+                       K1[0][i] = Dkind[(int)S1[0][i]-1];
+                       K2[0][i] = Dkind[(int)S2[0][i]-1];
+               }
+               writeInputMatrix("K1", K1, true);
                writeInputMatrix("K2", K2, true);
                
                boolean exceptionExpected = false;
                runTest(true, exceptionExpected, null, -1);
 
-               runRScript(true); 
+               runRScript(true);
                
                //compare matrices 
-               for( String out : new String[]{"bivar.stats", 
"category.counts", "category.means",  "category.variances" } )
-               {
+               for( String out : new String[]{"bivar.stats", 
"category.counts", "category.means",  "category.variances" } ) {
                        HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("bivarstats/"+out);
-                       
                        HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS(out);
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/3e54f1a2/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAComplexAppendTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAComplexAppendTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAComplexAppendTest.java
index 534ab91..58f9f46 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAComplexAppendTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAComplexAppendTest.java
@@ -35,7 +35,6 @@ import org.apache.sysml.test.integration.TestConfiguration;
 
 public class IPAComplexAppendTest extends AutomatedTestBase 
 {
-       
        private final static String TEST_NAME = "append_nnz";
        private final static String TEST_DIR = "functions/recompile/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
IPAComplexAppendTest.class.getSimpleName() + "/";
@@ -101,8 +100,9 @@ public class IPAComplexAppendTest extends AutomatedTestBase
                        runTest(true, false, null, -1); 
                        
                        //check expected number of compiled and executed MR jobs
-                       int expectedNumCompiled = (rewrites&&IPA)?2:3; //(GMR 
mm,) GMR append, GMR sum
-                       int expectedNumExecuted = rewrites?0:1; //(GMR mm)      
                
+                       //TODO investigate IPA side effect
+                       int expectedNumCompiled = (rewrites&&IPA)?1:3; //(GMR 
mm+, GMR append,) GMR sum
+                       int expectedNumExecuted = rewrites?0:IPA?2:1; //(GMR 
mm+, GMR append) 
                        
                        checkNumCompiledMRJobs(expectedNumCompiled); 
                        checkNumExecutedMRJobs(expectedNumExecuted); 
@@ -113,5 +113,4 @@ public class IPAComplexAppendTest extends AutomatedTestBase
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
oldFlagRewrites;
                }
        }
-       
 }

Reply via email to