This is an automated email from the ASF dual-hosted git repository.
ywcb00 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 780d790c75 [SYSTEMDS-3945] Support Reversal of Federation Map w/
differently sized federated partitions
780d790c75 is described below
commit 780d790c75642a74bccc0001074980e1e14f9690
Author: ywcb00 <[email protected]>
AuthorDate: Thu Apr 30 11:36:09 2026 +0200
[SYSTEMDS-3945] Support Reversal of Federation Map w/ differently sized
federated partitions
This commit fixes the federated reverse instruction to support federation
maps with differently sized partitions and adds the corresponding test cases.
---
.../controlprogram/federated/FederationMap.java | 46 +++++++++----
.../primitives/part2/FederatedRevTest.java | 80 ++++++++++++++++++----
.../functions/federated/FederatedRevTest.dml | 8 +--
3 files changed, 103 insertions(+), 31 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index b0cccce171..68dff4785b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -21,6 +21,8 @@ package org.apache.sysds.runtime.controlprogram.federated;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
@@ -724,21 +726,41 @@ public class FederationMap {
Arrays.stream(frset).forEach(fr ->
fr.setTID(tid));
}
+ /**
+ * Sort the entries of the federation map based on their federated
ranges
+ */
+ private void sortFederatedRanges() {
+ int dim = (this.getType() == FType.COL) ? 1 : 0;
+
+ this._fedMap.sort(new Comparator<Pair<FederatedRange,
FederatedData>>() {
+ @Override
+ public int compare(Pair<FederatedRange, FederatedData>
o1, Pair<FederatedRange, FederatedData> o2) {
+ return o1.getLeft().getBeginDimsInt()[dim] -
o2.getLeft().getBeginDimsInt()[dim];
+ }
+ });
+ }
+
public void reverseFedMap() {
// TODO perf
- // TODO: add a check if the map is sorted based on indexes
before reversing.
// TODO: add a setup such that on construction the federated
map is already sorted.
- FederatedRange[] fedRanges = getFederatedRanges();
-
- for(int i = 0; i < Math.floor(fedRanges.length / 2.0); i++) {
- FederatedData data1 = getFederatedData(fedRanges[i]);
- FederatedData data2 =
getFederatedData(fedRanges[fedRanges.length-1-i]);
-
- removeFederatedData(fedRanges[i]);
- removeFederatedData(fedRanges[fedRanges.length-1-i]);
-
- _fedMap.add(Pair.of(fedRanges[i], data2));
- _fedMap.add(Pair.of(fedRanges[fedRanges.length-1-i],
data1));
+ if(this.getType() != FType.ROW)
+ throw new DMLRuntimeException("Reversing is only
supported for row partitioned federation maps yet.");
+
+ this.sortFederatedRanges();
+
+ Collections.reverse(this._fedMap);
+
+ int dim = (getType() == FType.COL) ? 1 : 0;
+ int currentDimPos = 0;
+ Iterator<Pair<FederatedRange, FederatedData>> fmIter =
this._fedMap.iterator();
+ while(fmIter.hasNext()) {
+ Pair<FederatedRange, FederatedData> elem =
fmIter.next();
+ long dimSize = elem.getLeft().getSize(dim);
+ long[] beginDims = elem.getLeft().getBeginDims();
+ long[] endDims = elem.getLeft().getEndDims();
+ beginDims[dim] = currentDimPos;
+ currentDimPos += dimSize;
+ endDims[dim] = currentDimPos;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
index e4b7ed5e24..7fe88228d0 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
@@ -39,8 +39,6 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedRevTest extends AutomatedTestBase {
- // private static final Log LOG =
LogFactory.getLog(FederatedRightIndexTest.class.getName());
-
private final static String TEST_NAME = "FederatedRevTest";
private final static String TEST_DIR = "functions/federated/";
@@ -86,11 +84,25 @@ public class FederatedRevTest extends AutomatedTestBase {
runRevTest(Types.ExecMode.SPARK, true);
}
+ @Test
+ public void testRevDifferentRangesCP() {
+ runRevTest(Types.ExecMode.SINGLE_NODE, false, true);
+ }
+
+ @Test
+ public void testRevDifferentRangesSP() {
+ runRevTest(Types.ExecMode.SPARK, false, true);
+ }
+
private void runRevTest(ExecMode execMode) {
runRevTest(execMode, false);
}
private void runRevTest(ExecMode execMode, boolean
activateFedCompilation) {
+ runRevTest(execMode, activateFedCompilation, false);
+ }
+
+ private void runRevTest(ExecMode execMode, boolean
activateFedCompilation, boolean differentPartitionSizes) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -108,20 +120,52 @@ public class FederatedRevTest extends AutomatedTestBase {
c = cols;
}
- double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
- double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
- double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
- double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+ int r_X1 = r; int r_X2 = r; int r_X3 = r; int r_X4 = r;
+ int rend_X1 = r; int rend_X2 = r; int rend_X3 = r; int rend_X4
= r;
+ int c_X1 = c; int c_X2 = c; int c_X3 = c; int c_X4 = c;
+ int cend_X1 = c_X1; int cend_X2 = c_X1+c_X2; int cend_X3 =
cend_X2+c_X3; int cend_X4 = cend_X3+c_X4;
+ if(rowPartitioned) {
+ if(differentPartitionSizes) {
+ r_X1 = r+1;
+ r_X2 = r-2;
+ r_X3 = r+1;
+ r_X4 = r-0;
+ }
+ else {
+ r_X1 = r;
+ r_X2 = r;
+ r_X3 = r;
+ r_X4 = r;
+ }
+ rend_X1 = r_X1; rend_X2 = r_X1+r_X2; rend_X3 =
rend_X2+r_X3; rend_X4 = rend_X3+r_X4;
+ c_X1 = c; c_X2 = c; c_X3 = c; c_X4 = c;
+ cend_X1 = c; cend_X2 = c; cend_X3 = c; cend_X4 = c;
+ }
+ else if(differentPartitionSizes) {
+ c_X1 = c+1;
+ c_X2 = c-2;
+ c_X3 = c+1;
+ c_X4 = c-0;
+ cend_X1 = c_X1; cend_X2 = c_X1+c_X2; cend_X3 =
cend_X2+c_X3; cend_X4 = cend_X3+c_X4;
+ }
+
+ double[][] X1 = getRandomMatrix(r_X1, c_X1, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r_X2, c_X2, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r_X3, c_X3, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r_X4, c_X4, 1, 5, 1, 9);
for(int k : new int[] {1, 2, 3}) {
Arrays.fill(X3[k], 0);
}
- MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
- writeInputMatrixWithMTD("X1", X1, false, mc);
- writeInputMatrixWithMTD("X2", X2, false, mc);
- writeInputMatrixWithMTD("X3", X3, false, mc);
- writeInputMatrixWithMTD("X4", X4, false, mc);
+ writeInputMatrixWithMTD("X1", X1, false,
+ new MatrixCharacteristics(r_X1, c_X1, blocksize, r_X1 *
c_X1));
+ writeInputMatrixWithMTD("X2", X2, false,
+ new MatrixCharacteristics(r_X2, c_X2, blocksize, r_X2 *
c_X2));
+ writeInputMatrixWithMTD("X3", X3, false,
+ new MatrixCharacteristics(r_X3, c_X3, blocksize, r_X3 *
c_X3));
+ writeInputMatrixWithMTD("X4", X4, false,
+ new MatrixCharacteristics(r_X4, c_X4, blocksize, r_X4 *
c_X4));
// empty script name because we don't execute any script, just
start the worker
fullDMLScriptName = "";
@@ -134,7 +178,6 @@ public class FederatedRevTest extends AutomatedTestBase {
Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
Process t4 = startLocalFedWorker(port4);
-
try {
if(!isAlive(t1, t2, t3, t4))
throw new RuntimeException("Failed starting
federated worker");
@@ -147,7 +190,8 @@ public class FederatedRevTest extends AutomatedTestBase {
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
+ programArgs = new String[] {"-stats", "100", "-args",
+ input("X1"), input("X2"), input("X3"),
input("X4"),
Boolean.toString(rowPartitioned).toUpperCase(),
expected("S")};
runTest(null);
@@ -158,8 +202,14 @@ public class FederatedRevTest extends AutomatedTestBase {
"in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + cols,
- "rP=" +
Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")),
+ "rows=" + rows, "cols=" + cols,
+ "rend_X1=" + rend_X1, "cend_X1=" + cend_X1,
+ "rend_X2=" + rend_X2, "cend_X2=" + cend_X2,
+ "rend_X3=" + rend_X3, "cend_X3=" + cend_X3,
+ "rend_X4=" + rend_X4, "cend_X4=" + cend_X4,
+ "rP=" +
Boolean.toString(rowPartitioned).toUpperCase(),
+ "out_S=" + output("S")};
runTest(null);
diff --git a/src/test/scripts/functions/federated/FederatedRevTest.dml
b/src/test/scripts/functions/federated/FederatedRevTest.dml
index d43edd1728..128b511d5f 100644
--- a/src/test/scripts/functions/federated/FederatedRevTest.dml
+++ b/src/test/scripts/functions/federated/FederatedRevTest.dml
@@ -20,12 +20,12 @@
#-------------------------------------------------------------
if ($rP) {
A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
- ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
- list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ ranges=list(list(0, 0), list($rend_X1, $cols), list($rend_X1, 0),
list($rend_X2, $cols),
+ list($rend_X2, 0), list($rend_X3, $cols), list($rend_X3, 0),
list($rend_X4, $cols)));
} else {
A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
- ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
- list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+ ranges=list(list(0, 0), list($rows, $cend_X1), list(0,$cend_X1),
list($rows, $cend_X2),
+ list(0,$cend_X2), list($rows, $cend_X3), list(0, $cend_X3),
list($rows, $cend_X4)));
}
s = rev(A);