Repository: systemml Updated Branches: refs/heads/master 717831daf -> 009561384
[SYSTEMML-1690] New IPA function call summary for size information This patch introduces a new IPA abstraction called FunctionCallSizeInfo to hold information about function call summaries such as number of function calls, valid functions for statistics propagation, information about consistent dimensions and sparsity across function calls, and dimension-preserving functions. The benefits are (1) a simplified IPA by separating concerns into candidate selection and actual size propagation, (2) avoiding unnecessary passes over the program by leveraging the existing function call graph, and (3) better debugging capabilities in terms of concise internal explain functionality. Furthermore, this is the foundation for fine-grained constant propagation and literal replacement. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/00956138 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/00956138 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/00956138 Branch: refs/heads/master Commit: 009561384e117b34fb34729ea0c59b07fd8de52c Parents: 717831d Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu Jun 15 00:55:28 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Thu Jun 15 00:55:28 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/ipa/FunctionCallGraph.java | 32 +- .../sysml/hops/ipa/FunctionCallSizeInfo.java | 273 +++++++++++++++ .../sysml/hops/ipa/InterProceduralAnalysis.java | 342 ++++--------------- ...IPAPropagationSizeMultipleFunctionsTest.java | 50 +-- 4 files changed, 369 insertions(+), 328 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/00956138/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java index 9e55eaa..4a630c0 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java +++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java @@ -20,11 +20,13 @@ package org.apache.sysml.hops.ipa; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; +import java.util.Set; import java.util.Stack; +import java.util.stream.Collectors; import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; @@ -90,9 +92,9 @@ public class FunctionCallGraph * * @param fnamespace function namespace * @param fname function name - * @return list of function keys (namespace and name) + * @return set of function keys (namespace and name) */ - public Collection<String> getCalledFunctions(String fnamespace, String fname) { + public Set<String> getCalledFunctions(String fnamespace, String fname) { return getCalledFunctions( DMLProgram.constructFunctionKey(fnamespace, fname)); } @@ -101,9 +103,9 @@ public class FunctionCallGraph * Returns all functions called from the given function. * * @param fkey function key of calling function, null indicates the main program - * @return list of function keys (namespace and name) + * @return set of function keys (namespace and name) */ - public Collection<String> getCalledFunctions(String fkey) { + public Set<String> getCalledFunctions(String fkey) { String lfkey = (fkey == null) ? MAIN_FUNCTION_KEY : fkey; return _fGraph.get(lfkey); } @@ -115,7 +117,7 @@ public class FunctionCallGraph * null indicates the main program and returns an empty list * @return list of function call hops */ - public Collection<FunctionOp> getFunctionCalls(String fkey) { + public List<FunctionOp> getFunctionCalls(String fkey) { //main program cannot have function calls if( fkey == null ) return Collections.emptyList(); @@ -153,10 +155,10 @@ public class FunctionCallGraph * Returns all functions that are reachable either directly or indirectly * form the main program, except the main program itself. * - * @return list of function keys (namespace and name) + * @return set of function keys (namespace and name) */ - public Collection<String> getReachableFunctions() { - return getReachableFunctions(Collections.emptyList()); + public Set<String> getReachableFunctions() { + return getReachableFunctions(Collections.emptySet()); } /** @@ -165,14 +167,12 @@ public class FunctionCallGraph * blacklist of function names. * * @param blacklist list of function keys to exclude - * @return list of function keys (namespace and name) + * @return set of function keys (namespace and name) */ - public Collection<String> getReachableFunctions(Collection<String> blacklist) { - HashSet<String> ret = new HashSet<String>(); - for( String tmp : _fGraph.keySet() ) - if( !blacklist.contains(tmp) && !MAIN_FUNCTION_KEY.equals(tmp) ) - ret.add(tmp); - return ret; + public Set<String> getReachableFunctions(Set<String> blacklist) { + return _fGraph.keySet().stream() + .filter(p -> !blacklist.contains(p) && !MAIN_FUNCTION_KEY.equals(p)) + .collect(Collectors.toSet()); } /** http://git-wip-us.apache.org/repos/asf/systemml/blob/00956138/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java new file mode 100644 index 0000000..20054a2 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.ipa; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.sysml.hops.FunctionOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.rewrite.HopRewriteUtils; + +/** + * Auxiliary data structure to hold function call summaries in terms + * of information about number of function calls, consistent dimensions, + * consistent sparsity, and dimension-preserving functions. + * + */ +public class FunctionCallSizeInfo +{ + //basic function call graph to obtain size information + private final FunctionCallGraph _fgraph; + + //functions that are subject to size propagation + //(called once or multiple times with consistent sizes) + private final Set<String> _fcand; + + //functions that are not subject to size propagation + //but preserve the dimensions (used to propagate inputs + //to subsequent statement blocks and functions) + private final Set<String> _fcandUnary; + + //indicators for which function arguments it is safe to propagate nnz + //(mapping from function keys to set of function input HopIDs) + private final Map<String, Set<Long>> _fcandSafeNNZ; + + /** + * Constructs the function call summary for all functions + * reachable from the main program. + * + * @param fgraph function call graph + * @throws HopsException + */ + public FunctionCallSizeInfo(FunctionCallGraph fgraph) + throws HopsException + { + this(fgraph, true); + } + + /** + * Constructs the function call summary for all functions + * reachable from the main program. + * + * @param fgraph function call graph + * @param init initialize function candidates + * @throws HopsException + */ + public FunctionCallSizeInfo(FunctionCallGraph fgraph, boolean init) + throws HopsException + { + _fgraph = fgraph; + _fcand = new HashSet<String>(); + _fcandUnary = new HashSet<String>(); + _fcandSafeNNZ = new HashMap<String, Set<Long>>(); + + constructFunctionCallSizeInfo(); + } + + /** + * Gets the number of function calls to a given function. + * + * @param fkey function key + * @return number of function calls + */ + public int getFunctionCallCount(String fkey) { + return _fgraph.getFunctionCalls(fkey).size(); + } + + /** + * Indicates if the given function is valid for statistics + * propagation. + * + * @param fkey function key + * @return true if valid + */ + public boolean isValidFunction(String fkey) { + return _fcand.contains(fkey); + } + + /** + * Gets the set of functions that are valid for statistics + * propagation. + * + * @return set of function keys + */ + public Set<String> getValidFunctions() { + return _fcand; + } + + /** + * Gets the set of functions that are invalid for statistics + * propagation. This is literally the set of reachable + * functions minus the set of valid functions. + * + * @return set of function keys. + */ + public Set<String> getInvalidFunctions() { + return _fgraph.getReachableFunctions(getValidFunctions()); + } + + /** + * Adds a function to the set of dimension-preserving + * functions. + * + * @param fkey function key + */ + public void addDimsPreservingFunction(String fkey) { + _fcandUnary.add(fkey); + } + + /** + * Gets the set of dimension-preserving functions, i.e., + * functions with one matrix input and output of equal + * dimension sizes. + * + * @return set of function keys + */ + public Set<String> getDimsPreservingFunctions() { + return _fcandUnary; + } + + /** + * Indicates if the given function belongs to the set + * of dimension-preserving functions. + * + * @param fkey function key + * @return true if the function is dimension-preserving + */ + public boolean isDimsPreservingFunction(String fkey) { + return _fcandUnary.contains(fkey); + } + + /** + * Indicates if the given function input allows for safe + * nnz propagation, i.e., all function calls have a consistent + * number of non-zeros. + * + * @param fkey function key + * @param inputHopID hop ID of the input + * @return true if nnz can safely be propageted + */ + public boolean isSafeNnz(String fkey, long inputHopID) { + return _fcandSafeNNZ.containsKey(fkey) + && _fcandSafeNNZ.get(fkey).contains(inputHopID); + } + + private void constructFunctionCallSizeInfo() + throws HopsException + { + //determine function candidates by evaluating all function calls + for( String fkey : _fgraph.getReachableFunctions() ) { + List<FunctionOp> flist = _fgraph.getFunctionCalls(fkey); + + //condition 1: function called just once + if( flist.size() == 1 ) { + _fcand.add(fkey); + } + //condition 2: check for consistent input sizes + else if( InterProceduralAnalysis.ALLOW_MULTIPLE_FUNCTION_CALLS ) { + //compare input matrix characteristics of first against all other calls + FunctionOp first = flist.get(0); + boolean consistent = true; + for( int i=1; i<flist.size(); i++ ) { + FunctionOp other = flist.get(i); + for( int j=0; j<first.getInput().size(); j++ ) { + Hop h1 = first.getInput().get(j); + Hop h2 = other.getInput().get(j); + //check matrix and scalar sizes (if known dims, nnz known/unknown, + // safeness of nnz propagation, determined later per input) + consistent &= (h1.dimsKnown() && h2.dimsKnown() + && h1.getDim1()==h2.getDim1() + && h1.getDim2()==h2.getDim2() + && h1.getNnz()==h2.getNnz() ); + //check literal values (equi value) + if( h1 instanceof LiteralOp ){ + consistent &= (h2 instanceof LiteralOp + && HopRewriteUtils.isEqualValue((LiteralOp)h1, (LiteralOp)h2)); + } + } + } + if( consistent ) + _fcand.add(fkey); + } + } + + //determine safe nnz propagation per input + for( String fkey : _fcand ) { + FunctionOp first = _fgraph.getFunctionCalls(fkey).get(0); + HashSet<Long> tmp = new HashSet<Long>(); + for( Hop input : first.getInput() ) { + //if nnz known it is safe to propagate those nnz because for multiple calls + //we checked of equivalence and hence all calls have the same nnz + if( input.getNnz()>=0 ) + tmp.add(input.getHopID()); + } + _fcandSafeNNZ.put(fkey, tmp); + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + + sb.append("Valid Functions for Propagation: \n"); + for( String fkey : getValidFunctions() ) { + sb.append("--"); + sb.append(fkey); + sb.append(": "); + sb.append(getFunctionCallCount(fkey)); + if( !_fcandSafeNNZ.get(fkey).isEmpty() ) { + sb.append("\n----"); + sb.append(Arrays.toString(_fcandSafeNNZ.get(fkey).toArray(new Long[0]))); + } + sb.append("\n"); + } + + if( !getInvalidFunctions().isEmpty() ) { + sb.append("Invaid Functions for Propagation: \n"); + for( String fkey : getInvalidFunctions() ) { + sb.append("--"); + sb.append(fkey); + sb.append(": "); + sb.append(getFunctionCallCount(fkey)); + sb.append("\n"); + } + } + + if( !getDimsPreservingFunctions().isEmpty() ) { + sb.append("Dims-Preserving Functions: \n"); + for( String fkey : getDimsPreservingFunctions() ) { + sb.append("--"); + sb.append(fkey); + sb.append(": "); + sb.append(getFunctionCallCount(fkey)); + sb.append("\n"); + } + } + + return sb.toString(); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/00956138/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java index 0602208..1d997ed 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java @@ -20,11 +20,7 @@ package org.apache.sysml.hops.ipa; import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; -import java.util.Map; -import java.util.Map.Entry; import java.util.Set; import org.apache.commons.logging.Log; @@ -39,7 +35,6 @@ import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.LiteralOp; -import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.DMLTranslator; @@ -53,7 +48,6 @@ import org.apache.sysml.parser.FunctionStatement; import org.apache.sysml.parser.FunctionStatementBlock; import org.apache.sysml.parser.IfStatement; import org.apache.sysml.parser.IfStatementBlock; -import org.apache.sysml.parser.LanguageException; import org.apache.sysml.parser.ParseException; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.WhileStatement; @@ -100,7 +94,6 @@ public class InterProceduralAnalysis protected static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true; //remove unnecessary checkpoints (unconditionally overwritten intermediates) protected static final boolean REMOVE_CONSTANT_BINARY_OPS = true; //remove constant binary operations (e.g., X*ones, where ones=matrix(1,...)) protected static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true; //propagate scalar variables into functions that are called once - public static boolean UNARY_DIMS_PRESERVING_FUNS = true; //determine and exploit unary dimension preserving functions static { // for internal debugging only @@ -123,9 +116,8 @@ public class InterProceduralAnalysis * Creates a handle for performing inter-procedural analysis * for a given DML program and its associated HOP DAGs. This * call initializes various internal information such as the - * function call graph and auxiliary function call information - * which can be reused across multiple IPA calls (e.g., for - * second chance analysis). + * function call graph which can be reused across multiple IPA + * calls (e.g., for second chance analysis). * */ public InterProceduralAnalysis(DMLProgram dmlp) { @@ -157,45 +149,32 @@ public class InterProceduralAnalysis * * @param dmlp the dml program * @throws HopsException if HopsException occurs - * @throws ParseException if ParseException occurs - * @throws LanguageException if LanguageException occurs */ public void analyzeProgram() - throws HopsException, ParseException, LanguageException + throws HopsException { - //TODO move main IPA into separate IPA pass for size propagation - //together with rework of candidate selection - - //step 1: get candidates for statistics propagation into functions (if required) - Map<String, Integer> fcandCounts = new HashMap<String, Integer>(); - Map<String, FunctionOp> fcandHops = new HashMap<String, FunctionOp>(); - Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, Set<Long>>(); - if( !_prog.getFunctionStatementBlocks().isEmpty() ) { - for ( StatementBlock sb : _prog.getStatementBlocks() ) //get candidates (over entire program) - getFunctionCandidatesForStatisticPropagation( sb, fcandCounts, fcandHops ); - pruneFunctionCandidatesForStatisticPropagation( fcandCounts, fcandHops ); - determineFunctionCandidatesNNZPropagation( fcandHops, fcandSafeNNZ ); - DMLTranslator.resetHopsDAGVisitStatus( _prog ); - } - - //step 2: get unary dimension-preserving non-candidate functions - Collection<String> unaryFcandTmp = _fgraph.getReachableFunctions(fcandCounts.keySet()); - HashSet<String> unaryFcands = new HashSet<String>(); - if( !unaryFcandTmp.isEmpty() && UNARY_DIMS_PRESERVING_FUNS ) { - for( String tmp : unaryFcandTmp ) + //step 1: intra- and inter-procedural + if( INTRA_PROCEDURAL_ANALYSIS ) { + //get function call size infos to obtain candidates for statistics propagation + FunctionCallSizeInfo fcallSizes = new FunctionCallSizeInfo(_fgraph); + if( LOG.isDebugEnabled() ) + LOG.debug("IPA: Initial FunctionCallSummary: \n" + fcallSizes); + + //get unary dimension-preserving non-candidate functions + for( String tmp : fcallSizes.getInvalidFunctions() ) if( isUnarySizePreservingFunction(_prog.getFunctionStatementBlock(tmp)) ) - unaryFcands.add(tmp); - } - - //step 3: propagate statistics and scalars into functions and across DAGs - if( !fcandCounts.isEmpty() || INTRA_PROCEDURAL_ANALYSIS ) { - //(callVars used to chain outputs/inputs of multiple functions calls) + fcallSizes.addDimsPreservingFunction(tmp); + if( LOG.isDebugEnabled() ) + LOG.debug("IPA: Extended FunctionCallSummary: \n" + fcallSizes); + + //propagate statistics and scalars into functions and across DAGs + //(callVars used to chain outputs/inputs of multiple functions calls) LocalVariableMap callVars = new LocalVariableMap(); for ( StatementBlock sb : _prog.getStatementBlocks() ) //propagate stats into candidates - propagateStatisticsAcrossBlock( sb, fcandCounts, callVars, fcandSafeNNZ, unaryFcands, new HashSet<String>() ); + propagateStatisticsAcrossBlock( sb, callVars, fcallSizes, new HashSet<String>() ); } - //step 4: apply additional IPA passes + //step 2: apply additional IPA passes for( IPAPass pass : _passes ) if( pass.isApplicable() ) pass.rewriteProgram(_prog, _fgraph); @@ -206,183 +185,34 @@ public class InterProceduralAnalysis { DMLTranslator.resetHopsDAGVisitStatus(_sb); - //step 1: get candidates for statistics propagation into functions (if required) - Map<String, Integer> fcandCounts = new HashMap<String, Integer>(); - Map<String, FunctionOp> fcandHops = new HashMap<String, FunctionOp>(); - Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, Set<Long>>(); - Set<String> allFCandKeys = new HashSet<String>(); - getFunctionCandidatesForStatisticPropagation( _sb, fcandCounts, fcandHops ); - allFCandKeys.addAll(fcandCounts.keySet()); //cp before pruning - pruneFunctionCandidatesForStatisticPropagation( fcandCounts, fcandHops ); - determineFunctionCandidatesNNZPropagation( fcandHops, fcandSafeNNZ ); - DMLTranslator.resetHopsDAGVisitStatus( _sb ); + //get function call size infos to obtain candidates for statistics propagation + FunctionCallSizeInfo fcallSizes = new FunctionCallSizeInfo(_fgraph); - if( !fcandCounts.isEmpty() ) { - //step 2: propagate statistics into functions and across DAGs - //(callVars used to chain outputs/inputs of multiple functions calls) + //propagate statistics and scalars into functions and across DAGs + //(callVars used to chain outputs/inputs of multiple functions calls) + if( !fcallSizes.getValidFunctions().isEmpty() ) { LocalVariableMap callVars = new LocalVariableMap(); - propagateStatisticsAcrossBlock( _sb, fcandCounts, callVars, fcandSafeNNZ, new HashSet<String>(), new HashSet<String>() ); + propagateStatisticsAcrossBlock( _sb, callVars, fcallSizes, new HashSet<String>() ); } - return fcandCounts.keySet(); - } - - - ///////////////////////////// - // GET FUNCTION CANDIDATES - ////// - - private void getFunctionCandidatesForStatisticPropagation( StatementBlock sb, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops ) - throws HopsException, ParseException - { - if (sb instanceof FunctionStatementBlock) - { - FunctionStatementBlock fsb = (FunctionStatementBlock)sb; - FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); - for (StatementBlock sbi : fstmt.getBody()) - getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); - } - else if (sb instanceof WhileStatementBlock) - { - WhileStatementBlock wsb = (WhileStatementBlock) sb; - WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); - for (StatementBlock sbi : wstmt.getBody()) - getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); - } - else if (sb instanceof IfStatementBlock) - { - IfStatementBlock isb = (IfStatementBlock) sb; - IfStatement istmt = (IfStatement)isb.getStatement(0); - for (StatementBlock sbi : istmt.getIfBody()) - getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); - for (StatementBlock sbi : istmt.getElseBody()) - getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); - } - else if (sb instanceof ForStatementBlock) //incl parfor - { - ForStatementBlock fsb = (ForStatementBlock) sb; - ForStatement fstmt = (ForStatement)fsb.getStatement(0); - for (StatementBlock sbi : fstmt.getBody()) - getFunctionCandidatesForStatisticPropagation(sbi, fcandCounts, fcandHops); - } - else //generic (last-level) - { - ArrayList<Hop> roots = sb.get_hops(); - if( roots != null ) //empty statement blocks - for( Hop root : roots ) - getFunctionCandidatesForStatisticPropagation(sb.getDMLProg(), root, fcandCounts, fcandHops); - } - } - - private void getFunctionCandidatesForStatisticPropagation(DMLProgram prog, Hop hop, Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops ) - throws HopsException, ParseException - { - if( hop.isVisited() ) - return; - - if( hop instanceof FunctionOp && !((FunctionOp)hop).getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) ) - { - //maintain counters and investigate functions if not seen so far - FunctionOp fop = (FunctionOp) hop; - String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); - - if( fcandCounts.containsKey(fkey) ) { - if( ALLOW_MULTIPLE_FUNCTION_CALLS ) - { - //compare input matrix characteristics for both function calls - //(if unknown or difference: maintain counter - this function is no candidate) - boolean consistent = true; - FunctionOp efop = fcandHops.get(fkey); - int numInputs = efop.getInput().size(); - for( int i=0; i<numInputs; i++ ) - { - Hop h1 = efop.getInput().get(i); - Hop h2 = fop.getInput().get(i); - //check matrix and scalar sizes (if known dims, nnz known/unknown, - // safeness of nnz propagation, determined later per input) - consistent &= (h1.dimsKnown() && h2.dimsKnown() - && h1.getDim1()==h2.getDim1() - && h1.getDim2()==h2.getDim2() - && h1.getNnz()==h2.getNnz() ); - //check literal values (equi value) - if( h1 instanceof LiteralOp ){ - consistent &= (h2 instanceof LiteralOp - && HopRewriteUtils.isEqualValue((LiteralOp)h1, (LiteralOp)h2)); - } - - - } - - if( !consistent ) //if differences, do not propagate - fcandCounts.put(fkey, fcandCounts.get(fkey)+1); - } - else - { - //maintain counter (this function is no candidate) - fcandCounts.put(fkey, fcandCounts.get(fkey)+1); - } - } - else { //first appearance - fcandCounts.put(fkey, 1); //create a new count entry - fcandHops.put(fkey, fop); //keep the function call hop - FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName()); - getFunctionCandidatesForStatisticPropagation(fsb, fcandCounts, fcandHops); - } - } - - for( Hop c : hop.getInput() ) - getFunctionCandidatesForStatisticPropagation(prog, c, fcandCounts, fcandHops); - - hop.setVisited(); - } - - private void pruneFunctionCandidatesForStatisticPropagation(Map<String, Integer> fcandCounts, Map<String, FunctionOp> fcandHops) - { - //debug input - if( LOG.isDebugEnabled() ) - for( Entry<String,Integer> e : fcandCounts.entrySet() ) - { - String key = e.getKey(); - Integer count = e.getValue(); - LOG.debug("IPA: FUNC statistic propagation candidate: "+key+", callCount="+count); - } - - //materialize key set - Set<String> tmp = new HashSet<String>(fcandCounts.keySet()); - - //check and prune candidate list - for( String key : tmp ) - { - Integer cnt = fcandCounts.get(key); - if( cnt != null && cnt > 1 ) //if multiple refs - fcandCounts.remove(key); - } - - //debug output - if( LOG.isDebugEnabled() ) - for( String key : fcandCounts.keySet() ) - { - LOG.debug("IPA: FUNC statistic propagation candidate (after pruning): "+key); - } + return fcallSizes.getValidFunctions(); } private boolean isUnarySizePreservingFunction(FunctionStatementBlock fsb) - throws HopsException, ParseException + throws HopsException { FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); //check unary functions over matrices boolean ret = (fstmt.getInputParams().size() == 1 - && fstmt.getInputParams().get(0).getDataType()==DataType.MATRIX - && fstmt.getOutputParams().size() == 1 - && fstmt.getOutputParams().get(0).getDataType()==DataType.MATRIX); + && fstmt.getInputParams().get(0).getDataType()==DataType.MATRIX + && fstmt.getOutputParams().size() == 1 + && fstmt.getOutputParams().get(0).getDataType()==DataType.MATRIX); //check size-preserving characteristic if( ret ) { - HashMap<String, Integer> tmp1 = new HashMap<String,Integer>(); - HashMap<String, Set<Long>> tmp2 = new HashMap<String, Set<Long>>(); - HashSet<String> tmp3 = new HashSet<String>(); - HashSet<String> tmp4 = new HashSet<String>(); + FunctionCallSizeInfo fcallSizes = new FunctionCallSizeInfo(_fgraph, false); + HashSet<String> fnStack = new HashSet<String>(); LocalVariableMap callVars = new LocalVariableMap(); //populate input @@ -391,7 +221,7 @@ public class InterProceduralAnalysis //propagate statistics for (StatementBlock sbi : fstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, tmp1, callVars, tmp2, tmp3, tmp4); + propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack); //compare output MatrixObject mo2 = (MatrixObject)callVars.get(fstmt.getOutputParams().get(0).getName()); @@ -400,44 +230,11 @@ public class InterProceduralAnalysis //reset function mo.getMatrixCharacteristics().setDimension(-1, -1); for (StatementBlock sbi : fstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, tmp1, callVars, tmp2, tmp3, tmp4); + propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack); } return ret; } - - ///////////////////////////// - // DETERMINE NNZ PROPAGATE SAFENESS - ////// - - /** - * Populates fcandSafeNNZ with all <functionKey,hopID> pairs where it is safe to - * propagate nnz into the function. - * - * @param fcandHops function candidate HOPs - * @param fcandSafeNNZ function candidate safe non-zeros - */ - private void determineFunctionCandidatesNNZPropagation(Map<String, FunctionOp> fcandHops, Map<String, Set<Long>> fcandSafeNNZ) - { - //for all function candidates - for( Entry<String, FunctionOp> e : fcandHops.entrySet() ) - { - String fKey = e.getKey(); - FunctionOp fop = e.getValue(); - HashSet<Long> tmp = new HashSet<Long>(); - - //for all inputs of this function call - for( Hop input : fop.getInput() ) - { - //if nnz known it is safe to propagate those nnz because for multiple calls - //we checked of equivalence and hence all calls have the same nnz - if( input.getNnz()>=0 ) - tmp.add(input.getHopID()); - } - - fcandSafeNNZ.put(fKey, tmp); - } - } ///////////////////////////// // INTRA-PROCEDURE ANALYSIS @@ -456,15 +253,15 @@ public class InterProceduralAnalysis * @throws HopsException If a HopsException occurs. * @throws ParseException If a ParseException occurs. */ - private void propagateStatisticsAcrossBlock( StatementBlock sb, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) - throws HopsException, ParseException + private void propagateStatisticsAcrossBlock( StatementBlock sb, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack ) + throws HopsException { if (sb instanceof FunctionStatementBlock) { FunctionStatementBlock fsb = (FunctionStatementBlock)sb; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); for (StatementBlock sbi : fstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack); } else if (sb instanceof WhileStatementBlock) { @@ -477,11 +274,11 @@ public class InterProceduralAnalysis //check and propagate stats into body LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone(); for (StatementBlock sbi : wstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack); if( Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb) ){ //second pass if required propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars); for (StatementBlock sbi : wstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack); } //remove updated constant scalars Recompiler.removeUpdatedScalars(callVars, sb); @@ -496,9 +293,9 @@ public class InterProceduralAnalysis LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone(); LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone(); for (StatementBlock sbi : istmt.getIfBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack); for (StatementBlock sbi : istmt.getElseBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVarsElse, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsAcrossBlock(sbi, callVarsElse, fcallSizes, fnStack); callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb); //remove updated constant scalars Recompiler.removeUpdatedScalars(callVars, sb); @@ -516,10 +313,10 @@ public class InterProceduralAnalysis //check and propagate stats into body LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone(); for (StatementBlock sbi : fstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack); if( Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb) ) for (StatementBlock sbi : fstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsAcrossBlock(sbi, callVars, fcallSizes, fnStack); //remove updated constant scalars Recompiler.removeUpdatedScalars(callVars, sb); } @@ -538,7 +335,7 @@ public class InterProceduralAnalysis propagateStatisticsAcrossDAG(roots, callVars); //propagate stats into function calls Hop.resetVisitStatus(roots); - propagateStatisticsIntoFunctions(prog, roots, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsIntoFunctions(prog, roots, callVars, fcallSizes, fnStack); } } @@ -578,15 +375,11 @@ public class InterProceduralAnalysis //reset visit status because potentially called multiple times root.resetVisitStatus(); - try - { - Recompiler.rUpdateStatistics( root, vars ); - + try { //note: for predicates no output statistics - //Recompiler.extractDAGOutputStatistics(root, vars); + Recompiler.rUpdateStatistics( root, vars ); } - catch(Exception ex) - { + catch(Exception ex) { throw new HopsException("Failed to update Hop DAG statistics.", ex); } } @@ -604,8 +397,7 @@ public class InterProceduralAnalysis if( roots == null ) return; - try - { + try { //update DAG statistics from leafs to roots for( Hop hop : roots ) Recompiler.rUpdateStatistics( hop, vars ); @@ -613,8 +405,7 @@ public class InterProceduralAnalysis //extract statistics from roots Recompiler.extractDAGOutputStatistics(roots, vars, true); } - catch( Exception ex ) - { + catch( Exception ex ) { throw new HopsException("Failed to update Hop DAG statistics.", ex); } } @@ -639,11 +430,11 @@ public class InterProceduralAnalysis * @throws HopsException If a HopsException occurs. * @throws ParseException If a ParseException occurs. */ - private void propagateStatisticsIntoFunctions(DMLProgram prog, ArrayList<Hop> roots, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) - throws HopsException, ParseException + private void propagateStatisticsIntoFunctions(DMLProgram prog, ArrayList<Hop> roots, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack ) + throws HopsException { for( Hop root : roots ) - propagateStatisticsIntoFunctions(prog, root, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsIntoFunctions(prog, root, callVars, fcallSizes, fnStack); } /** @@ -661,14 +452,14 @@ public class InterProceduralAnalysis * @throws HopsException If a HopsException occurs. * @throws ParseException If a ParseException occurs. */ - private void propagateStatisticsIntoFunctions(DMLProgram prog, Hop hop, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) - throws HopsException, ParseException + private void propagateStatisticsIntoFunctions(DMLProgram prog, Hop hop, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes, Set<String> fnStack ) + throws HopsException { if( hop.isVisited() ) return; for( Hop c : hop.getInput() ) - propagateStatisticsIntoFunctions(prog, c, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsIntoFunctions(prog, c, callVars, fcallSizes, fnStack); if( hop instanceof FunctionOp ) { @@ -681,7 +472,7 @@ public class InterProceduralAnalysis FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName()); FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); - if( fcand.containsKey(fkey) && + if( fcallSizes.isValidFunction(fkey) && !fnStack.contains(fkey) ) //prevent recursion { //maintain function call stack @@ -689,11 +480,10 @@ public class InterProceduralAnalysis //create mapping and populate symbol table for refresh LocalVariableMap tmpVars = new LocalVariableMap(); - populateLocalVariableMapForFunctionCall( fstmt, fop, - callVars, tmpVars, fcandSafeNNZ.get(fkey), fcand.get(fkey) ); + populateLocalVariableMapForFunctionCall( fstmt, fop, callVars, tmpVars, fcallSizes); //recursively propagate statistics - propagateStatisticsAcrossBlock(fsb, fcand, tmpVars, fcandSafeNNZ, unaryFcands, fnStack); + propagateStatisticsAcrossBlock(fsb, tmpVars, fcallSizes, fnStack); //extract vars from symbol table, re-map and refresh main program extractFunctionCallReturnStatistics(fstmt, fop, tmpVars, callVars, true); @@ -701,7 +491,7 @@ public class InterProceduralAnalysis //maintain function call stack fnStack.remove(fkey); } - else if( unaryFcands.contains(fkey) ) { + else if( fcallSizes.isDimsPreservingFunction(fkey) ) { extractFunctionCallEquivalentReturnStatistics(fstmt, fop, callVars); } else { @@ -724,11 +514,12 @@ public class InterProceduralAnalysis hop.setVisited(); } - private void populateLocalVariableMapForFunctionCall( FunctionStatement fstmt, FunctionOp fop, LocalVariableMap callvars, LocalVariableMap vars, Set<Long> inputSafeNNZ, Integer numCalls ) + private void populateLocalVariableMapForFunctionCall( FunctionStatement fstmt, FunctionOp fop, LocalVariableMap callvars, LocalVariableMap vars, FunctionCallSizeInfo fcallSizes ) throws HopsException { ArrayList<DataIdentifier> inputVars = fstmt.getInputParams(); ArrayList<Hop> inputOps = fop.getInput(); + String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); for( int i=0; i<inputVars.size(); i++ ) { @@ -740,10 +531,9 @@ public class InterProceduralAnalysis { //propagate matrix characteristics MatrixObject mo = new MatrixObject(ValueType.DOUBLE, null); - MatrixCharacteristics mc = new MatrixCharacteristics( - input.getDim1(), input.getDim2(), - ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), - inputSafeNNZ.contains(input.getHopID())?input.getNnz():-1 ); + MatrixCharacteristics mc = new MatrixCharacteristics( input.getDim1(), input.getDim2(), + ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), + fcallSizes.isSafeNnz(fkey, input.getHopID())?input.getNnz():-1 ); MatrixFormatMetaData meta = new MatrixFormatMetaData(mc,null,null); mo.setMetaData(meta); vars.put(dat.getName(), mo); @@ -759,7 +549,7 @@ public class InterProceduralAnalysis //propagate scalar variables into functions if called once //and input scalar is existing variable in symbol table else if( PROPAGATE_SCALAR_VARS_INTO_FUN - && numCalls != null && numCalls == 1 + && fcallSizes.getFunctionCallCount(fkey) == 1 && input instanceof DataOp ) { Data scalar = callvars.get(input.getName()); http://git-wip-us.apache.org/repos/asf/systemml/blob/00956138/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java index 61f122a..7103775 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java @@ -23,7 +23,6 @@ import java.util.HashMap; import org.junit.Test; import org.apache.sysml.hops.OptimizerUtils; -import org.apache.sysml.hops.ipa.InterProceduralAnalysis; import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; @@ -59,75 +58,57 @@ public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase @Test public void testFunctionSizePropagationSameInput() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, false, false); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, false); } @Test public void testFunctionSizePropagationEqualDimsUnknownNnzRight() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, false, false); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, false); } @Test public void testFunctionSizePropagationEqualDimsUnknownNnzLeft() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, false, false); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, false); } @Test public void testFunctionSizePropagationEqualDimsUnknownNnz() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, false, false); - } - - @Test - public void testFunctionSizePropagationDifferentDims() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, false, false); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, false); } @Test public void testFunctionSizePropagationDifferentDimsUnary() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, false, true); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, false); } @Test public void testFunctionSizePropagationSameInputIPA() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, true, false); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, true); } @Test public void testFunctionSizePropagationEqualDimsUnknownNnzRightIPA() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, true, false); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, true); } @Test public void testFunctionSizePropagationEqualDimsUnknownNnzLeftIPA() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, true, false); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, true); } @Test public void testFunctionSizePropagationEqualDimsUnknownNnzIPA() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, true, false); - } - - @Test - public void testFunctionSizePropagationDifferentDimsIPA() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, true, false); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, true); } @Test public void testFunctionSizePropagationDifferentDimsIPAUnary() { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, true, true); + runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, true); } - - /** - * - * @param condition - * @param branchRemoval - * @param IPA - */ - private void runIPASizePropagationMultipleFunctionsTest( String TEST_NAME, boolean IPA, boolean unary ) + private void runIPASizePropagationMultipleFunctionsTest( String TEST_NAME, boolean IPA ) { boolean oldFlagIPA = OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS; - boolean oldFlagUnary = InterProceduralAnalysis.UNARY_DIMS_PRESERVING_FUNS; try { @@ -142,8 +123,7 @@ public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA; - InterProceduralAnalysis.UNARY_DIMS_PRESERVING_FUNS = unary; - + //generate input data double[][] V = getRandomMatrix(rows, cols, 0, 1, sparsity, 7); writeInputMatrixWithMTD("V", V, true); @@ -158,17 +138,15 @@ public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase TestUtils.compareMatrices(dmlfile, rfile, 0, "Stat-DML", "Stat-R"); //check expected number of compiled and executed MR jobs - int expectedNumCompiled = (IPA) ? ((TEST_NAME.equals(TEST_NAME5))?(unary?2:4):1) : + int expectedNumCompiled = (IPA) ? ((TEST_NAME.equals(TEST_NAME5))?2:1) : (TEST_NAME.equals(TEST_NAME5)?5:4); //reblock, 2xGMR foo, GMR int expectedNumExecuted = 0; checkNumCompiledMRJobs(expectedNumCompiled); checkNumExecutedMRJobs(expectedNumExecuted); } - finally - { + finally { OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = oldFlagIPA; - InterProceduralAnalysis.UNARY_DIMS_PRESERVING_FUNS = oldFlagUnary; } }