This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 57315fb  [SYSTEMDS-2641] Extended IPA rewrite handling (rebuild fgraph 
on demand)
57315fb is described below

commit 57315fb542af690fdf718686be2a6fd9296a4842
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Sep 13 18:40:10 2020 +0200

    [SYSTEMDS-2641] Extended IPA rewrite handling (rebuild fgraph on demand)
    
    This patch improves the inter-procedural analysis, which repeatedly
    propagates scalars and sizes, and applies various IPA rewrite passes.
    One of these rewrite passes is a second function inlining mechanism,
    which inlines functions that are called once or small functions with
    less than t=10 operators. However, this condition was based on a
    functional call graph which was never rebuilt.
    
    In slicefinder, after scalar propagation one of two calls to evalSlice
    gets removed (via remove unnecessary branches) but inlining did not take
    place due to mistakenly assumed two calls to this function. We now
    propagate the information of removed branches to IPA and rebuild the
    functional call graph if necessary.
---
 src/main/java/org/apache/sysds/hops/ipa/IPAPass.java     |  6 ++++--
 .../ipa/IPAPassApplyStaticAndDynamicHopRewrites.java     | 16 +++++++++++-----
 .../apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java  |  3 ++-
 .../hops/ipa/IPAPassFlagFunctionsRecompileOnce.java      |  5 +++--
 .../apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java |  5 +++--
 .../sysds/hops/ipa/IPAPassForwardFunctionCalls.java      |  3 ++-
 .../apache/sysds/hops/ipa/IPAPassInlineFunctions.java    |  3 ++-
 .../sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java  |  3 ++-
 .../sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java   |  6 +++---
 .../hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java    |  3 ++-
 .../sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java     |  3 ++-
 .../apache/sysds/hops/ipa/InterProceduralAnalysis.java   |  9 +++++++--
 .../org/apache/sysds/hops/rewrite/ProgramRewriter.java   |  6 ++++--
 .../test/functions/builtin/BuiltinSliceFinderTest.java   |  4 ++++
 14 files changed, 51 insertions(+), 24 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPass.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPass.java
index 7807a23..74a0b1d 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPass.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPass.java
@@ -29,7 +29,7 @@ import org.apache.sysds.parser.DMLProgram;
 public abstract class IPAPass 
 {
        protected static final Log LOG = 
LogFactory.getLog(IPAPass.class.getName());
-    
+
        /**
         * Indicates if an IPA pass is applicable for the current
         * configuration such as global flags or the chosen execution 
@@ -47,6 +47,8 @@ public abstract class IPAPass
         * @param prog dml program
         * @param fgraph function call graph
         * @param fcallSizes function call size infos
+        * @return true if function call graph should be rebuild
         */
-       public abstract void rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes );
+       public abstract boolean rewriteProgram( DMLProgram prog,
+               FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes );
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassApplyStaticAndDynamicHopRewrites.java
 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassApplyStaticAndDynamicHopRewrites.java
index c00b73e..fdd8af0 100644
--- 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassApplyStaticAndDynamicHopRewrites.java
+++ 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassApplyStaticAndDynamicHopRewrites.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops.ipa;
 
 
 import org.apache.sysds.hops.HopsException;
+import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
 import org.apache.sysds.hops.rewrite.ProgramRewriter;
 import org.apache.sysds.hops.rewrite.RewriteInjectSparkLoopCheckpointing;
 import org.apache.sysds.parser.DMLProgram;
@@ -42,17 +43,22 @@ public class IPAPassApplyStaticAndDynamicHopRewrites 
extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) {
+       public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) {
                try {
-                       //construct rewriter w/o checkpoint injection to avoid 
redundancy
+                       // construct rewriter w/o checkpoint injection to avoid 
redundancy
                        ProgramRewriter rewriter = new ProgramRewriter(
                                InterProceduralAnalysis.APPLY_STATIC_REWRITES,
                                InterProceduralAnalysis.APPLY_DYNAMIC_REWRITES);
                        
rewriter.removeStatementBlockRewrite(RewriteInjectSparkLoopCheckpointing.class);
                        
-                       //rewrite program hop dags and statement blocks
-                       rewriter.rewriteProgramHopDAGs(prog, true); //rewrite 
and split
-               } 
+                       // rewrite program hop dags and statement blocks
+                       ProgramRewriteStatus status = new 
ProgramRewriteStatus();
+                       rewriter.rewriteProgramHopDAGs(prog, true, status); 
//rewrite and split
+                       // in case of removed branches entire function calls 
might have been eliminated,
+                       // accordingly, we should rebuild the function call 
graph to allow for inlining
+                       // even large functions, and avoid restrictions of 
scalar/size propagation
+                       return status.getRemovedBranches();
+               }
                catch (LanguageException ex) {
                        throw new HopsException(ex);
                }
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java
index 043c4b2..b3aae5e 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java
@@ -52,7 +52,7 @@ public class IPAPassEliminateDeadCode extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) {
+       public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) {
                // step 1: backwards pass over main program to track used and 
remove unused vars
                findAndRemoveDeadCode(prog.getStatementBlocks(), new 
HashSet<>(), fgraph);
                
@@ -66,6 +66,7 @@ public class IPAPassEliminateDeadCode extends IPAPass
                        // backward pass over function to track used and remove 
unused vars
                        findAndRemoveDeadCode(fstmt.getBody(), usedVars, 
fgraph);
                }
+               return false;
        }
        
        private static void findAndRemoveDeadCode(List<StatementBlock> sbs, 
Set<String> usedVars, FunctionCallGraph fgraph) {
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
index c0cf6c7..b6351ba 100644
--- 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
+++ 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
@@ -53,10 +53,10 @@ public class IPAPassFlagFunctionsRecompileOnce extends 
IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
+       public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) 
        {
                if( !ConfigurationManager.isDynamicRecompilation() )
-                       return;
+                       return false;
                
                try {
                        // flag applicable functions for recompile-once, note 
that this IPA pass
@@ -82,6 +82,7 @@ public class IPAPassFlagFunctionsRecompileOnce extends IPAPass
                catch( LanguageException ex ) {
                        throw new HopsException(ex);
                }
+               return false;
        }
        
        /**
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
index a000096..6275f10 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
@@ -48,10 +48,10 @@ public class IPAPassFlagNonDeterminism extends IPAPass {
        }
 
        @Override
-       public void rewriteProgram (DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes) 
+       public boolean rewriteProgram (DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes) 
        {
                if (!LineageCacheConfig.isMultiLevelReuse())
-                       return;
+                       return false;
                
                try {
                        // Find the individual functions and statementblocks 
with non-determinism.
@@ -84,6 +84,7 @@ public class IPAPassFlagNonDeterminism extends IPAPass {
                catch( LanguageException ex ) {
                        throw new HopsException(ex);
                }
+               return false;
        }
 
        private boolean rIsNonDeterministicFnc (String fname, 
ArrayList<StatementBlock> sbs) 
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
index 1605524..8b57742 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
@@ -47,7 +47,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
+       public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) 
        {
                for( String fkey : fgraph.getReachableFunctions() ) {
                        FunctionStatementBlock fsb = 
prog.getFunctionStatementBlock(fkey);
@@ -87,6 +87,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
                                                + fkey +"' with 
'"+call2.getFunctionKey()+"'");
                        }
                }
+               return false;
        }
        
        private static boolean singleFunctionOp(ArrayList<Hop> hops) {
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
index 8c89689..3a465db 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
@@ -53,7 +53,7 @@ public class IPAPassInlineFunctions extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
+       public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) 
        {
                //NOTE: we inline single-statement-block (i.e., last-level 
block) functions
                //that do not contain other functions, and either are small or 
called once
@@ -133,6 +133,7 @@ public class IPAPassInlineFunctions extends IPAPass
                                }
                        }
                }
+               return false;
        }
        
        private static boolean containsFunctionOp(ArrayList<Hop> hops) {
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
index d9eac84..0f33a45 100644
--- 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
+++ 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
@@ -56,7 +56,7 @@ public class IPAPassPropagateReplaceLiterals extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
+       public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) 
        {
                //step 1: propagate final literals across main program
                rReplaceLiterals(prog.getStatementBlocks(), prog, fgraph, 
fcallSizes);
@@ -93,6 +93,7 @@ public class IPAPassPropagateReplaceLiterals extends IPAPass
                                rReplaceLiterals(fstmt.getBody(), prog, fgraph, 
fcallSizes);
                        }
                }
+               return false;
        }
        
        private void rReplaceLiterals(List<StatementBlock> sbs, DMLProgram 
prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
index 606c677..ffb7d68 100644
--- 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
+++ 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
@@ -56,13 +56,12 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) {
+       public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) {
                //approach: scan over top-level program (guaranteed to be 
unconditional),
                //collect ones=matrix(1,...); remove b(*)ones if not outer 
operation
                HashMap<String, Hop> mOnes = new HashMap<>();
                
-               for( StatementBlock sb : prog.getStatementBlocks() ) 
-               {
+               for( StatementBlock sb : prog.getStatementBlocks() )  {
                        //pruning updated variables
                        for( String var : 
sb.variablesUpdated().getVariableNames() )
                                if( mOnes.containsKey( var ) )
@@ -79,6 +78,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
                                collectMatrixOfOnes(sb.getHops(), mOnes);
                        }
                }
+               return false;
        }
        
        private static void collectMatrixOfOnes(ArrayList<Hop> roots, 
HashMap<String,Hop> mOnes)
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
index 351d099..c78aac6 100644
--- 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
+++ 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
@@ -55,7 +55,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends 
IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) {
+       public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) {
                //remove unnecessary checkpoint before update 
                removeCheckpointBeforeUpdate(prog);
                
@@ -64,6 +64,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends 
IPAPass
                
                //remove unnecessary checkpoint read-{write|uagg}
                removeCheckpointReadWrite(prog);
+               return false;
        }
        
        private static void removeCheckpointBeforeUpdate(DMLProgram dmlp) {
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java 
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java
index 6d6abc8..9304926 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java
@@ -44,7 +44,7 @@ public class IPAPassRemoveUnusedFunctions extends IPAPass
        }
        
        @Override
-       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) {
+       public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes ) {
                try {
                        Set<String> fnamespaces = prog.getNamespaces().keySet();
                        for( String fnspace : fnamespaces  ) {
@@ -64,5 +64,6 @@ public class IPAPassRemoveUnusedFunctions extends IPAPass
                catch(LanguageException ex) {
                        throw new HopsException(ex);
                }
+               return false;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index abb8b81..579af77 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -97,7 +97,7 @@ public class InterProceduralAnalysis
        private final StatementBlock _sb;
        
        //function call graph for functions reachable from main
-       private final FunctionCallGraph _fgraph;
+       private FunctionCallGraph _fgraph;
        
        //set IPA passes to apply in order 
        private final ArrayList<IPAPass> _passes;
@@ -200,9 +200,10 @@ public class InterProceduralAnalysis
                        }
                        
                        //step 2: apply additional IPA passes
+                       boolean rebuildFGraph = false;
                        for( IPAPass pass : _passes )
                                if( pass.isApplicable(_fgraph) )
-                                       pass.rewriteProgram(_prog, _fgraph, 
fcallSizes);
+                                       rebuildFGraph |= 
pass.rewriteProgram(_prog, _fgraph, fcallSizes);
                        
                        //early abort without functions or on reached fixpoint
                        if( _fgraph.getReachableFunctions().isEmpty() 
@@ -212,6 +213,10 @@ public class InterProceduralAnalysis
                                                + " repetitions due to reached 
fixpoint.");
                                break;
                        }
+                       
+                       //step 3: rebuild function call graph if necessary
+                       if( rebuildFGraph && i < repetitions-1 )
+                               _fgraph = new FunctionCallGraph(_prog);
                }
                
                //cleanup pass: remove unused functions
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 87df183..af81e86 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -191,8 +191,10 @@ public class ProgramRewriter
        }
        
        public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, 
boolean splitDags) {
-               ProgramRewriteStatus state = new ProgramRewriteStatus();
-               
+               return rewriteProgramHopDAGs(dmlp, splitDags, new 
ProgramRewriteStatus());
+       }
+       
+       public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, 
boolean splitDags, ProgramRewriteStatus state) {
                // for each namespace, handle function statement blocks
                for (String namespaceKey : dmlp.getNamespaces().keySet())
                        for (String fname : 
dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSliceFinderTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSliceFinderTest.java
index ff9b639..a5dd9a7 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSliceFinderTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSliceFinderTest.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.test.functions.builtin;
 
 
+import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -110,6 +111,9 @@ public class BuiltinSliceFinderTest extends 
AutomatedTestBase {
                        double[][] ret = 
TestUtils.convertHashMapToDoubleArray(readDMLMatrixFromHDFS("R"));
                        for(int i=0; i<K; i++)
                                TestUtils.compareMatrices(EXPECTED_TOPK[i], 
ret[i], 1e-2);
+               
+                       //ensure proper inlining, despite initially multiple 
calls and large function
+                       
Assert.assertFalse(heavyHittersContainsSubString("evalSlice"));
                }
                finally {
                        rtplatform = platformOld;

Reply via email to