Repository: systemml
Updated Branches:
  refs/heads/master db9da2855 -> 07e65189e


[SYSTEMML-2467] Fix IPA robustness for permuted named function args

This patch fixes incorrect size propagation issues when using the
recently introduced named function arguments. Specifically, we properly
propagate statistics according to the given input-name binding.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/07e65189
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/07e65189
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/07e65189

Branch: refs/heads/master
Commit: 07e65189ef8a2b9d15f17ae7f502bfd2d7588933
Parents: db9da28
Author: Matthias Boehm <[email protected]>
Authored: Wed Jul 25 19:14:34 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Jul 25 19:14:34 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/ipa/InterProceduralAnalysis.java | 24 +++++++------
 .../functions/misc/FunctionPotpourriTest.java   |  7 ++++
 .../functions/misc/FunPotpourriNamedArgsIPA.dml | 37 ++++++++++++++++++++
 3 files changed, 57 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/07e65189/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
index 85a433a..77cc3e2 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
@@ -313,7 +313,7 @@ public class InterProceduralAnalysis
                        IfStatementBlock isb = (IfStatementBlock) sb;
                        IfStatement istmt = (IfStatement)isb.getStatement(0);
                        //old stats into predicate
-                       
propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);        
                
+                       
propagateStatisticsAcrossPredicateDAG(isb.getPredicateHops(), callVars);
                        //check and propagate stats into body
                        LocalVariableMap oldCallVars = (LocalVariableMap) 
callVars.clone();
                        LocalVariableMap callVarsElse = (LocalVariableMap) 
callVars.clone();
@@ -476,8 +476,8 @@ public class InterProceduralAnalysis
                                FunctionStatementBlock fsb = 
prog.getFunctionStatementBlock(fop.getFunctionNamespace(), 
fop.getFunctionName());
                                FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
                                
-                               if(  fcallSizes.isValidFunction(fkey) && 
-                                   !fnStack.contains(fkey)  ) //prevent 
recursion      
+                               if( fcallSizes.isValidFunction(fkey) && 
+                                       !fnStack.contains(fkey)  ) //prevent 
recursion
                                {
                                        //maintain function call stack
                                        fnStack.add(fkey);
@@ -490,7 +490,7 @@ public class InterProceduralAnalysis
                                        propagateStatisticsAcrossBlock(fsb, 
tmpVars, fcallSizes, fnStack);
                                        
                                        //extract vars from symbol table, 
re-map and refresh main program
-                                       
extractFunctionCallReturnStatistics(fstmt, fop, tmpVars, callVars, true);       
        
+                                       
extractFunctionCallReturnStatistics(fstmt, fop, tmpVars, callVars, true);
                                        
                                        //maintain function call stack
                                        fnStack.remove(fkey);
@@ -520,26 +520,28 @@ public class InterProceduralAnalysis
        
        private static void populateLocalVariableMapForFunctionCall( 
FunctionStatement fstmt, FunctionOp fop, LocalVariableMap callvars, 
LocalVariableMap vars, FunctionCallSizeInfo fcallSizes ) 
        {
-               ArrayList<DataIdentifier> inputVars = fstmt.getInputParams();
+               //note: due to arbitrary binding sequences of named function 
arguments,
+               //we cannot use the sequence as defined in the function 
signature
+               String[] funArgNames = fop.getInputVariableNames();
                ArrayList<Hop> inputOps = fop.getInput();
                String fkey = fop.getFunctionKey();
                
-               for( int i=0; i<inputVars.size(); i++ )
+               for( int i=0; i<funArgNames.length; i++ )
                {
                        //create mapping between input hops and vars
-                       DataIdentifier dat = inputVars.get(i);
+                       DataIdentifier dat = 
fstmt.getInputParam(funArgNames[i]);
                        Hop input = inputOps.get(i);
                        
                        if( input.getDataType()==DataType.MATRIX )
                        {
                                //propagate matrix characteristics
                                MatrixObject mo = new 
MatrixObject(ValueType.DOUBLE, null);
-                               MatrixCharacteristics mc = new 
MatrixCharacteristics( input.getDim1(), input.getDim2(), 
+                               MatrixCharacteristics mc = new 
MatrixCharacteristics( input.getDim1(), input.getDim2(),
                                        ConfigurationManager.getBlocksize(), 
ConfigurationManager.getBlocksize(),
                                        fcallSizes.isSafeNnz(fkey, 
i)?input.getNnz():-1 );
                                MetaDataFormat meta = new 
MetaDataFormat(mc,null,null);
-                               mo.setMetaData(meta);   
-                               vars.put(dat.getName(), mo);    
+                               mo.setMetaData(meta);
+                               vars.put(dat.getName(), mo);
                        }
                        else if( input.getDataType()==DataType.SCALAR )
                        {
@@ -553,7 +555,7 @@ public class InterProceduralAnalysis
                                //and input scalar is existing variable in 
symbol table
                                else if( PROPAGATE_SCALAR_VARS_INTO_FUN 
                                        && 
fcallSizes.getFunctionCallCount(fkey) == 1
-                                       && input instanceof DataOp  ) 
+                                       && input instanceof DataOp )
                                {
                                        Data scalar = 
callvars.get(input.getName()); 
                                        if( scalar != null && scalar instanceof 
ScalarObject ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/07e65189/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
index 34269e2..f1d2493 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/FunctionPotpourriTest.java
@@ -39,6 +39,7 @@ public class FunctionPotpourriTest extends AutomatedTestBase
        private final static String TEST_NAME9 = "FunPotpourriNamedArgsPartial";
        private final static String TEST_NAME10 = 
"FunPotpourriNamedArgsUnknown1";
        private final static String TEST_NAME11 = 
"FunPotpourriNamedArgsUnknown2";
+       private final static String TEST_NAME12 = "FunPotpourriNamedArgsIPA";
        
        private final static String TEST_DIR = "functions/misc/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FunctionPotpourriTest.class.getSimpleName() + "/";
@@ -57,6 +58,7 @@ public class FunctionPotpourriTest extends AutomatedTestBase
                addTestConfiguration( TEST_NAME9, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME10, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME11, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME11, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME12, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME12, new String[] { "R" }) );
        }
 
        @Test
@@ -124,6 +126,11 @@ public class FunctionPotpourriTest extends 
AutomatedTestBase
                runFunctionTest( TEST_NAME11, true );
        }
        
+       @Test
+       public void testFunctionNamedArgsIPA() {
+               runFunctionTest( TEST_NAME12, false );
+       }
+       
        private void runFunctionTest(String testName, boolean error) {
                TestConfiguration config = getTestConfiguration(testName);
                loadTestConfiguration(config);

http://git-wip-us.apache.org/repos/asf/systemml/blob/07e65189/src/test/scripts/functions/misc/FunPotpourriNamedArgsIPA.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/FunPotpourriNamedArgsIPA.dml 
b/src/test/scripts/functions/misc/FunPotpourriNamedArgsIPA.dml
new file mode 100644
index 0000000..f0769f6
--- /dev/null
+++ b/src/test/scripts/functions/misc/FunPotpourriNamedArgsIPA.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo2 = function(Matrix[Double] A, Matrix[Double] B) return (Matrix[Double] C){
+  A = A[,1:100]; # check non-applied rewrite
+  B = B[1:100,]; # check non-applied rewrite
+  C = A %*% B + 7;
+  while(FALSE){} #no inlining
+}
+
+X1 = matrix(1, 100, 101)
+X2 = matrix(2, 101, 100)
+
+C = foo2(B=X2, A=X1);
+
+if( nrow(C) != 100 | ncol(C) != 100 )
+  C = X1 %*% X1; # cause error
+
+print(sum(C));

Reply via email to