Repository: systemml Updated Branches: refs/heads/master e9268d9e7 -> e7fccd1c7
[SYSTEMML-2416] Simplified paramserv aggregation service Closes #790. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e7fccd1c Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e7fccd1c Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e7fccd1c Branch: refs/heads/master Commit: e7fccd1c764ec470f6460e4e3cec90913e606798 Parents: e9268d9 Author: EdgarLGB <[email protected]> Authored: Fri Jun 22 22:29:49 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Jun 22 23:56:15 2018 -0700 ---------------------------------------------------------------------- .../controlprogram/paramserv/LocalPSWorker.java | 2 +- .../paramserv/LocalParamServer.java | 13 +- .../controlprogram/paramserv/ParamServer.java | 301 ++++++++----------- .../cp/ParamservBuiltinCPInstruction.java | 2 - 4 files changed, 129 insertions(+), 189 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e7fccd1c/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index 4f472ee..0ed7c81 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -107,7 +107,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int totalIter) { Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null; - globalParams = _ps.updateModel(_ec, gradients, globalParams); + globalParams = _ps.updateLocalModel(_ec, gradients, globalParams); if (DMLScript.STATISTICS) Statistics.accPSLocalModelUpdateTime((long) tUpd.stop()); http://git-wip-us.apache.org/repos/asf/systemml/blob/e7fccd1c/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java index d20383d..52372c9 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java @@ -19,8 +19,6 @@ package org.apache.sysml.runtime.controlprogram.paramserv; -import java.util.concurrent.ExecutionException; - import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; @@ -35,16 +33,7 @@ public class LocalParamServer extends ParamServer { @Override public void push(int workerID, ListObject gradients) { - try { - _gradientsQueue.put(new Gradient(workerID, gradients)); - } catch (InterruptedException e) { - throw new DMLRuntimeException(e); - } - try { - launchService(); - } catch (ExecutionException | InterruptedException e) { - throw new DMLRuntimeException("Aggregate service: some error occurred: ", e); - } + updateGlobalModel(workerID, gradients); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/e7fccd1c/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java index 4af72a4..abec267 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -27,16 +27,10 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.Callable; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.LinkedBlockingDeque; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.commons.lang3.ArrayUtils; -import org.apache.commons.lang3.concurrent.BasicThreadFactory; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; @@ -53,213 +47,172 @@ import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; import org.apache.sysml.runtime.instructions.cp.ListObject; import org.apache.sysml.utils.Statistics; -public abstract class ParamServer { - - final BlockingQueue<Gradient> _gradientsQueue; - final Map<Integer, BlockingQueue<ListObject>> _modelMap; - private final AggregationService _aggService; - private final ExecutorService _es; +public abstract class ParamServer +{ + protected final Log LOG = LogFactory.getLog(ParamServer.class.getName()); + + // worker input queues and global model + protected final Map<Integer, BlockingQueue<ListObject>> _modelMap; private ListObject _model; - ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { - _gradientsQueue = new LinkedBlockingDeque<>(); + //aggregation service + protected final ExecutionContext _ec; + private final Statement.PSUpdateType _updateType; + private final FunctionCallCPInstruction _inst; + private final String _outputName; + private final boolean[] _finishedStates; // Workers' finished states + + protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { + // init worker queues and global model _modelMap = new HashMap<>(workerNum); IntStream.range(0, workerNum).forEach(i -> { // Create a single element blocking queue for workers to receive the broadcasted model _modelMap.put(i, new ArrayBlockingQueue<>(1)); }); _model = model; - _aggService = new AggregationService(aggFunc, updateType, ec, workerNum); + + // init aggregation service + _ec = ec; + _updateType = updateType; + _finishedStates = new boolean[workerNum]; + String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX); + String ns = cfn[0]; + String fname = cfn[1]; + FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(ns, fname); + ArrayList<DataIdentifier> inputs = func.getInputParams(); + ArrayList<DataIdentifier> outputs = func.getOutputParams(); + + // Check the output of the aggregation function + if (outputs.size() != 1) { + throw new DMLRuntimeException(String.format("The output of the '%s' function should provide one list containing the updated model.", aggFunc)); + } + if (outputs.get(0).getDataType() != Expression.DataType.LIST) { + throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", aggFunc)); + } + _outputName = outputs.get(0).getName(); + + CPOperand[] boundInputs = inputs.stream() + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); + ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName) + .collect(Collectors.toCollection(ArrayList::new)); + ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) + .collect(Collectors.toCollection(ArrayList::new)); + _inst = new FunctionCallCPInstruction(ns, fname, boundInputs, inputNames, outputNames, "aggregate function"); + + // broadcast initial model try { - _aggService.broadcastModel(); + broadcastModel(); } catch (InterruptedException e) { throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e); } - BasicThreadFactory factory = new BasicThreadFactory.Builder() - .namingPattern("agg-service-pool-thread-%d").build(); - _es = Executors.newSingleThreadExecutor(factory); } public abstract void push(int workerID, ListObject value); public abstract Data pull(int workerID); - void launchService() throws ExecutionException, InterruptedException { - _es.submit(_aggService).get(); - } - - public void shutdown() { - _es.shutdownNow(); - } - public ListObject getResult() { // All the model updating work has terminated, // so we could return directly the result model return _model; } - - public ListObject updateModel(ExecutionContext ec, ListObject gradients, ListObject model) { - return _aggService.updateModel(ec, gradients, model); - } - - public static class Gradient { - final int _workerID; - final ListObject _gradients; - - public Gradient(int workerID, ListObject gradients) { - _workerID = workerID; - _gradients = gradients; - } - } - /** - * Inner aggregation service which is for updating the model - */ - private class AggregationService implements Callable<Void> { - - protected final Log LOG = LogFactory.getLog(AggregationService.class.getName()); - - protected final ExecutionContext _ec; - private final Statement.PSUpdateType _updateType; - private final FunctionCallCPInstruction _inst; - private final DataIdentifier _output; - private final boolean[] _finishedStates; // Workers' finished states - - AggregationService(String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { - _ec = ec; - _updateType = updateType; - _finishedStates = new boolean[workerNum]; - - // Fetch the aggregation function - String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX); - String ns = cfn[0]; - String fname = cfn[1]; - FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(ns, fname); - ArrayList<DataIdentifier> inputs = func.getInputParams(); - ArrayList<DataIdentifier> outputs = func.getOutputParams(); - - // Check the output of the aggregation function - if (outputs.size() != 1) { - throw new DMLRuntimeException(String.format("The output of the '%s' function should provide one list containing the updated model.", aggFunc)); - } - if (outputs.get(0).getDataType() != Expression.DataType.LIST) { - throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", aggFunc)); + protected synchronized void updateGlobalModel(int workerID, ListObject gradients) { + try { + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.", + gradients.getDataSize() / 1024, workerID)); } - _output = outputs.get(0); - - CPOperand[] boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); - ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); - ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); - _inst = new FunctionCallCPInstruction(ns, fname, boundInputs, inputNames, outputNames, "aggregate function"); - } - - private boolean allFinished() { - return !ArrayUtils.contains(_finishedStates, false); - } - - private void resetFinishedStates() { - Arrays.fill(_finishedStates, false); - } - - private void setFinishedState(int workerID) { - _finishedStates[workerID] = true; - } - - private void broadcastModel() throws InterruptedException { - Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null; - - //broadcast copy of the model to all workers, cleaned up by workers - for (BlockingQueue<ListObject> q : _modelMap.values()) - q.put(ParamservUtils.copyList(_model)); + // Update and redistribute the model + Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null; + _model = updateLocalModel(_ec, gradients, _model); if (DMLScript.STATISTICS) - Statistics.accPSModelBroadcastTime((long) tBroad.stop()); + Statistics.accPSAggregationTime((long) tAgg.stop()); + + // Redistribute model according to update type + switch(_updateType) { + case BSP: { + setFinishedState(workerID); + if (allFinished()) { + // Broadcast the updated model + resetFinishedStates(); + broadcastModel(); + if (LOG.isDebugEnabled()) + LOG.debug("Global parameter is broadcasted successfully."); + } + break; + } + case ASP: { + broadcastModel(workerID); + break; + } + default: + throw new DMLRuntimeException("Unsupported update: " + _updateType.name()); + } + } + catch (Exception e) { + throw new DMLRuntimeException("Aggregation service failed: ", e); } + } - private void broadcastModel(int workerID) throws InterruptedException { - Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null; + /** + * A service method for updating model with gradients + * + * @param ec execution context + * @param gradients list of gradients + * @param model old model + * @return new model + */ + protected ListObject updateLocalModel(ExecutionContext ec, ListObject gradients, ListObject model) { + // Populate the variables table with the gradients and model + ec.setVariable(Statement.PS_GRADIENTS, gradients); + ec.setVariable(Statement.PS_MODEL, model); - //broadcast copy of model to specific worker, cleaned up by worker - _modelMap.get(workerID).put(ParamservUtils.copyList(_model)); + // Invoke the aggregate function + _inst.processInstruction(ec); - if (DMLScript.STATISTICS) - Statistics.accPSModelBroadcastTime((long) tBroad.stop()); - } + // Get the output + ListObject newModel = (ListObject) ec.getVariable(_outputName); - @Override - public Void call() throws Exception { - try { - Gradient grad; - try { - grad = _gradientsQueue.take(); - } catch (InterruptedException e) { - throw new DMLRuntimeException("Aggregation service: error when waiting for the coming gradients.", e); - } - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.", - grad._gradients.getDataSize() / 1024, grad._workerID)); - } + // Update the model with the new output + ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL); + ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS); + return newModel; + } + + private boolean allFinished() { + return !ArrayUtils.contains(_finishedStates, false); + } - // Update and redistribute the model - Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null; - _model = updateModel(grad._gradients, _model); - if (DMLScript.STATISTICS) - Statistics.accPSAggregationTime((long) tAgg.stop()); + private void resetFinishedStates() { + Arrays.fill(_finishedStates, false); + } - // Redistribute model according to update type - switch(_updateType) { - case BSP: { - setFinishedState(grad._workerID); - if (allFinished()) { - // Broadcast the updated model - resetFinishedStates(); - broadcastModel(); - if (LOG.isDebugEnabled()) - LOG.debug("Global parameter is broadcasted successfully."); - } - break; - } - case ASP: { - broadcastModel(grad._workerID); - break; - } - default: - throw new DMLRuntimeException("Unsupported update: " + _updateType.name()); - } - } - catch (Exception e) { - throw new DMLRuntimeException("Aggregation service failed: ", e); - } - return null; - } + private void setFinishedState(int workerID) { + _finishedStates[workerID] = true; + } + + private void broadcastModel() throws InterruptedException { + Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null; - private ListObject updateModel(ListObject gradients, ListObject model) { - return updateModel(_ec, gradients, model); - } + //broadcast copy of the model to all workers, cleaned up by workers + for (BlockingQueue<ListObject> q : _modelMap.values()) + q.put(ParamservUtils.copyList(_model)); - /** - * A service method for updating model with gradients - */ - private ListObject updateModel(ExecutionContext ec, ListObject gradients, ListObject model) { - // Populate the variables table with the gradients and model - ec.setVariable(Statement.PS_GRADIENTS, gradients); - ec.setVariable(Statement.PS_MODEL, model); + if (DMLScript.STATISTICS) + Statistics.accPSModelBroadcastTime((long) tBroad.stop()); + } - // Invoke the aggregate function - _inst.processInstruction(ec); + private void broadcastModel(int workerID) throws InterruptedException { + Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null; - // Get the output - ListObject newModel = (ListObject) ec.getVariable(_output.getName()); + //broadcast copy of model to specific worker, cleaned up by worker + _modelMap.get(workerID).put(ParamservUtils.copyList(_model)); - // Update the model with the new output - ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL); - ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS); - return newModel; - } + if (DMLScript.STATISTICS) + Statistics.accPSModelBroadcastTime((long) tBroad.stop()); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e7fccd1c/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index be80127..25bc113 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -160,8 +160,6 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e); } finally { es.shutdownNow(); - // Should shutdown the thread pool in param server - ps.shutdown(); } }
