This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9a318eeccc [SYSTEMDS-3796] Fix robustness federated weighted
covariance and tests
9a318eeccc is described below
commit 9a318eeccc3ae1da999f47a3b8f1c4d003ea32bc
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Nov 28 09:38:31 2024 +0100
[SYSTEMDS-3796] Fix robustness federated weighted covariance and tests
---
.../controlprogram/federated/FederationMap.java | 6 ++--
.../instructions/fed/CovarianceFEDInstruction.java | 41 +++++++---------------
.../primitives/part5/FederatedCovarianceTest.java | 7 ++--
3 files changed, 19 insertions(+), 35 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 91e6c156c4..2574c4f175 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -147,10 +147,8 @@ public class FederationMap {
return broadcastSliced(data, null, transposed);
}
- public FederatedRequest[] broadcastSliced(MatrixLineagePair moLin,
- boolean transposed) {
- return broadcastSliced(moLin.getMO(), moLin.getLI(),
- transposed);
+ public FederatedRequest[] broadcastSliced(MatrixLineagePair moLin,
boolean transposed) {
+ return broadcastSliced(moLin.getMO(), moLin.getLI(),
transposed);
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
index d7f28293ce..4d22fd753e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
@@ -114,27 +114,20 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
new CPOperand[]{input1, input2, input3}, new
long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(),
moLin3.getFedMapping().getID()});
}
-
+
FederatedRequest fr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(),
fr1.getID());
- Future<FederatedResponse>[] covTmp =
mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
-
- //means
- Future<FederatedResponse>[] meanTmp1 = processMean(mo1, moLin3,
0);
- Future<FederatedResponse>[] meanTmp2 = processMean(mo2, moLin3,
1);
-
- Double[] cov = getResponses(covTmp);
- Double[] mean1 = getResponses(meanTmp1);
- Double[] mean2 = getResponses(meanTmp2);
+ Double[] cov =
getResponses(mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3));
+ Double[] mean1 = getResponses(processMean(mo1, moLin3, 0));
+ Double[] mean2 = getResponses(processMean(mo2, moLin3, 1));
if (moLin3 == null) {
double result = aggCov(cov, mean1, mean2,
mo1.getFedMapping().getFederatedRanges());
ec.setVariable(output.getName(), new
DoubleObject(result));
}
else {
- Future<FederatedResponse>[] weightsSumTmp =
getWeightsSum(moLin3, moLin3.getFedMapping().getID(), instString,
moLin3.getFedMapping());
- Double[] weights = getResponses(weightsSumTmp);
-
+ Double[] weights = getResponses(
+ getWeightsSum(moLin3,
moLin3.getFedMapping().getID(), instString, moLin3.getFedMapping()));
double result = aggWeightedCov(cov, mean1, mean2,
weights);
ec.setVariable(output.getName(), new
DoubleObject(result));
}
@@ -154,21 +147,13 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
new CPOperand[]{input1, input2, input3},
new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID(), fr1[0].getID()}
);
+ //sequential execution of cov and means for robustness
FederatedRequest fr3 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(),
fr2.getID());
- Future<FederatedResponse>[] covTmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
-
- //means
- Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0,
fr1[0].getID());
- Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1,
fr1[0].getID());
-
- Double[] cov = getResponses(covTmp);
- Double[] mean1 = getResponses(meanTmp1);
- Double[] mean2 = getResponses(meanTmp2);
-
- Future<FederatedResponse>[] weightsSumTmp =
getWeightsSum(moLin3, fr1[0].getID(), instString, mo1.getFedMapping());
- Double[] weights = getResponses(weightsSumTmp);
-
+ Double[] cov =
getResponses(mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3, fr4));
+ Double[] mean1 = getResponses(processMean(mo1, 0,
fr1[0].getID()));
+ Double[] mean2 = getResponses(processMean(mo2, 1,
fr1[0].getID()));
+ Double[] weights = getResponses(getWeightsSum(moLin3,
fr1[0].getID(), instString, mo1.getFedMapping()));
double result = aggWeightedCov(cov, mean1, mean2, weights);
ec.setVariable(output.getName(), new DoubleObject(result));
}
@@ -243,7 +228,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
fr[i] = ((ScalarObject)
ffr[i].get().getData()[0]).getDoubleValue();
}
catch(Exception e) {
- throw new
DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.");
+ throw new
DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.", e);
}
});
@@ -302,7 +287,7 @@ public class CovarianceFEDInstruction extends
BinaryFEDInstruction {
}
private Future<FederatedResponse>[] processMean(MatrixObject mo1,
MatrixLineagePair moLin3, int var){
- String[] parts = instString.split("°");
+ String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
Future<FederatedResponse>[] meanTmp = null;
if (moLin3 == null) {
String meanInstr = instString.replace(getOpcode(),
getOpcode().replace("cov", "uamean"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
index 9fb42f23fd..48c9cab632 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
@@ -49,7 +49,7 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/";
private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedCovarianceTest.class.getSimpleName() + "/";
- private final static int blocksize = 1024;
+ private final static int blocksize = 1000;
@Parameterized.Parameter
public int rows;
@Parameterized.Parameter(1)
@@ -57,8 +57,9 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{20, 1},
- // {100, 1}, {1000, 1}
+ return Arrays.asList(new Object[][] {
+ {120, 1},
+ {1100, 1},
});
}