This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 980ddca4429a78bb467828fe3c3056eed265e534 Author: Matthias Boehm <[email protected]> AuthorDate: Tue Nov 30 23:24:15 2021 +0100 [SYSTEMDS-3231] Fix IPA failure on non-existing paramserv functions This patch improves the robustness of the relatively recent extension of inter-procedural analysis (IPA) that includes function pointers passed to paramserv (as a second order function) in order to avoid unnecessary recompilation during runtime. So far the integration was brittle with regard to non-existing namespaces and functions, leading to incomprehensible errors. We now handle this in a more robust manner avoiding IPA and letting the runtime raise consistent errors. --- .../apache/sysds/hops/ipa/FunctionCallGraph.java | 25 ++++++------- .../apache/sysds/hops/rewrite/HopRewriteUtils.java | 18 +++++++++ .../apache/sysds/parser/FunctionDictionary.java | 2 +- .../sysds/runtime/compress/colgroup/APreAgg.java | 2 + .../paramserv/ParamservRuntimeNegativeTest.java | 37 ++++++++----------- .../paramserv/paramserv-invalid-function.dml | 43 ++++++++++++++++++++++ 6 files changed, 91 insertions(+), 36 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java index 1ac91c6..690eb19 100644 --- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java +++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java @@ -438,13 +438,14 @@ public class FunctionCallGraph rConstructFunctionCallGraph(h, fkey, sb, fstack, lfset); if( HopRewriteUtils.isParameterBuiltinOp(hop, ParamBuiltinOp.PARAMSERV) - && HopRewriteUtils.knownParamservFunctions(hop)) + && HopRewriteUtils.knownParamservFunctions(hop, sb.getDMLProg())) { ParameterizedBuiltinOp pop = (ParameterizedBuiltinOp) hop; List<FunctionOp> fps = pop.getParamservPseudoFunctionCalls(); //include artificial function ops into functional call graph - for( FunctionOp fop : fps ) - ret |= addFunctionOpToGraph(fop, fkey, sb, fstack, lfset); + if( !fps.isEmpty() ) //valid functional parameters + for( FunctionOp fop : fps ) + ret |= addFunctionOpToGraph(fop, fkey, sb, fstack, lfset); } hop.setVisited(); @@ -472,16 +473,14 @@ public class FunctionCallGraph //recursively construct function call dag if( !fstack.contains(lfkey) ) { - - fstack.push(lfkey); - _fGraph.get(fkey).add(lfkey); - FunctionStatementBlock fsb = sb.getDMLProg() - .getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName()); - FunctionStatement fs = (FunctionStatement) fsb.getStatement(0); - for( StatementBlock csb : fs.getBody() ) - ret |= rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>()); - fstack.pop(); - + fstack.push(lfkey); + _fGraph.get(fkey).add(lfkey); + FunctionStatementBlock fsb = sb.getDMLProg() + .getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName()); + FunctionStatement fs = (FunctionStatement) fsb.getStatement(0); + for( StatementBlock csb : fs.getBody() ) + ret |= rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>()); + fstack.pop(); } //recursive function call else { diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java index e23af65..c91671f 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java @@ -56,6 +56,7 @@ import org.apache.sysds.hops.ParameterizedBuiltinOp; import org.apache.sysds.hops.ReorgOp; import org.apache.sysds.hops.TernaryOp; import org.apache.sysds.hops.UnaryOp; +import org.apache.sysds.parser.DMLProgram; import org.apache.sysds.parser.DataExpression; import org.apache.sysds.parser.DataIdentifier; import org.apache.sysds.parser.ForStatement; @@ -1623,6 +1624,23 @@ public class HopRewriteUtils && (pop.getParameterHop("val") == null || pop.getParameterHop("val") instanceof LiteralOp); } + + public static boolean knownParamservFunctions(Hop hop, DMLProgram prog) { + if( !knownParamservFunctions(hop) ) + return false; + try { + ParameterizedBuiltinOp pop = (ParameterizedBuiltinOp) hop; + String supd = ((LiteralOp)pop.getParameterHop("upd")).getStringValue(); + String sagg = ((LiteralOp)pop.getParameterHop("agg")).getStringValue(); + //if functions not existing, let runtime handle it consistently + return prog.getFunctionStatementBlock(supd) != null + && prog.getFunctionStatementBlock(sagg) != null; + } + catch(Exception ex) { + //robustness invalid function keys + return false; + } + } public static void setUnoptimizedFunctionCalls(StatementBlock sb) { if( sb instanceof FunctionStatementBlock ) { diff --git a/src/main/java/org/apache/sysds/parser/FunctionDictionary.java b/src/main/java/org/apache/sysds/parser/FunctionDictionary.java index 32e2098..71c32ea 100644 --- a/src/main/java/org/apache/sysds/parser/FunctionDictionary.java +++ b/src/main/java/org/apache/sysds/parser/FunctionDictionary.java @@ -120,7 +120,7 @@ public class FunctionDictionary<T extends FunctionBlock> { @Override public String toString() { StringBuilder sb = new StringBuilder("Function Dictionary:"); - sb.append("----------------------------------------"); + sb.append("----------------------------------------\n"); int pos = 0; for( Entry<String, T> e : _funs.entrySet() ) { sb.append("-- ["); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java index e0cb171..2a15a21 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java @@ -36,6 +36,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; */ public abstract class APreAgg extends AColGroupValue { + private static final long serialVersionUID = 3250955207277128281L; + private static ThreadLocal<double[]> tmpLeftMultDoubleArray = new ThreadLocal<double[]>() { @Override protected double[] initialValue() { diff --git a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservRuntimeNegativeTest.java b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservRuntimeNegativeTest.java index d609b41..2a5bd27 100644 --- a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservRuntimeNegativeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservRuntimeNegativeTest.java @@ -19,18 +19,22 @@ package org.apache.sysds.test.functions.paramserv; -import org.junit.Ignore; import org.junit.Test; + +import java.util.Arrays; + import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; -@Ignore public class ParamservRuntimeNegativeTest extends AutomatedTestBase { - private static final String TEST_NAME1 = "paramserv-worker-failed"; - private static final String TEST_NAME2 = "paramserv-agg-service-failed"; - private static final String TEST_NAME3 = "paramserv-wrong-aggregate-func"; + private static final String[] TEST_NAMES = { + //"paramserv-worker-failed", + //"paramserv-agg-service-failed", + //"paramserv-wrong-aggregate-func-params", + "paramserv-invalid-function", + }; private static final String TEST_DIR = "functions/paramserv/"; private static final String TEST_CLASS_DIR = TEST_DIR + ParamservRuntimeNegativeTest.class.getSimpleName() + "/"; @@ -39,30 +43,19 @@ public class ParamservRuntimeNegativeTest extends AutomatedTestBase { @Override public void setUp() { - addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {})); - addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {})); - addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {})); - } - - @Test - public void testParamservWorkerFailed() { - runDMLTest(TEST_NAME1, "Invalid indexing by name in unnamed list: worker_err."); + Arrays.stream(TEST_NAMES) + .forEach(s -> addTestConfiguration(s, new TestConfiguration(TEST_CLASS_DIR, s, new String[]{}))); } - - @Test - public void testParamservAggServiceFailed() { - runDMLTest(TEST_NAME2, "Invalid indexing by name in unnamed list: agg_service_err"); - } - + @Test - public void testParamservWrongAggregateFunc() { - runDMLTest(TEST_NAME3, "The 'gradients' function should provide an input of 'MATRIX' type named 'labels'."); + public void testParamservMissingAggregateFunc() { + runDMLTest(TEST_NAMES[0], "namespace XXX is undefined"); } private void runDMLTest(String testname, String errmsg) { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); - programArgs = new String[] { }; + programArgs = new String[] {"-explain"}; fullDMLScriptName = HOME + testname + ".dml"; runTest(true, true, DMLRuntimeException.class, errmsg, -1); } diff --git a/src/test/scripts/functions/paramserv/paramserv-invalid-function.dml b/src/test/scripts/functions/paramserv/paramserv-invalid-function.dml new file mode 100644 index 0000000..a8adb64 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-invalid-function.dml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +gradients = function (matrix[double] input) return (matrix[double] output) { + output = input + 1; +} + +aggregation = function (matrix[double] input) return (matrix[double] output) { + output = input + 1; +} + +model = list(e1="element1") +X = matrix(1, 2, 3) +Y = matrix(2, 2, 3) +X_val = matrix(3, 2, 3) +Y_val = matrix(4, 2, 3) +hps = list(e2="element2") + +# Use paramserv function +supd = ".defaultNS::gradients"; +sagg = "XXX::aggregation"; +model = paramserv(model=model, features=X, labels=Y, val_features=X_val, val_labels=Y_val, + upd=supd, agg=sagg, mode="LOCAL", utype="BSP", freq="EPOCH", epochs=10, hyperparams=hps) + +print(toString(model))
