[SYSTEMML-2393/4] Eliminate side-effect-free functions w/ unused outputs This patch extends the new IPA pass for dead code elimination by the removal of side-effect-free function calls with unused outputs. As a necessary precondition, we now track (in a best effort manner) side-effect-free functions (explicitly annotated UDFs or DML-bodied functions w/o prints/persistent writes or functions w/ side effects).
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ea6dc8c5 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ea6dc8c5 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ea6dc8c5 Branch: refs/heads/master Commit: ea6dc8c58e7f6d9cbe0b3e2c6cd638a6afad99d9 Parents: 2a1e857 Author: Matthias Boehm <[email protected]> Authored: Wed Jun 13 20:12:00 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jun 13 20:12:32 2018 -0700 ---------------------------------------------------------------------- .../sysml/hops/ipa/FunctionCallGraph.java | 89 ++++++++++++++++++++ .../hops/ipa/IPAPassEliminateDeadCode.java | 27 ++++-- .../misc/IPADeadCodeEliminationTest.java | 32 ++++++- .../functions/misc/IPADeadCodeRemoval_Fun2.dml | 31 +++++++ .../functions/misc/IPADeadCodeRemoval_Fun3.dml | 32 +++++++ 5 files changed, 199 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java index b1bf301..e889703 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java +++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java @@ -32,6 +32,9 @@ import java.util.stream.Collectors; import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.Hop.DataOpTypes; +import org.apache.sysml.hops.Hop.OpOp1; +import org.apache.sysml.hops.Hop.OpOpN; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.ExternalFunctionStatement; @@ -63,6 +66,9 @@ public class FunctionCallGraph //subset of direct or indirect recursive functions private final HashSet<String> _fRecursive; + //subset of side-effect-free functions + private final HashSet<String> _fSideEffectFree; + // a boolean value to indicate if exists the second order function (e.g. eval, paramserv) // and the UDFs that are marked secondorder="true" private final boolean _containsSecondOrder; @@ -78,6 +84,7 @@ public class FunctionCallGraph _fCalls = new HashMap<>(); _fCallsSB = new HashMap<>(); _fRecursive = new HashSet<>(); + _fSideEffectFree = new HashSet<>(); _containsSecondOrder = constructFunctionCallGraph(prog); } @@ -92,6 +99,7 @@ public class FunctionCallGraph _fCalls = new HashMap<>(); _fCallsSB = new HashMap<>(); _fRecursive = new HashSet<>(); + _fSideEffectFree = new HashSet<>(); _containsSecondOrder = constructFunctionCallGraph(sb); } @@ -190,6 +198,33 @@ public class FunctionCallGraph } /** + * Indicates if the given function is side effect free, i.e., has no + * prints, no persistent write, and includes no or only calls to + * side-effect-free functions. + * + * @param fnamespace function namespace + * @param fname function name + * @return true if the given function is side-effect-free, false otherwise + */ + public boolean isSideEffectFreeFunction(String fnamespace, String fname) { + return isSideEffectFreeFunction( + DMLProgram.constructFunctionKey(fnamespace, fname)); + } + + /** + * Indicates if the given function is side effect free, i.e., has no + * prints, no persistent write, and includes no or only calls to + * side-effect-free functions. + * + * @param fkey function key of calling function, null indicates the main program + * @return true if the given function is side-effect-free, false otherwise + */ + public boolean isSideEffectFreeFunction(String fkey) { + String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey; + return _fSideEffectFree.contains(lfkey); + } + + /** * Returns all functions that are reachable either directly or indirectly * form the main program, except the main program itself. * @@ -255,11 +290,18 @@ public class FunctionCallGraph boolean ret = false; try { + //construct the main function call graph Stack<String> fstack = new Stack<>(); HashSet<String> lfset = new HashSet<>(); _fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>()); for( StatementBlock sblk : prog.getStatementBlocks() ) ret |= rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sblk, fstack, lfset); + + //analyze all non-recursive functions if free of side effects + _fSideEffectFree.addAll(_fCalls.keySet().stream() + .filter(s -> !s.startsWith(DMLProgram.INTERNAL_NAMESPACE)) + .filter(s -> isSideEffectFree(prog.getFunctionStatementBlock(s))) + .collect(Collectors.toList())); } catch(HopsException ex) { throw new RuntimeException(ex); @@ -371,4 +413,51 @@ public class FunctionCallGraph return ret; } + + private static boolean isSideEffectFree(FunctionStatementBlock fsb) { + //check for side-effect-free external functions (explicit annotation) + if( fsb.getStatement(0) instanceof ExternalFunctionStatement + && !((ExternalFunctionStatement)fsb.getStatement(0)).hasSideEffects() ) { + return true; + } + //check regular dml-bodied function for prints, pwrite, and other functions + FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); + for( StatementBlock csb : fstmt.getBody() ) + if( rHasSideEffects(csb) ) + return false; + return true; + } + + private static boolean rHasSideEffects(StatementBlock sb) { + boolean ret = false; + if( sb instanceof ForStatementBlock ) { + ForStatement fstmt = (ForStatement) sb.getStatement(0); + for( StatementBlock csb : fstmt.getBody() ) + ret |= rHasSideEffects(csb); + } + else if( sb instanceof WhileStatementBlock ) { + WhileStatement wstmt = (WhileStatement) sb.getStatement(0); + for( StatementBlock csb : wstmt.getBody() ) + ret |= rHasSideEffects(csb); + } + else if( sb instanceof IfStatementBlock ) { + IfStatement istmt = (IfStatement) sb.getStatement(0); + for( StatementBlock csb : istmt.getIfBody() ) + ret |= rHasSideEffects(csb); + if( istmt.getElseBody() != null ) + for( StatementBlock csb : istmt.getElseBody() ) + ret |= rHasSideEffects(csb); + } + else if( sb.getHops() != null ) { + //check for print, printf, pwrite, function calls, all of + //which can only appear as root nodes in the DAG + for( Hop root : sb.getHops() ) { + ret |= HopRewriteUtils.isUnary(root, OpOp1.PRINT) + || HopRewriteUtils.isNary(root, OpOpN.PRINTF) + || HopRewriteUtils.isData(root, DataOpTypes.PERSISTENTWRITE) + || root instanceof FunctionOp; + } + } + return ret; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java index 2fb4338..9ce79b0 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassEliminateDeadCode.java @@ -19,10 +19,12 @@ package org.apache.sysml.hops.ipa; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Set; +import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.rewrite.HopRewriteUtils; @@ -52,7 +54,7 @@ public class IPAPassEliminateDeadCode extends IPAPass @Override public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) { // step 1: backwards pass over main program to track used and remove unused vars - findAndRemoveDeadCode(prog.getStatementBlocks(), new HashSet<>()); + findAndRemoveDeadCode(prog.getStatementBlocks(), new HashSet<>(), fgraph); // step 2: backwards pass over functions to track used and remove unused vars for( FunctionStatementBlock fsb : prog.getFunctionStatementBlocks() ) { @@ -62,19 +64,19 @@ public class IPAPassEliminateDeadCode extends IPAPass fstmt.getOutputParams().stream().forEach(d -> usedVars.add(d.getName())); // backward pass over function to track used and remove unused vars - findAndRemoveDeadCode(fstmt.getBody(), usedVars); + findAndRemoveDeadCode(fstmt.getBody(), usedVars, fgraph); } } - private void findAndRemoveDeadCode(List<StatementBlock> sbs, Set<String> usedVars) { + private static void findAndRemoveDeadCode(List<StatementBlock> sbs, Set<String> usedVars, FunctionCallGraph fgraph) { for( int i=sbs.size()-1; i >= 0; i-- ) { // remove unused assignments if( HopRewriteUtils.isLastLevelStatementBlock(sbs.get(i)) ) { List<Hop> roots = sbs.get(i).getHops(); for( int j=0; j<roots.size(); j++ ) { Hop root = roots.get(j); - if( HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE) - && !usedVars.contains(root.getName()) ) { + if( (HopRewriteUtils.isData(root, DataOpTypes.TRANSIENTWRITE) && !usedVars.contains(root.getName())) + || isFunctionCallWithUnusedOutputs(root, usedVars, fgraph) ) { roots.remove(j); j--; rRemoveOpFromDAG(root); } @@ -88,7 +90,14 @@ public class IPAPassEliminateDeadCode extends IPAPass } } - private void rRemoveOpFromDAG(Hop current) { + private static boolean isFunctionCallWithUnusedOutputs(Hop hop, Set<String> varNames, FunctionCallGraph fgraph) { + return hop instanceof FunctionOp + && fgraph.isSideEffectFreeFunction(((FunctionOp)hop).getFunctionKey()) + && Arrays.stream(((FunctionOp) hop).getOutputVariableNames()) + .allMatch(var -> !varNames.contains(var)); + } + + private static void rRemoveOpFromDAG(Hop current) { for( int i=0; i<current.getInput().size(); i++ ) { Hop c = current.getInput().get(i); HopRewriteUtils.removeChildReference(current, c); @@ -97,7 +106,7 @@ public class IPAPassEliminateDeadCode extends IPAPass } } - private Set<String> rCollectReadVariableNames(StatementBlock sb, Set<String> varNames) { + private static Set<String> rCollectReadVariableNames(StatementBlock sb, Set<String> varNames) { if( sb instanceof WhileStatementBlock ) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement) sb.getStatement(0); @@ -132,14 +141,14 @@ public class IPAPassEliminateDeadCode extends IPAPass return varNames; } - private Set<String> collectReadVariableNames(Hop hop, Set<String> varNames) { + private static Set<String> collectReadVariableNames(Hop hop, Set<String> varNames) { if( hop == null ) return varNames; hop.resetVisitStatus(); return rCollectReadVariableNames(hop, varNames); } - private Set<String> rCollectReadVariableNames(Hop hop, Set<String> varNames) { + private static Set<String> rCollectReadVariableNames(Hop hop, Set<String> varNames) { if( hop.isVisited() ) return varNames; for( Hop c : hop.getInput() ) http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java index 1516969..3afe992 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPADeadCodeEliminationTest.java @@ -31,6 +31,8 @@ public class IPADeadCodeEliminationTest extends AutomatedTestBase { private final static String TEST_NAME1 = "IPADeadCodeRemoval_Main"; private final static String TEST_NAME2 = "IPADeadCodeRemoval_Fun"; + private final static String TEST_NAME3 = "IPADeadCodeRemoval_Fun2"; + private final static String TEST_NAME4 = "IPADeadCodeRemoval_Fun3"; //w/ print private final static String TEST_DIR = "functions/misc/"; private final static String TEST_CLASS_DIR = TEST_DIR + IPADeadCodeEliminationTest.class.getSimpleName() + "/"; @@ -38,8 +40,10 @@ public class IPADeadCodeEliminationTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); - addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); } @Test @@ -61,6 +65,26 @@ public class IPADeadCodeEliminationTest extends AutomatedTestBase public void testDeadCodeRemovalFunIPA() { runIPALiteralReplacementTest( TEST_NAME2, true ); } + + @Test + public void testDeadCodeRemovalFun2NoIPA() { + runIPALiteralReplacementTest( TEST_NAME3, false ); + } + + @Test + public void testDeadCodeRemovalFun2IPA() { + runIPALiteralReplacementTest( TEST_NAME3, true ); + } + + @Test + public void testDeadCodeRemovalFun3NoIPA() { + runIPALiteralReplacementTest( TEST_NAME4, false ); + } + + @Test + public void testDeadCodeRemovalFun3IPA() { + runIPALiteralReplacementTest( TEST_NAME4, true ); + } private void runIPALiteralReplacementTest( String testname, boolean IPA ) { @@ -75,8 +99,10 @@ public class IPADeadCodeEliminationTest extends AutomatedTestBase OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA; runTest(true, false, null, -1); - if( IPA ) //check for applied dead code removal + if( IPA && !testname.equals(TEST_NAME4) ) //check for applied dead code removal Assert.assertTrue(!heavyHittersContainsString("uak+")); + if( testname.equals(TEST_NAME4) ) + Assert.assertTrue(heavyHittersContainsString("uak+")); } finally { OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = oldFlagIPA; http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun2.dml b/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun2.dml new file mode 100644 index 0000000..e243ed3 --- /dev/null +++ b/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun2.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo = function(matrix[double] A) return (matrix[double] B) { + C = A + 7; + while(FALSE){} + B = C + sum(A + C); +} + +A = matrix(7, 10, 10); +B = foo(A); +if( 1 == 0 ) + print("sum = "+sum(B)); http://git-wip-us.apache.org/repos/asf/systemml/blob/ea6dc8c5/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun3.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun3.dml b/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun3.dml new file mode 100644 index 0000000..9fa20db --- /dev/null +++ b/src/test/scripts/functions/misc/IPADeadCodeRemoval_Fun3.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +foo = function(matrix[double] A) return (matrix[double] B) { + C = A + 7; + while(FALSE){} + B = C + sum(A + C); + print(sum(B)); +} + +A = matrix(7, 10, 10); +B = foo(A); +if( 1 == 0 ) + print("sum = "+sum(B));
