[SYSTEMML-1251] Fix recompile-once recursive functions, cleanup, tests

This patch fixes our inter-procedural analysis, disallowing to mark
directly or indirectly recursive functions for recompile-once because
recompile-once can lead to OOMs and even incorrect results in the
context of recursion. Recompile-once functions are recompiled on entry
with the given input statistics in order to avoid repeated recompilation
in loops, which is often unnecessary knowing the size of function
inputs. However, with recursive functions it is not that easy. A subcall
recompiles the plan again, also modifying the remaining plan for the
calling function - if both recursion levels work on different sizes or
compile literals into the plan, the consequences are disastrous. 

Since multiple places such as EXPLAIN and IPA require information about
the function call graph, this patch introduces a clean abstraction, the
FunctionCallGraph, that captures this information once and simplifies
related compiler passes.

Furthermore, this patch also includes a fix for our -stats tool to
properly reset the number and time of function recompilations.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/0a3d2481
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/0a3d2481
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/0a3d2481

Branch: refs/heads/master
Commit: 0a3d24815258744169367acd8f0287761b5d20de
Parents: 696d10b
Author: Matthias Boehm <[email protected]>
Authored: Sun Feb 12 06:27:58 2017 +0100
Committer: Matthias Boehm <[email protected]>
Committed: Sun Feb 12 06:27:58 2017 +0100

----------------------------------------------------------------------
 .../sysml/hops/ipa/FunctionCallGraph.java       | 247 +++++++++++++++++++
 .../sysml/hops/ipa/InterProceduralAnalysis.java |  38 +--
 .../org/apache/sysml/parser/DMLProgram.java     |   8 +
 .../java/org/apache/sysml/utils/Explain.java    | 131 +++-------
 .../java/org/apache/sysml/utils/Statistics.java |   3 +
 .../RecursiveFunctionRecompileTest.java         | 149 +++++++++++
 .../recompile/recursive_func_direct.dml         |  38 +++
 .../recompile/recursive_func_indirect.dml       |  46 ++++
 .../recompile/recursive_func_indirect2.dml      |  55 +++++
 .../functions/recompile/recursive_func_none.dml |  39 +++
 .../functions/recompile/ZPackageSuite.java      |   1 +
 11 files changed, 645 insertions(+), 110 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java 
b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
new file mode 100644
index 0000000..eed6531
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.ipa;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Stack;
+
+import org.apache.sysml.hops.FunctionOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.ForStatement;
+import org.apache.sysml.parser.ForStatementBlock;
+import org.apache.sysml.parser.FunctionStatement;
+import org.apache.sysml.parser.FunctionStatementBlock;
+import org.apache.sysml.parser.IfStatement;
+import org.apache.sysml.parser.IfStatementBlock;
+import org.apache.sysml.parser.StatementBlock;
+import org.apache.sysml.parser.WhileStatement;
+import org.apache.sysml.parser.WhileStatementBlock;
+
+public class FunctionCallGraph 
+{
+       //internal function key for main program (underscore 
+       //prevents any conflicts with user-defined functions)
+       private static final String MAIN_FUNCTION_KEY = "_main"; 
+       
+       //unrolled function call graph, in call direction
+       //(mapping from function keys to called function keys)
+       private final HashMap<String, HashSet<String>> _fGraph;
+       
+       //subset of direct or indirect recursive functions      
+       private final HashSet<String> _fRecursive;
+       
+       /**
+        * Constructs the function call graph for all functions
+        * reachable from the main program. 
+        * 
+        * @param prog dml program of given script
+        */
+       public FunctionCallGraph(DMLProgram prog) {
+               _fGraph = new HashMap<String, HashSet<String>>();
+               _fRecursive = new HashSet<String>();
+               
+               constructFunctionCallGraph(prog);
+       }
+
+       /**
+        * Returns all functions called from the given function. 
+        * 
+        * @param fnamespace function namespace
+        * @param fname function name
+        * @return list of function keys (namespace and name)
+        */
+       public Collection<String> getCalledFunctions(String fnamespace, String 
fname) {
+               return getCalledFunctions(
+                       DMLProgram.constructFunctionKey(fnamespace, fname));    
                        
+       }
+       
+       /**
+        * Returns all functions called from the given function. 
+        * 
+        * @param fkey function key of calling function, null indicates the 
main program
+        * @return list of function keys (namespace and name)
+        */
+       public Collection<String> getCalledFunctions(String fkey) {
+               String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey;
+               return _fGraph.get(lfkey);
+       }
+       
+       /**
+        * Indicates if the given function is either directly or indirectly 
recursive.
+        * An example of an indirect recursive function is foo2 in the 
following call
+        * chain: foo1 -> foo2 -> foo1.  
+        * 
+        * @param fnamespace function namespace
+        * @param fname function name
+        * @return true if the given function is recursive, false otherwise
+        */
+       public boolean isRecursiveFunction(String fnamespace, String fname) {
+               return isRecursiveFunction(
+                       DMLProgram.constructFunctionKey(fnamespace, fname));    
                
+       }
+       
+       /**
+        * Indicates if the given function is either directly or indirectly 
recursive.
+        * An example of an indirect recursive function is foo2 in the 
following call
+        * chain: foo1 -> foo2 -> foo1.  
+        * 
+        * @param fkey function key of calling function, null indicates the 
main program
+        * @return true if the given function is recursive, false otherwise
+        */
+       public boolean isRecursiveFunction(String fkey) {
+               String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey;
+               return _fRecursive.contains(lfkey);
+       }
+       
+       /**
+        * Returns all functions that are reachable either directly or 
indirectly
+        * form the main program, except the main program itself and the given 
+        * blacklist of function names.
+        * 
+        * @param blacklist list of function keys to exclude
+        * @return list of function keys (namespace and name)
+        */
+       public Collection<String> getReachableFunctions(Collection<String> 
blacklist) {
+               HashSet<String> ret = new HashSet<String>();
+               for( String tmp : _fGraph.keySet() )
+                       if( !blacklist.contains(tmp) && 
!MAIN_FUNCTION_KEY.equals(tmp) )
+                               ret.add(tmp);
+               return ret;
+       }
+       
+       /**
+        * Indicates if the given function is reachable either directly or 
indirectly
+        * from the main program.
+        * 
+        * @param fnamespace function namespace
+        * @param fname function name
+        * @return true if the given function is reachable, false otherwise
+        */
+       public boolean isReachableFunction(String fnamespace, String fname) {
+               return isReachableFunction(
+                       DMLProgram.constructFunctionKey(fnamespace, fname));
+       }
+       
+       /**
+        * Indicates if the given function is reachable either directly or 
indirectly
+        * from the main program.
+        * 
+        * @param fkey function key of calling function, null indicates the 
main program
+        * @return true if the given function is reachable, false otherwise
+        */
+       public boolean isReachableFunction(String fkey) {
+               String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey;
+               return _fGraph.containsKey(lfkey);              
+       }
+       
+       private void constructFunctionCallGraph(DMLProgram prog) {
+               if( !prog.hasFunctionStatementBlocks() )
+                       return; //early abort if prog without functions
+                       
+               try {
+                       Stack<String> fstack = new Stack<String>();
+                       HashSet<String> lfset = new HashSet<String>();
+                       _fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
+                       for( StatementBlock sblk : prog.getStatementBlocks() )
+                               rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, 
sblk, fstack, lfset);
+               }
+               catch(HopsException ex) {
+                       throw new RuntimeException(ex);
+               }
+       }
+       
+       private void rConstructFunctionCallGraph(String fkey, StatementBlock 
sb, Stack<String> fstack, HashSet<String> lfset) 
+               throws HopsException 
+       {
+               if (sb instanceof WhileStatementBlock) {
+                       WhileStatement ws = (WhileStatement)sb.getStatement(0);
+                       for (StatementBlock current : ws.getBody())
+                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+               } 
+               else if (sb instanceof IfStatementBlock) {
+                       IfStatement ifs = (IfStatement) sb.getStatement(0);
+                       for (StatementBlock current : ifs.getIfBody())
+                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+                       for (StatementBlock current : ifs.getElseBody())
+                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+               } 
+               else if (sb instanceof ForStatementBlock) {
+                       ForStatement fs = (ForStatement)sb.getStatement(0);
+                       for (StatementBlock current : fs.getBody())
+                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+               } 
+               else if (sb instanceof FunctionStatementBlock) {
+                       FunctionStatement fsb = (FunctionStatement) 
sb.getStatement(0);
+                       for (StatementBlock current : fsb.getBody())
+                               rConstructFunctionCallGraph(fkey, current, 
fstack, lfset);
+               } 
+               else {
+                       // For generic StatementBlock
+                       ArrayList<Hop> hopsDAG = sb.get_hops();
+                       if( hopsDAG == null || hopsDAG.isEmpty() ) 
+                               return; //nothing to do
+                       
+                       //function ops can only occur as root nodes of the dag
+                       for( Hop h : hopsDAG ) {
+                               if( h instanceof FunctionOp ){
+                                       FunctionOp fop = (FunctionOp) h;
+                                       String lfkey = 
DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), 
fop.getFunctionName());
+                                       //prevent redundant call edges
+                                       if( lfset.contains(lfkey) || 
fop.getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) )
+                                               continue;
+                                               
+                                       if( !_fGraph.containsKey(lfkey) )
+                                               _fGraph.put(lfkey, new 
HashSet<String>());
+                                               
+                                       //recursively construct function call 
dag
+                                       if( !fstack.contains(lfkey) ) {
+                                               fstack.push(lfkey);
+                                               _fGraph.get(fkey).add(lfkey);
+                                               
+                                               FunctionStatementBlock fsb = 
sb.getDMLProg()
+                                                               
.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
+                                               FunctionStatement fs = 
(FunctionStatement) fsb.getStatement(0);
+                                               for( StatementBlock csb : 
fs.getBody() )
+                                                       
rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>());
+                                               fstack.pop();
+                                       }
+                                       //recursive function call
+                                       else {
+                                               _fGraph.get(fkey).add(lfkey);
+                                               _fRecursive.add(lfkey);
+                                       
+                                               //mark indirectly recursive 
functions as recursive
+                                               int ix = fstack.indexOf(lfkey);
+                                               for( int i=ix+1; 
i<fstack.size(); i++ )
+                                                       
_fRecursive.add(fstack.get(i));
+                                       }
+                                       
+                                       //mark as visited for current function 
call context
+                                       lfset.add( lfkey );
+                               }
+                       }
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
index 5d94a68..caab391 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
@@ -29,7 +29,6 @@ import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
 
-import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
@@ -157,26 +156,25 @@ public class InterProceduralAnalysis
         * @throws ParseException if ParseException occurs
         * @throws LanguageException if LanguageException occurs
         */
-       @SuppressWarnings("unchecked")
        public void analyzeProgram( DMLProgram dmlp ) 
                throws HopsException, ParseException, LanguageException
        {
+               FunctionCallGraph fgraph = new FunctionCallGraph(dmlp);
+               
                //step 1: get candidates for statistics propagation into 
functions (if required)
                Map<String, Integer> fcandCounts = new HashMap<String, 
Integer>();
                Map<String, FunctionOp> fcandHops = new HashMap<String, 
FunctionOp>();
                Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, 
Set<Long>>(); 
-               Set<String> allFCandKeys = new HashSet<String>();
                if( !dmlp.getFunctionStatementBlocks().isEmpty() ) {
                        for ( StatementBlock sb : dmlp.getStatementBlocks() ) 
//get candidates (over entire program)
                                getFunctionCandidatesForStatisticPropagation( 
sb, fcandCounts, fcandHops );
-                       allFCandKeys.addAll(fcandCounts.keySet()); //cp before 
pruning
                        pruneFunctionCandidatesForStatisticPropagation( 
fcandCounts, fcandHops );       
                        determineFunctionCandidatesNNZPropagation( fcandHops, 
fcandSafeNNZ );
                        DMLTranslator.resetHopsDAGVisitStatus( dmlp );
                }
                
                //step 2: get unary dimension-preserving non-candidate functions
-               Collection<String> unaryFcandTmp = 
CollectionUtils.subtract(allFCandKeys, fcandCounts.keySet());
+               Collection<String> unaryFcandTmp = 
fgraph.getReachableFunctions(fcandCounts.keySet());
                HashSet<String> unaryFcands = new HashSet<String>();
                if( !unaryFcandTmp.isEmpty() && UNARY_DIMS_PRESERVING_FUNS ) {
                        for( String tmp : unaryFcandTmp )
@@ -194,12 +192,12 @@ public class InterProceduralAnalysis
                
                //step 4: remove unused functions (e.g., inlined or never 
called)
                if( REMOVE_UNUSED_FUNCTIONS ) {
-                       removeUnusedFunctions( dmlp, allFCandKeys );
+                       removeUnusedFunctions( dmlp, fgraph );
                }
                
                //step 5: flag functions with loops for 'recompile-on-entry'
                if( FLAG_FUNCTION_RECOMPILE_ONCE ) {
-                       flagFunctionsForRecompileOnce( dmlp );
+                       flagFunctionsForRecompileOnce( dmlp, fgraph );
                }
                
                //step 6: set global data flow properties
@@ -871,22 +869,21 @@ public class InterProceduralAnalysis
        // REMOVE UNUSED FUNCTIONS
        //////
 
-       public void removeUnusedFunctions( DMLProgram dmlp, Set<String> 
fcandKeys )
+       public void removeUnusedFunctions( DMLProgram dmlp, FunctionCallGraph 
fgraph )
                throws LanguageException
        {
                Set<String> fnamespaces = dmlp.getNamespaces().keySet();
-               for( String fnspace : fnamespaces  )
-               {
+               for( String fnspace : fnamespaces  ) {
                        HashMap<String, FunctionStatementBlock> fsbs = 
dmlp.getFunctionStatementBlocks(fnspace);
                        Iterator<Entry<String, FunctionStatementBlock>> iter = 
fsbs.entrySet().iterator();
-                       while( iter.hasNext() )
-                       {
+                       while( iter.hasNext() ) {
                                Entry<String, FunctionStatementBlock> e = 
iter.next();
-                               String fname = e.getKey();
-                               String fKey = 
DMLProgram.constructFunctionKey(fnspace, fname);
-                               //probe function candidates, remove if no 
candidate
-                               if( !fcandKeys.contains(fKey) )
+                               if( !fgraph.isReachableFunction(fnspace, 
e.getKey()) ) {
                                        iter.remove();
+                                       if( LOG.isDebugEnabled() )
+                                               LOG.debug("IPA: Removed unused 
function: " + 
+                                                       
DMLProgram.constructFunctionKey(fnspace, e.getKey()));
+                               }
                        }
                }
        }
@@ -902,17 +899,20 @@ public class InterProceduralAnalysis
         * @param dmlp the DML program
         * @throws LanguageException if LanguageException occurs
         */
-       public void flagFunctionsForRecompileOnce( DMLProgram dmlp ) 
+       public void flagFunctionsForRecompileOnce( DMLProgram dmlp, 
FunctionCallGraph fgraph ) 
                throws LanguageException
        {
                for (String namespaceKey : dmlp.getNamespaces().keySet())
                        for (String fname : 
dmlp.getFunctionStatementBlocks(namespaceKey).keySet())
                        {
                                FunctionStatementBlock fsblock = 
dmlp.getFunctionStatementBlock(namespaceKey,fname);
-                               if( rFlagFunctionForRecompileOnce( fsblock, 
false ) ) 
+                               if( !fgraph.isRecursiveFunction(namespaceKey, 
fname) &&
+                                       rFlagFunctionForRecompileOnce( fsblock, 
false ) ) 
                                {
                                        fsblock.setRecompileOnce( true ); 
-                                       LOG.debug("IPA: FUNC flagged for 
recompile-once: " + DMLProgram.constructFunctionKey(namespaceKey, fname));
+                                       if( LOG.isDebugEnabled() )
+                                               LOG.debug("IPA: FUNC flagged 
for recompile-once: " + 
+                                                       
DMLProgram.constructFunctionKey(namespaceKey, fname));
                                }
                        }
        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/parser/DMLProgram.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLProgram.java 
b/src/main/java/org/apache/sysml/parser/DMLProgram.java
index 292a13c..9fbd63c 100644
--- a/src/main/java/org/apache/sysml/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysml/parser/DMLProgram.java
@@ -107,6 +107,14 @@ public class DMLProgram
                return namespaceProgram._functionBlocks;
        }
        
+       public boolean hasFunctionStatementBlocks() {
+               boolean ret = false;
+               for( DMLProgram nsProg : _namespaces.values() )
+                       ret |= !nsProg._functionBlocks.isEmpty();
+               
+               return ret;
+       }
+       
        public ArrayList<FunctionStatementBlock> getFunctionStatementBlocks() 
                throws LanguageException
        {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java 
b/src/main/java/org/apache/sysml/utils/Explain.java
index 097e2b7..ccc9853 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -21,12 +21,12 @@ package org.apache.sysml.utils;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Map.Entry;
 
 import org.apache.sysml.api.DMLException;
-import org.apache.sysml.hops.FunctionOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.Hop.VisitStatus;
 import org.apache.sysml.hops.HopsException;
@@ -35,6 +35,7 @@ import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.globalopt.gdfgraph.GDFLoopNode;
 import org.apache.sysml.hops.globalopt.gdfgraph.GDFNode;
 import org.apache.sysml.hops.globalopt.gdfgraph.GDFNode.NodeType;
+import org.apache.sysml.hops.ipa.FunctionCallGraph;
 import org.apache.sysml.lops.Lop;
 import org.apache.sysml.parser.DMLProgram;
 import org.apache.sysml.parser.ForStatement;
@@ -206,32 +207,28 @@ public class Explain
                sb.append("\nPROGRAM\n");
                                                
                // Explain functions (if exists)
-               boolean firstFunction = true;
-               for (String namespace : prog.getNamespaces().keySet()) {
-                       for (String fname : 
prog.getFunctionStatementBlocks(namespace).keySet()) {
-                               if (firstFunction) {
-                                       sb.append("--FUNCTIONS\n");
-                                       firstFunction = false;
+               if( prog.hasFunctionStatementBlocks() ) {
+                       sb.append("--FUNCTIONS\n");
+                       
+                       //show function call graph
+                       sb.append("----FUNCTION CALL GRAPH\n");
+                       sb.append("------MAIN PROGRAM\n");
+                       FunctionCallGraph fgraph = new FunctionCallGraph(prog);
+                       sb.append(explainFunctionCallGraph(fgraph, new 
HashSet<String>(), null, 3));
+               
+                       //show individual functions
+                       for (String namespace : prog.getNamespaces().keySet()) {
+                               for (String fname : 
prog.getFunctionStatementBlocks(namespace).keySet()) {
+                                       FunctionStatementBlock fsb = 
prog.getFunctionStatementBlock(namespace, fname);
+                                       FunctionStatement fstmt = 
(FunctionStatement) fsb.getStatement(0);
                                        
-                                       //show function call dag
-                                       sb.append("----FUNCTION CALL DAG\n");
-                                       sb.append("------MAIN PROGRAM\n");
-                                       HashSet<String> fstack = new 
HashSet<String>();
-                                       HashSet<String> lfset = new 
HashSet<String>();
-                                       for( StatementBlock sblk : 
prog.getStatementBlocks() )
-                                               
sb.append(explainFunctionCallDag(sblk, fstack, lfset, 3));
-                               }
-                               
-                               //show individual functions
-                               FunctionStatementBlock fsb = 
prog.getFunctionStatementBlock(namespace, fname);
-                               FunctionStatement fstmt = (FunctionStatement) 
fsb.getStatement(0);
-                               
-                               if (fstmt instanceof ExternalFunctionStatement)
-                                       sb.append("----EXTERNAL FUNCTION " + 
namespace + "::" + fname + "\n");
-                               else {
-                                       sb.append("----FUNCTION " + namespace + 
"::" + fname + " [recompile="+fsb.isRecompileOnce()+"]\n");
-                                       for (StatementBlock current : 
fstmt.getBody())
-                                               
sb.append(explainStatementBlock(current, 3));
+                                       if (fstmt instanceof 
ExternalFunctionStatement)
+                                               sb.append("----EXTERNAL 
FUNCTION " + namespace + "::" + fname + "\n");
+                                       else {
+                                               sb.append("----FUNCTION " + 
namespace + "::" + fname + " [recompile="+fsb.isRecompileOnce()+"]\n");
+                                               for (StatementBlock current : 
fstmt.getBody())
+                                                       
sb.append(explainStatementBlock(current, 3));
+                                       }
                                }
                        }
                }
@@ -267,17 +264,15 @@ public class Explain
                {
                        sb.append("--FUNCTIONS\n");
                        
-                       //show function call dag
+                       //show function call graph
                        if( !rtprog.getProgramBlocks().isEmpty() &&
                                
rtprog.getProgramBlocks().get(0).getStatementBlock() != null )
                        {
-                               sb.append("----FUNCTION CALL DAG\n");
+                               sb.append("----FUNCTION CALL GRAPH\n");
                                sb.append("------MAIN PROGRAM\n");
                                DMLProgram prog = 
rtprog.getProgramBlocks().get(0).getStatementBlock().getDMLProg();
-                               HashSet<String> fstack = new HashSet<String>();
-                               HashSet<String> lfset = new HashSet<String>();
-                               for( StatementBlock sblk : 
prog.getStatementBlocks() )
-                                       sb.append(explainFunctionCallDag(sblk, 
fstack, lfset, 3));
+                               FunctionCallGraph fgraph = new 
FunctionCallGraph(prog);
+                               sb.append(explainFunctionCallGraph(fgraph, new 
HashSet<String>(), null, 3));
                        }
                        
                        //show individual functions
@@ -932,68 +927,22 @@ public class Explain
                return ret;
        }
 
-       private static String explainFunctionCallDag(StatementBlock sb, 
HashSet<String> fstack, HashSet<String> lfset, int level) 
+       private static String explainFunctionCallGraph(FunctionCallGraph 
fgraph, HashSet<String> fstack, String fkey, int level) 
                throws HopsException 
        {
                StringBuilder builder = new StringBuilder();
-               
-               if (sb instanceof WhileStatementBlock) {
-                       WhileStatement ws = (WhileStatement)sb.getStatement(0);
-                       for (StatementBlock current : ws.getBody())
-                               builder.append(explainFunctionCallDag(current, 
fstack, lfset, level));
-               } 
-               else if (sb instanceof IfStatementBlock) {
-                       IfStatement ifs = (IfStatement) sb.getStatement(0);
-                       for (StatementBlock current : ifs.getIfBody())
-                               builder.append(explainFunctionCallDag(current, 
fstack, lfset, level));
-                       for (StatementBlock current : ifs.getElseBody())
-                               builder.append(explainFunctionCallDag(current, 
fstack, lfset, level));
-               } 
-               else if (sb instanceof ForStatementBlock) {
-                       ForStatement fs = (ForStatement)sb.getStatement(0);
-                       for (StatementBlock current : fs.getBody())
-                               builder.append(explainFunctionCallDag(current, 
fstack, lfset, level));
-               } 
-               else if (sb instanceof FunctionStatementBlock) {
-                       FunctionStatement fsb = (FunctionStatement) 
sb.getStatement(0);
-                       for (StatementBlock current : fsb.getBody())
-                               builder.append(explainFunctionCallDag(current, 
fstack, lfset, level));
-               } 
-               else {
-                       // For generic StatementBlock
-                       ArrayList<Hop> hopsDAG = sb.get_hops();
-                       if( hopsDAG != null && !hopsDAG.isEmpty() ) {
-                               //function ops can only occur as root nodes of 
the dag
-                               for( Hop h : hopsDAG )
-                                       if( h instanceof FunctionOp ){
-                                               FunctionOp fop = (FunctionOp) h;
-                                               String fkey = 
DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), 
fop.getFunctionName());
-                                               //prevent redundant call edges
-                                               if( !lfset.contains(fkey) && 
!fop.getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) )
-                                               {
-                                                       //recursively explain 
function call dag
-                                                       if( 
!fstack.contains(fkey) ) {
-                                                               
fstack.add(fkey);
-                                                               String offset = 
createOffset(level);
-                                                               
builder.append(offset + "--" + fkey + "\n");
-                                                               
FunctionStatementBlock fsb = sb.getDMLProg()
-                                                                               
.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
-                                                               
FunctionStatement fs = (FunctionStatement) fsb.getStatement(0);
-                                                               HashSet<String> 
lfset2 = new HashSet<String>(); 
-                                                               for( 
StatementBlock csb : fs.getBody() )
-                                                                       
builder.append(explainFunctionCallDag(csb, fstack, lfset2, level+1));
-                                                               
fstack.remove(fkey);
-                                                       }
-                                                       //recursive function 
call
-                                                       else {
-                                                               String offset = 
createOffset(level);
-                                                               
builder.append(offset + "-->" + fkey + " (recursive)\n");
-                                                       }
-                                                       
-                                                       //mark as visited for 
current function call context
-                                                       lfset.add( fkey );
-                                               }
-                                       }
+               String offset = createOffset(level);
+               Collection<String> cfkeys = fgraph.getCalledFunctions(fkey);
+               if( cfkeys != null ) {
+                       for( String cfkey : cfkeys ) {
+                               if( fstack.contains(cfkey) && 
fgraph.isRecursiveFunction(cfkey) )
+                                       builder.append(offset + "--" + cfkey + 
" (recursive)\n");
+                               else {
+                                       fstack.add(cfkey);
+                                       builder.append(offset + "--" + cfkey + 
"\n");
+                                       
builder.append(explainFunctionCallGraph(fgraph, fstack, cfkey, level+1));
+                                       fstack.remove(cfkey);
+                               }
                        }
                }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java 
b/src/main/java/org/apache/sysml/utils/Statistics.java
index 87be90b..cf9b5fb 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -345,6 +345,9 @@ public class Statistics
                hopRecompilePred.set(0);
                hopRecompileSB.set(0);
                
+               funRecompiles.set(0);
+               funRecompileTime.set(0);
+               
                parforOptCount = 0;
                parforOptTime = 0;
                parforInitTime = 0;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/java/org/apache/sysml/test/integration/functions/recompile/RecursiveFunctionRecompileTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/RecursiveFunctionRecompileTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/RecursiveFunctionRecompileTest.java
new file mode 100644
index 0000000..99459b0
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/RecursiveFunctionRecompileTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.recompile;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.runtime.util.DataConverter;
+import org.apache.sysml.runtime.util.MapReduceTool;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * This test ensures that recursive functions are not marked for 
recompile-once 
+ * during IPA because this could potentially lead to incorrect plans that 
cause 
+ * OOMs or even incorrect results. 
+ *
+ */
+public class RecursiveFunctionRecompileTest extends AutomatedTestBase 
+{
+       private final static String TEST_DIR = "functions/recompile/";
+       private final static String TEST_NAME1 = "recursive_func_direct";
+       private final static String TEST_NAME2 = "recursive_func_indirect";
+       private final static String TEST_NAME3 = "recursive_func_indirect2";
+       private final static String TEST_NAME4 = "recursive_func_none";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
+               RecursiveFunctionRecompileTest.class.getSimpleName() + "/";
+       
+       private final static long rows = 5000;
+       private final static long cols = 10000;    
+       private final static double sparsity = 0.00001d;    
+       private final static double val = 7.0;
+       
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, 
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new 
String[] { "Rout" }) );
+               addTestConfiguration(TEST_NAME2, 
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new 
String[] { "Rout" }) );
+               addTestConfiguration(TEST_NAME3, 
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new 
String[] { "Rout" }) );
+               addTestConfiguration(TEST_NAME4, 
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new 
String[] { "Rout" }) );
+       }
+
+       @Test
+       public void testDirectRecursionRecompileIPA() {
+               runRecompileTest(TEST_NAME1, true);
+       }
+       
+       @Test
+       public void testIndirectRecursionRecompileIPA() {
+               runRecompileTest(TEST_NAME2, true);
+       }
+       
+       @Test
+       public void testIndirect2RecursionRecompileIPA() {
+               runRecompileTest(TEST_NAME3, true);
+       }
+       
+       @Test
+       public void testNoRecursionRecompileIPA() {
+               runRecompileTest(TEST_NAME4, true);
+       }
+       
+       @Test
+       public void testDirectRecursionRecompileNoIPA() {
+               runRecompileTest(TEST_NAME1, false);
+       }
+       
+       @Test
+       public void testIndirectRecursionRecompileNoIPA() {
+               runRecompileTest(TEST_NAME2, false);
+       }
+       
+       @Test
+       public void testIndirect2RecursionRecompileNoIPA() {
+               runRecompileTest(TEST_NAME3, false);
+       }
+       
+       @Test
+       public void testNoRecursionRecompileNoIPA() {
+               runRecompileTest(TEST_NAME4, false);
+       }
+       
+       private void runRecompileTest( String testname, boolean IPA )
+       {       
+               boolean oldFlagIPA = 
OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS;
+               
+               try
+               {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{"-explain","-stats","-args",
+                               input("V"), Double.toString(val), output("R") };
+
+                       OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA;
+                       
+                       //generate sparse input data
+                       MatrixBlock mb = MatrixBlock.randOperations((int)rows, 
(int)cols, sparsity, 0, 1, "uniform", 732);
+                       MatrixCharacteristics mc = new 
MatrixCharacteristics(rows,cols,OptimizerUtils.DEFAULT_BLOCKSIZE,OptimizerUtils.DEFAULT_BLOCKSIZE,(long)(rows*cols*sparsity));
+                       DataConverter.writeMatrixToHDFS(mb, input("V"), 
OutputInfo.TextCellOutputInfo, mc);
+                       MapReduceTool.writeMetaDataFile(input("V.mtd"), 
ValueType.DOUBLE, mc, OutputInfo.TextCellOutputInfo);
+                       
+                       //run test
+                       runTest(true, false, null, -1); 
+                       
+                       //check number of recompiled functions (recompile_once 
is not applicable for recursive functions
+                       //because the single recompilation on entry would 
implicitly change the remaining plan of the caller;
+                       //if not not handled correctly, TEST_NAME1 and 
TEST_NAME2 would have show with IPA 1111 function recompilations. 
+                       Assert.assertEquals(testname.equals(TEST_NAME4) && IPA 
? 1 : 0, Statistics.getFunRecompiles());
+               }
+               catch(Exception ex) {
+                       ex.printStackTrace();
+                       Assert.fail("Failed to run test: "+ex.getMessage());
+               }
+               finally {
+                       OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = 
oldFlagIPA;
+               }
+       }       
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/scripts/functions/recompile/recursive_func_direct.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/recompile/recursive_func_direct.dml 
b/src/test/scripts/functions/recompile/recursive_func_direct.dml
new file mode 100644
index 0000000..368bbc7
--- /dev/null
+++ b/src/test/scripts/functions/recompile/recursive_func_direct.dml
@@ -0,0 +1,38 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo1 = function (Matrix[Double] X)
+    return (Matrix[Double] Y)
+{  
+   print(ncol(X)+" cols -> "+sum(X)); 
+   for( i in 1:10 ) {
+      batch = ncol(X)/10;
+      tmp = X[,((i-1)*batch+1):(i*batch)];      
+      if( batch > 1 )
+         tmp = foo1(tmp);
+   }
+   Y = X*2;  
+}
+
+V = read($1);
+V = foo1(V); 
+print(sum(V));

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/scripts/functions/recompile/recursive_func_indirect.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/recompile/recursive_func_indirect.dml 
b/src/test/scripts/functions/recompile/recursive_func_indirect.dml
new file mode 100644
index 0000000..0e7b023
--- /dev/null
+++ b/src/test/scripts/functions/recompile/recursive_func_indirect.dml
@@ -0,0 +1,46 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo1 = function (Matrix[Double] X)
+    return (Matrix[Double] Y)
+{  
+   print(ncol(X)+" cols -> "+sum(X)); 
+   for( i in 1:10 ) {
+      batch = ncol(X)/10;
+      tmp = X[,((i-1)*batch+1):(i*batch)];      
+      tmp = foo2(tmp, batch);
+   }
+   Y = X*2;  
+}
+
+foo2 = function (Matrix[Double] X, Integer batch)
+    return (Matrix[Double] Y)
+{  
+   if( batch > 1 )
+      Y = foo1(X);
+   else
+      Y = X;   
+}
+
+V = read($1);
+V = foo1(V); 
+print(sum(V));

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/scripts/functions/recompile/recursive_func_indirect2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/recompile/recursive_func_indirect2.dml 
b/src/test/scripts/functions/recompile/recursive_func_indirect2.dml
new file mode 100644
index 0000000..f30f580
--- /dev/null
+++ b/src/test/scripts/functions/recompile/recursive_func_indirect2.dml
@@ -0,0 +1,55 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo1 = function (Matrix[Double] X)
+    return (Matrix[Double] Y)
+{  
+   print(ncol(X)+" cols -> "+sum(X)); 
+   for( i in 1:10 ) {
+      batch = ncol(X)/10;
+      tmp = X[,((i-1)*batch+1):(i*batch)];      
+      tmp = foo2(tmp, batch);
+   }
+   Y = X*2;  
+}
+
+foo2 = function (Matrix[Double] X, Integer batch)
+    return (Matrix[Double] Y)
+{  
+   if( batch > 1 )
+      Y = foo3(X);
+   else
+      Y = X;   
+}
+
+foo3 = function (Matrix[Double] X)
+    return (Matrix[Double] Y)
+{  
+   if( sum(X) >= 0 )
+      Y = foo1(X);
+   else
+      Y = X;      
+}
+
+V = read($1);
+V = foo3(V); 
+print(sum(V));

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/scripts/functions/recompile/recursive_func_none.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/recompile/recursive_func_none.dml 
b/src/test/scripts/functions/recompile/recursive_func_none.dml
new file mode 100644
index 0000000..7548017
--- /dev/null
+++ b/src/test/scripts/functions/recompile/recursive_func_none.dml
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo1 = function (Matrix[Double] X)
+    return (Matrix[Double] Y)
+{  
+   print(ncol(X)+" cols -> "+sum(X)); 
+   for( i in 1:10 ) {
+      batch = ncol(X)/10;
+      tmp = X[,((i-1)*batch+1):(i*batch)];      
+      if( sum(tmp)<0 )
+         print("Sum is negative");
+   }
+   Y = X*2;  
+}
+
+V = read($1);
+V = foo1(V); 
+print(sum(V));
+   
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java
index c15340a..da37e50 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java
@@ -40,6 +40,7 @@ import org.junit.runners.Suite;
        RandRecompileTest.class,
        RandSizeExpressionEvalTest.class,
        ReblockRecompileTest.class,
+       RecursiveFunctionRecompileTest.class,
        RemoveEmptyPotpourriTest.class,
        RemoveEmptyRecompileTest.class,
        RewriteComplexMapMultChainTest.class,


Reply via email to