This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a6a1509c8a [SYSTEMDS-3796] Fix flaky federated primitive tests and
instructions
a6a1509c8a is described below
commit a6a1509c8a40b400f63286b4631f0d42849be187
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Nov 26 09:46:31 2024 +0100
[SYSTEMDS-3796] Fix flaky federated primitive tests and instructions
---
.../instructions/fed/CovarianceFEDInstruction.java | 84 ++++++++++++----------
.../primitives/part5/FederatedCovarianceTest.java | 23 ++----
.../part5/FederatedMatrixScalarOperationsTest.java | 4 +-
3 files changed, 52 insertions(+), 59 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
index 719fd91588..d7f28293ce 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
@@ -28,6 +28,7 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -116,7 +117,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
FederatedRequest fr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(),
fr1.getID());
- Future<FederatedResponse>[] covTmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ Future<FederatedResponse>[] covTmp =
mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
//means
Future<FederatedResponse>[] meanTmp1 = processMean(mo1, moLin3,
0);
@@ -145,7 +146,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(moLin3, false);
// the original instruction encodes weights as "pREADW", change
to the new ID
- String[] parts = instString.split("°");
+ String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
String covInstr = instString.replace(parts[4],
String.valueOf(fr1[0].getID()) + "·MATRIX·FP64");
FederatedRequest fr2 = FederationUtils.callInstruction(
@@ -305,7 +306,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
Future<FederatedResponse>[] meanTmp = null;
if (moLin3 == null) {
String meanInstr = instString.replace(getOpcode(),
getOpcode().replace("cov", "uamean"));
- meanInstr = meanInstr.replace((var == 0 ? parts[2] :
parts[3]) + "°", "");
+ meanInstr = meanInstr.replace((var == 0 ? parts[2] :
parts[3]) + Lop.OPERAND_DELIMITOR, "");
meanInstr = meanInstr.replace(parts[4],
parts[4].replace("FP64", "STRING°16"));
//create federated commands for aggregation
@@ -321,7 +322,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
String multOutput = incrementVar(parts[4], 1);
String multInstr = instString
.replace(getOpcode(),
getOpcode().replace("cov", "*"))
- .replace((var == 0 ? parts[2] : parts[3]) +
"°", "")
+ .replace((var == 0 ? parts[2] : parts[3]) +
Lop.OPERAND_DELIMITOR, "")
.replace(parts[5], multOutput);
CPOperand multOutputCPOp = new CPOperand(
@@ -337,13 +338,13 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
);
// calculate the sum of the obtained vector
- String[] partsMult = multInstr.split("°");
+ String[] partsMult =
multInstr.split(Lop.OPERAND_DELIMITOR);
String sumInstr1Output = incrementVar(multOutput, 1)
.replace("m", "")
.replace("MATRIX", "SCALAR");
String sumInstr1 = multInstr
.replace(partsMult[1], "uak+")
- .replace(partsMult[3] + "°", "")
+ .replace(partsMult[3] + Lop.OPERAND_DELIMITOR,
"")
.replace(partsMult[4], sumInstr1Output)
.replace(partsMult[2], multOutput);
@@ -358,7 +359,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
);
// calculate the sum of weights
- String[] partsSum1 = sumInstr1.split("°");
+ String[] partsSum1 =
sumInstr1.split(Lop.OPERAND_DELIMITOR);
String sumInstr2Output = incrementVar(sumInstr1Output,
1);
String sumInstr2 = sumInstr1
.replace(partsSum1[2], parts[4])
@@ -367,7 +368,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
FederatedRequest sumFr2 =
FederationUtils.callInstruction(
sumInstr2,
new CPOperand(
- sumInstr2Output.substring(0,
sumInstr2Output.indexOf("·")),
+ sumInstr2Output.substring(0,
sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(),
output.getDataType()
),
new CPOperand[]{input3},
@@ -375,12 +376,13 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
);
// divide sum(X*W) by sum(W)
- String[] partsSum2 = sumInstr2.split("°");
+ String[] partsSum2 =
sumInstr2.split(Lop.OPERAND_DELIMITOR);
String divInstrOutput = incrementVar(sumInstr2Output,
1);
- String divInstrInput1 =
partsSum2[2].replace(partsSum2[2], sumInstr1Output + "·false");
- String divInstrInput2 =
partsSum2[3].replace(partsSum2[3], sumInstr2Output + "·false");
- String divInstr = partsSum2[0] + "°" +
partsSum2[1].replace("uak+", "/") + "°" +
- divInstrInput1 + "°" + divInstrInput2 +
"°" + divInstrOutput + "°" + partsSum2[4];
+ String divInstrInput1 =
partsSum2[2].replace(partsSum2[2], sumInstr1Output + Lop.DATATYPE_PREFIX +
"false");
+ String divInstrInput2 =
partsSum2[3].replace(partsSum2[3], sumInstr2Output + Lop.DATATYPE_PREFIX +
"false");
+ String divInstr = partsSum2[0] + Lop.OPERAND_DELIMITOR
+ partsSum2[1].replace("uak+", "/")
+ + Lop.OPERAND_DELIMITOR + divInstrInput1 +
Lop.OPERAND_DELIMITOR + divInstrInput2
+ + Lop.OPERAND_DELIMITOR + divInstrOutput +
Lop.OPERAND_DELIMITOR + partsSum2[4];
FederatedRequest divFr1 =
FederationUtils.callInstruction(
divInstr,
@@ -390,11 +392,11 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
),
new CPOperand[]{
new CPOperand(
- sumInstr1Output.substring(0,
sumInstr1Output.indexOf("·")),
+ sumInstr1Output.substring(0,
sumInstr1Output.indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(),
output.getDataType(), output.isLiteral()
),
new CPOperand(
- sumInstr2Output.substring(0,
sumInstr2Output.indexOf("·")),
+ sumInstr2Output.substring(0,
sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(),
output.getDataType(), output.isLiteral()
)
},
@@ -409,19 +411,19 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
}
private Future<FederatedResponse>[] processMean(MatrixObject mo1, int
var, long weightsID){
- String[] parts = instString.split("°");
+ String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
Future<FederatedResponse>[] meanTmp = null;
// multiply input X by weights W element-wise
String multOutput = (var == 0 ? incrementVar(parts[2], 5) :
incrementVar(parts[3], 3));
String multInstr = instString
.replace(getOpcode(), getOpcode().replace("cov", "*"))
- .replace((var == 0 ? parts[2] : parts[3]) + "°", "")
+ .replace((var == 0 ? parts[2] : parts[3]) +
Lop.OPERAND_DELIMITOR, "")
.replace(parts[4], String.valueOf(weightsID) +
"·MATRIX·FP64")
.replace(parts[5], multOutput);
CPOperand multOutputCPOp = new CPOperand(
- multOutput.substring(0, multOutput.indexOf("·")),
+ multOutput.substring(0,
multOutput.indexOf(Lop.DATATYPE_PREFIX)),
mo1.getValueType(), mo1.getDataType()
);
@@ -433,20 +435,20 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
);
// calculate the sum of the obtained vector
- String[] partsMult = multInstr.split("°");
+ String[] partsMult = multInstr.split(Lop.OPERAND_DELIMITOR);
String sumInstr1Output = incrementVar(multOutput, 1)
.replace("m", "")
.replace("MATRIX", "SCALAR");
String sumInstr1 = multInstr
.replace(partsMult[1], "uak+")
- .replace(partsMult[3] + "°", "")
+ .replace(partsMult[3] + Lop.OPERAND_DELIMITOR, "")
.replace(partsMult[4], sumInstr1Output)
.replace(partsMult[2], multOutput);
FederatedRequest sumFr1 = FederationUtils.callInstruction(
sumInstr1,
new CPOperand(
- sumInstr1Output.substring(0,
sumInstr1Output.indexOf("·")),
+ sumInstr1Output.substring(0,
sumInstr1Output.indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(), output.getDataType()
),
new CPOperand[]{multOutputCPOp},
@@ -454,7 +456,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
);
// calculate the sum of weights
- String[] partsSum1 = sumInstr1.split("°");
+ String[] partsSum1 = sumInstr1.split(Lop.OPERAND_DELIMITOR);
String sumInstr2Output = incrementVar(sumInstr1Output, 1);
String sumInstr2 = sumInstr1
.replace(partsSum1[2], String.valueOf(weightsID) +
"·MATRIX·FP64")
@@ -463,7 +465,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
FederatedRequest sumFr2 = FederationUtils.callInstruction(
sumInstr2,
new CPOperand(
- sumInstr2Output.substring(0,
sumInstr2Output.indexOf("·")),
+ sumInstr2Output.substring(0,
sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(), output.getDataType()
),
new CPOperand[]{input3},
@@ -471,26 +473,27 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
);
// divide sum(X*W) by sum(W)
- String[] partsSum2 = sumInstr2.split("°");
+ String[] partsSum2 = sumInstr2.split(Lop.OPERAND_DELIMITOR);
String divInstrOutput = incrementVar(sumInstr2Output, 1);
- String divInstrInput1 = partsSum2[2].replace(partsSum2[2],
sumInstr1Output + "·false");
- String divInstrInput2 = partsSum2[3].replace(partsSum2[3],
sumInstr2Output + "·false");
- String divInstr = partsSum2[0] + "°" +
partsSum2[1].replace("uak+", "/") + "°" +
- divInstrInput1 + "°" + divInstrInput2 + "°" +
divInstrOutput + "°" + partsSum2[4];
+ String divInstrInput1 = partsSum2[2].replace(partsSum2[2],
sumInstr1Output + Lop.DATATYPE_PREFIX + "false");
+ String divInstrInput2 = partsSum2[3].replace(partsSum2[3],
sumInstr2Output + Lop.DATATYPE_PREFIX + "false");
+ String divInstr = partsSum2[0] + Lop.OPERAND_DELIMITOR +
partsSum2[1].replace("uak+", "/") + Lop.OPERAND_DELIMITOR
+ + divInstrInput1 + Lop.OPERAND_DELIMITOR +
divInstrInput2 + Lop.OPERAND_DELIMITOR
+ + divInstrOutput + Lop.OPERAND_DELIMITOR +
partsSum2[4];
FederatedRequest divFr1 = FederationUtils.callInstruction(
divInstr,
new CPOperand(
- divInstrOutput.substring(0,
divInstrOutput.indexOf("·")),
+ divInstrOutput.substring(0,
divInstrOutput.indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(), output.getDataType()
),
new CPOperand[]{
new CPOperand(
- sumInstr1Output.substring(0,
sumInstr1Output.indexOf("·")),
+ sumInstr1Output.substring(0,
sumInstr1Output.indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(),
output.getDataType(), output.isLiteral()
),
new CPOperand(
- sumInstr2Output.substring(0,
sumInstr2Output.indexOf("·")),
+ sumInstr2Output.substring(0,
sumInstr2Output.indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(),
output.getDataType(), output.isLiteral()
)
},
@@ -506,14 +509,15 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
private Future<FederatedResponse>[] getWeightsSum(MatrixLineagePair
moLin3, long weightsID, String instString, FederationMap fedMap) {
Future<FederatedResponse>[] weightsSumTmp = null;
- String[] parts = instString.split("°");
+ String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
if (!instString.contains("pREADW")) {
- String sumInstr = "CP°uak+°" + parts[4] + "°" +
parts[5] + "°" + parts[6];
+ String sumInstr = "CP"+Lop.OPERAND_DELIMITOR+"uak+" +
Lop.OPERAND_DELIMITOR
+ + parts[4] + Lop.OPERAND_DELIMITOR + parts[5] +
Lop.OPERAND_DELIMITOR + parts[6];
FederatedRequest sumFr =
FederationUtils.callInstruction(
sumInstr,
new CPOperand(
- parts[5].substring(0,
parts[5].indexOf("·")),
+ parts[5].substring(0,
parts[5].indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(),
output.getDataType()
),
@@ -526,11 +530,13 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
weightsSumTmp = fedMap.execute(getTID(), sumFr, sumFr2,
sumFr3);
}
else {
- String sumInstr = "CP°uak+°" +
String.valueOf(weightsID) + "·MATRIX·FP64" + "°" + parts[5] + "°" + parts[6];
+ String sumInstr =
"CP"+Lop.OPERAND_DELIMITOR+"uak+"+Lop.OPERAND_DELIMITOR
+ + String.valueOf(weightsID) + "·MATRIX·FP64" +
Lop.OPERAND_DELIMITOR + parts[5]
+ + Lop.OPERAND_DELIMITOR + parts[6];
FederatedRequest sumFr =
FederationUtils.callInstruction(
sumInstr,
new CPOperand(
- parts[5].substring(0,
parts[5].indexOf("·")),
+ parts[5].substring(0,
parts[5].indexOf(Lop.DATATYPE_PREFIX)),
output.getValueType(),
output.getDataType()
),
@@ -576,7 +582,8 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op,
_mo2));
}
- @Override public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
return null;
}
}
@@ -600,7 +607,8 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op,
_mo2, _weights));
}
- @Override public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
return null;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
index 4f23c641fa..9fb42f23fd 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
@@ -136,17 +136,13 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
try {
if(!isAlive(t1, t2, t3, t4))
throw new RuntimeException("Failed starting
federated worker");
- rtplatform = execMode;
- if(rtplatform == ExecMode.SPARK) {
- System.out.println(7);
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- }
+ setExecMode(execMode);
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -214,11 +210,7 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
private void runWeightedCovarianceTest(ExecMode execMode, boolean
alignedInput, boolean alignedWeights) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- ExecMode platformOld = rtplatform;
-
- if(rtplatform == ExecMode.SPARK)
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
+ ExecMode platformOld = setExecMode(execMode);
String TEST_NAME = !alignedInput ? TEST_NAME3 :
(!alignedWeights ? TEST_NAME4 : TEST_NAME5);
getAndLoadTestConfiguration(TEST_NAME);
@@ -256,18 +248,11 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
try {
if(!isAlive(t1, t2, t3, t4))
throw new RuntimeException("Failed starting
federated worker");
-
- rtplatform = execMode;
- if(rtplatform == ExecMode.SPARK) {
- System.out.println(7);
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- }
-
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMatrixScalarOperationsTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMatrixScalarOperationsTest.java
index 7f68ec1a7c..84b906b9a4 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMatrixScalarOperationsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMatrixScalarOperationsTest.java
@@ -24,6 +24,7 @@ import org.junit.runners.Parameterized;
import org.junit.runner.RunWith;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -197,8 +198,7 @@ public class FederatedMatrixScalarOperationsTest extends
AutomatedTestBase {
// we need the reference file to not be written to
hdfs, so we get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;
programArgs = new String[] {"-w",
Integer.toString(FEDERATED_WORKER_PORT)};
- t = new Thread(() -> runTest(true, false, null, -1));
- t.start();
+ CommonThreadPool.get().submit(() -> runTest(true,
false, null, -1));
sleep(FED_WORKER_WAIT);
fullDMLScriptName = SCRIPT_DIR + TEST_DIR + dmlFile +
".dml";
programArgs = new String[] {"-nvargs",