Repository: systemml Updated Branches: refs/heads/master 631079c43 -> c5bec0082
[SYSTEMML-1935] Improved rewrite for merging statement block sequences This patch makes two high-impact improvements to the existing rewrite for merging statement block sequences. First, we now apply rewrites for lists of blocks before and after rewrites for individual blocks. For example, on GLM binomial-probit this allows us to compile the entire function glm_dist into a single HOP DAG. Second, we now check for additional cases, where intermediates between blocks can be deadcode eliminated. Again, on GLM binomial-probit this allows for the automatic removal of large intermediates that are exposed as deadcode after various branch removal and sequence merge rewrites. Overall, this patch improved performance over GLM binomial-probit w/ codegen on a 100M x 10 dense input, and 20/10 max outer/inner iterations from 1,633s to 1,338s. At the same time it creates more codegen opportunities but also more challenging optimization problems. For example, the number of considered plans increased from ~200K to ~1G, but cost-based and structural pruning reduced the number of plans to 151 while still guaranteeing plan optimality. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0b04cb16 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0b04cb16 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0b04cb16 Branch: refs/heads/master Commit: 0b04cb16c5b7ee6f2508fc575330b671b31e2de1 Parents: 631079c Author: Matthias Boehm <[email protected]> Authored: Mon Sep 25 19:59:04 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Tue Sep 26 11:51:49 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 2 +- .../apache/sysml/hops/recompile/Recompiler.java | 26 ++--- .../sysml/hops/rewrite/HopDagValidator.java | 15 ++- .../sysml/hops/rewrite/ProgramRewriter.java | 112 ++++++++----------- .../hops/rewrite/RewriteConstantFolding.java | 5 +- .../hops/rewrite/RewriteMergeBlockSequence.java | 4 +- .../parfor/opt/OptimizationWrapper.java | 4 +- .../parfor/opt/OptimizerRuleBased.java | 4 +- .../functions/recompile/BranchRemovalTest.java | 6 +- .../IPAAssignConstantPropagationTest.java | 36 ++---- 10 files changed, 91 insertions(+), 123 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index a374cf1..af6c72b 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -434,7 +434,7 @@ public class SpoofCompiler ret = constructModifiedHopDag(roots, cplans, clas); //run common subexpression elimination and other rewrites - ret = rewriteCSE.rewriteHopDAGs(ret, new ProgramRewriteStatus()); + ret = rewriteCSE.rewriteHopDAG(ret, new ProgramRewriteStatus()); //explain after modification if( LOG.isTraceEnabled() ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java index a4f8e0f..df6746b 100644 --- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java @@ -179,10 +179,10 @@ public class Recompiler //need for synchronization as we do temp changes in shared hops/lops //however, we create deep copies for most dags to allow for concurrent recompile synchronized( hops ) - { + { LOG.debug ("\n**************** Optimizer (Recompile) *************\nMemory Budget = " + - OptimizerUtils.toMB(OptimizerUtils.getLocalMemBudget()) + " MB"); - + OptimizerUtils.toMB(OptimizerUtils.getLocalMemBudget()) + " MB"); + // prepare hops dag for recompile if( !inplace ){ // deep copy hop dag (for non-reversable rewrites) @@ -194,7 +194,7 @@ public class Recompiler for( Hop hopRoot : hops ) rClearLops( hopRoot ); } - + // replace scalar reads with literals if( !inplace && litreplace ) { Hop.resetVisitStatus(hops); @@ -202,14 +202,14 @@ public class Recompiler rReplaceLiterals( hopRoot, vars, false ); } - // refresh matrix characteristics (update stats) + // refresh matrix characteristics (update stats) Hop.resetVisitStatus(hops); for( Hop hopRoot : hops ) rUpdateStatistics( hopRoot, vars ); // dynamic hop rewrites if( !inplace ) { - _rewriter.get().rewriteHopDAGs( hops, null ); + _rewriter.get().rewriteHopDAG( hops, null ); //update stats after rewrites Hop.resetVisitStatus(hops); @@ -236,12 +236,12 @@ public class Recompiler (status==null || !status.isInitialCodegen())); } - // construct lops + // construct lops Dag<Lop> dag = new Dag<Lop>(); for( Hop hopRoot : hops ){ Lop lops = hopRoot.constructLops(); lops.addToDag(dag); - } + } // generate runtime instructions (incl piggybacking) newInst = dag.getJobs(sb, ConfigurationManager.getDMLConfig()); @@ -309,7 +309,7 @@ public class Recompiler if( !inplace ) { // deep copy hop dag (for non-reversable rewrites) //(this also clears existing lops in the created dag) - hops = deepCopyHopsDag(hops); + hops = deepCopyHopsDag(hops); } else { // clear existing lops @@ -323,7 +323,7 @@ public class Recompiler rReplaceLiterals( hops, vars, false ); } - // refresh matrix characteristics (update stats) + // refresh matrix characteristics (update stats) hops.resetVisitStatus(); rUpdateStatistics( hops, vars ); @@ -341,7 +341,7 @@ public class Recompiler hops.resetVisitStatus(); memo.init(hops, status); hops.resetVisitStatus(); - hops.refreshMemEstimates(memo); + hops.refreshMemEstimates(memo); // codegen if enabled if( ConfigurationManager.isCodegenEnabled() @@ -351,10 +351,10 @@ public class Recompiler (status==null || !status.isInitialCodegen())); } - // construct lops + // construct lops Dag<Lop> dag = new Dag<Lop>(); Lop lops = hops.constructLops(); - lops.addToDag(dag); + lops.addToDag(dag); // generate runtime instructions (incl piggybacking) newInst = dag.getJobs(null, ConfigurationManager.getDMLConfig()); http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java index 39c3afe..7d14532 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java @@ -46,7 +46,9 @@ public class HopDagValidator { private HopDagValidator() {} - public static void validateHopDag(final ArrayList<Hop> roots) throws HopsException { + public static void validateHopDag(ArrayList<Hop> roots, HopRewriteRule rule) + throws HopsException + { if( roots == null ) return; try { @@ -57,13 +59,16 @@ public class HopDagValidator { } catch(HopsException ex) { try { - LOG.error( "\n"+Explain.explainHops(roots) ); + LOG.error("Invalid HOP DAG after rewrite " + rule.getClass().getName() + + ": \n" + Explain.explainHops(roots), ex); }catch(DMLRuntimeException e){} throw ex; } } - public static void validateHopDag(final Hop root) throws HopsException { + public static void validateHopDag(Hop root, HopRewriteRule rule) + throws HopsException + { if( root == null ) return; try { @@ -73,7 +78,8 @@ public class HopDagValidator { } catch(HopsException ex) { try { - LOG.error( "\n"+Explain.explain(root) ); + LOG.error("Invalid HOP DAG after rewrite " + rule.getClass().getName() + + ": \n" + Explain.explain(root), ex); }catch(DMLRuntimeException e){} throw ex; } @@ -91,7 +97,6 @@ public class HopDagValidator { if (seen != hop.isVisited()) { String parentIDs = hop.getParent().stream() .map(h -> Long.toString(h.getHopID())).collect(Collectors.joining(", ")); - //noinspection ConstantConditions check(false, hop, parentIDs, seen); } if (seen) return; // we saw the Hop previously, no need to re-validate http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index acc9da9..737e5e8 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -22,8 +22,6 @@ package org.apache.sysml.hops.rewrite; import java.util.ArrayList; import java.util.List; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.sysml.conf.CompilerConfig.ConfigType; @@ -52,8 +50,6 @@ import org.apache.sysml.parser.WhileStatementBlock; */ public class ProgramRewriter { - private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName()); - //internal local debug level private static final boolean LDEBUG = false; private static final boolean CHECK = false; @@ -99,19 +95,19 @@ public class ProgramRewriter if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) //dependency: simplifications (no need to merge leafs again) - _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); + _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) _dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications _dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock //add statement block rewrite rules if( OptimizerUtils.ALLOW_BRANCH_REMOVAL ) { - _sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding + _sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding _sbRuleSet.add( new RewriteMergeBlockSequence() ); //dependency: remove branches } _sbRuleSet.add( new RewriteCompressedReblock() ); if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS ) - _sbRuleSet.add( new RewriteSplitDagUnknownCSVRead() ); //dependency: reblock, merge blocks + _sbRuleSet.add( new RewriteSplitDagUnknownCSVRead() ); //dependency: reblock, merge blocks if( ConfigurationManager.getCompilerConfigFlag(ConfigType.ALLOW_INDIVIDUAL_SB_SPECIFIC_OPS) ) _sbRuleSet.add( new RewriteSplitDagDataDependentOperators() ); //dependency: merge blocks if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) @@ -137,9 +133,9 @@ public class ProgramRewriter // cleanup after all rewrites applied // (newly introduced operators, introduced redundancy after rewrites w/ multiple parents) - _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); - if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) - _dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); + _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); + if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) + _dagRuleSet.add( new RewriteCommonSubexpressionElimination(true) ); } /** @@ -200,27 +196,25 @@ public class ProgramRewriter // for each namespace, handle function statement blocks for (String namespaceKey : dmlp.getNamespaces().keySet()) - for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) - { + for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); - rewriteStatementBlockHopDAGs(fsblock, state); - rewriteStatementBlock(fsblock, state); + rRewriteStatementBlockHopDAGs(fsblock, state); + rRewriteStatementBlock(fsblock, state); } // handle regular statement blocks in "main" method - for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) - { + for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) { StatementBlock current = dmlp.getStatementBlock(i); - rewriteStatementBlockHopDAGs(current, state); + rRewriteStatementBlockHopDAGs(current, state); } - dmlp.setStatementBlocks( rewriteStatementBlocks(dmlp.getStatementBlocks(), state) ); + dmlp.setStatementBlocks( rRewriteStatementBlocks(dmlp.getStatementBlocks(), state) ); return state; } - public void rewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) + public void rRewriteStatementBlockHopDAGs(StatementBlock current, ProgramRewriteStatus state) throws LanguageException, HopsException - { + { //ensure robustness for calls from outside if( state == null ) state = new ProgramRewriteStatus(); @@ -230,7 +224,7 @@ public class ProgramRewriter FunctionStatementBlock fsb = (FunctionStatementBlock)current; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); for (StatementBlock sb : fstmt.getBody()) - rewriteStatementBlockHopDAGs(sb, state); + rRewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof WhileStatementBlock) { @@ -238,7 +232,7 @@ public class ProgramRewriter WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); wsb.setPredicateHops(rewriteHopDAG(wsb.getPredicateHops(), state)); for (StatementBlock sb : wstmt.getBody()) - rewriteStatementBlockHopDAGs(sb, state); + rRewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof IfStatementBlock) { @@ -246,9 +240,9 @@ public class ProgramRewriter IfStatement istmt = (IfStatement)isb.getStatement(0); isb.setPredicateHops(rewriteHopDAG(isb.getPredicateHops(), state)); for (StatementBlock sb : istmt.getIfBody()) - rewriteStatementBlockHopDAGs(sb, state); + rRewriteStatementBlockHopDAGs(sb, state); for (StatementBlock sb : istmt.getElseBody()) - rewriteStatementBlockHopDAGs(sb, state); + rRewriteStatementBlockHopDAGs(sb, state); } else if (current instanceof ForStatementBlock) //incl parfor { @@ -258,29 +252,22 @@ public class ProgramRewriter fsb.setToHops(rewriteHopDAG(fsb.getToHops(), state)); fsb.setIncrementHops(rewriteHopDAG(fsb.getIncrementHops(), state)); for (StatementBlock sb : fstmt.getBody()) - rewriteStatementBlockHopDAGs(sb, state); + rRewriteStatementBlockHopDAGs(sb, state); } else //generic (last-level) { - current.set_hops( rewriteHopDAGs(current.get_hops(), state) ); + current.set_hops( rewriteHopDAG(current.get_hops(), state) ); } } - public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) + public ArrayList<Hop> rewriteHopDAG(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { - for( HopRewriteRule r : _dagRuleSet ) - { + for( HopRewriteRule r : _dagRuleSet ) { Hop.resetVisitStatus( roots ); //reset for each rule roots = r.rewriteHopDAGs(roots, state); - if( CHECK ) - try { - HopDagValidator.validateHopDag(roots); - } catch (HopsException e) { - LOG.error("Invalid hop after rewriting by " + r.getClass().getName(), e); - throw e; - } + HopDagValidator.validateHopDag(roots, r); } return roots; } @@ -291,23 +278,16 @@ public class ProgramRewriter if( root == null ) return null; - for( HopRewriteRule r : _dagRuleSet ) - { + for( HopRewriteRule r : _dagRuleSet ) { root.resetVisitStatus(); //reset for each rule root = r.rewriteHopDAG(root, state); - if( CHECK ) - try { - HopDagValidator.validateHopDag(root); - } catch (HopsException e) { - LOG.error("Invalid hop after rewriting by " + r.getClass().getName(), e); - throw e; - } + HopDagValidator.validateHopDag(root, r); } return root; } - public ArrayList<StatementBlock> rewriteStatementBlocks( ArrayList<StatementBlock> sbs, ProgramRewriteStatus status ) + public ArrayList<StatementBlock> rRewriteStatementBlocks( ArrayList<StatementBlock> sbs, ProgramRewriteStatus status ) throws HopsException { //ensure robustness for calls from outside @@ -315,22 +295,26 @@ public class ProgramRewriter status = new ProgramRewriteStatus(); //apply rewrite rules to list of statement blocks - List<StatementBlock> sbList = sbs; - for( StatementBlockRewriteRule r : _sbRuleSet ) { - sbList = r.rewriteStatementBlocks(sbList, status); - } + List<StatementBlock> tmp = sbs; + for( StatementBlockRewriteRule r : _sbRuleSet ) + tmp = r.rewriteStatementBlocks(tmp, status); - //rewrite statement blocks (with potential expansion) - ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>(); - for( StatementBlock sb : sbList ) - tmp.addAll( rewriteStatementBlock(sb, status) ); - sbs.clear(); - sbs.addAll( tmp ); + //recursively rewrite statement blocks (with potential expansion) + List<StatementBlock> tmp2 = new ArrayList<StatementBlock>(); + for( StatementBlock sb : tmp ) + tmp2.addAll( rRewriteStatementBlock(sb, status) ); + + //apply rewrite rules to list of statement blocks (with potential contraction) + for( StatementBlockRewriteRule r : _sbRuleSet ) + tmp2 = r.rewriteStatementBlocks(tmp2, status); + //prepare output list + sbs.clear(); + sbs.addAll(tmp2); return sbs; } - private ArrayList<StatementBlock> rewriteStatementBlock( StatementBlock sb, ProgramRewriteStatus status ) + public ArrayList<StatementBlock> rRewriteStatementBlock( StatementBlock sb, ProgramRewriteStatus status ) throws HopsException { ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>(); @@ -341,20 +325,20 @@ public class ProgramRewriter { FunctionStatementBlock fsb = (FunctionStatementBlock)sb; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); - fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) ); + fstmt.setBody( rRewriteStatementBlocks(fstmt.getBody(), status) ); } else if (sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); - wstmt.setBody( rewriteStatementBlocks( wstmt.getBody(), status ) ); + wstmt.setBody( rRewriteStatementBlocks( wstmt.getBody(), status ) ); } else if (sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; IfStatement istmt = (IfStatement)isb.getStatement(0); - istmt.setIfBody( rewriteStatementBlocks( istmt.getIfBody(), status ) ); - istmt.setElseBody( rewriteStatementBlocks( istmt.getElseBody(), status ) ); + istmt.setIfBody( rRewriteStatementBlocks( istmt.getIfBody(), status ) ); + istmt.setElseBody( rRewriteStatementBlocks( istmt.getElseBody(), status ) ); } else if (sb instanceof ForStatementBlock) //incl parfor { @@ -365,18 +349,18 @@ public class ProgramRewriter ForStatementBlock fsb = (ForStatementBlock) sb; ForStatement fstmt = (ForStatement)fsb.getStatement(0); - fstmt.setBody( rewriteStatementBlocks(fstmt.getBody(), status) ); + fstmt.setBody( rRewriteStatementBlocks(fstmt.getBody(), status) ); status.setInParforContext(prestatus); } //apply rewrite rules to individual statement blocks for( StatementBlockRewriteRule r : _sbRuleSet ) { - ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>(); + ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>(); for( StatementBlock sbc : ret ) tmp.addAll( r.rewriteStatementBlock(sbc, status) ); - //take over set of rewritten sbs + //take over set of rewritten sbs ret.clear(); ret.addAll(tmp); } http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java index ec9dcae..a0867e2 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteConstantFolding.java @@ -108,15 +108,14 @@ public class RewriteConstantFolding extends HopRewriteRule LiteralOp literal = null; //fold binary op if both are literals / unary op if literal - if( root.getDataType() == DataType.SCALAR //scalar ouput + if( root.getDataType() == DataType.SCALAR //scalar output && ( isApplicableBinaryOp(root) || isApplicableUnaryOp(root) ) ) { //core constant folding via runtime instructions try { literal = evalScalarOperation(root); } - catch(Exception ex) - { + catch(Exception ex) { LOG.error("Failed to execute constant folding instructions. No abort.", ex); } http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java index 659ad2e..9cba102 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMergeBlockSequence.java @@ -98,7 +98,7 @@ public class RewriteMergeBlockSequence extends StatementBlockRewriteRule } //add remaining roots from s1 to s2 else if( !(HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE) - && twrites.containsKey(root.getName())) ) { + && (twrites.containsKey(root.getName()) || !sb2.liveOut().containsVariable(root.getName()))) ) { sb2Hops.add(root); } } @@ -108,7 +108,7 @@ public class RewriteMergeBlockSequence extends StatementBlockRewriteRule //run common-subexpression elimination Hop.resetVisitStatus(sb2Hops); - rewriter.rewriteHopDAGs(sb2Hops, new ProgramRewriteStatus()); + rewriter.rewriteHopDAG(sb2Hops, new ProgramRewriteStatus()); //modify live variable sets of s2 sb2.setLiveIn(sb1.liveIn()); //liveOut remains unchanged http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java index 75dfc3c..eec2fb7 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java @@ -184,8 +184,8 @@ public class OptimizationWrapper try { ProgramRewriter rewriter = createProgramRewriterWithRuleSets(); ProgramRewriteStatus state = new ProgramRewriteStatus(); - rewriter.rewriteStatementBlockHopDAGs( sb, state ); - fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state)); + rewriter.rRewriteStatementBlockHopDAGs( sb, state ); + fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state)); if( state.getRemovedBranches() ){ LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program"); pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody())); http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java index 72b9566..b80e786 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java @@ -1909,8 +1909,8 @@ public class OptimizerRuleBased extends Optimizer RewriteInjectSparkLoopCheckpointing rewrite = new RewriteInjectSparkLoopCheckpointing(false); ProgramRewriter rewriter = new ProgramRewriter(rewrite); ProgramRewriteStatus state = new ProgramRewriteStatus(); - rewriter.rewriteStatementBlockHopDAGs( pfsb, state ); - fs.setBody(rewriter.rewriteStatementBlocks(fs.getBody(), state)); + rewriter.rRewriteStatementBlockHopDAGs( pfsb, state ); + fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state)); //recompile if additional checkpoints introduced if( state.getInjectedCheckpoints() ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java b/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java index 733df84..2af9ffe 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/recompile/BranchRemovalTest.java @@ -31,7 +31,6 @@ import org.apache.sysml.test.utils.TestUtils; public class BranchRemovalTest extends AutomatedTestBase { - private final static String TEST_NAME = "if_branch_removal"; private final static String TEST_DIR = "functions/recompile/"; private final static String TEST_CLASS_DIR = TEST_DIR + BranchRemovalTest.class.getSimpleName() + "/"; @@ -142,11 +141,8 @@ public class BranchRemovalTest extends AutomatedTestBase //check expected number of compiled and executed MR jobs int expectedNumCompiled = 5; //reblock, 3xGMR (append), write int expectedNumExecuted = 0; - if( branchRemoval && IPA ) + if( branchRemoval ) expectedNumCompiled = 1; //reblock - else if( branchRemoval ){ - expectedNumCompiled = condition ? 4 : 3; //reblock, GMR (append), write - } checkNumCompiledMRJobs(expectedNumCompiled); checkNumExecutedMRJobs(expectedNumExecuted); http://git-wip-us.apache.org/repos/asf/systemml/blob/0b04cb16/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java index 94e4f73..8d0736b 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAAssignConstantPropagationTest.java @@ -31,13 +31,12 @@ import org.apache.sysml.test.utils.TestUtils; public class IPAAssignConstantPropagationTest extends AutomatedTestBase { - private final static String TEST_NAME = "constant_propagation_sb"; private final static String TEST_DIR = "functions/recompile/"; private final static String TEST_CLASS_DIR = TEST_DIR + IPAAssignConstantPropagationTest.class.getSimpleName() + "/"; private final static int rows = 10; - private final static int cols = 15; + private final static int cols = 15; @Override @@ -47,40 +46,27 @@ public class IPAAssignConstantPropagationTest extends AutomatedTestBase addTestConfiguration( TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "X" }) ); } - - @Test - public void testAssignConstantPropagationNoBranchRemovalNoIPA() - { + public void testAssignConstantPropagationNoBranchRemovalNoIPA() { runIPAAssignConstantPropagationTest(false, false); } @Test - public void testAssignConstantPropagationNoBranchRemovalIPA() - { + public void testAssignConstantPropagationNoBranchRemovalIPA() { runIPAAssignConstantPropagationTest(false, true); } @Test - public void testAssignConstantPropagationBranchRemovalNoIPA() - { + public void testAssignConstantPropagationBranchRemovalNoIPA() { runIPAAssignConstantPropagationTest(true, false); } @Test - public void testAssignConstantPropagationBranchRemovalIPA() - { + public void testAssignConstantPropagationBranchRemovalIPA() { runIPAAssignConstantPropagationTest(true, true); } - - - /** - * - * @param condition - * @param branchRemoval - * @param IPA - */ + private void runIPAAssignConstantPropagationTest( boolean branchRemoval, boolean IPA ) { boolean oldFlagBranchRemoval = OptimizerUtils.ALLOW_BRANCH_REMOVAL; @@ -98,7 +84,7 @@ public class IPAAssignConstantPropagationTest extends AutomatedTestBase fullRScriptName = HOME + TEST_NAME + ".R"; rCmd = "Rscript" + " " + fullRScriptName + " " + - Integer.toString(rows) + " " + Integer.toString(cols) + " " + expectedDir(); + Integer.toString(rows) + " " + Integer.toString(cols) + " " + expectedDir(); OptimizerUtils.ALLOW_BRANCH_REMOVAL = branchRemoval; OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA; @@ -112,17 +98,15 @@ public class IPAAssignConstantPropagationTest extends AutomatedTestBase TestUtils.compareMatrices(dmlfile, rfile, 0, "Stat-DML", "Stat-R"); //check expected number of compiled and executed MR jobs - int expectedNumCompiled = ( branchRemoval && IPA ) ? 0 : 1; //rand - int expectedNumExecuted = 0; + int expectedNumCompiled = branchRemoval ? 0 : 1; //rand + int expectedNumExecuted = 0; checkNumCompiledMRJobs(expectedNumCompiled); checkNumExecutedMRJobs(expectedNumExecuted); } - finally - { + finally { OptimizerUtils.ALLOW_BRANCH_REMOVAL = oldFlagBranchRemoval; OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = oldFlagIPA; } } - }
