This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f00c6c7  [SYSTEMDS-3278] Fix federated unary aggregate, duplicates in 
fedinit
f00c6c7 is described below

commit f00c6c70023826c5f27876edaa59e4b82853ec5b
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sun Jan 23 23:19:42 2022 +0100

    [SYSTEMDS-3278] Fix federated unary aggregate, duplicates in fedinit
    
    This patch addresses two severe issues that have been introduced or
    identified in recent history:
    
    * The unary aggregate incorrectly called cleanup on get (for local out),
    which causes issues if a federated instruction is called on the output
    as well (besides the CP output data)
    
    * The parallel event loop revealed shortcomings with duplicated
    addresses in federated data. Concurrent requests and fixed variable
    names across requests cause an overwrite of intermediates in this case.
    This patch adds a warning on federated init and fixes the incorrect
    test.
---
 .../federated/FederatedLookupTable.java            |  2 +-
 .../fed/AggregateUnaryFEDInstruction.java          | 21 +++---
 .../instructions/fed/InitFEDInstruction.java       | 13 ++++
 .../federated/primitives/FederatedSumTest.java     | 86 +++++++++++-----------
 .../functions/federated/FederatedSumTest.dml       |  2 +-
 .../federated/FederatedSumTestReference.dml        |  2 +-
 6 files changed, 70 insertions(+), 56 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
index 63defe4..55ab971 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
@@ -31,7 +31,7 @@ import org.apache.sysds.api.DMLScript;
  * ExecutionContextMap (ECM) so that every coordinator can address federated
  * variables with its own local sequential variable IDs. Therefore, the IDs
  * among different coordinators do not have to be distinct, as every
- * coordinator works with a seperate ECM at the FederatedWorker.
+ * coordinator works with a separate ECM at the FederatedWorker.
  */
 public class FederatedLookupTable {
        // the NOHOST constant is needed for creating FederatedLocalData where 
there
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index d329f44..88a066a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -201,10 +201,9 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
                        new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()}, true);
                FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
-               FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
 
                //execute federated commands and cleanups
-               Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2, fr3);
+               Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2);
                if( output.isScalar() )
                        ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aggUOptr, tmp, map));
                else
@@ -250,18 +249,20 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                        FederatedRequest meanFr1 =  
FederationUtils.callInstruction(meanInstr, output, id,
                                new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()}, isSpark ? ExecType.SPARK : ExecType.CP, 
isSpark);
                        FederatedRequest meanFr2 = new 
FederatedRequest(RequestType.GET_VAR, meanFr1.getID());
-                       FederatedRequest meanFr3 = map.cleanup(getTID(), 
meanFr1.getID());
-                       meanTmp = map.execute(getTID(), isSpark ? new 
FederatedRequest[] {tmpRequest, meanFr1, meanFr2, meanFr3} : new 
FederatedRequest[] {meanFr1, meanFr2, meanFr3});
+                       meanTmp = map.execute(getTID(), isSpark ?
+                               new FederatedRequest[] {tmpRequest, meanFr1, 
meanFr2} :
+                               new FederatedRequest[] {meanFr1, meanFr2});
                }
 
                //create federated commands for aggregation
                FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output, id,
                        new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()}, isSpark ? ExecType.SPARK : ExecType.CP, 
isSpark);
                FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
-               FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
                
                //execute federated commands and cleanups
-               Future<FederatedResponse>[] tmp = map.execute(getTID(), isSpark 
? new FederatedRequest[] {tmpRequest,  fr1, fr2, fr3} : new FederatedRequest[] 
{ fr1, fr2, fr3});
+               Future<FederatedResponse>[] tmp = map.execute(getTID(), isSpark 
?
+                       new FederatedRequest[] {tmpRequest, fr1, fr2} :
+                       new FederatedRequest[] { fr1, fr2});
                if( output.isScalar() )
                        ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, tmp, meanTmp, map));
                else
@@ -281,7 +282,7 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output, id,
                        new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()}, ExecType.SPARK, true);
 
-               map.execute(getTID(), fr1, fr2);
+               map.execute(getTID(), true, fr1, fr2);
                // derive new fed mapping for output
                MatrixObject out = ec.getMatrixObject(output);
                
out.setFedMapping(in.getFedMapping().copyWithNewID(fr2.getID()));
@@ -298,7 +299,6 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                        id = fr1.getID();
                }
                else {
-
                        if((map.getType() == FederationMap.FType.COL && 
aop.isColAggregate()) || (map.getType() == FederationMap.FType.ROW && 
aop.isRowAggregate()))
                                fr1 = new FederatedRequest(RequestType.PUT_VAR, 
id, new MatrixCharacteristics(-1, -1), in.getDataType());
                        else
@@ -307,11 +307,10 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
 
                FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output, id,
                        new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()}, ExecType.SPARK, true);
-               FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
-               FederatedRequest fr4 = map.cleanup(getTID(), fr2.getID());
+               FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
 
                //execute federated commands and cleanups
-               Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2, fr3, fr4);
+               Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2, fr3);
                if( output.isScalar() )
                        ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, tmp, map));
                else
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 3db1c8b..29b2a17 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -26,7 +26,9 @@ import java.net.URL;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
@@ -105,6 +107,17 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                        throw new DMLRuntimeException("Federated read needs 
twice the amount of addresses as ranges "
                                + "(begin and end): addresses=" + 
addresses.getLength() + " ranges=" + ranges.getLength());
 
+               //check for duplicate addresses (would lead to overwrite with 
common variable names)
+               // TODO relax requirement by using different execution contexts 
per federated data?
+               Set<String> addCheck = new HashSet<>();
+               for( Data dat : addresses.getData() )
+                       if( dat instanceof StringObject ) {
+                               String address = 
((StringObject)dat).getStringValue();
+                               if(addCheck.contains(address))
+                                       LOG.warn("Federated data contains 
address duplicates: " + addresses);
+                               addCheck.add(address);
+                       }
+               
                Types.DataType fedDataType;
                if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER))
                        fedDataType = Types.DataType.MATRIX;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
index 82ac6eb..4f70cce 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
@@ -91,49 +91,51 @@ public class FederatedSumTest extends AutomatedTestBase {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                Types.ExecMode platformOld = rtplatform;
 
-               getAndLoadTestConfiguration(TEST_NAME);
-               String HOME = SCRIPT_DIR + TEST_DIR;
-
-               double[][] A = getRandomMatrix(rows / 2, cols, -10, 10, 1, 1);
-               writeInputMatrixWithMTD("A", A, false, new 
MatrixCharacteristics(rows / 2, cols, blocksize, (rows / 2) * cols));
-               int port = getRandomAvailablePort();
-               Thread t = startLocalFedWorkerThread(port);
-
-               // we need the reference file to not be written to hdfs, so we 
get the correct format
-               rtplatform = Types.ExecMode.SINGLE_NODE;
-               // Run reference dml script with normal matrix for Row/Col sum
-               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-args", input("A"), input("A"), 
expected("R"), expected("C")};
-               runTest(true, false, null, -1);
-
-               // write expected sum
-               double sum = 0;
-               for(double[] doubles : A) {
-                       sum += Arrays.stream(doubles).sum();
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+       
+                       double[][] A = getRandomMatrix(rows / 2, cols, -10, 10, 
1, 1);
+                       writeInputMatrixWithMTD("A", A, false, new 
MatrixCharacteristics(rows / 2, cols, blocksize, (rows / 2) * cols));
+                       int port = getRandomAvailablePort();
+                       Thread t = startLocalFedWorkerThread(port);
+       
+                       // we need the reference file to not be written to 
hdfs, so we get the correct format
+                       rtplatform = Types.ExecMode.SINGLE_NODE;
+                       // Run reference dml script with normal matrix for 
Row/Col sum
+                       fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+                       programArgs = new String[] {"-args", input("A"), 
input("A"), expected("R"), expected("C")};
+                       runTest(true, false, null, -1);
+       
+                       // write expected sum
+                       double sum = 0;
+                       for(double[] doubles : A)
+                               sum += Arrays.stream(doubles).sum();
+                       writeExpectedScalar("S", sum);
+       
+                       // reference file should not be written to hdfs, so we 
set platform here
+                       rtplatform = execMode;
+                       if(rtplatform == Types.ExecMode.SPARK) {
+                               DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+                       }
+                       TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+                       loadTestConfiguration(config);
+                       OptimizerUtils.FEDERATED_COMPILATION = 
federatedCompilation;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-explain","-nvargs", "in=" 
+ TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
+                               "cols=" + cols, "out_S=" + output("S"), 
"out_R=" + output("R"), "out_C=" + output("C")};
+       
+                       runTest(true, false, null, -1);
+       
+                       // compare all sums via files
+                       compareResults(1e-11);
+       
+                       TestUtils.shutdownThread(t);
+                       rtplatform = platformOld;
                }
-               sum *= 2;
-               writeExpectedScalar("S", sum);
-
-               // reference file should not be written to hdfs, so we set 
platform here
-               rtplatform = execMode;
-               if(rtplatform == Types.ExecMode.SPARK) {
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               finally {
+                       OptimizerUtils.FEDERATED_COMPILATION = false;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                }
-               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
-               loadTestConfiguration(config);
-               OptimizerUtils.FEDERATED_COMPILATION = federatedCompilation;
-               fullDMLScriptName = HOME + TEST_NAME + ".dml";
-               programArgs = new String[] {"-nvargs", "in=" + 
TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
-                       "cols=" + cols, "out_S=" + output("S"), "out_R=" + 
output("R"), "out_C=" + output("C")};
-
-               runTest(true, false, null, -1);
-
-               // compare all sums via files
-               compareResults(1e-11);
-
-               TestUtils.shutdownThread(t);
-               rtplatform = platformOld;
-               OptimizerUtils.FEDERATED_COMPILATION = false;
-               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
        }
 }
diff --git a/src/test/scripts/functions/federated/FederatedSumTest.dml 
b/src/test/scripts/functions/federated/FederatedSumTest.dml
index 37a19f6..385b1dd 100644
--- a/src/test/scripts/functions/federated/FederatedSumTest.dml
+++ b/src/test/scripts/functions/federated/FederatedSumTest.dml
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in, $in), ranges=list(list(0, 0), list($rows / 
2, $cols), list($rows / 2, 0), list($rows, $cols)))
+A = federated(addresses=list($in), ranges=list(list(0, 0), list($rows / 2, 
$cols)))
 s = sum(A)
 r = rowSums(A)
 c = colSums(A)
diff --git a/src/test/scripts/functions/federated/FederatedSumTestReference.dml 
b/src/test/scripts/functions/federated/FederatedSumTestReference.dml
index af51717..3684860 100644
--- a/src/test/scripts/functions/federated/FederatedSumTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedSumTestReference.dml
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-A = rbind(read($1), read($2))
+A = read($1)
 r = rowSums(A)
 c = colSums(A)
 write(r, $3)

Reply via email to