This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new e8dcd3d [SYSTEMDS-2911] Asynchronous accrue gradients for paramserv
epoch sync
e8dcd3d is described below
commit e8dcd3de52ab27dc19890f7252978d5e4d9e83bd
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Mar 24 23:55:47 2021 +0100
[SYSTEMDS-2911] Asynchronous accrue gradients for paramserv epoch sync
This patch moves the aggregation of gradients in parameter server epoch
synchronization out of the critical path by executing these updates
asynchronously. On a single worker paramserv CNN run on mnist (with
batch size 32), this patch improved end-to-end performance from 331s to
315s (~5%). For the federated parameter server this is not applied yet,
because it would require a holistic handling of caching and buffer pool
management (which also brings up/down the system-wide maintenance thread
pool).
Furthermore, this patch also includes a fix of the recently modified
broadcast handling in parfor spark jobs. An mlcontext test (mis)used the
known export for testing proper scratch space cleanup when running
multiple scripts though the programmatic APIs.
---
.../runtime/controlprogram/ParForProgramBlock.java | 2 +-
.../controlprogram/paramserv/LocalPSWorker.java | 52 +++++++++++++---------
.../runtime/controlprogram/paramserv/PSWorker.java | 7 +++
.../controlprogram/paramserv/ParamServer.java | 4 +-
.../cp/ParamservBuiltinCPInstruction.java | 2 +-
.../mlcontext/MLContextScratchCleanupTest.java | 9 ++--
.../functions/mlcontext/ScratchCleanup1.dml | 4 +-
7 files changed, 50 insertions(+), 30 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index c3c6909..0c97d0a 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -297,7 +297,7 @@ public class ParForProgramBlock extends ForProgramBlock
public static final boolean FORCE_CP_ON_REMOTE_SPARK = true; //
compile body to CP if exec type forced to Spark
public static final boolean LIVEVAR_AWARE_EXPORT = true; //
export only read variables according to live variable analysis
public static final boolean RESET_RECOMPILATION_FLAGs = true;
- public static final boolean ALLOW_BROADCAST_INPUTS = true; //
enables to broadcast inputs for remote_spark
+ public static boolean ALLOW_BROADCAST_INPUTS = true; //
enables to broadcast inputs for remote_spark
public static final String PARFOR_FNAME_PREFIX = "/parfor/";
public static final String PARFOR_MR_TASKS_TMP_FNAME =
PARFOR_FNAME_PREFIX + "%ID%_MR_taskfile";
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
index 241bfc6..8ba81f6 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -20,7 +20,10 @@
package org.apache.sysds.runtime.controlprogram.paramserv;
import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import org.apache.commons.lang3.concurrent.ConcurrentUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
@@ -37,15 +40,12 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
protected static final Log LOG =
LogFactory.getLog(LocalPSWorker.class.getName());
private static final long serialVersionUID = 5195390748495357295L;
- private boolean _parUpdates = false;
-
protected LocalPSWorker() {}
public LocalPSWorker(int workerID, String updFunc,
Statement.PSFrequency freq,
- int epochs, long batchSize, ExecutionContext ec, ParamServer
ps, boolean parUpdates)
+ int epochs, long batchSize, ExecutionContext ec, ParamServer ps)
{
super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
- _parUpdates = parUpdates;
}
@Override
@@ -84,27 +84,35 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
for (int i = 0; i < _epochs; i++) {
// Pull the global parameters from ps
ListObject params = pullModel();
- ListObject accGradients = null;
-
- for (int j = 0; j < batchIter; j++) {
- ListObject gradients = computeGradients(params,
dataSize, batchIter, i, j);
+ Future<ListObject> accGradients =
ConcurrentUtils.constantFuture(null);
- boolean localUpdate = j < batchIter - 1;
- // Accumulate the intermediate gradients
- accGradients = ParamservUtils.accrueGradients(
- accGradients, gradients, _parUpdates,
!localUpdate);
-
- // Update the local model with gradients
- if(localUpdate)
- params = updateModel(params, gradients,
i, j, batchIter);
+ try {
+ for (int j = 0; j < batchIter; j++) {
+ ListObject gradients =
computeGradients(params, dataSize, batchIter, i, j);
+
+ boolean localUpdate = j < batchIter - 1;
+
+ // Accumulate the intermediate
gradients (async for overlap w/ model updates
+ // and gradient computation, sequential
over gradient matrices to avoid deadlocks)
+ ListObject accGradientsPrev =
accGradients.get();
+ accGradients = _tpool.submit(() ->
ParamservUtils.accrueGradients(
+ accGradientsPrev, gradients,
false, !localUpdate));
+
+ // Update the local model with gradients
+ if(localUpdate)
+ params = updateModel(params,
gradients, i, j, batchIter);
+
+ accNumBatches(1);
+ }
- accNumBatches(1);
+ // Push the gradients to ps
+ pushGradients(accGradients.get());
+ ParamservUtils.cleanupListObject(_ec,
Statement.PS_MODEL);
}
-
- // Push the gradients to ps
- pushGradients(accGradients);
- ParamservUtils.cleanupListObject(_ec,
Statement.PS_MODEL);
-
+ catch(ExecutionException | InterruptedException ex) {
+ throw new DMLRuntimeException(ex);
+ }
+
accNumEpochs(1);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d
epoch.", getWorkerName(), i + 1));
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index c0389f3..cc75e52 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.controlprogram.paramserv;
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types.DataType;
@@ -29,6 +30,7 @@ import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
@@ -39,6 +41,11 @@ public abstract class PSWorker implements Serializable
{
private static final long serialVersionUID = -3510485051178200118L;
+ // thread pool for asynchronous accrue gradients on epoch scheduling
+ // Note: we use a non-static variable to obtain the live maintenance
thread pool
+ // which is important in scenarios w/ multiple scripts in a single JVM
(e.g., tests)
+ protected ExecutorService _tpool = LazyWriteBuffer.getUtilThreadPool();
+
protected int _workerID;
protected int _epochs;
protected long _batchSize;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 4fe072c..96a08e3 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -230,8 +230,8 @@ public abstract class ParamServer
}
case ASP: {
updateGlobalModel(gradients);
- // This if works similarly to the one
for BSP, but divides the sync couter through the number of workers,
- // creating "Pseudo Epochs"
+ // This works similarly to the one for
BSP, but divides the sync counter by
+ // the number of workers, creating
"Pseudo Epochs"
if (_numBatchesPerEpoch != -1 &&
((_freq ==
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
(_freq ==
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float)
_numBatchesPerEpoch == 0))) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index e64fdf8..4057f73 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -332,7 +332,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
// Create the local workers
List<LocalPSWorker> workers = IntStream.range(0, workerNum)
.mapToObj(i -> new LocalPSWorker(i, updFunc, freq,
- getEpochs(), getBatchSize(), workerECs.get(i),
ps, workerNum==1))
+ getEpochs(), getBatchSize(), workerECs.get(i),
ps))
.collect(Collectors.toList());
// Do data partition
diff --git
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextScratchCleanupTest.java
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextScratchCleanupTest.java
index 28a6543..f2e8c2e 100644
---
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextScratchCleanupTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextScratchCleanupTest.java
@@ -28,13 +28,14 @@ import org.junit.After;
import org.junit.Test;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysds.api.mlcontext.MLContext;
import org.apache.sysds.api.mlcontext.Matrix;
import org.apache.sysds.api.mlcontext.Script;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestUtils;
-
[email protected]
public class MLContextScratchCleanupTest extends AutomatedTestBase
{
private final static String TEST_DIR = "functions/mlcontext";
@@ -88,10 +89,11 @@ public class MLContextScratchCleanupTest extends
AutomatedTestBase
//create mlcontext
SparkSession spark =
createSystemDSSparkSession("MLContextScratchCleanupTest", "local");
MLContext ml = new MLContext(spark);
- ml.setExplain(true);
-
+
String dml1 = baseDirectory + File.separator +
"ScratchCleanup1.dml";
String dml2 = baseDirectory + File.separator +
(wRead?"ScratchCleanup2b.dml":"ScratchCleanup2.dml");
+ boolean broadcastOld =
ParForProgramBlock.ALLOW_BROADCAST_INPUTS;
+ ParForProgramBlock.ALLOW_BROADCAST_INPUTS = false;
try
{
@@ -110,6 +112,7 @@ public class MLContextScratchCleanupTest extends
AutomatedTestBase
}
finally {
DMLScript.setGlobalExecMode(oldplatform);
+ ParForProgramBlock.ALLOW_BROADCAST_INPUTS =
broadcastOld;
// stop underlying spark context to allow single jvm
tests (otherwise the
// next test that tries to create a SparkContext would
fail)
diff --git a/src/test/scripts/functions/mlcontext/ScratchCleanup1.dml
b/src/test/scripts/functions/mlcontext/ScratchCleanup1.dml
index 1fb2c7e..b0c466c 100644
--- a/src/test/scripts/functions/mlcontext/ScratchCleanup1.dml
+++ b/src/test/scripts/functions/mlcontext/ScratchCleanup1.dml
@@ -23,7 +23,9 @@ X = rand(rows=$rows, cols=$cols);
#force export of X via remote parfor
parfor(i in 1:ncol(X), opt=CONSTRAINED, mode=REMOTE_SPARK) {
- print(sum(X[,i]));
+ x = sum(X[,i]);
+ if( x < 0 )
+ print(x)
}
write(X, "out/X", format="binary");
\ No newline at end of file