[SYSTEMML-1251] Fix recompile-once recursive functions, cleanup, tests This patch fixes our inter-procedural analysis, disallowing to mark directly or indirectly recursive functions for recompile-once because recompile-once can lead to OOMs and even incorrect results in the context of recursion. Recompile-once functions are recompiled on entry with the given input statistics in order to avoid repeated recompilation in loops, which is often unnecessary knowing the size of function inputs. However, with recursive functions it is not that easy. A subcall recompiles the plan again, also modifying the remaining plan for the calling function - if both recursion levels work on different sizes or compile literals into the plan, the consequences are disastrous.
Since multiple places such as EXPLAIN and IPA require information about the function call graph, this patch introduces a clean abstraction, the FunctionCallGraph, that captures this information once and simplifies related compiler passes. Furthermore, this patch also includes a fix for our -stats tool to properly reset the number and time of function recompilations. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/0a3d2481 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/0a3d2481 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/0a3d2481 Branch: refs/heads/master Commit: 0a3d24815258744169367acd8f0287761b5d20de Parents: 696d10b Author: Matthias Boehm <[email protected]> Authored: Sun Feb 12 06:27:58 2017 +0100 Committer: Matthias Boehm <[email protected]> Committed: Sun Feb 12 06:27:58 2017 +0100 ---------------------------------------------------------------------- .../sysml/hops/ipa/FunctionCallGraph.java | 247 +++++++++++++++++++ .../sysml/hops/ipa/InterProceduralAnalysis.java | 38 +-- .../org/apache/sysml/parser/DMLProgram.java | 8 + .../java/org/apache/sysml/utils/Explain.java | 131 +++------- .../java/org/apache/sysml/utils/Statistics.java | 3 + .../RecursiveFunctionRecompileTest.java | 149 +++++++++++ .../recompile/recursive_func_direct.dml | 38 +++ .../recompile/recursive_func_indirect.dml | 46 ++++ .../recompile/recursive_func_indirect2.dml | 55 +++++ .../functions/recompile/recursive_func_none.dml | 39 +++ .../functions/recompile/ZPackageSuite.java | 1 + 11 files changed, 645 insertions(+), 110 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java new file mode 100644 index 0000000..eed6531 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.ipa; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Stack; + +import org.apache.sysml.hops.FunctionOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.ForStatement; +import org.apache.sysml.parser.ForStatementBlock; +import org.apache.sysml.parser.FunctionStatement; +import org.apache.sysml.parser.FunctionStatementBlock; +import org.apache.sysml.parser.IfStatement; +import org.apache.sysml.parser.IfStatementBlock; +import org.apache.sysml.parser.StatementBlock; +import org.apache.sysml.parser.WhileStatement; +import org.apache.sysml.parser.WhileStatementBlock; + +public class FunctionCallGraph +{ + //internal function key for main program (underscore + //prevents any conflicts with user-defined functions) + private static final String MAIN_FUNCTION_KEY = "_main"; + + //unrolled function call graph, in call direction + //(mapping from function keys to called function keys) + private final HashMap<String, HashSet<String>> _fGraph; + + //subset of direct or indirect recursive functions + private final HashSet<String> _fRecursive; + + /** + * Constructs the function call graph for all functions + * reachable from the main program. + * + * @param prog dml program of given script + */ + public FunctionCallGraph(DMLProgram prog) { + _fGraph = new HashMap<String, HashSet<String>>(); + _fRecursive = new HashSet<String>(); + + constructFunctionCallGraph(prog); + } + + /** + * Returns all functions called from the given function. + * + * @param fnamespace function namespace + * @param fname function name + * @return list of function keys (namespace and name) + */ + public Collection<String> getCalledFunctions(String fnamespace, String fname) { + return getCalledFunctions( + DMLProgram.constructFunctionKey(fnamespace, fname)); + } + + /** + * Returns all functions called from the given function. + * + * @param fkey function key of calling function, null indicates the main program + * @return list of function keys (namespace and name) + */ + public Collection<String> getCalledFunctions(String fkey) { + String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey; + return _fGraph.get(lfkey); + } + + /** + * Indicates if the given function is either directly or indirectly recursive. + * An example of an indirect recursive function is foo2 in the following call + * chain: foo1 -> foo2 -> foo1. + * + * @param fnamespace function namespace + * @param fname function name + * @return true if the given function is recursive, false otherwise + */ + public boolean isRecursiveFunction(String fnamespace, String fname) { + return isRecursiveFunction( + DMLProgram.constructFunctionKey(fnamespace, fname)); + } + + /** + * Indicates if the given function is either directly or indirectly recursive. + * An example of an indirect recursive function is foo2 in the following call + * chain: foo1 -> foo2 -> foo1. + * + * @param fkey function key of calling function, null indicates the main program + * @return true if the given function is recursive, false otherwise + */ + public boolean isRecursiveFunction(String fkey) { + String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey; + return _fRecursive.contains(lfkey); + } + + /** + * Returns all functions that are reachable either directly or indirectly + * form the main program, except the main program itself and the given + * blacklist of function names. + * + * @param blacklist list of function keys to exclude + * @return list of function keys (namespace and name) + */ + public Collection<String> getReachableFunctions(Collection<String> blacklist) { + HashSet<String> ret = new HashSet<String>(); + for( String tmp : _fGraph.keySet() ) + if( !blacklist.contains(tmp) && !MAIN_FUNCTION_KEY.equals(tmp) ) + ret.add(tmp); + return ret; + } + + /** + * Indicates if the given function is reachable either directly or indirectly + * from the main program. + * + * @param fnamespace function namespace + * @param fname function name + * @return true if the given function is reachable, false otherwise + */ + public boolean isReachableFunction(String fnamespace, String fname) { + return isReachableFunction( + DMLProgram.constructFunctionKey(fnamespace, fname)); + } + + /** + * Indicates if the given function is reachable either directly or indirectly + * from the main program. + * + * @param fkey function key of calling function, null indicates the main program + * @return true if the given function is reachable, false otherwise + */ + public boolean isReachableFunction(String fkey) { + String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey; + return _fGraph.containsKey(lfkey); + } + + private void constructFunctionCallGraph(DMLProgram prog) { + if( !prog.hasFunctionStatementBlocks() ) + return; //early abort if prog without functions + + try { + Stack<String> fstack = new Stack<String>(); + HashSet<String> lfset = new HashSet<String>(); + _fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>()); + for( StatementBlock sblk : prog.getStatementBlocks() ) + rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sblk, fstack, lfset); + } + catch(HopsException ex) { + throw new RuntimeException(ex); + } + } + + private void rConstructFunctionCallGraph(String fkey, StatementBlock sb, Stack<String> fstack, HashSet<String> lfset) + throws HopsException + { + if (sb instanceof WhileStatementBlock) { + WhileStatement ws = (WhileStatement)sb.getStatement(0); + for (StatementBlock current : ws.getBody()) + rConstructFunctionCallGraph(fkey, current, fstack, lfset); + } + else if (sb instanceof IfStatementBlock) { + IfStatement ifs = (IfStatement) sb.getStatement(0); + for (StatementBlock current : ifs.getIfBody()) + rConstructFunctionCallGraph(fkey, current, fstack, lfset); + for (StatementBlock current : ifs.getElseBody()) + rConstructFunctionCallGraph(fkey, current, fstack, lfset); + } + else if (sb instanceof ForStatementBlock) { + ForStatement fs = (ForStatement)sb.getStatement(0); + for (StatementBlock current : fs.getBody()) + rConstructFunctionCallGraph(fkey, current, fstack, lfset); + } + else if (sb instanceof FunctionStatementBlock) { + FunctionStatement fsb = (FunctionStatement) sb.getStatement(0); + for (StatementBlock current : fsb.getBody()) + rConstructFunctionCallGraph(fkey, current, fstack, lfset); + } + else { + // For generic StatementBlock + ArrayList<Hop> hopsDAG = sb.get_hops(); + if( hopsDAG == null || hopsDAG.isEmpty() ) + return; //nothing to do + + //function ops can only occur as root nodes of the dag + for( Hop h : hopsDAG ) { + if( h instanceof FunctionOp ){ + FunctionOp fop = (FunctionOp) h; + String lfkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); + //prevent redundant call edges + if( lfset.contains(lfkey) || fop.getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) ) + continue; + + if( !_fGraph.containsKey(lfkey) ) + _fGraph.put(lfkey, new HashSet<String>()); + + //recursively construct function call dag + if( !fstack.contains(lfkey) ) { + fstack.push(lfkey); + _fGraph.get(fkey).add(lfkey); + + FunctionStatementBlock fsb = sb.getDMLProg() + .getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName()); + FunctionStatement fs = (FunctionStatement) fsb.getStatement(0); + for( StatementBlock csb : fs.getBody() ) + rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>()); + fstack.pop(); + } + //recursive function call + else { + _fGraph.get(fkey).add(lfkey); + _fRecursive.add(lfkey); + + //mark indirectly recursive functions as recursive + int ix = fstack.indexOf(lfkey); + for( int i=ix+1; i<fstack.size(); i++ ) + _fRecursive.add(fstack.get(i)); + } + + //mark as visited for current function call context + lfset.add( lfkey ); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java index 5d94a68..caab391 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java @@ -29,7 +29,6 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; -import org.apache.commons.collections.CollectionUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; @@ -157,26 +156,25 @@ public class InterProceduralAnalysis * @throws ParseException if ParseException occurs * @throws LanguageException if LanguageException occurs */ - @SuppressWarnings("unchecked") public void analyzeProgram( DMLProgram dmlp ) throws HopsException, ParseException, LanguageException { + FunctionCallGraph fgraph = new FunctionCallGraph(dmlp); + //step 1: get candidates for statistics propagation into functions (if required) Map<String, Integer> fcandCounts = new HashMap<String, Integer>(); Map<String, FunctionOp> fcandHops = new HashMap<String, FunctionOp>(); Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, Set<Long>>(); - Set<String> allFCandKeys = new HashSet<String>(); if( !dmlp.getFunctionStatementBlocks().isEmpty() ) { for ( StatementBlock sb : dmlp.getStatementBlocks() ) //get candidates (over entire program) getFunctionCandidatesForStatisticPropagation( sb, fcandCounts, fcandHops ); - allFCandKeys.addAll(fcandCounts.keySet()); //cp before pruning pruneFunctionCandidatesForStatisticPropagation( fcandCounts, fcandHops ); determineFunctionCandidatesNNZPropagation( fcandHops, fcandSafeNNZ ); DMLTranslator.resetHopsDAGVisitStatus( dmlp ); } //step 2: get unary dimension-preserving non-candidate functions - Collection<String> unaryFcandTmp = CollectionUtils.subtract(allFCandKeys, fcandCounts.keySet()); + Collection<String> unaryFcandTmp = fgraph.getReachableFunctions(fcandCounts.keySet()); HashSet<String> unaryFcands = new HashSet<String>(); if( !unaryFcandTmp.isEmpty() && UNARY_DIMS_PRESERVING_FUNS ) { for( String tmp : unaryFcandTmp ) @@ -194,12 +192,12 @@ public class InterProceduralAnalysis //step 4: remove unused functions (e.g., inlined or never called) if( REMOVE_UNUSED_FUNCTIONS ) { - removeUnusedFunctions( dmlp, allFCandKeys ); + removeUnusedFunctions( dmlp, fgraph ); } //step 5: flag functions with loops for 'recompile-on-entry' if( FLAG_FUNCTION_RECOMPILE_ONCE ) { - flagFunctionsForRecompileOnce( dmlp ); + flagFunctionsForRecompileOnce( dmlp, fgraph ); } //step 6: set global data flow properties @@ -871,22 +869,21 @@ public class InterProceduralAnalysis // REMOVE UNUSED FUNCTIONS ////// - public void removeUnusedFunctions( DMLProgram dmlp, Set<String> fcandKeys ) + public void removeUnusedFunctions( DMLProgram dmlp, FunctionCallGraph fgraph ) throws LanguageException { Set<String> fnamespaces = dmlp.getNamespaces().keySet(); - for( String fnspace : fnamespaces ) - { + for( String fnspace : fnamespaces ) { HashMap<String, FunctionStatementBlock> fsbs = dmlp.getFunctionStatementBlocks(fnspace); Iterator<Entry<String, FunctionStatementBlock>> iter = fsbs.entrySet().iterator(); - while( iter.hasNext() ) - { + while( iter.hasNext() ) { Entry<String, FunctionStatementBlock> e = iter.next(); - String fname = e.getKey(); - String fKey = DMLProgram.constructFunctionKey(fnspace, fname); - //probe function candidates, remove if no candidate - if( !fcandKeys.contains(fKey) ) + if( !fgraph.isReachableFunction(fnspace, e.getKey()) ) { iter.remove(); + if( LOG.isDebugEnabled() ) + LOG.debug("IPA: Removed unused function: " + + DMLProgram.constructFunctionKey(fnspace, e.getKey())); + } } } } @@ -902,17 +899,20 @@ public class InterProceduralAnalysis * @param dmlp the DML program * @throws LanguageException if LanguageException occurs */ - public void flagFunctionsForRecompileOnce( DMLProgram dmlp ) + public void flagFunctionsForRecompileOnce( DMLProgram dmlp, FunctionCallGraph fgraph ) throws LanguageException { for (String namespaceKey : dmlp.getNamespaces().keySet()) for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) { FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname); - if( rFlagFunctionForRecompileOnce( fsblock, false ) ) + if( !fgraph.isRecursiveFunction(namespaceKey, fname) && + rFlagFunctionForRecompileOnce( fsblock, false ) ) { fsblock.setRecompileOnce( true ); - LOG.debug("IPA: FUNC flagged for recompile-once: " + DMLProgram.constructFunctionKey(namespaceKey, fname)); + if( LOG.isDebugEnabled() ) + LOG.debug("IPA: FUNC flagged for recompile-once: " + + DMLProgram.constructFunctionKey(namespaceKey, fname)); } } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/parser/DMLProgram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLProgram.java b/src/main/java/org/apache/sysml/parser/DMLProgram.java index 292a13c..9fbd63c 100644 --- a/src/main/java/org/apache/sysml/parser/DMLProgram.java +++ b/src/main/java/org/apache/sysml/parser/DMLProgram.java @@ -107,6 +107,14 @@ public class DMLProgram return namespaceProgram._functionBlocks; } + public boolean hasFunctionStatementBlocks() { + boolean ret = false; + for( DMLProgram nsProg : _namespaces.values() ) + ret |= !nsProg._functionBlocks.isEmpty(); + + return ret; + } + public ArrayList<FunctionStatementBlock> getFunctionStatementBlocks() throws LanguageException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/utils/Explain.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java index 097e2b7..ccc9853 100644 --- a/src/main/java/org/apache/sysml/utils/Explain.java +++ b/src/main/java/org/apache/sysml/utils/Explain.java @@ -21,12 +21,12 @@ package org.apache.sysml.utils; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.HashSet; import java.util.Map; import java.util.Map.Entry; import org.apache.sysml.api.DMLException; -import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.VisitStatus; import org.apache.sysml.hops.HopsException; @@ -35,6 +35,7 @@ import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.globalopt.gdfgraph.GDFLoopNode; import org.apache.sysml.hops.globalopt.gdfgraph.GDFNode; import org.apache.sysml.hops.globalopt.gdfgraph.GDFNode.NodeType; +import org.apache.sysml.hops.ipa.FunctionCallGraph; import org.apache.sysml.lops.Lop; import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.ForStatement; @@ -206,32 +207,28 @@ public class Explain sb.append("\nPROGRAM\n"); // Explain functions (if exists) - boolean firstFunction = true; - for (String namespace : prog.getNamespaces().keySet()) { - for (String fname : prog.getFunctionStatementBlocks(namespace).keySet()) { - if (firstFunction) { - sb.append("--FUNCTIONS\n"); - firstFunction = false; + if( prog.hasFunctionStatementBlocks() ) { + sb.append("--FUNCTIONS\n"); + + //show function call graph + sb.append("----FUNCTION CALL GRAPH\n"); + sb.append("------MAIN PROGRAM\n"); + FunctionCallGraph fgraph = new FunctionCallGraph(prog); + sb.append(explainFunctionCallGraph(fgraph, new HashSet<String>(), null, 3)); + + //show individual functions + for (String namespace : prog.getNamespaces().keySet()) { + for (String fname : prog.getFunctionStatementBlocks(namespace).keySet()) { + FunctionStatementBlock fsb = prog.getFunctionStatementBlock(namespace, fname); + FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); - //show function call dag - sb.append("----FUNCTION CALL DAG\n"); - sb.append("------MAIN PROGRAM\n"); - HashSet<String> fstack = new HashSet<String>(); - HashSet<String> lfset = new HashSet<String>(); - for( StatementBlock sblk : prog.getStatementBlocks() ) - sb.append(explainFunctionCallDag(sblk, fstack, lfset, 3)); - } - - //show individual functions - FunctionStatementBlock fsb = prog.getFunctionStatementBlock(namespace, fname); - FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); - - if (fstmt instanceof ExternalFunctionStatement) - sb.append("----EXTERNAL FUNCTION " + namespace + "::" + fname + "\n"); - else { - sb.append("----FUNCTION " + namespace + "::" + fname + " [recompile="+fsb.isRecompileOnce()+"]\n"); - for (StatementBlock current : fstmt.getBody()) - sb.append(explainStatementBlock(current, 3)); + if (fstmt instanceof ExternalFunctionStatement) + sb.append("----EXTERNAL FUNCTION " + namespace + "::" + fname + "\n"); + else { + sb.append("----FUNCTION " + namespace + "::" + fname + " [recompile="+fsb.isRecompileOnce()+"]\n"); + for (StatementBlock current : fstmt.getBody()) + sb.append(explainStatementBlock(current, 3)); + } } } } @@ -267,17 +264,15 @@ public class Explain { sb.append("--FUNCTIONS\n"); - //show function call dag + //show function call graph if( !rtprog.getProgramBlocks().isEmpty() && rtprog.getProgramBlocks().get(0).getStatementBlock() != null ) { - sb.append("----FUNCTION CALL DAG\n"); + sb.append("----FUNCTION CALL GRAPH\n"); sb.append("------MAIN PROGRAM\n"); DMLProgram prog = rtprog.getProgramBlocks().get(0).getStatementBlock().getDMLProg(); - HashSet<String> fstack = new HashSet<String>(); - HashSet<String> lfset = new HashSet<String>(); - for( StatementBlock sblk : prog.getStatementBlocks() ) - sb.append(explainFunctionCallDag(sblk, fstack, lfset, 3)); + FunctionCallGraph fgraph = new FunctionCallGraph(prog); + sb.append(explainFunctionCallGraph(fgraph, new HashSet<String>(), null, 3)); } //show individual functions @@ -932,68 +927,22 @@ public class Explain return ret; } - private static String explainFunctionCallDag(StatementBlock sb, HashSet<String> fstack, HashSet<String> lfset, int level) + private static String explainFunctionCallGraph(FunctionCallGraph fgraph, HashSet<String> fstack, String fkey, int level) throws HopsException { StringBuilder builder = new StringBuilder(); - - if (sb instanceof WhileStatementBlock) { - WhileStatement ws = (WhileStatement)sb.getStatement(0); - for (StatementBlock current : ws.getBody()) - builder.append(explainFunctionCallDag(current, fstack, lfset, level)); - } - else if (sb instanceof IfStatementBlock) { - IfStatement ifs = (IfStatement) sb.getStatement(0); - for (StatementBlock current : ifs.getIfBody()) - builder.append(explainFunctionCallDag(current, fstack, lfset, level)); - for (StatementBlock current : ifs.getElseBody()) - builder.append(explainFunctionCallDag(current, fstack, lfset, level)); - } - else if (sb instanceof ForStatementBlock) { - ForStatement fs = (ForStatement)sb.getStatement(0); - for (StatementBlock current : fs.getBody()) - builder.append(explainFunctionCallDag(current, fstack, lfset, level)); - } - else if (sb instanceof FunctionStatementBlock) { - FunctionStatement fsb = (FunctionStatement) sb.getStatement(0); - for (StatementBlock current : fsb.getBody()) - builder.append(explainFunctionCallDag(current, fstack, lfset, level)); - } - else { - // For generic StatementBlock - ArrayList<Hop> hopsDAG = sb.get_hops(); - if( hopsDAG != null && !hopsDAG.isEmpty() ) { - //function ops can only occur as root nodes of the dag - for( Hop h : hopsDAG ) - if( h instanceof FunctionOp ){ - FunctionOp fop = (FunctionOp) h; - String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); - //prevent redundant call edges - if( !lfset.contains(fkey) && !fop.getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) ) - { - //recursively explain function call dag - if( !fstack.contains(fkey) ) { - fstack.add(fkey); - String offset = createOffset(level); - builder.append(offset + "--" + fkey + "\n"); - FunctionStatementBlock fsb = sb.getDMLProg() - .getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName()); - FunctionStatement fs = (FunctionStatement) fsb.getStatement(0); - HashSet<String> lfset2 = new HashSet<String>(); - for( StatementBlock csb : fs.getBody() ) - builder.append(explainFunctionCallDag(csb, fstack, lfset2, level+1)); - fstack.remove(fkey); - } - //recursive function call - else { - String offset = createOffset(level); - builder.append(offset + "-->" + fkey + " (recursive)\n"); - } - - //mark as visited for current function call context - lfset.add( fkey ); - } - } + String offset = createOffset(level); + Collection<String> cfkeys = fgraph.getCalledFunctions(fkey); + if( cfkeys != null ) { + for( String cfkey : cfkeys ) { + if( fstack.contains(cfkey) && fgraph.isRecursiveFunction(cfkey) ) + builder.append(offset + "--" + cfkey + " (recursive)\n"); + else { + fstack.add(cfkey); + builder.append(offset + "--" + cfkey + "\n"); + builder.append(explainFunctionCallGraph(fgraph, fstack, cfkey, level+1)); + fstack.remove(cfkey); + } } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index 87be90b..cf9b5fb 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -345,6 +345,9 @@ public class Statistics hopRecompilePred.set(0); hopRecompileSB.set(0); + funRecompiles.set(0); + funRecompileTime.set(0); + parforOptCount = 0; parforOptTime = 0; parforInitTime = 0; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/java/org/apache/sysml/test/integration/functions/recompile/RecursiveFunctionRecompileTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/recompile/RecursiveFunctionRecompileTest.java b/src/test/java/org/apache/sysml/test/integration/functions/recompile/RecursiveFunctionRecompileTest.java new file mode 100644 index 0000000..99459b0 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/recompile/RecursiveFunctionRecompileTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.recompile; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.OutputInfo; +import org.apache.sysml.runtime.util.DataConverter; +import org.apache.sysml.runtime.util.MapReduceTool; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; + +/** + * This test ensures that recursive functions are not marked for recompile-once + * during IPA because this could potentially lead to incorrect plans that cause + * OOMs or even incorrect results. + * + */ +public class RecursiveFunctionRecompileTest extends AutomatedTestBase +{ + private final static String TEST_DIR = "functions/recompile/"; + private final static String TEST_NAME1 = "recursive_func_direct"; + private final static String TEST_NAME2 = "recursive_func_indirect"; + private final static String TEST_NAME3 = "recursive_func_indirect2"; + private final static String TEST_NAME4 = "recursive_func_none"; + private final static String TEST_CLASS_DIR = TEST_DIR + + RecursiveFunctionRecompileTest.class.getSimpleName() + "/"; + + private final static long rows = 5000; + private final static long cols = 10000; + private final static double sparsity = 0.00001d; + private final static double val = 7.0; + + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, + new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "Rout" }) ); + addTestConfiguration(TEST_NAME2, + new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "Rout" }) ); + addTestConfiguration(TEST_NAME3, + new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "Rout" }) ); + addTestConfiguration(TEST_NAME4, + new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "Rout" }) ); + } + + @Test + public void testDirectRecursionRecompileIPA() { + runRecompileTest(TEST_NAME1, true); + } + + @Test + public void testIndirectRecursionRecompileIPA() { + runRecompileTest(TEST_NAME2, true); + } + + @Test + public void testIndirect2RecursionRecompileIPA() { + runRecompileTest(TEST_NAME3, true); + } + + @Test + public void testNoRecursionRecompileIPA() { + runRecompileTest(TEST_NAME4, true); + } + + @Test + public void testDirectRecursionRecompileNoIPA() { + runRecompileTest(TEST_NAME1, false); + } + + @Test + public void testIndirectRecursionRecompileNoIPA() { + runRecompileTest(TEST_NAME2, false); + } + + @Test + public void testIndirect2RecursionRecompileNoIPA() { + runRecompileTest(TEST_NAME3, false); + } + + @Test + public void testNoRecursionRecompileNoIPA() { + runRecompileTest(TEST_NAME4, false); + } + + private void runRecompileTest( String testname, boolean IPA ) + { + boolean oldFlagIPA = OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-explain","-stats","-args", + input("V"), Double.toString(val), output("R") }; + + OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA; + + //generate sparse input data + MatrixBlock mb = MatrixBlock.randOperations((int)rows, (int)cols, sparsity, 0, 1, "uniform", 732); + MatrixCharacteristics mc = new MatrixCharacteristics(rows,cols,OptimizerUtils.DEFAULT_BLOCKSIZE,OptimizerUtils.DEFAULT_BLOCKSIZE,(long)(rows*cols*sparsity)); + DataConverter.writeMatrixToHDFS(mb, input("V"), OutputInfo.TextCellOutputInfo, mc); + MapReduceTool.writeMetaDataFile(input("V.mtd"), ValueType.DOUBLE, mc, OutputInfo.TextCellOutputInfo); + + //run test + runTest(true, false, null, -1); + + //check number of recompiled functions (recompile_once is not applicable for recursive functions + //because the single recompilation on entry would implicitly change the remaining plan of the caller; + //if not not handled correctly, TEST_NAME1 and TEST_NAME2 would have show with IPA 1111 function recompilations. + Assert.assertEquals(testname.equals(TEST_NAME4) && IPA ? 1 : 0, Statistics.getFunRecompiles()); + } + catch(Exception ex) { + ex.printStackTrace(); + Assert.fail("Failed to run test: "+ex.getMessage()); + } + finally { + OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = oldFlagIPA; + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/scripts/functions/recompile/recursive_func_direct.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/recompile/recursive_func_direct.dml b/src/test/scripts/functions/recompile/recursive_func_direct.dml new file mode 100644 index 0000000..368bbc7 --- /dev/null +++ b/src/test/scripts/functions/recompile/recursive_func_direct.dml @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +foo1 = function (Matrix[Double] X) + return (Matrix[Double] Y) +{ + print(ncol(X)+" cols -> "+sum(X)); + for( i in 1:10 ) { + batch = ncol(X)/10; + tmp = X[,((i-1)*batch+1):(i*batch)]; + if( batch > 1 ) + tmp = foo1(tmp); + } + Y = X*2; +} + +V = read($1); +V = foo1(V); +print(sum(V)); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/scripts/functions/recompile/recursive_func_indirect.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/recompile/recursive_func_indirect.dml b/src/test/scripts/functions/recompile/recursive_func_indirect.dml new file mode 100644 index 0000000..0e7b023 --- /dev/null +++ b/src/test/scripts/functions/recompile/recursive_func_indirect.dml @@ -0,0 +1,46 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +foo1 = function (Matrix[Double] X) + return (Matrix[Double] Y) +{ + print(ncol(X)+" cols -> "+sum(X)); + for( i in 1:10 ) { + batch = ncol(X)/10; + tmp = X[,((i-1)*batch+1):(i*batch)]; + tmp = foo2(tmp, batch); + } + Y = X*2; +} + +foo2 = function (Matrix[Double] X, Integer batch) + return (Matrix[Double] Y) +{ + if( batch > 1 ) + Y = foo1(X); + else + Y = X; +} + +V = read($1); +V = foo1(V); +print(sum(V)); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/scripts/functions/recompile/recursive_func_indirect2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/recompile/recursive_func_indirect2.dml b/src/test/scripts/functions/recompile/recursive_func_indirect2.dml new file mode 100644 index 0000000..f30f580 --- /dev/null +++ b/src/test/scripts/functions/recompile/recursive_func_indirect2.dml @@ -0,0 +1,55 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +foo1 = function (Matrix[Double] X) + return (Matrix[Double] Y) +{ + print(ncol(X)+" cols -> "+sum(X)); + for( i in 1:10 ) { + batch = ncol(X)/10; + tmp = X[,((i-1)*batch+1):(i*batch)]; + tmp = foo2(tmp, batch); + } + Y = X*2; +} + +foo2 = function (Matrix[Double] X, Integer batch) + return (Matrix[Double] Y) +{ + if( batch > 1 ) + Y = foo3(X); + else + Y = X; +} + +foo3 = function (Matrix[Double] X) + return (Matrix[Double] Y) +{ + if( sum(X) >= 0 ) + Y = foo1(X); + else + Y = X; +} + +V = read($1); +V = foo3(V); +print(sum(V)); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test/scripts/functions/recompile/recursive_func_none.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/recompile/recursive_func_none.dml b/src/test/scripts/functions/recompile/recursive_func_none.dml new file mode 100644 index 0000000..7548017 --- /dev/null +++ b/src/test/scripts/functions/recompile/recursive_func_none.dml @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +foo1 = function (Matrix[Double] X) + return (Matrix[Double] Y) +{ + print(ncol(X)+" cols -> "+sum(X)); + for( i in 1:10 ) { + batch = ncol(X)/10; + tmp = X[,((i-1)*batch+1):(i*batch)]; + if( sum(tmp)<0 ) + print("Sum is negative"); + } + Y = X*2; +} + +V = read($1); +V = foo1(V); +print(sum(V)); + \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a3d2481/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java index c15340a..da37e50 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/recompile/ZPackageSuite.java @@ -40,6 +40,7 @@ import org.junit.runners.Suite; RandRecompileTest.class, RandSizeExpressionEvalTest.class, ReblockRecompileTest.class, + RecursiveFunctionRecompileTest.class, RemoveEmptyPotpourriTest.class, RemoveEmptyRecompileTest.class, RewriteComplexMapMultChainTest.class,
