[SYSTEMML-694] New ipa pass (unary, dim-preserving functions), for lstm

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/a5584c0f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/a5584c0f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/a5584c0f

Branch: refs/heads/master
Commit: a5584c0fd39a8687f5858e87a9acb4dbd43c2c24
Parents: 084afea
Author: Matthias Boehm <[email protected]>
Authored: Fri Jul 22 22:58:32 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jul 23 14:41:54 2016 -0700

----------------------------------------------------------------------
 .../sysml/hops/ipa/InterProceduralAnalysis.java | 144 +++++++++++++++----
 .../org/apache/sysml/parser/DMLProgram.java     |  11 ++
 ...IPAPropagationSizeMultipleFunctionsTest.java |  71 ++++-----
 3 files changed, 161 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a5584c0f/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
index 00fd643..849d8ff 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
@@ -20,6 +20,7 @@
 package org.apache.sysml.hops.ipa;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
@@ -27,6 +28,7 @@ import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
 
+import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
@@ -129,6 +131,7 @@ public class InterProceduralAnalysis
        private static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true; 
//remove unnecessary checkpoints (unconditionally overwritten intermediates) 
        private static final boolean REMOVE_CONSTANT_BINARY_OPS     = true; 
//remove constant binary operations (e.g., X*ones, where ones=matrix(1,...)) 
        private static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true; 
//propagate scalar variables into functions that are called once
+       public static boolean UNARY_DIMS_PRESERVING_FUNS = true; //determine 
and exploit unary dimension preserving functions 
        
        static {
                // for internal debugging only
@@ -151,6 +154,7 @@ public class InterProceduralAnalysis
         * @throws ParseException
         * @throws LanguageException
         */
+       @SuppressWarnings("unchecked")
        public void analyzeProgram( DMLProgram dmlp ) 
                throws HopsException, ParseException, LanguageException
        {
@@ -159,8 +163,7 @@ public class InterProceduralAnalysis
                Map<String, FunctionOp> fcandHops = new HashMap<String, 
FunctionOp>();
                Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, 
Set<Long>>(); 
                Set<String> allFCandKeys = new HashSet<String>();
-               if( dmlp.getFunctionStatementBlocks().size() > 0 )
-               {
+               if( !dmlp.getFunctionStatementBlocks().isEmpty() ) {
                        for ( StatementBlock sb : dmlp.getStatementBlocks() ) 
//get candidates (over entire program)
                                getFunctionCandidatesForStatisticPropagation( 
sb, fcandCounts, fcandHops );
                        allFCandKeys.addAll(fcandCounts.keySet()); //cp before 
pruning
@@ -169,35 +172,47 @@ public class InterProceduralAnalysis
                        DMLTranslator.resetHopsDAGVisitStatus( dmlp );
                }
                
-               //step 2: propagate statistics and scalars into functions and 
across DAGs
+               //step 2: get unary dimension-preserving non-candidate functions
+               Collection<String> unaryFcandTmp = 
CollectionUtils.subtract(allFCandKeys, fcandCounts.keySet());
+               HashSet<String> unaryFcands = new HashSet<String>();
+               if( !unaryFcandTmp.isEmpty() && UNARY_DIMS_PRESERVING_FUNS ) {
+                       for( String tmp : unaryFcandTmp )
+                               if( 
isUnarySizePreservingFunction(dmlp.getFunctionStatementBlock(tmp)) )
+                                       unaryFcands.add(tmp);
+               }
+               
+               //step 3: propagate statistics and scalars into functions and 
across DAGs
                if( !fcandCounts.isEmpty() || INTRA_PROCEDURAL_ANALYSIS ) {
                        //(callVars used to chain outputs/inputs of multiple 
functions calls) 
                        LocalVariableMap callVars = new LocalVariableMap();
                        for ( StatementBlock sb : dmlp.getStatementBlocks() ) 
//propagate stats into candidates
-                               propagateStatisticsAcrossBlock( sb, 
fcandCounts, callVars, fcandSafeNNZ, new HashSet<String>() );
+                               propagateStatisticsAcrossBlock( sb, 
fcandCounts, callVars, fcandSafeNNZ, unaryFcands, new HashSet<String>() );
                }
                
-               //step 3: remove unused functions (e.g., inlined or never 
called)
+               //step 4: remove unused functions (e.g., inlined or never 
called)
                if( REMOVE_UNUSED_FUNCTIONS ) {
                        removeUnusedFunctions( dmlp, allFCandKeys );
                }
                
-               //step 4: flag functions with loops for 'recompile-on-entry'
+               //step 5: flag functions with loops for 'recompile-on-entry'
                if( FLAG_FUNCTION_RECOMPILE_ONCE ) {
                        flagFunctionsForRecompileOnce( dmlp );
                }
                
-               //step 5: set global data flow properties
+               //step 6: set global data flow properties
                if( REMOVE_UNNECESSARY_CHECKPOINTS 
                        && OptimizerUtils.isSparkExecutionMode() )
                {
                        removeUnnecessaryCheckpoints(dmlp);
                }
                
-               //step 6: remove constant binary ops
+               //step 7: remove constant binary ops
                if( REMOVE_CONSTANT_BINARY_OPS ) {
                        removeConstantBinaryOps(dmlp);
                }
+               
+               //TODO evaluate potential of SECOND_CHANCE
+               //(consistent call stats after first IPA pass and hence 
additional potential)
        }
        
        /**
@@ -227,7 +242,7 @@ public class InterProceduralAnalysis
                        //step 2: propagate statistics into functions and 
across DAGs
                        //(callVars used to chain outputs/inputs of multiple 
functions calls) 
                        LocalVariableMap callVars = new LocalVariableMap();
-                       propagateStatisticsAcrossBlock( sb, fcandCounts, 
callVars, fcandSafeNNZ, new HashSet<String>() );
+                       propagateStatisticsAcrossBlock( sb, fcandCounts, 
callVars, fcandSafeNNZ, new HashSet<String>(), new HashSet<String>() );
                }
                
                return fcandCounts.keySet();
@@ -390,6 +405,53 @@ public class InterProceduralAnalysis
                                LOG.debug("IPA: FUNC statistic propagation 
candidate (after pruning): "+key);
                        }
        }
+       
+       /**
+        * 
+        * @param fsb
+        * @return
+        * @throws HopsException
+        * @throws ParseException
+        */
+       private boolean isUnarySizePreservingFunction(FunctionStatementBlock 
fsb) 
+               throws HopsException, ParseException 
+       {
+               FunctionStatement fstmt = (FunctionStatement) 
fsb.getStatement(0);
+               
+               //check unary functions over matrices
+               boolean ret = (fstmt.getInputParams().size() == 1 
+                               && 
fstmt.getInputParams().get(0).getDataType()==DataType.MATRIX
+                               && fstmt.getOutputParams().size() == 1
+                               && 
fstmt.getOutputParams().get(0).getDataType()==DataType.MATRIX);
+               
+               //check size-preserving characteristic
+               if( ret ) {
+                       HashMap<String, Integer> tmp1 = new 
HashMap<String,Integer>();
+                       HashMap<String, Set<Long>> tmp2 = new HashMap<String, 
Set<Long>>();
+                       HashSet<String> tmp3 = new HashSet<String>();
+                       HashSet<String> tmp4 = new HashSet<String>();
+                       LocalVariableMap callVars = new LocalVariableMap();
+                       
+                       //populate input
+                       MatrixObject mo = createOutputMatrix(7777, 3333, -1);
+                       callVars.put(fstmt.getInputParams().get(0).getName(), 
mo);
+                       
+                       //propagate statistics
+                       for (StatementBlock sbi : fstmt.getBody())
+                               propagateStatisticsAcrossBlock(sbi,  tmp1, 
callVars, tmp2, tmp3, tmp4);
+               
+                       //compare output
+                       MatrixObject mo2 = 
(MatrixObject)callVars.get(fstmt.getOutputParams().get(0).getName());
+                       ret &= mo.getNumRows() == mo2.getNumRows() && 
mo.getNumColumns() == mo2.getNumColumns();
+               
+                       //reset function
+                       mo.getMatrixCharacteristics().setDimension(-1, -1);
+                       for (StatementBlock sbi : fstmt.getBody())
+                               propagateStatisticsAcrossBlock(sbi,  tmp1, 
callVars, tmp2, tmp3, tmp4);
+               }
+               
+               return ret;
+       }
 
        /////////////////////////////
        // DETERMINE NNZ PROPAGATE SAFENESS
@@ -436,7 +498,7 @@ public class InterProceduralAnalysis
         * @throws ParseException
         * @throws CloneNotSupportedException 
         */
-       private void propagateStatisticsAcrossBlock( StatementBlock sb, 
Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> 
fcandSafeNNZ, Set<String> fnStack ) 
+       private void propagateStatisticsAcrossBlock( StatementBlock sb, 
Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> 
fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) 
                throws HopsException, ParseException
        {
                if (sb instanceof FunctionStatementBlock)
@@ -444,7 +506,7 @@ public class InterProceduralAnalysis
                        FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
                        FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
                        for (StatementBlock sbi : fstmt.getBody())
-                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVars, fcandSafeNNZ, fnStack);
+                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVars, fcandSafeNNZ, unaryFcands, fnStack);
                }
                else if (sb instanceof WhileStatementBlock)
                {
@@ -457,11 +519,11 @@ public class InterProceduralAnalysis
                        //check and propagate stats into body
                        LocalVariableMap oldCallVars = (LocalVariableMap) 
callVars.clone();
                        for (StatementBlock sbi : wstmt.getBody())
-                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVars, fcandSafeNNZ, fnStack);
+                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVars, fcandSafeNNZ, unaryFcands, fnStack);
                        if( 
Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, wsb) ){ 
//second pass if required
                                
propagateStatisticsAcrossPredicateDAG(wsb.getPredicateHops(), callVars);
                                for (StatementBlock sbi : wstmt.getBody())
-                                       propagateStatisticsAcrossBlock(sbi, 
fcand, callVars, fcandSafeNNZ, fnStack);
+                                       propagateStatisticsAcrossBlock(sbi, 
fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
                        }
                        //remove updated constant scalars
                        Recompiler.removeUpdatedScalars(callVars, sb);
@@ -476,9 +538,9 @@ public class InterProceduralAnalysis
                        LocalVariableMap oldCallVars = (LocalVariableMap) 
callVars.clone();
                        LocalVariableMap callVarsElse = (LocalVariableMap) 
callVars.clone();
                        for (StatementBlock sbi : istmt.getIfBody())
-                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVars, fcandSafeNNZ, fnStack);
+                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVars, fcandSafeNNZ, unaryFcands, fnStack);
                        for (StatementBlock sbi : istmt.getElseBody())
-                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVarsElse, fcandSafeNNZ, fnStack);
+                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVarsElse, fcandSafeNNZ, unaryFcands, fnStack);
                        callVars = 
Recompiler.reconcileUpdatedCallVarsIf(oldCallVars, callVars, callVarsElse, isb);
                        //remove updated constant scalars
                        Recompiler.removeUpdatedScalars(callVars, sb);
@@ -496,10 +558,10 @@ public class InterProceduralAnalysis
                        //check and propagate stats into body
                        LocalVariableMap oldCallVars = (LocalVariableMap) 
callVars.clone();
                        for (StatementBlock sbi : fstmt.getBody())
-                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVars, fcandSafeNNZ, fnStack);
+                               propagateStatisticsAcrossBlock(sbi, fcand, 
callVars, fcandSafeNNZ, unaryFcands, fnStack);
                        if( 
Recompiler.reconcileUpdatedCallVarsLoops(oldCallVars, callVars, fsb) )
                                for (StatementBlock sbi : fstmt.getBody())
-                                       propagateStatisticsAcrossBlock(sbi, 
fcand, callVars, fcandSafeNNZ, fnStack);
+                                       propagateStatisticsAcrossBlock(sbi, 
fcand, callVars, fcandSafeNNZ, unaryFcands, fnStack);
                        //remove updated constant scalars
                        Recompiler.removeUpdatedScalars(callVars, sb);
                }
@@ -515,7 +577,7 @@ public class InterProceduralAnalysis
                        propagateStatisticsAcrossDAG(roots, callVars);
                        //propagate stats into function calls
                        Hop.resetVisitStatus(roots);
-                       propagateStatisticsIntoFunctions(prog, roots, fcand, 
callVars, fcandSafeNNZ, fnStack);
+                       propagateStatisticsIntoFunctions(prog, roots, fcand, 
callVars, fcandSafeNNZ, unaryFcands, fnStack);
                }
        }
        
@@ -591,11 +653,11 @@ public class InterProceduralAnalysis
         * @throws HopsException
         * @throws ParseException
         */
-       private void propagateStatisticsIntoFunctions(DMLProgram prog, 
ArrayList<Hop> roots, Map<String, Integer> fcand, LocalVariableMap callVars, 
Map<String, Set<Long>> fcandSafeNNZ, Set<String> fnStack ) 
+       private void propagateStatisticsIntoFunctions(DMLProgram prog, 
ArrayList<Hop> roots, Map<String, Integer> fcand, LocalVariableMap callVars, 
Map<String, Set<Long>> fcandSafeNNZ, Set<String> unaryFcands, Set<String> 
fnStack ) 
                        throws HopsException, ParseException
        {
                for( Hop root : roots )
-                       propagateStatisticsIntoFunctions(prog, root, fcand, 
callVars, fcandSafeNNZ, fnStack);
+                       propagateStatisticsIntoFunctions(prog, root, fcand, 
callVars, fcandSafeNNZ, unaryFcands, fnStack);
        }
        
        
@@ -607,14 +669,14 @@ public class InterProceduralAnalysis
         * @throws HopsException
         * @throws ParseException
         */
-       private void propagateStatisticsIntoFunctions(DMLProgram prog, Hop hop, 
Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> 
fcandSafeNNZ, Set<String> fnStack ) 
+       private void propagateStatisticsIntoFunctions(DMLProgram prog, Hop hop, 
Map<String, Integer> fcand, LocalVariableMap callVars, Map<String, Set<Long>> 
fcandSafeNNZ, Set<String> unaryFcands, Set<String> fnStack ) 
                throws HopsException, ParseException
        {
                if( hop.getVisited() == VisitStatus.DONE )
                        return;
                
                for( Hop c : hop.getInput() )
-                       propagateStatisticsIntoFunctions(prog, c, fcand, 
callVars, fcandSafeNNZ, fnStack);
+                       propagateStatisticsIntoFunctions(prog, c, fcand, 
callVars, fcandSafeNNZ, unaryFcands, fnStack);
                
                if( hop instanceof FunctionOp )
                {
@@ -639,7 +701,7 @@ public class InterProceduralAnalysis
                                                        callVars, tmpVars, 
fcandSafeNNZ.get(fkey), fcand.get(fkey) );
        
                                        //recursively propagate statistics
-                                       propagateStatisticsAcrossBlock(fsb, 
fcand, tmpVars, fcandSafeNNZ, fnStack);
+                                       propagateStatisticsAcrossBlock(fsb, 
fcand, tmpVars, fcandSafeNNZ, unaryFcands, fnStack);
                                        
                                        //extract vars from symbol table, 
re-map and refresh main program
                                        
extractFunctionCallReturnStatistics(fstmt, fop, tmpVars, callVars, true);       
        
@@ -647,8 +709,10 @@ public class InterProceduralAnalysis
                                        //maintain function call stack
                                        fnStack.remove(fkey);
                                }
-                               else
-                               {
+                               else if( unaryFcands.contains(fkey) ) {
+                                       
extractFunctionCallEquivalentReturnStatistics(fstmt, fop, callVars);
+                               }
+                               else {
                                        
extractFunctionCallUnknownReturnStatistics(fstmt, fop, callVars);
                                }
                        }
@@ -834,6 +898,27 @@ public class InterProceduralAnalysis
         * @param callVars
         * @throws HopsException
         */
+       private void extractFunctionCallEquivalentReturnStatistics( 
FunctionStatement fstmt, FunctionOp fop, LocalVariableMap callVars ) 
+               throws HopsException
+       {
+               String fkey = 
DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), 
fop.getFunctionName());
+               try {
+                       Hop input = fop.getInput().get(0);
+                       MatrixObject moOut = 
createOutputMatrix(input.getDim1(), input.getDim2(), -1);  
+                       callVars.put(fop.getOutputVariableNames()[0], moOut);
+               }
+               catch( Exception ex ) {
+                       throw new HopsException( "Failed to extract output 
statistics for unary function "+fkey+".", ex);
+               }
+       }
+       
+       /**
+        * 
+        * @param fstmt
+        * @param fop
+        * @param callVars
+        * @throws HopsException
+        */
        private void extractExternalFunctionCallReturnStatistics( 
ExternalFunctionStatement fstmt, FunctionOp fop, LocalVariableMap callVars ) 
                throws HopsException
        {
@@ -883,13 +968,10 @@ public class InterProceduralAnalysis
         * @param nnz
         * @return
         */
-       private MatrixObject createOutputMatrix( long dim1, long dim2, long nnz 
)
-       {
+       private MatrixObject createOutputMatrix( long dim1, long dim2, long nnz 
) {
                MatrixObject moOut = new MatrixObject(ValueType.DOUBLE, null);
-               MatrixCharacteristics mc = new MatrixCharacteristics( 
-                                                                       dim1, 
dim2,
-                                                                       
ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize(),
-                                                                       nnz);
+               MatrixCharacteristics mc = new MatrixCharacteristics( dim1, 
dim2,
+                               ConfigurationManager.getBlocksize(), 
ConfigurationManager.getBlocksize(), nnz);
                MatrixFormatMetaData meta = new 
MatrixFormatMetaData(mc,null,null);
                moOut.setMetaData(meta);
                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a5584c0f/src/main/java/org/apache/sysml/parser/DMLProgram.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLProgram.java 
b/src/main/java/org/apache/sysml/parser/DMLProgram.java
index ee76d56..12a8369 100644
--- a/src/main/java/org/apache/sysml/parser/DMLProgram.java
+++ b/src/main/java/org/apache/sysml/parser/DMLProgram.java
@@ -80,6 +80,17 @@ public class DMLProgram
                return _blocks.size();
        }
 
+       /**
+        * 
+        * @param fkey   function key as concatenation of namespace and 
function name 
+        *               (see DMLProgram.constructFunctionKey)
+        * @return
+        */
+       public FunctionStatementBlock getFunctionStatementBlock(String fkey) {
+               String[] tmp = splitFunctionKey(fkey);
+               return getFunctionStatementBlock(tmp[0], tmp[1]);
+       }
+       
        public FunctionStatementBlock getFunctionStatementBlock(String 
namespaceKey, String functionName) {
                DMLProgram namespaceProgram = 
this.getNamespaces().get(namespaceKey);
                if (namespaceProgram == null)

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a5584c0f/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java
index cdd15e5..61f122a 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/recompile/IPAPropagationSizeMultipleFunctionsTest.java
@@ -22,8 +22,8 @@ package org.apache.sysml.test.integration.functions.recompile;
 import java.util.HashMap;
 
 import org.junit.Test;
-
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.ipa.InterProceduralAnalysis;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
@@ -31,7 +31,6 @@ import org.apache.sysml.test.utils.TestUtils;
 
 public class IPAPropagationSizeMultipleFunctionsTest extends AutomatedTestBase 
 {
-       
        private final static String TEST_NAME1 = "multiple_function_calls1";
        private final static String TEST_NAME2 = "multiple_function_calls2";
        private final static String TEST_NAME3 = "multiple_function_calls3";
@@ -59,63 +58,63 @@ public class IPAPropagationSizeMultipleFunctionsTest 
extends AutomatedTestBase
        
        
        @Test
-       public void testFunctionSizePropagationSameInput() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, false);
+       public void testFunctionSizePropagationSameInput() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, false, 
false);
        }
        
        @Test
-       public void testFunctionSizePropagationEqualDimsUnknownNnzRight() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, false);
+       public void testFunctionSizePropagationEqualDimsUnknownNnzRight() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, false, 
false);
        }
        
        @Test
-       public void testFunctionSizePropagationEqualDimsUnknownNnzLeft() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, false);
+       public void testFunctionSizePropagationEqualDimsUnknownNnzLeft() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, false, 
false);
        }
        
        @Test
-       public void testFunctionSizePropagationEqualDimsUnknownNnz() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, false);
+       public void testFunctionSizePropagationEqualDimsUnknownNnz() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, false, 
false);
        }
        
        @Test
-       public void testFunctionSizePropagationDifferentDims() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, false);
+       public void testFunctionSizePropagationDifferentDims() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, false, 
false);
        }
        
        @Test
-       public void testFunctionSizePropagationSameInputIPA() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, true);
+       public void testFunctionSizePropagationDifferentDimsUnary() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, false, 
true);
        }
        
        @Test
-       public void testFunctionSizePropagationEqualDimsUnknownNnzRightIPA() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, true);
+       public void testFunctionSizePropagationSameInputIPA() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME1, true, 
false);
        }
        
        @Test
-       public void testFunctionSizePropagationEqualDimsUnknownNnzLeftIPA() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, true);
+       public void testFunctionSizePropagationEqualDimsUnknownNnzRightIPA() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME2, true, 
false);
        }
        
        @Test
-       public void testFunctionSizePropagationEqualDimsUnknownNnzIPA() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, true);
+       public void testFunctionSizePropagationEqualDimsUnknownNnzLeftIPA() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME3, true, 
false);
        }
        
        @Test
-       public void testFunctionSizePropagationDifferentDimsIPA() 
-       {
-               runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, true);
+       public void testFunctionSizePropagationEqualDimsUnknownNnzIPA() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME4, true, 
false);
+       }
+       
+       @Test
+       public void testFunctionSizePropagationDifferentDimsIPA() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, true, 
false);
+       }
+       
+       @Test
+       public void testFunctionSizePropagationDifferentDimsIPAUnary() {
+               runIPASizePropagationMultipleFunctionsTest(TEST_NAME5, true, 
true);
        }
        
        
@@ -125,9 +124,10 @@ public class IPAPropagationSizeMultipleFunctionsTest 
extends AutomatedTestBase
         * @param branchRemoval
         * @param IPA
         */
-       private void runIPASizePropagationMultipleFunctionsTest( String 
TEST_NAME, boolean IPA )
+       private void runIPASizePropagationMultipleFunctionsTest( String 
TEST_NAME, boolean IPA, boolean unary )
        {       
                boolean oldFlagIPA = 
OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS;
+               boolean oldFlagUnary = 
InterProceduralAnalysis.UNARY_DIMS_PRESERVING_FUNS;
                
                try
                {
@@ -142,6 +142,7 @@ public class IPAPropagationSizeMultipleFunctionsTest 
extends AutomatedTestBase
                        rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + expectedDir();
                        
                        OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA;
+                       InterProceduralAnalysis.UNARY_DIMS_PRESERVING_FUNS = 
unary;
 
                        //generate input data
                        double[][] V = getRandomMatrix(rows, cols, 0, 1, 
sparsity, 7);
@@ -157,7 +158,8 @@ public class IPAPropagationSizeMultipleFunctionsTest 
extends AutomatedTestBase
                        TestUtils.compareMatrices(dmlfile, rfile, 0, 
"Stat-DML", "Stat-R");
                        
                        //check expected number of compiled and executed MR jobs
-                       int expectedNumCompiled = (IPA) ? 
(TEST_NAME.equals(TEST_NAME5)?4:1) : (TEST_NAME.equals(TEST_NAME5)?5:4); 
//reblock, 2xGMR foo, GMR 
+                       int expectedNumCompiled = (IPA) ? 
((TEST_NAME.equals(TEST_NAME5))?(unary?2:4):1) : 
+                               (TEST_NAME.equals(TEST_NAME5)?5:4); //reblock, 
2xGMR foo, GMR 
                        int expectedNumExecuted = 0;                    
                        
                        checkNumCompiledMRJobs(expectedNumCompiled); 
@@ -166,6 +168,7 @@ public class IPAPropagationSizeMultipleFunctionsTest 
extends AutomatedTestBase
                finally
                {
                        OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = 
oldFlagIPA;
+                       InterProceduralAnalysis.UNARY_DIMS_PRESERVING_FUNS = 
oldFlagUnary;
                }
        }
        

Reply via email to