Repository: systemml
Updated Branches:
  refs/heads/master 2cee9bb9f -> 1995f3569


[SYSTEMML-2077] Extended IPA w/ conditional removal of unused functions

So far, inter-procedural analysis (IPA) removed all unused functions,
i.e., functions that are either inlined in all call locations or simply
not reachable from the main problem. This is especially important for
script with nested imports, which otherwise cause unnecessary
compilation and memory overheads. However, this removal is invalid in
the presence of second-order eval functions, which can executed
dynamically constructed function names and hence was temporarily
disabled.

This patch extends IPA to determine if the function call graph contains
at least one call to eval and conditionally disable the removal of
unused functions accordingly. Thus, the common case without eval is not
affected while support eval calls of "unused" functions if a script
requires this flexibility (e.g., for ensemble learning).
 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1995f356
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1995f356
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1995f356

Branch: refs/heads/master
Commit: 1995f35699b4bd4cb3738b2dc555e6d4c382b9b0
Parents: 108ee7a
Author: Matthias Boehm <[email protected]>
Authored: Thu Mar 8 23:57:04 2018 -0800
Committer: Matthias Boehm <[email protected]>
Committed: Thu Mar 8 23:58:08 2018 -0800

----------------------------------------------------------------------
 .../sysml/hops/ipa/FunctionCallGraph.java       | 68 ++++++++++++++------
 .../java/org/apache/sysml/hops/ipa/IPAPass.java |  3 +-
 .../hops/ipa/IPAPassApplyStaticHopRewrites.java |  2 +-
 .../ipa/IPAPassFlagFunctionsRecompileOnce.java  |  2 +-
 .../sysml/hops/ipa/IPAPassInlineFunctions.java  |  2 +-
 .../ipa/IPAPassPropagateReplaceLiterals.java    |  2 +-
 .../ipa/IPAPassRemoveConstantBinaryOps.java     |  2 +-
 .../IPAPassRemoveUnnecessaryCheckpoints.java    |  2 +-
 .../hops/ipa/IPAPassRemoveUnusedFunctions.java  |  5 +-
 .../sysml/hops/ipa/InterProceduralAnalysis.java | 10 +--
 10 files changed, 63 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java 
b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
index d1ece9b..5105716 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
@@ -32,6 +32,9 @@ import java.util.stream.Collectors;
 import org.apache.sysml.hops.FunctionOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.Hop.OpOpN;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.parser.DMLProgram;
 import org.apache.sysml.parser.ForStatement;
 import org.apache.sysml.parser.ForStatementBlock;
@@ -61,6 +64,8 @@ public class FunctionCallGraph
        //subset of direct or indirect recursive functions
        private final HashSet<String> _fRecursive;
        
+       private final boolean _containsEval;
+       
        /**
         * Constructs the function call graph for all functions
         * reachable from the main program. 
@@ -72,8 +77,7 @@ public class FunctionCallGraph
                _fCalls = new HashMap<>();
                _fCallsSB = new HashMap<>();
                _fRecursive = new HashSet<>();
-               
-               constructFunctionCallGraph(prog);
+               _containsEval = constructFunctionCallGraph(prog);
        }
        
        /**
@@ -87,8 +91,7 @@ public class FunctionCallGraph
                _fCalls = new HashMap<>();
                _fCallsSB = new HashMap<>();
                _fRecursive = new HashSet<>();
-               
-               constructFunctionCallGraph(sb);
+               _containsEval = constructFunctionCallGraph(sb);
        }
 
        /**
@@ -231,74 +234,88 @@ public class FunctionCallGraph
         */
        public boolean isReachableFunction(String fkey) {
                String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey;
-               return _fGraph.containsKey(lfkey);              
+               return _fGraph.containsKey(lfkey);
+       }
+       
+       /**
+        * Indicates if the function call graph, i.e., functions that are 
transitively
+        * reachable from the main program, contains a second-order eval call, 
which
+        * prohibits the removal of unused functions.
+        * 
+        * @return true if the function call graph contains an eval call.
+        */
+       public boolean containsEvalCall() {
+               return _containsEval;
        }
        
-       private void constructFunctionCallGraph(DMLProgram prog) {
+       private boolean constructFunctionCallGraph(DMLProgram prog) {
                if( !prog.hasFunctionStatementBlocks() )
-                       return; //early abort if prog without functions
+                       return false; //early abort if prog without functions
                
+               boolean ret = false;
                try {
                        Stack<String> fstack = new Stack<>();
                        HashSet<String> lfset = new HashSet<>();
                        _fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
                        for( StatementBlock sblk : prog.getStatementBlocks() )
-                               rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, 
sblk, fstack, lfset);
+                               ret |= 
rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sblk, fstack, lfset);
                }
                catch(HopsException ex) {
                        throw new RuntimeException(ex);
                }
+               return ret;
        }
        
-       private void constructFunctionCallGraph(StatementBlock sb) {
+       private boolean constructFunctionCallGraph(StatementBlock sb) {
                if( !sb.getDMLProg().hasFunctionStatementBlocks() )
-                       return; //early abort if prog without functions
+                       return false; //early abort if prog without functions
                
                try {
                        Stack<String> fstack = new Stack<>();
                        HashSet<String> lfset = new HashSet<>();
                        _fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
-                       rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sb, 
fstack, lfset);
+                       return rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, 
sb, fstack, lfset);
                }
                catch(HopsException ex) {
                        throw new RuntimeException(ex);
                }
        }
        
-       private void rConstructFunctionCallGraph(String fkey, StatementBlock 
sb, Stack<String> fstack, HashSet<String> lfset) 
+       private boolean rConstructFunctionCallGraph(String fkey, StatementBlock 
sb, Stack<String> fstack, HashSet<String> lfset) 
                throws HopsException 
        {
+               boolean ret = false;
                if (sb instanceof WhileStatementBlock) {
                        WhileStatement ws = (WhileStatement)sb.getStatement(0);
                        for (StatementBlock current : ws.getBody())
-                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+                               ret |= rConstructFunctionCallGraph(fkey, 
current, fstack, lfset);
                } 
                else if (sb instanceof IfStatementBlock) {
                        IfStatement ifs = (IfStatement) sb.getStatement(0);
                        for (StatementBlock current : ifs.getIfBody())
-                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+                               ret |= rConstructFunctionCallGraph(fkey, 
current, fstack, lfset);
                        for (StatementBlock current : ifs.getElseBody())
-                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+                               ret |= rConstructFunctionCallGraph(fkey, 
current, fstack, lfset);
                } 
                else if (sb instanceof ForStatementBlock) {
                        ForStatement fs = (ForStatement)sb.getStatement(0);
                        for (StatementBlock current : fs.getBody())
-                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+                               ret |= rConstructFunctionCallGraph(fkey, 
current, fstack, lfset);
                } 
                else if (sb instanceof FunctionStatementBlock) {
                        FunctionStatement fsb = (FunctionStatement) 
sb.getStatement(0);
                        for (StatementBlock current : fsb.getBody())
-                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+                               ret |= rConstructFunctionCallGraph(fkey, 
current, fstack, lfset);
                } 
                else {
                        // For generic StatementBlock
                        ArrayList<Hop> hopsDAG = sb.getHops();
                        if( hopsDAG == null || hopsDAG.isEmpty() ) 
-                               return; //nothing to do
+                               return false; //nothing to do
                        
                        //function ops can only occur as root nodes of the dag
                        for( Hop h : hopsDAG ) {
-                               if( h instanceof FunctionOp ){
+                               if( h instanceof FunctionOp ) {
                                        FunctionOp fop = (FunctionOp) h;
                                        String lfkey = fop.getFunctionKey();
                                        //keep all function operators
@@ -322,10 +339,10 @@ public class FunctionCallGraph
                                                _fGraph.get(fkey).add(lfkey);
                                                
                                                FunctionStatementBlock fsb = 
sb.getDMLProg()
-                                                               
.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
+                                                       
.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
                                                FunctionStatement fs = 
(FunctionStatement) fsb.getStatement(0);
                                                for( StatementBlock csb : 
fs.getBody() )
-                                                       
rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>());
+                                                       ret |= 
rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>());
                                                fstack.pop();
                                        }
                                        //recursive function call
@@ -342,7 +359,16 @@ public class FunctionCallGraph
                                        //mark as visited for current function 
call context
                                        lfset.add( lfkey );
                                }
+                               else if( HopRewriteUtils.isData(h, 
DataOpTypes.TRANSIENTWRITE)
+                                       && 
HopRewriteUtils.isNary(h.getInput().get(0), OpOpN.EVAL) ) {
+                                       //NOTE: after 
RewriteSplitDagDataDependentOperators, eval operators
+                                       //will always appear as childs to root 
nodes which allows for an
+                                       //efficient existence check without DAG 
traversal.
+                                       ret = true;
+                               }
                        }
                }
+               
+               return ret;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
index ced407e..f3a4912 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
@@ -36,9 +36,10 @@ public abstract class IPAPass
         * configuration such as global flags or the chosen execution 
         * mode (e.g., hybrid_spark).
         * 
+        * @param fgraph function call graph
         * @return true if applicable.
         */
-       public abstract boolean isApplicable();
+       public abstract boolean isApplicable(FunctionCallGraph fgraph);
        
        /**
         * Rewrites the given program or its functions in place,

http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
index 923259c..cae2e35 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java
@@ -35,7 +35,7 @@ import org.apache.sysml.parser.LanguageException;
 public class IPAPassApplyStaticHopRewrites extends IPAPass
 {
        @Override
-       public boolean isApplicable() {
+       public boolean isApplicable(FunctionCallGraph fgraph) {
                return InterProceduralAnalysis.APPLY_STATIC_REWRITES;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
index 314f0ed..a552df1 100644
--- 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
+++ 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
@@ -44,7 +44,7 @@ import org.apache.sysml.parser.WhileStatementBlock;
 public class IPAPassFlagFunctionsRecompileOnce extends IPAPass
 {
        @Override
-       public boolean isApplicable() {
+       public boolean isApplicable(FunctionCallGraph fgraph) {
                return InterProceduralAnalysis.FLAG_FUNCTION_RECOMPILE_ONCE;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/IPAPassInlineFunctions.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassInlineFunctions.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassInlineFunctions.java
index db19e26..61bc895 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassInlineFunctions.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassInlineFunctions.java
@@ -46,7 +46,7 @@ import org.apache.sysml.parser.StatementBlock;
 public class IPAPassInlineFunctions extends IPAPass
 {
        @Override
-       public boolean isApplicable() {
+       public boolean isApplicable(FunctionCallGraph fgraph) {
                return InterProceduralAnalysis.INLINING_MAX_NUM_OPS > 0;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java
index 664ac7b..2e4649d 100644
--- 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java
+++ 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java
@@ -48,7 +48,7 @@ import 
org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;
 public class IPAPassPropagateReplaceLiterals extends IPAPass
 {
        @Override
-       public boolean isApplicable() {
+       public boolean isApplicable(FunctionCallGraph fgraph) {
                return InterProceduralAnalysis.PROPAGATE_SCALAR_LITERALS;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
index d948128..d79379b 100644
--- 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
+++ 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
@@ -52,7 +52,7 @@ import org.apache.sysml.parser.Expression.DataType;
 public class IPAPassRemoveConstantBinaryOps extends IPAPass
 {
        @Override
-       public boolean isApplicable() {
+       public boolean isApplicable(FunctionCallGraph fgraph) {
                return InterProceduralAnalysis.REMOVE_CONSTANT_BINARY_OPS;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
index 11bce2f..07d694e 100644
--- 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
+++ 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
@@ -50,7 +50,7 @@ import org.apache.sysml.parser.WhileStatementBlock;
 public class IPAPassRemoveUnnecessaryCheckpoints extends IPAPass
 {
        @Override
-       public boolean isApplicable() {
+       public boolean isApplicable(FunctionCallGraph fgraph) {
                return InterProceduralAnalysis.REMOVE_UNNECESSARY_CHECKPOINTS 
                        && OptimizerUtils.isSparkExecutionMode();
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
index 9d41ca6..a4c112c 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
@@ -38,8 +38,9 @@ import org.apache.sysml.parser.LanguageException;
 public class IPAPassRemoveUnusedFunctions extends IPAPass
 {
        @Override
-       public boolean isApplicable() {
-               return InterProceduralAnalysis.REMOVE_UNUSED_FUNCTIONS;
+       public boolean isApplicable(FunctionCallGraph fgraph) {
+               return InterProceduralAnalysis.REMOVE_UNUSED_FUNCTIONS
+                       && !fgraph.containsEvalCall();
        }
        
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/1995f356/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
index 6a5788e..0a182cc 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
@@ -84,12 +84,12 @@ public class InterProceduralAnalysis
 {
        private static final boolean LDEBUG = false; //internal local debug 
level
        private static final Log LOG = 
LogFactory.getLog(InterProceduralAnalysis.class.getName());
-    
+
        //internal configuration parameters
        protected static final boolean INTRA_PROCEDURAL_ANALYSIS      = true; 
//propagate statistics across statement blocks (main/functions)   
        protected static final boolean PROPAGATE_KNOWN_UDF_STATISTICS = true; 
//propagate statistics for known external functions 
        protected static final boolean ALLOW_MULTIPLE_FUNCTION_CALLS  = true; 
//propagate consistent statistics from multiple calls 
-       protected static final boolean REMOVE_UNUSED_FUNCTIONS        = false; 
//remove unused functions (inlined or never called)
+       protected static final boolean REMOVE_UNUSED_FUNCTIONS        = true; 
//remove unused functions (inlined or never called)
        protected static final boolean FLAG_FUNCTION_RECOMPILE_ONCE   = true; 
//flag functions which require recompilation inside a loop for full function 
recompile
        protected static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true; 
//remove unnecessary checkpoints (unconditionally overwritten intermediates) 
        protected static final boolean REMOVE_CONSTANT_BINARY_OPS     = true; 
//remove constant binary operations (e.g., X*ones, where ones=matrix(1,...)) 
@@ -166,7 +166,7 @@ public class InterProceduralAnalysis
         * @throws HopsException in case of compilation errors
         */
        public void analyzeProgram(int repetitions) 
-               throws HopsException    
+               throws HopsException
        {
                //sanity check for valid number of repetitions
                if( repetitions <= 0 )
@@ -201,7 +201,7 @@ public class InterProceduralAnalysis
                        
                        //step 2: apply additional IPA passes
                        for( IPAPass pass : _passes )
-                               if( pass.isApplicable() )
+                               if( pass.isApplicable(_fgraph) )
                                        pass.rewriteProgram(_prog, _fgraph, 
fcallSizes);
                        
                        //early abort without functions or on reached fixpoint
@@ -217,7 +217,7 @@ public class InterProceduralAnalysis
                //cleanup pass: remove unused functions
                FunctionCallGraph graph2 = new FunctionCallGraph(_prog);
                IPAPass rmFuns = new IPAPassRemoveUnusedFunctions();
-               if( rmFuns.isApplicable() )
+               if( rmFuns.isApplicable(graph2) )
                        rmFuns.rewriteProgram(_prog, graph2, null);
        }
        

Reply via email to