This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new e8dcd3d  [SYSTEMDS-2911] Asynchronous accrue gradients for paramserv 
epoch sync
e8dcd3d is described below

commit e8dcd3de52ab27dc19890f7252978d5e4d9e83bd
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Mar 24 23:55:47 2021 +0100

    [SYSTEMDS-2911] Asynchronous accrue gradients for paramserv epoch sync
    
    This patch moves the aggregation of gradients in parameter server epoch
    synchronization out of the critical path by executing these updates
    asynchronously. On a single worker paramserv CNN run on mnist (with
    batch size 32), this patch improved end-to-end performance from 331s to
    315s (~5%). For the federated parameter server this is not applied yet,
    because it would require a holistic handling of caching and buffer pool
    management (which also brings up/down the system-wide maintenance thread
    pool).
    
    Furthermore, this patch also includes a fix of the recently modified
    broadcast handling in parfor spark jobs. An mlcontext test (mis)used the
    known export for testing proper scratch space cleanup when running
    multiple scripts though the programmatic APIs.
---
 .../runtime/controlprogram/ParForProgramBlock.java |  2 +-
 .../controlprogram/paramserv/LocalPSWorker.java    | 52 +++++++++++++---------
 .../runtime/controlprogram/paramserv/PSWorker.java |  7 +++
 .../controlprogram/paramserv/ParamServer.java      |  4 +-
 .../cp/ParamservBuiltinCPInstruction.java          |  2 +-
 .../mlcontext/MLContextScratchCleanupTest.java     |  9 ++--
 .../functions/mlcontext/ScratchCleanup1.dml        |  4 +-
 7 files changed, 50 insertions(+), 30 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index c3c6909..0c97d0a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -297,7 +297,7 @@ public class ParForProgramBlock extends ForProgramBlock
        public static final boolean FORCE_CP_ON_REMOTE_SPARK    = true; // 
compile body to CP if exec type forced to Spark
        public static final boolean LIVEVAR_AWARE_EXPORT        = true; // 
export only read variables according to live variable analysis
        public static final boolean RESET_RECOMPILATION_FLAGs   = true;
-       public static final boolean ALLOW_BROADCAST_INPUTS      = true; // 
enables to broadcast inputs for remote_spark
+       public static       boolean ALLOW_BROADCAST_INPUTS      = true; // 
enables to broadcast inputs for remote_spark
        
        public static final String PARFOR_FNAME_PREFIX          = "/parfor/"; 
        public static final String PARFOR_MR_TASKS_TMP_FNAME    = 
PARFOR_FNAME_PREFIX + "%ID%_MR_taskfile"; 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
index 241bfc6..8ba81f6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -20,7 +20,10 @@
 package org.apache.sysds.runtime.controlprogram.paramserv;
 
 import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
 
+import org.apache.commons.lang3.concurrent.ConcurrentUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
@@ -37,15 +40,12 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        protected static final Log LOG = 
LogFactory.getLog(LocalPSWorker.class.getName());
        private static final long serialVersionUID = 5195390748495357295L;
 
-       private boolean _parUpdates = false;
-       
        protected LocalPSWorker() {}
 
        public LocalPSWorker(int workerID, String updFunc, 
Statement.PSFrequency freq,
-               int epochs, long batchSize, ExecutionContext ec, ParamServer 
ps, boolean parUpdates)
+               int epochs, long batchSize, ExecutionContext ec, ParamServer ps)
        {
                super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
-               _parUpdates = parUpdates;
        }
 
        @Override
@@ -84,27 +84,35 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                for (int i = 0; i < _epochs; i++) {
                        // Pull the global parameters from ps
                        ListObject params = pullModel();
-                       ListObject accGradients = null;
-                       
-                       for (int j = 0; j < batchIter; j++) {
-                               ListObject gradients = computeGradients(params, 
dataSize, batchIter, i, j);
+                       Future<ListObject> accGradients = 
ConcurrentUtils.constantFuture(null);
 
-                               boolean localUpdate = j < batchIter - 1;
-                               // Accumulate the intermediate gradients
-                               accGradients = ParamservUtils.accrueGradients(
-                                       accGradients, gradients, _parUpdates, 
!localUpdate);
-
-                               // Update the local model with gradients
-                               if(localUpdate)
-                                       params = updateModel(params, gradients, 
i, j, batchIter);
+                       try {
+                               for (int j = 0; j < batchIter; j++) {
+                                       ListObject gradients = 
computeGradients(params, dataSize, batchIter, i, j);
+       
+                                       boolean localUpdate = j < batchIter - 1;
+                                       
+                                       // Accumulate the intermediate 
gradients (async for overlap w/ model updates 
+                                       // and gradient computation, sequential 
over gradient matrices to avoid deadlocks)
+                                       ListObject accGradientsPrev = 
accGradients.get();
+                                       accGradients = _tpool.submit(() -> 
ParamservUtils.accrueGradients(
+                                               accGradientsPrev, gradients, 
false, !localUpdate));
+       
+                                       // Update the local model with gradients
+                                       if(localUpdate)
+                                               params = updateModel(params, 
gradients, i, j, batchIter);
+       
+                                       accNumBatches(1);
+                               }
 
-                               accNumBatches(1);
+                               // Push the gradients to ps
+                               pushGradients(accGradients.get());
+                               ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
                        }
-
-                       // Push the gradients to ps
-                       pushGradients(accGradients);
-                       ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
-
+                       catch(ExecutionException | InterruptedException ex) {
+                               throw new DMLRuntimeException(ex);
+                       }
+                       
                        accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index c0389f3..cc75e52 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.controlprogram.paramserv;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.concurrent.ExecutorService;
 import java.util.stream.Collectors;
 
 import org.apache.sysds.common.Types.DataType;
@@ -29,6 +30,7 @@ import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
@@ -39,6 +41,11 @@ public abstract class PSWorker implements Serializable
 {
        private static final long serialVersionUID = -3510485051178200118L;
 
+       // thread pool for asynchronous accrue gradients on epoch scheduling
+       // Note: we use a non-static variable to obtain the live maintenance 
thread pool
+       // which is important in scenarios w/ multiple scripts in a single JVM 
(e.g., tests)
+       protected ExecutorService _tpool = LazyWriteBuffer.getUtilThreadPool();
+       
        protected int _workerID;
        protected int _epochs;
        protected long _batchSize;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 4fe072c..96a08e3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -230,8 +230,8 @@ public abstract class ParamServer
                                }
                                case ASP: {
                                        updateGlobalModel(gradients);
-                                       // This if works similarly to the one 
for BSP, but divides the sync couter through the number of workers,
-                                       // creating "Pseudo Epochs"
+                                       // This works similarly to the one for 
BSP, but divides the sync counter by
+                                       // the number of workers, creating 
"Pseudo Epochs"
                                        if (_numBatchesPerEpoch != -1 &&
                                                ((_freq == 
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
                                                (_freq == 
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float) 
_numBatchesPerEpoch == 0))) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index e64fdf8..4057f73 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -332,7 +332,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                // Create the local workers
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
                        .mapToObj(i -> new LocalPSWorker(i, updFunc, freq,
-                               getEpochs(), getBatchSize(), workerECs.get(i), 
ps, workerNum==1))
+                               getEpochs(), getBatchSize(), workerECs.get(i), 
ps))
                        .collect(Collectors.toList());
 
                // Do data partition
diff --git 
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextScratchCleanupTest.java
 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextScratchCleanupTest.java
index 28a6543..f2e8c2e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextScratchCleanupTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextScratchCleanupTest.java
@@ -28,13 +28,14 @@ import org.junit.After;
 import org.junit.Test;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
 import org.apache.sysds.api.mlcontext.MLContext;
 import org.apache.sysds.api.mlcontext.Matrix;
 import org.apache.sysds.api.mlcontext.Script;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestUtils;
 
-
[email protected]
 public class MLContextScratchCleanupTest extends AutomatedTestBase 
 {
        private final static String TEST_DIR = "functions/mlcontext";
@@ -88,10 +89,11 @@ public class MLContextScratchCleanupTest extends 
AutomatedTestBase
                //create mlcontext
                SparkSession spark = 
createSystemDSSparkSession("MLContextScratchCleanupTest", "local");
                MLContext ml = new MLContext(spark);
-               ml.setExplain(true);
-
+               
                String dml1 = baseDirectory + File.separator + 
"ScratchCleanup1.dml";
                String dml2 = baseDirectory + File.separator + 
(wRead?"ScratchCleanup2b.dml":"ScratchCleanup2.dml");
+               boolean broadcastOld = 
ParForProgramBlock.ALLOW_BROADCAST_INPUTS;
+               ParForProgramBlock.ALLOW_BROADCAST_INPUTS = false;
                
                try
                {
@@ -110,6 +112,7 @@ public class MLContextScratchCleanupTest extends 
AutomatedTestBase
                }
                finally {
                        DMLScript.setGlobalExecMode(oldplatform);
+                       ParForProgramBlock.ALLOW_BROADCAST_INPUTS = 
broadcastOld;
                        
                        // stop underlying spark context to allow single jvm 
tests (otherwise the
                        // next test that tries to create a SparkContext would 
fail)
diff --git a/src/test/scripts/functions/mlcontext/ScratchCleanup1.dml 
b/src/test/scripts/functions/mlcontext/ScratchCleanup1.dml
index 1fb2c7e..b0c466c 100644
--- a/src/test/scripts/functions/mlcontext/ScratchCleanup1.dml
+++ b/src/test/scripts/functions/mlcontext/ScratchCleanup1.dml
@@ -23,7 +23,9 @@ X = rand(rows=$rows, cols=$cols);
 
 #force export of X via remote parfor
 parfor(i in 1:ncol(X), opt=CONSTRAINED, mode=REMOTE_SPARK) {
-   print(sum(X[,i]));
+  x = sum(X[,i]);
+  if( x < 0 )
+    print(x)
 }
 
 write(X, "out/X", format="binary");
\ No newline at end of file

Reply via email to