[SYSTEMML-694] New ipa pass (unary, dim-preserving functions), for lstm Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/a5584c0f Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/a5584c0f Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/a5584c0f
Branch: refs/heads/master Commit: a5584c0fd39a8687f5858e87a9acb4dbd43c2c24 Parents: 084afea Author: Matthias Boehm <[email protected]> Authored: Fri Jul 22 22:58:32 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jul 23 14:41:54 2016 -0700 ---------------------------------------------------------------------- .../sysml/hops/ipa/InterProceduralAnalysis.java | 144 +++++++++++++++---- .../org/apache/sysml/parser/DMLProgram.java | 11 ++ ...IPAPropagationSizeMultipleFunctionsTest.java | 71 ++++----- 3 files changed, 161 insertions(+), 65 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a5584c0f/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java index 00fd643..849d8ff 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java @@ -20,6 +20,7 @@ package org.apache.sysml.hops.ipa; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -27,6 +28,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; +import org.apache.commons.collections.CollectionUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; @@ -129,6 +131,7 @@ public class InterProceduralAnalysis private static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true; //remove unnecessary checkpoints (unconditionally overwritten intermediates) private static final boolean REMOVE_CONSTANT_BINARY_OPS = true; //remove constant binary operations (e.g., X*ones, where ones=matrix(1,...)) private static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true; //propagate scalar variables into functions that are called once + public static boolean UNARY_DIMS_PRESERVING_FUNS = true; //determine and exploit unary dimension preserving functions static { // for internal debugging only @@ -151,6 +154,7 @@ public class InterProceduralAnalysis * @throws ParseException * @throws LanguageException */ + @SuppressWarnings("unchecked") public void analyzeProgram( DMLProgram dmlp ) throws HopsException, ParseException, LanguageException { @@ -159,8 +163,7 @@ public class InterProceduralAnalysis Map<String, FunctionOp> fcandHops = new HashMap<String, FunctionOp>(); Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, Set<Long>>(); Set<String> allFCandKeys = new HashSet<String>(); - if( dmlp.getFunctionStatementBlocks().size() > 0 ) - { + if( !dmlp.getFunctionStatementBlocks().isEmpty() ) { for ( StatementBlock sb : dmlp.getStatementBlocks() ) //get candidates (over entire program) getFunctionCandidatesForStatisticPropagation( sb, fcandCounts, fcandHops ); allFCandKeys.addAll(fcandCounts.keySet()); //cp before pruning @@ -169,35 +172,47 @@ public class InterProceduralAnalysis DMLTranslator.resetHopsDAGVisitStatus( dmlp ); } - //step 2: propagate statistics and scalars into functions and across DAGs + //step 2: get unary dimension-preserving non-candidate functions + Collection<String> unaryFcandTmp = CollectionUtils.subtract(allFCandKeys, fcandCounts.keySet()); + HashSet<String> unaryFcands = new HashSet<String>(); + if( !unaryFcandTmp.isEmpty() && UNARY_DIMS_PRESERVING_FUNS ) { + for( String tmp : unaryFcandTmp ) + if( isUnarySizePreservingFunction(dmlp.getFunctionStatementBlock(tmp)) ) + unaryFcands.add(tmp); + } + + //step 3: propagate statistics and scalars into functions and across DAGs if( !fcandCounts.isEmpty() || INTRA_PROCEDURAL_ANALYSIS ) { //(callVars used to chain outputs/inputs of multiple functions calls) LocalVariableMap callVars = new LocalVariableMap(); for ( StatementBlock sb : dmlp.getStatementBlocks() ) //propagate stats into candidates - propagateStatisticsAcrossBlock( sb, fcandCounts, callVars, fcandSafeNNZ, new HashSet<String>() ); + propagateStatisticsAcrossBlock( sb, fcandCounts, callVars, fcandSafeNNZ, unaryFcands, new HashSet<String>() ); } - //step 3: remove unused functions (e.g., inlined or never called) + //step 4: remove unused functions (e.g., inlined or never called) if( REMOVE_UNUSED_FUNCTIONS ) { removeUnusedFunctions( dmlp, allFCandKeys ); } - //step 4: flag functions with loops for 'recompile-on-entry' + //step 5: flag functions with loops for 'recompile-on-entry' if( FLAG_FUNCTION_RECOMPILE_ONCE ) { flagFunctionsForRecompileOnce( dmlp ); } - //step 5: set global data flow properties + //step 6: set global data flow properties if( REMOVE_UNNECESSARY_CHECKPOINTS && OptimizerUtils.isSparkExecutionMode() ) { removeUnnecessaryCheckpoints(dmlp); } - //step 6: remove constant binary ops + //step 7: remove constant binary ops if( REMOVE_CONSTANT_BINARY_OPS ) { removeConstantBinaryOps(dmlp); } + + //TODO evaluate potential of SECOND_CHANCE + //(consistent call stats after first IPA pass and hence additional potential) } /** @@ -227,7 +242,7 @@ public class InterProceduralAnalysis //step 2: propagate statistics into functions and across DAGs //(callVars used to chain outputs/inputs of multiple functions calls) LocalVariableMap callVars = new LocalVariableMap(); - propagateStatisticsAcrossBlock( sb, fcandCounts, callVars, fcandSafeNNZ, new HashSet<String>() ); + propagateStatisticsAcrossBlock( sb, fcandCounts, callVars, fcandSafeNNZ, new HashSet<String>(), new HashSet<String>() ); } return fcandCounts.keySet(); @@ -390,6 +405,53 @@ public class InterProceduralAnalysis LOG.debug("IPA: FUNC statistic propagation candidate (after pruning): "+key); } } + + /** + * + * @param fsb + * @return + * @throws HopsException + * @throws ParseException + */ + private boolean isUnarySizePreservingFunction(FunctionStatementBlock fsb) + throws HopsException, ParseException + { + FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); + + //check unary functions over matrices + boolean ret = (fstmt.getInputParams().size() == 1 + && fstmt.getInputParams().get(0).getDataType()==DataType.MATRIX + && fstmt.getOutputParams().size() == 1 + && fstmt.getOutputParams().get(0).getDataType()==DataType.MATRIX); + + //check size-preserving characteristic + if( ret ) { + HashMap<String, Integer> tmp1 = new HashMap<String,Integer>(); + HashMap<String, Set<Long>> tmp2 = new HashMap<String, Set<Long>>(); + HashSet<String> tmp3 = new HashSet<String>(); + HashSet<String> tmp4 = new HashSet<String>(); + LocalVariableMap callVars = new LocalVariableMap(); + + //populate input + MatrixObject mo = createOutputMatrix(7777, 3333, -1); + callVars.put(fstmt.getInputParams().get(0).getName(), mo); + + //propagate statistics + for (StatementBlock sbi : fstmt.getBody()) + propagateStatisticsAcrossBlock(sbi, tmp1, callVars, tmp2, tmp3, tmp4); + + //compare output + MatrixObject mo2 = (MatrixObject)callVars.get(fstmt.getOutputParams().get(0).getName()); + ret &= mo.getNumRows() == mo2.getNumRows() && mo.getNumColumns() == mo2.getNumColumns(); + + //reset function + mo.getMatrixCharacteristics().setDimension(-1, -1); + for (StatementBlock sbi : fstmt.getBody()) + propagateStatisticsAcrossBlock(sbi, tmp1, callVars, tmp2, tmp3, tmp4); + } + + return ret; + } ///////////////////////////// // DETERMINE NNZ PROPAGATE SAFENESS @@ -436,7 +498,7 @@ public class InterProceduralAnalysis * @throws ParseException * @throws CloneNotSupportedException */ - private void propagateStatisticsAcrossBlock( StatementBlock sb, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> fnStack ) + private void propagateStatisticsAcrossBlock( StatementBlock sb, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) throws HopsException, ParseException { if (sb instanceof FunctionStatementBlock) @@ -444,7 +506,7 @@ public class InterProceduralAnalysis FunctionStatementBlock fsb = (FunctionStatementBlock)sb; FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); for (StatementBlock sbi : fstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, fnStack); + propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); } else if (sb instanceof WhileStatementBlock) { @@ -457,11 +519,11 @@ public class InterProceduralAnalysis //check and propagate stats into body LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone(); for (StatementBlock sbi : wstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, fnStack); + propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); if( Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb) ){ //second pass if required propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars); for (StatementBlock sbi : wstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, fnStack); + propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); } //remove updated constant scalars Recompiler.removeUpdatedScalars(callVars, sb); @@ -476,9 +538,9 @@ public class InterProceduralAnalysis LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone(); LocalVariableMap callVarsElse = (LocalVariableMap) callVars.clone(); for (StatementBlock sbi : istmt.getIfBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, fnStack); + propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); for (StatementBlock sbi : istmt.getElseBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVarsElse, fcandSafeNNZ, fnStack); + propagateStatisticsAcrossBlock(sbi, fcand, callVarsElse, fcandSafeNNZ, unaryFcands, fnStack); callVars = Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb); //remove updated constant scalars Recompiler.removeUpdatedScalars(callVars, sb); @@ -496,10 +558,10 @@ public class InterProceduralAnalysis //check and propagate stats into body LocalVariableMap oldCallVars = (LocalVariableMap) callVars.clone(); for (StatementBlock sbi : fstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, fnStack); + propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); if( Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb) ) for (StatementBlock sbi : fstmt.getBody()) - propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, fnStack); + propagateStatisticsAcrossBlock(sbi, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); //remove updated constant scalars Recompiler.removeUpdatedScalars(callVars, sb); } @@ -515,7 +577,7 @@ public class InterProceduralAnalysis propagateStatisticsAcrossDAG(roots, callVars); //propagate stats into function calls Hop.resetVisitStatus(roots); - propagateStatisticsIntoFunctions(prog, roots, fcand, callVars, fcandSafeNNZ, fnStack); + propagateStatisticsIntoFunctions(prog, roots, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); } } @@ -591,11 +653,11 @@ public class InterProceduralAnalysis * @throws HopsException * @throws ParseException */ - private void propagateStatisticsIntoFunctions(DMLProgram prog, ArrayList<Hop> roots, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> fnStack ) + private void propagateStatisticsIntoFunctions(DMLProgram prog, ArrayList<Hop> roots, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) throws HopsException, ParseException { for( Hop root : roots ) - propagateStatisticsIntoFunctions(prog, root, fcand, callVars, fcandSafeNNZ, fnStack); + propagateStatisticsIntoFunctions(prog, root, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); } @@ -607,14 +669,14 @@ public class InterProceduralAnalysis * @throws HopsException * @throws ParseException */ - private void propagateStatisticsIntoFunctions(DMLProgram prog, Hop hop, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> fnStack ) + private void propagateStatisticsIntoFunctions(DMLProgram prog, Hop hop, Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) throws HopsException, ParseException { if( hop.getVisited() == VisitStatus.DONE ) return; for( Hop c : hop.getInput() ) - propagateStatisticsIntoFunctions(prog, c, fcand, callVars, fcandSafeNNZ, fnStack); + propagateStatisticsIntoFunctions(prog, c, fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack); if( hop instanceof FunctionOp ) { @@ -639,7 +701,7 @@ public class InterProceduralAnalysis callVars, tmpVars, fcandSafeNNZ.get(fkey), fcand.get(fkey) ); //recursively propagate statistics - propagateStatisticsAcrossBlock(fsb, fcand, tmpVars, fcandSafeNNZ, fnStack); + propagateStatisticsAcrossBlock(fsb, fcand, tmpVars, fcandSafeNNZ, unaryFcands, fnStack); //extract vars from symbol table, re-map and refresh main program extractFunctionCallReturnStatistics(fstmt, fop, tmpVars, callVars, true); @@ -647,8 +709,10 @@ public class InterProceduralAnalysis //maintain function call stack fnStack.remove(fkey); } - else - { + else if( unaryFcands.contains(fkey) ) { + extractFunctionCallEquivalentReturnStatistics(fstmt, fop, callVars); + } + else { extractFunctionCallUnknownReturnStatistics(fstmt, fop, callVars); } } @@ -834,6 +898,27 @@ public class InterProceduralAnalysis * @param callVars * @throws HopsException */ + private void extractFunctionCallEquivalentReturnStatistics( FunctionStatement fstmt, FunctionOp fop, LocalVariableMap callVars ) + throws HopsException + { + String fkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName()); + try { + Hop input = fop.getInput().get(0); + MatrixObject moOut = createOutputMatrix(input.getDim1(), input.getDim2(), -1); + callVars.put(fop.getOutputVariableNames()[0], moOut); + } + catch( Exception ex ) { + throw new HopsException( "Failed to extract output statistics for unary function "+fkey+".", ex); + } + } + + /** + * + * @param fstmt + * @param fop + * @param callVars + * @throws HopsException + */ private void extractExternalFunctionCallReturnStatistics( ExternalFunctionStatement fstmt, FunctionOp fop, LocalVariableMap callVars ) throws HopsException { @@ -883,13 +968,10 @@ public class InterProceduralAnalysis * @param nnz * @return */ - private MatrixObject createOutputMatrix( long dim1, long dim2, long nnz ) - { + private MatrixObject createOutputMatrix( long dim1, long dim2, long nnz ) { MatrixObject moOut = new MatrixObject(ValueType.DOUBLE, null); - MatrixCharacteristics mc = new MatrixCharacteristics( - dim1, dim2, - ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), - nnz); + MatrixCharacteristics mc = new MatrixCharacteristics( dim1, dim2, + ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(), nnz); MatrixFormatMetaData meta = new MatrixFormatMetaData(mc,null,null); moOut.setMetaData(meta); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a5584c0f/src/main/java/org/apache/sysml/parser/DMLProgram.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLProgram.java b/src/main/java/org/apache/sysml/parser/DMLProgram.java index ee76d56..12a8369 100644 --- a/src/main/java/org/apache/sysml/parser/DMLProgram.java +++ b/src/main/java/org/apache/sysml/parser/DMLProgram.java @@ -80,6 +80,17 @@ public class DMLProgram return _blocks.size(); } + /** + * + * @param fkey function key as concatenation of namespace and function name + * (see DMLProgram.constructFunctionKey) + * @return + */ + public FunctionStatementBlock getFunctionStatementBlock(String fkey) { + String[] tmp = splitFunctionKey(fkey); + return getFunctionStatementBlock(tmp[0], tmp[1]); + } + public FunctionStatementBlock getFunctionStatementBlock(String namespaceKey, String functionName) { DMLProgram namespaceProgram = this.getNamespaces().get(namespaceKey); if (namespaceProgram == null) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a5584c0f/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java index cdd15e5..61f122a 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java @@ -22,8 +22,8 @@ package org.apache.sysml.test.integration.functions.recompile; import java.util.HashMap; import org.junit.Test; - import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.ipa.InterProceduralAnalysis; import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; @@ -31,7 +31,6 @@ import org.apache.sysml.test.utils.TestUtils; public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase { - private final static String TEST_NAME1 = "multiple_function_calls1"; private final static String TEST_NAME2 = "multiple_function_calls2"; private final static String TEST_NAME3 = "multiple_function_calls3"; @@ -59,63 +58,63 @@ public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase @Test - public void testFunctionSizePropagationSameInput() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, false); + public void testFunctionSizePropagationSameInput() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, false, false); } @Test - public void testFunctionSizePropagationEqualDimsUnknownNnzRight() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, false); + public void testFunctionSizePropagationEqualDimsUnknownNnzRight() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, false, false); } @Test - public void testFunctionSizePropagationEqualDimsUnknownNnzLeft() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, false); + public void testFunctionSizePropagationEqualDimsUnknownNnzLeft() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, false, false); } @Test - public void testFunctionSizePropagationEqualDimsUnknownNnz() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, false); + public void testFunctionSizePropagationEqualDimsUnknownNnz() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, false, false); } @Test - public void testFunctionSizePropagationDifferentDims() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, false); + public void testFunctionSizePropagationDifferentDims() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, false, false); } @Test - public void testFunctionSizePropagationSameInputIPA() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, true); + public void testFunctionSizePropagationDifferentDimsUnary() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, false, true); } @Test - public void testFunctionSizePropagationEqualDimsUnknownNnzRightIPA() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, true); + public void testFunctionSizePropagationSameInputIPA() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, true, false); } @Test - public void testFunctionSizePropagationEqualDimsUnknownNnzLeftIPA() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, true); + public void testFunctionSizePropagationEqualDimsUnknownNnzRightIPA() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, true, false); } @Test - public void testFunctionSizePropagationEqualDimsUnknownNnzIPA() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, true); + public void testFunctionSizePropagationEqualDimsUnknownNnzLeftIPA() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, true, false); } @Test - public void testFunctionSizePropagationDifferentDimsIPA() - { - runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, true); + public void testFunctionSizePropagationEqualDimsUnknownNnzIPA() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, true, false); + } + + @Test + public void testFunctionSizePropagationDifferentDimsIPA() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, true, false); + } + + @Test + public void testFunctionSizePropagationDifferentDimsIPAUnary() { + runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, true, true); } @@ -125,9 +124,10 @@ public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase * @param branchRemoval * @param IPA */ - private void runIPASizePropagationMultipleFunctionsTest( String TEST_NAME, boolean IPA ) + private void runIPASizePropagationMultipleFunctionsTest( String TEST_NAME, boolean IPA, boolean unary ) { boolean oldFlagIPA = OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS; + boolean oldFlagUnary = InterProceduralAnalysis.UNARY_DIMS_PRESERVING_FUNS; try { @@ -142,6 +142,7 @@ public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA; + InterProceduralAnalysis.UNARY_DIMS_PRESERVING_FUNS = unary; //generate input data double[][] V = getRandomMatrix(rows, cols, 0, 1, sparsity, 7); @@ -157,7 +158,8 @@ public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase TestUtils.compareMatrices(dmlfile, rfile, 0, "Stat-DML", "Stat-R"); //check expected number of compiled and executed MR jobs - int expectedNumCompiled = (IPA) ? (TEST_NAME.equals(TEST_NAME5)?4:1) : (TEST_NAME.equals(TEST_NAME5)?5:4); //reblock, 2xGMR foo, GMR + int expectedNumCompiled = (IPA) ? ((TEST_NAME.equals(TEST_NAME5))?(unary?2:4):1) : + (TEST_NAME.equals(TEST_NAME5)?5:4); //reblock, 2xGMR foo, GMR int expectedNumExecuted = 0; checkNumCompiledMRJobs(expectedNumCompiled); @@ -166,6 +168,7 @@ public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase finally { OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = oldFlagIPA; + InterProceduralAnalysis.UNARY_DIMS_PRESERVING_FUNS = oldFlagUnary; } }
