This is an automated email from the ASF dual-hosted git repository.

ywcb00 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 780d790c75 [SYSTEMDS-3945] Support Reversal of Federation Map w/ 
differently sized federated partitions
780d790c75 is described below

commit 780d790c75642a74bccc0001074980e1e14f9690
Author: ywcb00 <[email protected]>
AuthorDate: Thu Apr 30 11:36:09 2026 +0200

    [SYSTEMDS-3945] Support Reversal of Federation Map w/ differently sized 
federated partitions
    
    This commit fixes the federated reverse instruction to support federation 
maps with differently sized partitions and adds the corresponding test cases.
---
 .../controlprogram/federated/FederationMap.java    | 46 +++++++++----
 .../primitives/part2/FederatedRevTest.java         | 80 ++++++++++++++++++----
 .../functions/federated/FederatedRevTest.dml       |  8 +--
 3 files changed, 103 insertions(+), 31 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index b0cccce171..68dff4785b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -21,6 +21,8 @@ package org.apache.sysds.runtime.controlprogram.federated;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map.Entry;
@@ -724,21 +726,41 @@ public class FederationMap {
                                Arrays.stream(frset).forEach(fr -> 
fr.setTID(tid));
        }
 
+       /**
+        * Sort the entries of the federation map based on their federated 
ranges
+        */
+       private void sortFederatedRanges() {
+               int dim = (this.getType() == FType.COL) ? 1 : 0;
+
+               this._fedMap.sort(new Comparator<Pair<FederatedRange, 
FederatedData>>() {
+                       @Override
+                       public int compare(Pair<FederatedRange, FederatedData> 
o1, Pair<FederatedRange, FederatedData> o2) {
+                               return o1.getLeft().getBeginDimsInt()[dim] - 
o2.getLeft().getBeginDimsInt()[dim];
+                       }
+               });
+       }
+
        public void reverseFedMap() {
                // TODO perf
-               // TODO: add a check if the map is sorted based on indexes 
before reversing.
                // TODO: add a setup such that on construction the federated 
map is already sorted.
-               FederatedRange[] fedRanges = getFederatedRanges();
-
-               for(int i = 0; i < Math.floor(fedRanges.length / 2.0); i++) {
-                       FederatedData data1 = getFederatedData(fedRanges[i]);
-                       FederatedData data2 = 
getFederatedData(fedRanges[fedRanges.length-1-i]);
-
-                       removeFederatedData(fedRanges[i]);
-                       removeFederatedData(fedRanges[fedRanges.length-1-i]);
-
-                       _fedMap.add(Pair.of(fedRanges[i], data2));
-                       _fedMap.add(Pair.of(fedRanges[fedRanges.length-1-i], 
data1));
+               if(this.getType() != FType.ROW)
+                       throw new DMLRuntimeException("Reversing is only 
supported for row partitioned federation maps yet.");
+
+               this.sortFederatedRanges();
+
+               Collections.reverse(this._fedMap);
+
+               int dim = (getType() == FType.COL) ? 1 : 0;
+               int currentDimPos = 0;
+               Iterator<Pair<FederatedRange, FederatedData>> fmIter = 
this._fedMap.iterator();
+               while(fmIter.hasNext()) {
+                       Pair<FederatedRange, FederatedData> elem = 
fmIter.next();
+                       long dimSize = elem.getLeft().getSize(dim);
+                       long[] beginDims = elem.getLeft().getBeginDims();
+                       long[] endDims = elem.getLeft().getEndDims();
+                       beginDims[dim] = currentDimPos;
+                       currentDimPos += dimSize;
+                       endDims[dim] = currentDimPos;
                }
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
index e4b7ed5e24..7fe88228d0 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
@@ -39,8 +39,6 @@ import org.junit.runners.Parameterized;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class FederatedRevTest extends AutomatedTestBase {
-       // private static final Log LOG = 
LogFactory.getLog(FederatedRightIndexTest.class.getName());
-
        private final static String TEST_NAME = "FederatedRevTest";
 
        private final static String TEST_DIR = "functions/federated/";
@@ -86,11 +84,25 @@ public class FederatedRevTest extends AutomatedTestBase {
                runRevTest(Types.ExecMode.SPARK, true);
        }
 
+       @Test
+       public void testRevDifferentRangesCP() {
+               runRevTest(Types.ExecMode.SINGLE_NODE, false, true);
+       }
+
+               @Test
+       public void testRevDifferentRangesSP() {
+               runRevTest(Types.ExecMode.SPARK, false, true);
+       }
+
        private void runRevTest(ExecMode execMode) {
                runRevTest(execMode, false);
        }
 
        private void runRevTest(ExecMode execMode, boolean 
activateFedCompilation) {
+               runRevTest(execMode, activateFedCompilation, false);
+       }
+
+       private void runRevTest(ExecMode execMode, boolean 
activateFedCompilation, boolean differentPartitionSizes) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
 
@@ -108,20 +120,52 @@ public class FederatedRevTest extends AutomatedTestBase {
                        c = cols;
                }
 
-               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
-               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
-               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
-               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+               int r_X1 = r; int r_X2 = r; int r_X3 = r; int r_X4 = r;
+               int rend_X1 = r; int rend_X2 = r; int rend_X3 = r; int rend_X4 
= r;
+               int c_X1 = c; int c_X2 = c; int c_X3 = c; int c_X4 = c;
+               int cend_X1 = c_X1; int cend_X2 = c_X1+c_X2; int cend_X3 = 
cend_X2+c_X3; int cend_X4 = cend_X3+c_X4;
+               if(rowPartitioned) {
+                       if(differentPartitionSizes) {
+                               r_X1 = r+1;
+                               r_X2 = r-2;
+                               r_X3 = r+1;
+                               r_X4 = r-0;
+                       }
+                       else {
+                               r_X1 = r;
+                               r_X2 = r;
+                               r_X3 = r;
+                               r_X4 = r;
+                       }
+                       rend_X1 = r_X1; rend_X2 = r_X1+r_X2; rend_X3 = 
rend_X2+r_X3; rend_X4 = rend_X3+r_X4;
+                       c_X1 = c; c_X2 = c; c_X3 = c; c_X4 = c;
+                       cend_X1 = c; cend_X2 = c; cend_X3 = c; cend_X4 = c;
+               }
+               else if(differentPartitionSizes) {
+                       c_X1 = c+1;
+                       c_X2 = c-2;
+                       c_X3 = c+1;
+                       c_X4 = c-0;
+                       cend_X1 = c_X1; cend_X2 = c_X1+c_X2; cend_X3 = 
cend_X2+c_X3; cend_X4 = cend_X3+c_X4;
+               }
+
+               double[][] X1 = getRandomMatrix(r_X1, c_X1, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r_X2, c_X2, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r_X3, c_X3, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r_X4, c_X4, 1, 5, 1, 9);
 
                for(int k : new int[] {1, 2, 3}) {
                        Arrays.fill(X3[k], 0);
                }
 
-               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
-               writeInputMatrixWithMTD("X1", X1, false, mc);
-               writeInputMatrixWithMTD("X2", X2, false, mc);
-               writeInputMatrixWithMTD("X3", X3, false, mc);
-               writeInputMatrixWithMTD("X4", X4, false, mc);
+               writeInputMatrixWithMTD("X1", X1, false,
+                       new MatrixCharacteristics(r_X1, c_X1, blocksize, r_X1 * 
c_X1));
+               writeInputMatrixWithMTD("X2", X2, false,
+                       new MatrixCharacteristics(r_X2, c_X2, blocksize, r_X2 * 
c_X2));
+               writeInputMatrixWithMTD("X3", X3, false,
+                       new MatrixCharacteristics(r_X3, c_X3, blocksize, r_X3 * 
c_X3));
+               writeInputMatrixWithMTD("X4", X4, false,
+                       new MatrixCharacteristics(r_X4, c_X4, blocksize, r_X4 * 
c_X4));
 
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
@@ -134,7 +178,6 @@ public class FederatedRevTest extends AutomatedTestBase {
                Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
                Process t4 = startLocalFedWorker(port4);
 
-               
                try {
                        if(!isAlive(t1, t2, t3, t4))
                                throw new RuntimeException("Failed starting 
federated worker");
@@ -147,7 +190,8 @@ public class FederatedRevTest extends AutomatedTestBase {
 
                        // Run reference dml script with normal matrix
                        fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-                       programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                       programArgs = new String[] {"-stats", "100", "-args",
+                               input("X1"), input("X2"), input("X3"), 
input("X4"),
                                Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
 
                        runTest(null);
@@ -158,8 +202,14 @@ public class FederatedRevTest extends AutomatedTestBase {
                                "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
                                "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
                                "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
-                               "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
-                               "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+                               "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
+                               "rows=" + rows, "cols=" + cols,
+                               "rend_X1=" + rend_X1, "cend_X1=" + cend_X1,
+                               "rend_X2=" + rend_X2, "cend_X2=" + cend_X2,
+                               "rend_X3=" + rend_X3, "cend_X3=" + cend_X3,
+                               "rend_X4=" + rend_X4, "cend_X4=" + cend_X4,
+                               "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase(),
+                               "out_S=" + output("S")};
 
                        runTest(null);
 
diff --git a/src/test/scripts/functions/federated/FederatedRevTest.dml 
b/src/test/scripts/functions/federated/FederatedRevTest.dml
index d43edd1728..128b511d5f 100644
--- a/src/test/scripts/functions/federated/FederatedRevTest.dml
+++ b/src/test/scripts/functions/federated/FederatedRevTest.dml
@@ -20,12 +20,12 @@
  #-------------------------------------------------------------
 if ($rP) {
     A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
-               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+        ranges=list(list(0, 0), list($rend_X1, $cols), list($rend_X1, 0), 
list($rend_X2, $cols),
+            list($rend_X2, 0), list($rend_X3, $cols), list($rend_X3, 0), 
list($rend_X4, $cols)));
 } else {
     A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
-               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+        ranges=list(list(0, 0), list($rows, $cend_X1), list(0,$cend_X1), 
list($rows, $cend_X2),
+            list(0,$cend_X2), list($rows, $cend_X3), list(0, $cend_X3), 
list($rows, $cend_X4)));
 }
 
 s = rev(A);

Reply via email to