This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 980ddca4429a78bb467828fe3c3056eed265e534
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Nov 30 23:24:15 2021 +0100

    [SYSTEMDS-3231] Fix IPA failure on non-existing paramserv functions
    
    This patch improves the robustness of the relatively recent extension
    of inter-procedural analysis (IPA) that includes function pointers
    passed to paramserv (as a second order function) in order to avoid
    unnecessary recompilation during runtime. So far the integration was
    brittle with regard to non-existing namespaces and functions, leading
    to incomprehensible errors. We now handle this in a more robust manner
    avoiding IPA and letting the runtime raise consistent errors.
---
 .../apache/sysds/hops/ipa/FunctionCallGraph.java   | 25 ++++++-------
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java | 18 +++++++++
 .../apache/sysds/parser/FunctionDictionary.java    |  2 +-
 .../sysds/runtime/compress/colgroup/APreAgg.java   |  2 +
 .../paramserv/ParamservRuntimeNegativeTest.java    | 37 ++++++++-----------
 .../paramserv/paramserv-invalid-function.dml       | 43 ++++++++++++++++++++++
 6 files changed, 91 insertions(+), 36 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java 
b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
index 1ac91c6..690eb19 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
@@ -438,13 +438,14 @@ public class FunctionCallGraph
                        rConstructFunctionCallGraph(h, fkey, sb, fstack, lfset);
                
                if( HopRewriteUtils.isParameterBuiltinOp(hop, 
ParamBuiltinOp.PARAMSERV)
-                       && HopRewriteUtils.knownParamservFunctions(hop))
+                       && HopRewriteUtils.knownParamservFunctions(hop, 
sb.getDMLProg()))
                {
                        ParameterizedBuiltinOp pop = (ParameterizedBuiltinOp) 
hop;
                        List<FunctionOp> fps = 
pop.getParamservPseudoFunctionCalls();
                        //include artificial function ops into functional call 
graph
-                       for( FunctionOp fop : fps )
-                               ret |= addFunctionOpToGraph(fop, fkey, sb, 
fstack, lfset);
+                       if( !fps.isEmpty() ) //valid functional parameters
+                               for( FunctionOp fop : fps )
+                                       ret |= addFunctionOpToGraph(fop, fkey, 
sb, fstack, lfset);
                }
                
                hop.setVisited();
@@ -472,16 +473,14 @@ public class FunctionCallGraph
 
                        //recursively construct function call dag
                        if( !fstack.contains(lfkey) ) {
-
-                                       fstack.push(lfkey);
-                                       _fGraph.get(fkey).add(lfkey);
-                                       FunctionStatementBlock fsb = 
sb.getDMLProg()
-                                               
.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
-                                       FunctionStatement fs = 
(FunctionStatement) fsb.getStatement(0);
-                                       for( StatementBlock csb : fs.getBody() )
-                                               ret |= 
rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>());
-                                       fstack.pop();
-
+                               fstack.push(lfkey);
+                               _fGraph.get(fkey).add(lfkey);
+                               FunctionStatementBlock fsb = sb.getDMLProg()
+                                       
.getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
+                               FunctionStatement fs = (FunctionStatement) 
fsb.getStatement(0);
+                               for( StatementBlock csb : fs.getBody() )
+                                       ret |= 
rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>());
+                               fstack.pop();
                        }
                        //recursive function call
                        else {
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index e23af65..c91671f 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -56,6 +56,7 @@ import org.apache.sysds.hops.ParameterizedBuiltinOp;
 import org.apache.sysds.hops.ReorgOp;
 import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.hops.UnaryOp;
+import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.ForStatement;
@@ -1623,6 +1624,23 @@ public class HopRewriteUtils
                        && (pop.getParameterHop("val") == null 
                         || pop.getParameterHop("val") instanceof LiteralOp);
        }
+       
+       public static boolean knownParamservFunctions(Hop hop, DMLProgram prog) 
{
+               if( !knownParamservFunctions(hop) )
+                       return false;
+               try {
+                       ParameterizedBuiltinOp pop = (ParameterizedBuiltinOp) 
hop;
+                       String supd = 
((LiteralOp)pop.getParameterHop("upd")).getStringValue();
+                       String sagg = 
((LiteralOp)pop.getParameterHop("agg")).getStringValue();
+                       //if functions not existing, let runtime handle it 
consistently
+                       return prog.getFunctionStatementBlock(supd) != null
+                               && prog.getFunctionStatementBlock(sagg) != null;
+               }
+               catch(Exception ex) {
+                       //robustness invalid function keys
+                       return false;
+               }
+       }
 
        public static void setUnoptimizedFunctionCalls(StatementBlock sb) {
                if( sb instanceof FunctionStatementBlock ) {
diff --git a/src/main/java/org/apache/sysds/parser/FunctionDictionary.java 
b/src/main/java/org/apache/sysds/parser/FunctionDictionary.java
index 32e2098..71c32ea 100644
--- a/src/main/java/org/apache/sysds/parser/FunctionDictionary.java
+++ b/src/main/java/org/apache/sysds/parser/FunctionDictionary.java
@@ -120,7 +120,7 @@ public class FunctionDictionary<T extends FunctionBlock> {
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder("Function Dictionary:");
-               sb.append("----------------------------------------");
+               sb.append("----------------------------------------\n");
                int pos = 0;
                for( Entry<String, T> e : _funs.entrySet() ) {
                        sb.append("-- [");
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java
index e0cb171..2a15a21 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java
@@ -36,6 +36,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
  */
 public abstract class APreAgg extends AColGroupValue {
 
+       private static final long serialVersionUID = 3250955207277128281L;
+
        private static ThreadLocal<double[]> tmpLeftMultDoubleArray = new 
ThreadLocal<double[]>() {
                @Override
                protected double[] initialValue() {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservRuntimeNegativeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservRuntimeNegativeTest.java
index d609b41..2a5bd27 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservRuntimeNegativeTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservRuntimeNegativeTest.java
@@ -19,18 +19,22 @@
 
 package org.apache.sysds.test.functions.paramserv;
 
-import org.junit.Ignore;
 import org.junit.Test;
+
+import java.util.Arrays;
+
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 
-@Ignore
 public class ParamservRuntimeNegativeTest extends AutomatedTestBase {
 
-       private static final String TEST_NAME1 = "paramserv-worker-failed";
-       private static final String TEST_NAME2 = "paramserv-agg-service-failed";
-       private static final String TEST_NAME3 = 
"paramserv-wrong-aggregate-func";
+       private static final String[] TEST_NAMES = {
+               //"paramserv-worker-failed",
+               //"paramserv-agg-service-failed",
+               //"paramserv-wrong-aggregate-func-params",
+               "paramserv-invalid-function",
+       };
 
        private static final String TEST_DIR = "functions/paramserv/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
ParamservRuntimeNegativeTest.class.getSimpleName() + "/";
@@ -39,30 +43,19 @@ public class ParamservRuntimeNegativeTest extends 
AutomatedTestBase {
 
        @Override
        public void setUp() {
-               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
-               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {}));
-               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {}));
-       }
-
-       @Test
-       public void testParamservWorkerFailed() {
-               runDMLTest(TEST_NAME1, "Invalid indexing by name in unnamed 
list: worker_err.");
+               Arrays.stream(TEST_NAMES)
+                       .forEach(s -> addTestConfiguration(s, new 
TestConfiguration(TEST_CLASS_DIR, s, new String[]{})));
        }
-
-       @Test
-       public void testParamservAggServiceFailed() {
-               runDMLTest(TEST_NAME2, "Invalid indexing by name in unnamed 
list: agg_service_err");
-       }
-
+       
        @Test
-       public void testParamservWrongAggregateFunc() {
-               runDMLTest(TEST_NAME3, "The 'gradients' function should provide 
an input of 'MATRIX' type named 'labels'.");
+       public void testParamservMissingAggregateFunc() {
+               runDMLTest(TEST_NAMES[0], "namespace XXX is undefined");
        }
 
        private void runDMLTest(String testname, String errmsg) {
                TestConfiguration config = getTestConfiguration(testname);
                loadTestConfiguration(config);
-               programArgs = new String[] { };
+               programArgs = new String[] {"-explain"};
                fullDMLScriptName = HOME + testname + ".dml";
                runTest(true, true, DMLRuntimeException.class, errmsg, -1);
        }
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-invalid-function.dml 
b/src/test/scripts/functions/paramserv/paramserv-invalid-function.dml
new file mode 100644
index 0000000..a8adb64
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-invalid-function.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+gradients = function (matrix[double] input) return (matrix[double] output) {
+       output = input + 1;
+}
+
+aggregation = function (matrix[double] input) return (matrix[double] output) {
+  output = input + 1;
+}
+
+model = list(e1="element1")
+X = matrix(1, 2, 3)
+Y = matrix(2, 2, 3)
+X_val = matrix(3, 2, 3)
+Y_val = matrix(4, 2, 3)
+hps = list(e2="element2")
+
+# Use paramserv function
+supd = ".defaultNS::gradients";
+sagg = "XXX::aggregation";
+model = paramserv(model=model, features=X, labels=Y, val_features=X_val, 
val_labels=Y_val,
+  upd=supd, agg=sagg, mode="LOCAL", utype="BSP", freq="EPOCH", epochs=10, 
hyperparams=hps)
+
+print(toString(model))

Reply via email to