[SYSTEMML-2185] Phase ordering of statement block rewrites (merge/split) There are a number of rewrites that modify the program structure (e.g., branch removal, merge of block sequences, inlining, split of dags after data-dependent operators, checkpointing). The order of these rewrites matters because artificially split dags are not subject to the merge of block sequences which can lose optimization opportunities.
Accordingly, this patch improves the existing static rewrites by a proper phase ordering where we first apply all rewrites that merge and consolidate statement blocks (for maximum CSE) and second apply splitting rewrites to create necessary recompilation hooks. For example on stratstats 100K x 1K, this patch improved end-to-end performance from 1296s to 1080s. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e1fd3431 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e1fd3431 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e1fd3431 Branch: refs/heads/master Commit: e1fd34314e8f68d5626407cdd9d8197948605573 Parents: bcaa140 Author: Matthias Boehm <[email protected]> Authored: Thu Mar 15 14:53:39 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Mar 15 14:53:39 2018 -0700 ---------------------------------------------------------------------- .../hops/ipa/IPAPassApplyStaticHopRewrites.java | 2 +- .../sysml/hops/rewrite/ProgramRewriter.java | 61 ++++++++++---------- .../hops/rewrite/RewriteCompressedReblock.java | 5 ++ .../rewrite/RewriteForLoopVectorization.java | 4 ++ .../RewriteInjectSparkLoopCheckpointing.java | 5 ++ .../RewriteInjectSparkPReadCheckpointing.java | 5 +- .../RewriteMarkLoopVariablesUpdateInPlace.java | 5 ++ .../hops/rewrite/RewriteMergeBlockSequence.java | 5 ++ .../rewrite/RewriteRemoveEmptyBasicBlocks.java | 5 ++ .../RewriteRemoveUnnecessaryBranches.java | 5 ++ .../RewriteSplitDagDataDependentOperators.java | 5 ++ .../rewrite/RewriteSplitDagUnknownCSVRead.java | 5 ++ .../hops/rewrite/StatementBlockRewriteRule.java | 8 +++ .../org/apache/sysml/parser/DMLTranslator.java | 4 +- .../parfor/opt/OptimizationWrapper.java | 2 +- .../parfor/opt/OptimizerRuleBased.java | 2 +- 16 files changed, 92 insertions(+), 36 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java index cae2e35..57bff00 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java @@ -49,7 +49,7 @@ public class IPAPassApplyStaticHopRewrites extends IPAPass rewriter.removeStatementBlockRewrite(RewriteInjectSparkLoopCheckpointing.class); //rewrite program hop dags and statement blocks - rewriter.rewriteProgramHopDAGs(prog); + rewriter.rewriteProgramHopDAGs(prog, true); //rewrite and split } catch (LanguageException ex) { throw new HopsException(ex); http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index a73ea6d..65687e3 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -71,7 +71,7 @@ public class ProgramRewriter this( true, true ); } - public ProgramRewriter( boolean staticRewrites, boolean dynamicRewrites ) + public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<>(); @@ -145,12 +145,11 @@ public class ProgramRewriter * * @param rewrites the HOP rewrite rules */ - public ProgramRewriter( HopRewriteRule... rewrites ) { + public ProgramRewriter(HopRewriteRule... rewrites) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<>(); for( HopRewriteRule rewrite : rewrites ) _dagRuleSet.add( rewrite ); - _sbRuleSet = new ArrayList<>(); } @@ -159,10 +158,9 @@ public class ProgramRewriter * * @param rewrites the statement block rewrite rules */ - public ProgramRewriter( StatementBlockRewriteRule... rewrites ) { + public ProgramRewriter(StatementBlockRewriteRule... rewrites) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<>(); - _sbRuleSet = new ArrayList<>(); for( StatementBlockRewriteRule rewrite : rewrites ) _sbRuleSet.add( rewrite ); @@ -191,9 +189,13 @@ public class ProgramRewriter _sbRuleSet.removeIf(r -> r.getClass().equals(clazz)); } - public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) + public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp) throws LanguageException, HopsException { + return rewriteProgramHopDAGs(dmlp, true); + } + + public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags) throws LanguageException, HopsException - { + { ProgramRewriteStatus state = new ProgramRewriteStatus(); // for each namespace, handle function statement blocks @@ -201,7 +203,8 @@ public class ProgramRewriter for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); rRewriteStatementBlockHopDAGs(fsblock, state); - rRewriteStatementBlock(fsblock, state); + if( !_sbRuleSet.isEmpty() ) + rRewriteStatementBlock(fsblock, state, splitDags); } // handle regular statement blocks in "main" method @@ -209,7 +212,9 @@ public class ProgramRewriter StatementBlock current = dmlp.getStatementBlock(i); rRewriteStatementBlockHopDAGs(current, state); } - dmlp.setStatementBlocks( rRewriteStatementBlocks(dmlp.getStatementBlocks(), state) ); + if( !_sbRuleSet.isEmpty() ) + dmlp.setStatementBlocks(rRewriteStatementBlocks( + dmlp.getStatementBlocks(), state, splitDags)); return state; } @@ -289,7 +294,7 @@ public class ProgramRewriter return root; } - public ArrayList<StatementBlock> rRewriteStatementBlocks( ArrayList<StatementBlock> sbs, ProgramRewriteStatus status ) + public ArrayList<StatementBlock> rRewriteStatementBlocks(ArrayList<StatementBlock> sbs, ProgramRewriteStatus status, boolean splitDags) throws HopsException { //ensure robustness for calls from outside @@ -299,16 +304,18 @@ public class ProgramRewriter //apply rewrite rules to list of statement blocks List<StatementBlock> tmp = sbs; for( StatementBlockRewriteRule r : _sbRuleSet ) - tmp = r.rewriteStatementBlocks(tmp, status); + if( splitDags || !r.createsSplitDag() ) + tmp = r.rewriteStatementBlocks(tmp, status); //recursively rewrite statement blocks (with potential expansion) List<StatementBlock> tmp2 = new ArrayList<>(); for( StatementBlock sb : tmp ) - tmp2.addAll( rRewriteStatementBlock(sb, status) ); + tmp2.addAll( rRewriteStatementBlock(sb, status, splitDags) ); //apply rewrite rules to list of statement blocks (with potential contraction) for( StatementBlockRewriteRule r : _sbRuleSet ) - tmp2 = r.rewriteStatementBlocks(tmp2, status); + if( splitDags || !r.createsSplitDag() ) + tmp2 = r.rewriteStatementBlocks(tmp2, status); //prepare output list sbs.clear(); @@ -316,48 +323,44 @@ public class ProgramRewriter return sbs; } - public ArrayList<StatementBlock> rRewriteStatementBlock( StatementBlock sb, ProgramRewriteStatus status ) + public ArrayList<StatementBlock> rRewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status, boolean splitDags) throws HopsException { ArrayList<StatementBlock> ret = new ArrayList<>(); ret.add(sb); //recursive invocation - if (sb instanceof FunctionStatementBlock) - { + if (sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)sb; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); - fstmt.setBody( rRewriteStatementBlocks(fstmt.getBody(), status) ); + fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags)); } - else if (sb instanceof WhileStatementBlock) - { + else if (sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); - wstmt.setBody( rRewriteStatementBlocks( wstmt.getBody(), status ) ); + wstmt.setBody(rRewriteStatementBlocks(wstmt.getBody(), status, splitDags)); } - else if (sb instanceof IfStatementBlock) - { + else if (sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; IfStatement istmt = (IfStatement)isb.getStatement(0); - istmt.setIfBody( rRewriteStatementBlocks( istmt.getIfBody(), status ) ); - istmt.setElseBody( rRewriteStatementBlocks( istmt.getElseBody(), status ) ); + istmt.setIfBody(rRewriteStatementBlocks(istmt.getIfBody(), status, splitDags)); + istmt.setElseBody(rRewriteStatementBlocks(istmt.getElseBody(), status, splitDags)); } - else if (sb instanceof ForStatementBlock) //incl parfor - { + else if (sb instanceof ForStatementBlock) { //incl parfor //maintain parfor context information (e.g., for checkpointing) boolean prestatus = status.isInParforContext(); if( sb instanceof ParForStatementBlock ) status.setInParforContext(true); - ForStatementBlock fsb = (ForStatementBlock) sb; ForStatement fstmt = (ForStatement)fsb.getStatement(0); - fstmt.setBody( rRewriteStatementBlocks(fstmt.getBody(), status) ); - + fstmt.setBody(rRewriteStatementBlocks(fstmt.getBody(), status, splitDags)); status.setInParforContext(prestatus); } //apply rewrite rules to individual statement blocks for( StatementBlockRewriteRule r : _sbRuleSet ) { + if( !splitDags && r.createsSplitDag() ) + continue; ArrayList<StatementBlock> tmp = new ArrayList<>(); for( StatementBlock sbc : ret ) tmp.addAll( r.rewriteStatementBlock(sbc, status) ); http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java index fdaad10..b24a178 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteCompressedReblock.java @@ -63,6 +63,11 @@ public class RewriteCompressedReblock extends StatementBlockRewriteRule private static final String TMP_PREFIX = "__cmtx"; @Override + public boolean createsSplitDag() { + return false; + } + + @Override public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus sate) throws HopsException { http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java index fce1fa1..cade354 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java @@ -53,6 +53,10 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule private static final OpOp2[] MAP_SCALAR_AGGREGATE_SOURCE_OPS = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT, OpOp2.MIN, OpOp2.MAX}; private static final AggOp[] MAP_SCALAR_AGGREGATE_TARGET_OPS = new AggOp[]{AggOp.SUM, AggOp.PROD, AggOp.MIN, AggOp.MAX}; + @Override + public boolean createsSplitDag() { + return false; + } @Override public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java index ddee8ab..50a5579 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java @@ -54,6 +54,11 @@ public class RewriteInjectSparkLoopCheckpointing extends StatementBlockRewriteRu } @Override + public boolean createsSplitDag() { + return true; + } + + @Override public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) throws HopsException { http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java index 755c2e8..7ee3506 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkPReadCheckpointing.java @@ -35,7 +35,6 @@ import org.apache.sysml.hops.OptimizerUtils; */ public class RewriteInjectSparkPReadCheckpointing extends HopRewriteRule { - @Override public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException @@ -67,8 +66,8 @@ public class RewriteInjectSparkPReadCheckpointing extends HopRewriteRule if(hop.isVisited()) return; - // The reblocking is performed after transform, and hence checkpoint only non-transformed reads. - if( (hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.PERSISTENTREAD) + // The reblocking is performed after transform, and hence checkpoint only non-transformed reads. + if( (hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.PERSISTENTREAD) || hop.requiresReblock() ) { //make given hop for checkpointing (w/ default storage level) http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java index c4b3340..79683fb 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java @@ -49,6 +49,11 @@ import org.apache.sysml.parser.Expression.DataType; public class RewriteMarkLoopVariablesUpdateInPlace extends StatementBlockRewriteRule { @Override + public boolean createsSplitDag() { + return false; + } + + @Override public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) throws HopsException { http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java index cf7701c..9c2d715 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java @@ -45,6 +45,11 @@ public class RewriteMergeBlockSequence extends StatementBlockRewriteRule new RewriteCommonSubexpressionElimination(true)); @Override + public boolean createsSplitDag() { + return false; + } + + @Override public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException { return Arrays.asList(sb); http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java index 5010618..b03763c 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveEmptyBasicBlocks.java @@ -33,6 +33,11 @@ import org.apache.sysml.parser.StatementBlock; public class RewriteRemoveEmptyBasicBlocks extends StatementBlockRewriteRule { @Override + public boolean createsSplitDag() { + return false; + } + + @Override public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException { http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java index 52aa7a5..094adb0 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemoveUnnecessaryBranches.java @@ -38,6 +38,11 @@ import org.apache.sysml.parser.StatementBlock; public class RewriteRemoveUnnecessaryBranches extends StatementBlockRewriteRule { @Override + public boolean createsSplitDag() { + return false; + } + + @Override public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException { http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index a3e037f..13f4a50 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -73,6 +73,11 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite private static IDSequence _seq = new IDSequence(); @Override + public boolean createsSplitDag() { + return true; + } + + @Override public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException { http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java index 5ec7cb0..5351c0d 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java @@ -46,6 +46,11 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule { @Override + public boolean createsSplitDag() { + return true; + } + + @Override public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException { http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java b/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java index eba2276..ff2d5bf 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java @@ -37,6 +37,14 @@ public abstract class StatementBlockRewriteRule protected static final Log LOG = LogFactory.getLog(StatementBlockRewriteRule.class.getName()); /** + * Indicates if the rewrite potentially splits dags, which is used + * for phase ordering of rewrites. + * + * @return true if dag splits are possible. + */ + public abstract boolean createsSplitDag(); + + /** * Handle an arbitrary statement block. Specific type constraints have to be ensured * within the individual rewrites. If a rewrite does not apply to individual blocks, it * should simply return the input block. http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 220be8a..8a0c5df 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -263,7 +263,9 @@ public class DMLTranslator { //apply hop rewrites (static rewrites) ProgramRewriter rewriter = new ProgramRewriter(true, false); - rewriter.rewriteProgramHopDAGs(dmlp); + rewriter.rewriteProgramHopDAGs(dmlp, false); //rewrite and merge + resetHopsDAGVisitStatus(dmlp); + rewriter.rewriteProgramHopDAGs(dmlp, true); //rewrite and split resetHopsDAGVisitStatus(dmlp); //propagate size information from main into functions (but conservatively) http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java index c77d4d0..5f2f030 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java @@ -185,7 +185,7 @@ public class OptimizationWrapper ProgramRewriter rewriter = createProgramRewriterWithRuleSets(); ProgramRewriteStatus state = new ProgramRewriteStatus(); rewriter.rRewriteStatementBlockHopDAGs( sb, state ); - fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state)); + fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true)); if( state.getRemovedBranches() ){ LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program"); pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody())); http://git-wip-us.apache.org/repos/asf/systemml/blob/e1fd3431/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java index 584a1fa..d5b09e1 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java @@ -1913,7 +1913,7 @@ public class OptimizerRuleBased extends Optimizer ProgramRewriter rewriter = new ProgramRewriter(rewrite); ProgramRewriteStatus state = new ProgramRewriteStatus(); rewriter.rRewriteStatementBlockHopDAGs( pfsb, state ); - fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state)); + fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true)); //recompile if additional checkpoints introduced if( state.getInjectedCheckpoints() ) {
