Repository: systemml
Updated Branches:
  refs/heads/master d01d13c4b -> 631079c43


[SYSTEMML-1904] Improved codegen for recompile-once functions 

This patch improves the reset of recompilation flags on dynamic
recompilation of "recompile-once" functions. If codegen is enabled, we
now only reset these flags if the execution type is CP and at least the
dimensions are known. This improves the potential of code generation
because the codegen optimizer requires this size information for
checking validity constraints and computing cost estimates.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b27cbf20
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b27cbf20
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b27cbf20

Branch: refs/heads/master
Commit: b27cbf20ccdbcf44612cd4bda61f0685880a32c6
Parents: d01d13c
Author: Matthias Boehm <[email protected]>
Authored: Mon Sep 25 14:28:16 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Sep 25 14:28:16 2017 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/Hop.java    | 16 ++++---
 .../sysml/hops/globalopt/GDFEnumOptimizer.java  | 10 +++--
 .../apache/sysml/hops/recompile/Recompiler.java | 47 ++++++++++++--------
 .../controlprogram/FunctionProgramBlock.java    |  4 +-
 .../parfor/opt/OptimizationWrapper.java         | 22 ++++-----
 5 files changed, 58 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index 7e172ac..165a831 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -28,6 +28,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.hops.recompile.Recompiler.ResetType;
 import org.apache.sysml.lops.Binary;
 import org.apache.sysml.lops.BinaryScalar;
 import org.apache.sysml.lops.CSVReBlock;
@@ -938,30 +939,31 @@ public abstract class Hop implements ParseInfo
                memo.add(getHopID());
        }
 
-       public static void resetRecompilationFlag( ArrayList<Hop> hops, 
ExecType et )
+       public static void resetRecompilationFlag( ArrayList<Hop> hops, 
ExecType et, ResetType reset )
        {
                resetVisitStatus( hops );
                for( Hop hopRoot : hops )
-                       hopRoot.resetRecompilationFlag( et );
+                       hopRoot.resetRecompilationFlag( et, reset );
        }
        
-       public static void resetRecompilationFlag( Hop hops, ExecType et )
+       public static void resetRecompilationFlag( Hop hops, ExecType et, 
ResetType reset )
        {
                hops.resetVisitStatus();
-               hops.resetRecompilationFlag( et );
+               hops.resetRecompilationFlag( et, reset );
        }
        
-       private void resetRecompilationFlag( ExecType et ) 
+       private void resetRecompilationFlag( ExecType et, ResetType reset ) 
        {
                if( isVisited() )
                        return;
                
                //process child hops
                for (Hop h : getInput())
-                       h.resetRecompilationFlag( et );
+                       h.resetRecompilationFlag( et, reset );
                
                //reset recompile flag
-               if( et == null || getExecType() == et || getExecType()==null )
+               if( (et == null || getExecType() == et || getExecType() == null)
+                       && (reset==ResetType.RESET || 
(reset==ResetType.RESET_KNOWN_DIMS && dimsKnown())) )
                        _requiresRecompile = false;
                
                setVisited();

http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java 
b/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java
index 82a3376..7e043be 100644
--- a/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java
+++ b/src/main/java/org/apache/sysml/hops/globalopt/GDFEnumOptimizer.java
@@ -41,6 +41,7 @@ import 
org.apache.sysml.hops.globalopt.gdfresolve.GDFMismatchHeuristic.MismatchH
 import org.apache.sysml.hops.globalopt.gdfresolve.MismatchHeuristicFactory;
 import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.hops.recompile.Recompiler.ResetType;
 import org.apache.sysml.lops.LopsException;
 import org.apache.sysml.lops.LopProperties.ExecType;
 import org.apache.sysml.runtime.DMLRuntimeException;
@@ -137,7 +138,8 @@ public class GDFEnumOptimizer extends GlobalOptimizer
                long finalPlanMismatch = getPlanMismatches();
                
                //generate final runtime plan (w/ optimal config)
-               
Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), new 
LocalVariableMap(), 0, false);
+               
Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(),
+                       new LocalVariableMap(), 0, ResetType.NO_RESET);
                
                ec = ExecutionContextFactory.createContext(prog);
                double optCosts = CostEstimationWrapper.getTimeEstimate(prog, 
ec);
@@ -430,7 +432,8 @@ public class GDFEnumOptimizer extends GlobalOptimizer
                   (p.getNode().getHop()==null || 
p.getNode().getProgramBlock()==null) )
                {
                        //recompile entire runtime program
-                       
Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), new 
LocalVariableMap(), 0, false);
+                       
Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(),
+                               new LocalVariableMap(), 0, ResetType.NO_RESET);
                        _compiledPlans++;
                        
                        //cost entire runtime program
@@ -456,7 +459,8 @@ public class GDFEnumOptimizer extends GlobalOptimizer
                                }
                                
                                //recompile modified runtime program
-                               
Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(), new 
LocalVariableMap(), 0, false);
+                               
Recompiler.recompileProgramBlockHierarchy(prog.getProgramBlocks(),
+                                       new LocalVariableMap(), 0, 
ResetType.NO_RESET);
                                _compiledPlans++;
                                
                                //cost partial runtime program up to current hop

http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
index f22d44a..a4f8e0f 100644
--- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
@@ -118,8 +118,7 @@ import org.apache.sysml.utils.MLContextProxy;
  * 
  */
 public class Recompiler 
-{      
-       
+{
        private static final Log LOG = 
LogFactory.getLog(Recompiler.class.getName());
        
        //Max threshold for in-memory reblock of text input [in bytes]
@@ -133,7 +132,16 @@ public class Recompiler
        /** Local DML configuration for thread-local config updates */
        private static ThreadLocal<ProgramRewriter> _rewriter = new 
ThreadLocal<ProgramRewriter>() {
                @Override protected ProgramRewriter initialValue() { return new 
ProgramRewriter(false, true); }
-    };
+       };
+       
+       public enum ResetType {
+               RESET,
+               RESET_KNOWN_DIMS,
+               NO_RESET;
+               public boolean isReset() {
+                       return this != NO_RESET;
+               }
+       }
        
        /**
         * Re-initializes the recompiler according to the current optimizer 
flags.
@@ -547,7 +555,7 @@ public class Recompiler
                return newInst;
        }
 
-       public static void recompileProgramBlockHierarchy( 
ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, boolean 
resetRecompile ) 
+       public static void recompileProgramBlockHierarchy( 
ArrayList<ProgramBlock> pbs, LocalVariableMap vars, long tid, ResetType 
resetRecompile ) 
                throws DMLRuntimeException
        {
                try 
@@ -797,7 +805,8 @@ public class Recompiler
        // private helper functions //
        //////////////////////////////
        
-       private static void rRecompileProgramBlock( ProgramBlock pb, 
LocalVariableMap vars, RecompileStatus status, long tid, boolean resetRecompile 
) 
+       private static void rRecompileProgramBlock( ProgramBlock pb, 
LocalVariableMap vars, 
+               RecompileStatus status, long tid, ResetType resetRecompile ) 
                throws HopsException, DMLRuntimeException, LopsException, 
IOException
        {
                if (pb instanceof WhileProgramBlock)
@@ -884,11 +893,11 @@ public class Recompiler
                                
Recompiler.extractDAGOutputStatistics(sb.get_hops(), vars);
                                
                                //reset recompilation flags (w/ special 
handling functions)
-                               if(    
ParForProgramBlock.RESET_RECOMPILATION_FLAGs 
+                               if( 
ParForProgramBlock.RESET_RECOMPILATION_FLAGs 
                                        && 
!containsRootFunctionOp(sb.get_hops())  
-                                       && resetRecompile ) 
+                                       && resetRecompile.isReset() ) 
                                {
-                                       
Hop.resetRecompilationFlag(sb.get_hops(), ExecType.CP);
+                                       
Hop.resetRecompilationFlag(sb.get_hops(), ExecType.CP, resetRecompile);
                                        sb.updateRecompilationFlag();
                                }
                        }
@@ -1123,7 +1132,7 @@ public class Recompiler
        
        //helper functions for predicate recompile
        
-       private static void recompileIfPredicate( IfProgramBlock ipb, 
IfStatementBlock isb, LocalVariableMap vars, RecompileStatus status, long tid, 
boolean resetRecompile ) 
+       private static void recompileIfPredicate( IfProgramBlock ipb, 
IfStatementBlock isb, LocalVariableMap vars, RecompileStatus status, long tid, 
ResetType resetRecompile ) 
                throws DMLRuntimeException, HopsException, LopsException, 
IOException
        {
                if( isb == null )
@@ -1135,14 +1144,14 @@ public class Recompiler
                                hops, vars, status, true, false, tid);
                        ipb.setPredicate( tmp );
                        if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs
-                               && resetRecompile ) {
-                               Hop.resetRecompilationFlag(hops, ExecType.CP);
+                               && resetRecompile.isReset() ) {
+                               Hop.resetRecompilationFlag(hops, ExecType.CP, 
resetRecompile);
                                isb.updatePredicateRecompilationFlag();
                        }
                }
        }
        
-       private static void recompileWhilePredicate( WhileProgramBlock wpb, 
WhileStatementBlock wsb, LocalVariableMap vars, RecompileStatus status, long 
tid, boolean resetRecompile ) 
+       private static void recompileWhilePredicate( WhileProgramBlock wpb, 
WhileStatementBlock wsb, LocalVariableMap vars, RecompileStatus status, long 
tid, ResetType resetRecompile ) 
                throws DMLRuntimeException, HopsException, LopsException, 
IOException
        {
                if( wsb == null )
@@ -1154,14 +1163,14 @@ public class Recompiler
                                hops, vars, status, true, false, tid);
                        wpb.setPredicate( tmp );
                        if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs 
-                               && resetRecompile ) {
-                               Hop.resetRecompilationFlag(hops, ExecType.CP);
+                               && resetRecompile.isReset() ) {
+                               Hop.resetRecompilationFlag(hops, ExecType.CP, 
resetRecompile);
                                wsb.updatePredicateRecompilationFlag();
                        }
                }
        }
        
-       private static void recompileForPredicates( ForProgramBlock fpb, 
ForStatementBlock fsb, LocalVariableMap vars, RecompileStatus status, long tid, 
boolean resetRecompile ) 
+       private static void recompileForPredicates( ForProgramBlock fpb, 
ForStatementBlock fsb, LocalVariableMap vars, RecompileStatus status, long tid, 
ResetType resetRecompile ) 
                throws DMLRuntimeException, HopsException, LopsException, 
IOException
        {
                if( fsb != null )
@@ -1172,25 +1181,25 @@ public class Recompiler
                        
                        //handle recompilation flags
                        if( ParForProgramBlock.RESET_RECOMPILATION_FLAGs 
-                               && resetRecompile ) 
+                               && resetRecompile.isReset() ) 
                        {
                                if( fromHops != null ) {
                                        ArrayList<Instruction> tmp = 
recompileHopsDag(
                                                fromHops, vars, status, true, 
false, tid);
                                        fpb.setFromInstructions(tmp);
-                                       
Hop.resetRecompilationFlag(fromHops,ExecType.CP);
+                                       
Hop.resetRecompilationFlag(fromHops,ExecType.CP, resetRecompile);
                                }
                                if( toHops != null ) {
                                        ArrayList<Instruction> tmp = 
recompileHopsDag(
                                                toHops, vars, status, true, 
false, tid);
                                        fpb.setToInstructions(tmp);
-                                       
Hop.resetRecompilationFlag(toHops,ExecType.CP);
+                                       
Hop.resetRecompilationFlag(toHops,ExecType.CP, resetRecompile);
                                }
                                if( incrHops != null ) {
                                        ArrayList<Instruction> tmp = 
recompileHopsDag(
                                                incrHops, vars, status, true, 
false, tid);
                                        fpb.setIncrementInstructions(tmp);
-                                       
Hop.resetRecompilationFlag(incrHops,ExecType.CP);
+                                       
Hop.resetRecompilationFlag(incrHops,ExecType.CP, resetRecompile);
                                }
                                fsb.updatePredicateRecompilationFlags();
                        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
index 830251e..d402502 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
@@ -24,6 +24,7 @@ import java.util.ArrayList;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.hops.recompile.Recompiler.ResetType;
 import org.apache.sysml.parser.DataIdentifier;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.DMLScriptException;
@@ -95,7 +96,8 @@ public class FunctionProgramBlock extends ProgramBlock
                                //     function will be recompiled for every 
execution.
                                // (2) without reset, there would be no benefit 
in recompiling the entire function
                                LocalVariableMap tmp = (LocalVariableMap) 
ec.getVariables().clone();
-                               
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, true);
+                               ResetType reset = 
ConfigurationManager.isCodegenEnabled() ? ResetType.RESET_KNOWN_DIMS : 
ResetType.RESET;
+                               
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, reset);
                                
                                if( DMLScript.STATISTICS ){
                                        long t1 = System.nanoTime();

http://git-wip-us.apache.org/repos/asf/systemml/blob/b27cbf20/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
index 92dd567..75dfc3c 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
@@ -32,6 +32,7 @@ import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.ipa.InterProceduralAnalysis;
 import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.hops.recompile.Recompiler.ResetType;
 import org.apache.sysml.hops.rewrite.HopRewriteRule;
 import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
 import org.apache.sysml.hops.rewrite.ProgramRewriter;
@@ -200,7 +201,9 @@ public class OptimizationWrapper
                                //* clone of variables in order to allow for 
statistics propagation across DAGs
                                //(tid=0, because deep copies created after opt)
                                LocalVariableMap tmp = (LocalVariableMap) 
ec.getVariables().clone();
-                               
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, true);
+                               ResetType reset = 
ConfigurationManager.isCodegenEnabled() ? 
+                                       ResetType.RESET_KNOWN_DIMS : 
ResetType.RESET;
+                               
Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0, reset);
                                
                                //inter-procedural optimization (based on 
previous recompilation)
                                if( pb.hasFunctions() ) {
@@ -215,8 +218,9 @@ public class OptimizationWrapper
                                                        FunctionProgramBlock 
fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
                                                        //reset recompilation 
flags according to recompileOnce because it is only safe if function is 
recompileOnce 
                                                        //because then 
recompiled for every execution (otherwise potential issues if func also called 
outside parfor)
-                                                       
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new 
LocalVariableMap(), 0, fpb.isRecompileOnce());
-                                               }               
+                                                       ResetType reset2 = 
fpb.isRecompileOnce() ? reset : ResetType.NO_RESET;
+                                                       
Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new 
LocalVariableMap(), 0, reset2);
+                                               }
                                        }
                                }
                        }
@@ -230,8 +234,7 @@ public class OptimizationWrapper
                        tree = OptTreeConverter.createOptTree(ck, cm, 
opt.getPlanInputType(), sb, pb, ec); 
                        LOG.debug("ParFOR Opt: Input plan (before 
optimization):\n" + tree.explain(false));
                }
-               catch(Exception ex)
-               {
+               catch(Exception ex) {
                        throw new DMLRuntimeException("Unable to create opt 
tree.", ex);
                }
                
@@ -244,14 +247,12 @@ public class OptimizationWrapper
                LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" 
+ tree.explain(false));
                
                //assert plan correctness
-               if( CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled() )
-               {
+               if( CHECK_PLAN_CORRECTNESS && LOG.isDebugEnabled() ) {
                        try{
                                OptTreePlanChecker.checkProgramCorrectness(pb, 
sb, new HashSet<String>());
                                LOG.debug("ParFOR Opt: Checked plan and program 
correctness.");
                        }
-                       catch(Exception ex)
-                       {
+                       catch(Exception ex) {
                                throw new DMLRuntimeException("Failed to check 
program correctness.", ex);
                        }
                }
@@ -265,8 +266,7 @@ public class OptimizationWrapper
                OptTreeConverter.clear();
                
                //monitor stats
-               if( monitor )
-               {
+               if( monitor ) {
                        StatisticMonitor.putPFStat( pb.getID() , 
Stat.OPT_OPTIMIZER, otype.ordinal());
                        StatisticMonitor.putPFStat( pb.getID() , 
Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
                        StatisticMonitor.putPFStat( pb.getID() , 
Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());

Reply via email to