[SYSTEMML-1693] New IPA pass for function inlining after rewrites

The existing function inlining (of single-statement-block functions)
happens during validate, i.e., before rewrites. However, after constant
propagation, constant folding, branch removal, and statement block
merge, often additional opportunities arise. This patch exploits such
opportunities by adding a new inter-procedural analysis pass for
inlining single-statement-block functions. To limit the potential
exponential increase of program size, we only inline functions with less
or equal than 10 operations other than dataops and literals.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/83e01b02
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/83e01b02
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/83e01b02

Branch: refs/heads/master
Commit: 83e01b02891e54612e2ab82d3f2f805eee2f09f1
Parents: 897d29d
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Oct 27 20:23:56 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Oct 27 20:23:56 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/ipa/FunctionCallGraph.java       |  31 +++-
 .../sysml/hops/ipa/IPAPassInlineFunctions.java  | 156 +++++++++++++++++++
 .../sysml/hops/ipa/InterProceduralAnalysis.java |   4 +-
 .../org/apache/sysml/parser/DMLTranslator.java  |   8 +-
 .../functions/misc/IPAFunctionInliningTest.java | 122 +++++++++++++++
 .../test/integration/functions/misc/IfTest.java | 155 +++++++++---------
 .../scripts/functions/misc/IPAFunInline1.dml    |  34 ++++
 .../scripts/functions/misc/IPAFunInline2.dml    |  36 +++++
 .../scripts/functions/misc/IPAFunInline3.dml    |  39 +++++
 .../scripts/functions/misc/IPAFunInline4.dml    |  36 +++++
 .../functions/misc/ZPackageSuite.java           |   1 +
 11 files changed, 532 insertions(+), 90 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java 
b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
index d719da7..4735f47 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
@@ -55,8 +55,9 @@ public class FunctionCallGraph
        //program-wide function call operators per target function
        //(mapping from function keys to set of its function calls)
        private final HashMap<String, ArrayList<FunctionOp>> _fCalls;
+       private final HashMap<String, ArrayList<StatementBlock>> _fCallsSB;
        
-       //subset of direct or indirect recursive functions      
+       //subset of direct or indirect recursive functions
        private final HashSet<String> _fRecursive;
        
        /**
@@ -68,6 +69,7 @@ public class FunctionCallGraph
        public FunctionCallGraph(DMLProgram prog) {
                _fGraph = new HashMap<>();
                _fCalls = new HashMap<>();
+               _fCallsSB = new HashMap<>();
                _fRecursive = new HashSet<>();
                
                constructFunctionCallGraph(prog);
@@ -82,6 +84,7 @@ public class FunctionCallGraph
        public FunctionCallGraph(StatementBlock sb) {
                _fGraph = new HashMap<>();
                _fCalls = new HashMap<>();
+               _fCallsSB = new HashMap<>();
                _fRecursive = new HashSet<>();
                
                constructFunctionCallGraph(sb);
@@ -125,6 +128,21 @@ public class FunctionCallGraph
        }
        
        /**
+        * Returns all statement blocks that contain a function operator
+        * calling the given function.
+        * 
+        * @param fkey function key of called function,
+        *      null indicates the main program and returns an empty list
+        * @return list of statement blocks
+        */
+       public List<StatementBlock> getFunctionCallsSB(String fkey) {
+               //main program cannot have function calls
+               if( fkey == null )
+                       return Collections.emptyList();
+               return _fCallsSB.get(fkey);
+       }
+       
+       /**
         * Indicates if the given function is either directly or indirectly 
recursive.
         * An example of an indirect recursive function is foo2 in the 
following call
         * chain: foo1 -&gt; foo2 -&gt; foo1.
@@ -135,7 +153,7 @@ public class FunctionCallGraph
         */
        public boolean isRecursiveFunction(String fnamespace, String fname) {
                return isRecursiveFunction(
-                       DMLProgram.constructFunctionKey(fnamespace, fname));    
                
+                       DMLProgram.constructFunctionKey(fnamespace, fname));
        }
        
        /**
@@ -268,9 +286,12 @@ public class FunctionCallGraph
                                        FunctionOp fop = (FunctionOp) h;
                                        String lfkey = fop.getFunctionKey();
                                        //keep all function operators
-                                       if( !_fCalls.containsKey(lfkey) )
-                                               _fCalls.put(lfkey, new 
ArrayList<FunctionOp>());
+                                       if( !_fCalls.containsKey(lfkey) ) {
+                                               _fCalls.put(lfkey, new 
ArrayList<>());
+                                               _fCallsSB.put(lfkey, new 
ArrayList<>());
+                                       }
                                        _fCalls.get(lfkey).add(fop);
+                                       _fCallsSB.get(lfkey).add(sb);
                                        
                                        //prevent redundant call edges
                                        if( lfset.contains(lfkey) || 
fop.getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) )
@@ -278,7 +299,7 @@ public class FunctionCallGraph
                                        
                                        if( !_fGraph.containsKey(lfkey) )
                                                _fGraph.put(lfkey, new 
HashSet<String>());
-                                               
+                                       
                                        //recursively construct function call 
dag
                                        if( !fstack.contains(lfkey) ) {
                                                fstack.push(lfkey);

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/main/java/org/apache/sysml/hops/ipa/IPAPassInlineFunctions.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/IPAPassInlineFunctions.java 
b/src/main/java/org/apache/sysml/hops/ipa/IPAPassInlineFunctions.java
new file mode 100644
index 0000000..0527a10
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassInlineFunctions.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.ipa;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.sysml.hops.DataOp;
+import org.apache.sysml.hops.FunctionOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.FunctionStatement;
+import org.apache.sysml.parser.FunctionStatementBlock;
+import org.apache.sysml.parser.StatementBlock;
+
+/**
+ * This rewrite inlines single statement block functions, which have fewer 
+ * operations than an internal threshold. Function inlining happens during 
+ * validate but after rewrites such as constant folding and branch removal 
+ * there are additional opportunities.
+ * 
+ */
+public class IPAPassInlineFunctions extends IPAPass
+{
+       @Override
+       public boolean isApplicable() {
+               return InterProceduralAnalysis.INLINING_MAX_NUM_OPS > 0;
+       }
+       
+       @Override
+       public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, 
FunctionCallSizeInfo fcallSizes ) 
+               throws HopsException
+       {
+               for( String fkey : fgraph.getReachableFunctions() ) {
+                       FunctionStatementBlock fsb = 
prog.getFunctionStatementBlock(fkey);
+                       FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
+                       if( fstmt.getBody().size() == 1 
+                               && 
HopRewriteUtils.isLastLevelStatementBlock(fstmt.getBody().get(0)) 
+                               && 
!containsFunctionOp(fstmt.getBody().get(0).get_hops())
+                               && 
countOperators(fstmt.getBody().get(0).get_hops()) 
+                                       <= 
InterProceduralAnalysis.INLINING_MAX_NUM_OPS )
+                       {
+                               if( LOG.isDebugEnabled() )
+                                       LOG.debug("IPA: Inline function 
'"+fkey+"'");
+                               
+                               //replace all relevant function calls 
+                               ArrayList<Hop> hops = 
fstmt.getBody().get(0).get_hops();
+                               List<FunctionOp> fcalls = 
fgraph.getFunctionCalls(fkey);
+                               List<StatementBlock> fcallsSB = 
fgraph.getFunctionCallsSB(fkey);
+                               for(int i=0; i<fcalls.size(); i++) {
+                                       FunctionOp op = fcalls.get(i);
+                                       
+                                       //step 0: robustness for special cases
+                                       if( op.getInput().size() != 
fstmt.getInputParams().size()
+                                               || 
op.getOutputVariableNames().length != fstmt.getOutputParams().size() )
+                                               continue;
+                                       
+                                       //step 1: deep copy hop dag
+                                       ArrayList<Hop> hops2 = 
Recompiler.deepCopyHopsDag(hops);
+                                       
+                                       //step 2: replace inputs
+                                       HashMap<String,Hop> inMap = new 
HashMap<>();
+                                       for(int j=0; j<op.getInput().size(); 
j++)
+                                               
inMap.put(fstmt.getInputParams().get(j).getName(), op.getInput().get(j));
+                                       replaceTransientReads(hops2, inMap);
+                                       
+                                       //step 3: replace outputs
+                                       HashMap<String,String> outMap = new 
HashMap<>();
+                                       String[] opOutputs = 
op.getOutputVariableNames();
+                                       for(int j=0; j<opOutputs.length; j++)
+                                               
outMap.put(fstmt.getOutputParams().get(j).getName(), opOutputs[j]);
+                                       for(int j=0; j<hops2.size(); j++) {
+                                               Hop out = hops2.get(j);
+                                               if( HopRewriteUtils.isData(out, 
DataOpTypes.TRANSIENTWRITE) )
+                                                       
out.setName(outMap.get(out.getName()));
+                                       }
+                                       fcallsSB.get(i).get_hops().remove(op);
+                                       
fcallsSB.get(i).get_hops().addAll(hops2);
+                               }
+                       }
+               }
+       }
+       
+       private static boolean containsFunctionOp(ArrayList<Hop> hops) {
+               if( hops==null || hops.isEmpty() )
+                       return false;
+               Hop.resetVisitStatus(hops);
+               boolean ret = HopRewriteUtils.containsOp(hops, 
FunctionOp.class);
+               Hop.resetVisitStatus(hops);
+               return ret;
+       }
+       
+       private static int countOperators(ArrayList<Hop> hops) {
+               if( hops==null || hops.isEmpty() )
+                       return 0;
+               Hop.resetVisitStatus(hops);
+               int count = 0;
+               for( Hop hop : hops )
+                       count += rCountOperators(hop);
+               Hop.resetVisitStatus(hops);
+               return count;
+       }
+       
+       private static int rCountOperators(Hop current) {
+               if( current.isVisited() )
+                       return 0;
+               int count = !(current instanceof DataOp 
+                       || current instanceof LiteralOp) ? 1 : 0;
+               for( Hop c : current.getInput() )
+                       count += rCountOperators(c);
+               current.setVisited();
+               return count;
+       }
+       
+       private static void replaceTransientReads(ArrayList<Hop> hops, 
HashMap<String, Hop> inMap) {
+               Hop.resetVisitStatus(hops);
+               for( Hop hop : hops )
+                       rReplaceTransientReads(hop, inMap);
+               Hop.resetVisitStatus(hops);
+       }
+       
+       private static void rReplaceTransientReads(Hop current, HashMap<String, 
Hop> inMap) {
+               if( current.isVisited() )
+                       return;
+               for( int i=0; i<current.getInput().size(); i++ ) {
+                       Hop c = current.getInput().get(i);
+                       rReplaceTransientReads(c, inMap);
+                       if( HopRewriteUtils.isData(c, 
DataOpTypes.TRANSIENTREAD) )
+                               HopRewriteUtils.replaceChildReference(current, 
c, inMap.get(c.getName()));
+               }
+               current.setVisited();
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
index 65f7e54..2ab6b3c 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
@@ -96,12 +96,13 @@ public class InterProceduralAnalysis
        protected static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true; 
//propagate scalar variables into functions that are called once
        protected static final boolean PROPAGATE_SCALAR_LITERALS      = true; 
//propagate and replace scalar literals into functions
        protected static final boolean APPLY_STATIC_REWRITES          = true; 
//apply static hop dag and statement block rewrites
+       protected static final int     INLINING_MAX_NUM_OPS           = 10;    
//inline single-statement functions w/ #ops <= threshold, other than dataops 
and literals
        
        static {
                // for internal debugging only
                if( LDEBUG ) {
                        
Logger.getLogger("org.apache.sysml.hops.ipa.InterProceduralAnalysis")
-                                 .setLevel((Level) Level.DEBUG);
+                               .setLevel((Level) Level.DEBUG);
                }
        }
        
@@ -136,6 +137,7 @@ public class InterProceduralAnalysis
                _passes.add(new IPAPassRemoveConstantBinaryOps());
                _passes.add(new IPAPassPropagateReplaceLiterals());
                _passes.add(new IPAPassApplyStaticHopRewrites());
+               _passes.add(new IPAPassInlineFunctions());
        }
        
        public InterProceduralAnalysis(StatementBlock sb) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 565c367..75103d1 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -1451,12 +1451,8 @@ public class DMLTranslator
                                        }
 
                                        //create function op
-                                       String[] foutputs = new 
String[mas.getTargetList().size()]; 
-                                       int count = 0;
-                                       for ( DataIdentifier paramName : 
mas.getTargetList() ){
-                                               
foutputs[count++]=paramName.getName();
-                                       }
-                                       
+                                       String[] foutputs = 
mas.getTargetList().stream()
+                                               .map(d -> 
d.getName()).toArray(String[]::new);
                                        FunctionType ftype = 
fsb.getFunctionOpType();
                                        FunctionOp fcall = new 
FunctionOp(ftype, fci.getNamespace(), fci.getName(), finputs, foutputs, false);
                                        output.add(fcall);

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAFunctionInliningTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAFunctionInliningTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAFunctionInliningTest.java
new file mode 100644
index 0000000..f58d400
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPAFunctionInliningTest.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+public class IPAFunctionInliningTest extends AutomatedTestBase 
+{
+       private final static String TEST_NAME1 = "IPAFunInline1"; //pos 1
+       private final static String TEST_NAME2 = "IPAFunInline2"; //pos 2
+       private final static String TEST_NAME3 = "IPAFunInline3"; //neg 1 (too 
large)
+       private final static String TEST_NAME4 = "IPAFunInline4"; //neg 2 
(control flow)
+       
+       private final static String TEST_DIR = "functions/misc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
IPAFunctionInliningTest.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
+       }
+
+       @Test
+       public void testFunInline1NoIPA() {
+               runIPAFunInlineTest( TEST_NAME1, false );
+       }
+       
+       @Test
+       public void testFunInline2NoIPA() {
+               runIPAFunInlineTest( TEST_NAME2, false );
+       }
+       
+       @Test
+       public void testFunInline3NoIPA() {
+               runIPAFunInlineTest( TEST_NAME3, false );
+       }
+       
+       @Test
+       public void testFunInline4NoIPA() {
+               runIPAFunInlineTest( TEST_NAME4, false );
+       }
+       
+       @Test
+       public void testFunInline1IPA() {
+               runIPAFunInlineTest( TEST_NAME1, true );
+       }
+       
+       @Test
+       public void testFunInline2IPA() {
+               runIPAFunInlineTest( TEST_NAME2, true );
+       }
+       
+       @Test
+       public void testFunInline3IPA() {
+               runIPAFunInlineTest( TEST_NAME3, true );
+       }
+       
+       @Test
+       public void testFunInline4IPA() {
+               runIPAFunInlineTest( TEST_NAME4, true );
+       }
+       
+       private void runIPAFunInlineTest( String testName, boolean IPA )
+       {
+               boolean oldFlagIPA = 
OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS;
+               
+               try
+               {
+                       TestConfiguration config = 
getTestConfiguration(testName);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testName + ".dml";
+                       programArgs = new String[]{"-explain", "-stats", 
"-args", output("R") };
+                       
+                       fullRScriptName = HOME + testName + ".R";
+                       rCmd = getRCmd(expectedDir());
+
+                       OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = IPA;
+
+                       //run script and compare output
+                       runTest(true, false, null, -1); 
+                       double val = readDMLMatrixFromHDFS("R").get(new 
CellIndex(1,1));
+                       Assert.assertTrue("Wrong result: 7 vs "+val, 
Math.abs(val-7)<Math.pow(10, -14));
+                       
+                       //compare inlined functions
+                       boolean inlined = ( IPA && (testName.equals(TEST_NAME1) 
|| testName.equals(TEST_NAME2)) );
+                       Assert.assertTrue("Unexpected function call: "+inlined, 
!heavyHittersContainsSubString("foo")==inlined);
+               }
+               finally {
+                       OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS = 
oldFlagIPA;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/test/java/org/apache/sysml/test/integration/functions/misc/IfTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/IfTest.java 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/IfTest.java
index 82635ae..08deba4 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/IfTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/IfTest.java
@@ -26,82 +26,81 @@ import org.junit.Test;
 
 public class IfTest extends AutomatedTestBase
 {
-
-    private final static String TEST_DIR = "functions/misc/";
-    private final static String TEST_NAME1 = "IfTest";
-    private final static String TEST_NAME2 = "IfTest2";
-    private final static String TEST_NAME3 = "IfTest3";
-    private final static String TEST_NAME4 = "IfTest4";
-    private final static String TEST_NAME5 = "IfTest5";
-    private final static String TEST_NAME6 = "IfTest6";
-    private final static String TEST_CLASS_DIR = TEST_DIR + 
IfTest.class.getSimpleName() + "/";
-
-    @Override
-    public void setUp()
-    {
-        addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME1, new String[] {}));
-        addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME2, new String[] {}));
-        addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME3, new String[] {}));
-        addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME4, new String[] {}));
-        addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME5, new String[] {}));
-        addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME6, new String[] {}));
-    }
-
-    @Test
-    public void testIf() { runTest(TEST_NAME1, 1); }
-
-    @Test
-    public void testIfElse() {
-        runTest(TEST_NAME2, 1);
-        runTest(TEST_NAME2, 2);
-    }
-
-    @Test
-    public void testIfElif() {
-        runTest(TEST_NAME3, 1);
-        runTest(TEST_NAME3, 2);
-    }
-
-    @Test
-    public void testIfElifElse() {
-        runTest(TEST_NAME4, 1);
-        runTest(TEST_NAME4, 2);
-        runTest(TEST_NAME4, 3);
-    }
-
-    @Test
-    public void testIfElifElif() {
-        runTest(TEST_NAME5, 1);
-        runTest(TEST_NAME5, 2);
-        runTest(TEST_NAME5, 3);
-    }
-
-    @Test
-    public void testIfElifElifElse() {
-        runTest(TEST_NAME6, 1);
-        runTest(TEST_NAME6, 2);
-        runTest(TEST_NAME6, 3);
-        runTest(TEST_NAME6, 4);
-    }
-
-    private void runTest( String testName, int val )
-    {
-        TestConfiguration config = getTestConfiguration(testName);
-        loadTestConfiguration(config);
-
-        String HOME = SCRIPT_DIR + TEST_DIR;
-        fullDMLScriptName = HOME + testName + ".pydml";
-        programArgs = new String[]{"-python","-nvargs","val=" + 
Integer.toString(val)};
-
-        if (val == 1)
-            setExpectedStdOut("A");
-        else if (val == 2)
-            setExpectedStdOut("B");
-        else if (val == 3)
-            setExpectedStdOut("C");
-        else
-            setExpectedStdOut("D");
-
-        runTest(true, false, null, -1);
-    }
+       private final static String TEST_DIR = "functions/misc/";
+       private final static String TEST_NAME1 = "IfTest";
+       private final static String TEST_NAME2 = "IfTest2";
+       private final static String TEST_NAME3 = "IfTest3";
+       private final static String TEST_NAME4 = "IfTest4";
+       private final static String TEST_NAME5 = "IfTest5";
+       private final static String TEST_NAME6 = "IfTest6";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
IfTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp()
+       {
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {}));
+               addTestConfiguration(TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {}));
+               addTestConfiguration(TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {}));
+               addTestConfiguration(TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {}));
+       }
+
+       @Test
+       public void testIf() { runTest(TEST_NAME1, 1); }
+
+       @Test
+       public void testIfElse() {
+               runTest(TEST_NAME2, 1);
+               runTest(TEST_NAME2, 2);
+       }
+
+       @Test
+       public void testIfElif() {
+               runTest(TEST_NAME3, 1);
+               runTest(TEST_NAME3, 2);
+       }
+
+       @Test
+       public void testIfElifElse() {
+               runTest(TEST_NAME4, 1);
+               runTest(TEST_NAME4, 2);
+               runTest(TEST_NAME4, 3);
+       }
+
+       @Test
+       public void testIfElifElif() {
+               runTest(TEST_NAME5, 1);
+               runTest(TEST_NAME5, 2);
+               runTest(TEST_NAME5, 3);
+       }
+
+       @Test
+       public void testIfElifElifElse() {
+               runTest(TEST_NAME6, 1);
+               runTest(TEST_NAME6, 2);
+               runTest(TEST_NAME6, 3);
+               runTest(TEST_NAME6, 4);
+       }
+
+       private void runTest( String testName, int val )
+       {
+               TestConfiguration config = getTestConfiguration(testName);
+               loadTestConfiguration(config);
+
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               fullDMLScriptName = HOME + testName + ".pydml";
+               programArgs = new String[]{"-python","-nvargs","val=" + 
Integer.toString(val)};
+
+               if (val == 1)
+                       setExpectedStdOut("A");
+               else if (val == 2)
+                       setExpectedStdOut("B");
+               else if (val == 3)
+                       setExpectedStdOut("C");
+               else
+                       setExpectedStdOut("D");
+
+               runTest(true, false, null, -1);
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/test/scripts/functions/misc/IPAFunInline1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/IPAFunInline1.dml 
b/src/test/scripts/functions/misc/IPAFunInline1.dml
new file mode 100644
index 0000000..6492502
--- /dev/null
+++ b/src/test/scripts/functions/misc/IPAFunInline1.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo = function(Matrix[Double] A, Integer type) return (Matrix[Double] B) {
+  if( type==1 )
+    B = A * A * A;
+  else
+    B = A - 0.1;
+}
+
+X = matrix(0.1, rows=100, cols=10);
+Y = foo(X, 1);
+z = as.matrix(sum(Y)*7);
+
+write(z, $1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/test/scripts/functions/misc/IPAFunInline2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/IPAFunInline2.dml 
b/src/test/scripts/functions/misc/IPAFunInline2.dml
new file mode 100644
index 0000000..3f7c36d
--- /dev/null
+++ b/src/test/scripts/functions/misc/IPAFunInline2.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo = function(Matrix[Double] A, Integer type) return (Matrix[Double] B) {
+  if( type==1 ) {
+    T = matrix(as.scalar(A[1,1]), nrow(A), ncol(A));
+    B = T * T * T;
+  }
+  else
+    B = A - 0.1;
+}
+
+X = matrix(0.1, rows=100, cols=10);
+Y = foo(X, 1);
+z = as.matrix(sum(Y)*7);
+
+write(z, $1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/test/scripts/functions/misc/IPAFunInline3.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/IPAFunInline3.dml 
b/src/test/scripts/functions/misc/IPAFunInline3.dml
new file mode 100644
index 0000000..f384717
--- /dev/null
+++ b/src/test/scripts/functions/misc/IPAFunInline3.dml
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo = function(Matrix[Double] A, Integer type) return (Matrix[Double] B) {
+  if( type==1 ) {
+    C = (A * A * A) / 3 + 2;
+    D = (A^2 + A^2 + 7) * A;
+    E = min(C, D)
+    B = ((E != 0) * A) * A * A;
+  }
+  else {
+    B = A - 0.1;
+  } 
+}
+
+X = matrix(0.1, rows=100, cols=10);
+Y = foo(X, 1);
+z = as.matrix(sum(Y)*7);
+
+write(z, $1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/test/scripts/functions/misc/IPAFunInline4.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/IPAFunInline4.dml 
b/src/test/scripts/functions/misc/IPAFunInline4.dml
new file mode 100644
index 0000000..42dd29c
--- /dev/null
+++ b/src/test/scripts/functions/misc/IPAFunInline4.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+foo = function(Matrix[Double] A, Integer type) return (Matrix[Double] B) {
+  for(i in 1:2) {
+    if( type==1 )
+      B = A * A * A;
+    else
+      B = A - 0.1; 
+  } 
+}
+
+X = matrix(0.1, rows=100, cols=10);
+Y = foo(X, 1);
+z = as.matrix(sum(Y)*7);
+
+write(z, $1);

http://git-wip-us.apache.org/repos/asf/systemml/blob/83e01b02/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index e3833f4..cac39e1 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -36,6 +36,7 @@ import org.junit.runners.Suite;
        InvalidFunctionAssignmentTest.class,
        InvalidFunctionSignatureTest.class,
        IPAConstantFoldingScalarVariablePropagationTest.class,
+       IPAFunctionInliningTest.class,
        IPALiteralReplacementTest.class,
        IPANnzPropagationTest.class,
        IPAScalarRecursionTest.class,

Reply via email to