[SYSTEMML-1699] Fix IPA nnz propagation w/ multiple function calls

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a625c642
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a625c642
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a625c642

Branch: refs/heads/master
Commit: a625c6423655210fa91b41162734c9214d3aaa0d
Parents: 72645d3
Author: Matthias Boehm <[email protected]>
Authored: Sat Jun 17 14:41:35 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jun 17 14:41:35 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/ipa/FunctionCallSizeInfo.java    | 19 ++---
 .../sysml/hops/ipa/InterProceduralAnalysis.java |  2 +-
 .../functions/misc/IPANnzPropagationTest.java   | 83 ++++++++++++++++++++
 .../functions/misc/IPANnzPropagation1.dml       | 30 +++++++
 .../functions/misc/IPANnzPropagation2.dml       | 31 ++++++++
 .../functions/misc/ZPackageSuite.java           |  1 +
 6 files changed, 156 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java 
b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
index 402e780..fb668b5 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java
@@ -55,8 +55,8 @@ public class FunctionCallSizeInfo
        
        //indicators for which function arguments of valid functions it 
        //is safe to propagate the number of non-zeros 
-       //(mapping from function keys to set of function input HopIDs)
-       private final Map<String, Set<Long>> _fcandSafeNNZ;
+       //(mapping from function keys to set of function input positions)
+       private final Map<String, Set<Integer>> _fcandSafeNNZ;
        
        //indicators which literal function arguments can be safely 
        //propagated into and replaced in the respective functions 
@@ -90,7 +90,7 @@ public class FunctionCallSizeInfo
                _fgraph = fgraph;
                _fcand = new HashSet<String>();
                _fcandUnary = new HashSet<String>();
-               _fcandSafeNNZ =  new HashMap<String, Set<Long>>();
+               _fcandSafeNNZ =  new HashMap<String, Set<Integer>>();
                _fSafeLiterals = new HashMap<String, Set<Integer>>();
                
                constructFunctionCallSizeInfo();
@@ -176,12 +176,12 @@ public class FunctionCallSizeInfo
         * number of non-zeros.  
         * 
         * @param fkey function key
-        * @param inputHopID hop ID of the input
+        * @param pos function input position
         * @return true if nnz can safely be propagated
         */
-       public boolean isSafeNnz(String fkey, long inputHopID) {
+       public boolean isSafeNnz(String fkey, int pos) {
                return _fcandSafeNNZ.containsKey(fkey)
-                       && _fcandSafeNNZ.get(fkey).contains(inputHopID);
+                       && _fcandSafeNNZ.get(fkey).contains(pos);
        }
        
        /**
@@ -254,12 +254,13 @@ public class FunctionCallSizeInfo
                //(considered for valid functions only)
                for( String fkey : _fcand ) {
                        FunctionOp first = 
_fgraph.getFunctionCalls(fkey).get(0);
-                       HashSet<Long> tmp = new HashSet<Long>();
-                       for( Hop input : first.getInput() ) {
+                       HashSet<Integer> tmp = new HashSet<Integer>();
+                       for( int j=0; j<first.getInput().size(); j++ ) {
                                //if nnz known it is safe to propagate those 
nnz because for multiple calls 
                                //we checked of equivalence and hence all calls 
have the same nnz
+                               Hop input = first.getInput().get(0);
                                if( input.getNnz()>=0 ) 
-                                       tmp.add(input.getHopID());
+                                       tmp.add(j);
                        }
                        _fcandSafeNNZ.put(fkey, tmp);
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
index 7d371ac..b813685 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
@@ -541,7 +541,7 @@ public class InterProceduralAnalysis
                                MatrixObject mo = new 
MatrixObject(ValueType.DOUBLE, null);
                                MatrixCharacteristics mc = new 
MatrixCharacteristics( input.getDim1(), input.getDim2(), 
                                        ConfigurationManager.getBlocksize(), 
ConfigurationManager.getBlocksize(),
-                                       fcallSizes.isSafeNnz(fkey, 
input.getHopID())?input.getNnz():-1 );
+                                       fcallSizes.isSafeNnz(fkey, 
i)?input.getNnz():-1 );
                                MatrixFormatMetaData meta = new 
MatrixFormatMetaData(mc,null,null);
                                mo.setMetaData(meta);   
                                vars.put(dat.getName(), mo);    

http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/test/java/org/apache/sysml/test/integration/functions/misc/IPANnzPropagationTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/IPANnzPropagationTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPANnzPropagationTest.java
new file mode 100644
index 0000000..2809ecf
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/IPANnzPropagationTest.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+
+public class IPANnzPropagationTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME1 = "IPANnzPropagation1";
+       private final static String TEST_NAME2 = "IPANnzPropagation2";
+       private final static String TEST_DIR = "functions/misc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
IPANnzPropagationTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{}));
+       }
+
+       @Test
+       public void testNnzPropgationPositive() {
+               runIPANnzPropgationTest(TEST_NAME1);
+       }
+
+       @Test
+       public void testNnzPropgationNegative() {
+               runIPANnzPropgationTest(TEST_NAME2);
+       }
+
+
+       private void runIPANnzPropgationTest(String testname)
+       {
+               // Save old settings
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               
+               try
+               {
+                       // Setup test
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{"-stats", "-explain", 
"recompile_hops"};
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+                       rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
+                       
+                       runTest(true, false, null, -1);
+                       
+                       //check for propagated nnz
+                       checkNumCompiledSparkInst(testname.equals(TEST_NAME1) ? 
0 : 1);
+                       checkNumExecutedSparkInst(0);
+               }
+               finally {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       rtplatform = platformOld;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/test/scripts/functions/misc/IPANnzPropagation1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/IPANnzPropagation1.dml 
b/src/test/scripts/functions/misc/IPANnzPropagation1.dml
new file mode 100644
index 0000000..919dc32
--- /dev/null
+++ b/src/test/scripts/functions/misc/IPANnzPropagation1.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(matrix[double] X) return (double sum) {
+    if( 1==1 ) {}
+    sum = sum(X);
+}
+
+X = rand(rows=1000, cols=1000000000, sparsity=1e-6)
+s1 = foo(X);
+s2 = foo(X);
+print(s1+" "+s2);

http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/test/scripts/functions/misc/IPANnzPropagation2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/IPANnzPropagation2.dml 
b/src/test/scripts/functions/misc/IPANnzPropagation2.dml
new file mode 100644
index 0000000..86990ca
--- /dev/null
+++ b/src/test/scripts/functions/misc/IPANnzPropagation2.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+foo = function(matrix[double] X) return (double sum) {
+    if( 1==1 ) {}
+    sum = sum(X);
+}
+
+X = rand(rows=1000, cols=1000000000, sparsity=1e-6)
+s1 = foo(X);
+X = rand(rows=1000, cols=1000000000, sparsity=1e-7)
+s2 = foo(X);
+print(s1+" "+s2);

http://git-wip-us.apache.org/repos/asf/systemml/blob/a625c642/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index 9917bb3..e352e6d 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -36,6 +36,7 @@ import org.junit.runners.Suite;
        InvalidFunctionSignatureTest.class,
        IPAConstantFoldingScalarVariablePropagationTest.class,
        IPALiteralReplacementTest.class,
+       IPANnzPropagationTest.class,
        IPAScalarRecursionTest.class,
        IPAScalarVariablePropagationTest.class,
        IPAUnknownRecursionTest.class,

Reply via email to