Repository: systemml
Updated Branches:
  refs/heads/master e9268d9e7 -> e7fccd1c7


[SYSTEMML-2416] Simplified paramserv aggregation service

Closes #790.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e7fccd1c
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e7fccd1c
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e7fccd1c

Branch: refs/heads/master
Commit: e7fccd1c764ec470f6460e4e3cec90913e606798
Parents: e9268d9
Author: EdgarLGB <[email protected]>
Authored: Fri Jun 22 22:29:49 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Jun 22 23:56:15 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/paramserv/LocalPSWorker.java |   2 +-
 .../paramserv/LocalParamServer.java             |  13 +-
 .../controlprogram/paramserv/ParamServer.java   | 301 ++++++++-----------
 .../cp/ParamservBuiltinCPInstruction.java       |   2 -
 4 files changed, 129 insertions(+), 189 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e7fccd1c/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index 4f472ee..0ed7c81 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -107,7 +107,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        private ListObject updateModel(ListObject globalParams, ListObject 
gradients, int i, int j, int totalIter) {
                Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
 
-               globalParams = _ps.updateModel(_ec, gradients, globalParams);
+               globalParams = _ps.updateLocalModel(_ec, gradients, 
globalParams);
 
                if (DMLScript.STATISTICS)
                        Statistics.accPSLocalModelUpdateTime((long) 
tUpd.stop());

http://git-wip-us.apache.org/repos/asf/systemml/blob/e7fccd1c/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
index d20383d..52372c9 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
-import java.util.concurrent.ExecutionException;
-
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
@@ -35,16 +33,7 @@ public class LocalParamServer extends ParamServer {
 
        @Override
        public void push(int workerID, ListObject gradients) {
-               try {
-                       _gradientsQueue.put(new Gradient(workerID, gradients));
-               } catch (InterruptedException e) {
-                       throw new DMLRuntimeException(e);
-               }
-               try {
-                       launchService();
-               } catch (ExecutionException | InterruptedException e) {
-                       throw new DMLRuntimeException("Aggregate service: some 
error occurred: ", e);
-               }
+               updateGlobalModel(workerID, gradients);
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/e7fccd1c/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index 4af72a4..abec267 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -27,16 +27,10 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.ArrayBlockingQueue;
 import java.util.concurrent.BlockingQueue;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.LinkedBlockingDeque;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 import org.apache.commons.lang3.ArrayUtils;
-import org.apache.commons.lang3.concurrent.BasicThreadFactory;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
@@ -53,213 +47,172 @@ import 
org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.apache.sysml.utils.Statistics;
 
-public abstract class ParamServer {
-
-       final BlockingQueue<Gradient> _gradientsQueue;
-       final Map<Integer, BlockingQueue<ListObject>> _modelMap;
-       private final AggregationService _aggService;
-       private final ExecutorService _es;
+public abstract class ParamServer 
+{
+       protected final Log LOG = 
LogFactory.getLog(ParamServer.class.getName());
+       
+       // worker input queues and global model
+       protected final Map<Integer, BlockingQueue<ListObject>> _modelMap;
        private ListObject _model;
 
-       ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType 
updateType, ExecutionContext ec, int workerNum) {
-               _gradientsQueue = new LinkedBlockingDeque<>();
+       //aggregation service
+       protected final ExecutionContext _ec;
+       private final Statement.PSUpdateType _updateType;
+       private final FunctionCallCPInstruction _inst;
+       private final String _outputName;
+       private final boolean[] _finishedStates;  // Workers' finished states
+
+       protected ParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
+               // init worker queues and global model
                _modelMap = new HashMap<>(workerNum);
                IntStream.range(0, workerNum).forEach(i -> {
                        // Create a single element blocking queue for workers 
to receive the broadcasted model
                        _modelMap.put(i, new ArrayBlockingQueue<>(1));
                });
                _model = model;
-               _aggService = new AggregationService(aggFunc, updateType, ec, 
workerNum);
+               
+               // init aggregation service
+               _ec = ec;
+               _updateType = updateType;
+               _finishedStates = new boolean[workerNum];
+               String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, 
PS_FUNC_PREFIX);
+               String ns = cfn[0];
+               String fname = cfn[1];
+               FunctionProgramBlock func = 
_ec.getProgram().getFunctionProgramBlock(ns, fname);
+               ArrayList<DataIdentifier> inputs = func.getInputParams();
+               ArrayList<DataIdentifier> outputs = func.getOutputParams();
+
+               // Check the output of the aggregation function
+               if (outputs.size() != 1) {
+                       throw new DMLRuntimeException(String.format("The output 
of the '%s' function should provide one list containing the updated model.", 
aggFunc));
+               }
+               if (outputs.get(0).getDataType() != Expression.DataType.LIST) {
+                       throw new DMLRuntimeException(String.format("The output 
of the '%s' function should be type of list.", aggFunc));
+               }
+               _outputName = outputs.get(0).getName();
+
+               CPOperand[] boundInputs = inputs.stream()
+                       .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
+                       .toArray(CPOperand[]::new);
+               ArrayList<String> inputNames = 
inputs.stream().map(DataIdentifier::getName)
+                       .collect(Collectors.toCollection(ArrayList::new));
+               ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
+                       .collect(Collectors.toCollection(ArrayList::new));
+               _inst = new FunctionCallCPInstruction(ns, fname, boundInputs, 
inputNames, outputNames, "aggregate function");
+               
+               // broadcast initial model
                try {
-                       _aggService.broadcastModel();
+                       broadcastModel();
                }
                catch (InterruptedException e) {
                        throw new DMLRuntimeException("Param server: failed to 
broadcast the initial model.", e);
                }
-               BasicThreadFactory factory = new BasicThreadFactory.Builder()
-                       .namingPattern("agg-service-pool-thread-%d").build();
-               _es = Executors.newSingleThreadExecutor(factory);
        }
 
        public abstract void push(int workerID, ListObject value);
 
        public abstract Data pull(int workerID);
 
-       void launchService() throws ExecutionException, InterruptedException {
-               _es.submit(_aggService).get();
-       }
-
-       public void shutdown() {
-               _es.shutdownNow();
-       }
-
        public ListObject getResult() {
                // All the model updating work has terminated,
                // so we could return directly the result model
                return _model;
        }
-
-       public ListObject updateModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
-               return _aggService.updateModel(ec, gradients, model);
-       }
-
-       public static class Gradient {
-               final int _workerID;
-               final ListObject _gradients;
-
-               public Gradient(int workerID, ListObject gradients) {
-                       _workerID = workerID;
-                       _gradients = gradients;
-               }
-       }
        
-       /**
-        * Inner aggregation service which is for updating the model
-        */
-       private class AggregationService implements Callable<Void> {
-
-               protected final Log LOG = 
LogFactory.getLog(AggregationService.class.getName());
-
-               protected final ExecutionContext _ec;
-               private final Statement.PSUpdateType _updateType;
-               private final FunctionCallCPInstruction _inst;
-               private final DataIdentifier _output;
-               private final boolean[] _finishedStates;  // Workers' finished 
states
-
-               AggregationService(String aggFunc, Statement.PSUpdateType 
updateType, ExecutionContext ec, int workerNum) {
-                       _ec = ec;
-                       _updateType = updateType;
-                       _finishedStates = new boolean[workerNum];
-
-                       // Fetch the aggregation function
-                       String[] cfn = 
ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX);
-                       String ns = cfn[0];
-                       String fname = cfn[1];
-                       FunctionProgramBlock func = 
_ec.getProgram().getFunctionProgramBlock(ns, fname);
-                       ArrayList<DataIdentifier> inputs = 
func.getInputParams();
-                       ArrayList<DataIdentifier> outputs = 
func.getOutputParams();
-
-                       // Check the output of the aggregation function
-                       if (outputs.size() != 1) {
-                               throw new 
DMLRuntimeException(String.format("The output of the '%s' function should 
provide one list containing the updated model.", aggFunc));
-                       }
-                       if (outputs.get(0).getDataType() != 
Expression.DataType.LIST) {
-                               throw new 
DMLRuntimeException(String.format("The output of the '%s' function should be 
type of list.", aggFunc));
+       protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
+               try {
+                       if (LOG.isDebugEnabled()) {
+                               LOG.debug(String.format("Successfully pulled 
the gradients [size:%d kb] of worker_%d.",
+                                       gradients.getDataSize() / 1024, 
workerID));
                        }
-                       _output = outputs.get(0);
-
-                       CPOperand[] boundInputs = inputs.stream()
-                               .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
-                               .toArray(CPOperand[]::new);
-                       ArrayList<String> inputNames = 
inputs.stream().map(DataIdentifier::getName)
-                               
.collect(Collectors.toCollection(ArrayList::new));
-                       ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
-                               
.collect(Collectors.toCollection(ArrayList::new));
-                       _inst = new FunctionCallCPInstruction(ns, fname, 
boundInputs, inputNames, outputNames, "aggregate function");
-               }
-
-               private boolean allFinished() {
-                       return !ArrayUtils.contains(_finishedStates, false);
-               }
-
-               private void resetFinishedStates() {
-                       Arrays.fill(_finishedStates, false);
-               }
-
-               private void setFinishedState(int workerID) {
-                       _finishedStates[workerID] = true;
-               }
-
-               private void broadcastModel() throws InterruptedException {
-                       Timing tBroad = DMLScript.STATISTICS ? new Timing(true) 
: null;
-
-                       //broadcast copy of the model to all workers, cleaned 
up by workers
-                       for (BlockingQueue<ListObject> q : _modelMap.values())
-                               q.put(ParamservUtils.copyList(_model));
 
+                       // Update and redistribute the model
+                       Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : 
null;
+                       _model = updateLocalModel(_ec, gradients, _model);
                        if (DMLScript.STATISTICS)
-                               Statistics.accPSModelBroadcastTime((long) 
tBroad.stop());
+                               Statistics.accPSAggregationTime((long) 
tAgg.stop());
+
+                       // Redistribute model according to update type
+                       switch(_updateType) {
+                               case BSP: {
+                                       setFinishedState(workerID);
+                                       if (allFinished()) {
+                                               // Broadcast the updated model
+                                               resetFinishedStates();
+                                               broadcastModel();
+                                               if (LOG.isDebugEnabled())
+                                                       LOG.debug("Global 
parameter is broadcasted successfully.");
+                                       }
+                                       break;
+                               }
+                               case ASP: {
+                                       broadcastModel(workerID);
+                                       break;
+                               }
+                               default:
+                                       throw new 
DMLRuntimeException("Unsupported update: " + _updateType.name());
+                       }
+               } 
+               catch (Exception e) {
+                       throw new DMLRuntimeException("Aggregation service 
failed: ", e);
                }
+       }
 
-               private void broadcastModel(int workerID) throws 
InterruptedException {
-                       Timing tBroad = DMLScript.STATISTICS ? new Timing(true) 
: null;
+       /**
+        * A service method for updating model with gradients
+        *
+        * @param ec execution context
+        * @param gradients list of gradients
+        * @param model old model
+        * @return new model
+        */
+       protected ListObject updateLocalModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
+               // Populate the variables table with the gradients and model
+               ec.setVariable(Statement.PS_GRADIENTS, gradients);
+               ec.setVariable(Statement.PS_MODEL, model);
 
-                       //broadcast copy of model to specific worker, cleaned 
up by worker
-                       
_modelMap.get(workerID).put(ParamservUtils.copyList(_model));
+               // Invoke the aggregate function
+               _inst.processInstruction(ec);
 
-                       if (DMLScript.STATISTICS)
-                               Statistics.accPSModelBroadcastTime((long) 
tBroad.stop());
-               }
+               // Get the output
+               ListObject newModel = (ListObject) ec.getVariable(_outputName);
 
-               @Override
-               public Void call() throws Exception {
-                       try {
-                               Gradient grad;
-                               try {
-                                       grad = _gradientsQueue.take();
-                               } catch (InterruptedException e) {
-                                       throw new 
DMLRuntimeException("Aggregation service: error when waiting for the coming 
gradients.", e);
-                               }
-                               if (LOG.isDebugEnabled()) {
-                                       LOG.debug(String.format("Successfully 
pulled the gradients [size:%d kb] of worker_%d.",
-                                               grad._gradients.getDataSize() / 
1024, grad._workerID));
-                               }
+               // Update the model with the new output
+               ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL);
+               ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+               return newModel;
+       }
+       
+       private boolean allFinished() {
+               return !ArrayUtils.contains(_finishedStates, false);
+       }
 
-                               // Update and redistribute the model
-                               Timing tAgg = DMLScript.STATISTICS ? new 
Timing(true) : null;
-                               _model = updateModel(grad._gradients, _model);
-                               if (DMLScript.STATISTICS)
-                                       Statistics.accPSAggregationTime((long) 
tAgg.stop());
+       private void resetFinishedStates() {
+               Arrays.fill(_finishedStates, false);
+       }
 
-                               // Redistribute model according to update type
-                               switch(_updateType) {
-                                       case BSP: {
-                                               
setFinishedState(grad._workerID);
-                                               if (allFinished()) {
-                                                       // Broadcast the 
updated model
-                                                       resetFinishedStates();
-                                                       broadcastModel();
-                                                       if 
(LOG.isDebugEnabled())
-                                                               
LOG.debug("Global parameter is broadcasted successfully.");
-                                               }
-                                               break;
-                                       }
-                                       case ASP: {
-                                               broadcastModel(grad._workerID);
-                                               break;
-                                       }
-                                       default:
-                                               throw new 
DMLRuntimeException("Unsupported update: " + _updateType.name());
-                               }
-                       } 
-                       catch (Exception e) {
-                               throw new DMLRuntimeException("Aggregation 
service failed: ", e);
-                       }
-                       return null;
-               }
+       private void setFinishedState(int workerID) {
+               _finishedStates[workerID] = true;
+       }
+       
+       private void broadcastModel() throws InterruptedException {
+               Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null;
 
-               private ListObject updateModel(ListObject gradients, ListObject 
model) {
-                       return updateModel(_ec, gradients, model);
-               }
+               //broadcast copy of the model to all workers, cleaned up by 
workers
+               for (BlockingQueue<ListObject> q : _modelMap.values())
+                       q.put(ParamservUtils.copyList(_model));
 
-               /**
-                * A service method for updating model with gradients
-                */
-               private ListObject updateModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
-                       // Populate the variables table with the gradients and 
model
-                       ec.setVariable(Statement.PS_GRADIENTS, gradients);
-                       ec.setVariable(Statement.PS_MODEL, model);
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSModelBroadcastTime((long) 
tBroad.stop());
+       }
 
-                       // Invoke the aggregate function
-                       _inst.processInstruction(ec);
+       private void broadcastModel(int workerID) throws InterruptedException {
+               Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null;
 
-                       // Get the output
-                       ListObject newModel = (ListObject) 
ec.getVariable(_output.getName());
+               //broadcast copy of model to specific worker, cleaned up by 
worker
+               _modelMap.get(workerID).put(ParamservUtils.copyList(_model));
 
-                       // Update the model with the new output
-                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_MODEL);
-                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_GRADIENTS);
-                       return newModel;
-               }
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSModelBroadcastTime((long) 
tBroad.stop());
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e7fccd1c/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index be80127..25bc113 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -160,8 +160,6 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
                } finally {
                        es.shutdownNow();
-                       // Should shutdown the thread pool in param server
-                       ps.shutdown();
                }
        }
 

Reply via email to