Repository: systemml Updated Branches: refs/heads/master d01d13c4b -> 631079c43
[SYSTEMML-1904] Improved codegen for recompile-once functions This patch improves the reset of recompilation flags on dynamic recompilation of "recompile-once" functions. If codegen is enabled, we now only reset these flags if the execution type is CP and at least the dimensions are known. This improves the potential of code generation because the codegen optimizer requires this size information for checking validity constraints and computing cost estimates. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b27cbf20 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b27cbf20 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b27cbf20 Branch: refs/heads/master Commit: b27cbf20ccdbcf44612cd4bda61f0685880a32c6 Parents: d01d13c Author: Matthias Boehm <[email protected]> Authored: Mon Sep 25 14:28:16 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Sep 25 14:28:16 2017 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 16 ++++--- .../sysml/hops/globalopt/GDFEnumOptimizer.java | 10 +++-- .../apache/sysml/hops/recompile/Recompiler.java | 47 ++++++++++++-------- .../controlprogram/FunctionProgramBlock.java | 4 +- .../parfor/opt/OptimizationWrapper.java | 22 ++++----- 5 files changed, 58 insertions(+), 41 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 7e172ac..165a831 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -28,6 +28,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.hops.recompile.Recompiler.ResetType; import org.apache.sysml.lops.Binary; import org.apache.sysml.lops.BinaryScalar; import org.apache.sysml.lops.CSVReBlock; @@ -938,30 +939,31 @@ public abstract class Hop implements ParseInfo memo.add(getHopID()); } - public static void resetRecompilationFlag( ArrayList<Hop> hops, ExecType et ) + public static void resetRecompilationFlag( ArrayList<Hop> hops, ExecType et, ResetType reset ) { resetVisitStatus( hops ); for( Hop hopRoot : hops ) - hopRoot.resetRecompilationFlag( et ); + hopRoot.resetRecompilationFlag( et, reset ); } - public static void resetRecompilationFlag( Hop hops, ExecType et ) + public static void resetRecompilationFlag( Hop hops, ExecType et, ResetType reset ) { hops.resetVisitStatus(); - hops.resetRecompilationFlag( et ); + hops.resetRecompilationFlag( et, reset ); } - private void resetRecompilationFlag( ExecType et ) + private void resetRecompilationFlag( ExecType et, ResetType reset ) { if( isVisited() ) return; //process child hops for (Hop h : getInput()) - h.resetRecompilationFlag( et ); + h.resetRecompilationFlag( et, reset ); //reset recompile flag - if( et == null || getExecType() == et || getExecType()==null ) + if( (et == null || getExecType() == et || getExecType() == null) + && (reset==ResetType.RESET || (reset==ResetType.RESET_KNOWN_DIMS && dimsKnown())) ) _requiresRecompile = false; setVisited(); http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java b/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java index 82a3376..7e043be 100644 --- a/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java +++ b/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java @@ -41,6 +41,7 @@ import org.apache.sysml.hops.globalopt.gdfresolve.GDFMismatchHeuristic.MismatchH import org.apache.sysml.hops.globalopt.gdfresolve.MismatchHeuristicFactory; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.hops.recompile.Recompiler.ResetType; import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.runtime.DMLRuntimeException; @@ -137,7 +138,8 @@ public class GDFEnumOptimizer extends GlobalOptimizer long finalPlanMismatch = getPlanMismatches(); //generate final runtime plan (w/ optimal config) - Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), new LocalVariableMap(), 0, false); + Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), + new LocalVariableMap(), 0, ResetType.NO_RESET); ec = ExecutionContextFactory.createContext(prog); double optCosts = CostEstimationWrapper.getTimeEstimate(prog, ec); @@ -430,7 +432,8 @@ public class GDFEnumOptimizer extends GlobalOptimizer (p.getNode().getHop()==null || p.getNode().getProgramBlock()==null) ) { //recompile entire runtime program - Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), new LocalVariableMap(), 0, false); + Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), + new LocalVariableMap(), 0, ResetType.NO_RESET); _compiledPlans++; //cost entire runtime program @@ -456,7 +459,8 @@ public class GDFEnumOptimizer extends GlobalOptimizer } //recompile modified runtime program - Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), new LocalVariableMap(), 0, false); + Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), + new LocalVariableMap(), 0, ResetType.NO_RESET); _compiledPlans++; //cost partial runtime program up to current hop http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java index f22d44a..a4f8e0f 100644 --- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java @@ -118,8 +118,7 @@ import org.apache.sysml.utils.MLContextProxy; * */ public class Recompiler -{ - +{ private static final Log LOG = LogFactory.getLog(Recompiler.class.getName()); //Max threshold for in-memory reblock of text input [in bytes] @@ -133,7 +132,16 @@ public class Recompiler /** Local DML configuration for thread-local config updates */ private static ThreadLocal<ProgramRewriter> _rewriter = new ThreadLocal<ProgramRewriter>() { @Override protected ProgramRewriter initialValue() { return new ProgramRewriter(false, true); } - }; + }; + + public enum ResetType { + RESET, + RESET_KNOWN_DIMS, + NO_RESET; + public boolean isReset() { + return this != NO_RESET; + } + } /** * Re-initializes the recompiler according to the current optimizer flags. @@ -547,7 +555,7 @@ public class Recompiler return newInst; } - public static void recompileProgramBlockHierarchy( ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, boolean resetRecompile ) + public static void recompileProgramBlockHierarchy( ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, ResetType resetRecompile ) throws DMLRuntimeException { try @@ -797,7 +805,8 @@ public class Recompiler // private helper functions // ////////////////////////////// - private static void rRecompileProgramBlock( ProgramBlock pb, LocalVariableMap vars, RecompileStatus status, long tid, boolean resetRecompile ) + private static void rRecompileProgramBlock( ProgramBlock pb, LocalVariableMap vars, + RecompileStatus status, long tid, ResetType resetRecompile ) throws HopsException, DMLRuntimeException, LopsException, IOException { if (pb instanceof WhileProgramBlock) @@ -884,11 +893,11 @@ public class Recompiler Recompiler.extractDAGOutputStatistics(sb.get_hops(), vars); //reset recompilation flags (w/ special handling functions) - if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs + if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs && !containsRootFunctionOp(sb.get_hops()) - && resetRecompile ) + && resetRecompile.isReset() ) { - Hop.resetRecompilationFlag(sb.get_hops(), ExecType.CP); + Hop.resetRecompilationFlag(sb.get_hops(), ExecType.CP, resetRecompile); sb.updateRecompilationFlag(); } } @@ -1123,7 +1132,7 @@ public class Recompiler //helper functions for predicate recompile - private static void recompileIfPredicate( IfProgramBlock ipb, IfStatementBlock isb, LocalVariableMap vars, RecompileStatus status, long tid, boolean resetRecompile ) + private static void recompileIfPredicate( IfProgramBlock ipb, IfStatementBlock isb, LocalVariableMap vars, RecompileStatus status, long tid, ResetType resetRecompile ) throws DMLRuntimeException, HopsException, LopsException, IOException { if( isb == null ) @@ -1135,14 +1144,14 @@ public class Recompiler hops, vars, status, true, false, tid); ipb.setPredicate( tmp ); if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs - && resetRecompile ) { - Hop.resetRecompilationFlag(hops, ExecType.CP); + && resetRecompile.isReset() ) { + Hop.resetRecompilationFlag(hops, ExecType.CP, resetRecompile); isb.updatePredicateRecompilationFlag(); } } } - private static void recompileWhilePredicate( WhileProgramBlock wpb, WhileStatementBlock wsb, LocalVariableMap vars, RecompileStatus status, long tid, boolean resetRecompile ) + private static void recompileWhilePredicate( WhileProgramBlock wpb, WhileStatementBlock wsb, LocalVariableMap vars, RecompileStatus status, long tid, ResetType resetRecompile ) throws DMLRuntimeException, HopsException, LopsException, IOException { if( wsb == null ) @@ -1154,14 +1163,14 @@ public class Recompiler hops, vars, status, true, false, tid); wpb.setPredicate( tmp ); if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs - && resetRecompile ) { - Hop.resetRecompilationFlag(hops, ExecType.CP); + && resetRecompile.isReset() ) { + Hop.resetRecompilationFlag(hops, ExecType.CP, resetRecompile); wsb.updatePredicateRecompilationFlag(); } } } - private static void recompileForPredicates( ForProgramBlock fpb, ForStatementBlock fsb, LocalVariableMap vars, RecompileStatus status, long tid, boolean resetRecompile ) + private static void recompileForPredicates( ForProgramBlock fpb, ForStatementBlock fsb, LocalVariableMap vars, RecompileStatus status, long tid, ResetType resetRecompile ) throws DMLRuntimeException, HopsException, LopsException, IOException { if( fsb != null ) @@ -1172,25 +1181,25 @@ public class Recompiler //handle recompilation flags if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs - && resetRecompile ) + && resetRecompile.isReset() ) { if( fromHops != null ) { ArrayList<Instruction> tmp = recompileHopsDag( fromHops, vars, status, true, false, tid); fpb.setFromInstructions(tmp); - Hop.resetRecompilationFlag(fromHops,ExecType.CP); + Hop.resetRecompilationFlag(fromHops,ExecType.CP, resetRecompile); } if( toHops != null ) { ArrayList<Instruction> tmp = recompileHopsDag( toHops, vars, status, true, false, tid); fpb.setToInstructions(tmp); - Hop.resetRecompilationFlag(toHops,ExecType.CP); + Hop.resetRecompilationFlag(toHops,ExecType.CP, resetRecompile); } if( incrHops != null ) { ArrayList<Instruction> tmp = recompileHopsDag( incrHops, vars, status, true, false, tid); fpb.setIncrementInstructions(tmp); - Hop.resetRecompilationFlag(incrHops,ExecType.CP); + Hop.resetRecompilationFlag(incrHops,ExecType.CP, resetRecompile); } fsb.updatePredicateRecompilationFlags(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java index 830251e..d402502 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import org.apache.sysml.api.DMLScript; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.hops.recompile.Recompiler.ResetType; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; @@ -95,7 +96,8 @@ public class FunctionProgramBlock extends ProgramBlock // function will be recompiled for every execution. // (2) without reset, there would be no benefit in recompiling the entire function LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone(); - Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, true); + ResetType reset = ConfigurationManager.isCodegenEnabled() ? ResetType.RESET_KNOWN_DIMS : ResetType.RESET; + Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, reset); if( DMLScript.STATISTICS ){ long t1 = System.nanoTime(); http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java index 92dd567..75dfc3c 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java @@ -32,6 +32,7 @@ import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ipa.InterProceduralAnalysis; import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.hops.recompile.Recompiler.ResetType; import org.apache.sysml.hops.rewrite.HopRewriteRule; import org.apache.sysml.hops.rewrite.ProgramRewriteStatus; import org.apache.sysml.hops.rewrite.ProgramRewriter; @@ -200,7 +201,9 @@ public class OptimizationWrapper //* clone of variables in order to allow for statistics propagation across DAGs //(tid=0, because deep copies created after opt) LocalVariableMap tmp = (LocalVariableMap) ec.getVariables().clone(); - Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true); + ResetType reset = ConfigurationManager.isCodegenEnabled() ? + ResetType.RESET_KNOWN_DIMS : ResetType.RESET; + Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, reset); //inter-procedural optimization (based on previous recompilation) if( pb.hasFunctions() ) { @@ -215,8 +218,9 @@ public class OptimizationWrapper FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]); //reset recompilation flags according to recompileOnce because it is only safe if function is recompileOnce //because then recompiled for every execution (otherwise potential issues if func also called outside parfor) - Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, fpb.isRecompileOnce()); - } + ResetType reset2 = fpb.isRecompileOnce() ? reset : ResetType.NO_RESET; + Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0, reset2); + } } } } @@ -230,8 +234,7 @@ public class OptimizationWrapper tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec); LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false)); } - catch(Exception ex) - { + catch(Exception ex) { throw new DMLRuntimeException("Unable to create opt tree.", ex); } @@ -244,14 +247,12 @@ public class OptimizationWrapper LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false)); //assert plan correctness - if( CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled() ) - { + if( CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled() ) { try{ OptTreePlanChecker.checkProgramCorrectness(pb, sb, new HashSet<String>()); LOG.debug("ParFOR Opt: Checked plan and program correctness."); } - catch(Exception ex) - { + catch(Exception ex) { throw new DMLRuntimeException("Failed to check program correctness.", ex); } } @@ -265,8 +266,7 @@ public class OptimizationWrapper OptTreeConverter.clear(); //monitor stats - if( monitor ) - { + if( monitor ) { StatisticMonitor.putPFStat( pb.getID() , Stat.OPT_OPTIMIZER, otype.ordinal()); StatisticMonitor.putPFStat( pb.getID() , Stat.OPT_NUMTPLANS, opt.getNumTotalPlans()); StatisticMonitor.putPFStat( pb.getID() , Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
