Repository: systemml
Updated Branches:
  refs/heads/master 631079c43 -> c5bec0082


[SYSTEMML-1935] Improved rewrite for merging statement block sequences

This patch makes two high-impact improvements to the existing rewrite
for merging statement block sequences. 

First, we now apply rewrites for lists of blocks before and after
rewrites for individual blocks. For example, on GLM binomial-probit this
allows us to compile the entire function glm_dist into a single HOP DAG. 

Second, we now check for additional cases, where intermediates between
blocks can be deadcode eliminated. Again, on GLM binomial-probit this
allows for the automatic removal of large intermediates that are exposed
as deadcode after various branch removal and sequence merge rewrites.

Overall, this patch improved performance over GLM binomial-probit w/
codegen on a 100M x 10 dense input, and 20/10 max outer/inner iterations
from 1,633s to 1,338s. At the same time it creates more codegen
opportunities but also more challenging optimization problems. For
example, the number of considered plans increased from ~200K to ~1G, but
cost-based and structural pruning reduced the number of plans to 151
while still guaranteeing plan optimality.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0b04cb16
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0b04cb16
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0b04cb16

Branch: refs/heads/master
Commit: 0b04cb16c5b7ee6f2508fc575330b671b31e2de1
Parents: 631079c
Author: Matthias Boehm <[email protected]>
Authored: Mon Sep 25 19:59:04 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Sep 26 11:51:49 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |   2 +-
 .../apache/sysml/hops/recompile/Recompiler.java |  26 ++---
 .../sysml/hops/rewrite/HopDagValidator.java     |  15 ++-
 .../sysml/hops/rewrite/ProgramRewriter.java     | 112 ++++++++-----------
 .../hops/rewrite/RewriteConstantFolding.java    |   5 +-
 .../hops/rewrite/RewriteMergeBlockSequence.java |   4 +-
 .../parfor/opt/OptimizationWrapper.java         |   4 +-
 .../parfor/opt/OptimizerRuleBased.java          |   4 +-
 .../functions/recompile/BranchRemovalTest.java  |   6 +-
 .../IPAAssignConstantPropagationTest.java       |  36 ++----
 10 files changed, 91 insertions(+), 123 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index a374cf1..af6c72b 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -434,7 +434,7 @@ public class SpoofCompiler
                                ret = constructModifiedHopDag(roots, cplans, 
clas);
                                
                                //run common subexpression elimination and 
other rewrites
-                               ret = rewriteCSE.rewriteHopDAGs(ret, new 
ProgramRewriteStatus());       
+                               ret = rewriteCSE.rewriteHopDAG(ret, new 
ProgramRewriteStatus());        
                                
                                //explain after modification
                                if( LOG.isTraceEnabled() ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
index a4f8e0f..df6746b 100644
--- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
@@ -179,10 +179,10 @@ public class Recompiler
                //need for synchronization as we do temp changes in shared 
hops/lops
                //however, we create deep copies for most dags to allow for 
concurrent recompile
                synchronized( hops ) 
-               {       
+               {
                        LOG.debug ("\n**************** Optimizer (Recompile) 
*************\nMemory Budget = " + 
-                                          
OptimizerUtils.toMB(OptimizerUtils.getLocalMemBudget()) + " MB");
-       
+                                       
OptimizerUtils.toMB(OptimizerUtils.getLocalMemBudget()) + " MB");
+                       
                        // prepare hops dag for recompile
                        if( !inplace ){ 
                                // deep copy hop dag (for non-reversable 
rewrites)
@@ -194,7 +194,7 @@ public class Recompiler
                                for( Hop hopRoot : hops )
                                        rClearLops( hopRoot );
                        }
-
+                       
                        // replace scalar reads with literals 
                        if( !inplace && litreplace ) {
                                Hop.resetVisitStatus(hops);
@@ -202,14 +202,14 @@ public class Recompiler
                                        rReplaceLiterals( hopRoot, vars, false 
);
                        }
                        
-                       // refresh matrix characteristics (update stats)        
                
+                       // refresh matrix characteristics (update stats)
                        Hop.resetVisitStatus(hops);
                        for( Hop hopRoot : hops )
                                rUpdateStatistics( hopRoot, vars );
                        
                        // dynamic hop rewrites
                        if( !inplace ) {
-                               _rewriter.get().rewriteHopDAGs( hops, null );
+                               _rewriter.get().rewriteHopDAG( hops, null );
                                
                                //update stats after rewrites
                                Hop.resetVisitStatus(hops);
@@ -236,12 +236,12 @@ public class Recompiler
                                        (status==null || 
!status.isInitialCodegen()));
                        }
                        
-                       // construct lops                       
+                       // construct lops
                        Dag<Lop> dag = new Dag<Lop>();
                        for( Hop hopRoot : hops ){
                                Lop lops = hopRoot.constructLops();
                                lops.addToDag(dag);     
-                       }               
+                       }
                        
                        // generate runtime instructions (incl piggybacking)
                        newInst = dag.getJobs(sb, 
ConfigurationManager.getDMLConfig());
@@ -309,7 +309,7 @@ public class Recompiler
                        if( !inplace ) {
                                // deep copy hop dag (for non-reversable 
rewrites)
                                //(this also clears existing lops in the 
created dag) 
-                               hops = deepCopyHopsDag(hops);   
+                               hops = deepCopyHopsDag(hops);
                        }
                        else {
                                // clear existing lops
@@ -323,7 +323,7 @@ public class Recompiler
                                rReplaceLiterals( hops, vars, false );
                        }
                        
-                       // refresh matrix characteristics (update stats)        
                
+                       // refresh matrix characteristics (update stats)
                        hops.resetVisitStatus();
                        rUpdateStatistics( hops, vars );
                        
@@ -341,7 +341,7 @@ public class Recompiler
                        hops.resetVisitStatus();
                        memo.init(hops, status);
                        hops.resetVisitStatus();
-                       hops.refreshMemEstimates(memo);                 
+                       hops.refreshMemEstimates(memo);
                        
                        // codegen if enabled
                        if( ConfigurationManager.isCodegenEnabled()
@@ -351,10 +351,10 @@ public class Recompiler
                                        (status==null || 
!status.isInitialCodegen()));
                        }
                        
-                       // construct lops                       
+                       // construct lops
                        Dag<Lop> dag = new Dag<Lop>();
                        Lop lops = hops.constructLops();
-                       lops.addToDag(dag);             
+                       lops.addToDag(dag);
                        
                        // generate runtime instructions (incl piggybacking)
                        newInst = dag.getJobs(null, 
ConfigurationManager.getDMLConfig());

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
index 39c3afe..7d14532 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
@@ -46,7 +46,9 @@ public class HopDagValidator {
 
        private HopDagValidator() {}
        
-       public static void validateHopDag(final ArrayList<Hop> roots) throws 
HopsException {
+       public static void validateHopDag(ArrayList<Hop> roots, HopRewriteRule 
rule) 
+               throws HopsException 
+       {
                if( roots == null )
                        return;
                try {
@@ -57,13 +59,16 @@ public class HopDagValidator {
                }
                catch(HopsException ex) {
                        try {
-                               LOG.error( "\n"+Explain.explainHops(roots) );
+                               LOG.error("Invalid HOP DAG after rewrite " + 
rule.getClass().getName() 
+                                               + ": \n" + 
Explain.explainHops(roots), ex);
                        }catch(DMLRuntimeException e){}
                        throw ex;
                }
        }
        
-       public static void validateHopDag(final Hop root) throws HopsException {
+       public static void validateHopDag(Hop root, HopRewriteRule rule) 
+               throws HopsException 
+       {
                if( root == null )
                        return;
                try {
@@ -73,7 +78,8 @@ public class HopDagValidator {
                }
                catch(HopsException ex) {
                        try {
-                               LOG.error( "\n"+Explain.explain(root) );
+                               LOG.error("Invalid HOP DAG after rewrite " + 
rule.getClass().getName() 
+                                               + ": \n" + 
Explain.explain(root), ex);
                        }catch(DMLRuntimeException e){}
                        throw ex;
                }
@@ -91,7 +97,6 @@ public class HopDagValidator {
                if (seen != hop.isVisited()) {
                        String parentIDs = hop.getParent().stream()
                                        .map(h -> 
Long.toString(h.getHopID())).collect(Collectors.joining(", "));
-                       //noinspection ConstantConditions
                        check(false, hop, parentIDs, seen);
                }
                if (seen) return; // we saw the Hop previously, no need to 
re-validate

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index acc9da9..737e5e8 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -22,8 +22,6 @@ package org.apache.sysml.hops.rewrite;
 import java.util.ArrayList;
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
 import org.apache.sysml.conf.CompilerConfig.ConfigType;
@@ -52,8 +50,6 @@ import org.apache.sysml.parser.WhileStatementBlock;
  */
 public class ProgramRewriter
 {
-       private static final Log LOG = 
LogFactory.getLog(ProgramRewriter.class.getName());
-       
        //internal local debug level
        private static final boolean LDEBUG = false;
        private static final boolean CHECK = false;
@@ -99,19 +95,19 @@ public class ProgramRewriter
                        if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
                                _dagRuleSet.add( new 
RewriteAlgebraicSimplificationStatic()      ); //dependencies: cse
                        if( 
OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )             
//dependency: simplifications (no need to merge leafs again)
-                               _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination()     ); 
+                               _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination()     );
                        if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
                                _dagRuleSet.add( new 
RewriteIndexingVectorization()              ); //dependency: cse, 
simplifications
                        _dagRuleSet.add( new 
RewriteInjectSparkPReadCheckpointing()          ); //dependency: reblock
                        
                        //add statement block rewrite rules
                        if( OptimizerUtils.ALLOW_BRANCH_REMOVAL ) {
-                               _sbRuleSet.add(  new 
RewriteRemoveUnnecessaryBranches()          ); //dependency: constant folding   
           
+                               _sbRuleSet.add(  new 
RewriteRemoveUnnecessaryBranches()          ); //dependency: constant folding
                                _sbRuleSet.add(  new 
RewriteMergeBlockSequence()                 ); //dependency: remove branches
                        }
                        _sbRuleSet.add(      new RewriteCompressedReblock()     
             );
                        if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )
-                               _sbRuleSet.add(  new 
RewriteSplitDagUnknownCSVRead()             ); //dependency: reblock, merge 
blocks 
+                               _sbRuleSet.add(  new 
RewriteSplitDagUnknownCSVRead()             ); //dependency: reblock, merge 
blocks
                        if( 
ConfigurationManager.getCompilerConfigFlag(ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS)
 )
                                _sbRuleSet.add(  new 
RewriteSplitDagDataDependentOperators()     ); //dependency: merge blocks
                        if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION )
@@ -137,9 +133,9 @@ public class ProgramRewriter
                
                // cleanup after all rewrites applied 
                // (newly introduced operators, introduced redundancy after 
rewrites w/ multiple parents) 
-               _dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()        
     );         
-               if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )     
        
-                       _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination(true) );                     
+               _dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()        
     );
+               if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
+                       _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination(true) );
        }
        
        /**
@@ -200,27 +196,25 @@ public class ProgramRewriter
                
                // for each namespace, handle function statement blocks
                for (String namespaceKey : dmlp.getNamespaces().keySet())
-                       for (String fname : 
dmlp.getFunctionStatementBlocks(namespaceKey).keySet())
-                       {
+                       for (String fname : 
dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
                                FunctionStatementBlock fsblock = 
dmlp.getFunctionStatementBlock(namespaceKey,fname);
-                               rewriteStatementBlockHopDAGs(fsblock, state);
-                               rewriteStatementBlock(fsblock, state);
+                               rRewriteStatementBlockHopDAGs(fsblock, state);
+                               rRewriteStatementBlock(fsblock, state);
                        }
                
                // handle regular statement blocks in "main" method
-               for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) 
-               {
+               for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
                        StatementBlock current = dmlp.getStatementBlock(i);
-                       rewriteStatementBlockHopDAGs(current, state);
+                       rRewriteStatementBlockHopDAGs(current, state);
                }
-               dmlp.setStatementBlocks( 
rewriteStatementBlocks(dmlp.getStatementBlocks(), state) );
+               dmlp.setStatementBlocks( 
rRewriteStatementBlocks(dmlp.getStatementBlocks(), state) );
                
                return state;
        }
        
-       public void rewriteStatementBlockHopDAGs(StatementBlock current, 
ProgramRewriteStatus state) 
+       public void rRewriteStatementBlockHopDAGs(StatementBlock current, 
ProgramRewriteStatus state) 
                throws LanguageException, HopsException
-       {       
+       {
                //ensure robustness for calls from outside
                if( state == null )
                        state = new ProgramRewriteStatus();
@@ -230,7 +224,7 @@ public class ProgramRewriter
                        FunctionStatementBlock fsb = 
(FunctionStatementBlock)current;
                        FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
                        for (StatementBlock sb : fstmt.getBody())
-                               rewriteStatementBlockHopDAGs(sb, state);
+                               rRewriteStatementBlockHopDAGs(sb, state);
                }
                else if (current instanceof WhileStatementBlock)
                {
@@ -238,7 +232,7 @@ public class ProgramRewriter
                        WhileStatement wstmt = 
(WhileStatement)wsb.getStatement(0);
                        
wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state));
                        for (StatementBlock sb : wstmt.getBody())
-                               rewriteStatementBlockHopDAGs(sb, state);
+                               rRewriteStatementBlockHopDAGs(sb, state);
                }       
                else if (current instanceof IfStatementBlock)
                {
@@ -246,9 +240,9 @@ public class ProgramRewriter
                        IfStatement istmt = (IfStatement)isb.getStatement(0);
                        
isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state));
                        for (StatementBlock sb : istmt.getIfBody())
-                               rewriteStatementBlockHopDAGs(sb, state);
+                               rRewriteStatementBlockHopDAGs(sb, state);
                        for (StatementBlock sb : istmt.getElseBody())
-                               rewriteStatementBlockHopDAGs(sb, state);
+                               rRewriteStatementBlockHopDAGs(sb, state);
                }
                else if (current instanceof ForStatementBlock) //incl parfor
                {
@@ -258,29 +252,22 @@ public class ProgramRewriter
                        fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state));
                        
fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state));
                        for (StatementBlock sb : fstmt.getBody())
-                               rewriteStatementBlockHopDAGs(sb, state);
+                               rRewriteStatementBlockHopDAGs(sb, state);
                }
                else //generic (last-level)
                {
-                       current.set_hops( rewriteHopDAGs(current.get_hops(), 
state) );
+                       current.set_hops( rewriteHopDAG(current.get_hops(), 
state) );
                }
        }
        
-       public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) 
+       public ArrayList<Hop> rewriteHopDAG(ArrayList<Hop> roots, 
ProgramRewriteStatus state) 
                throws HopsException
        {
-               for( HopRewriteRule r : _dagRuleSet )
-               {
+               for( HopRewriteRule r : _dagRuleSet ) {
                        Hop.resetVisitStatus( roots ); //reset for each rule
                        roots = r.rewriteHopDAGs(roots, state);
-               
                        if( CHECK )
-                               try {
-                                       HopDagValidator.validateHopDag(roots);
-                               } catch (HopsException e) {
-                                       LOG.error("Invalid hop after rewriting 
by " + r.getClass().getName(), e);
-                                       throw e;
-                               }
+                               HopDagValidator.validateHopDag(roots, r);
                }
                return roots;
        }
@@ -291,23 +278,16 @@ public class ProgramRewriter
                if( root == null )
                        return null;
                
-               for( HopRewriteRule r : _dagRuleSet )
-               {
+               for( HopRewriteRule r : _dagRuleSet ) {
                        root.resetVisitStatus(); //reset for each rule
                        root = r.rewriteHopDAG(root, state);
-
                        if( CHECK )
-                               try {
-                                       HopDagValidator.validateHopDag(root);
-                               } catch (HopsException e) {
-                                       LOG.error("Invalid hop after rewriting 
by " + r.getClass().getName(), e);
-                                       throw e;
-                               }
+                               HopDagValidator.validateHopDag(root, r);
                }
                return root;
        }
        
-       public ArrayList<StatementBlock> rewriteStatementBlocks( 
ArrayList<StatementBlock> sbs, ProgramRewriteStatus status ) 
+       public ArrayList<StatementBlock> rRewriteStatementBlocks( 
ArrayList<StatementBlock> sbs, ProgramRewriteStatus status ) 
                throws HopsException
        {
                //ensure robustness for calls from outside
@@ -315,22 +295,26 @@ public class ProgramRewriter
                        status = new ProgramRewriteStatus();
                
                //apply rewrite rules to list of statement blocks
-               List<StatementBlock> sbList = sbs; 
-               for( StatementBlockRewriteRule r : _sbRuleSet ) {
-                       sbList = r.rewriteStatementBlocks(sbList, status);
-               }
+               List<StatementBlock> tmp = sbs; 
+               for( StatementBlockRewriteRule r : _sbRuleSet )
+                       tmp = r.rewriteStatementBlocks(tmp, status);
                
-               //rewrite statement blocks (with potential expansion)
-               ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
-               for( StatementBlock sb : sbList )
-                       tmp.addAll( rewriteStatementBlock(sb, status) );
-               sbs.clear();
-               sbs.addAll( tmp );
+               //recursively rewrite statement blocks (with potential 
expansion)
+               List<StatementBlock> tmp2 = new ArrayList<StatementBlock>();
+               for( StatementBlock sb : tmp )
+                       tmp2.addAll( rRewriteStatementBlock(sb, status) );
+               
+               //apply rewrite rules to list of statement blocks (with 
potential contraction)
+               for( StatementBlockRewriteRule r : _sbRuleSet )
+                       tmp2 = r.rewriteStatementBlocks(tmp2, status);
                
+               //prepare output list
+               sbs.clear();
+               sbs.addAll(tmp2);
                return sbs;
        }
        
-       private ArrayList<StatementBlock> rewriteStatementBlock( StatementBlock 
sb, ProgramRewriteStatus status ) 
+       public ArrayList<StatementBlock> rRewriteStatementBlock( StatementBlock 
sb, ProgramRewriteStatus status ) 
                throws HopsException
        {
                ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
@@ -341,20 +325,20 @@ public class ProgramRewriter
                {
                        FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
                        FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
-                       fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), 
status) );                       
+                       fstmt.setBody( rRewriteStatementBlocks(fstmt.getBody(), 
status) );
                }
                else if (sb instanceof WhileStatementBlock)
                {
                        WhileStatementBlock wsb = (WhileStatementBlock) sb;
                        WhileStatement wstmt = 
(WhileStatement)wsb.getStatement(0);
-                       wstmt.setBody( rewriteStatementBlocks( wstmt.getBody(), 
status ) );
+                       wstmt.setBody( rRewriteStatementBlocks( 
wstmt.getBody(), status ) );
                }       
                else if (sb instanceof IfStatementBlock)
                {
                        IfStatementBlock isb = (IfStatementBlock) sb;
                        IfStatement istmt = (IfStatement)isb.getStatement(0);
-                       istmt.setIfBody( rewriteStatementBlocks( 
istmt.getIfBody(), status ) );
-                       istmt.setElseBody( rewriteStatementBlocks( 
istmt.getElseBody(), status ) );
+                       istmt.setIfBody( rRewriteStatementBlocks( 
istmt.getIfBody(), status ) );
+                       istmt.setElseBody( rRewriteStatementBlocks( 
istmt.getElseBody(), status ) );
                }
                else if (sb instanceof ForStatementBlock) //incl parfor
                {
@@ -365,18 +349,18 @@ public class ProgramRewriter
                        
                        ForStatementBlock fsb = (ForStatementBlock) sb;
                        ForStatement fstmt = (ForStatement)fsb.getStatement(0);
-                       fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), 
status) );
+                       fstmt.setBody( rRewriteStatementBlocks(fstmt.getBody(), 
status) );
                        
                        status.setInParforContext(prestatus);
                }
                
                //apply rewrite rules to individual statement blocks
                for( StatementBlockRewriteRule r : _sbRuleSet ) {
-                       ArrayList<StatementBlock> tmp = new 
ArrayList<StatementBlock>();                        
+                       ArrayList<StatementBlock> tmp = new 
ArrayList<StatementBlock>();
                        for( StatementBlock sbc : ret )
                                tmp.addAll( r.rewriteStatementBlock(sbc, 
status) );
                        
-                       //take over set of rewritten sbs                
+                       //take over set of rewritten sbs
                        ret.clear();
                        ret.addAll(tmp);
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
index ec9dcae..a0867e2 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java
@@ -108,15 +108,14 @@ public class RewriteConstantFolding extends HopRewriteRule
                LiteralOp literal = null;
                
                //fold binary op if both are literals / unary op if literal
-               if(    root.getDataType() == DataType.SCALAR //scalar ouput
+               if( root.getDataType() == DataType.SCALAR //scalar output
                        && ( isApplicableBinaryOp(root) || 
isApplicableUnaryOp(root) ) )
                { 
                        //core constant folding via runtime instructions
                        try {
                                literal = evalScalarOperation(root); 
                        }
-                       catch(Exception ex)
-                       {
+                       catch(Exception ex) {
                                LOG.error("Failed to execute constant folding 
instructions. No abort.", ex);
                        }
                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
index 659ad2e..9cba102 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java
@@ -98,7 +98,7 @@ public class RewriteMergeBlockSequence extends 
StatementBlockRewriteRule
                                                }
                                                //add remaining roots from s1 
to s2
                                                else if( 
!(HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE)
-                                                       && 
twrites.containsKey(root.getName())) ) {
+                                                       && 
(twrites.containsKey(root.getName()) || 
!sb2.liveOut().containsVariable(root.getName()))) ) {
                                                        sb2Hops.add(root);
                                                }
                                        }
@@ -108,7 +108,7 @@ public class RewriteMergeBlockSequence extends 
StatementBlockRewriteRule
                                        
                                        //run common-subexpression elimination
                                        Hop.resetVisitStatus(sb2Hops);
-                                       rewriter.rewriteHopDAGs(sb2Hops, new 
ProgramRewriteStatus());
+                                       rewriter.rewriteHopDAG(sb2Hops, new 
ProgramRewriteStatus());
                                        
                                        //modify live variable sets of s2
                                        sb2.setLiveIn(sb1.liveIn()); //liveOut 
remains unchanged

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
index 75dfc3c..eec2fb7 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
@@ -184,8 +184,8 @@ public class OptimizationWrapper
                        try {
                                ProgramRewriter rewriter = 
createProgramRewriterWithRuleSets();
                                ProgramRewriteStatus state = new 
ProgramRewriteStatus();
-                               rewriter.rewriteStatementBlockHopDAGs( sb, 
state );
-                               
fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
+                               rewriter.rRewriteStatementBlockHopDAGs( sb, 
state );
+                               
fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state));
                                if( state.getRemovedBranches() ){
                                        LOG.debug("ParFOR Opt: Removed branches 
during program rewrites, rebuilding runtime program");
                                        
pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(),
 fs.getBody()));

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
index 72b9566..b80e786 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
@@ -1909,8 +1909,8 @@ public class OptimizerRuleBased extends Optimizer
                        RewriteInjectSparkLoopCheckpointing rewrite = new 
RewriteInjectSparkLoopCheckpointing(false);
                        ProgramRewriter rewriter = new ProgramRewriter(rewrite);
                        ProgramRewriteStatus state = new ProgramRewriteStatus();
-                       rewriter.rewriteStatementBlockHopDAGs( pfsb, state );
-                       
fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state));
+                       rewriter.rRewriteStatementBlockHopDAGs( pfsb, state );
+                       
fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state));
                        
                        //recompile if additional checkpoints introduced
                        if( state.getInjectedCheckpoints() ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java
index 733df84..2af9ffe 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java
@@ -31,7 +31,6 @@ import org.apache.sysml.test.utils.TestUtils;
 
 public class BranchRemovalTest extends AutomatedTestBase 
 {
-       
        private final static String TEST_NAME = "if_branch_removal";
        private final static String TEST_DIR = "functions/recompile/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
BranchRemovalTest.class.getSimpleName() + "/";
@@ -142,11 +141,8 @@ public class BranchRemovalTest extends AutomatedTestBase
                        //check expected number of compiled and executed MR jobs
                        int expectedNumCompiled = 5; //reblock, 3xGMR (append), 
write
                        int expectedNumExecuted = 0;                    
-                       if( branchRemoval && IPA )
+                       if( branchRemoval )
                                expectedNumCompiled = 1; //reblock
-                       else if( branchRemoval ){
-                               expectedNumCompiled = condition ? 4 : 3; 
//reblock, GMR (append), write
-                       }
                        
                        checkNumCompiledMRJobs(expectedNumCompiled); 
                        checkNumExecutedMRJobs(expectedNumExecuted); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java
index 94e4f73..8d0736b 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java
@@ -31,13 +31,12 @@ import org.apache.sysml.test.utils.TestUtils;
 
 public class IPAAssignConstantPropagationTest extends AutomatedTestBase 
 {
-       
        private final static String TEST_NAME = "constant_propagation_sb";
        private final static String TEST_DIR = "functions/recompile/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
IPAAssignConstantPropagationTest.class.getSimpleName() + "/";
        
        private final static int rows = 10;
-       private final static int cols = 15;    
+       private final static int cols = 15;
        
        
        @Override
@@ -47,40 +46,27 @@ public class IPAAssignConstantPropagationTest extends 
AutomatedTestBase
                addTestConfiguration( TEST_NAME, 
                        new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new 
String[] { "X" }) );
        }
-
-       
        
        @Test
-       public void testAssignConstantPropagationNoBranchRemovalNoIPA() 
-       {
+       public void testAssignConstantPropagationNoBranchRemovalNoIPA() {
                runIPAAssignConstantPropagationTest(false, false);
        }
        
        @Test
-       public void testAssignConstantPropagationNoBranchRemovalIPA() 
-       {
+       public void testAssignConstantPropagationNoBranchRemovalIPA() {
                runIPAAssignConstantPropagationTest(false, true);
        }
        
        @Test
-       public void testAssignConstantPropagationBranchRemovalNoIPA() 
-       {
+       public void testAssignConstantPropagationBranchRemovalNoIPA() {
                runIPAAssignConstantPropagationTest(true, false);
        }
        
        @Test
-       public void testAssignConstantPropagationBranchRemovalIPA() 
-       {
+       public void testAssignConstantPropagationBranchRemovalIPA() {
                runIPAAssignConstantPropagationTest(true, true);
        }
-
-
-       /**
-        * 
-        * @param condition
-        * @param branchRemoval
-        * @param IPA
-        */
+       
        private void runIPAAssignConstantPropagationTest( boolean 
branchRemoval, boolean IPA )
        {       
                boolean oldFlagBranchRemoval = 
OptimizerUtils.ALLOW_BRANCH_REMOVAL;
@@ -98,7 +84,7 @@ public class IPAAssignConstantPropagationTest extends 
AutomatedTestBase
                        
                        fullRScriptName = HOME + TEST_NAME + ".R";
                        rCmd = "Rscript" + " " + fullRScriptName + " " + 
-                               Integer.toString(rows) + " " + 
Integer.toString(cols) + " " + expectedDir();                    
+                               Integer.toString(rows) + " " + 
Integer.toString(cols) + " " + expectedDir();
 
                        OptimizerUtils.ALLOW_BRANCH_REMOVAL = branchRemoval;
                        OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA;
@@ -112,17 +98,15 @@ public class IPAAssignConstantPropagationTest extends 
AutomatedTestBase
                        TestUtils.compareMatrices(dmlfile, rfile, 0, 
"Stat-DML", "Stat-R");
                        
                        //check expected number of compiled and executed MR jobs
-                       int expectedNumCompiled = ( branchRemoval && IPA ) ? 0 
: 1; //rand
-                       int expectedNumExecuted = 0;                    
+                       int expectedNumCompiled = branchRemoval ? 0 : 1; //rand
+                       int expectedNumExecuted = 0;
                        
                        checkNumCompiledMRJobs(expectedNumCompiled); 
                        checkNumExecutedMRJobs(expectedNumExecuted); 
                }
-               finally
-               {
+               finally {
                        OptimizerUtils.ALLOW_BRANCH_REMOVAL = 
oldFlagBranchRemoval;
                        OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = 
oldFlagIPA;
                }
        }
-       
 }

Reply via email to