Repository: systemml Updated Branches: refs/heads/master 2b86a4d92 -> d44b3280f
[SYSTEMML-2344,48,49,52] Various improvements local paramserv backend Closes #777. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d44b3280 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d44b3280 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d44b3280 Branch: refs/heads/master Commit: d44b3280f2132deb303e955ff5b9a17daac4c31e Parents: 2b86a4d Author: EdgarLGB <[email protected]> Authored: Sun Jun 3 23:07:02 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Jun 3 23:07:03 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/parser/Statement.java | 8 +- .../controlprogram/paramserv/LocalPSWorker.java | 99 ++--- .../paramserv/LocalParamServer.java | 40 +- .../controlprogram/paramserv/PSWorker.java | 113 +++--- .../controlprogram/paramserv/ParamServer.java | 227 +++++------ .../paramserv/ParamservUtils.java | 12 +- .../cp/ParamservBuiltinCPInstruction.java | 260 +++++++++---- .../functions/paramserv/ParamservFuncTest.java | 39 +- .../paramserv/mnist_lenet_paramserv.dml | 1 - .../paramserv/mnist_lenet_paramserv_asp.dml | 376 +++++++++++++++++++ .../mnist_lenet_paramserv_minimum_version.dml | 1 - .../paramserv/paramserv-agg-service-failed.dml | 53 +++ .../paramserv/paramserv-large-parallelism.dml | 52 +++ .../functions/paramserv/paramserv-nn-asp.dml | 52 +++ .../paramserv/paramserv-worker-failed.dml | 53 +++ .../paramserv-wrong-aggregate-func.dml | 50 +++ 16 files changed, 1117 insertions(+), 319 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/parser/Statement.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Statement.java b/src/main/java/org/apache/sysml/parser/Statement.java index 1987d31..d49eada 100644 --- a/src/main/java/org/apache/sysml/parser/Statement.java +++ b/src/main/java/org/apache/sysml/parser/Statement.java @@ -77,7 +77,13 @@ public abstract class Statement implements ParseInfo } public static final String PS_UPDATE_TYPE = "utype"; public enum PSUpdateType { - BSP, ASP, SSP + BSP, ASP, SSP; + public boolean isBSP() { + return this == BSP; + } + public boolean isASP() { + return this == ASP; + } } public static final String PS_FREQUENCY = "freq"; public enum PSFrequency { http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index 181b866..e902aea 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -19,79 +19,82 @@ package org.apache.sysml.runtime.controlprogram.paramserv; +import java.util.concurrent.Callable; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.cp.ListObject; -public class LocalPSWorker extends PSWorker implements Runnable { +public class LocalPSWorker extends PSWorker implements Callable<Void> { protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName()); - public LocalPSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, - ListObject hyperParams, ExecutionContext ec, ParamServer ps) { - super(workerID, updFunc, freq, epochs, batchSize, hyperParams, ec, ps); + public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, + ExecutionContext ec, ParamServer ps) { + super(workerID, updFunc, freq, epochs, batchSize, ec, ps); } @Override - public void run() { + public Void call() throws Exception { + try { + long dataSize = _features.getNumRows(); + for (int i = 0; i < _epochs; i++) { + int totalIter = (int) Math.ceil(dataSize / _batchSize); + for (int j = 0; j < totalIter; j++) { + // Pull the global parameters from ps + ListObject globalParams = (ListObject)_ps.pull(_workerID); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Successfully pull the global parameters " + + "[size:%d kb] from ps.", _workerID, globalParams.getDataSize() / 1024)); + } + _ec.setVariable(Statement.PS_MODEL, globalParams); - long dataSize = _features.getNumRows(); + long begin = j * _batchSize + 1; + long end = Math.min(begin + _batchSize, dataSize); - for (int i = 0; i < _epochs; i++) { - int totalIter = (int) Math.ceil(dataSize / _batchSize); - for (int j = 0; j < totalIter; j++) { - // Pull the global parameters from ps - // Need to copy the global parameter - ListObject globalParams = ParamservUtils.copyList((ListObject) _ps.pull(_workerID)); - if (LOG.isDebugEnabled()) { - LOG.debug(String.format( - "Local worker_%d: Successfully pull the global parameters [size:%d kb] from ps.", _workerID, - globalParams.getDataSize() / 1024)); - } - _ec.setVariable(Statement.PS_MODEL, globalParams); + // Get batch features and labels + MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end); + MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end); + _ec.setVariable(Statement.PS_FEATURES, bFeatures); + _ec.setVariable(Statement.PS_LABELS, bLabels); - long begin = j * _batchSize + 1; - long end = Math.min(begin + _batchSize, dataSize); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Got batch data [size:%d kb] of index from %d to %d. " + + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", _workerID, bFeatures.getDataSize() + / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 1, _epochs, j + 1, totalIter)); + } - // Get batch features and labels - MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end); - MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end); - _ec.setVariable(Statement.PS_FEATURES, bFeatures); - _ec.setVariable(Statement.PS_LABELS, bLabels); + // Invoke the update function + _inst.processInstruction(_ec); - if (LOG.isDebugEnabled()) { - LOG.debug(String.format( - "Local worker_%d: Got batch data [size:%d kb] of index from %d to %d. [Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", - _workerID, bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 1, - _epochs, j + 1, totalIter)); - } + // Get the gradients + ListObject gradients = (ListObject) _ec.getVariable(_output.getName()); - // Invoke the update function - _inst.processInstruction(_ec); + // Push the gradients to ps + _ps.push(_workerID, gradients); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Successfully push the gradients " + + "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024)); + } - // Get the gradients - ListObject gradients = (ListObject) _ec.getVariable(_outputs.get(0).getName()); - - // Push the gradients to ps - _ps.push(_workerID, gradients); + ParamservUtils.cleanupListObject(_ec, globalParams); + ParamservUtils.cleanupData(bFeatures); + ParamservUtils.cleanupData(bLabels); + } if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Successfully push the gradients [size:%d kb] to ps.", - _workerID, gradients.getDataSize() / 1024)); + LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); } - - ParamservUtils.cleanupListObject(_ec, globalParams); - ParamservUtils.cleanupData(bFeatures); - ParamservUtils.cleanupData(bLabels); } if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); + LOG.debug(String.format("Local worker_%d: Job finished.", _workerID)); } + } catch (Exception e) { + throw new DMLRuntimeException(String.format("Local worker_%d failed", _workerID), e); } - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Job finished.", _workerID)); - } + return null; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java index d060a91..665395e 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java @@ -19,6 +19,8 @@ package org.apache.sysml.runtime.controlprogram.paramserv; +import java.util.concurrent.ExecutionException; + import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; @@ -28,32 +30,32 @@ import org.apache.sysml.runtime.instructions.cp.ListObject; public class LocalParamServer extends ParamServer { public LocalParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq, - Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum, - ListObject hyperParams) { - super(model, aggFunc, freq, updateType, ec, workerNum, hyperParams); + Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { + super(model, aggFunc, freq, updateType, ec, workerNum); } @Override - public void push(long workerID, ListObject gradients) { - synchronized (_lock) { - _queue.add(new Gradient(workerID, gradients)); - _lock.notifyAll(); + public void push(int workerID, ListObject gradients) { + try { + _gradientsQueue.put(new Gradient(workerID, gradients)); + } catch (InterruptedException e) { + throw new DMLRuntimeException(e); + } + try { + launchService(); + } catch (ExecutionException | InterruptedException e) { + throw new DMLRuntimeException("Aggregate service: some error occurred: ", e); } } @Override - public Data pull(long workerID) { - synchronized (_lock) { - while (getPulledState((int) workerID)) { - try { - _lock.wait(); - } catch (InterruptedException e) { - throw new DMLRuntimeException( - String.format("Local worker_%d: failed to pull the global parameters.", workerID), e); - } - } - setPulledState((int) workerID, true); + public Data pull(int workerID) { + ListObject model; + try { + model = _modelMap.get((int) workerID).take(); + } catch (InterruptedException e) { + throw new DMLRuntimeException(e); } - return getResult(); + return model; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index 9ace823..56fda22 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -30,102 +30,99 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; -import org.apache.sysml.runtime.instructions.cp.ListObject; @SuppressWarnings("unused") public abstract class PSWorker { - long _workerID = -1; - int _epochs; - long _batchSize; - MatrixObject _features; - MatrixObject _labels; - ExecutionContext _ec; - ParamServer _ps; - private String _updFunc; - private Statement.PSFrequency _freq; + protected final int _workerID; + protected final int _epochs; + protected final long _batchSize; + protected final ExecutionContext _ec; + protected final ParamServer _ps; + protected final DataIdentifier _output; + protected final FunctionCallCPInstruction _inst; + protected MatrixObject _features; + protected MatrixObject _labels; + private MatrixObject _valFeatures; private MatrixObject _valLabels; - - ArrayList<DataIdentifier> _outputs; - FunctionCallCPInstruction _inst; - - public PSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, - ListObject hyperParams, ExecutionContext ec, ParamServer ps) { - this._workerID = workerID; - this._updFunc = updFunc; - this._freq = freq; - this._epochs = epochs; - this._batchSize = batchSize; - this._ec = ExecutionContextFactory.createContext(ec.getProgram()); - if (hyperParams != null) { - this._ec.setVariable(Statement.PS_HYPER_PARAMS, hyperParams); - } - this._ps = ps; + private final String _updFunc; + private final Statement.PSFrequency _freq; + + protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, + int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { + _workerID = workerID; + _updFunc = updFunc; + _freq = freq; + _epochs = epochs; + _batchSize = batchSize; + _ec = ec; + _ps = ps; // Get the update function String[] keys = DMLProgram.splitFunctionKey(updFunc); - String _funcName = keys[0]; - String _funcNS = null; + String funcName = keys[0]; + String funcNS = null; if (keys.length == 2) { - _funcNS = keys[0]; - _funcName = keys[1]; + funcNS = keys[0]; + funcName = keys[1]; } - FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(_funcNS, _funcName); - ArrayList<DataIdentifier> _inputs = func.getInputParams(); - _outputs = func.getOutputParams(); - CPOperand[] _boundInputs = _inputs.stream() + FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(funcNS, funcName); + ArrayList<DataIdentifier> inputs = func.getInputParams(); + ArrayList<DataIdentifier> outputs = func.getOutputParams(); + CPOperand[] boundInputs = inputs.stream() .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) .toArray(CPOperand[]::new); - ArrayList<String> _inputNames = _inputs.stream().map(DataIdentifier::getName) + ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName) .collect(Collectors.toCollection(ArrayList::new)); - ArrayList<String> _outputNames = _outputs.stream().map(DataIdentifier::getName) + ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) .collect(Collectors.toCollection(ArrayList::new)); - _inst = new FunctionCallCPInstruction(_funcNS, _funcName, _boundInputs, _inputNames, _outputNames, + _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames, "update function"); // Check the inputs of the update function - checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_FEATURES); - checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_LABELS); - checkInput(_inputs, Expression.DataType.LIST, Statement.PS_MODEL); - if (hyperParams != null) { - checkInput(_inputs, Expression.DataType.LIST, Statement.PS_HYPER_PARAMS); - } + checkInput(false, inputs, Expression.DataType.MATRIX, Statement.PS_FEATURES); + checkInput(false, inputs, Expression.DataType.MATRIX, Statement.PS_LABELS); + checkInput(false, inputs, Expression.DataType.LIST, Statement.PS_MODEL); + checkInput(true, inputs, Expression.DataType.LIST, Statement.PS_HYPER_PARAMS); // Check the output of the update function - if (_outputs.size() != 1) { - throw new DMLRuntimeException( - String.format("The output of the '%s' function should provide one list containing the gradients.", updFunc)); + if (outputs.size() != 1) { + throw new DMLRuntimeException(String.format("The output of the '%s' function " + + "should provide one list containing the gradients.", updFunc)); } - if (_outputs.get(0).getDataType() != Expression.DataType.LIST) { - throw new DMLRuntimeException( - String.format("The output of the '%s' function should be type of list.", updFunc)); + if (outputs.get(0).getDataType() != Expression.DataType.LIST) { + throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", updFunc)); } + _output = outputs.get(0); } - private void checkInput(ArrayList<DataIdentifier> _inputs, Expression.DataType dt, String pname) { - if (_inputs.stream().filter(input -> input.getDataType() == dt && pname.equals(input.getName())).count() != 1) { - throw new DMLRuntimeException( - String.format("The '%s' function should provide an input of '%s' type named '%s'.", _updFunc, dt, pname)); + private void checkInput(boolean optional, ArrayList<DataIdentifier> inputs, Expression.DataType dt, String pname) { + if (optional && inputs.stream().noneMatch(input -> pname.equals(input.getName()))) { + // We do not need to check if the input is optional and is not provided + return; + } + if (inputs.stream().filter(input -> input.getDataType() == dt && pname.equals(input.getName())).count() != 1) { + throw new DMLRuntimeException(String.format("The '%s' function should provide " + + "an input of '%s' type named '%s'.", _updFunc, dt, pname)); } } public void setFeatures(MatrixObject features) { - this._features = features; + _features = features; } public void setLabels(MatrixObject labels) { - this._labels = labels; + _labels = labels; } public void setValFeatures(MatrixObject valFeatures) { - this._valFeatures = valFeatures; + _valFeatures = valFeatures; } public void setValLabels(MatrixObject valLabels) { - this._valLabels = valLabels; + _valLabels = valLabels; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java index 6e1cd13..1052390 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -21,9 +21,17 @@ package org.apache.sysml.runtime.controlprogram.paramserv; import java.util.ArrayList; import java.util.Arrays; -import java.util.Queue; -import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingDeque; import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.logging.Log; @@ -35,98 +43,83 @@ import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; import org.apache.sysml.runtime.instructions.cp.ListObject; public abstract class ParamServer { + + final BlockingQueue<Gradient> _gradientsQueue; + final Map<Integer, BlockingQueue<ListObject>> _modelMap; + private final AggregationService _aggService; + private final ExecutorService _es; + private ListObject _model; - public class Gradient { - final long _workerID; - final ListObject _gradients; - - public Gradient(long workerID, ListObject gradients) { - this._workerID = workerID; - this._gradients = gradients; + ParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq, Statement.PSUpdateType updateType, + ExecutionContext ec, int workerNum) { + _gradientsQueue = new LinkedBlockingDeque<>(); + _modelMap = new HashMap<>(workerNum); + IntStream.range(0, workerNum).forEach(i -> { + // Create a single element blocking queue for workers to receive the broadcasted model + _modelMap.put(i, new ArrayBlockingQueue<>(1)); + }); + _model = model; + _aggService = new AggregationService(aggFunc, freq, updateType, ec, workerNum); + try { + _aggService.broadcastModel(); } + catch (InterruptedException e) { + throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e); + } + _es = Executors.newSingleThreadExecutor(); } - Queue<Gradient> _queue; - final Object _lock = new Object(); - private ListObject _model; - private AggregationService _aggService; - private Thread _aggThread; - private boolean[] _pulledStates; - - protected ParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq, - Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum, ListObject hyperParams) { - this._queue = new ConcurrentLinkedQueue<>(); - this._model = model; - this._aggService = new AggregationService(aggFunc, freq, updateType, ec, workerNum, hyperParams); - this._pulledStates = new boolean[workerNum]; - this._aggThread = new Thread(_aggService); - } - - public abstract void push(long workerID, ListObject value); + public abstract void push(int workerID, ListObject value); - public abstract Data pull(long workerID); + public abstract Data pull(int workerID); - public void start() { - _aggService._alive = true; - _aggThread.start(); + void launchService() throws ExecutionException, InterruptedException { + _es.submit(_aggService).get(); } - public void stop() { - _aggService._alive = false; - try { - _aggThread.join(); - } catch (InterruptedException e) { - throw new DMLRuntimeException("Parameter server: failed when stopping the server.", e); - } + public void shutdown() { + _es.shutdownNow(); } public ListObject getResult() { return _model; } - public boolean getPulledState(int workerID) { - return _pulledStates[workerID]; - } - - public void setPulledState(int workerID, boolean state) { - _pulledStates[workerID] = state; - } + public static class Gradient { + final int _workerID; + final ListObject _gradients; - private void resetPulledStates() { - _pulledStates = new boolean[_pulledStates.length]; + public Gradient(int workerID, ListObject gradients) { + _workerID = workerID; + _gradients = gradients; + } } - + /** * Inner aggregation service which is for updating the model */ - @SuppressWarnings("unused") - private class AggregationService implements Runnable { + private class AggregationService implements Callable<Void> { protected final Log LOG = LogFactory.getLog(AggregationService.class.getName()); protected ExecutionContext _ec; - private Statement.PSFrequency _freq; + //private Statement.PSFrequency _freq; private Statement.PSUpdateType _updateType; private FunctionCallCPInstruction _inst; private DataIdentifier _output; - private boolean _alive; private boolean[] _finishedStates; // Workers' finished states AggregationService(String aggFunc, Statement.PSFrequency freq, Statement.PSUpdateType updateType, - ExecutionContext ec, int workerNum, ListObject hyperParams) { - _ec = ExecutionContextFactory.createContext(ec.getProgram()); - _freq = freq; + ExecutionContext ec, int workerNum) { + _ec = ec; + //_freq = freq; _updateType = updateType; - if (hyperParams != null) { - _ec.setVariable(Statement.PS_HYPER_PARAMS, hyperParams); - } _finishedStates = new boolean[workerNum]; // Fetch the aggregation function @@ -143,13 +136,11 @@ public abstract class ParamServer { // Check the output of the aggregation function if (outputs.size() != 1) { - throw new DMLRuntimeException(String.format( - "The output of the '%s' function should provide one list containing the updated model.", + throw new DMLRuntimeException(String.format("The output of the '%s' function should provide one list containing the updated model.", aggFunc)); } if (outputs.get(0).getDataType() != Expression.DataType.LIST) { - throw new DMLRuntimeException( - String.format("The output of the '%s' function should be type of list.", aggFunc)); + throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", aggFunc)); } _output = outputs.get(0); @@ -160,12 +151,7 @@ public abstract class ParamServer { .collect(Collectors.toCollection(ArrayList::new)); ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) .collect(Collectors.toCollection(ArrayList::new)); - _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames, - "aggregate function"); - } - - boolean isAlive() { - return _alive; + _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames, "aggregate function"); } private boolean allFinished() { @@ -180,53 +166,76 @@ public abstract class ParamServer { _finishedStates[workerID] = true; } - @Override - public void run() { - synchronized (_lock) { - while (isAlive()) { - do { - while (_queue.isEmpty()) { - try { - _lock.wait(); - } catch (InterruptedException e) { - throw new DMLRuntimeException( - "Aggregation service: error when waiting for the coming gradients.", e); - } - } - Gradient p = _queue.remove(); - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.", - p._gradients.getDataSize() / 1024, p._workerID)); - } - - setFinishedState((int) p._workerID); + private void broadcastModel() throws InterruptedException { + //broadcast copy of the model to all workers, cleaned up by workers + for (BlockingQueue<ListObject> q : _modelMap.values()) + q.put(ParamservUtils.copyList(_model)); + } + + private void broadcastModel(int workerID) throws InterruptedException { + //broadcast copy of model to specific worker, cleaned up by worker + _modelMap.get(workerID).put(ParamservUtils.copyList(_model)); + } - // Populate the variables table with the gradients and model - _ec.setVariable(Statement.PS_GRADIENTS, p._gradients); - _ec.setVariable(Statement.PS_MODEL, _model); + @Override + public Void call() throws Exception { + try { + Gradient grad; + try { + grad = _gradientsQueue.take(); + } catch (InterruptedException e) { + throw new DMLRuntimeException("Aggregation service: error when waiting for the coming gradients.", e); + } + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.", + grad._gradients.getDataSize() / 1024, grad._workerID)); + } - // Invoke the aggregate function - _inst.processInstruction(_ec); + // Update and redistribute the model + updateModel(grad); + + // Redistribute model according to update type + switch( _updateType ) { + case BSP: { + setFinishedState(grad._workerID); + if( allFinished() ) { + // Broadcast the updated model + resetFinishedStates(); + broadcastModel(); + if (LOG.isDebugEnabled()) + LOG.debug("Global parameter is broadcasted successfully."); + } + break; + } + case ASP: { + broadcastModel(grad._workerID); + break; + } + default: + throw new DMLRuntimeException("Unsupported update: "+_updateType.name()); + } + } + catch (Exception e) { + throw new DMLRuntimeException("Aggregation service failed: ", e); + } + return null; + } - // Get the output - ListObject newModel = (ListObject) _ec.getVariable(_output.getName()); + private void updateModel(Gradient grad) throws InterruptedException { + // Populate the variables table with the gradients and model + _ec.setVariable(Statement.PS_GRADIENTS, grad._gradients); + _ec.setVariable(Statement.PS_MODEL, _model); - // Update the model with the new output - ParamservUtils.cleanupListObject(_ec, _model); - ParamservUtils.cleanupListObject(_ec, p._gradients); - _model = newModel; + // Invoke the aggregate function + _inst.processInstruction(_ec); - } while (!allFinished()); + // Get the output + ListObject newModel = (ListObject) _ec.getVariable(_output.getName()); - // notify all the workers to get the updated model - resetPulledStates(); - resetFinishedStates(); - _lock.notifyAll(); - if (LOG.isDebugEnabled()) { - LOG.debug("Global parameter is broadcasted successfully."); - } - } - } + // Update the model with the new output + ParamservUtils.cleanupListObject(_ec, _model); + ParamservUtils.cleanupListObject(_ec, grad._gradients); + _model = newModel; } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index 54c5d6c..426a7fe 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -22,6 +22,7 @@ package org.apache.sysml.runtime.controlprogram.paramserv; import java.util.HashSet; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.sysml.parser.Expression; import org.apache.sysml.runtime.DMLRuntimeException; @@ -49,8 +50,8 @@ public class ParamservUtils { if (lo.getLength() == 0) { return lo; } - List<Data> newData = lo.getNames().stream().map(name -> { - Data oldData = lo.slice(name); + List<Data> newData = IntStream.range(0, lo.getLength()).mapToObj(i -> { + Data oldData = lo.slice(i); if (oldData instanceof MatrixObject) { MatrixObject mo = (MatrixObject) oldData; return sliceMatrix(mo, 1, mo.getNumRows()); @@ -69,7 +70,7 @@ public class ParamservUtils { } public static void cleanupData(Data data) { - if( !(data instanceof CacheableData) ) + if (!(data instanceof CacheableData)) return; CacheableData<?> cd = (CacheableData<?>) data; cd.enableCleanup(true); @@ -78,6 +79,7 @@ public class ParamservUtils { /** * Slice the matrix + * * @param mo input matrix * @param rl low boundary * @param rh high boundary @@ -88,10 +90,10 @@ public class ParamservUtils { new MetaDataFormat(new MatrixCharacteristics(-1, -1, -1, -1), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); MatrixBlock tmp = mo.acquireRead(); - result.acquireModify(tmp.slice((int)rl-1, (int)rh-1, 0, - tmp.getNumColumns()-1, new MatrixBlock())); + result.acquireModify(tmp.slice((int) rl - 1, (int) rh - 1, 0, tmp.getNumColumns() - 1, new MatrixBlock())); mo.release(); result.release(); + result.enableCleanup(false); return result; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 3ab0fc8..79d1ff3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -39,30 +39,49 @@ import static org.apache.sysml.parser.Statement.PS_UPDATE_TYPE; import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES; import static org.apache.sysml.parser.Statement.PS_VAL_LABELS; +import java.io.IOException; import java.util.ArrayList; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.log4j.Level; import org.apache.log4j.Logger; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.DMLTranslator; +import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.ForProgramBlock; +import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; +import org.apache.sysml.runtime.controlprogram.IfProgramBlock; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; +import org.apache.sysml.runtime.controlprogram.ParForProgramBlock; +import org.apache.sysml.runtime.controlprogram.Program; +import org.apache.sysml.runtime.controlprogram.ProgramBlock; +import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.matrix.operators.Operator; -import org.apache.sysml.utils.NativeHelper; public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction { private static final int DEFAULT_BATCH_SIZE = 64; private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.BATCH; - private static final int DEFAULT_LEVEL_PARALLELISM = InfrastructureAnalyzer.getLocalParallelism(); private static final PSScheme DEFAULT_SCHEME = PSScheme.DISJOINT_CONTIGUOUS; //internal local debug level @@ -71,79 +90,178 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc static { // for internal debugging only if (LDEBUG) { - Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel((Level) Level.DEBUG); + Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel(Level.DEBUG); } } - protected ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, - String opcode, String istr) { + public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) { super(op, paramsMap, out, opcode, istr); } @Override public void processInstruction(ExecutionContext ec) { - PSModeType mode = PSModeType.valueOf(getParam(PS_MODE)); int workerNum = getWorkerNum(mode); + ExecutorService es = Executors.newFixedThreadPool(workerNum); String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); + + // Create the workers' execution context + int k = getParLevel(workerNum); + List<ExecutionContext> workerECs = createExecutionContext(ec, updFunc, workerNum, k); + + // Create the agg service's execution context + ExecutionContext aggServiceEC = createExecutionContext(ec, aggFunc, 1, 1).get(0); + PSFrequency freq = getFrequency(); PSUpdateType updateType = getUpdateType(); - int epochs = Integer.valueOf(getParam(PS_EPOCHS)); - if (epochs <= 0) { - throw new DMLRuntimeException( - String.format("Paramserv function: The argument '%s' could not be less than or equal to 0.", - PS_EPOCHS)); - } - long batchSize = getBatchSize(); + int epochs = getEpochs(); // Create the parameter server ListObject model = ec.getListObject(getParam(PS_MODEL)); - ListObject hyperParams = getHyperParams(ec); - ParamServer ps = createPS(mode, aggFunc, freq, updateType, workerNum, model, ec, hyperParams); + ParamServer ps = createPS(mode, aggFunc, freq, updateType, workerNum, model, aggServiceEC); // Create the local workers List<LocalPSWorker> workers = IntStream.range(0, workerNum) - .mapToObj(i -> new LocalPSWorker((long) i, updFunc, freq, epochs, batchSize, hyperParams, ec, ps)) - .collect(Collectors.toList()); + .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), workerECs.get(i), ps)) + .collect(Collectors.toList()); // Do data partition doDataPartition(ec, workers); - // Create the worker threads - List<Thread> threads = workers.stream().map(Thread::new).collect(Collectors.toList()); + // Launch the worker threads and wait for completion + try { + for( Future<Void> ret : es.invokeAll(workers) ) + ret.get(); //error handling + } catch (InterruptedException | ExecutionException e) { + throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e); + } finally { + es.shutdownNow(); + } - // Start the ps - ps.start(); + // Fetch the final model from ps + ListObject result; + result = ps.getResult(); + ec.setVariable(output.getName(), result); + } - // Start the workers - threads.forEach(Thread::start); + private int getEpochs() { + int epochs = Integer.valueOf(getParam(PS_EPOCHS)); + if (epochs <= 0) { + throw new DMLRuntimeException(String.format("Paramserv function: " + + "The argument '%s' could not be less than or equal to 0.", PS_EPOCHS)); + } + return epochs; + } + + private int getParLevel(int workerNum) { + return Math.max((int)Math.ceil((double)getRemainingCores()/workerNum), 1); + } + + private List<ExecutionContext> createExecutionContext(ExecutionContext ec, String funcName, int workerNum, int k) { + // Fetch the target function + String[] keys = DMLProgram.splitFunctionKey(funcName); + String namespace = null; + String func = keys[0]; + if (keys.length == 2) { + namespace = keys[0]; + func = keys[1]; + } + return createExecutionContext(ec, namespace, func, workerNum, k); + } - // Wait for the workers stopping - threads.forEach(thread -> { + private List<ExecutionContext> createExecutionContext(ExecutionContext ec, String namespace, String func, + int workerNum, int k) { + FunctionProgramBlock targetFunc = ec.getProgram().getFunctionProgramBlock(namespace, func); + return IntStream.range(0, workerNum).mapToObj(i -> { + // Put the hyperparam into the variables table + LocalVariableMap varsMap = new LocalVariableMap(); + ListObject hyperParams = getHyperParams(ec); + if (hyperParams != null) { + varsMap.put(PS_HYPER_PARAMS, hyperParams); + } + + // Deep copy the target func + FunctionProgramBlock copiedFunc = ProgramConverter + .createDeepCopyFunctionProgramBlock(targetFunc, new HashSet<>(), new HashSet<>()); + + // Reset the visit status from root + for( ProgramBlock pb : copiedFunc.getChildBlocks() ) + DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock()); + + // Should recursively assign the level of parallelism + // and recompile the program block try { - thread.join(); - } catch (InterruptedException e) { - throw new DMLRuntimeException("Paramserv function: Failed to join the worker threads.", e); + rAssignParallelism(copiedFunc.getChildBlocks(), k, false); + } catch (IOException e) { + throw new DMLRuntimeException(e); } - }); - ps.stop(); + Program prog = new Program(); + prog.addProgramBlock(copiedFunc); + prog.addFunctionProgramBlock(namespace, func, copiedFunc); + return ExecutionContextFactory.createContext(varsMap, prog); - // Create the output - ListObject result = ps.getResult(); - ec.setVariable(output.getName(), result); + }).collect(Collectors.toList()); + } + + private boolean rAssignParallelism(ArrayList<ProgramBlock> pbs, int k, boolean recompiled) throws IOException { + for (ProgramBlock pb : pbs) { + if (pb instanceof ParForProgramBlock) { + ParForProgramBlock pfpb = (ParForProgramBlock) pb; + pfpb.setDegreeOfParallelism(k); + recompiled |= rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled); + } + else if (pb instanceof ForProgramBlock) { + recompiled |= rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled); + } + else if (pb instanceof WhileProgramBlock) { + recompiled |= rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled); + } + else if (pb instanceof FunctionProgramBlock) { + recompiled |= rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled); + } + else if (pb instanceof IfProgramBlock) { + IfProgramBlock ipb = (IfProgramBlock) pb; + recompiled |= rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled); + if( ipb.getChildBlocksElseBody() != null ) + recompiled |= rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled); + } + else { + StatementBlock sb = pb.getStatementBlock(); + for (Hop hop : sb.getHops()) + recompiled |= rAssignParallelism(hop, k, recompiled); + } + // Recompile the program block + if (recompiled) { + Recompiler.recompileProgramBlockInstructions(pb); + } + } + return recompiled; + } + + private boolean rAssignParallelism(Hop hop, int k, boolean recompiled) { + if (hop.isVisited()) { + return recompiled; + } + if (hop instanceof Hop.MultiThreadedHop) { + // Reassign the level of parallelism + Hop.MultiThreadedHop mhop = (Hop.MultiThreadedHop) hop; + mhop.setMaxNumThreads(k); + recompiled = true; + } + ArrayList<Hop> inputs = hop.getInput(); + for (Hop h : inputs) { + recompiled |= rAssignParallelism(h, k, recompiled); + } + hop.setVisited(); + return recompiled; } private PSUpdateType getUpdateType() { PSUpdateType updType = PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE)); - switch (updType) { - case ASP: - case SSP: + if( updType == PSUpdateType.SSP ) throw new DMLRuntimeException(String.format("Not support update type '%s'.", updType)); - case BSP: - break; - } return updType; } @@ -152,15 +270,15 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc return DEFAULT_UPDATE_FREQUENCY; } PSFrequency freq = PSFrequency.valueOf(getParam(PS_FREQUENCY)); - switch (freq) { - case EPOCH: + if( freq == PSFrequency.EPOCH ) throw new DMLRuntimeException("Not support epoch update frequency."); - case BATCH: - break; - } return freq; } + private int getRemainingCores() { + return InfrastructureAnalyzer.getLocalParallelism() - 1; + } + /** * Get the worker numbers according to the vcores * @@ -168,24 +286,17 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc * @return worker numbers */ private int getWorkerNum(PSModeType mode) { - int workerNum = DEFAULT_LEVEL_PARALLELISM; - if (getParameterMap().containsKey(PS_PARALLELISM)) { - workerNum = Integer.valueOf(getParam(PS_PARALLELISM)); - } + int workerNum = -1; switch (mode) { - case LOCAL: - //FIXME: this is a workaround for a maximum number of buffers in openblas - //However, the root cause is a missing function preparation for each worker - //(i.e., deep copy with unique file names, and reduced degree of parallelism) - int vcores = InfrastructureAnalyzer.getLocalParallelism(); - if ("openblas".equals(NativeHelper.getCurrentBLAS())) { - workerNum = Math.min(workerNum, vcores / 2); - } else { - workerNum = Math.min(workerNum, vcores); - } - break; - case REMOTE_SPARK: - throw new DMLRuntimeException("Do not support remote spark."); + case LOCAL: + // default worker number: available cores - 1 (assign one process for agg service) + workerNum = getRemainingCores(); + if (getParameterMap().containsKey(PS_PARALLELISM)) { + workerNum = Math.min(workerNum, Integer.valueOf(getParam(PS_PARALLELISM))); + } + break; + case REMOTE_SPARK: + throw new DMLRuntimeException("Do not support remote spark."); } return workerNum; } @@ -196,14 +307,14 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc * @return parameter server */ private ParamServer createPS(PSModeType mode, String aggFunc, PSFrequency freq, PSUpdateType updateType, - int workerNum, ListObject model, ExecutionContext ec, ListObject hyperParams) { + int workerNum, ListObject model, ExecutionContext ec) { ParamServer ps = null; switch (mode) { - case LOCAL: - ps = new LocalParamServer(model, aggFunc, freq, updateType, ec, workerNum, hyperParams); - break; - case REMOTE_SPARK: - throw new DMLRuntimeException("Do not support remote spark."); + case LOCAL: + ps = new LocalParamServer(model, aggFunc, freq, updateType, ec, workerNum); + break; + case REMOTE_SPARK: + throw new DMLRuntimeException("Do not support remote spark."); } return ps; } @@ -214,9 +325,8 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc } long batchSize = Integer.valueOf(getParam(PS_BATCH_SIZE)); if (batchSize <= 0) { - throw new DMLRuntimeException(String.format( - "Paramserv function: the number of argument '%s' could not be less than or equal to 0.", - PS_BATCH_SIZE)); + throw new DMLRuntimeException(String.format("Paramserv function: the number " + + "of argument '%s' could not be less than or equal to 0.", PS_BATCH_SIZE)); } return batchSize; } @@ -245,8 +355,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc case DISJOINT_RANDOM: case OVERLAP_RESHUFFLE: case DISJOINT_ROUND_ROBIN: - throw new DMLRuntimeException( - String.format("Paramserv function: the scheme '%s' is not supported.", scheme)); + throw new DMLRuntimeException(String.format("Paramserv function: the scheme '%s' is not supported.", scheme)); } } @@ -256,9 +365,10 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc List<MatrixObject> pfs = disjointContiguous(workers.size(), features); List<MatrixObject> pls = disjointContiguous(workers.size(), labels); if (pfs.size() < workers.size()) { - LOG.warn(String.format( - "There is only %d batches of data but has %d workers. Hence, reset the number of workers with %d.", - pfs.size(), workers.size(), pfs.size())); + if (LOG.isWarnEnabled()) { + LOG.warn(String.format("There is only %d batches of data but has %d workers. " + + "Hence, reset the number of workers with %d.", pfs.size(), workers.size(), pfs.size())); + } workers = workers.subList(0, pfs.size()); } for (int i = 0; i < workers.size(); i++) { http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java index 6370099..28b525a 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java @@ -34,6 +34,11 @@ public class ParamservFuncTest extends AutomatedTestBase { private static final String TEST_NAME6 = "paramserv-wrong-args2"; private static final String TEST_NAME7 = "paramserv-nn-test"; private static final String TEST_NAME8 = "paramserv-minimum-version"; + private static final String TEST_NAME9 = "paramserv-worker-failed"; + private static final String TEST_NAME10 = "paramserv-agg-service-failed"; + private static final String TEST_NAME11 = "paramserv-large-parallelism"; + private static final String TEST_NAME12 = "paramserv-wrong-aggregate-func"; + private static final String TEST_NAME13 = "paramserv-nn-asp"; private static final String TEST_DIR = "functions/paramserv/"; private static final String TEST_CLASS_DIR = TEST_DIR + ParamservFuncTest.class.getSimpleName() + "/"; @@ -50,6 +55,11 @@ public class ParamservFuncTest extends AutomatedTestBase { addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {})); addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {})); addTestConfiguration(TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {})); + addTestConfiguration(TEST_NAME9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] {})); + addTestConfiguration(TEST_NAME10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {})); + addTestConfiguration(TEST_NAME11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME11, new String[] {})); + addTestConfiguration(TEST_NAME12, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME12, new String[] {})); + addTestConfiguration(TEST_NAME13, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME13, new String[] {})); } @Test @@ -96,8 +106,33 @@ public class ParamservFuncTest extends AutomatedTestBase { runDMLTest(TEST_NAME8, false, null, null); } - private void runDMLTest(String testname, boolean exceptionExpected, Class<?> exceptionClass, - String errmsg) { + @Test + public void testParamservWorkerFailedTest() { + runDMLTest(TEST_NAME9, true, DMLException.class, "Invalid lookup by name in unnamed list: worker_err."); + } + + @Test + public void testParamservAggServiceFailedTest() { + runDMLTest(TEST_NAME10, true, DMLException.class, "Invalid lookup by name in unnamed list: agg_service_err"); + } + + @Test + public void testParamservLargeParallelismTest() { + runDMLTest(TEST_NAME11, false, null, null); + } + + @Test + public void testParamservWrongAggregateFuncTest() { + runDMLTest(TEST_NAME12, true, DMLException.class, + "The 'gradients' function should provide an input of 'MATRIX' type named 'labels'."); + } + + @Test + public void testParamservASPTest() { + runDMLTest(TEST_NAME13, false, null, null); + } + + private void runDMLTest(String testname, boolean exceptionExpected, Class<?> exceptionClass, String errmsg) { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); programArgs = new String[] { "-explain" }; http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml index 2a3bbe2..4ea6e5f 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml @@ -208,7 +208,6 @@ gradients = function(matrix[double] features, gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4) } -# PB: how to handle the velocity? (put into the model) # Should use the arguments named 'model', 'gradients', 'hyperparams' # and return always a model of type list aggregation = function(list[unknown] model, http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml new file mode 100644 index 0000000..b2e155e --- /dev/null +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml @@ -0,0 +1,376 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +/* + * MNIST LeNet Example + */ +# Imports +source("nn/layers/affine.dml") as affine +source("nn/layers/conv2d_builtin.dml") as conv2d +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss +source("nn/layers/dropout.dml") as dropout +source("nn/layers/l2_reg.dml") as l2_reg +source("nn/layers/max_pool2d_builtin.dml") as max_pool2d +source("nn/layers/relu.dml") as relu +source("nn/layers/softmax.dml") as softmax +source("nn/optim/sgd_nesterov.dml") as sgd_nesterov + +train = function(matrix[double] X, matrix[double] Y, + matrix[double] X_val, matrix[double] Y_val, + int C, int Hin, int Win, int epochs, int workers) + return (matrix[double] W1, matrix[double] b1, + matrix[double] W2, matrix[double] b2, + matrix[double] W3, matrix[double] b3, + matrix[double] W4, matrix[double] b4) { + /* + * Trains a convolutional net using the "LeNet" architecture. + * + * The input matrix, X, has N examples, each represented as a 3D + * volume unrolled into a single vector. The targets, Y, have K + * classes, and are one-hot encoded. + * + * Inputs: + * - X: Input data matrix, of shape (N, C*Hin*Win). + * - Y: Target matrix, of shape (N, K). + * - X_val: Input validation data matrix, of shape (N, C*Hin*Win). + * - Y_val: Target validation matrix, of shape (N, K). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + * - epochs: Total number of full training loops over the full data set. + * + * Outputs: + * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf). + * - b1: 1st layer biases vector, of shape (F1, 1). + * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf). + * - b2: 2nd layer biases vector, of shape (F2, 1). + * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3). + * - b3: 3rd layer biases vector, of shape (1, N3). + * - W4: 4th layer weights (parameters) matrix, of shape (N3, K). + * - b4: 4th layer biases vector, of shape (1, K). + */ + N = nrow(X) + K = ncol(Y) + + # Create network: + # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax + Hf = 5 # filter height + Wf = 5 # filter width + stride = 1 + pad = 2 # For same dimensions, (Hf - stride) / 2 + + F1 = 32 # num conv filters in conv1 + F2 = 64 # num conv filters in conv2 + N3 = 512 # num nodes in affine3 + # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes) + + [W1, b1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win) + [W2, b2] = conv2d::init(F2, F1, Hf, Wf) # inputs: (N, F1*(Hin/2)*(Win/2)) + [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3) # inputs: (N, F2*(Hin/2/2)*(Win/2/2)) + [W4, b4] = affine::init(N3, K) # inputs: (N, N3) + W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu + + # Initialize SGD w/ Nesterov momentum optimizer + lr = 0.01 # learning rate + mu = 0.9 #0.5 # momentum + decay = 0.95 # learning rate decay constant + vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1) + vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2) + vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3) + vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4) + + # Regularization + lambda = 5e-04 + + # Create the model object + modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) + + # Create the hyper parameter list + params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) + + # Use paramserv function + modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml::aggregation", mode="LOCAL", utype="ASP", epochs=epochs, hyperparams=params) + + W1 = as.matrix(modelList2["W1"]) + b1 = as.matrix(modelList2["b1"]) + W2 = as.matrix(modelList2["W2"]) + b2 = as.matrix(modelList2["b2"]) + W3 = as.matrix(modelList2["W3"]) + b3 = as.matrix(modelList2["b3"]) + W4 = as.matrix(modelList2["W4"]) + b4 = as.matrix(modelList2["b4"]) + +} + +gradients = function(matrix[double] features, + matrix[double] labels, + list[unknown] hyperparams, + list[unknown] model) + return (list[unknown] gradients) { + + C = 1 + Hin = 28 + Win = 28 + Hf = 5 + Wf = 5 + stride = 1 + pad = 2 + lambda = 5e-04 + F1 = 32 + F2 = 64 + N3 = 512 + W1 = as.matrix(model["W1"]) + b1 = as.matrix(model["b1"]) + W2 = as.matrix(model["W2"]) + b2 = as.matrix(model["b2"]) + W3 = as.matrix(model["W3"]) + b3 = as.matrix(model["b3"]) + W4 = as.matrix(model["W4"]) + b4 = as.matrix(model["b4"]) + + # Compute forward pass + ## layer 1: conv1 -> relu1 -> pool1 + [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf, + stride, stride, pad, pad) + outr1 = relu::forward(outc1) + [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 2: conv2 -> relu2 -> pool2 + [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf, + stride, stride, pad, pad) + outr2 = relu::forward(outc2) + [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 3: affine3 -> relu3 -> dropout + outa3 = affine::forward(outp2, W3, b3) + outr3 = relu::forward(outa3) + [outd3, maskd3] = dropout::forward(outr3, 0.5, -1) + ## layer 4: affine4 -> softmax + outa4 = affine::forward(outd3, W4, b4) + probs = softmax::forward(outa4) + + # Compute data backward pass + ## loss: + dprobs = cross_entropy_loss::backward(probs, labels) + ## layer 4: affine4 -> softmax + douta4 = softmax::backward(dprobs, outa4) + [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4) + ## layer 3: affine3 -> relu3 -> dropout + doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3) + douta3 = relu::backward(doutr3, outa3) + [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3) + ## layer 2: conv2 -> relu2 -> pool2 + doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + doutc2 = relu::backward(doutr2, outc2) + [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1, + Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad) + ## layer 1: conv1 -> relu1 -> pool1 + doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + doutc1 = relu::backward(doutr1, outc1) + [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win, + Hf, Wf, stride, stride, pad, pad) + + # Compute regularization backward pass + dW1_reg = l2_reg::backward(W1, lambda) + dW2_reg = l2_reg::backward(W2, lambda) + dW3_reg = l2_reg::backward(W3, lambda) + dW4_reg = l2_reg::backward(W4, lambda) + dW1 = dW1 + dW1_reg + dW2 = dW2 + dW2_reg + dW3 = dW3 + dW3_reg + dW4 = dW4 + dW4_reg + + gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4) + +} + +aggregation = function(list[unknown] model, + list[unknown] gradients, + list[unknown] hyperparams) + return (list[unknown] modelResult) { + + W1 = as.matrix(model["W1"]) + W2 = as.matrix(model["W2"]) + W3 = as.matrix(model["W3"]) + W4 = as.matrix(model["W4"]) + b1 = as.matrix(model["b1"]) + b2 = as.matrix(model["b2"]) + b3 = as.matrix(model["b3"]) + b4 = as.matrix(model["b4"]) + dW1 = as.matrix(gradients["dW1"]) + dW2 = as.matrix(gradients["dW2"]) + dW3 = as.matrix(gradients["dW3"]) + dW4 = as.matrix(gradients["dW4"]) + db1 = as.matrix(gradients["db1"]) + db2 = as.matrix(gradients["db2"]) + db3 = as.matrix(gradients["db3"]) + db4 = as.matrix(gradients["db4"]) + vW1 = as.matrix(model["vW1"]) + vW2 = as.matrix(model["vW2"]) + vW3 = as.matrix(model["vW3"]) + vW4 = as.matrix(model["vW4"]) + vb1 = as.matrix(model["vb1"]) + vb2 = as.matrix(model["vb2"]) + vb3 = as.matrix(model["vb3"]) + vb4 = as.matrix(model["vb4"]) + lr = 0.01 + mu = 0.9 + + # Optimize with SGD w/ Nesterov momentum + [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1) + [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1) + [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2) + [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2) + [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3) + [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3) + [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4) + [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4) + + modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) + } + +predict = function(matrix[double] X, int C, int Hin, int Win, + matrix[double] W1, matrix[double] b1, + matrix[double] W2, matrix[double] b2, + matrix[double] W3, matrix[double] b3, + matrix[double] W4, matrix[double] b4) + return (matrix[double] probs) { + /* + * Computes the class probability predictions of a convolutional + * net using the "LeNet" architecture. + * + * The input matrix, X, has N examples, each represented as a 3D + * volume unrolled into a single vector. + * + * Inputs: + * - X: Input data matrix, of shape (N, C*Hin*Win). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf). + * - b1: 1st layer biases vector, of shape (F1, 1). + * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf). + * - b2: 2nd layer biases vector, of shape (F2, 1). + * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3). + * - b3: 3rd layer biases vector, of shape (1, N3). + * - W4: 4th layer weights (parameters) matrix, of shape (N3, K). + * - b4: 4th layer biases vector, of shape (1, K). + * + * Outputs: + * - probs: Class probabilities, of shape (N, K). + */ + N = nrow(X) + + # Network: + # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax + Hf = 5 # filter height + Wf = 5 # filter width + stride = 1 + pad = 2 # For same dimensions, (Hf - stride) / 2 + + F1 = nrow(W1) # num conv filters in conv1 + F2 = nrow(W2) # num conv filters in conv2 + N3 = ncol(W3) # num nodes in affine3 + K = ncol(W4) # num nodes in affine4, equal to number of target dimensions (num classes) + + # Compute predictions over mini-batches + probs = matrix(0, rows=N, cols=K) + batch_size = 64 + iters = ceil(N / batch_size) + for(i in 1:iters) { + # Get next batch + beg = ((i-1) * batch_size) %% N + 1 + end = min(N, beg + batch_size - 1) + X_batch = X[beg:end,] + + # Compute forward pass + ## layer 1: conv1 -> relu1 -> pool1 + [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad) + outr1 = relu::forward(outc1) + [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 2: conv2 -> relu2 -> pool2 + [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf, + stride, stride, pad, pad) + outr2 = relu::forward(outc2) + [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 3: affine3 -> relu3 + outa3 = affine::forward(outp2, W3, b3) + outr3 = relu::forward(outa3) + ## layer 4: affine4 -> softmax + outa4 = affine::forward(outr3, W4, b4) + probs_batch = softmax::forward(outa4) + + # Store predictions + probs[beg:end,] = probs_batch + } +} + +eval = function(matrix[double] probs, matrix[double] Y) + return (double loss, double accuracy) { + /* + * Evaluates a convolutional net using the "LeNet" architecture. + * + * The probs matrix contains the class probability predictions + * of K classes over N examples. The targets, Y, have K classes, + * and are one-hot encoded. + * + * Inputs: + * - probs: Class probabilities, of shape (N, K). + * - Y: Target matrix, of shape (N, K). + * + * Outputs: + * - loss: Scalar loss, of shape (1). + * - accuracy: Scalar accuracy, of shape (1). + */ + # Compute loss & accuracy + loss = cross_entropy_loss::forward(probs, Y) + correct_pred = rowIndexMax(probs) == rowIndexMax(Y) + accuracy = mean(correct_pred) +} + +generate_dummy_data = function() + return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) { + /* + * Generate a dummy dataset similar to the MNIST dataset. + * + * Outputs: + * - X: Input data matrix, of shape (N, D). + * - Y: Target matrix, of shape (N, K). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + */ + # Generate dummy input data + N = 1024 # num examples + C = 1 # num input channels + Hin = 28 # input height + Win = 28 # input width + K = 10 # num target classes + X = rand(rows=N, cols=C*Hin*Win, pdf="normal") + classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform")) + Y = table(seq(1, N), classes) # one-hot encoding +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml index 8811c36..707722e 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml @@ -204,7 +204,6 @@ gradients = function(matrix[double] features, } -# how to handle the velocity? aggregation = function(list[unknown] model, list[unknown] gradients, list[unknown] hyperparams) http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-agg-service-failed.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-agg-service-failed.dml b/src/test/scripts/functions/paramserv/paramserv-agg-service-failed.dml new file mode 100644 index 0000000..1edd237 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-agg-service-failed.dml @@ -0,0 +1,53 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +modelList = list(e1) +X = matrix(1, rows=200, cols=30) +Y = matrix(2, rows=200, cols=1) +X_val = matrix(3, rows=200, cols=30) +Y_val = matrix(4, rows=200, cols=1) + +gradients = function(matrix[double] features, + matrix[double] labels, + list[unknown] hyperparams, + list[unknown] model) + return (list[unknown] gradients) { + gradients = model +} + +aggregation = function(list[unknown] model, + list[unknown] gradients, + list[unknown] hyperparams) + return (list[unknown] modelResult) { + modelResult = model + print(toString(as.matrix(gradients["agg_service_err"]))) +} + +e2 = "element2" +params = list(e2) + +modelList = list("model") + +# Use paramserv function +modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", epochs=10, hyperparams=params, k=1) + +print(toString(as.matrix(modelList2[1]))) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml b/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml new file mode 100644 index 0000000..4e2d2b7 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 10 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers) + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml new file mode 100644 index 0000000..b50e17c --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 2 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers) + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-worker-failed.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-worker-failed.dml b/src/test/scripts/functions/paramserv/paramserv-worker-failed.dml new file mode 100644 index 0000000..7ccde60 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-worker-failed.dml @@ -0,0 +1,53 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +modelList = list(e1) +X = matrix(1, rows=200, cols=30) +Y = matrix(2, rows=200, cols=1) +X_val = matrix(3, rows=200, cols=30) +Y_val = matrix(4, rows=200, cols=1) + +gradients = function(matrix[double] features, + matrix[double] labels, + list[unknown] hyperparams, + list[unknown] model) + return (list[unknown] gradients) { + gradients = model + print(toString(as.matrix(gradients["worker_err"]))) +} + +aggregation = function(list[unknown] model, + list[unknown] gradients, + list[unknown] hyperparams) + return (list[unknown] modelResult) { + modelResult = model +} + +e2 = "element2" +params = list(e2) + +modelList = list("model") + +# Use paramserv function +modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", epochs=10, hyperparams=params, k=1) + +print(toString(as.matrix(modelList2[1]))) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-wrong-aggregate-func.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-wrong-aggregate-func.dml b/src/test/scripts/functions/paramserv/paramserv-wrong-aggregate-func.dml new file mode 100644 index 0000000..ca76850 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-wrong-aggregate-func.dml @@ -0,0 +1,50 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +e1 = "element1" +modelList = list(e1) +X = matrix(1, rows=2, cols=3) +Y = matrix(2, rows=2, cols=3) +X_val = matrix(3, rows=2, cols=3) +Y_val = matrix(4, rows=2, cols=3) + +gradients = function(matrix[double] features, + matrix[double] wrong_labels, + list[unknown] hyperparams, + list[unknown] model) + return (list[unknown] gradients) { + gradients = model +} + +aggregation = function(list[unknown] model, + list[unknown] gradients, + list[unknown] hyperparams) + return (list[unknown] modelResult) { + modelResult = model +} + +e2 = "element2" +params = list(e2) + +# Use paramserv function +modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", epochs=10, hyperparams=params) + +print(toString(as.matrix(modelList2[1]))) \ No newline at end of file
