This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 914b8f8966879c274ca30130a24d502c08f59b6c Author: baunsgaard <baunsga...@tugraz.at> AuthorDate: Sat Nov 14 10:29:08 2020 +0100 [MINOR] Federated Modifications Major reduction in federated tests, by redusing startup time of federated tests with multiple workers. Furthermore a timeout is added to funtions tests allowing only 60 minutes of execution time before being forcefully terminated. This reduce the waiting time for feedback of tests that anyway would timeout after 6 hours. Isolate Function Test in workflows, and stabilize negative federated test, and reduce Federated Kmeans Tests Privacy monitor added a null pointer check that happens if the object on the federated site becomes null. This error would result in stack traces that were hard to debug. Fix :bug: in federated right indexing if the indexing aligns to a split between locations. --- .github/workflows/functionsTests.yml | 4 +- .../fed/MatrixIndexingFEDInstruction.java | 4 +- .../sysds/runtime/privacy/PrivacyMonitor.java | 2 + .../org/apache/sysds/test/AutomatedTestBase.java | 15 ++++- .../federated/algorithms/FederatedBivarTest.java | 6 +- .../federated/algorithms/FederatedCorTest.java | 6 +- .../federated/algorithms/FederatedGLMTest.java | 2 +- .../federated/algorithms/FederatedKmeansTest.java | 22 ++++--- .../federated/algorithms/FederatedL2SVMTest.java | 2 +- .../federated/algorithms/FederatedLogRegTest.java | 2 +- .../federated/algorithms/FederatedPCATest.java | 6 +- .../federated/algorithms/FederatedUnivarTest.java | 6 +- .../federated/algorithms/FederatedVarTest.java | 6 +- .../federated/algorithms/FederatedYL2SVMTest.java | 2 +- .../federated/io/FederatedReaderTest.java | 2 +- .../functions/federated/io/FederatedSSLTest.java | 2 +- .../federated/io/FederatedWriterTest.java | 2 +- .../paramserv/FederatedParamservTest.java | 63 ++++++++++-------- .../primitives/FederatedBinaryMatrixTest.java | 2 +- .../primitives/FederatedBinaryVectorTest.java | 2 +- .../primitives/FederatedCastToFrameTest.java | 2 +- .../primitives/FederatedCastToMatrixTest.java | 2 +- .../primitives/FederatedCentralMomentTest.java | 8 +-- ...ateTest.java => FederatedColAggregateTest.java} | 74 ++++++--------------- .../primitives/FederatedFullAggregateTest.java | 8 ++- .../primitives/FederatedMultiplyTest.java | 2 +- .../primitives/FederatedNegativeTest.java | 28 +++++--- .../federated/primitives/FederatedRCBindTest.java | 4 +- .../primitives/FederatedRightIndexTest.java | 41 ++++++------ ...ateTest.java => FederatedRowAggregateTest.java} | 75 ++++++---------------- .../federated/primitives/FederatedSplitTest.java | 4 +- .../primitives/FederatedStatisticsTest.java | 2 +- .../TransformFederatedEncodeApplyTest.java | 6 +- .../TransformFederatedEncodeDecodeTest.java | 6 +- 34 files changed, 192 insertions(+), 228 deletions(-) diff --git a/.github/workflows/functionsTests.yml b/.github/workflows/functionsTests.yml index a094652..c816245 100644 --- a/.github/workflows/functionsTests.yml +++ b/.github/workflows/functionsTests.yml @@ -32,13 +32,15 @@ on: jobs: applicationsTests: runs-on: ${{ matrix.os }} + timeout-minutes: 60 strategy: fail-fast: false matrix: tests: [ "**.functions.aggregate.**,**.functions.append.**,**.functions.binary.frame.**,**.functions.binary.matrix.**,**.functions.binary.scalar.**,**.functions.binary.tensor.**", "**.functions.blocks.**,**.functions.compress.**,**.functions.countDistinct.**,**.functions.data.misc.**,**.functions.data.rand.**,**.functions.data.tensor.**,**.functions.codegenalg.parttwo.**,**.functions.codegen.**,**.functions.caching.**", - "**.functions.federated.**,**.functions.binary.matrix_full_cellwise.**,**.functions.binary.matrix_full_other.**", + "**.functions.binary.matrix_full_cellwise.**,**.functions.binary.matrix_full_other.**", + "**.functions.federated.**", "**.functions.codegenalg.partone.**", "**.functions.builtin.**", "**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.jmlc.**,**.functions.lineage.**", diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java index bc2c066..5c0a821 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MatrixIndexingFEDInstruction.java @@ -77,9 +77,9 @@ public final class MatrixIndexingFEDInstruction extends IndexingFEDInstruction { curFedRange.setBeginDim(0, Math.max(rs - ixrange.rowStart, 0)); curFedRange.setBeginDim(1, Math.max(cs - ixrange.colStart, 0)); curFedRange.setEndDim(0, - (ixrange.rowEnd > re ? re - ixrange.rowStart : ixrange.rowEnd - ixrange.rowStart + 1)); + (ixrange.rowEnd >= re ? re - ixrange.rowStart : ixrange.rowEnd - ixrange.rowStart + 1)); curFedRange.setEndDim(1, - (ixrange.colEnd > ce ? ce - ixrange.colStart : ixrange.colEnd - ixrange.colStart + 1)); + (ixrange.colEnd >= ce ? ce - ixrange.colStart : ixrange.colEnd - ixrange.colStart + 1)); if(LOG.isDebugEnabled()) { LOG.debug("Fed Mapping After : " + curFedRange); } diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java index 4e286d0..97ac22b 100644 --- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java +++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyMonitor.java @@ -65,6 +65,8 @@ public class PrivacyMonitor * @return data object or data object with privacy constraint removed in case the privacy level was none. */ public static Data handlePrivacy(Data dataObject){ + if(dataObject == null) + return null; PrivacyConstraint privacyConstraint = dataObject.getPrivacyConstraint(); if (privacyConstraint != null){ PrivacyLevel privacyLevel = privacyConstraint.getPrivacyLevel(); diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index c3a0c59..3c3471e 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -1421,6 +1421,19 @@ public abstract class AutomatedTestBase { * @return the thread associated with the worker. */ protected Thread startLocalFedWorkerThread(int port) { + return startLocalFedWorkerThread(port, FED_WORKER_WAIT); + } + + /** + * Start a thread for a worker. This will share the same JVM, so all static variables will be shared.! + * + * Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled. + * + * @param port Port to use + * @param sleep The amount of time to wait for the worker startup. in Milliseconds + * @return the thread associated with the worker. + */ + protected Thread startLocalFedWorkerThread(int port, int sleep) { Thread t = null; String[] fedWorkArgs = {"-w", Integer.toString(port)}; ArrayList<String> args = new ArrayList<>(); @@ -1443,7 +1456,7 @@ public abstract class AutomatedTestBase { } }); t.start(); - java.util.concurrent.TimeUnit.MILLISECONDS.sleep(FED_WORKER_WAIT); + java.util.concurrent.TimeUnit.MILLISECONDS.sleep(sleep); } catch(InterruptedException e) { e.printStackTrace(); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java index ced8bca..ff811e0 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java @@ -114,9 +114,9 @@ public class FederatedBivarTest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); Thread t4 = startLocalFedWorkerThread(port4); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java index 15383b2..82437b1 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java @@ -102,9 +102,9 @@ public class FederatedCorTest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); Thread t4 = startLocalFedWorkerThread(port4); rtplatform = execMode; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java index 1e608ce..eb8aee8 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java @@ -95,7 +95,7 @@ public class FederatedGLMTest extends AutomatedTestBase { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java index 8a33d20..f296b3a 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java @@ -19,10 +19,8 @@ package org.apache.sysds.test.functions.federated.algorithms; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import java.util.Arrays; +import java.util.Collection; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.ExecMode; @@ -33,9 +31,11 @@ import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; - -import java.util.Arrays; -import java.util.Collection; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe @@ -64,9 +64,10 @@ public class FederatedKmeansTest extends AutomatedTestBase { @Parameterized.Parameters public static Collection<Object[]> data() { // rows have to be even and > 1 - return Arrays.asList(new Object[][] {{10000, 10, 1, 1}, + return Arrays.asList(new Object[][] { + // {10000, 10, 1, 1}, // {2000, 50, 1, 1}, {1000, 100, 1, 1}, - {10000, 10, 2, 1}, + // {10000, 10, 2, 1}, // {2000, 50, 2, 1}, {1000, 100, 2, 1}, //concurrent requests {10000, 10, 2, 2}, // repeated exec // TODO more runs e.g., 16 -> but requires rework RPC framework first @@ -80,6 +81,7 @@ public class FederatedKmeansTest extends AutomatedTestBase { } @Test + @Ignore public void federatedKmeansHybrid() { federatedKmeans(Types.ExecMode.HYBRID); } @@ -102,7 +104,7 @@ public class FederatedKmeansTest extends AutomatedTestBase { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java index 53bfc8d..f17754e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java @@ -99,7 +99,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java index 42c614b..e7f1f80 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java @@ -95,7 +95,7 @@ public class FederatedLogRegTest extends AutomatedTestBase { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java index 99c90ee..8438bb6 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java @@ -102,9 +102,9 @@ public class FederatedPCATest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); Thread t4 = startLocalFedWorkerThread(port4); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java index 588796a..7333533 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java @@ -100,9 +100,9 @@ public class FederatedUnivarTest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); Thread t4 = startLocalFedWorkerThread(port4); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java index 280f0d3..46af1c9 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java @@ -109,9 +109,9 @@ public class FederatedVarTest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); Thread t4 = startLocalFedWorkerThread(port4); rtplatform = execMode; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java index 0657e50..d0eaf87 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedYL2SVMTest.java @@ -104,7 +104,7 @@ public class FederatedYL2SVMTest extends AutomatedTestBase { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java index d4dc464..2587fe9 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java @@ -87,7 +87,7 @@ public class FederatedReaderTest extends AutomatedTestBase { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); String host = "localhost"; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java index 6ec2f40..d086174 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java @@ -93,7 +93,7 @@ public class FederatedSSLTest extends AutomatedTestBase { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); String host = "localhost"; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java index c8a50fe..a83fad3 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java @@ -83,7 +83,7 @@ public class FederatedWriterTest extends AutomatedTestBase { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); try { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index c4d04ea..3015aaa 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -37,7 +37,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; - @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe public class FederatedParamservTest extends AutomatedTestBase { @@ -60,15 +59,12 @@ public class FederatedParamservTest extends AutomatedTestBase { @Parameterized.Parameters public static Collection<Object[]> parameters() { return Arrays.asList(new Object[][] { - //Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update type, update frequency - {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, + // Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update + // type, update frequency + {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, + {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, + {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, + {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH"}, // {"TwoNN", 5, 1000, 200, 2, 0.01, "ASP", "BATCH"}, // {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH"}, @@ -80,7 +76,8 @@ public class FederatedParamservTest extends AutomatedTestBase { }); } - public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size, int epochs, double eta, String utype, String freq) { + public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size, + int epochs, double eta, String utype, String freq) { _networkType = networkType; _numFederatedWorkers = numFederatedWorkers; _examplesPerWorker = examplesPerWorker; @@ -101,12 +98,12 @@ public class FederatedParamservTest extends AutomatedTestBase { public void federatedParamservSingleNode() { federatedParamserv(ExecMode.SINGLE_NODE); } - + @Test public void federatedParamservHybrid() { federatedParamserv(ExecMode.HYBRID); } - + private void federatedParamserv(ExecMode mode) { // config getAndLoadTestConfiguration(TEST_NAME); @@ -114,18 +111,17 @@ public class FederatedParamservTest extends AutomatedTestBase { setOutputBuffering(true); int C = 1, Hin = 28, Win = 28; - int numFeatures = C*Hin*Win; + int numFeatures = C * Hin * Win; int numLabels = 10; ExecMode platformOld = setExecMode(mode); - + try { - + // dml name fullDMLScriptName = HOME + TEST_NAME + ".dml"; // generate program args - List<String> programArgsList = new ArrayList<>(Arrays.asList( - "-stats", + List<String> programArgsList = new ArrayList<>(Arrays.asList("-stats", "-nvargs", "examples_per_worker=" + _examplesPerWorker, "num_features=" + numFeatures, @@ -138,28 +134,39 @@ public class FederatedParamservTest extends AutomatedTestBase { "network_type=" + _networkType, "channels=" + C, "hin=" + Hin, - "win=" + Win - )); - + "win=" + Win)); + // for each worker List<Integer> ports = new ArrayList<>(); List<Thread> threads = new ArrayList<>(); for(int i = 0; i < _numFederatedWorkers; i++) { // write row partitioned features to disk - writeInputMatrixWithMTD("X" + i, generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), false, - new MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize, _examplesPerWorker * numFeatures)); + writeInputMatrixWithMTD("X" + i, + generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), + false, + new MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize, + _examplesPerWorker * numFeatures)); // write row partitioned labels to disk - writeInputMatrixWithMTD("y" + i, generateDummyMNISTLabels(_examplesPerWorker, numLabels), false, - new MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize, _examplesPerWorker * numLabels)); - + writeInputMatrixWithMTD("y" + i, + generateDummyMNISTLabels(_examplesPerWorker, numLabels), + false, + new MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize, + _examplesPerWorker * numLabels)); + // start worker ports.add(getRandomAvailablePort()); - threads.add(startLocalFedWorkerThread(ports.get(i))); - + threads.add(startLocalFedWorkerThread(ports.get(i), 10)); + // add worker to program args programArgsList.add("X" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("X" + i))); programArgsList.add("y" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("y" + i))); } + try { + Thread.sleep(1000); + } + catch(InterruptedException e) { + e.printStackTrace(); + } programArgs = programArgsList.toArray(new String[0]); LOG.debug(runTest(null)); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java index 11f2bd4..958c09b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java @@ -95,7 +95,7 @@ public class FederatedBinaryMatrixTest extends AutomatedTestBase { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java index f11b7be..e8dd6f7 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java @@ -96,7 +96,7 @@ public class FederatedBinaryVectorTest extends AutomatedTestBase { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java index b67cc93..fe03906 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java @@ -97,7 +97,7 @@ public class FederatedCastToFrameTest extends AutomatedTestBase { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java index 57ffacf..fa51d89 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java @@ -126,7 +126,7 @@ public class FederatedCastToMatrixTest extends AutomatedTestBase { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java index 4d644ce..828718e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java @@ -98,10 +98,10 @@ public class FederatedCentralMomentTest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); - Thread t4 = startLocalFedWorkerThread(port4); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); + Thread t4 = startLocalFedWorkerThread(port4); // reference file should not be written to hdfs, so we set platform here rtplatform = execMode; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java similarity index 70% copy from src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java copy to src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java index 31800af..a8480e9 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java @@ -36,20 +36,15 @@ import org.junit.runners.Parameterized; @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe -public class FederatedRowColAggregateTest extends AutomatedTestBase { +public class FederatedColAggregateTest extends AutomatedTestBase { private final static String TEST_NAME1 = "FederatedColSumTest"; private final static String TEST_NAME2 = "FederatedColMeanTest"; private final static String TEST_NAME3 = "FederatedColMaxTest"; private final static String TEST_NAME4 = "FederatedColMinTest"; - private final static String TEST_NAME5 = "FederatedRowSumTest"; - private final static String TEST_NAME6 = "FederatedRowMeanTest"; - private final static String TEST_NAME7 = "FederatedRowMaxTest"; - private final static String TEST_NAME8 = "FederatedRowMinTest"; - private final static String TEST_NAME9 = "FederatedRowVarTest"; private final static String TEST_NAME10 = "FederatedColVarTest"; private final static String TEST_DIR = "functions/federated/aggregate/"; - private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRowColAggregateTest.class.getSimpleName() + "/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedColAggregateTest.class.getSimpleName() + "/"; private final static int blocksize = 1024; @Parameterized.Parameter() @@ -72,10 +67,6 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase { SUM, MEAN, MAX, MIN, VAR } - private enum InstType { - ROW, COL - } - @Override public void setUp() { TestUtils.clearAssertionInformation(); @@ -83,65 +74,36 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase { addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"})); addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"})); addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"})); - addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S"})); - addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {"S"})); - addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {"S"})); - addTestConfiguration(TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {"S"})); - addTestConfiguration(TEST_NAME9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] {"S"})); addTestConfiguration(TEST_NAME10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {"S"})); } @Test public void testColSumDenseMatrixCP() { - runAggregateOperationTest(OpType.SUM, InstType.COL, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.SUM, ExecMode.SINGLE_NODE); } @Test public void testColMeanDenseMatrixCP() { - runAggregateOperationTest(OpType.MEAN, InstType.COL, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.MEAN, ExecMode.SINGLE_NODE); } @Test public void testColMaxDenseMatrixCP() { - runAggregateOperationTest(OpType.MAX, InstType.COL, ExecMode.SINGLE_NODE); - } - - @Test - public void testRowSumDenseMatrixCP() { - runAggregateOperationTest(OpType.SUM, InstType.ROW, ExecMode.SINGLE_NODE); - } - - @Test - public void testRowMeanDenseMatrixCP() { - runAggregateOperationTest(OpType.MEAN, InstType.ROW, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.MAX, ExecMode.SINGLE_NODE); } - @Test - public void testRowMaxDenseMatrixCP() { - runAggregateOperationTest(OpType.MAX, InstType.ROW, ExecMode.SINGLE_NODE); - } - - @Test - public void testRowMinDenseMatrixCP() { - runAggregateOperationTest(OpType.MIN, InstType.ROW, ExecMode.SINGLE_NODE); - } @Test public void testColMinDenseMatrixCP() { - runAggregateOperationTest(OpType.MIN, InstType.COL, ExecMode.SINGLE_NODE); - } - - @Test - public void testRowVarDenseMatrixCP() { - runAggregateOperationTest(OpType.VAR, InstType.ROW, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.MIN, ExecMode.SINGLE_NODE); } @Test public void testColVarDenseMatrixCP() { - runAggregateOperationTest(OpType.VAR, InstType.COL, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.VAR, ExecMode.SINGLE_NODE); } - private void runAggregateOperationTest(OpType type, InstType instr, ExecMode execMode) { + private void runAggregateOperationTest(OpType type, ExecMode execMode) { boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; ExecMode platformOld = rtplatform; @@ -151,19 +113,19 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase { String TEST_NAME = null; switch(type) { case SUM: - TEST_NAME = instr == InstType.COL ? TEST_NAME1 : TEST_NAME5; + TEST_NAME = TEST_NAME1; break; case MEAN: - TEST_NAME = instr == InstType.COL ? TEST_NAME2 : TEST_NAME6; + TEST_NAME = TEST_NAME2; break; case MAX: - TEST_NAME = instr == InstType.COL ? TEST_NAME3 : TEST_NAME7; + TEST_NAME = TEST_NAME3; break; case MIN: - TEST_NAME = instr == InstType.COL ? TEST_NAME4 : TEST_NAME8; + TEST_NAME = TEST_NAME4; break; case VAR: - TEST_NAME = instr == InstType.COL ? TEST_NAME10 : TEST_NAME9; + TEST_NAME = TEST_NAME10; break; } @@ -195,9 +157,9 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); Thread t4 = startLocalFedWorkerThread(port4); rtplatform = execMode; @@ -227,9 +189,9 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase { runTest(true, false, null, -1); // compare via files - compareResults(type == FederatedRowColAggregateTest.OpType.VAR ? 1e-2 : 1e-9); + compareResults(type == FederatedColAggregateTest.OpType.VAR ? 1e-2 : 1e-9); - String fedInst = instr == InstType.COL ? "fed_uac" : "fed_uar"; + String fedInst = "fed_uac"; switch(type) { case SUM: diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java index ec7bda6..d388913 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java @@ -108,6 +108,7 @@ public class FederatedFullAggregateTest extends AutomatedTestBase { } @Test + @Ignore public void testSumDenseMatrixSP() { runColAggregateOperationTest(OpType.SUM, ExecType.SPARK); } @@ -131,6 +132,7 @@ public class FederatedFullAggregateTest extends AutomatedTestBase { } @Test + @Ignore public void testVarDenseMatrixSP() { runColAggregateOperationTest(OpType.VAR, ExecType.SPARK); } @@ -196,9 +198,9 @@ public class FederatedFullAggregateTest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); Thread t4 = startLocalFedWorkerThread(port4); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java index 4170914..3bc2649 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java @@ -103,7 +103,7 @@ public class FederatedMultiplyTest extends AutomatedTestBase { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedNegativeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedNegativeTest.java index 2ebe0c8..59dfcff 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedNegativeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedNegativeTest.java @@ -19,27 +19,37 @@ package org.apache.sysds.test.functions.federated.primitives; -import org.apache.sysds.common.Types; -import org.apache.sysds.runtime.controlprogram.federated.*; -import org.apache.sysds.test.AutomatedTestBase; -import org.apache.sysds.test.TestUtils; -import org.junit.Test; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import java.net.InetSocketAddress; import java.util.HashMap; import java.util.Map; import java.util.concurrent.Future; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRange; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederationMap; +import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; @net.jcip.annotations.NotThreadSafe public class FederatedNegativeTest { @Test public void NegativeTest1() { int port = AutomatedTestBase.getRandomAvailablePort(); - String[] args = {"-w", Integer.toString(port)}; - Thread t = AutomatedTestBase.startLocalFedWorkerWithArgs(args); + Thread t = null; + try{ + String[] args = {"-w", Integer.toString(port)}; + t = AutomatedTestBase.startLocalFedWorkerWithArgs(args); + } catch(Exception e){ + NegativeTest1(); + } FederationUtils.resetFedDataID(); //ensure expected ID when tests run in single JVM Map<FederatedRange, FederatedData> fedMap = new HashMap<>(); FederatedRange r = new FederatedRange(new long[]{0,0}, new long[]{1,1}); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java index abf37eb..540b188 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java @@ -93,8 +93,8 @@ public class FederatedRCBindTest extends AutomatedTestBase { writeInputMatrixWithMTD("B", B, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols)); int port1 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - int port2 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); // we need the reference file to not be written to hdfs, so we get the correct format diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java index d5f81e9..b9e7f62 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java @@ -65,14 +65,7 @@ public class FederatedRightIndexTest extends AutomatedTestBase { @Parameterized.Parameters public static Collection<Object[]> data() { - return Arrays.asList(new Object[][] { - // {20, 10, 6, 8, true}, - {20, 10, 1, 1, true}, - {20, 10, 2, 10, true}, - // {20, 10, 2, 10, true}, - // {20, 12, 2, 10, false}, - // {20, 12, 1, 4, false} - }); + return Arrays.asList(new Object[][] {{20, 10, 1, 1, true}, {20, 10, 3, 5, true}, {10, 12, 1, 10, false}}); } private enum IndexType { @@ -87,15 +80,15 @@ public class FederatedRightIndexTest extends AutomatedTestBase { addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"})); } - @Test - public void testRightIndexRightDenseMatrixCP() { - runAggregateOperationTest(IndexType.RIGHT, ExecMode.SINGLE_NODE); - } + // @Test + // public void testRightIndexRightDenseMatrixCP() { + // runAggregateOperationTest(IndexType.RIGHT, ExecMode.SINGLE_NODE); + // } - @Test - public void testRightIndexLeftDenseMatrixCP() { - runAggregateOperationTest(IndexType.LEFT, ExecMode.SINGLE_NODE); - } + // @Test + // public void testRightIndexLeftDenseMatrixCP() { + // runAggregateOperationTest(IndexType.LEFT, ExecMode.SINGLE_NODE); + // } @Test public void testRightIndexFullDenseMatrixCP() { @@ -112,13 +105,19 @@ public class FederatedRightIndexTest extends AutomatedTestBase { String TEST_NAME = null; switch(type) { case RIGHT: + from = from <= cols ? from : cols; + to = to <= cols ? to : cols; TEST_NAME = TEST_NAME1; break; case LEFT: + from = from <= rows ? from : rows; + to = to <= rows ? to : rows; TEST_NAME = TEST_NAME2; break; case FULL: TEST_NAME = TEST_NAME3; + from = from <= rows && from <= cols ? from : Math.min(rows, cols); + to = to <= rows && to <= cols ? to : Math.min(rows, cols); break; } @@ -150,9 +149,9 @@ public class FederatedRightIndexTest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); Thread t4 = startLocalFedWorkerThread(port4); rtplatform = execMode; @@ -163,6 +162,10 @@ public class FederatedRightIndexTest extends AutomatedTestBase { TestConfiguration config = availableTestConfigurations.get(TEST_NAME); loadTestConfiguration(config); + if(from > to) { + from = to; + } + // Run reference dml script with normal matrix fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; programArgs = new String[] {"-args", input("X1"), input("X2"), input("X3"), input("X4"), String.valueOf(from), diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java similarity index 70% rename from src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java rename to src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java index 31800af..49e692e 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java @@ -36,20 +36,15 @@ import org.junit.runners.Parameterized; @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe -public class FederatedRowColAggregateTest extends AutomatedTestBase { - private final static String TEST_NAME1 = "FederatedColSumTest"; - private final static String TEST_NAME2 = "FederatedColMeanTest"; - private final static String TEST_NAME3 = "FederatedColMaxTest"; - private final static String TEST_NAME4 = "FederatedColMinTest"; +public class FederatedRowAggregateTest extends AutomatedTestBase { private final static String TEST_NAME5 = "FederatedRowSumTest"; private final static String TEST_NAME6 = "FederatedRowMeanTest"; private final static String TEST_NAME7 = "FederatedRowMaxTest"; private final static String TEST_NAME8 = "FederatedRowMinTest"; private final static String TEST_NAME9 = "FederatedRowVarTest"; - private final static String TEST_NAME10 = "FederatedColVarTest"; private final static String TEST_DIR = "functions/federated/aggregate/"; - private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRowColAggregateTest.class.getSimpleName() + "/"; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRowAggregateTest.class.getSimpleName() + "/"; private final static int blocksize = 1024; @Parameterized.Parameter() @@ -72,76 +67,42 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase { SUM, MEAN, MAX, MIN, VAR } - private enum InstType { - ROW, COL - } - @Override public void setUp() { TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"})); - addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"})); - addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"})); - addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"})); addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S"})); addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {"S"})); addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {"S"})); addTestConfiguration(TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {"S"})); addTestConfiguration(TEST_NAME9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] {"S"})); - addTestConfiguration(TEST_NAME10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {"S"})); - } - - @Test - public void testColSumDenseMatrixCP() { - runAggregateOperationTest(OpType.SUM, InstType.COL, ExecMode.SINGLE_NODE); - } - - @Test - public void testColMeanDenseMatrixCP() { - runAggregateOperationTest(OpType.MEAN, InstType.COL, ExecMode.SINGLE_NODE); - } - - @Test - public void testColMaxDenseMatrixCP() { - runAggregateOperationTest(OpType.MAX, InstType.COL, ExecMode.SINGLE_NODE); } @Test public void testRowSumDenseMatrixCP() { - runAggregateOperationTest(OpType.SUM, InstType.ROW, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.SUM, ExecMode.SINGLE_NODE); } @Test public void testRowMeanDenseMatrixCP() { - runAggregateOperationTest(OpType.MEAN, InstType.ROW, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.MEAN, ExecMode.SINGLE_NODE); } @Test public void testRowMaxDenseMatrixCP() { - runAggregateOperationTest(OpType.MAX, InstType.ROW, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.MAX, ExecMode.SINGLE_NODE); } @Test public void testRowMinDenseMatrixCP() { - runAggregateOperationTest(OpType.MIN, InstType.ROW, ExecMode.SINGLE_NODE); - } - - @Test - public void testColMinDenseMatrixCP() { - runAggregateOperationTest(OpType.MIN, InstType.COL, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.MIN, ExecMode.SINGLE_NODE); } @Test public void testRowVarDenseMatrixCP() { - runAggregateOperationTest(OpType.VAR, InstType.ROW, ExecMode.SINGLE_NODE); - } - - @Test - public void testColVarDenseMatrixCP() { - runAggregateOperationTest(OpType.VAR, InstType.COL, ExecMode.SINGLE_NODE); + runAggregateOperationTest(OpType.VAR, ExecMode.SINGLE_NODE); } - private void runAggregateOperationTest(OpType type, InstType instr, ExecMode execMode) { + private void runAggregateOperationTest(OpType type, ExecMode execMode) { boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; ExecMode platformOld = rtplatform; @@ -151,19 +112,19 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase { String TEST_NAME = null; switch(type) { case SUM: - TEST_NAME = instr == InstType.COL ? TEST_NAME1 : TEST_NAME5; + TEST_NAME = TEST_NAME5; break; case MEAN: - TEST_NAME = instr == InstType.COL ? TEST_NAME2 : TEST_NAME6; + TEST_NAME = TEST_NAME6; break; case MAX: - TEST_NAME = instr == InstType.COL ? TEST_NAME3 : TEST_NAME7; + TEST_NAME = TEST_NAME7; break; case MIN: - TEST_NAME = instr == InstType.COL ? TEST_NAME4 : TEST_NAME8; + TEST_NAME = TEST_NAME8; break; case VAR: - TEST_NAME = instr == InstType.COL ? TEST_NAME10 : TEST_NAME9; + TEST_NAME = TEST_NAME9; break; } @@ -195,9 +156,9 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase { int port2 = getRandomAvailablePort(); int port3 = getRandomAvailablePort(); int port4 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - Thread t3 = startLocalFedWorkerThread(port3); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2, 10); + Thread t3 = startLocalFedWorkerThread(port3, 10); Thread t4 = startLocalFedWorkerThread(port4); rtplatform = execMode; @@ -227,9 +188,9 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase { runTest(true, false, null, -1); // compare via files - compareResults(type == FederatedRowColAggregateTest.OpType.VAR ? 1e-2 : 1e-9); + compareResults(type == FederatedRowAggregateTest.OpType.VAR ? 1e-2 : 1e-9); - String fedInst = instr == InstType.COL ? "fed_uac" : "fed_uar"; + String fedInst = "fed_uar"; switch(type) { case SUM: diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java index 9c4b6d0..9d37aff 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java @@ -98,8 +98,8 @@ public class FederatedSplitTest extends AutomatedTestBase { int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); + Thread t1 = startLocalFedWorkerThread(port1, 10); + Thread t2 = startLocalFedWorkerThread(port2); // Run reference dml script with normal matrix fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java index 99e649f..865582d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java @@ -99,7 +99,7 @@ public class FederatedStatisticsTest extends AutomatedTestBase { fullDMLScriptName = ""; int port1 = getRandomAvailablePort(); int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); + Thread t1 = startLocalFedWorkerThread(port1, 10); Thread t2 = startLocalFedWorkerThread(port2); TestConfiguration config = availableTestConfigurations.get(TEST_NAME); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java index 3aa0981..b7036d0 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java @@ -196,12 +196,12 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase { getAndLoadTestConfiguration(TEST_NAME1); int port1 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1); int port2 = getRandomAvailablePort(); - t2 = startLocalFedWorkerThread(port2); int port3 = getRandomAvailablePort(); - t3 = startLocalFedWorkerThread(port3); int port4 = getRandomAvailablePort(); + t1 = startLocalFedWorkerThread(port1, 10); + t2 = startLocalFedWorkerThread(port2, 10); + t3 = startLocalFedWorkerThread(port3, 10); t4 = startLocalFedWorkerThread(port4); FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER, diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java index 0c8ec1f..458dbc1 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java @@ -131,12 +131,12 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase { getAndLoadTestConfiguration(TEST_NAME_RECODE); int port1 = getRandomAvailablePort(); - t1 = startLocalFedWorkerThread(port1); int port2 = getRandomAvailablePort(); - t2 = startLocalFedWorkerThread(port2); int port3 = getRandomAvailablePort(); - t3 = startLocalFedWorkerThread(port3); int port4 = getRandomAvailablePort(); + t1 = startLocalFedWorkerThread(port1, 10); + t2 = startLocalFedWorkerThread(port2, 10); + t3 = startLocalFedWorkerThread(port3, 10); t4 = startLocalFedWorkerThread(port4); // schema