Repository: systemml Updated Branches: refs/heads/master 948943d17 -> ddcb9e019
[SYSTEMML-1691,1692] New IPA passes: literal replacement and rewrites This patch introduces two new passes for inter-procedural analysis (IPA): (1) literal propagation and replacement into functions, and (2) static rewrites, which are both applied for any number of requires IPA iterations. The internal abstraction for function call summaries has been extended accordingly. The new literal propagation and replacement works on a fine granularity of individual function parameters and propagates any literals that are consistent across all function calls, independent of remaining inputs. Together with the additional rewrites pass, this allows rewrites such as constant folding and subsequent removal of branches which can significantly cut down the program size and number of distributed operations. For example, for GLM poisson.log, this change reduced the size of the initial runtime program from 2132/153 to 1164/45 local/distributed instructions, which is now much easier to debug and profile. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1e6639c7 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1e6639c7 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1e6639c7 Branch: refs/heads/master Commit: 1e6639c754961f51bb53754c1fa8b6dce404294a Parents: 948943d Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu Jun 15 20:01:12 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Jun 16 10:01:57 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/ipa/FunctionCallSizeInfo.java | 90 ++++++++++- .../java/org/apache/sysml/hops/ipa/IPAPass.java | 3 +- .../hops/ipa/IPAPassApplyStaticHopRewrites.java | 53 +++++++ .../ipa/IPAPassFlagFunctionsRecompileOnce.java | 2 +- .../ipa/IPAPassPropagateReplaceLiterals.java | 155 +++++++++++++++++++ .../ipa/IPAPassRemoveConstantBinaryOps.java | 2 +- .../IPAPassRemoveUnnecessaryCheckpoints.java | 2 +- .../hops/ipa/IPAPassRemoveUnusedFunctions.java | 2 +- .../sysml/hops/ipa/InterProceduralAnalysis.java | 104 +++++++------ .../org/apache/sysml/parser/DMLTranslator.java | 12 +- .../java/org/apache/sysml/utils/Explain.java | 5 +- ...antFoldingScalarVariablePropagationTest.java | 17 +- 12 files changed, 360 insertions(+), 87 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java index 20054a2..402e780 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java +++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; import org.apache.sysml.hops.FunctionOp; @@ -52,10 +53,16 @@ public class FunctionCallSizeInfo //to subsequent statement blocks and functions) private final Set<String> _fcandUnary; - //indicators for which function arguments it is safe to propagate nnz + //indicators for which function arguments of valid functions it + //is safe to propagate the number of non-zeros //(mapping from function keys to set of function input HopIDs) private final Map<String, Set<Long>> _fcandSafeNNZ; + //indicators which literal function arguments can be safely + //propagated into and replaced in the respective functions + //(mapping from function keys to set of function input positions) + private final Map<String, Set<Integer>> _fSafeLiterals; + /** * Constructs the function call summary for all functions * reachable from the main program. @@ -84,6 +91,7 @@ public class FunctionCallSizeInfo _fcand = new HashSet<String>(); _fcandUnary = new HashSet<String>(); _fcandSafeNNZ = new HashMap<String, Set<Long>>(); + _fSafeLiterals = new HashMap<String, Set<Integer>>(); constructFunctionCallSizeInfo(); } @@ -169,17 +177,44 @@ public class FunctionCallSizeInfo * * @param fkey function key * @param inputHopID hop ID of the input - * @return true if nnz can safely be propageted + * @return true if nnz can safely be propagated */ public boolean isSafeNnz(String fkey, long inputHopID) { return _fcandSafeNNZ.containsKey(fkey) && _fcandSafeNNZ.get(fkey).contains(inputHopID); } + /** + * Indicates if the given function has at least one input + * that allows for safe literal propagation and replacement, + * i.e., all function calls have consistent literal inputs. + * + * @param fkey function key + * @return true if a literal can be safely propagated + */ + public boolean hasSafeLiterals(String fkey) { + return _fSafeLiterals.containsKey(fkey) + && !_fSafeLiterals.get(fkey).isEmpty(); + } + + /** + * Indicates if the given function input allows for safe + * literal propagation and replacement, i.e., all function calls + * have consistent literal inputs. + * + * @param fkey function key + * @param pos function input position + * @return true if literal that can be safely propagated + */ + public boolean isSafeLiteral(String fkey, int pos) { + return _fSafeLiterals.containsKey(fkey) + && _fSafeLiterals.get(fkey).contains(pos); + } + private void constructFunctionCallSizeInfo() throws HopsException { - //determine function candidates by evaluating all function calls + //step 1: determine function candidates by evaluating all function calls for( String fkey : _fgraph.getReachableFunctions() ) { List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey); @@ -215,7 +250,8 @@ public class FunctionCallSizeInfo } } - //determine safe nnz propagation per input + //step 2: determine safe nnz propagation per input + //(considered for valid functions only) for( String fkey : _fcand ) { FunctionOp first = _fgraph.getFunctionCalls(fkey).get(0); HashSet<Long> tmp = new HashSet<Long>(); @@ -227,13 +263,38 @@ public class FunctionCallSizeInfo } _fcandSafeNNZ.put(fkey, tmp); } + + //step 3: determine safe literal replacement per function input + //(considered for all functions) + for( String fkey : _fgraph.getReachableFunctions() ) { + List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey); + FunctionOp first = flist.get(0); + //initialize w/ all literals of first call + HashSet<Integer> tmp = new HashSet<Integer>(); + for( int j=0; j<first.getInput().size(); j++ ) + if( first.getInput().get(j) instanceof LiteralOp ) + tmp.add(j); + //check consistency across all function calls + for( int i=1; i<flist.size(); i++ ) { + FunctionOp other = flist.get(i); + for( int j=0; j<first.getInput().size(); j++ ) + if( tmp.contains(j) ) { + Hop h1 = first.getInput().get(j); + Hop h2 = other.getInput().get(j); + if( !(h2 instanceof LiteralOp && HopRewriteUtils + .isEqualValue((LiteralOp)h1, (LiteralOp)h2)) ) + tmp.remove(j); + } + } + _fSafeLiterals.put(fkey, tmp); + } } @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append("Valid Functions for Propagation: \n"); + sb.append("Valid functions for propagation: \n"); for( String fkey : getValidFunctions() ) { sb.append("--"); sb.append(fkey); @@ -247,7 +308,7 @@ public class FunctionCallSizeInfo } if( !getInvalidFunctions().isEmpty() ) { - sb.append("Invaid Functions for Propagation: \n"); + sb.append("Invaid functions for propagation: \n"); for( String fkey : getInvalidFunctions() ) { sb.append("--"); sb.append(fkey); @@ -258,7 +319,7 @@ public class FunctionCallSizeInfo } if( !getDimsPreservingFunctions().isEmpty() ) { - sb.append("Dims-Preserving Functions: \n"); + sb.append("Dimensions-preserving functions: \n"); for( String fkey : getDimsPreservingFunctions() ) { sb.append("--"); sb.append(fkey); @@ -268,6 +329,21 @@ public class FunctionCallSizeInfo } } + sb.append("Valid scalars for propagation: \n"); + for( Entry<String, Set<Integer>> e : _fSafeLiterals.entrySet() ) { + sb.append("--"); + sb.append(e.getKey()); + sb.append(": "); + for( Integer pos : e.getValue() ) { + sb.append(pos); + sb.append(":"); + sb.append(_fgraph.getFunctionCalls(e.getKey()) + .get(0).getInput().get(pos).getName()); + sb.append(" "); + } + sb.append("\n"); + } + return sb.toString(); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java index cfd9df7..ced407e 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java @@ -46,8 +46,9 @@ public abstract class IPAPass * * @param prog dml program * @param fgraph function call graph + * @param fcallSizes function call size infos * @throws HopsException */ - public abstract void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph ) + public abstract void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) throws HopsException; } http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java new file mode 100644 index 0000000..f436658 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassApplyStaticHopRewrites.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.ipa; + + +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.rewrite.ProgramRewriter; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.LanguageException; + +/** + * This rewrite applies static hop dag and statement block + * rewrites such as constant folding and branch removal + * in order to simplify statistic propagation. + * + */ +public class IPAPassApplyStaticHopRewrites extends IPAPass +{ + @Override + public boolean isApplicable() { + return InterProceduralAnalysis.APPLY_STATIC_REWRITES; + } + + @Override + public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) + throws HopsException + { + try { + ProgramRewriter rewriter = new ProgramRewriter(true, false); + rewriter.rewriteProgramHopDAGs(prog); + } + catch (LanguageException ex) { + throw new HopsException(ex); + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java index ee072e4..82f4681 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java @@ -48,7 +48,7 @@ public class IPAPassFlagFunctionsRecompileOnce extends IPAPass } @Override - public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph ) + public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) throws HopsException { try { http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java new file mode 100644 index 0000000..57647ff --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassPropagateReplaceLiterals.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.ipa; + +import java.util.ArrayList; + +import org.apache.sysml.hops.FunctionOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.DataIdentifier; +import org.apache.sysml.parser.ForStatement; +import org.apache.sysml.parser.ForStatementBlock; +import org.apache.sysml.parser.FunctionStatement; +import org.apache.sysml.parser.FunctionStatementBlock; +import org.apache.sysml.parser.IfStatement; +import org.apache.sysml.parser.IfStatementBlock; +import org.apache.sysml.parser.StatementBlock; +import org.apache.sysml.parser.WhileStatement; +import org.apache.sysml.parser.WhileStatementBlock; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; +import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory; + +/** + * This rewrite propagates and replaces literals into functions + * in order to enable subsequent rewrites such as branch removal. + * + */ +public class IPAPassPropagateReplaceLiterals extends IPAPass +{ + @Override + public boolean isApplicable() { + return InterProceduralAnalysis.PROPAGATE_SCALAR_LITERALS; + } + + @Override + public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) + throws HopsException + { + for( String fkey : fgraph.getReachableFunctions() ) { + FunctionOp first = fgraph.getFunctionCalls(fkey).get(0); + + //propagate and replace amenable literals into function + if( fcallSizes.hasSafeLiterals(fkey) ) { + FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey); + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + ArrayList<DataIdentifier> finputs = fstmt.getInputParams(); + + //populate call vars with amenable literals + LocalVariableMap callVars = new LocalVariableMap(); + for( int j=0; j<finputs.size(); j++ ) + if( fcallSizes.isSafeLiteral(fkey, j) ) { + LiteralOp lit = (LiteralOp) first.getInput().get(j); + callVars.put(finputs.get(j).getName(), ScalarObjectFactory + .createScalarObject(lit.getValueType(), lit)); + } + + //propagate and replace literals + for( StatementBlock sb : fstmt.getBody() ) + rReplaceLiterals(sb, callVars); + } + } + } + + private void rReplaceLiterals(StatementBlock sb, LocalVariableMap constants) + throws HopsException + { + //remove updated literals + for( String varname : sb.variablesUpdated().getVariableNames() ) + if( constants.keySet().contains(varname) ) + constants.remove(varname); + + //propagate and replace literals + if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement ws = (WhileStatement)sb.getStatement(0); + replaceLiterals(wsb.getPredicateHops(), constants); + for (StatementBlock current : ws.getBody()) + rReplaceLiterals(current, constants); + } + else if (sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement ifs = (IfStatement) sb.getStatement(0); + replaceLiterals(isb.getPredicateHops(), constants); + for (StatementBlock current : ifs.getIfBody()) + rReplaceLiterals(current, constants); + for (StatementBlock current : ifs.getElseBody()) + rReplaceLiterals(current, constants); + } + else if (sb instanceof ForStatementBlock) { + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fs = (ForStatement)sb.getStatement(0); + replaceLiterals(fsb.getFromHops(), constants); + replaceLiterals(fsb.getToHops(), constants); + replaceLiterals(fsb.getIncrementHops(), constants); + for (StatementBlock current : fs.getBody()) + rReplaceLiterals(current, constants); + } + else { + replaceLiterals(sb.get_hops(), constants); + } + } + + private void replaceLiterals(ArrayList<Hop> roots, LocalVariableMap constants) + throws HopsException + { + if( roots == null ) + return; + + try { + Hop.resetVisitStatus(roots); + for( Hop root : roots ) + Recompiler.rReplaceLiterals(root, constants, true); + Hop.resetVisitStatus(roots); + } + catch(Exception ex) { + throw new HopsException(ex); + } + } + + private void replaceLiterals(Hop root, LocalVariableMap constants) + throws HopsException + { + if( root == null ) + return; + + try { + root.resetVisitStatus(); + Recompiler.rReplaceLiterals(root, constants, true); + root.resetVisitStatus(); + } + catch(Exception ex) { + throw new HopsException(ex); + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java index c71ed45..1a433a3 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java @@ -57,7 +57,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass } @Override - public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph ) + public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) throws HopsException { //approach: scan over top-level program (guaranteed to be unconditional), http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java index 20c47da..664ec2a 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java @@ -56,7 +56,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends IPAPass } @Override - public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph ) + public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) throws HopsException { //remove unnecessary checkpoint before update http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java index 3424a52..9d41ca6 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java +++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java @@ -43,7 +43,7 @@ public class IPAPassRemoveUnusedFunctions extends IPAPass } @Override - public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph ) + public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) throws HopsException { try { http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java index 1d997ed..7d371ac 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java @@ -94,6 +94,8 @@ public class InterProceduralAnalysis protected static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true; //remove unnecessary checkpoints (unconditionally overwritten intermediates) protected static final boolean REMOVE_CONSTANT_BINARY_OPS = true; //remove constant binary operations (e.g., X*ones, where ones=matrix(1,...)) protected static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true; //propagate scalar variables into functions that are called once + protected static final boolean PROPAGATE_SCALAR_LITERALS = true; //propagate and replace scalar literals into functions + protected static final boolean APPLY_STATIC_REWRITES = true; //apply static hop dag and statement block rewrites static { // for internal debugging only @@ -132,6 +134,8 @@ public class InterProceduralAnalysis _passes.add(new IPAPassFlagFunctionsRecompileOnce()); _passes.add(new IPAPassRemoveUnnecessaryCheckpoints()); _passes.add(new IPAPassRemoveConstantBinaryOps()); + _passes.add(new IPAPassPropagateReplaceLiterals()); + _passes.add(new IPAPassApplyStaticHopRewrites()); } public InterProceduralAnalysis(StatementBlock sb) { @@ -145,39 +149,64 @@ public class InterProceduralAnalysis } /** - * Public interface to perform IPA over a given DML program. + * Main interface to perform IPA over a given DML program. * - * @param dmlp the dml program - * @throws HopsException if HopsException occurs + * @throws HopsException in case of compilation errors */ - public void analyzeProgram() - throws HopsException + public void analyzeProgram() throws HopsException { + analyzeProgram(1); //single run + } + + /** + * Main interface to perform IPA over a given DML program. + * + * @param repetitions number of IPA rounds + * @throws HopsException in case of compilation errors + */ + public void analyzeProgram(int repetitions) + throws HopsException { - //step 1: intra- and inter-procedural - if( INTRA_PROCEDURAL_ANALYSIS ) { + //sanity check for valid number of repetitions + if( repetitions <= 0 ) + throw new HopsException("Invalid number of IPA repetitions: " + repetitions); + + //perform number of requested IPA iterations + for( int i=0; i<repetitions; i++ ) { + if( LOG.isDebugEnabled() ) + LOG.debug("IPA: start IPA iteration " + (i+1) + "/" + repetitions +"."); + //get function call size infos to obtain candidates for statistics propagation FunctionCallSizeInfo fcallSizes = new FunctionCallSizeInfo(_fgraph); if( LOG.isDebugEnabled() ) LOG.debug("IPA: Initial FunctionCallSummary: \n" + fcallSizes); - //get unary dimension-preserving non-candidate functions - for( String tmp : fcallSizes.getInvalidFunctions() ) - if( isUnarySizePreservingFunction(_prog.getFunctionStatementBlock(tmp)) ) - fcallSizes.addDimsPreservingFunction(tmp); - if( LOG.isDebugEnabled() ) - LOG.debug("IPA: Extended FunctionCallSummary: \n" + fcallSizes); + //step 1: intra- and inter-procedural + if( INTRA_PROCEDURAL_ANALYSIS ) { + //get unary dimension-preserving non-candidate functions + for( String tmp : fcallSizes.getInvalidFunctions() ) + if( isUnarySizePreservingFunction(_prog.getFunctionStatementBlock(tmp)) ) + fcallSizes.addDimsPreservingFunction(tmp); + if( LOG.isDebugEnabled() ) + LOG.debug("IPA: Extended FunctionCallSummary: \n" + fcallSizes); + + //propagate statistics and scalars into functions and across DAGs + //(callVars used to chain outputs/inputs of multiple functions calls) + LocalVariableMap callVars = new LocalVariableMap(); + for ( StatementBlock sb : _prog.getStatementBlocks() ) //propagate stats into candidates + propagateStatisticsAcrossBlock( sb, callVars, fcallSizes, new HashSet<String>() ); + } - //propagate statistics and scalars into functions and across DAGs - //(callVars used to chain outputs/inputs of multiple functions calls) - LocalVariableMap callVars = new LocalVariableMap(); - for ( StatementBlock sb : _prog.getStatementBlocks() ) //propagate stats into candidates - propagateStatisticsAcrossBlock( sb, callVars, fcallSizes, new HashSet<String>() ); + //step 2: apply additional IPA passes + for( IPAPass pass : _passes ) + if( pass.isApplicable() ) + pass.rewriteProgram(_prog, _fgraph, fcallSizes); } - //step 2: apply additional IPA passes - for( IPAPass pass : _passes ) - if( pass.isApplicable() ) - pass.rewriteProgram(_prog, _fgraph); + //cleanup pass: remove unused functions + FunctionCallGraph graph2 = new FunctionCallGraph(_prog); + IPAPass rmFuns = new IPAPassRemoveUnusedFunctions(); + if( rmFuns.isApplicable() ) + rmFuns.rewriteProgram(_prog, graph2, null); } public Set<String> analyzeSubProgram() @@ -240,19 +269,6 @@ public class InterProceduralAnalysis // INTRA-PROCEDURE ANALYSIS ////// - /** - * Perform intra-procedural analysis (IPA) by propagating statistics - * across statement blocks. - * - * @param sb DML statement blocks. - * @param fcand Function candidates. - * @param callVars Map of variables eligible for propagation. - * @param fcandSafeNNZ Function candidate safe non-zeros. - * @param unaryFcands Unary function candidates. - * @param fnStack Function stack to determine current scope. - * @throws HopsException If a HopsException occurs. - * @throws ParseException If a ParseException occurs. - */ private void propagateStatisticsAcrossBlock( StatementBlock sb, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack ) throws HopsException { @@ -421,16 +437,12 @@ public class InterProceduralAnalysis * * @param prog The DML program. * @param roots List of HOP DAG root notes for propagation. - * @param fcand Function candidates. - * @param callVars Calling program's map of variables eligible for - * propagation. - * @param fcandSafeNNZ Function candidate safe non-zeros. - * @param unaryFcands Unary function candidates. + * @param callVars Calling program's map of variables eligible for propagation. + * @param fcallSizes function call summary * @param fnStack Function stack to determine current scope. * @throws HopsException If a HopsException occurs. - * @throws ParseException If a ParseException occurs. */ - private void propagateStatisticsIntoFunctions(DMLProgram prog, ArrayList<Hop> roots, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack ) + private void propagateStatisticsIntoFunctions(DMLProgram prog, ArrayList<Hop> roots, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack) throws HopsException { for( Hop root : roots ) @@ -443,14 +455,10 @@ public class InterProceduralAnalysis * * @param prog The DML program. * @param hop HOP to propagate statistics into. - * @param fcand Function candidates. - * @param callVars Calling program's map of variables eligible for - * propagation. - * @param fcandSafeNNZ Function candidate safe non-zeros. - * @param unaryFcands Unary function candidates. + * @param callVars Calling program's map of variables eligible for propagation. + * @param fcallSizes function call summary * @param fnStack Function stack to determine current scope. * @throws HopsException If a HopsException occurs. - * @throws ParseException If a ParseException occurs. */ private void propagateStatisticsIntoFunctions(DMLProgram prog, Hop hop, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack ) throws HopsException http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 47446f6..42ab12e 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -266,18 +266,8 @@ public class DMLTranslator //propagate size information from main into functions (but conservatively) if( OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS ) { InterProceduralAnalysis ipa = new InterProceduralAnalysis(dmlp); - ipa.analyzeProgram(); + ipa.analyzeProgram(OptimizerUtils.ALLOW_IPA_SECOND_CHANCE ? 2 : 1); resetHopsDAGVisitStatus(dmlp); - if (OptimizerUtils.ALLOW_IPA_SECOND_CHANCE) { - // SECOND CHANCE: - // Rerun static rewrites + IPA to allow for further improvements, such as making use - // of constant folding (static rewrite) after scalar -> literal replacement (IPA), - // and then further scalar -> literal replacement (IPA). - rewriter.rewriteProgramHopDAGs(dmlp); - resetHopsDAGVisitStatus(dmlp); - ipa.analyzeProgram(); - resetHopsDAGVisitStatus(dmlp); - } } //apply hop rewrites (dynamic rewrites, after IPA) http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/main/java/org/apache/sysml/utils/Explain.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java index af7102b..5cf0548 100644 --- a/src/main/java/org/apache/sysml/utils/Explain.java +++ b/src/main/java/org/apache/sysml/utils/Explain.java @@ -229,11 +229,12 @@ public class Explain for (String fname : prog.getFunctionStatementBlocks(namespace).keySet()) { FunctionStatementBlock fsb = prog.getFunctionStatementBlock(namespace, fname); FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); + String fkey = DMLProgram.constructFunctionKey(namespace, fname); if (fstmt instanceof ExternalFunctionStatement) - sb.append("----EXTERNAL FUNCTION " + namespace + "::" + fname + "\n"); + sb.append("----EXTERNAL FUNCTION " + fkey + "\n"); else { - sb.append("----FUNCTION " + namespace + "::" + fname + " [recompile="+fsb.isRecompileOnce()+"]\n"); + sb.append("----FUNCTION " + fkey + " [recompile="+fsb.isRecompileOnce()+"]\n"); for (StatementBlock current : fstmt.getBody()) sb.append(explainStatementBlock(current, 3)); } http://git-wip-us.apache.org/repos/asf/systemml/blob/1e6639c7/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java index a73fe5b..e4ff6c3 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAConstantFoldingScalarVariablePropagationTest.java @@ -110,20 +110,9 @@ public class IPAConstantFoldingScalarVariablePropagationTest extends AutomatedTe runTest(true, false, null, -1); // Check for correct number of compiled & executed Spark jobs - if (IPA_SECOND_CHANCE) { - // No distributed instructions compiled/executed with second chance enabled - checkNumCompiledSparkInst(0); - checkNumExecutedSparkInst(0); - } else { - // without second chance enabled, distributed jobs will be compiled/executed - if (testname == TEST_NAME1) { - checkNumCompiledSparkInst(2); - checkNumExecutedSparkInst(1); - } else { //if (testname == TEST_NAME2) { - checkNumCompiledSparkInst(1); - checkNumExecutedSparkInst(0); - } - } + // (MB: originally, this required a second chance, but not anymore) + checkNumCompiledSparkInst(0); + checkNumExecutedSparkInst(0); } finally { // Reset