[SYSTEMML-2393/4] Eliminate side-effect-free functions w/ unused outputs

This patch extends the new IPA pass for dead code elimination by the
removal of side-effect-free function calls with unused outputs. As a
necessary precondition, we now track (in a best effort manner)
side-effect-free functions (explicitly annotated UDFs or DML-bodied
functions w/o prints/persistent writes or functions w/ side effects).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ea6dc8c5
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ea6dc8c5
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ea6dc8c5

Branch: refs/heads/master
Commit: ea6dc8c58e7f6d9cbe0b3e2c6cd638a6afad99d9
Parents: 2a1e857
Author: Matthias Boehm <[email protected]>
Authored: Wed Jun 13 20:12:00 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Jun 13 20:12:32 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/ipa/FunctionCallGraph.java       | 89 ++++++++++++++++++++
 .../hops/ipa/IPAPassEliminateDeadCode.java      | 27 ++++--
 .../misc/IPADeadCodeEliminationTest.java        | 32 ++++++-
 .../functions/misc/IPADeadCodeRemoval_Fun2.dml  | 31 +++++++
 .../functions/misc/IPADeadCodeRemoval_Fun3.dml  | 32 +++++++
 5 files changed, 199 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java 
b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
index b1bf301..e889703 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
@@ -32,6 +32,9 @@ import java.util.stream.Collectors;
 import org.apache.sysml.hops.FunctionOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOpN;
 import org.apache.sysml.hops.rewrite.HopRewriteUtils;
 import org.apache.sysml.parser.DMLProgram;
 import org.apache.sysml.parser.ExternalFunctionStatement;
@@ -63,6 +66,9 @@ public class FunctionCallGraph
        //subset of direct or indirect recursive functions
        private final HashSet<String> _fRecursive;
 
+       //subset of side-effect-free functions
+       private final HashSet<String> _fSideEffectFree;
+       
        // a boolean value to indicate if exists the second order function 
(e.g. eval, paramserv)
        // and the UDFs that are marked secondorder="true"
        private final boolean _containsSecondOrder;
@@ -78,6 +84,7 @@ public class FunctionCallGraph
                _fCalls = new HashMap<>();
                _fCallsSB = new HashMap<>();
                _fRecursive = new HashSet<>();
+               _fSideEffectFree = new HashSet<>();
                _containsSecondOrder = constructFunctionCallGraph(prog);
        }
        
@@ -92,6 +99,7 @@ public class FunctionCallGraph
                _fCalls = new HashMap<>();
                _fCallsSB = new HashMap<>();
                _fRecursive = new HashSet<>();
+               _fSideEffectFree = new HashSet<>();
                _containsSecondOrder = constructFunctionCallGraph(sb);
        }
 
@@ -190,6 +198,33 @@ public class FunctionCallGraph
        }
        
        /**
+        * Indicates if the given function is side effect free, i.e., has no
+        * prints, no persistent write, and includes no or only calls to
+        * side-effect-free functions.
+        * 
+        * @param fnamespace function namespace
+        * @param fname function name
+        * @return true if the given function is side-effect-free, false 
otherwise
+        */
+       public boolean isSideEffectFreeFunction(String fnamespace, String 
fname) {
+               return isSideEffectFreeFunction(
+                       DMLProgram.constructFunctionKey(fnamespace, fname));
+       }
+       
+       /**
+        * Indicates if the given function is side effect free, i.e., has no
+        * prints, no persistent write, and includes no or only calls to
+        * side-effect-free functions.
+        * 
+        * @param fkey function key of calling function, null indicates the 
main program
+        * @return true if the given function is side-effect-free, false 
otherwise
+        */
+       public boolean isSideEffectFreeFunction(String fkey) {
+               String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey;
+               return _fSideEffectFree.contains(lfkey);
+       }
+       
+       /**
         * Returns all functions that are reachable either directly or 
indirectly
         * form the main program, except the main program itself.
         * 
@@ -255,11 +290,18 @@ public class FunctionCallGraph
                
                boolean ret = false;
                try {
+                       //construct the main function call graph
                        Stack<String> fstack = new Stack<>();
                        HashSet<String> lfset = new HashSet<>();
                        _fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
                        for( StatementBlock sblk : prog.getStatementBlocks() )
                                ret |= 
rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sblk, fstack, lfset);
+                       
+                       //analyze all non-recursive functions if free of side 
effects
+                       _fSideEffectFree.addAll(_fCalls.keySet().stream()
+                               .filter(s -> 
!s.startsWith(DMLProgram.INTERNAL_NAMESPACE))
+                               .filter(s -> 
isSideEffectFree(prog.getFunctionStatementBlock(s)))
+                               .collect(Collectors.toList()));
                }
                catch(HopsException ex) {
                        throw new RuntimeException(ex);
@@ -371,4 +413,51 @@ public class FunctionCallGraph
                
                return ret;
        }
+       
+       private static boolean isSideEffectFree(FunctionStatementBlock fsb) {
+               //check for side-effect-free external functions (explicit 
annotation)
+               if( fsb.getStatement(0) instanceof ExternalFunctionStatement
+                       && 
!((ExternalFunctionStatement)fsb.getStatement(0)).hasSideEffects() ) {
+                       return true;
+               }
+               //check regular dml-bodied function for prints, pwrite, and 
other functions
+               FunctionStatement fstmt = (FunctionStatement) 
fsb.getStatement(0);
+               for( StatementBlock csb : fstmt.getBody() )
+                       if( rHasSideEffects(csb) )
+                               return false;
+               return true;
+       }
+       
+       private static boolean rHasSideEffects(StatementBlock sb) {
+               boolean ret = false;
+               if( sb instanceof ForStatementBlock ) {
+                       ForStatement fstmt = (ForStatement) sb.getStatement(0);
+                       for( StatementBlock csb : fstmt.getBody() )
+                               ret |= rHasSideEffects(csb);
+               }
+               else if( sb instanceof WhileStatementBlock ) {
+                       WhileStatement wstmt = (WhileStatement) 
sb.getStatement(0);
+                       for( StatementBlock csb : wstmt.getBody() )
+                               ret |= rHasSideEffects(csb);
+               }
+               else if( sb instanceof IfStatementBlock ) {
+                       IfStatement istmt = (IfStatement) sb.getStatement(0);
+                       for( StatementBlock csb : istmt.getIfBody() )
+                               ret |= rHasSideEffects(csb);
+                       if( istmt.getElseBody() != null )
+                               for( StatementBlock csb : istmt.getElseBody() )
+                                       ret |= rHasSideEffects(csb);
+               }
+               else if( sb.getHops() != null ) {
+                       //check for print, printf, pwrite, function calls, all 
of
+                       //which can only appear as root nodes in the DAG
+                       for( Hop root : sb.getHops() ) {
+                               ret |= HopRewriteUtils.isUnary(root, 
OpOp1.PRINT)
+                                       || HopRewriteUtils.isNary(root, 
OpOpN.PRINTF)
+                                       || HopRewriteUtils.isData(root, 
DataOpTypes.PERSISTENTWRITE)
+                                       || root instanceof FunctionOp;
+                       }
+               }
+               return ret;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java
index 2fb4338..9ce79b0 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java
@@ -19,10 +19,12 @@
 
 package org.apache.sysml.hops.ipa;
 
+import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
+import org.apache.sysml.hops.FunctionOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.Hop.DataOpTypes;
 import org.apache.sysml.hops.rewrite.HopRewriteUtils;
@@ -52,7 +54,7 @@ public class IPAPassEliminateDeadCode extends IPAPass
        @Override
        public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) {
                // step 1: backwards pass over main program to track used and 
remove unused vars
-               findAndRemoveDeadCode(prog.getStatementBlocks(), new 
HashSet<>());
+               findAndRemoveDeadCode(prog.getStatementBlocks(), new 
HashSet<>(), fgraph);
                
                // step 2: backwards pass over functions to track used and 
remove unused vars
                for( FunctionStatementBlock fsb : 
prog.getFunctionStatementBlocks() ) {
@@ -62,19 +64,19 @@ public class IPAPassEliminateDeadCode extends IPAPass
                        fstmt.getOutputParams().stream().forEach(d -> 
usedVars.add(d.getName()));
                        
                        // backward pass over function to track used and remove 
unused vars
-                       findAndRemoveDeadCode(fstmt.getBody(), usedVars);
+                       findAndRemoveDeadCode(fstmt.getBody(), usedVars, 
fgraph);
                }
        }
        
-       private void findAndRemoveDeadCode(List<StatementBlock> sbs, 
Set<String> usedVars) {
+       private static void findAndRemoveDeadCode(List<StatementBlock> sbs, 
Set<String> usedVars, FunctionCallGraph fgraph) {
                for( int i=sbs.size()-1; i >= 0; i-- ) {
                        // remove unused assignments
                        if( 
HopRewriteUtils.isLastLevelStatementBlock(sbs.get(i)) ) {
                                List<Hop> roots = sbs.get(i).getHops();
                                for( int j=0; j<roots.size(); j++ ) {
                                        Hop root = roots.get(j);
-                                       if( HopRewriteUtils.isData(root, 
DataOpTypes.TRANSIENTWRITE)
-                                               && 
!usedVars.contains(root.getName()) ) {
+                                       if( (HopRewriteUtils.isData(root, 
DataOpTypes.TRANSIENTWRITE) && !usedVars.contains(root.getName()))
+                                               || 
isFunctionCallWithUnusedOutputs(root, usedVars, fgraph) ) {
                                                roots.remove(j); j--;
                                                rRemoveOpFromDAG(root);
                                        }
@@ -88,7 +90,14 @@ public class IPAPassEliminateDeadCode extends IPAPass
                }
        }
        
-       private void rRemoveOpFromDAG(Hop current) {
+       private static boolean isFunctionCallWithUnusedOutputs(Hop hop, 
Set<String> varNames, FunctionCallGraph fgraph) {
+               return hop instanceof FunctionOp
+                       && 
fgraph.isSideEffectFreeFunction(((FunctionOp)hop).getFunctionKey())
+                       && Arrays.stream(((FunctionOp) 
hop).getOutputVariableNames())
+                               .allMatch(var -> !varNames.contains(var));
+       }
+       
+       private static void rRemoveOpFromDAG(Hop current) {
                for( int i=0; i<current.getInput().size(); i++ ) {
                        Hop c = current.getInput().get(i);
                        HopRewriteUtils.removeChildReference(current, c);
@@ -97,7 +106,7 @@ public class IPAPassEliminateDeadCode extends IPAPass
                }
        }
        
-       private Set<String> rCollectReadVariableNames(StatementBlock sb, 
Set<String> varNames) {
+       private static Set<String> rCollectReadVariableNames(StatementBlock sb, 
Set<String> varNames) {
                if( sb instanceof WhileStatementBlock ) {
                        WhileStatementBlock wsb = (WhileStatementBlock) sb;
                        WhileStatement wstmt = (WhileStatement) 
sb.getStatement(0);
@@ -132,14 +141,14 @@ public class IPAPassEliminateDeadCode extends IPAPass
                return varNames;
        }
        
-       private Set<String> collectReadVariableNames(Hop hop, Set<String> 
varNames) {
+       private static Set<String> collectReadVariableNames(Hop hop, 
Set<String> varNames) {
                if( hop == null )
                        return varNames;
                hop.resetVisitStatus();
                return rCollectReadVariableNames(hop, varNames);
        }
        
-       private Set<String> rCollectReadVariableNames(Hop hop, Set<String> 
varNames) {
+       private static Set<String> rCollectReadVariableNames(Hop hop, 
Set<String> varNames) {
                if( hop.isVisited() )
                        return varNames;
                for( Hop c : hop.getInput() )

http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java
index 1516969..3afe992 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java
@@ -31,6 +31,8 @@ public class IPADeadCodeEliminationTest extends 
AutomatedTestBase
 {
        private final static String TEST_NAME1 = "IPADeadCodeRemoval_Main";
        private final static String TEST_NAME2 = "IPADeadCodeRemoval_Fun";
+       private final static String TEST_NAME3 = "IPADeadCodeRemoval_Fun2";
+       private final static String TEST_NAME4 = "IPADeadCodeRemoval_Fun3"; 
//w/ print
        
        private final static String TEST_DIR = "functions/misc/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
IPADeadCodeEliminationTest.class.getSimpleName() + "/";
@@ -38,8 +40,10 @@ public class IPADeadCodeEliminationTest extends 
AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" })   );
-               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" })   );
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
        }
 
        @Test
@@ -61,6 +65,26 @@ public class IPADeadCodeEliminationTest extends 
AutomatedTestBase
        public void testDeadCodeRemovalFunIPA() {
                runIPALiteralReplacementTest( TEST_NAME2, true );
        }
+       
+       @Test
+       public void testDeadCodeRemovalFun2NoIPA() {
+               runIPALiteralReplacementTest( TEST_NAME3, false );
+       }
+       
+       @Test
+       public void testDeadCodeRemovalFun2IPA() {
+               runIPALiteralReplacementTest( TEST_NAME3, true );
+       }
+       
+       @Test
+       public void testDeadCodeRemovalFun3NoIPA() {
+               runIPALiteralReplacementTest( TEST_NAME4, false );
+       }
+       
+       @Test
+       public void testDeadCodeRemovalFun3IPA() {
+               runIPALiteralReplacementTest( TEST_NAME4, true );
+       }
 
        private void runIPALiteralReplacementTest( String testname, boolean IPA 
)
        {
@@ -75,8 +99,10 @@ public class IPADeadCodeEliminationTest extends 
AutomatedTestBase
                        OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA;
                        runTest(true, false, null, -1);
                        
-                       if( IPA ) //check for applied dead code removal
+                       if( IPA && !testname.equals(TEST_NAME4) ) //check for 
applied dead code removal
                                
Assert.assertTrue(!heavyHittersContainsString("uak+"));
+                       if( testname.equals(TEST_NAME4) )
+                               
Assert.assertTrue(heavyHittersContainsString("uak+"));
                }
                finally {
                        OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = 
oldFlagIPA;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun2.dml 
b/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun2.dml
new file mode 100644
index 0000000..e243ed3
--- /dev/null
+++ b/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun2.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(matrix[double] A) return (matrix[double] B) {
+  C = A + 7;
+  while(FALSE){}
+  B = C + sum(A + C);
+}
+
+A = matrix(7, 10, 10);
+B = foo(A);
+if( 1 == 0 )
+  print("sum = "+sum(B));

http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun3.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun3.dml 
b/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun3.dml
new file mode 100644
index 0000000..9fa20db
--- /dev/null
+++ b/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun3.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(matrix[double] A) return (matrix[double] B) {
+  C = A + 7;
+  while(FALSE){}
+  B = C + sum(A + C);
+  print(sum(B));
+}
+
+A = matrix(7, 10, 10);
+B = foo(A);
+if( 1 == 0 )
+  print("sum = "+sum(B));

Reply via email to