This is an automated email from the ASF dual-hosted git repository.
Baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 249a8f52ee [SYSTEMDS-2651] Federated test workers: TCP port polling
and bulk startup
249a8f52ee is described below
commit 249a8f52eea041356f292adf88a2b58efdd7f688
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Mon May 18 12:59:24 2026 +0000
[SYSTEMDS-2651] Federated test workers: TCP port polling and bulk startup
Replace fixed Thread.sleep after each federated worker start with TCP
port polling that returns as soon as the worker accepts a connection.
Add bulk helpers that spawn N workers in parallel and wait once for the
slowest to become ready, instead of summing per-worker waits.
Cuts the federated CI total by ~7 min (-5%) vs main, with the biggest
wins in setup-heavy suites such as transform+fedplanner (-66%) and
codegen (-25%).
Closes #2468.
---
.../org/apache/sysds/test/AutomatedTestBase.java | 260 +++++++++++++++------
.../apache/sysds/test/FederatedWorkerUtils.java | 196 ++++++++++++++++
.../federated/algorithms/FederatedAlsCGTest.java | 5 +-
.../federated/algorithms/FederatedBivarTest.java | 7 +-
.../federated/algorithms/FederatedCorTest.java | 7 +-
.../federated/algorithms/FederatedGLMTest.java | 5 +-
.../federated/algorithms/FederatedKmeansTest.java | 5 +-
.../federated/algorithms/FederatedL2SVMTest.java | 5 +-
.../federated/algorithms/FederatedLmPipeline.java | 7 +-
.../federated/algorithms/FederatedLogRegTest.java | 5 +-
.../federated/algorithms/FederatedMSVMTest.java | 5 +-
.../federated/algorithms/FederatedPCATest.java | 7 +-
.../federated/algorithms/FederatedPNMFTest.java | 5 +-
.../federated/algorithms/FederatedUnivarTest.java | 7 +-
.../federated/algorithms/FederatedVarTest.java | 7 +-
.../federated/algorithms/FederatedYL2SVMTest.java | 5 +-
.../codegen/FederatedCellwiseTmplTest.java | 5 +-
.../codegen/FederatedCodegenMultipleFedMOTest.java | 5 +-
.../codegen/FederatedMultiAggTmplTest.java | 5 +-
.../codegen/FederatedOuterProductTmplTest.java | 5 +-
.../codegen/FederatedRowwiseTmplTest.java | 5 +-
.../fedplanning/FederatedDynamicPlanningTest.java | 7 +-
.../fedplanning/FederatedKMeansPlanningTest.java | 7 +-
.../fedplanning/FederatedL2SVMPlanningTest.java | 7 +-
.../fedplanning/FederatedMultiplyPlanningTest.java | 7 +-
.../federated/io/FederatedReaderTest.java | 5 +-
.../functions/federated/io/FederatedSSLTest.java | 5 +-
.../io/FederatedSparsityPropagationTest.java | 5 +-
.../federated/io/FederatedWriterTest.java | 5 +-
.../federated/multitenant/MultiTenantTestBase.java | 13 +-
.../paramserv/AvgModelFederatedParamservTest.java | 19 +-
.../paramserv/EncryptedFederatedParamservTest.java | 27 +--
.../paramserv/FederatedParamservTest.java | 30 +--
.../paramserv/NbatchesFederatedParamservTest.java | 19 +-
.../part1/FederatedBinaryMatrixTest.java | 7 +-
.../part1/FederatedBinaryVectorTest.java | 7 +-
.../primitives/part1/FederatedBroadcastTest.java | 7 +-
.../primitives/part1/FederatedCastToFrameTest.java | 8 +-
.../part1/FederatedCastToMatrixTest.java | 7 +-
.../part1/FederatedCentralMomentTest.java | 10 +-
.../part1/FederatedColAggregateTest.java | 9 +-
.../primitives/part1/FederatedLeftIndexTest.java | 9 +-
.../primitives/part1/FederatedMisAlignedTest.java | 10 +-
.../primitives/part2/FederatedMultiplyTest.java | 7 +-
.../primitives/part2/FederatedProdTest.java | 9 +-
.../primitives/part2/FederatedQuantileTest.java | 11 +-
.../part2/FederatedQuantileWeightsTest.java | 11 +-
.../primitives/part2/FederatedRCBindTest.java | 9 +-
.../primitives/part2/FederatedRdiagTest.java | 9 +-
.../primitives/part2/FederatedRemoveEmptyTest.java | 9 +-
.../primitives/part2/FederatedReplaceTest.java | 9 +-
.../primitives/part2/FederatedReshapeTest.java | 9 +-
.../primitives/part2/FederatedRevTest.java | 9 +-
.../primitives/part2/FederatedRightIndexTest.java | 9 +-
.../primitives/part2/FederatedRollTest.java | 9 +-
.../primitives/part2/FederatedRowIndexTest.java | 9 +-
.../primitives/part3/FederatedSplitTest.java | 7 +-
.../primitives/part3/FederatedStatisticsTest.java | 7 +-
.../primitives/part3/FederatedTokenizeTest.java | 8 +-
.../part3/FederatedTransferLocalDataTest.java | 9 +-
.../primitives/part3/FederatedTriTest.java | 9 +-
.../part3/FederatedWeightedCrossEntropyTest.java | 7 +-
.../part3/FederatedWeightedDivMatrixMultTest.java | 7 +-
.../part3/FederatedWeightedSigmoidTest.java | 7 +-
.../part3/FederatedWeightedSquaredLossTest.java | 7 +-
.../FederatedWeightedUnaryMatrixMultTest.java | 7 +-
.../part4/FederatedRowAggregateTest.java | 9 +-
.../primitives/part5/FederatedCovarianceTest.java | 18 +-
.../primitives/part5/FederatedCtableTest.java | 9 +-
.../primitives/part5/FederatedFrameMapTest.java | 9 +-
.../part5/FederatedFullAggregateTest.java | 9 +-
.../part5/FederatedFullCumulativeTest.java | 9 +-
.../primitives/part5/FederatedIfelseTest.java | 9 +-
.../primitives/part5/FederatedMMChainTest.java | 9 +-
.../TransformFederatedEncodeApplyTest.java | 9 +-
.../TransformFederatedEncodeDecodeTest.java | 9 +-
.../test/functions/lineage/FedFullReuseTest.java | 5 +-
.../test/functions/lineage/FedUDFReuseTest.java | 7 +-
.../test/functions/lineage/LineageFedReuseAlg.java | 5 +-
79 files changed, 615 insertions(+), 488 deletions(-)
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index c7f62b02a2..85a37b7dbd 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -117,9 +117,15 @@ public abstract class AutomatedTestBase {
public static boolean TEST_GPU = false;
public static final double GPU_TOLERANCE = 1e-9;
- // ms wait time
- public static final int FED_WORKER_WAIT = 3000;
- public static final int FED_MONITOR_WAIT = 10000;
+ /**
+ * Default upper bound (ms) passed to federated worker readiness waits.
The wait returns as soon
+ * as the worker's TCP port accepts a connection, so this value only
affects the deadline used
+ * when a worker never becomes ready. {@link FederatedWorkerUtils}
clamps caller values below its
+ * enforced floor up to that floor, so the effective ceiling is at
least that floor regardless
+ * of this constant.
+ */
+ public static final int FED_WORKER_WAIT = 3000;
+ public static final int FED_MONITOR_WAIT = 10000;
public static final int FED_WORKER_WAIT_S = 50;
@@ -1642,13 +1648,14 @@ public abstract class AutomatedTestBase {
/**
* Start a new JVM for a federated worker at the port.
- *
- * @param port Port to use for the JVM
- * @param sleep The sleep time to wait for the worker to start
+ *
+ * @param port Port to use for the JVM
+ * @param timeoutMs Upper bound on the wait for the worker to become
ready, in ms; raised to a
+ * minimum value enforced inside {@link
FederatedWorkerUtils}.
* @return The process containing the worker
*/
- protected Process startLocalFedWorker(int port, int sleep){
- return startLocalFedWorker(port, null, sleep);
+ protected Process startLocalFedWorker(int port, int timeoutMs){
+ return startLocalFedWorker(port, null, timeoutMs);
}
/**
@@ -1665,18 +1672,64 @@ public abstract class AutomatedTestBase {
/**
* Start new JVM for a federated worker at the port.
- *
- * @param port Port to use for the JVM
- * @param addArgs The arguments to add
- * @param sleep The time to wait for the process to start
+ *
+ * <p>Returns once the worker's TCP port accepts connections (the
worker opens the port after
+ * Netty's bind completes), or throws a {@link RuntimeException} after
{@code timeoutMs} elapses.
+ *
+ * @param port Port to use for the JVM
+ * @param addArgs The arguments to add
+ * @param timeoutMs Upper bound on the wait for the worker to become
ready, in ms; raised to a
+ * minimum value enforced inside {@link
FederatedWorkerUtils}.
* @return the process associated with the worker.
*/
- protected static Process startLocalFedWorker(int port, String[]
addArgs, int sleep) {
- Process process = null;
+ protected static Process startLocalFedWorker(int port, String[]
addArgs, int timeoutMs) {
+ Process process = spawnLocalFedWorker(port, addArgs);
+ FederatedWorkerUtils.waitForWorker(process, port, timeoutMs);
+ return process;
+ }
+
+ /**
+ * Start N federated worker JVMs back to back, then wait for all of
them to become ready in one
+ * shared poll loop. The wall-clock wait scales with the slowest worker
rather than the sum of the
+ * per-worker waits.
+ *
+ * @param ports Ports to use, one per worker
+ * @return The process per port, in the same order as {@code ports}.
+ */
+ protected static Process[] startLocalFedWorkers(int[] ports) {
+ return startLocalFedWorkers(ports, null, FED_WORKER_WAIT);
+ }
+
+ /** @see #startLocalFedWorkers(int[], String[], int) */
+ protected static Process[] startLocalFedWorkers(int[] ports, String[]
addArgs) {
+ return startLocalFedWorkers(ports, addArgs, FED_WORKER_WAIT);
+ }
+
+ /**
+ * Start N federated worker JVMs back to back, then wait for all of
them to become ready in one
+ * shared poll loop.
+ *
+ * @param ports Ports to use, one per worker
+ * @param addArgs Extra worker CLI args (applied to every worker), or
null
+ * @param timeoutMs Upper bound on the wait, in ms; raised to a minimum
value enforced inside
+ * {@link FederatedWorkerUtils}.
+ * @return The process per port, in the same order as {@code ports}.
+ */
+ protected static Process[] startLocalFedWorkers(int[] ports, String[]
addArgs, int timeoutMs) {
+ Process[] processes = new Process[ports.length];
+ for(int i = 0; i < ports.length; i++) {
+ processes[i] = spawnLocalFedWorker(ports[i], addArgs);
+ }
+ FederatedWorkerUtils.waitForWorkers(processes, ports,
timeoutMs);
+ return processes;
+ }
+
+ /** Spawn a federated worker JVM and return without waiting for the
port to bind. */
+ private static Process spawnLocalFedWorker(int port, String[] addArgs) {
String separator = System.getProperty("file.separator");
String classpath = System.getProperty("java.class.path");
String path = System.getProperty("java.home") + separator +
"bin" + separator + "java";
- String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m",
"-Xmn100m",
+ String[] args = new String[] {path, "-Xmx1000m", "-Xms1000m",
"-Xmn100m",
"--add-opens=java.base/java.nio=ALL-UNNAMED" ,
"--add-opens=java.base/java.io=ALL-UNNAMED" ,
"--add-opens=java.base/java.util=ALL-UNNAMED" ,
@@ -1701,19 +1754,14 @@ public abstract class AutomatedTestBase {
DMLScript.class.getName(), "-w",
Integer.toString(port), "-stats"});
if(addArgs != null)
args = ArrayUtils.addAll(args, addArgs);
-
- ProcessBuilder processBuilder = new
ProcessBuilder(args).inheritIO();
+ ProcessBuilder processBuilder = new
ProcessBuilder(args).inheritIO();
try {
- process = processBuilder.start();
- // Give some time to startup the worker.
- sleep(sleep);
+ return processBuilder.start();
}
- catch(IOException | InterruptedException e) {
- e.printStackTrace();
+ catch(IOException e) {
+ throw new RuntimeException("Failed to launch federated
worker process on port " + port, e);
}
- isAlive(process);
- return process;
}
/**
@@ -1743,7 +1791,7 @@ public abstract class AutomatedTestBase {
}
/**
- * Start a thread for a worker. This will share the same JVM, so all
static variables will be shared.!
+ * Start a thread for a worker. This will share the same JVM, so all
static variables will be shared.
*
* Also when using the local Fed Worker thread the statistics printing,
and clearing from the worker is disabled.
*
@@ -1769,63 +1817,112 @@ public abstract class AutomatedTestBase {
}
/**
- * Start a thread for a worker. This will share the same JVM, so all
static variables will be shared.!
+ * Start a thread for a worker. This will share the same JVM, so all
static variables will be shared.
*
* Also when using the local Fed Worker thread the statistics printing,
and clearing from the worker is disabled.
*
- * @param port Port to use
- * @param sleep The amount of time to wait for the worker startup. in
Milliseconds
+ * @param port Port to use
+ * @param timeoutMs Upper bound on the wait for the worker to become
ready, in ms; raised to a
+ * minimum value enforced inside {@link
FederatedWorkerUtils}.
* @return The thread associated with the worker.
*/
- public static Thread startLocalFedWorkerThread(int port, int sleep) {
- return startLocalFedWorkerThread(port, null, sleep);
+ public static Thread startLocalFedWorkerThread(int port, int timeoutMs)
{
+ return startLocalFedWorkerThread(port, null, timeoutMs);
}
/**
- * Start a thread for a worker. This will share the same JVM, so all
static variables will be shared.!
- *
- * Also when using the local Fed Worker thread the statistics printing,
and clearing from the worker is disabled.
- *
+ * Start a thread for a worker. This will share the same JVM, so all
static variables will be shared.
+ *
+ * <p>Also when using the local Fed Worker thread the statistics
printing, and clearing from the worker is
+ * disabled.
+ *
+ * <p>Returns once the worker's TCP port accepts connections (the
worker opens the port after Netty's bind
+ * completes), or throws a {@link RuntimeException} after {@code
timeoutMs} elapses.
+ *
* @param port Port to use
* @param otherArgs The command line arguments to start the worker with
- * @param sleep The amount of time to wait for the worker startup.
in Milliseconds
+ * @param timeoutMs Upper bound on the wait for the worker to become
ready, in ms; raised to a
+ * minimum value enforced inside {@link
FederatedWorkerUtils}.
* @return The thread associated with the worker.
*/
- public static Thread startLocalFedWorkerThread(int port, String[]
otherArgs, int sleep) {
+ public static Thread startLocalFedWorkerThread(int port, String[]
otherArgs, int timeoutMs) {
+ Thread t = spawnLocalFedWorkerThread(port, otherArgs);
+ FederatedWorkerUtils.waitForWorker(t, port, timeoutMs);
+ return t;
+ }
+ /**
+ * Start N federated worker threads in the same JVM back to back, then
wait for all of them to
+ * become ready in one shared poll loop. The wall-clock wait scales
with the slowest worker rather
+ * than the sum of the per-worker waits.
+ *
+ * @param ports Ports to use, one per worker
+ * @return The thread per port, in the same order as {@code ports}.
+ */
+ public static Thread[] startLocalFedWorkerThreads(int[] ports) {
+ return startLocalFedWorkerThreads(ports, null, FED_WORKER_WAIT);
+ }
+
+ /** @see #startLocalFedWorkerThreads(int[], String[], int) */
+ public static Thread[] startLocalFedWorkerThreads(int[] ports, String[]
otherArgs) {
+ return startLocalFedWorkerThreads(ports, otherArgs,
FED_WORKER_WAIT);
+ }
+
+ /**
+ * Start N federated worker threads in the same JVM back to back, then
wait for all of them to
+ * become ready in one shared poll loop.
+ *
+ * @param ports Ports to use, one per worker
+ * @param otherArgs Extra worker CLI args (applied to every worker), or
null
+ * @param timeoutMs Upper bound on the wait, in ms; raised to a minimum
value enforced inside
+ * {@link FederatedWorkerUtils}.
+ * @return The thread per port, in the same order as {@code ports}.
+ */
+ public static Thread[] startLocalFedWorkerThreads(int[] ports, String[]
otherArgs, int timeoutMs) {
+ Thread[] threads = new Thread[ports.length];
+ for(int i = 0; i < ports.length; i++) {
+ threads[i] = spawnLocalFedWorkerThread(ports[i],
otherArgs);
+ // Sleep THREAD_SPAWN_STAGGER_MS between in-JVM thread
spawns to reduce contention on
+ // shared static initialization in DMLScript /
FederatedWorker (e.g. LineageCacheConfig
+ // setters) when multiple worker threads enter main()
concurrently.
+ if(i + 1 < ports.length) {
+ try {
+
java.util.concurrent.TimeUnit.MILLISECONDS.sleep(THREAD_SPAWN_STAGGER_MS);
+ }
+ catch(InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException("Interrupted
while spawning federated worker threads", e);
+ }
+ }
+ }
+ FederatedWorkerUtils.waitForWorkers(threads, ports, timeoutMs);
+ return threads;
+ }
+
+ private static final int THREAD_SPAWN_STAGGER_MS = 25;
+
+ /** Spawn a federated worker thread in this JVM and return without
waiting for the port to bind. */
+ private static Thread spawnLocalFedWorkerThread(int port, String[]
otherArgs) {
ArrayList<String> args = new ArrayList<>();
-
args.add("-w");
args.add(Integer.toString(port));
-
if(otherArgs != null)
- for( String s : otherArgs)
+ for(String s : otherArgs)
args.add(s);
String[] finalArguments = args.toArray(new String[args.size()]);
Statistics.allowWorkerStatistics = false;
- try {
- Thread t = new Thread(() -> {
- try {
- main(finalArguments);
- }
- catch(Exception e) {
- LOG.error("Exception in startup of
federated worker", e);
- }
- });
- t.start();
- java.util.concurrent.TimeUnit.MILLISECONDS.sleep(sleep);
- if(!t.isAlive())
- throw new RuntimeException("Failed starting
federated worker");
- return t;
- }
- catch(InterruptedException e) {
- e.printStackTrace();
- fail("Failed to start federated worker : " + e);
- // should never happen
- return null;
- }
+ Thread t = new Thread(() -> {
+ try {
+ main(finalArguments);
+ }
+ catch(Exception e) {
+ LOG.error("Exception in startup of federated
worker", e);
+ }
+ });
+ t.start();
+ return t;
}
public static boolean isAlive(Thread... threads){
@@ -1846,28 +1943,43 @@ public abstract class AutomatedTestBase {
/**
* Start java worker in same JVM.
- *
+ *
+ * <p>Returns once the worker's TCP port accepts connections (the
worker opens the port after
+ * Netty's bind completes), or throws a {@link RuntimeException} after
the default federated worker
+ * timeout elapses. The port is extracted from {@code args}, which must
contain {@code "-w" <port>}.
+ *
* @param args the command line arguments
- * @return the thread associated with the process.s
+ * @return the thread associated with the worker.
*/
public static Thread startLocalFedWorkerWithArgs(String[] args) {
- Thread t = null;
+ final int port = extractWorkerPort(args);
+ Thread t = new Thread(() -> {
+ try {
+ main(args);
+ }
+ catch(IOException e) {
+ LOG.error("Exception in startup of federated
worker on port " + port, e);
+ }
+ });
+ t.start();
+ FederatedWorkerUtils.waitForWorker(t, port, FED_WORKER_WAIT);
+ return t;
+ }
- try {
- t = new Thread(() -> {
+ private static int extractWorkerPort(String[] args) {
+ for(int i = 0; i < args.length - 1; i++) {
+ if("-w".equals(args[i])) {
try {
- main(args);
+ return Integer.parseInt(args[i + 1]);
}
- catch(IOException e) {
+ catch(NumberFormatException e) {
+ throw new IllegalArgumentException(
+ "Federated worker args contain
non-numeric port after -w: " + args[i + 1], e);
}
- });
- t.start();
-
java.util.concurrent.TimeUnit.MILLISECONDS.sleep(FED_WORKER_WAIT);
- }
- catch(InterruptedException e) {
- // Should happen at closing of the worker so don't print
+ }
}
- return t;
+ throw new IllegalArgumentException("Federated worker args must
contain '-w <port>': "
+ + Arrays.toString(args));
}
private boolean rCompareException(boolean exceptionExpected, String
errMessage, Throwable e, boolean result) {
diff --git a/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java
b/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java
new file mode 100644
index 0000000000..d604d7dcab
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/FederatedWorkerUtils.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.util.function.BooleanSupplier;
+
+/**
+ * Test helpers that block until a federated worker is accepting TCP
connections on its port.
+ *
+ * <p>The federated worker opens its TCP port after Netty's {@code
bind().sync()} returns; a successful
+ * TCP connect to that port therefore indicates that the worker is ready to
accept requests. The methods
+ * here poll for that signal and throw {@link RuntimeException} on timeout or
if the underlying
+ * {@code Process}/{@code Thread} exits before the port becomes ready.
+ */
+public final class FederatedWorkerUtils {
+
+ /** Sleep between successive poll rounds, in milliseconds. */
+ private static final int POLL_INTERVAL_MS = 25;
+
+ /** Per-attempt {@link Socket#connect} timeout, in milliseconds. */
+ private static final int CONNECT_TIMEOUT_MS = 25;
+
+ /**
+ * Minimum value applied to the caller-supplied {@code timeoutMs}. The
wait returns as soon as the
+ * worker accepts a connection, so this only affects the upper bound
used when a worker never becomes
+ * ready. Set to 60s to accommodate cold JVM startup on heavily
contended CI runners: tests starting
+ * four workers in parallel can have all four still pending after 30s
when the runner is CPU-starved,
+ * and burning a surefire retry costs more wall time than padding this
clamp.
+ */
+ private static final int MIN_TIMEOUT_MS = 60_000;
+
+ private FederatedWorkerUtils() {
+ // utility class
+ }
+
+ /**
+ * Block until a federated worker is accepting TCP connections on
{@code port}, or throw a
+ * {@link RuntimeException} after the effective timeout elapses.
+ *
+ * @param port port the federated worker is expected to bind
+ * @param timeoutMs upper bound on the wait, in ms; raised to {@link
#MIN_TIMEOUT_MS} if smaller
+ */
+ public static void waitForWorker(int port, int timeoutMs) {
+ waitForWorker(port, timeoutMs, () -> true, "worker");
+ }
+
+ /**
+ * Block until a federated worker is accepting TCP connections on
{@code port}. Returns early with
+ * a {@link RuntimeException} if {@code aliveCheck} reports the worker
is no longer alive.
+ */
+ public static void waitForWorker(int port, int timeoutMs,
BooleanSupplier aliveCheck, String workerKind) {
+ final int effectiveTimeout = Math.max(timeoutMs,
MIN_TIMEOUT_MS);
+ final long deadline = System.currentTimeMillis() +
effectiveTimeout;
+ while(System.currentTimeMillis() < deadline) {
+ if(!aliveCheck.getAsBoolean()) {
+ throw new RuntimeException(
+ "Federated " + workerKind + " on port "
+ port + " died before becoming ready.");
+ }
+ if(tryConnect(port)) {
+ return;
+ }
+ sleepQuietly();
+ }
+ throw new RuntimeException("Federated " + workerKind + " on
port " + port
+ + " did not become ready within " + effectiveTimeout +
"ms.");
+ }
+
+ /** Overload that also returns early if the given worker process exits
before the port is ready. */
+ public static void waitForWorker(Process process, int port, int
timeoutMs) {
+ waitForWorker(port, timeoutMs, process::isAlive, "worker
process");
+ }
+
+ /** Overload that also returns early if the given worker thread exits
before the port is ready. */
+ public static void waitForWorker(Thread thread, int port, int
timeoutMs) {
+ waitForWorker(port, timeoutMs, thread::isAlive, "worker
thread");
+ }
+
+ /**
+ * Block until every listed federated worker is accepting TCP
connections. All ports are polled in
+ * one shared loop, so the wall-clock wait is bounded by the slowest
worker rather than the sum of
+ * individual waits.
+ *
+ * @param ports ports the workers are expected to bind
+ * @param timeoutMs upper bound on the wait, in ms; raised to {@link
#MIN_TIMEOUT_MS} if smaller
+ */
+ public static void waitForWorkers(int[] ports, int timeoutMs) {
+ waitForWorkers(ports, timeoutMs, i -> true, "workers");
+ }
+
+ /**
+ * Overload that also returns early if any of the worker processes
exits before its port is ready.
+ *
+ * @throws IllegalArgumentException if {@code processes.length !=
ports.length}
+ */
+ public static void waitForWorkers(Process[] processes, int[] ports, int
timeoutMs) {
+ if(processes.length != ports.length) {
+ throw new IllegalArgumentException(
+ "processes/ports length mismatch: " +
processes.length + " vs " + ports.length);
+ }
+ waitForWorkers(ports, timeoutMs, i -> processes[i].isAlive(),
"worker processes");
+ }
+
+ /**
+ * Overload that also returns early if any of the worker threads exits
before its port is ready.
+ *
+ * @throws IllegalArgumentException if {@code threads.length !=
ports.length}
+ */
+ public static void waitForWorkers(Thread[] threads, int[] ports, int
timeoutMs) {
+ if(threads.length != ports.length) {
+ throw new IllegalArgumentException(
+ "threads/ports length mismatch: " +
threads.length + " vs " + ports.length);
+ }
+ waitForWorkers(ports, timeoutMs, i -> threads[i].isAlive(),
"worker threads");
+ }
+
+ /**
+ * Bulk variant taking a per-index liveness predicate so callers can
plug in either {@code Process}
+ * or {@code Thread} liveness. Each port flips to ready as soon as it
accepts a connection; the loop
+ * yields between sweeps so a still-pending worker is not starved by
repeated probes on the same CPU.
+ */
+ public static void waitForWorkers(int[] ports, int timeoutMs,
java.util.function.IntPredicate aliveCheck,
+ String workerKind) {
+ final int effectiveTimeout = Math.max(timeoutMs,
MIN_TIMEOUT_MS);
+ final long deadline = System.currentTimeMillis() +
effectiveTimeout;
+ final boolean[] ready = new boolean[ports.length];
+ int remaining = ports.length;
+ while(remaining > 0 && System.currentTimeMillis() < deadline) {
+ for(int i = 0; i < ports.length; i++) {
+ if(ready[i]) {
+ continue;
+ }
+ if(!aliveCheck.test(i)) {
+ throw new RuntimeException("Federated "
+ workerKind + " on port " + ports[i]
+ + " died before becoming
ready.");
+ }
+ if(tryConnect(ports[i])) {
+ ready[i] = true;
+ remaining--;
+ }
+ }
+ if(remaining > 0) {
+ sleepQuietly();
+ }
+ }
+ if(remaining > 0) {
+ StringBuilder sb = new StringBuilder("Federated
").append(workerKind)
+ .append(" did not all become ready within
").append(effectiveTimeout).append("ms. Pending ports:");
+ for(int i = 0; i < ports.length; i++) {
+ if(!ready[i]) {
+ sb.append(' ').append(ports[i]);
+ }
+ }
+ throw new RuntimeException(sb.toString());
+ }
+ }
+
+ private static boolean tryConnect(int port) {
+ try(Socket s = new Socket()) {
+ s.connect(new InetSocketAddress("localhost", port),
CONNECT_TIMEOUT_MS);
+ return true;
+ }
+ catch(IOException e) {
+ return false;
+ }
+ }
+
+ private static void sleepQuietly() {
+ try {
+ Thread.sleep(POLL_INTERVAL_MS);
+ }
+ catch(InterruptedException ie) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException("Interrupted while waiting
for federated worker", ie);
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
index 5e880166af..3b53c71c0e 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
@@ -120,8 +120,7 @@ public class FederatedAlsCGTest extends AutomatedTestBase
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
getAndLoadTestConfiguration(testname);
@@ -153,7 +152,7 @@ public class FederatedAlsCGTest extends AutomatedTestBase
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_!="));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
index f2f3570011..0537ecf18d 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
@@ -114,10 +114,7 @@ public class FederatedBivarTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2, port3, port4});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -139,7 +136,7 @@ public class FederatedBivarTest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
index 73bf8e91de..2c81da6333 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
@@ -102,10 +102,7 @@ public class FederatedCorTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2, port3, port4});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -140,7 +137,7 @@ public class FederatedCorTest extends AutomatedTestBase {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 6d8e816530..b54cb3d0a6 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -100,8 +100,7 @@ public class FederatedGLMTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -122,7 +121,7 @@ public class FederatedGLMTest extends AutomatedTestBase {
// compare via files
compareResults(1e-2);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index c8605ac3d8..c1fc83ae3c 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -113,8 +113,7 @@ public class FederatedKmeansTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -155,7 +154,7 @@ public class FederatedKmeansTest extends AutomatedTestBase {
// compare via files
// compareResults(1e-9); --> randomized
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
index f7040e6a4d..0b38317f73 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
@@ -104,8 +104,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -126,7 +125,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
index 8b02246e0a..f3e921dd93 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLmPipeline.java
@@ -109,10 +109,7 @@ public class FederatedLmPipeline extends AutomatedTestBase
{
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2,
FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3,
FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[]
{port1, port2, port3, port4});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -134,7 +131,7 @@ public class FederatedLmPipeline extends AutomatedTestBase {
// compare via files
compareResults(1e-2);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
// check correct federated operations
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
index a3e91ef37d..74882637d9 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
@@ -95,8 +95,7 @@ public class FederatedLogRegTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -117,7 +116,7 @@ public class FederatedLogRegTest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue("contains fed_ba+*",
heavyHittersContainsString("fed_ba+*"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
index c5344f9d84..97a4d84a0b 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
@@ -97,8 +97,7 @@ public class FederatedMSVMTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -119,7 +118,7 @@ public class FederatedMSVMTest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
index 2a4186680c..9ddd59483c 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
@@ -102,10 +102,7 @@ public class FederatedPCATest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2, port3, port4});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -127,7 +124,7 @@ public class FederatedPCATest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
index 19fb72ce47..183f9900a6 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
@@ -113,8 +113,7 @@ public class FederatedPNMFTest extends AutomatedTestBase
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -141,7 +140,7 @@ public class FederatedPNMFTest extends AutomatedTestBase
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_wcemm"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
index a4a8236e3a..21f2e39225 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
@@ -100,10 +100,7 @@ public class FederatedUnivarTest extends AutomatedTestBase
{
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2, port3, port4});
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -126,7 +123,7 @@ public class FederatedUnivarTest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_uacmax"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
index db071220d0..445abfa148 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
@@ -115,10 +115,7 @@ public class FederatedVarTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2, port3, port4});
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK)
@@ -157,7 +154,7 @@ public class FederatedVarTest extends AutomatedTestBase {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
OptimizerUtils.FEDERATED_COMPILATION = false;
rtplatform = platformOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
index b8eef26f35..a87328a8aa 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java
@@ -117,8 +117,7 @@ public class FederatedYL2SVMTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
TestConfiguration config =
availableTestConfigurations.get(testName);
loadTestConfiguration(config);
@@ -140,7 +139,7 @@ public class FederatedYL2SVMTest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
index 2f6e5e4617..94fab208e0 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
@@ -158,8 +158,7 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -188,7 +187,7 @@ public class FederatedCellwiseTmplTest extends
AutomatedTestBase
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofCell"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
index 61722dbc46..a23219e158 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
@@ -194,8 +194,7 @@ public class FederatedCodegenMultipleFedMOTest extends
AutomatedTestBase
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -228,7 +227,7 @@ public class FederatedCodegenMultipleFedMOTest extends
AutomatedTestBase
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
if(test_num >= 0 && test_num < 100)
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
index 33a551d28d..74cd2c0a42 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
@@ -143,8 +143,7 @@ public class FederatedMultiAggTmplTest extends
AutomatedTestBase
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -173,7 +172,7 @@ public class FederatedMultiAggTmplTest extends
AutomatedTestBase
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofMA"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
index cef5fd5e99..dd08ae5863 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
@@ -146,8 +146,7 @@ public class FederatedOuterProductTmplTest extends
AutomatedTestBase
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -176,7 +175,7 @@ public class FederatedOuterProductTmplTest extends
AutomatedTestBase
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP"));
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
index b47a718c0c..185c15557d 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
@@ -147,8 +147,7 @@ public class FederatedRowwiseTmplTest extends
AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- Thread thread2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2});
getAndLoadTestConfiguration(TEST_NAME);
@@ -178,7 +177,7 @@ public class FederatedRowwiseTmplTest extends
AutomatedTestBase {
TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
- TestUtils.shutdownThreads(thread1, thread2);
+ TestUtils.shutdownThreads(workers);
// check for federated operations
if(row_partitioned)
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
index 53dce3f01c..bcc2b61adf 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedDynamicPlanningTest.java
@@ -118,7 +118,7 @@ public class FederatedDynamicPlanningTest extends
AutomatedTestBase {
Types.ExecMode platformOld = rtplatform;
rtplatform = Types.ExecMode.SINGLE_NODE;
- Thread t1 = null, t2 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(testName);
@@ -128,8 +128,7 @@ public class FederatedDynamicPlanningTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2);
+ workers = startLocalFedWorkerThreads(new int[] {port1,
port2}, null, FED_WORKER_WAIT);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testName + ".dml";
@@ -164,7 +163,7 @@ public class FederatedDynamicPlanningTest extends
AutomatedTestBase {
+
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
index 9a9ff18d28..b6c239854b 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedKMeansPlanningTest.java
@@ -125,7 +125,7 @@ public class FederatedKMeansPlanningTest extends
AutomatedTestBase {
Types.ExecMode platformOld = rtplatform;
rtplatform = Types.ExecMode.SINGLE_NODE;
- Thread t1 = null, t2 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(testName);
@@ -135,8 +135,7 @@ public class FederatedKMeansPlanningTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2);
+ workers = startLocalFedWorkerThreads(new int[] {port1,
port2}, null, FED_WORKER_WAIT);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testName + ".dml";
@@ -158,7 +157,7 @@ public class FederatedKMeansPlanningTest extends
AutomatedTestBase {
// fail("The following expected heavy hitters are
missing: "
// +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
} finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
index 9114339620..39ab490148 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedL2SVMPlanningTest.java
@@ -170,7 +170,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
Types.ExecMode platformOld = rtplatform;
rtplatform = Types.ExecMode.SINGLE_NODE;
- Thread t1 = null, t2 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(testName);
@@ -180,8 +180,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2);
+ workers = startLocalFedWorkerThreads(new int[] {port1,
port2}, null, FED_WORKER_WAIT);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testName + ".dml";
@@ -205,7 +204,7 @@ public class FederatedL2SVMPlanningTest extends
AutomatedTestBase {
// +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
index 5b54f14d05..0a2bd230da 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/fedplanning/FederatedMultiplyPlanningTest.java
@@ -240,7 +240,7 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
if(rtplatform == Types.ExecMode.SPARK) {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
- Thread t1 = null, t2 = null;
+ Thread[] workers = null;
try{
getAndLoadTestConfiguration(testName);
@@ -250,8 +250,7 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2);
+ workers = startLocalFedWorkerThreads(new int[] {port1,
port2}, null, FED_WORKER_WAIT);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + testName + ".dml";
@@ -275,7 +274,7 @@ public class FederatedMultiplyPlanningTest extends
AutomatedTestBase {
fail("The following expected heavy hitters are
missing: "
+
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
} finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
index ff96ad8af2..566d3ff323 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -91,8 +91,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2}, null, FED_WORKER_WAIT);
String host = "localhost";
try {
@@ -137,6 +136,6 @@ public class FederatedReaderTest extends AutomatedTestBase {
resetExecMode(oldPlatform);
}
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
index 5d4d2c4e9c..5f5c09e07c 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
@@ -96,8 +96,7 @@ public class FederatedSSLTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2}, null, FED_WORKER_WAIT);
String host = "localhost";
@@ -134,7 +133,7 @@ public class FederatedSSLTest extends AutomatedTestBase {
resetExecMode(oldPlatform);
}
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
/**
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java
index 42165d6b68..f2654b9bb6 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSparsityPropagationTest.java
@@ -102,8 +102,7 @@ public class FederatedSparsityPropagationTest extends
AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2}, null, FED_WORKER_WAIT);
getAndLoadTestConfiguration(TEST_NAME);
@@ -136,7 +135,7 @@ public class FederatedSparsityPropagationTest extends
AutomatedTestBase {
compareNNZ(refNNZ, fedNNZ);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
index d8bb743147..68fb3a7da6 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
@@ -83,8 +83,7 @@ public class FederatedWriterTest extends AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2}, null, FED_WORKER_WAIT);
try {
@@ -122,6 +121,6 @@ public class FederatedWriterTest extends AutomatedTestBase {
resetExecMode(oldPlatform);
}
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
index c3a4756a2d..866bbe0013 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/MultiTenantTestBase.java
@@ -38,8 +38,6 @@ import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.test.AutomatedTestBase;
import org.junit.After;
-import com.google.crypto.tink.subtle.Random;
-
public abstract class MultiTenantTestBase extends AutomatedTestBase {
protected static final Log LOG =
LogFactory.getLog(MultiTenantTestBase.class.getName());
@@ -63,7 +61,8 @@ public abstract class MultiTenantTestBase extends
AutomatedTestBase {
}
/**
- * Start numFedWorkers federated worker processes on available ports
and add them to the workerProcesses
+ * Start numFedWorkers federated worker processes on available ports
and add them to the workerProcesses.
+ * Workers are spawned together and their port-bind is awaited in
parallel.
*
* @param numFedWorkers the number of federated workers to start
* @return int[] the ports of the created federated workers
@@ -72,10 +71,10 @@ public abstract class MultiTenantTestBase extends
AutomatedTestBase {
int[] ports = new int[numFedWorkers];
for(int counter = 0; counter < numFedWorkers; counter++) {
ports[counter] = getRandomAvailablePort();
- // start process but only wait long for last one.
- Process tmpProcess =
startLocalFedWorker(ports[counter], addArgs,
- counter == numFedWorkers - 1 ? (FED_WORKER_WAIT
+ Random.randInt(1000)) * 3 : FED_WORKER_WAIT_S);
- workerProcesses.add(tmpProcess);
+ }
+ Process[] processes = startLocalFedWorkers(ports, addArgs);
+ for(Process p : processes) {
+ workerProcesses.add(p);
}
return ports;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
index b7eb0c2283..66d2afaa6a 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -142,12 +142,13 @@ public class AvgModelFederatedParamservTest extends
AutomatedTestBase {
try {
// start threads
List<Integer> ports = new ArrayList<>();
- List<Thread> threads = new ArrayList<>();
+ int[] portArr = new int[_numFederatedWorkers];
for(int i = 0; i < _numFederatedWorkers; i++) {
- ports.add(getRandomAvailablePort());
-
threads.add(startLocalFedWorkerThread(ports.get(i),
- i==(_numFederatedWorkers-1) ?
FED_WORKER_WAIT : FED_WORKER_WAIT_S));
+ int port = getRandomAvailablePort();
+ portArr[i] = port;
+ ports.add(port);
}
+ Thread[] threads = startLocalFedWorkerThreads(portArr);
// generate test data
double[][] features =
generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
@@ -171,14 +172,6 @@ public class AvgModelFederatedParamservTest extends
AutomatedTestBase {
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels,
_numFederatedWorkers, ports, ranges);
}
- try {
- //wait for all workers to be setup
- Thread.sleep(FED_WORKER_WAIT);
- }
- catch(InterruptedException e) {
- e.printStackTrace();
- }
-
// dml name
fullDMLScriptName = HOME + TEST_NAME + ".dml";
// generate program args
@@ -207,7 +200,7 @@ public class AvgModelFederatedParamservTest extends
AutomatedTestBase {
// shut down threads
for(int i = 0; i < _numFederatedWorkers; i++) {
- TestUtils.shutdownThreads(threads.get(i));
+ TestUtils.shutdownThreads(threads[i]);
}
}
finally {
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
index 2663b0e762..a63cc40bdb 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
@@ -26,7 +26,6 @@ import java.util.List;
import java.util.Objects;
import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -151,19 +150,16 @@ public class EncryptedFederatedParamservTest extends
AutomatedTestBase {
}
ExecMode platformOld = setExecMode(mode);
- // start threads
List<Integer> ports = new ArrayList<>();
- List<Thread> threads = new ArrayList<>();
+ int[] portArr = new int[_numFederatedWorkers];
+ for(int i = 0; i < _numFederatedWorkers; i++) {
+ int port = getRandomAvailablePort();
+ portArr[i] = port;
+ ports.add(port);
+ }
+ Thread[] threads = new Thread[0];
try {
- for(int i = 0; i < _numFederatedWorkers; i++) {
- int port = getRandomAvailablePort();
- threads.add(startLocalFedWorkerThread(port,
- i==(_numFederatedWorkers-1) ?
FED_WORKER_WAIT : FED_WORKER_WAIT_S));
- ports.add(port);
-
- if ( threads.get(i).isInterrupted() ||
!threads.get(i).isAlive() )
- throw new
DMLRuntimeException("Federated worker thread dead or interrupted! Port " +
port);
- }
+ threads = startLocalFedWorkerThreads(portArr);
// generate test data
double[][] features =
ParamServTestUtils.generateFeatures(_networkType, _dataSetSize, C, Hin, Win);
@@ -187,9 +183,6 @@ public class EncryptedFederatedParamservTest extends
AutomatedTestBase {
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels,
_numFederatedWorkers, ports, ranges);
}
- //wait for all workers to be setup
- Thread.sleep(FED_WORKER_WAIT);
-
// dml name
fullDMLScriptName = HOME + TEST_NAME + ".dml";
// generate program args
@@ -220,10 +213,6 @@ public class EncryptedFederatedParamservTest extends
AutomatedTestBase {
+
Arrays.toString(missingHeavyHitters("paramserv")));
Assert.assertEquals("Test Failed \n" + log, 0,
Statistics.getNoOfExecutedSPInst());
}
- catch(InterruptedException e) {
- e.printStackTrace();
- fail(e.getMessage());
- }
finally {
// shut down threads
for ( Thread thread : threads ){
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index b9877b3f16..887b63f0af 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -25,7 +25,6 @@ import java.util.Collection;
import java.util.List;
import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -35,8 +34,6 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
-import static org.junit.Assert.fail;
-
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedParamservTest extends AutomatedTestBase {
@@ -147,17 +144,15 @@ public class FederatedParamservTest extends
AutomatedTestBase {
ExecMode platformOld = setExecMode(mode);
List<Integer> ports = new ArrayList<>();
- List<Thread> threads = new ArrayList<>();
+ int[] portArr = new int[_numFederatedWorkers];
+ for(int i = 0; i < _numFederatedWorkers; i++) {
+ int port = getRandomAvailablePort();
+ portArr[i] = port;
+ ports.add(port);
+ }
+ Thread[] threads = new Thread[0];
try {
- // start threads
- for(int i = 0; i < _numFederatedWorkers; i++) {
- int port = getRandomAvailablePort();
- threads.add(startLocalFedWorkerThread(port,
FED_WORKER_WAIT_S));
- ports.add(port);
-
- if ( threads.get(i).isInterrupted() ||
!threads.get(i).isAlive() )
- throw new
DMLRuntimeException("Federated worker thread dead or interrupted! Port " +
port);
- }
+ threads = startLocalFedWorkerThreads(portArr);
// generate test data
double[][] features =
ParamServTestUtils.generateFeatures(_networkType, _dataSetSize, C, Hin, Win);
@@ -181,11 +176,6 @@ public class FederatedParamservTest extends
AutomatedTestBase {
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels,
_numFederatedWorkers, ports, ranges);
}
- //wait for all workers to be setup
- Thread.sleep(FED_WORKER_WAIT);
- if (threads.stream().anyMatch(t -> !t.isAlive()))
- throw new DMLRuntimeException("Federated worker
thread interrupted!");
-
// dml name
fullDMLScriptName = HOME + TEST_NAME + ".dml";
// generate program args
@@ -211,10 +201,6 @@ public class FederatedParamservTest extends
AutomatedTestBase {
String log = runTest(null).toString();
Assert.assertEquals("Test Failed \n" + log, 0,
Statistics.getNoOfExecutedSPInst());
}
- catch(InterruptedException e) {
- e.printStackTrace();
- fail(e.getMessage());
- }
finally {
// shut down threads
for ( Thread thread : threads ){
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
index 9b307a4e53..e4f85df50a 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/NbatchesFederatedParamservTest.java
@@ -121,12 +121,13 @@ public class NbatchesFederatedParamservTest extends
AutomatedTestBase {
try {
// start threads
List<Integer> ports = new ArrayList<>();
- List<Thread> threads = new ArrayList<>();
+ int[] portArr = new int[_numFederatedWorkers];
for(int i = 0; i < _numFederatedWorkers; i++) {
- ports.add(getRandomAvailablePort());
-
threads.add(startLocalFedWorkerThread(ports.get(i),
- (i==(_numFederatedWorkers-1) ?
FED_WORKER_WAIT : FED_WORKER_WAIT_S)));
+ int port = getRandomAvailablePort();
+ portArr[i] = port;
+ ports.add(port);
}
+ Thread[] threads = startLocalFedWorkerThreads(portArr);
// generate test data
double[][] features =
generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
@@ -150,14 +151,6 @@ public class NbatchesFederatedParamservTest extends
AutomatedTestBase {
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels,
_numFederatedWorkers, ports, ranges);
}
- try {
- //wait for all workers to be setup
- Thread.sleep(FED_WORKER_WAIT);
- }
- catch(InterruptedException e) {
- e.printStackTrace();
- }
-
// dml name
fullDMLScriptName = HOME + TEST_NAME + ".dml";
// generate program args
@@ -186,7 +179,7 @@ public class NbatchesFederatedParamservTest extends
AutomatedTestBase {
// shut down threads
for(int i = 0; i < _numFederatedWorkers; i++) {
- TestUtils.shutdownThreads(threads.get(i));
+ TestUtils.shutdownThreads(threads[i]);
}
}
finally {
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java
index 8991d28194..20945303b0 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryMatrixTest.java
@@ -105,11 +105,10 @@ public class FederatedBinaryMatrixTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -134,7 +133,7 @@ public class FederatedBinaryMatrixTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.FEDERATED_COMPILATION = false;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java
index e038d9efda..d47a7efcfa 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBinaryVectorTest.java
@@ -97,11 +97,10 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
@@ -127,7 +126,7 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.FEDERATED_COMPILATION = false;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java
index 434eb08e26..1a88f7a0d3 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedBroadcastTest.java
@@ -90,11 +90,10 @@ public class FederatedBroadcastTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
@@ -117,7 +116,7 @@ public class FederatedBroadcastTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java
index a7bc0d8064..392e53aa77 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToFrameTest.java
@@ -97,12 +97,10 @@ public class FederatedCastToFrameTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
-
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
@@ -127,7 +125,7 @@ public class FederatedCastToFrameTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java
index ccfd4a6c43..832ae98d8e 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCastToMatrixTest.java
@@ -122,10 +122,9 @@ public class FederatedCastToMatrixTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1,
FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[]
{port1, port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed
starting federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
@@ -152,7 +151,7 @@ public class FederatedCastToMatrixTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG =
sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
index 1d9f951e78..f38e322097 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedCentralMomentTest.java
@@ -118,14 +118,10 @@ public class FederatedCentralMomentTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT + 1000);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
-
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
// reference file should not be written to hdfs, so we
set platform here
@@ -187,7 +183,7 @@ public class FederatedCentralMomentTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java
index f52a74bd4f..5254378cf2 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedColAggregateTest.java
@@ -174,13 +174,10 @@ public class FederatedColAggregateTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
@@ -244,7 +241,7 @@ public class FederatedColAggregateTest extends
AutomatedTestBase {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java
index 30d6b8fc0e..6f9744ded1 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedLeftIndexTest.java
@@ -165,13 +165,10 @@ public class FederatedLeftIndexTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1,
FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2,
FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3,
FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[]
{port1, port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed
starting federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -212,7 +209,7 @@ public class FederatedLeftIndexTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
}
}
finally {
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java
index 8fb8a80663..06f86dd61b 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part1/FederatedMisAlignedTest.java
@@ -220,14 +220,10 @@ public class FederatedMisAlignedTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
-
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
// Run reference dml script with normal matrix
@@ -281,7 +277,7 @@ public class FederatedMisAlignedTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java
index 05e4954af1..cfdac1932f 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedMultiplyTest.java
@@ -114,12 +114,11 @@ public class FederatedMultiplyTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
@@ -145,7 +144,7 @@ public class FederatedMultiplyTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
OptimizerUtils.FEDERATED_COMPILATION = false;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java
index 530a8b7b55..2447143de2 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedProdTest.java
@@ -108,14 +108,11 @@ public class FederatedProdTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
@@ -154,7 +151,7 @@ public class FederatedProdTest extends AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java
index 25d460526e..6f13fce81b 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileTest.java
@@ -164,12 +164,13 @@ public class FederatedQuantileTest extends
AutomatedTestBase {
port2 = getRandomAvailablePort();
port3 = getRandomAvailablePort();
port4 = getRandomAvailablePort();
- t1 = startLocalFedWorker(port1,
FED_WORKER_WAIT_S);
- t2 = startLocalFedWorker(port2,
FED_WORKER_WAIT_S);
- t3 = startLocalFedWorker(port3,
FED_WORKER_WAIT_S);
- t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new
int[] {port1, port2, port3, port4});
+ t1 = workers[0];
+ t2 = workers[1];
+ t3 = workers[2];
+ t4 = workers[3];
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed
starting federated worker");
programArgs1 = new String[] {"-explain",
"-stats", "100", "-args", String.valueOf(p), expected("S"),
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java
index dadbd6b590..ebf4708b3e 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedQuantileWeightsTest.java
@@ -132,12 +132,13 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
port2 = getRandomAvailablePort();
port3 = getRandomAvailablePort();
port4 = getRandomAvailablePort();
- t1 = startLocalFedWorker(port1,
FED_WORKER_WAIT_S);
- t2 = startLocalFedWorker(port2,
FED_WORKER_WAIT_S);
- t3 = startLocalFedWorker(port3,
FED_WORKER_WAIT_S);
- t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new
int[] {port1, port2, port3, port4});
+ t1 = workers[0];
+ t2 = workers[1];
+ t3 = workers[2];
+ t4 = workers[3];
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed
starting federated worker");
programArgs1 = new String[] {"-explain",
"-stats", "100", "-args", String.valueOf(p), expected("S"),
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java
index d8ef8ca4c2..be69fac747 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRCBindTest.java
@@ -109,14 +109,11 @@ public class FederatedRCBindTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
// we need the reference file to not be written to
hdfs, so we get the correct format
@@ -158,7 +155,7 @@ public class FederatedRCBindTest extends AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
index 46bf7a4565..ce6aba438e 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRdiagTest.java
@@ -112,14 +112,11 @@ public class FederatedRdiagTest extends AutomatedTestBase
{
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
ProgramBlock.CHECK_MATRIX_PROPERTIES = true;
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
// reference file should not be written to hdfs, so we
set platform here
@@ -159,7 +156,7 @@ public class FederatedRdiagTest extends AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.FEDERATED_COMPILATION = false;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java
index 71f7c58366..0714a6e6fb 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRemoveEmptyTest.java
@@ -121,14 +121,11 @@ public class FederatedRemoveEmptyTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
@@ -166,7 +163,7 @@ public class FederatedRemoveEmptyTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java
index d2ec98a4a8..c5347d6276 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReplaceTest.java
@@ -108,13 +108,10 @@ public class FederatedReplaceTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
@@ -153,7 +150,7 @@ public class FederatedReplaceTest extends AutomatedTestBase
{
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java
index 5d6887c3e2..fb8bc619ec 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedReshapeTest.java
@@ -101,14 +101,11 @@ public class FederatedReshapeTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
// reference file should not be written to hdfs, so we
set platform here
rtplatform = execMode;
@@ -146,7 +143,7 @@ public class FederatedReshapeTest extends AutomatedTestBase
{
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
index 7fe88228d0..64b7d396d2 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRevTest.java
@@ -173,13 +173,10 @@ public class FederatedRevTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
@@ -226,7 +223,7 @@ public class FederatedRevTest extends AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java
index b8fe21ef52..7a72a868b6 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRightIndexTest.java
@@ -178,13 +178,10 @@ public class FederatedRightIndexTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
@@ -236,7 +233,7 @@ public class FederatedRightIndexTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
index f242710338..aa367ea7d4 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRollTest.java
@@ -131,14 +131,11 @@ public class FederatedRollTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if (!isAlive(t1, t2, t3, t4))
+ if (!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
if (rtplatform == ExecMode.SPARK) {
@@ -177,7 +174,7 @@ public class FederatedRollTest extends AutomatedTestBase {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
} finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java
index 25b02ca153..37b6f8826e 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part2/FederatedRowIndexTest.java
@@ -108,13 +108,10 @@ public class FederatedRowIndexTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
@@ -155,7 +152,7 @@ public class FederatedRowIndexTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java
index fc9e4a73ba..462181e75e 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedSplitTest.java
@@ -106,11 +106,10 @@ public class FederatedSplitTest extends AutomatedTestBase
{
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
// Run reference dml script with normal matrix
@@ -141,7 +140,7 @@ public class FederatedSplitTest extends AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java
index ac9706a44a..dec5fac443 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedStatisticsTest.java
@@ -96,10 +96,9 @@ public class FederatedStatisticsTest extends
AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
@@ -134,7 +133,7 @@ public class FederatedStatisticsTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java
index f8acc4623a..84b632c5e2 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java
@@ -91,12 +91,10 @@ public class FederatedTokenizeTest extends
AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3});
try {
- if(!isAlive(t1, t2, t3))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
FileFormatPropertiesCSV ffpCSV = new
FileFormatPropertiesCSV(false, DataExpression.DEFAULT_DELIM_DELIMITER,
@@ -143,7 +141,7 @@ public class FederatedTokenizeTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3);
+ TestUtils.shutdownThreads(workers);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java
index 0d29617eb7..2f2adf2bc9 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTransferLocalDataTest.java
@@ -88,13 +88,10 @@ public class FederatedTransferLocalDataTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
@@ -125,7 +122,7 @@ public class FederatedTransferLocalDataTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java
index 63432e4465..837fb979fc 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTriTest.java
@@ -108,12 +108,9 @@ public class FederatedTriTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
@@ -151,7 +148,7 @@ public class FederatedTriTest extends AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java
index bd48a45924..5e0953d0aa 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedCrossEntropyTest.java
@@ -136,11 +136,10 @@ public class FederatedWeightedCrossEntropyTest extends
AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
getAndLoadTestConfiguration(testname);
@@ -172,7 +171,7 @@ public class FederatedWeightedCrossEntropyTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java
index 6753774f65..3f866be984 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedDivMatrixMultTest.java
@@ -297,11 +297,10 @@ public class FederatedWeightedDivMatrixMultTest extends
AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
getAndLoadTestConfiguration(test_name);
@@ -334,7 +333,7 @@ public class FederatedWeightedDivMatrixMultTest extends
AutomatedTestBase {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java
index ad21760d08..15a709a205 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSigmoidTest.java
@@ -161,11 +161,10 @@ public class FederatedWeightedSigmoidTest extends
AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
getAndLoadTestConfiguration(test_name);
@@ -198,7 +197,7 @@ public class FederatedWeightedSigmoidTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java
index 7fac163c4a..111c16ded8 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedSquaredLossTest.java
@@ -151,11 +151,10 @@ public class FederatedWeightedSquaredLossTest extends
AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
getAndLoadTestConfiguration(test_name);
@@ -187,7 +186,7 @@ public class FederatedWeightedSquaredLossTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java
index 8019edf83f..610b5ebedf 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedWeightedUnaryMatrixMultTest.java
@@ -160,11 +160,10 @@ public class FederatedWeightedUnaryMatrixMultTest extends
AutomatedTestBase {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2});
try {
- if(!isAlive(t1, t2))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
getAndLoadTestConfiguration(test_name);
@@ -197,7 +196,7 @@ public class FederatedWeightedUnaryMatrixMultTest extends
AutomatedTestBase {
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platform_old);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java
index ba725519a8..1088de5362 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part4/FederatedRowAggregateTest.java
@@ -214,14 +214,11 @@ public class FederatedRowAggregateTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
// Run reference dml script with normal matrix
@@ -280,7 +277,7 @@ public class FederatedRowAggregateTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
index 48c9cab632..0993b41698 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCovarianceTest.java
@@ -134,13 +134,10 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
setExecMode(execMode);
@@ -203,7 +200,7 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
@@ -246,13 +243,10 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4, FED_WORKER_WAIT);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -364,7 +358,7 @@ public class FederatedCovarianceTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java
index 4ba19fb3d0..ebb42dfec8 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedCtableTest.java
@@ -133,14 +133,11 @@ public class FederatedCtableTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
@@ -154,7 +151,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java
index 037675c584..189c109fa8 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFrameMapTest.java
@@ -101,14 +101,11 @@ public class FederatedFrameMapTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
@@ -149,7 +146,7 @@ public class FederatedFrameMapTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java
index d75924746c..006b5d7688 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullAggregateTest.java
@@ -203,13 +203,10 @@ public class FederatedFullAggregateTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -264,7 +261,7 @@ public class FederatedFullAggregateTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
index 5eb1179efd..b20d949857 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedFullCumulativeTest.java
@@ -168,14 +168,11 @@ public class FederatedFullCumulativeTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
@@ -245,7 +242,7 @@ public class FederatedFullCumulativeTest extends
AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(platformOld);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
index abf0c8c228..375a4f9518 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedIfelseTest.java
@@ -138,14 +138,11 @@ public class FederatedIfelseTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1);
- Process t2 = startLocalFedWorker(port2);
- Process t3 = startLocalFedWorker(port3);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
@@ -220,7 +217,7 @@ public class FederatedIfelseTest extends AutomatedTestBase {
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java
index c37341ed68..73cb5470a3 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part5/FederatedMMChainTest.java
@@ -138,14 +138,11 @@ public class FederatedMMChainTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- Process t1 = startLocalFedWorker(port1, FED_WORKER_WAIT_S);
- Process t2 = startLocalFedWorker(port2, FED_WORKER_WAIT_S);
- Process t3 = startLocalFedWorker(port3, FED_WORKER_WAIT_S);
- Process t4 = startLocalFedWorker(port4);
+ Process[] workers = startLocalFedWorkers(new int[] {port1,
port2, port3, port4});
try {
- if(!isAlive(t1, t2, t3, t4))
+ if(!isAlive(workers))
throw new RuntimeException("Failed starting
federated worker");
rtplatform = execMode;
@@ -181,7 +178,7 @@ public class FederatedMMChainTest extends AutomatedTestBase
{
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index b70bdd2940..fd1f88f276 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -230,7 +230,7 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
default: throw new RuntimeException("Not supported
type");
}
- Thread t1 = null, t2 = null, t3 = null, t4 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(TEST_NAME1);
@@ -239,10 +239,7 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
String[] otherargs = lineage ? new String[]
{"-lineage", "reuse_full"} : null;
- t1 = startLocalFedWorkerThread(port1, otherargs);
- t2 = startLocalFedWorkerThread(port2, otherargs);
- t3 = startLocalFedWorkerThread(port3, otherargs);
- t4 = startLocalFedWorkerThread(port4, otherargs);
+ workers = startLocalFedWorkerThreads(new int[] {port1,
port2, port3, port4}, otherargs);
FileFormatPropertiesCSV ffpCSV = new
FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
DataExpression.DEFAULT_DELIM_FILL,
DataExpression.DEFAULT_DELIM_FILL_VALUE, DATASET.equals(DATASET1) ?
@@ -345,7 +342,7 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
throw new RuntimeException(ex);
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(rtold);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
index f144e03984..8c94e62692 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
@@ -127,7 +127,7 @@ public class TransformFederatedEncodeDecodeTest extends
AutomatedTestBase {
private void runTransformEncodeDecodeTest(boolean recode, boolean
sparse, Types.FileFormat format) {
ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE);
- Thread t1 = null, t2 = null, t3 = null, t4 = null;
+ Thread[] workers = null;
try {
getAndLoadTestConfiguration(TEST_NAME_RECODE);
@@ -135,10 +135,7 @@ public class TransformFederatedEncodeDecodeTest extends
AutomatedTestBase {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
- t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
- t2 = startLocalFedWorkerThread(port2,
FED_WORKER_WAIT_S);
- t3 = startLocalFedWorkerThread(port3,
FED_WORKER_WAIT_S);
- t4 = startLocalFedWorkerThread(port4);
+ workers = startLocalFedWorkerThreads(new int[] {port1,
port2, port3, port4}, null, FED_WORKER_WAIT);
// schema
Types.ValueType[] schema = new Types.ValueType[cols /
2];
@@ -205,7 +202,7 @@ public class TransformFederatedEncodeDecodeTest extends
AutomatedTestBase {
Assert.fail(ex.getMessage());
}
finally {
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
resetExecMode(rtold);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
index ca133caa76..0c46cd68ea 100644
---
a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
@@ -103,8 +103,7 @@ public class FedFullReuseTest extends AutomatedTestBase {
int port2 = getRandomAvailablePort();
String[] otherargs = new String[] {"-lineage", "reuse_full"};
Lineage.resetInternalState();
- Thread t1 = startLocalFedWorkerThread(port1, otherargs,
FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, otherargs);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2}, otherargs, FED_WORKER_WAIT);
TestConfiguration config =
availableTestConfigurations.get(test);
loadTestConfiguration(config);
@@ -149,7 +148,7 @@ public class FedFullReuseTest extends AutomatedTestBase {
}
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
index f7f01f6f9e..0cf9d97271 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedUDFReuseTest.java
@@ -108,10 +108,7 @@ public class FedUDFReuseTest extends AutomatedTestBase {
int port4 = getRandomAvailablePort();
String[] otherargs = new String[] {"-lineage", "reuse_full"};
Lineage.resetInternalState();
- Thread t1 = startLocalFedWorkerThread(port1, otherargs,
FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, otherargs,
FED_WORKER_WAIT_S);
- Thread t3 = startLocalFedWorkerThread(port3, otherargs,
FED_WORKER_WAIT_S);
- Thread t4 = startLocalFedWorkerThread(port4, otherargs);
+ Thread[] workers = startLocalFedWorkerThreads(new int[] {port1,
port2, port3, port4}, otherargs, FED_WORKER_WAIT);
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
@@ -146,7 +143,7 @@ public class FedUDFReuseTest extends AutomatedTestBase {
// assert reuse count
Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
+ TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java
index 15f6a7978e..b8a6619667 100644
---
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageFedReuseAlg.java
@@ -93,8 +93,7 @@ public class LineageFedReuseAlg extends AutomatedTestBase {
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
String[] otherargs = new String[] {"-lineage",
"reuse_full"};
- Thread t1 = startLocalFedWorkerThread(port1, otherargs,
FED_WORKER_WAIT_S);
- Thread t2 = startLocalFedWorkerThread(port2, otherargs);
+ Thread[] workers = startLocalFedWorkerThreads(new int[]
{port1, port2}, otherargs, FED_WORKER_WAIT);
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
@@ -136,7 +135,7 @@ public class LineageFedReuseAlg extends AutomatedTestBase {
assertTrue(mmCount > mmCount_reuse);
assertTrue(fed_mmCount > fed_mmCount_reuse);
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(workers);
}
finally {
resetExecMode(oldExec);