[SYSTEMML-2185] Phase ordering of statement block rewrites (merge/split)

There are a number of rewrites that modify the program structure (e.g.,
branch removal, merge of block sequences, inlining, split of dags after
data-dependent operators, checkpointing). The order of these rewrites
matters because artificially split dags are not subject to the merge of
block sequences which can lose optimization opportunities.

Accordingly, this patch improves the existing static rewrites by a
proper phase ordering where we first apply all rewrites that merge and
consolidate statement blocks (for maximum CSE) and second apply
splitting rewrites to create necessary recompilation hooks.

For example on stratstats 100K x 1K, this patch improved end-to-end
performance from 1296s to 1080s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e1fd3431
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e1fd3431
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e1fd3431

Branch: refs/heads/master
Commit: e1fd34314e8f68d5626407cdd9d8197948605573
Parents: bcaa140
Author: Matthias Boehm <[email protected]>
Authored: Thu Mar 15 14:53:39 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Mar 15 14:53:39 2018 -0700

----------------------------------------------------------------------
 .../hops/ipa/IPAPassApplyStaticHopRewrites.java |  2 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     | 61 ++++++++++----------
 .../hops/rewrite/RewriteCompressedReblock.java  |  5 ++
 .../rewrite/RewriteForLoopVectorization.java    |  4 ++
 .../RewriteInjectSparkLoopCheckpointing.java    |  5 ++
 .../RewriteInjectSparkPReadCheckpointing.java   |  5 +-
 .../RewriteMarkLoopVariablesUpdateInPlace.java  |  5 ++
 .../hops/rewrite/RewriteMergeBlockSequence.java |  5 ++
 .../rewrite/RewriteRemoveEmptyBasicBlocks.java  |  5 ++
 .../RewriteRemoveUnnecessaryBranches.java       |  5 ++
 .../RewriteSplitDagDataDependentOperators.java  |  5 ++
 .../rewrite/RewriteSplitDagUnknownCSVRead.java  |  5 ++
 .../hops/rewrite/StatementBlockRewriteRule.java |  8 +++
 .../org/apache/sysml/parser/DMLTranslator.java  |  4 +-
 .../parfor/opt/OptimizationWrapper.java         |  2 +-
 .../parfor/opt/OptimizerRuleBased.java          |  2 +-
 16 files changed, 92 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
index cae2e35..57bff00 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
@@ -49,7 +49,7 @@ public class IPAPassApplyStaticHopRewrites extends IPAPass
                        
rewriter.removeStatementBlockRewrite(RewriteInjectSparkLoopCheckpointing.class);
                        
                        //rewrite program hop dags and statement blocks
-                       rewriter.rewriteProgramHopDAGs(prog);
+                       rewriter.rewriteProgramHopDAGs(prog, true); //rewrite 
and split
                } 
                catch (LanguageException ex) {
                        throw new HopsException(ex);

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index a73ea6d..65687e3 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -71,7 +71,7 @@ public class ProgramRewriter
                this( true, true );
        }
        
-       public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites 
)
+       public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
        {
                //initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
                _dagRuleSet = new ArrayList<>();
@@ -145,12 +145,11 @@ public class ProgramRewriter
         * 
         * @param rewrites the HOP rewrite rules
         */
-       public ProgramRewriter( HopRewriteRule... rewrites ) {
+       public ProgramRewriter(HopRewriteRule... rewrites) {
                //initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
                _dagRuleSet = new ArrayList<>();
                for( HopRewriteRule rewrite : rewrites )
                        _dagRuleSet.add( rewrite );
-               
                _sbRuleSet = new ArrayList<>();
        }
        
@@ -159,10 +158,9 @@ public class ProgramRewriter
         * 
         * @param rewrites the statement block rewrite rules
         */
-       public ProgramRewriter( StatementBlockRewriteRule... rewrites ) {
+       public ProgramRewriter(StatementBlockRewriteRule... rewrites) {
                //initialize HOP DAG rewrite ruleSet (with fixed rewrite order)
                _dagRuleSet = new ArrayList<>();
-               
                _sbRuleSet = new ArrayList<>();
                for( StatementBlockRewriteRule rewrite : rewrites )
                        _sbRuleSet.add( rewrite );
@@ -191,9 +189,13 @@ public class ProgramRewriter
                _sbRuleSet.removeIf(r -> r.getClass().equals(clazz));
        }
        
-       public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) 
+       public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) 
throws LanguageException, HopsException {
+               return rewriteProgramHopDAGs(dmlp, true);
+       }
+       
+       public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, 
boolean splitDags) 
                throws LanguageException, HopsException
-       {       
+       {
                ProgramRewriteStatus state = new ProgramRewriteStatus();
                
                // for each namespace, handle function statement blocks
@@ -201,7 +203,8 @@ public class ProgramRewriter
                        for (String fname : 
dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
                                FunctionStatementBlock fsblock = 
dmlp.getFunctionStatementBlock(namespaceKey,fname);
                                rRewriteStatementBlockHopDAGs(fsblock, state);
-                               rRewriteStatementBlock(fsblock, state);
+                               if( !_sbRuleSet.isEmpty() )
+                                       rRewriteStatementBlock(fsblock, state, 
splitDags);
                        }
                
                // handle regular statement blocks in "main" method
@@ -209,7 +212,9 @@ public class ProgramRewriter
                        StatementBlock current = dmlp.getStatementBlock(i);
                        rRewriteStatementBlockHopDAGs(current, state);
                }
-               dmlp.setStatementBlocks( 
rRewriteStatementBlocks(dmlp.getStatementBlocks(), state) );
+               if( !_sbRuleSet.isEmpty() )
+                       dmlp.setStatementBlocks(rRewriteStatementBlocks(
+                               dmlp.getStatementBlocks(), state, splitDags));
                
                return state;
        }
@@ -289,7 +294,7 @@ public class ProgramRewriter
                return root;
        }
        
-       public ArrayList<StatementBlock> rRewriteStatementBlocks( 
ArrayList<StatementBlock> sbs, ProgramRewriteStatus status ) 
+       public ArrayList<StatementBlock> 
rRewriteStatementBlocks(ArrayList<StatementBlock> sbs, ProgramRewriteStatus 
status, boolean splitDags)
                throws HopsException
        {
                //ensure robustness for calls from outside
@@ -299,16 +304,18 @@ public class ProgramRewriter
                //apply rewrite rules to list of statement blocks
                List<StatementBlock> tmp = sbs; 
                for( StatementBlockRewriteRule r : _sbRuleSet )
-                       tmp = r.rewriteStatementBlocks(tmp, status);
+                       if( splitDags || !r.createsSplitDag() )
+                               tmp = r.rewriteStatementBlocks(tmp, status);
                
                //recursively rewrite statement blocks (with potential 
expansion)
                List<StatementBlock> tmp2 = new ArrayList<>();
                for( StatementBlock sb : tmp )
-                       tmp2.addAll( rRewriteStatementBlock(sb, status) );
+                       tmp2.addAll( rRewriteStatementBlock(sb, status, 
splitDags) );
                
                //apply rewrite rules to list of statement blocks (with 
potential contraction)
                for( StatementBlockRewriteRule r : _sbRuleSet )
-                       tmp2 = r.rewriteStatementBlocks(tmp2, status);
+                       if( splitDags || !r.createsSplitDag() )
+                               tmp2 = r.rewriteStatementBlocks(tmp2, status);
                
                //prepare output list
                sbs.clear();
@@ -316,48 +323,44 @@ public class ProgramRewriter
                return sbs;
        }
        
-       public ArrayList<StatementBlock> rRewriteStatementBlock( StatementBlock 
sb, ProgramRewriteStatus status ) 
+       public ArrayList<StatementBlock> rRewriteStatementBlock(StatementBlock 
sb, ProgramRewriteStatus status, boolean splitDags)
                throws HopsException
        {
                ArrayList<StatementBlock> ret = new ArrayList<>();
                ret.add(sb);
                
                //recursive invocation
-               if (sb instanceof FunctionStatementBlock)
-               {
+               if (sb instanceof FunctionStatementBlock) {
                        FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
                        FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
-                       fstmt.setBody( rRewriteStatementBlocks(fstmt.getBody(), 
status) );
+                       fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), 
status, splitDags));
                }
-               else if (sb instanceof WhileStatementBlock)
-               {
+               else if (sb instanceof WhileStatementBlock) {
                        WhileStatementBlock wsb = (WhileStatementBlock) sb;
                        WhileStatement wstmt = 
(WhileStatement)wsb.getStatement(0);
-                       wstmt.setBody( rRewriteStatementBlocks( 
wstmt.getBody(), status ) );
+                       wstmt.setBody(rRewriteStatementBlocks(wstmt.getBody(), 
status, splitDags));
                }       
-               else if (sb instanceof IfStatementBlock)
-               {
+               else if (sb instanceof IfStatementBlock) {
                        IfStatementBlock isb = (IfStatementBlock) sb;
                        IfStatement istmt = (IfStatement)isb.getStatement(0);
-                       istmt.setIfBody( rRewriteStatementBlocks( 
istmt.getIfBody(), status ) );
-                       istmt.setElseBody( rRewriteStatementBlocks( 
istmt.getElseBody(), status ) );
+                       
istmt.setIfBody(rRewriteStatementBlocks(istmt.getIfBody(), status, splitDags));
+                       
istmt.setElseBody(rRewriteStatementBlocks(istmt.getElseBody(), status, 
splitDags));
                }
-               else if (sb instanceof ForStatementBlock) //incl parfor
-               {
+               else if (sb instanceof ForStatementBlock) { //incl parfor
                        //maintain parfor context information (e.g., for 
checkpointing)
                        boolean prestatus = status.isInParforContext();
                        if( sb instanceof ParForStatementBlock )
                                status.setInParforContext(true);
-                       
                        ForStatementBlock fsb = (ForStatementBlock) sb;
                        ForStatement fstmt = (ForStatement)fsb.getStatement(0);
-                       fstmt.setBody( rRewriteStatementBlocks(fstmt.getBody(), 
status) );
-                       
+                       fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), 
status, splitDags));
                        status.setInParforContext(prestatus);
                }
                
                //apply rewrite rules to individual statement blocks
                for( StatementBlockRewriteRule r : _sbRuleSet ) {
+                       if( !splitDags && r.createsSplitDag() )
+                               continue;
                        ArrayList<StatementBlock> tmp = new ArrayList<>();
                        for( StatementBlock sbc : ret )
                                tmp.addAll( r.rewriteStatementBlock(sbc, 
status) );

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java
index fdaad10..b24a178 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java
@@ -63,6 +63,11 @@ public class RewriteCompressedReblock extends 
StatementBlockRewriteRule
        private static final String TMP_PREFIX = "__cmtx";
        
        @Override
+       public boolean createsSplitDag() {
+               return false;
+       }
+       
+       @Override
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus sate)
                throws HopsException 
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
index fce1fa1..cade354 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
@@ -53,6 +53,10 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
        private static final OpOp2[] MAP_SCALAR_AGGREGATE_SOURCE_OPS = new 
OpOp2[]{OpOp2.PLUS, OpOp2.MULT, OpOp2.MIN, OpOp2.MAX};
        private static final AggOp[] MAP_SCALAR_AGGREGATE_TARGET_OPS = new 
AggOp[]{AggOp.SUM,  AggOp.PROD, AggOp.MIN, AggOp.MAX};
        
+       @Override
+       public boolean createsSplitDag() {
+               return false;
+       }
        
        @Override
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java
index ddee8ab..50a5579 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java
@@ -54,6 +54,11 @@ public class RewriteInjectSparkLoopCheckpointing extends 
StatementBlockRewriteRu
        }
        
        @Override
+       public boolean createsSplitDag() {
+               return true;
+       }
+       
+       @Override
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus status)
                throws HopsException 
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
index 755c2e8..7ee3506 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java
@@ -35,7 +35,6 @@ import org.apache.sysml.hops.OptimizerUtils;
  */
 public class RewriteInjectSparkPReadCheckpointing extends HopRewriteRule
 {
-       
        @Override
        public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state)
                throws HopsException
@@ -67,8 +66,8 @@ public class RewriteInjectSparkPReadCheckpointing extends 
HopRewriteRule
                if(hop.isVisited())
                        return;
                
-               // The reblocking is performed after transform, and hence 
checkpoint only non-transformed reads.   
-               if(    (hop instanceof DataOp && 
((DataOp)hop).getDataOpType()==DataOpTypes.PERSISTENTREAD)
+               // The reblocking is performed after transform, and hence 
checkpoint only non-transformed reads.
+               if( (hop instanceof DataOp && 
((DataOp)hop).getDataOpType()==DataOpTypes.PERSISTENTREAD)
                        || hop.requiresReblock() )
                {
                        //make given hop for checkpointing (w/ default storage 
level)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
index c4b3340..79683fb 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
@@ -49,6 +49,11 @@ import org.apache.sysml.parser.Expression.DataType;
 public class RewriteMarkLoopVariablesUpdateInPlace extends 
StatementBlockRewriteRule
 {
        @Override
+       public boolean createsSplitDag() {
+               return false;
+       }
+       
+       @Override
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus status)
                throws HopsException 
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
index cf7701c..9c2d715 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
@@ -45,6 +45,11 @@ public class RewriteMergeBlockSequence extends 
StatementBlockRewriteRule
                new RewriteCommonSubexpressionElimination(true));
        
        @Override
+       public boolean createsSplitDag() {
+               return false;
+       }
+       
+       @Override
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
                        ProgramRewriteStatus state) throws HopsException {
                return Arrays.asList(sb);

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java
index 5010618..b03763c 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java
@@ -33,6 +33,11 @@ import org.apache.sysml.parser.StatementBlock;
 public class RewriteRemoveEmptyBasicBlocks extends StatementBlockRewriteRule
 {
        @Override
+       public boolean createsSplitDag() {
+               return false;
+       }
+       
+       @Override
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state)
                throws HopsException 
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java
index 52aa7a5..094adb0 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java
@@ -38,6 +38,11 @@ import org.apache.sysml.parser.StatementBlock;
 public class RewriteRemoveUnnecessaryBranches extends StatementBlockRewriteRule
 {
        @Override
+       public boolean createsSplitDag() {
+               return false;
+       }
+       
+       @Override
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state)
                throws HopsException 
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index a3e037f..13f4a50 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -73,6 +73,11 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
        private static IDSequence _seq = new IDSequence();
        
        @Override
+       public boolean createsSplitDag() {
+               return true;
+       }
+       
+       @Override
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state)
                throws HopsException 
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
index 5ec7cb0..5351c0d 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java
@@ -46,6 +46,11 @@ import 
org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
 public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule
 {
        @Override
+       public boolean createsSplitDag() {
+               return true;
+       }
+       
+       @Override
        public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state)
                throws HopsException 
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java 
b/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java
index eba2276..ff2d5bf 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java
@@ -37,6 +37,14 @@ public abstract class StatementBlockRewriteRule
        protected static final Log LOG = 
LogFactory.getLog(StatementBlockRewriteRule.class.getName());
        
        /**
+        * Indicates if the rewrite potentially splits dags, which is used
+        * for phase ordering of rewrites.
+        * 
+        * @return true if dag splits are possible.
+        */
+       public abstract boolean createsSplitDag();
+       
+       /**
         * Handle an arbitrary statement block. Specific type constraints have 
to be ensured
         * within the individual rewrites. If a rewrite does not apply to 
individual blocks, it 
         * should simply return the input block.

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 220be8a..8a0c5df 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -263,7 +263,9 @@ public class DMLTranslator
        {
                //apply hop rewrites (static rewrites)
                ProgramRewriter rewriter = new ProgramRewriter(true, false);
-               rewriter.rewriteProgramHopDAGs(dmlp);
+               rewriter.rewriteProgramHopDAGs(dmlp, false); //rewrite and merge
+               resetHopsDAGVisitStatus(dmlp);
+               rewriter.rewriteProgramHopDAGs(dmlp, true); //rewrite and split
                resetHopsDAGVisitStatus(dmlp);
                
                //propagate size information from main into functions (but 
conservatively)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
index c77d4d0..5f2f030 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
@@ -185,7 +185,7 @@ public class OptimizationWrapper
                                ProgramRewriter rewriter = 
createProgramRewriterWithRuleSets();
                                ProgramRewriteStatus state = new 
ProgramRewriteStatus();
                                rewriter.rRewriteStatementBlockHopDAGs( sb, 
state );
-                               
fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state));
+                               
fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true));
                                if( state.getRemovedBranches() ){
                                        LOG.debug("ParFOR Opt: Removed branches 
during program rewrites, rebuilding runtime program");
                                        
pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(),
 fs.getBody()));

http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
index 584a1fa..d5b09e1 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
@@ -1913,7 +1913,7 @@ public class OptimizerRuleBased extends Optimizer
                        ProgramRewriter rewriter = new ProgramRewriter(rewrite);
                        ProgramRewriteStatus state = new ProgramRewriteStatus();
                        rewriter.rRewriteStatementBlockHopDAGs( pfsb, state );
-                       
fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state));
+                       
fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true));
                        
                        //recompile if additional checkpoints introduced
                        if( state.getInjectedCheckpoints() ) {

Reply via email to