This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 3d97048e457a00af9288a9940f962392df3abcbc Author: Matthias Boehm <[email protected]> AuthorDate: Sat Oct 31 18:05:54 2020 +0100 [SYSTEMDS-2679] Fix dim size propagation for federated data ops This patch fixes two issues that where encountered in federated parameter servers. First, we fixed null pointer executions in the cleanup of parallelized RDDs for the case of redundant cleanups. Second, we fixed the size propagation for federated DataOps which so far was never refreshed during recompilation after initial parsing. Each of these fixes alone would already solve the reported bug, but the fixed size propagation is important for all federated use cases in default hybrid execution mode. --- src/main/java/org/apache/sysds/hops/DataOp.java | 29 +- src/main/java/org/apache/sysds/hops/Hop.java | 4 + .../apache/sysds/hops/rewrite/HopRewriteUtils.java | 8 +- src/main/java/org/apache/sysds/lops/Federated.java | 3 +- .../context/SparkExecutionContext.java | 4 +- .../federated/io/FederatedReaderTest.java | 172 +++++------ .../federated/io/FederatedWriterTest.java | 192 ++++++------ .../paramserv/FederatedParamservTest.java | 325 +++++++++++---------- 8 files changed, 382 insertions(+), 355 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java index 0046078..114006f 100644 --- a/src/main/java/org/apache/sysds/hops/DataOp.java +++ b/src/main/java/org/apache/sysds/hops/DataOp.java @@ -30,6 +30,7 @@ import org.apache.sysds.common.Types.OpOpData; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.conf.CompilerConfig.ConfigType; import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.lops.Data; import org.apache.sysds.lops.Federated; import org.apache.sysds.lops.Lop; @@ -37,6 +38,7 @@ import org.apache.sysds.lops.LopProperties.ExecType; import org.apache.sysds.lops.LopsException; import org.apache.sysds.lops.Sql; import org.apache.sysds.parser.DataExpression; +import static org.apache.sysds.parser.DataExpression.FED_RANGES; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.util.LocalFileUtils; @@ -270,8 +272,7 @@ public class DataOp extends Hop { // construct lops for all input parameters HashMap<String, Lop> inputLops = new HashMap<>(); for (Entry<String, Integer> cur : _paramIndexMap.entrySet()) { - inputLops.put(cur.getKey(), getInput().get(cur.getValue()) - .constructLops()); + inputLops.put(cur.getKey(), getInput().get(cur.getValue()).constructLops()); } // Create the lop @@ -488,21 +489,30 @@ public class DataOp extends Hop { } @Override - public void refreshSizeInformation() - { - if( _op == OpOpData.PERSISTENTWRITE || _op == OpOpData.TRANSIENTWRITE ) - { + public void refreshSizeInformation() { + if( _op == OpOpData.PERSISTENTWRITE || _op == OpOpData.TRANSIENTWRITE ) { Hop input1 = getInput().get(0); setDim1(input1.getDim1()); setDim2(input1.getDim2()); setNnz(input1.getNnz()); } - else //READ - { + else if( _op == OpOpData.FEDERATED ) { + Hop ranges = getInput().get(getParameterIndex(FED_RANGES)); + long nrow = -1, ncol = -1; + for( Hop c : ranges.getInput() ) { + if( !(c.getInput(0) instanceof LiteralOp && c.getInput(1) instanceof LiteralOp)) + return; // invalid size inference if not all know. + nrow = Math.max(nrow, HopRewriteUtils.getIntValueSafe(c.getInput(0))); + ncol = Math.max(ncol, HopRewriteUtils.getIntValueSafe(c.getInput(1))); + } + setDim1(nrow); + setDim2(ncol); + } + else { //READ //do nothing; dimensions updated via set output params } } - + /** * Explicitly disables recompilation of transient reads, this additional information @@ -590,5 +600,4 @@ public class DataOp extends Hop { } } } - } diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index 60a4dc5..bb0960d 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -758,6 +758,10 @@ public abstract class Hop implements ParseInfo { return _input; } + public Hop getInput(int ix) { + return _input.get(ix); + } + public void addInput( Hop h ) { _input.add(h); h._parent.add(this); diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java index 1815d69..b1a8799 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java @@ -162,10 +162,14 @@ public class HopRewriteUtils } } + public static long getIntValueSafe( Hop op ) { + return getIntValueSafe((LiteralOp) op); + } + public static long getIntValueSafe( LiteralOp op ) { switch( op.getValueType() ) { - case FP64: return UtilFunctions.toLong(op.getDoubleValue()); - case INT64: return op.getLongValue(); + case FP64: return UtilFunctions.toLong(op.getDoubleValue()); + case INT64: return op.getLongValue(); case BOOLEAN: return op.getBooleanValue() ? 1 : 0; default: return Long.MAX_VALUE; } diff --git a/src/main/java/org/apache/sysds/lops/Federated.java b/src/main/java/org/apache/sysds/lops/Federated.java index 8aacbd7..52b52be 100644 --- a/src/main/java/org/apache/sysds/lops/Federated.java +++ b/src/main/java/org/apache/sysds/lops/Federated.java @@ -63,7 +63,6 @@ public class Federated extends Lop { @Override public String toString() { - // TODO Federated.toString() lop - return null; + return "FedInit"; } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java index 2be647d..41ac510 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java @@ -1847,8 +1847,8 @@ public class SparkExecutionContext extends ExecutionContext } public synchronized void deregisterRDD(int rddID) { - long rddSize = _rdds.remove(rddID); - _size -= rddSize; + Long rddSize = _rdds.remove(rddID); + _size -= (rddSize!=null) ? rddSize : 0; } public synchronized void clear() { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java index a8e4407..c14ac1d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java @@ -38,98 +38,98 @@ import org.junit.runners.Parameterized; @net.jcip.annotations.NotThreadSafe public class FederatedReaderTest extends AutomatedTestBase { - // private static final Log LOG = LogFactory.getLog(FederatedReaderTest.class.getName()); - private final static String TEST_DIR = "functions/federated/ioR/"; - private final static String TEST_NAME = "FederatedReaderTest"; - private final static String TEST_CLASS_DIR = TEST_DIR + FederatedReaderTest.class.getSimpleName() + "/"; - private final static int blocksize = 1024; - @Parameterized.Parameter() - public int rows; - @Parameterized.Parameter(1) - public int cols; - @Parameterized.Parameter(2) - public boolean rowPartitioned; - @Parameterized.Parameter(3) - public int fedCount; + // private static final Log LOG = LogFactory.getLog(FederatedReaderTest.class.getName()); + private final static String TEST_DIR = "functions/federated/ioR/"; + private final static String TEST_NAME = "FederatedReaderTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedReaderTest.class.getSimpleName() + "/"; + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + @Parameterized.Parameter(2) + public boolean rowPartitioned; + @Parameterized.Parameter(3) + public int fedCount; - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); - } + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } - @Parameterized.Parameters - public static Collection<Object[]> data() { - // number of rows or cols has to be >= number of federated locations. - return Arrays.asList(new Object[][] {{10, 13, true, 2},}); - } + @Parameterized.Parameters + public static Collection<Object[]> data() { + // number of rows or cols has to be >= number of federated locations. + return Arrays.asList(new Object[][] {{10, 13, true, 2},}); + } - @Test - public void federatedSinglenodeRead() { - federatedRead(Types.ExecMode.SINGLE_NODE); - } + @Test + public void federatedSinglenodeRead() { + federatedRead(Types.ExecMode.SINGLE_NODE); + } - public void federatedRead(Types.ExecMode execMode) { - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - Types.ExecMode platformOld = rtplatform; - rtplatform = execMode; - if(rtplatform == Types.ExecMode.SPARK) { - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - } - getAndLoadTestConfiguration(TEST_NAME); + public void federatedRead(Types.ExecMode execMode) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + getAndLoadTestConfiguration(TEST_NAME); - // write input matrices - int halfRows = rows / 2; - long[][] begins = new long[][] {new long[] {0, 0}, new long[] {halfRows, 0}}; - long[][] ends = new long[][] {new long[] {halfRows, cols}, new long[] {rows, cols}}; - // We have two matrices handled by a single federated worker - double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42); - double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340); - writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); - writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); - // empty script name because we don't execute any script, just start the worker - fullDMLScriptName = ""; - int port1 = getRandomAvailablePort(); - int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - String host = "localhost"; + // write input matrices + int halfRows = rows / 2; + long[][] begins = new long[][] {new long[] {0, 0}, new long[] {halfRows, 0}}; + long[][] ends = new long[][] {new long[] {halfRows, cols}, new long[] {rows, cols}}; + // We have two matrices handled by a single federated worker + double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42); + double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340); + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1); + Thread t2 = startLocalFedWorkerThread(port2); + String host = "localhost"; - MatrixObject fed = FederatedTestObjectConstructor.constructFederatedInput(rows, - cols, - blocksize, - host, - begins, - ends, - new int[] {port1, port2}, - new String[] {input("X1"), input("X2")}, - input("X.json")); - writeInputFederatedWithMTD("X.json", fed, null); + MatrixObject fed = FederatedTestObjectConstructor.constructFederatedInput(rows, + cols, + blocksize, + host, + begins, + ends, + new int[] {port1, port2}, + new String[] {input("X1"), input("X2")}, + input("X.json")); + writeInputFederatedWithMTD("X.json", fed, null); - try { - // Run reference dml script with normal matrix - fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + (rowPartitioned ? "Row" : "Col") - + "Reference.dml"; - programArgs = new String[] {"-args", input("X1"), input("X2")}; - String refOut = runTest(null).toString(); - // Run federated - fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + ".dml"; - programArgs = new String[] {"-stats", "-args", input("X.json")}; - String out = runTest(null).toString(); - // LOG.error(out); - Assert.assertTrue(heavyHittersContainsString("fed_uak+")); - // Verify output - Assert.assertEquals(Double.parseDouble(refOut.split("\n")[0]), - Double.parseDouble(out.split("\n")[0]), - 0.00001); - } - catch(Exception e) { - e.printStackTrace(); - Assert.assertTrue(false); - } + try { + // Run reference dml script with normal matrix + fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + (rowPartitioned ? "Row" : "Col") + + "Reference.dml"; + programArgs = new String[] {"-args", input("X1"), input("X2")}; + String refOut = runTest(null).toString(); + // Run federated + fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X.json")}; + String out = runTest(null).toString(); + // LOG.error(out); + Assert.assertTrue(heavyHittersContainsString("fed_uak+")); + // Verify output + Assert.assertEquals(Double.parseDouble(refOut.split("\n")[0]), + Double.parseDouble(out.split("\n")[0]), + 0.00001); + } + catch(Exception e) { + e.printStackTrace(); + Assert.assertTrue(false); + } - TestUtils.shutdownThreads(t1, t2); - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; - } + TestUtils.shutdownThreads(t1, t2); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java index ef92a67..e03474d 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java @@ -36,100 +36,100 @@ import org.junit.runners.Parameterized; @net.jcip.annotations.NotThreadSafe public class FederatedWriterTest extends AutomatedTestBase { - // private static final Log LOG = LogFactory.getLog(FederatedWriterTest.class.getName()); - private final static String TEST_DIR = "functions/federated/"; - private final static String TEST_NAME = "FederatedWriterTest"; - private final static String TEST_CLASS_DIR = TEST_DIR + FederatedWriterTest.class.getSimpleName() + "/"; - private final static int blocksize = 1024; - - @Parameterized.Parameter() - public int rows; - @Parameterized.Parameter(1) - public int cols; - @Parameterized.Parameter(2) - public boolean rowPartitioned; - @Parameterized.Parameter(3) - public int fedCount; - - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); - } - - @Parameterized.Parameters - public static Collection<Object[]> data() { - // number of rows or cols has to be >= number of federated locations. - return Arrays.asList(new Object[][] {{10, 13, true, 2},}); - } - - @Test - public void federatedSinglenodeWrite() { - federatedWrite(Types.ExecMode.SINGLE_NODE); - } - - public void federatedWrite(Types.ExecMode execMode) { - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - Types.ExecMode platformOld = rtplatform; - rtplatform = execMode; - if(rtplatform == Types.ExecMode.SPARK) { - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - } - getAndLoadTestConfiguration(TEST_NAME); - - // write input matrices - int halfRows = rows / 2; - // We have two matrices handled by a single federated worker - double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42); - double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340); - writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); - writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); - // empty script name because we don't execute any script, just start the worker - fullDMLScriptName = ""; - int port1 = getRandomAvailablePort(); - int port2 = getRandomAvailablePort(); - Thread t1 = startLocalFedWorkerThread(port1); - Thread t2 = startLocalFedWorkerThread(port2); - - try { - - // Run reader and write a federated json to enable the rest of the test - fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/FederatedReaderTestCreate.dml"; - programArgs = new String[] {"-stats", "-explain", "-args", input("X1"), input("X2"), port1 + "", port2 + "", - input("X.json")}; - // String writer = runTest(null).toString(); - runTest(null); - // LOG.error(writer); - // LOG.error("Writing Done"); - - // Run reference dml script with normal matrix - fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/FederatedReaderTest.dml"; - programArgs = new String[] {"-stats", "-args", input("X.json")}; - String out = runTest(null).toString(); - - Assert.assertTrue(heavyHittersContainsString("fed_uak+")); - - fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/FederatedReference.dml"; - // programArgs = new String[] {"-args", input("X1"), input("X2")}; - programArgs = new String[] {"-stats", "100", "-nvargs", - "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), - "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols}; - String refOut = runTest(null).toString(); - - // Run federated - - // Verify output - Assert.assertEquals(Double.parseDouble(refOut.split("\n")[0]), - Double.parseDouble(out.split("\n")[0]), - 0.00001); - } - catch(Exception e) { - e.printStackTrace(); - Assert.assertTrue(false); - } - - TestUtils.shutdownThreads(t1, t2); - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; - } + // private static final Log LOG = LogFactory.getLog(FederatedWriterTest.class.getName()); + private final static String TEST_DIR = "functions/federated/"; + private final static String TEST_NAME = "FederatedWriterTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedWriterTest.class.getSimpleName() + "/"; + private final static int blocksize = 1024; + + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + @Parameterized.Parameter(2) + public boolean rowPartitioned; + @Parameterized.Parameter(3) + public int fedCount; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Parameterized.Parameters + public static Collection<Object[]> data() { + // number of rows or cols has to be >= number of federated locations. + return Arrays.asList(new Object[][] {{10, 13, true, 2},}); + } + + @Test + public void federatedSinglenodeWrite() { + federatedWrite(Types.ExecMode.SINGLE_NODE); + } + + public void federatedWrite(Types.ExecMode execMode) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + getAndLoadTestConfiguration(TEST_NAME); + + // write input matrices + int halfRows = rows / 2; + // We have two matrices handled by a single federated worker + double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42); + double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340); + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + // empty script name because we don't execute any script, just start the worker + fullDMLScriptName = ""; + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1); + Thread t2 = startLocalFedWorkerThread(port2); + + try { + + // Run reader and write a federated json to enable the rest of the test + fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/FederatedReaderTestCreate.dml"; + programArgs = new String[] {"-stats", "-explain", "-args", input("X1"), input("X2"), port1 + "", port2 + "", + input("X.json")}; + // String writer = runTest(null).toString(); + runTest(null); + // LOG.error(writer); + // LOG.error("Writing Done"); + + // Run reference dml script with normal matrix + fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/FederatedReaderTest.dml"; + programArgs = new String[] {"-stats", "-args", input("X.json")}; + String out = runTest(null).toString(); + + Assert.assertTrue(heavyHittersContainsString("fed_uak+")); + + fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/FederatedReference.dml"; + // programArgs = new String[] {"-args", input("X1"), input("X2")}; + programArgs = new String[] {"-stats", "100", "-nvargs", + "in_X1=" + TestUtils.federatedAddress(port1, input("X1")), + "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols}; + String refOut = runTest(null).toString(); + + // Run federated + + // Verify output + Assert.assertEquals(Double.parseDouble(refOut.split("\n")[0]), + Double.parseDouble(out.split("\n")[0]), + 0.00001); + } + catch(Exception e) { + e.printStackTrace(); + Assert.assertTrue(false); + } + + TestUtils.shutdownThreads(t1, t2); + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java index 194df09..9b321e4 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java @@ -24,172 +24,183 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; -import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; + @RunWith(value = Parameterized.class) @net.jcip.annotations.NotThreadSafe public class FederatedParamservTest extends AutomatedTestBase { - private final static String TEST_DIR = "functions/federated/paramserv/"; - private final static String TEST_NAME = "FederatedParamservTest"; - private final static String TEST_CLASS_DIR = TEST_DIR + FederatedParamservTest.class.getSimpleName() + "/"; - private final static int _blocksize = 1024; - - private final String _networkType; - private final int _numFederatedWorkers; - private final int _examplesPerWorker; - private final int _epochs; - private final int _batch_size; - private final double _eta; - private final String _utype; - private final String _freq; - - private Types.ExecMode _platformOld; - - // parameters - @Parameterized.Parameters - public static Collection<Object[]> parameters() { - return Arrays.asList(new Object[][] { - //Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update type, update frequency - {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, - {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, - {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, - {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"}, - {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"}, - {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"}, - {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"}, - {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"}, - {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"}, - {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"}, - {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"} - }); - } - - public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size, int epochs, double eta, String utype, String freq) { - _networkType = networkType; - _numFederatedWorkers = numFederatedWorkers; - _examplesPerWorker = examplesPerWorker; - _batch_size = batch_size; - _epochs = epochs; - _eta = eta; - _utype = utype; - _freq = freq; - } - - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); - - _platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); - } - - @Override - public void tearDown() { - - rtplatform = _platformOld; - } - - @Test - public void federatedParamserv() { - // config - getAndLoadTestConfiguration(TEST_NAME); - String HOME = SCRIPT_DIR + TEST_DIR; - setOutputBuffering(true); - - int C = 1, Hin = 28, Win = 28; - int numFeatures = C*Hin*Win; - int numLabels = 10; - - // dml name - fullDMLScriptName = HOME + TEST_NAME + ".dml"; - // generate program args - List<String> programArgsList = new ArrayList<>(Arrays.asList( - "-stats", - "-nvargs", - "examples_per_worker=" + _examplesPerWorker, - "num_features=" + numFeatures, - "num_labels=" + numLabels, - "epochs=" + _epochs, - "batch_size=" + _batch_size, - "eta=" + _eta, - "utype=" + _utype, - "freq=" + _freq, - "network_type=" + _networkType, - "channels=" + C, - "hin=" + Hin, - "win=" + Win - )); - - // for each worker - List<Integer> ports = new ArrayList<>(); - List<Thread> threads = new ArrayList<>(); - for(int i = 0; i < _numFederatedWorkers; i++) { - // write row partitioned features to disk - writeInputMatrixWithMTD("X" + i, generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), false, - new MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize, _examplesPerWorker * numFeatures)); - // write row partitioned labels to disk - writeInputMatrixWithMTD("y" + i, generateDummyMNISTLabels(_examplesPerWorker, numLabels), false, - new MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize, _examplesPerWorker * numLabels)); - - // start worker - ports.add(getRandomAvailablePort()); - threads.add(startLocalFedWorkerThread(ports.get(i))); - - // add worker to program args - programArgsList.add("X" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("X" + i))); - programArgsList.add("y" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("y" + i))); - } - - programArgs = programArgsList.toArray(new String[0]); - // ByteArrayOutputStream stdout = - runTest(null); - // System.out.print(stdout.toString()); - - // cleanup - for(int i = 0; i < _numFederatedWorkers; i++) { - TestUtils.shutdownThreads(threads.get(i)); - } - } - - /** - * Generates an feature matrix that has the same format as the MNIST dataset, - * but is completely random and normalized - * - * @param numExamples Number of examples to generate - * @param C Channels in the input data - * @param Hin Height in Pixels of the input data - * @param Win Width in Pixels of the input data - * @return a dummy MNIST feature matrix - */ - private double[][] generateDummyMNISTFeatures(int numExamples, int C, int Hin, int Win) { - // Seed -1 takes the time in milliseconds as a seed - // Sparsity 1 means no sparsity - return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1); - } - - /** - * Generates an label matrix that has the same format as the MNIST dataset, but is completely random and consists - * of one hot encoded vectors as rows - * - * @param numExamples Number of examples to generate - * @param numLabels Number of labels to generate - * @return a dummy MNIST lable matrix - */ - private double[][] generateDummyMNISTLabels(int numExamples, int numLabels) { - // Seed -1 takes the time in milliseconds as a seed - // Sparsity 1 means no sparsity - return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1); - } + private final static String TEST_DIR = "functions/federated/paramserv/"; + private final static String TEST_NAME = "FederatedParamservTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederatedParamservTest.class.getSimpleName() + "/"; + private final static int _blocksize = 1024; + + private final String _networkType; + private final int _numFederatedWorkers; + private final int _examplesPerWorker; + private final int _epochs; + private final int _batch_size; + private final double _eta; + private final String _utype; + private final String _freq; + + // parameters + @Parameterized.Parameters + public static Collection<Object[]> parameters() { + return Arrays.asList(new Object[][] { + //Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update type, update frequency + {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, + {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, + {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, + {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, + {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, + {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"}, + {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, + {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"}, + {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"}, + {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"}, + {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"}, + {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"}, + {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"}, + {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"}, + {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"}, + {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"} + }); + } + + public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size, int epochs, double eta, String utype, String freq) { + _networkType = networkType; + _numFederatedWorkers = numFederatedWorkers; + _examplesPerWorker = examplesPerWorker; + _batch_size = batch_size; + _epochs = epochs; + _eta = eta; + _utype = utype; + _freq = freq; + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void federatedParamservSingleNode() { + federatedParamserv(ExecMode.SINGLE_NODE); + } + + @Test + public void federatedParamservHybrid() { + federatedParamserv(ExecMode.HYBRID); + } + + private void federatedParamserv(ExecMode mode) { + // config + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + setOutputBuffering(true); + + int C = 1, Hin = 28, Win = 28; + int numFeatures = C*Hin*Win; + int numLabels = 10; + + ExecMode platformOld = setExecMode(mode); + + try { + + // dml name + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + // generate program args + List<String> programArgsList = new ArrayList<>(Arrays.asList( + "-stats", + "-nvargs", + "examples_per_worker=" + _examplesPerWorker, + "num_features=" + numFeatures, + "num_labels=" + numLabels, + "epochs=" + _epochs, + "batch_size=" + _batch_size, + "eta=" + _eta, + "utype=" + _utype, + "freq=" + _freq, + "network_type=" + _networkType, + "channels=" + C, + "hin=" + Hin, + "win=" + Win + )); + + // for each worker + List<Integer> ports = new ArrayList<>(); + List<Thread> threads = new ArrayList<>(); + for(int i = 0; i < _numFederatedWorkers; i++) { + // write row partitioned features to disk + writeInputMatrixWithMTD("X" + i, generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), false, + new MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize, _examplesPerWorker * numFeatures)); + // write row partitioned labels to disk + writeInputMatrixWithMTD("y" + i, generateDummyMNISTLabels(_examplesPerWorker, numLabels), false, + new MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize, _examplesPerWorker * numLabels)); + + // start worker + ports.add(getRandomAvailablePort()); + threads.add(startLocalFedWorkerThread(ports.get(i))); + + // add worker to program args + programArgsList.add("X" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("X" + i))); + programArgsList.add("y" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("y" + i))); + } + + programArgs = programArgsList.toArray(new String[0]); + // ByteArrayOutputStream stdout = + runTest(null); + // System.out.print(stdout.toString()); + Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst()); + + // cleanup + for(int i = 0; i < _numFederatedWorkers; i++) { + TestUtils.shutdownThreads(threads.get(i)); + } + } + finally { + resetExecMode(platformOld); + } + } + + /** + * Generates an feature matrix that has the same format as the MNIST dataset, + * but is completely random and normalized + * + * @param numExamples Number of examples to generate + * @param C Channels in the input data + * @param Hin Height in Pixels of the input data + * @param Win Width in Pixels of the input data + * @return a dummy MNIST feature matrix + */ + private double[][] generateDummyMNISTFeatures(int numExamples, int C, int Hin, int Win) { + // Seed -1 takes the time in milliseconds as a seed + // Sparsity 1 means no sparsity + return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1); + } + + /** + * Generates an label matrix that has the same format as the MNIST dataset, but is completely random and consists + * of one hot encoded vectors as rows + * + * @param numExamples Number of examples to generate + * @param numLabels Number of labels to generate + * @return a dummy MNIST lable matrix + */ + private double[][] generateDummyMNISTLabels(int numExamples, int numLabels) { + // Seed -1 takes the time in milliseconds as a seed + // Sparsity 1 means no sparsity + return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1); + } }
