This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a6a1509c8a [SYSTEMDS-3796] Fix flaky federated primitive tests and 
instructions
a6a1509c8a is described below

commit a6a1509c8a40b400f63286b4631f0d42849be187
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Nov 26 09:46:31 2024 +0100

    [SYSTEMDS-3796] Fix flaky federated primitive tests and instructions
---
 .../instructions/fed/CovarianceFEDInstruction.java | 84 ++++++++++++----------
 .../primitives/part5/FederatedCovarianceTest.java  | 23 ++----
 .../part5/FederatedMatrixScalarOperationsTest.java |  4 +-
 3 files changed, 52 insertions(+), 59 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
index 719fd91588..d7f28293ce 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
@@ -28,6 +28,7 @@ import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.lops.Lop;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -116,7 +117,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        
                FederatedRequest fr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
                FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), 
fr1.getID());
-               Future<FederatedResponse>[] covTmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+               Future<FederatedResponse>[] covTmp = 
mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
 
                //means
                Future<FederatedResponse>[] meanTmp1 = processMean(mo1, moLin3, 
0);
@@ -145,7 +146,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(moLin3, false);
 
                // the original instruction encodes weights as "pREADW", change 
to the new ID
-               String[] parts = instString.split("°");
+               String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
                String covInstr = instString.replace(parts[4], 
String.valueOf(fr1[0].getID()) + "·MATRIX·FP64");
 
                FederatedRequest fr2 = FederationUtils.callInstruction(
@@ -305,7 +306,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                Future<FederatedResponse>[] meanTmp = null;
                if (moLin3 == null) {
                        String meanInstr = instString.replace(getOpcode(), 
getOpcode().replace("cov", "uamean"));
-                       meanInstr = meanInstr.replace((var == 0 ? parts[2] : 
parts[3]) + "°", "");
+                       meanInstr = meanInstr.replace((var == 0 ? parts[2] : 
parts[3]) + Lop.OPERAND_DELIMITOR, "");
                        meanInstr = meanInstr.replace(parts[4], 
parts[4].replace("FP64", "STRING°16"));
 
                        //create federated commands for aggregation
@@ -321,7 +322,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        String multOutput = incrementVar(parts[4], 1);
                        String multInstr = instString
                                .replace(getOpcode(), 
getOpcode().replace("cov", "*"))
-                               .replace((var == 0 ? parts[2] : parts[3]) + 
"°", "")
+                               .replace((var == 0 ? parts[2] : parts[3]) + 
Lop.OPERAND_DELIMITOR, "")
                                .replace(parts[5], multOutput);
 
                        CPOperand multOutputCPOp = new CPOperand(
@@ -337,13 +338,13 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        );
 
                        // calculate the sum of the obtained vector
-                       String[] partsMult = multInstr.split("°");
+                       String[] partsMult = 
multInstr.split(Lop.OPERAND_DELIMITOR);
                        String sumInstr1Output = incrementVar(multOutput, 1)
                                .replace("m", "")
                                .replace("MATRIX", "SCALAR");
                        String sumInstr1 = multInstr
                                .replace(partsMult[1], "uak+")
-                               .replace(partsMult[3] + "°", "")
+                               .replace(partsMult[3] + Lop.OPERAND_DELIMITOR, 
"")
                                .replace(partsMult[4], sumInstr1Output)
                                .replace(partsMult[2], multOutput);
 
@@ -358,7 +359,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        );
 
                        // calculate the sum of weights
-                       String[] partsSum1 = sumInstr1.split("°");
+                       String[] partsSum1 = 
sumInstr1.split(Lop.OPERAND_DELIMITOR);
                        String sumInstr2Output = incrementVar(sumInstr1Output, 
1);
                        String sumInstr2 = sumInstr1
                                .replace(partsSum1[2], parts[4])
@@ -367,7 +368,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        FederatedRequest sumFr2 = 
FederationUtils.callInstruction(
                                sumInstr2,
                                new CPOperand(
-                                       sumInstr2Output.substring(0, 
sumInstr2Output.indexOf("·")),
+                                       sumInstr2Output.substring(0, 
sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
                                        output.getValueType(), 
output.getDataType()
                                ),
                                new CPOperand[]{input3},
@@ -375,12 +376,13 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        );
 
                        // divide sum(X*W) by sum(W)
-                       String[] partsSum2 = sumInstr2.split("°");
+                       String[] partsSum2 = 
sumInstr2.split(Lop.OPERAND_DELIMITOR);
                        String divInstrOutput = incrementVar(sumInstr2Output, 
1);
-                       String divInstrInput1 = 
partsSum2[2].replace(partsSum2[2], sumInstr1Output + "·false");
-                       String divInstrInput2 = 
partsSum2[3].replace(partsSum2[3], sumInstr2Output + "·false");
-                       String divInstr = partsSum2[0] + "°" + 
partsSum2[1].replace("uak+", "/") + "°" +
-                                       divInstrInput1 + "°" + divInstrInput2 + 
"°" + divInstrOutput + "°" + partsSum2[4];
+                       String divInstrInput1 = 
partsSum2[2].replace(partsSum2[2], sumInstr1Output + Lop.DATATYPE_PREFIX + 
"false");
+                       String divInstrInput2 = 
partsSum2[3].replace(partsSum2[3], sumInstr2Output + Lop.DATATYPE_PREFIX + 
"false");
+                       String divInstr = partsSum2[0] + Lop.OPERAND_DELIMITOR 
+ partsSum2[1].replace("uak+", "/") 
+                               + Lop.OPERAND_DELIMITOR + divInstrInput1 + 
Lop.OPERAND_DELIMITOR + divInstrInput2 
+                               + Lop.OPERAND_DELIMITOR + divInstrOutput + 
Lop.OPERAND_DELIMITOR + partsSum2[4];
 
                        FederatedRequest divFr1 = 
FederationUtils.callInstruction(
                                divInstr,
@@ -390,11 +392,11 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                                ),
                                new CPOperand[]{
                                        new CPOperand(
-                                               sumInstr1Output.substring(0, 
sumInstr1Output.indexOf("·")),
+                                               sumInstr1Output.substring(0, 
sumInstr1Output.indexOf(Lop.DATATYPE_PREFIX)),
                                                output.getValueType(), 
output.getDataType(), output.isLiteral()
                                        ),
                                        new CPOperand(
-                                               sumInstr2Output.substring(0, 
sumInstr2Output.indexOf("·")),
+                                               sumInstr2Output.substring(0, 
sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
                                                output.getValueType(), 
output.getDataType(), output.isLiteral()
                                        )
                                },
@@ -409,19 +411,19 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
        }
 
        private Future<FederatedResponse>[] processMean(MatrixObject mo1, int 
var, long weightsID){
-               String[] parts = instString.split("°");
+               String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
                Future<FederatedResponse>[] meanTmp = null;
 
                // multiply input X by weights W element-wise
                String multOutput = (var == 0 ? incrementVar(parts[2], 5) : 
incrementVar(parts[3], 3));
                String multInstr = instString
                        .replace(getOpcode(), getOpcode().replace("cov", "*"))
-                       .replace((var == 0 ? parts[2] : parts[3]) + "°", "")
+                       .replace((var == 0 ? parts[2] : parts[3]) + 
Lop.OPERAND_DELIMITOR, "")
                        .replace(parts[4], String.valueOf(weightsID) + 
"·MATRIX·FP64")
                        .replace(parts[5], multOutput);
 
                CPOperand multOutputCPOp = new CPOperand(
-                       multOutput.substring(0, multOutput.indexOf("·")),
+                       multOutput.substring(0, 
multOutput.indexOf(Lop.DATATYPE_PREFIX)),
                        mo1.getValueType(), mo1.getDataType()
                );
 
@@ -433,20 +435,20 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                );
 
                // calculate the sum of the obtained vector
-               String[] partsMult = multInstr.split("°");
+               String[] partsMult = multInstr.split(Lop.OPERAND_DELIMITOR);
                String sumInstr1Output = incrementVar(multOutput, 1)
                        .replace("m", "")
                        .replace("MATRIX", "SCALAR");
                String sumInstr1 = multInstr
                        .replace(partsMult[1], "uak+")
-                       .replace(partsMult[3] + "°", "")
+                       .replace(partsMult[3] + Lop.OPERAND_DELIMITOR, "")
                        .replace(partsMult[4], sumInstr1Output)
                        .replace(partsMult[2], multOutput);
 
                FederatedRequest sumFr1 = FederationUtils.callInstruction(
                        sumInstr1,
                        new CPOperand(
-                               sumInstr1Output.substring(0, 
sumInstr1Output.indexOf("·")),
+                               sumInstr1Output.substring(0, 
sumInstr1Output.indexOf(Lop.DATATYPE_PREFIX)),
                                output.getValueType(), output.getDataType()
                        ),
                        new CPOperand[]{multOutputCPOp},
@@ -454,7 +456,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                );
 
                // calculate the sum of weights
-               String[] partsSum1 = sumInstr1.split("°");
+               String[] partsSum1 = sumInstr1.split(Lop.OPERAND_DELIMITOR);
                String sumInstr2Output = incrementVar(sumInstr1Output, 1);
                String sumInstr2 = sumInstr1
                        .replace(partsSum1[2], String.valueOf(weightsID) + 
"·MATRIX·FP64")
@@ -463,7 +465,7 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                FederatedRequest sumFr2 = FederationUtils.callInstruction(
                        sumInstr2,
                        new CPOperand(
-                               sumInstr2Output.substring(0, 
sumInstr2Output.indexOf("·")),
+                               sumInstr2Output.substring(0, 
sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
                                output.getValueType(), output.getDataType()
                        ),
                        new CPOperand[]{input3},
@@ -471,26 +473,27 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                );
 
                // divide sum(X*W) by sum(W)
-               String[] partsSum2 = sumInstr2.split("°");
+               String[] partsSum2 = sumInstr2.split(Lop.OPERAND_DELIMITOR);
                String divInstrOutput = incrementVar(sumInstr2Output, 1);
-               String divInstrInput1 = partsSum2[2].replace(partsSum2[2], 
sumInstr1Output + "·false");
-               String divInstrInput2 = partsSum2[3].replace(partsSum2[3], 
sumInstr2Output + "·false");
-               String divInstr = partsSum2[0] + "°" + 
partsSum2[1].replace("uak+", "/") + "°" +
-                               divInstrInput1 + "°" + divInstrInput2 + "°" + 
divInstrOutput + "°" + partsSum2[4];
+               String divInstrInput1 = partsSum2[2].replace(partsSum2[2], 
sumInstr1Output + Lop.DATATYPE_PREFIX + "false");
+               String divInstrInput2 = partsSum2[3].replace(partsSum2[3], 
sumInstr2Output + Lop.DATATYPE_PREFIX + "false");
+               String divInstr = partsSum2[0] + Lop.OPERAND_DELIMITOR + 
partsSum2[1].replace("uak+", "/") + Lop.OPERAND_DELIMITOR 
+                               + divInstrInput1 + Lop.OPERAND_DELIMITOR + 
divInstrInput2 + Lop.OPERAND_DELIMITOR 
+                               + divInstrOutput + Lop.OPERAND_DELIMITOR + 
partsSum2[4];
 
                FederatedRequest divFr1 = FederationUtils.callInstruction(
                        divInstr,
                        new CPOperand(
-                               divInstrOutput.substring(0, 
divInstrOutput.indexOf("·")),
+                               divInstrOutput.substring(0, 
divInstrOutput.indexOf(Lop.DATATYPE_PREFIX)),
                                output.getValueType(), output.getDataType()
                        ),
                        new CPOperand[]{
                                new CPOperand(
-                                       sumInstr1Output.substring(0, 
sumInstr1Output.indexOf("·")),
+                                       sumInstr1Output.substring(0, 
sumInstr1Output.indexOf(Lop.DATATYPE_PREFIX)),
                                        output.getValueType(), 
output.getDataType(), output.isLiteral()
                                ),
                                new CPOperand(
-                                       sumInstr2Output.substring(0, 
sumInstr2Output.indexOf("·")),
+                                       sumInstr2Output.substring(0, 
sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
                                        output.getValueType(), 
output.getDataType(), output.isLiteral()
                                )
                        },
@@ -506,14 +509,15 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
        private Future<FederatedResponse>[] getWeightsSum(MatrixLineagePair 
moLin3, long weightsID, String instString, FederationMap fedMap) {
                Future<FederatedResponse>[] weightsSumTmp = null;
 
-               String[] parts = instString.split("°");
+               String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
                if (!instString.contains("pREADW")) {
-                       String sumInstr = "CP°uak+°" + parts[4] + "°" + 
parts[5] + "°" + parts[6];
+                       String sumInstr = "CP"+Lop.OPERAND_DELIMITOR+"uak+" + 
Lop.OPERAND_DELIMITOR 
+                               + parts[4] + Lop.OPERAND_DELIMITOR + parts[5] + 
Lop.OPERAND_DELIMITOR + parts[6];
 
                        FederatedRequest sumFr = 
FederationUtils.callInstruction(
                                sumInstr,
                                new CPOperand(
-                                       parts[5].substring(0, 
parts[5].indexOf("·")),
+                                       parts[5].substring(0, 
parts[5].indexOf(Lop.DATATYPE_PREFIX)),
                                        output.getValueType(),
                                        output.getDataType()
                                ),
@@ -526,11 +530,13 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        weightsSumTmp = fedMap.execute(getTID(), sumFr, sumFr2, 
sumFr3);
                }
                else {
-                       String sumInstr = "CP°uak+°" + 
String.valueOf(weightsID) + "·MATRIX·FP64" + "°" + parts[5] + "°" + parts[6];
+                       String sumInstr = 
"CP"+Lop.OPERAND_DELIMITOR+"uak+"+Lop.OPERAND_DELIMITOR
+                               + String.valueOf(weightsID) + "·MATRIX·FP64" + 
Lop.OPERAND_DELIMITOR + parts[5] 
+                               + Lop.OPERAND_DELIMITOR + parts[6];
                        FederatedRequest sumFr = 
FederationUtils.callInstruction(
                                sumInstr,
                                new CPOperand(
-                                       parts[5].substring(0, 
parts[5].indexOf("·")),
+                                       parts[5].substring(0, 
parts[5].indexOf(Lop.DATATYPE_PREFIX)),
                                        output.getValueType(),
                                        output.getDataType()
                                ),
@@ -576,7 +582,8 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op, 
_mo2));
                }
 
-               @Override public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+               @Override 
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
                        return null;
                }
        }
@@ -600,7 +607,8 @@ public class CovarianceFEDInstruction extends 
BinaryFEDInstruction {
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op, 
_mo2, _weights));
                }
 
-               @Override public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+               @Override 
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
                        return null;
                }
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
index 4f23c641fa..9fb42f23fd 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
@@ -136,17 +136,13 @@ public class FederatedCovarianceTest extends 
AutomatedTestBase {
                Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
                Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
                Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
-               Process t4 = startLocalFedWorker(port4);
+               Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
 
                try {
                        if(!isAlive(t1, t2, t3, t4))
                                throw new RuntimeException("Failed starting 
federated worker");
 
-                       rtplatform = execMode;
-                       if(rtplatform == ExecMode.SPARK) {
-                               System.out.println(7);
-                               DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-                       }
+                       setExecMode(execMode);
                        TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                        loadTestConfiguration(config);
 
@@ -214,11 +210,7 @@ public class FederatedCovarianceTest extends 
AutomatedTestBase {
 
        private void runWeightedCovarianceTest(ExecMode execMode, boolean 
alignedInput, boolean alignedWeights) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               ExecMode platformOld = rtplatform;
-
-               if(rtplatform == ExecMode.SPARK)
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
+               ExecMode platformOld = setExecMode(execMode);
                String TEST_NAME = !alignedInput ? TEST_NAME3 : 
(!alignedWeights ? TEST_NAME4 : TEST_NAME5);
                getAndLoadTestConfiguration(TEST_NAME);
 
@@ -256,18 +248,11 @@ public class FederatedCovarianceTest extends 
AutomatedTestBase {
                Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
                Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
                Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
-               Process t4 = startLocalFedWorker(port4);
+               Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
 
                try {
                        if(!isAlive(t1, t2, t3, t4))
                                throw new RuntimeException("Failed starting 
federated worker");
-
-                       rtplatform = execMode;
-                       if(rtplatform == ExecMode.SPARK) {
-                               System.out.println(7);
-                               DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-                       }
-
                        TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                        loadTestConfiguration(config);
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMatrixScalarOperationsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMatrixScalarOperationsTest.java
index 7f68ec1a7c..84b906b9a4 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMatrixScalarOperationsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMatrixScalarOperationsTest.java
@@ -24,6 +24,7 @@ import org.junit.runners.Parameterized;
 import org.junit.runner.RunWith;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
@@ -197,8 +198,7 @@ public class FederatedMatrixScalarOperationsTest extends 
AutomatedTestBase {
                        // we need the reference file to not be written to 
hdfs, so we get the correct format
                        rtplatform = Types.ExecMode.SINGLE_NODE;
                        programArgs = new String[] {"-w", 
Integer.toString(FEDERATED_WORKER_PORT)};
-                       t = new Thread(() -> runTest(true, false, null, -1));
-                       t.start();
+                       CommonThreadPool.get().submit(() -> runTest(true, 
false, null, -1));
                        sleep(FED_WORKER_WAIT);
                        fullDMLScriptName = SCRIPT_DIR + TEST_DIR + dmlFile + 
".dml";
                        programArgs = new String[] {"-nvargs",

Reply via email to