This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 70c3e5f93d [MINOR] Cleanup flaky privacy/FederatedWorkerHandlerTest
70c3e5f93d is described below

commit 70c3e5f93d4ef22447d765ef261985129ff1a7e2
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Jun 4 23:45:20 2022 +0200

    [MINOR] Cleanup flaky privacy/FederatedWorkerHandlerTest
---
 .../privacy/FederatedWorkerHandlerTest.java        | 119 +++++++++------------
 1 file changed, 51 insertions(+), 68 deletions(-)

diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
index 7339aea931..d23fe8c533 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
@@ -23,6 +23,7 @@ import java.util.Arrays;
 
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
@@ -32,9 +33,7 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Ignore;
 import org.junit.Test;
-import static java.lang.Thread.sleep;
 import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
 @net.jcip.annotations.NotThreadSafe
 public class FederatedWorkerHandlerTest extends AutomatedTestBase {
@@ -49,7 +48,6 @@ public class FederatedWorkerHandlerTest extends 
AutomatedTestBase {
        private final static String TRANSFER_TEST_NAME = "FederatedRCBindTest";
        private final static String MATVECMULT_TEST_NAME = 
"FederatedMultiplyTest";
        private static final String FEDERATED_WORKER_HOST = "localhost";
-       private static final int FEDERATED_WORKER_PORT = 1222;
 
        private final static int blocksize = 1024;
        private final int rows = 10;
@@ -103,20 +101,15 @@ public class FederatedWorkerHandlerTest extends 
AutomatedTestBase {
 
        private void runGenericScalarTest(String dmlFile, int s, Class<?> 
expectedException, PrivacyLevel privacyLevel)
        {
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               Types.ExecMode platformOld = rtplatform;
+               ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
 
-               Thread t = null;
                try {
-                       // we need the reference file to not be written to 
hdfs, so we get the correct format
-                       rtplatform = Types.ExecMode.SINGLE_NODE;
-                       programArgs = new String[] {"-w", 
Integer.toString(FEDERATED_WORKER_PORT)};
-                       t = new Thread(() -> runTest(true, false, null, -1));
-                       t.start();
-                       sleep(FED_WORKER_WAIT);
+                       int port = getRandomAvailablePort();
+                       Thread t = startLocalFedWorkerThread(port);
+
                        fullDMLScriptName = SCRIPT_DIR + TEST_DIR_SCALAR + 
dmlFile + ".dml";
                        programArgs = new String[]{"-checkPrivacy", "-nvargs",
-                                       "in=" + 
TestUtils.federatedAddress(FEDERATED_WORKER_HOST, FEDERATED_WORKER_PORT, 
input("M")),
+                                       "in=" + 
TestUtils.federatedAddress(FEDERATED_WORKER_HOST, port, input("M")),
                                        "rows=" + Integer.toString(rows), 
"cols=" + Integer.toString(cols),
                                        "scalar=" + Integer.toString(s),
                                        "out=" + output("R")};
@@ -125,15 +118,12 @@ public class FederatedWorkerHandlerTest extends 
AutomatedTestBase {
 
                        if ( !exceptionExpected )
                                compareResults();
-               } catch (InterruptedException e) {
-                       fail("InterruptedException thrown" + e.getMessage() + " 
" + Arrays.toString(e.getStackTrace()));
-               } finally {
+                       TestUtils.shutdownThread(t);
+               }
+               finally {
                        assertTrue("The privacy level " + 
privacyLevel.toString() + " should have been checked during execution",
                                
checkedPrivacyConstraintsContains(privacyLevel));
-                       rtplatform = platformOld;
-                       TestUtils.shutdownThread(t);
-                       rtplatform = platformOld;
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       resetExecMode(platformOld);
                }
        }
 
@@ -153,57 +143,50 @@ public class FederatedWorkerHandlerTest extends 
AutomatedTestBase {
        }
 
        public void federatedSum(Types.ExecMode execMode, PrivacyLevel 
privacyLevel, Class<?> expectedException) {
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               Types.ExecMode platformOld = rtplatform;
-
-
-               getAndLoadTestConfiguration("aggregation");
-               String HOME = SCRIPT_DIR + TEST_DIR_fed;
+               ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
 
-               double[][] A = getRandomMatrix(rows/2, cols, -10, 10, 1, 1);
-               writeInputMatrixWithMTD("A", A, false, new 
MatrixCharacteristics(rows/2, cols, blocksize, (rows/2) * cols), new 
PrivacyConstraint(privacyLevel));
-               int port = getRandomAvailablePort();
-               Thread t = startLocalFedWorkerThread(port);
-
-               // we need the reference file to not be written to hdfs, so we 
get the correct format
-               rtplatform = Types.ExecMode.SINGLE_NODE;
-               // Run reference dml script with normal matrix for Row/Col sum
-               fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + 
"Reference.dml";
-               programArgs = new String[] {"-args", input("A"), input("A"), 
expected("R"), expected("C")};
-               runTest(true, false, null, -1);
-
-               // write expected sum
-               double sum = 0;
-               for(double[] doubles : A) {
-                       sum += Arrays.stream(doubles).sum();
+               try {
+                       getAndLoadTestConfiguration("aggregation");
+                       String HOME = SCRIPT_DIR + TEST_DIR_fed;
+       
+                       double[][] A = getRandomMatrix(rows/2, cols, -10, 10, 
1, 1);
+                       writeInputMatrixWithMTD("A", A, false, new 
MatrixCharacteristics(rows/2, cols, blocksize, (rows/2) * cols), new 
PrivacyConstraint(privacyLevel));
+                       int port = getRandomAvailablePort();
+                       Thread t = startLocalFedWorkerThread(port);
+       
+                       // Run reference dml script with normal matrix for 
Row/Col sum
+                       fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + 
"Reference.dml";
+                       programArgs = new String[] {"-args", input("A"), 
input("A"), expected("R"), expected("C")};
+                       runTest(true, false, null, -1);
+       
+                       // write expected sum
+                       double sum = 0;
+                       for(double[] doubles : A) {
+                               sum += Arrays.stream(doubles).sum();
+                       }
+       
+                       if ( expectedException == null )
+                               writeExpectedScalar("S", sum);
+       
+                       TestConfiguration config = 
availableTestConfigurations.get("aggregation");
+                       loadTestConfiguration(config);
+                       fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + 
".dml";
+                       programArgs = new String[] {"-checkPrivacy", "-nvargs", 
"in=" + TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
+                               "cols=" + cols, "out_S=" + output("S"), 
"out_R=" + output("R"), "out_C=" + output("C")};
+       
+                       runTest(true, (expectedException != null), 
expectedException, -1);
+       
+                       // compare all sums via files
+                       if ( expectedException == null )
+                               compareResults(1e-11);
+       
+                       assertTrue("The privacy level " + 
privacyLevel.toString() + " should have been checked during execution",
+                               
checkedPrivacyConstraintsContains(privacyLevel));
+                       TestUtils.shutdownThread(t);
                }
-
-               if ( expectedException == null )
-                       writeExpectedScalar("S", sum);
-
-               // reference file should not be written to hdfs, so we set 
platform here
-               rtplatform = execMode;
-               if(rtplatform == Types.ExecMode.SPARK) {
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               finally {
+                       resetExecMode(platformOld);
                }
-               TestConfiguration config = 
availableTestConfigurations.get("aggregation");
-               loadTestConfiguration(config);
-               fullDMLScriptName = HOME + AGGREGATION_TEST_NAME + ".dml";
-               programArgs = new String[] {"-checkPrivacy", "-nvargs", "in=" + 
TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
-                       "cols=" + cols, "out_S=" + output("S"), "out_R=" + 
output("R"), "out_C=" + output("C")};
-
-               runTest(true, (expectedException != null), expectedException, 
-1);
-
-               // compare all sums via files
-               if ( expectedException == null )
-                       compareResults(1e-11);
-
-               assertTrue("The privacy level " + privacyLevel.toString() + " 
should have been checked during execution",
-                       checkedPrivacyConstraintsContains(privacyLevel));
-
-               TestUtils.shutdownThread(t);
-               rtplatform = platformOld;
-               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
        }
 
        @Test

Reply via email to