Repository: systemml Updated Branches: refs/heads/master 095781868 -> 51057e471
[SYSTEMML-2359] Additional paramserv update frequency: per-epoch Closes #780. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/51057e47 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/51057e47 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/51057e47 Branch: refs/heads/master Commit: 51057e4712d6ab9a190a9c1f8e9f36d48a8a1fd5 Parents: 0957818 Author: EdgarLGB <[email protected]> Authored: Mon Jun 4 21:25:53 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Jun 4 22:02:23 2018 -0700 ---------------------------------------------------------------------- .../ParameterizedBuiltinFunctionExpression.java | 24 +- .../controlprogram/paramserv/LocalPSWorker.java | 139 +++++-- .../paramserv/LocalParamServer.java | 8 +- .../controlprogram/paramserv/PSWorker.java | 2 +- .../controlprogram/paramserv/ParamServer.java | 42 ++- .../cp/ParamservBuiltinCPInstruction.java | 61 ++- .../functions/paramserv/ParamservFuncTest.java | 22 +- .../paramserv/mnist_lenet_paramserv.dml | 4 +- .../paramserv/mnist_lenet_paramserv_asp.dml | 376 ------------------- .../mnist_lenet_paramserv_minimum_version.dml | 4 +- .../paramserv/paramserv-nn-asp-batch.dml | 52 +++ .../paramserv/paramserv-nn-asp-epoch.dml | 52 +++ .../functions/paramserv/paramserv-nn-asp.dml | 52 --- .../paramserv/paramserv-nn-bsp-batch.dml | 52 +++ .../paramserv/paramserv-nn-bsp-epoch.dml | 52 +++ .../functions/paramserv/paramserv-nn-test.dml | 52 --- .../paramserv/paramserv-wrong-args.dml | 19 +- 17 files changed, 425 insertions(+), 588 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java index 99aec78..33c5b0e 100644 --- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java @@ -334,22 +334,15 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier checkDataType(fname, Statement.PS_VAL_LABELS, DataType.MATRIX, conditional); checkDataValueType(false, fname, Statement.PS_UPDATE_FUN, DataType.SCALAR, ValueType.STRING, conditional); checkDataValueType(false, fname, Statement.PS_AGGREGATION_FUN, DataType.SCALAR, ValueType.STRING, conditional); - Set<String> modes = Arrays.stream(Statement.PSModeType.values()).map(Enum::name) - .collect(Collectors.toSet()); - checkStringParam(false, fname, Statement.PS_MODE, modes, conditional); - Set<String> utypes = Arrays.stream(Statement.PSUpdateType.values()).map(Enum::name) - .collect(Collectors.toSet()); - checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, utypes, conditional); - Set<String> frequencies = Arrays.stream(Statement.PSFrequency.values()).map(Enum::name).collect(Collectors.toSet()); - checkStringParam(true, fname, Statement.PS_FREQUENCY, frequencies, conditional); + checkStringParam(false, fname, Statement.PS_MODE, conditional); + checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, conditional); + checkStringParam(true, fname, Statement.PS_FREQUENCY, conditional); checkDataValueType(false, fname, Statement.PS_EPOCHS, DataType.SCALAR, ValueType.INT, conditional); checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT, conditional); checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT, conditional); - Set<String> schemes = Arrays.stream(Statement.PSScheme.values()).map(Enum::name).collect(Collectors.toSet()); - checkStringParam(true, fname, Statement.PS_SCHEME, schemes, conditional); + checkStringParam(true, fname, Statement.PS_SCHEME, conditional); checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional); - Set<String> checkpointings = Arrays.stream(Statement.PSCheckpointing.values()).map(Enum::name).collect(Collectors.toSet()); - checkStringParam(true, fname, Statement.PS_CHECKPOINTING, checkpointings, conditional); + checkStringParam(true, fname, Statement.PS_CHECKPOINTING, conditional); // set output characteristics output.setDataType(DataType.LIST); @@ -358,7 +351,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier output.setBlockDimensions(-1, -1); } - private void checkStringParam(boolean optional, String fname, String pname, Set<String> validOptions, boolean conditional) { + private void checkStringParam(boolean optional, String fname, String pname, boolean conditional) { Expression param = getVarParam(pname); if (param == null) { if (optional) { @@ -371,11 +364,6 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier String.format("Function %s should provide a string value for %s parameter.", fname, pname), conditional); } - StringIdentifier si = (StringIdentifier) param; - if (!validOptions.contains(si.getValue())) { - raiseValidateError(String.format("Function %s does not support value '%s' as the '%s' parameter.", fname, - si.getValue(), pname), conditional, LanguageErrorCodes.INVALID_PARAMETERS); - } } // example: A = transformapply(target=X, meta=M, spec=s) http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index e902aea..1583fbf 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -42,59 +42,118 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { public Void call() throws Exception { try { long dataSize = _features.getNumRows(); - for (int i = 0; i < _epochs; i++) { - int totalIter = (int) Math.ceil(dataSize / _batchSize); - for (int j = 0; j < totalIter; j++) { - // Pull the global parameters from ps - ListObject globalParams = (ListObject)_ps.pull(_workerID); - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Successfully pull the global parameters " - + "[size:%d kb] from ps.", _workerID, globalParams.getDataSize() / 1024)); - } - _ec.setVariable(Statement.PS_MODEL, globalParams); + int totalIter = (int) Math.ceil(dataSize / _batchSize); - long begin = j * _batchSize + 1; - long end = Math.min(begin + _batchSize, dataSize); + switch (_freq) { + case BATCH: + computeBatch(dataSize, totalIter); + break; + case EPOCH: + computeEpoch(dataSize, totalIter); + break; + } - // Get batch features and labels - MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end); - MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end); - _ec.setVariable(Statement.PS_FEATURES, bFeatures); - _ec.setVariable(Statement.PS_LABELS, bLabels); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Job finished.", _workerID)); + } + } catch (Exception e) { + throw new DMLRuntimeException(String.format("Local worker_%d failed", _workerID), e); + } + return null; + } - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Got batch data [size:%d kb] of index from %d to %d. " - + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", _workerID, bFeatures.getDataSize() - / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 1, _epochs, j + 1, totalIter)); - } + private void computeEpoch(long dataSize, int totalIter) { + for (int i = 0; i < _epochs; i++) { + // Pull the global parameters from ps + ListObject globalParams = pullModel(); - // Invoke the update function - _inst.processInstruction(_ec); + for (int j = 0; j < totalIter; j++) { + _ec.setVariable(Statement.PS_MODEL, globalParams); - // Get the gradients - ListObject gradients = (ListObject) _ec.getVariable(_output.getName()); + ListObject gradients = computeGradients(dataSize, totalIter, i, j); + if (j == totalIter - 1) { // Push the gradients to ps - _ps.push(_workerID, gradients); + pushGradients(gradients); + ParamservUtils.cleanupListObject(_ec, globalParams); + } else { + // Update the local model with gradients + globalParams = _ps.updateModel(gradients, globalParams); if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Successfully push the gradients " - + "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024)); + LOG.debug(String.format("Local worker_%d: Local global parameter [size:%d kb] updated.", + _workerID, globalParams.getDataSize())); } - - ParamservUtils.cleanupListObject(_ec, globalParams); - ParamservUtils.cleanupData(bFeatures); - ParamservUtils.cleanupData(bLabels); - } - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); } } if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Job finished.", _workerID)); + LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); } - } catch (Exception e) { - throw new DMLRuntimeException(String.format("Local worker_%d failed", _workerID), e); } - return null; + + } + + private void computeBatch(long dataSize, int totalIter) { + for (int i = 0; i < _epochs; i++) { + for (int j = 0; j < totalIter; j++) { + ListObject globalParams = pullModel(); + + _ec.setVariable(Statement.PS_MODEL, globalParams); + ListObject gradients = computeGradients(dataSize, totalIter, i, j); + + // Push the gradients to ps + pushGradients(gradients); + + ParamservUtils.cleanupListObject(_ec, globalParams); + } + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); + } + } + } + + private ListObject pullModel() { + // Pull the global parameters from ps + ListObject globalParams = (ListObject)_ps.pull(_workerID); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Successfully pull the global parameters " + + "[size:%d kb] from ps.", _workerID, globalParams.getDataSize() / 1024)); + } + return globalParams; + } + + private void pushGradients(ListObject gradients) { + // Push the gradients to ps + _ps.push(_workerID, gradients); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Successfully push the gradients " + + "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024)); + } + } + + private ListObject computeGradients(long dataSize, int totalIter, int i, int j) { + long begin = j * _batchSize + 1; + long end = Math.min(begin + _batchSize, dataSize); + + // Get batch features and labels + MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end); + MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end); + _ec.setVariable(Statement.PS_FEATURES, bFeatures); + _ec.setVariable(Statement.PS_LABELS, bLabels); + + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Got batch data [size:%d kb] of index from %d to %d. " + + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", _workerID, bFeatures.getDataSize() + / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 1, _epochs, j + 1, totalIter)); + } + + // Invoke the update function + _inst.processInstruction(_ec); + + // Get the gradients + ListObject gradients = (ListObject) _ec.getVariable(_output.getName()); + + ParamservUtils.cleanupData(bFeatures); + ParamservUtils.cleanupData(bLabels); + return gradients; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java index 665395e..bac507c 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java @@ -29,9 +29,9 @@ import org.apache.sysml.runtime.instructions.cp.ListObject; public class LocalParamServer extends ParamServer { - public LocalParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq, - Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { - super(model, aggFunc, freq, updateType, ec, workerNum); + public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, + int workerNum) { + super(model, aggFunc, updateType, ec, workerNum); } @Override @@ -52,7 +52,7 @@ public class LocalParamServer extends ParamServer { public Data pull(int workerID) { ListObject model; try { - model = _modelMap.get((int) workerID).take(); + model = _modelMap.get(workerID).take(); } catch (InterruptedException e) { throw new DMLRuntimeException(e); } http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index 56fda22..affa3c1 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -49,7 +49,7 @@ public abstract class PSWorker { private MatrixObject _valFeatures; private MatrixObject _valLabels; private final String _updFunc; - private final Statement.PSFrequency _freq; + protected final Statement.PSFrequency _freq; protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java index 1052390..d7cd78d 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -56,8 +56,7 @@ public abstract class ParamServer { private final ExecutorService _es; private ListObject _model; - ParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq, Statement.PSUpdateType updateType, - ExecutionContext ec, int workerNum) { + ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { _gradientsQueue = new LinkedBlockingDeque<>(); _modelMap = new HashMap<>(workerNum); IntStream.range(0, workerNum).forEach(i -> { @@ -65,7 +64,7 @@ public abstract class ParamServer { _modelMap.put(i, new ArrayBlockingQueue<>(1)); }); _model = model; - _aggService = new AggregationService(aggFunc, freq, updateType, ec, workerNum); + _aggService = new AggregationService(aggFunc, updateType, ec, workerNum); try { _aggService.broadcastModel(); } @@ -91,6 +90,10 @@ public abstract class ParamServer { return _model; } + public ListObject updateModel(ListObject gradients, ListObject model) { + return _aggService.updateModel(gradients, model); + } + public static class Gradient { final int _workerID; final ListObject _gradients; @@ -109,16 +112,13 @@ public abstract class ParamServer { protected final Log LOG = LogFactory.getLog(AggregationService.class.getName()); protected ExecutionContext _ec; - //private Statement.PSFrequency _freq; private Statement.PSUpdateType _updateType; private FunctionCallCPInstruction _inst; private DataIdentifier _output; private boolean[] _finishedStates; // Workers' finished states - AggregationService(String aggFunc, Statement.PSFrequency freq, Statement.PSUpdateType updateType, - ExecutionContext ec, int workerNum) { + AggregationService(String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { _ec = ec; - //_freq = freq; _updateType = updateType; _finishedStates = new boolean[workerNum]; @@ -192,13 +192,13 @@ public abstract class ParamServer { } // Update and redistribute the model - updateModel(grad); - + _model = updateModel(grad._gradients, _model); + // Redistribute model according to update type - switch( _updateType ) { + switch(_updateType) { case BSP: { setFinishedState(grad._workerID); - if( allFinished() ) { + if (allFinished()) { // Broadcast the updated model resetFinishedStates(); broadcastModel(); @@ -212,7 +212,7 @@ public abstract class ParamServer { break; } default: - throw new DMLRuntimeException("Unsupported update: "+_updateType.name()); + throw new DMLRuntimeException("Unsupported update: " + _updateType.name()); } } catch (Exception e) { @@ -221,10 +221,16 @@ public abstract class ParamServer { return null; } - private void updateModel(Gradient grad) throws InterruptedException { + /** + * A synchronized service method for updating model with gradients + * + * @param gradients A list object of gradients + * @return A updated list object of model + */ + private synchronized ListObject updateModel(ListObject gradients, ListObject model) { // Populate the variables table with the gradients and model - _ec.setVariable(Statement.PS_GRADIENTS, grad._gradients); - _ec.setVariable(Statement.PS_MODEL, _model); + _ec.setVariable(Statement.PS_GRADIENTS, gradients); + _ec.setVariable(Statement.PS_MODEL, model); // Invoke the aggregate function _inst.processInstruction(_ec); @@ -233,9 +239,9 @@ public abstract class ParamServer { ListObject newModel = (ListObject) _ec.getVariable(_output.getName()); // Update the model with the new output - ParamservUtils.cleanupListObject(_ec, _model); - ParamservUtils.cleanupListObject(_ec, grad._gradients); - _model = newModel; + ParamservUtils.cleanupListObject(_ec, model); + ParamservUtils.cleanupListObject(_ec, gradients); + return newModel; } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 79d1ff3..6e2b187 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -51,6 +51,8 @@ import java.util.concurrent.Future; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.sysml.hops.Hop; @@ -86,11 +88,14 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc //internal local debug level private static final boolean LDEBUG = false; + protected static final Log LOG = LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName()); + static { // for internal debugging only if (LDEBUG) { Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel(Level.DEBUG); + Logger.getLogger(ParamservBuiltinCPInstruction.class.getName()).setLevel(Level.DEBUG); } } @@ -100,7 +105,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc @Override public void processInstruction(ExecutionContext ec) { - PSModeType mode = PSModeType.valueOf(getParam(PS_MODE)); + PSModeType mode = getPSMode(); int workerNum = getWorkerNum(mode); ExecutorService es = Executors.newFixedThreadPool(workerNum); String updFunc = getParam(PS_UPDATE_FUN); @@ -119,7 +124,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc // Create the parameter server ListObject model = ec.getListObject(getParam(PS_MODEL)); - ParamServer ps = createPS(mode, aggFunc, freq, updateType, workerNum, model, aggServiceEC); + ParamServer ps = createPS(mode, aggFunc, updateType, workerNum, model, aggServiceEC); // Create the local workers List<LocalPSWorker> workers = IntStream.range(0, workerNum) @@ -129,9 +134,14 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc // Do data partition doDataPartition(ec, workers); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("\nConfiguration of paramserv func: \nmode: %s \nworkerNum: %d \nupdate frequency: %s \nstrategy: %s", + mode, workerNum, freq, updateType)); + } + // Launch the worker threads and wait for completion try { - for( Future<Void> ret : es.invokeAll(workers) ) + for (Future<Void> ret : es.invokeAll(workers)) ret.get(); //error handling } catch (InterruptedException | ExecutionException e) { throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e); @@ -145,6 +155,18 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc ec.setVariable(output.getName(), result); } + private PSModeType getPSMode() { + PSModeType mode; + try { + mode = PSModeType.valueOf(getParam(PS_MODE)); + } catch (IllegalArgumentException e) { + throw new DMLRuntimeException(String.format("Paramserv function: not support ps execution mode '%s'", getParam(PS_MODE))); + } + if( mode == PSModeType.REMOTE_SPARK ) + throw new DMLRuntimeException("Do not support remote spark."); + return mode; + } + private int getEpochs() { int epochs = Integer.valueOf(getParam(PS_EPOCHS)); if (epochs <= 0) { @@ -224,7 +246,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc else if (pb instanceof IfProgramBlock) { IfProgramBlock ipb = (IfProgramBlock) pb; recompiled |= rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled); - if( ipb.getChildBlocksElseBody() != null ) + if (ipb.getChildBlocksElseBody() != null) recompiled |= rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled); } else { @@ -259,9 +281,14 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc } private PSUpdateType getUpdateType() { - PSUpdateType updType = PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE)); - if( updType == PSUpdateType.SSP ) - throw new DMLRuntimeException(String.format("Not support update type '%s'.", updType)); + PSUpdateType updType; + try { + updType = PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE)); + } catch (IllegalArgumentException e) { + throw new DMLRuntimeException(String.format("Paramserv function: not support update type '%s'.", getParam(PS_UPDATE_TYPE))); + } + if (updType == PSUpdateType.SSP) + throw new DMLRuntimeException("Not support update type SSP."); return updType; } @@ -269,9 +296,12 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc if (!getParameterMap().containsKey(PS_FREQUENCY)) { return DEFAULT_UPDATE_FREQUENCY; } - PSFrequency freq = PSFrequency.valueOf(getParam(PS_FREQUENCY)); - if( freq == PSFrequency.EPOCH ) - throw new DMLRuntimeException("Not support epoch update frequency."); + PSFrequency freq; + try { + freq = PSFrequency.valueOf(getParam(PS_FREQUENCY)); + } catch (IllegalArgumentException e) { + throw new DMLRuntimeException(String.format("Paramserv function: not support '%s' update frequency.", getParam(PS_FREQUENCY))); + } return freq; } @@ -306,12 +336,11 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc * * @return parameter server */ - private ParamServer createPS(PSModeType mode, String aggFunc, PSFrequency freq, PSUpdateType updateType, - int workerNum, ListObject model, ExecutionContext ec) { + private ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) { ParamServer ps = null; switch (mode) { case LOCAL: - ps = new LocalParamServer(model, aggFunc, freq, updateType, ec, workerNum); + ps = new LocalParamServer(model, aggFunc, updateType, ec, workerNum); break; case REMOTE_SPARK: throw new DMLRuntimeException("Do not support remote spark."); @@ -346,7 +375,11 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS)); PSScheme scheme = DEFAULT_SCHEME; if (getParameterMap().containsKey(PS_SCHEME)) { - scheme = PSScheme.valueOf(getParam(PS_SCHEME)); + try { + scheme = PSScheme.valueOf(getParam(PS_SCHEME)); + } catch (IllegalArgumentException e) { + throw new DMLRuntimeException(String.format("Paramserv function: not support data partition scheme '%s'", getParam(PS_SCHEME))); + } } switch (scheme) { case DISJOINT_CONTIGUOUS: http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java index 28b525a..3185621 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java @@ -32,13 +32,15 @@ public class ParamservFuncTest extends AutomatedTestBase { private static final String TEST_NAME4 = "paramserv-wrong-type-args"; private static final String TEST_NAME5 = "paramserv-wrong-args"; private static final String TEST_NAME6 = "paramserv-wrong-args2"; - private static final String TEST_NAME7 = "paramserv-nn-test"; + private static final String TEST_NAME7 = "paramserv-nn-bsp-batch"; private static final String TEST_NAME8 = "paramserv-minimum-version"; private static final String TEST_NAME9 = "paramserv-worker-failed"; private static final String TEST_NAME10 = "paramserv-agg-service-failed"; private static final String TEST_NAME11 = "paramserv-large-parallelism"; private static final String TEST_NAME12 = "paramserv-wrong-aggregate-func"; - private static final String TEST_NAME13 = "paramserv-nn-asp"; + private static final String TEST_NAME13 = "paramserv-nn-asp-batch"; + private static final String TEST_NAME14 = "paramserv-nn-bsp-epoch"; + private static final String TEST_NAME15 = "paramserv-nn-asp-epoch"; private static final String TEST_DIR = "functions/paramserv/"; private static final String TEST_CLASS_DIR = TEST_DIR + ParamservFuncTest.class.getSimpleName() + "/"; @@ -60,6 +62,8 @@ public class ParamservFuncTest extends AutomatedTestBase { addTestConfiguration(TEST_NAME11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME11, new String[] {})); addTestConfiguration(TEST_NAME12, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME12, new String[] {})); addTestConfiguration(TEST_NAME13, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME13, new String[] {})); + addTestConfiguration(TEST_NAME14, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME14, new String[] {})); + addTestConfiguration(TEST_NAME15, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME15, new String[] {})); } @Test @@ -86,7 +90,7 @@ public class ParamservFuncTest extends AutomatedTestBase { @Test public void testParamservWrongArgs() { - final String errmsg = "Function PARAMSERV does not support value 'NSP' as the 'utype' parameter."; + final String errmsg = "Paramserv function: not support update type 'NSP'."; runDMLTest(TEST_NAME5, true, DMLException.class, errmsg); } @@ -97,7 +101,7 @@ public class ParamservFuncTest extends AutomatedTestBase { } @Test - public void testParamservNNTest() { + public void testParamservNNBspBatchTest() { runDMLTest(TEST_NAME7, false, null, null); } @@ -132,6 +136,16 @@ public class ParamservFuncTest extends AutomatedTestBase { runDMLTest(TEST_NAME13, false, null, null); } + @Test + public void testParamservBSPEpochTest() { + runDMLTest(TEST_NAME14, false, null, null); + } + + @Test + public void testParamservASPEpochTest() { + runDMLTest(TEST_NAME15, false, null, null); + } + private void runDMLTest(String testname, boolean exceptionExpected, Class<?> exceptionClass, String errmsg) { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml index 4ea6e5f..041c2bf 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml @@ -35,7 +35,7 @@ source("nn/optim/sgd_nesterov.dml") as sgd_nesterov train = function(matrix[double] X, matrix[double] Y, matrix[double] X_val, matrix[double] Y_val, - int C, int Hin, int Win, int epochs, int workers) + int C, int Hin, int Win, int epochs, int workers, string utype, string freq) return (matrix[double] W1, matrix[double] b1, matrix[double] W2, matrix[double] b2, matrix[double] W3, matrix[double] b3, @@ -107,7 +107,7 @@ train = function(matrix[double] X, matrix[double] Y, params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) # Use paramserv function - modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode="LOCAL", utype="BSP", freq="BATCH", epochs=epochs, batchsize=64, k=workers, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE") + modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode="LOCAL", utype=utype, freq=freq, epochs=epochs, batchsize=64, k=workers, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE") W1 = as.matrix(modelList2["W1"]) b1 = as.matrix(modelList2["b1"]) http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml deleted file mode 100644 index b2e155e..0000000 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml +++ /dev/null @@ -1,376 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -/* - * MNIST LeNet Example - */ -# Imports -source("nn/layers/affine.dml") as affine -source("nn/layers/conv2d_builtin.dml") as conv2d -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss -source("nn/layers/dropout.dml") as dropout -source("nn/layers/l2_reg.dml") as l2_reg -source("nn/layers/max_pool2d_builtin.dml") as max_pool2d -source("nn/layers/relu.dml") as relu -source("nn/layers/softmax.dml") as softmax -source("nn/optim/sgd_nesterov.dml") as sgd_nesterov - -train = function(matrix[double] X, matrix[double] Y, - matrix[double] X_val, matrix[double] Y_val, - int C, int Hin, int Win, int epochs, int workers) - return (matrix[double] W1, matrix[double] b1, - matrix[double] W2, matrix[double] b2, - matrix[double] W3, matrix[double] b3, - matrix[double] W4, matrix[double] b4) { - /* - * Trains a convolutional net using the "LeNet" architecture. - * - * The input matrix, X, has N examples, each represented as a 3D - * volume unrolled into a single vector. The targets, Y, have K - * classes, and are one-hot encoded. - * - * Inputs: - * - X: Input data matrix, of shape (N, C*Hin*Win). - * - Y: Target matrix, of shape (N, K). - * - X_val: Input validation data matrix, of shape (N, C*Hin*Win). - * - Y_val: Target validation matrix, of shape (N, K). - * - C: Number of input channels (dimensionality of input depth). - * - Hin: Input height. - * - Win: Input width. - * - epochs: Total number of full training loops over the full data set. - * - * Outputs: - * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf). - * - b1: 1st layer biases vector, of shape (F1, 1). - * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf). - * - b2: 2nd layer biases vector, of shape (F2, 1). - * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3). - * - b3: 3rd layer biases vector, of shape (1, N3). - * - W4: 4th layer weights (parameters) matrix, of shape (N3, K). - * - b4: 4th layer biases vector, of shape (1, K). - */ - N = nrow(X) - K = ncol(Y) - - # Create network: - # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax - Hf = 5 # filter height - Wf = 5 # filter width - stride = 1 - pad = 2 # For same dimensions, (Hf - stride) / 2 - - F1 = 32 # num conv filters in conv1 - F2 = 64 # num conv filters in conv2 - N3 = 512 # num nodes in affine3 - # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes) - - [W1, b1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win) - [W2, b2] = conv2d::init(F2, F1, Hf, Wf) # inputs: (N, F1*(Hin/2)*(Win/2)) - [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3) # inputs: (N, F2*(Hin/2/2)*(Win/2/2)) - [W4, b4] = affine::init(N3, K) # inputs: (N, N3) - W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu - - # Initialize SGD w/ Nesterov momentum optimizer - lr = 0.01 # learning rate - mu = 0.9 #0.5 # momentum - decay = 0.95 # learning rate decay constant - vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1) - vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2) - vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3) - vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4) - - # Regularization - lambda = 5e-04 - - # Create the model object - modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) - - # Create the hyper parameter list - params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) - - # Use paramserv function - modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml::aggregation", mode="LOCAL", utype="ASP", epochs=epochs, hyperparams=params) - - W1 = as.matrix(modelList2["W1"]) - b1 = as.matrix(modelList2["b1"]) - W2 = as.matrix(modelList2["W2"]) - b2 = as.matrix(modelList2["b2"]) - W3 = as.matrix(modelList2["W3"]) - b3 = as.matrix(modelList2["b3"]) - W4 = as.matrix(modelList2["W4"]) - b4 = as.matrix(modelList2["b4"]) - -} - -gradients = function(matrix[double] features, - matrix[double] labels, - list[unknown] hyperparams, - list[unknown] model) - return (list[unknown] gradients) { - - C = 1 - Hin = 28 - Win = 28 - Hf = 5 - Wf = 5 - stride = 1 - pad = 2 - lambda = 5e-04 - F1 = 32 - F2 = 64 - N3 = 512 - W1 = as.matrix(model["W1"]) - b1 = as.matrix(model["b1"]) - W2 = as.matrix(model["W2"]) - b2 = as.matrix(model["b2"]) - W3 = as.matrix(model["W3"]) - b3 = as.matrix(model["b3"]) - W4 = as.matrix(model["W4"]) - b4 = as.matrix(model["b4"]) - - # Compute forward pass - ## layer 1: conv1 -> relu1 -> pool1 - [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf, - stride, stride, pad, pad) - outr1 = relu::forward(outc1) - [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, - strideh=2, stridew=2, pad=0, pad=0) - ## layer 2: conv2 -> relu2 -> pool2 - [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf, - stride, stride, pad, pad) - outr2 = relu::forward(outc2) - [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, - strideh=2, stridew=2, pad=0, pad=0) - ## layer 3: affine3 -> relu3 -> dropout - outa3 = affine::forward(outp2, W3, b3) - outr3 = relu::forward(outa3) - [outd3, maskd3] = dropout::forward(outr3, 0.5, -1) - ## layer 4: affine4 -> softmax - outa4 = affine::forward(outd3, W4, b4) - probs = softmax::forward(outa4) - - # Compute data backward pass - ## loss: - dprobs = cross_entropy_loss::backward(probs, labels) - ## layer 4: affine4 -> softmax - douta4 = softmax::backward(dprobs, outa4) - [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4) - ## layer 3: affine3 -> relu3 -> dropout - doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3) - douta3 = relu::backward(doutr3, outa3) - [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3) - ## layer 2: conv2 -> relu2 -> pool2 - doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, - strideh=2, stridew=2, pad=0, pad=0) - doutc2 = relu::backward(doutr2, outc2) - [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1, - Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad) - ## layer 1: conv1 -> relu1 -> pool1 - doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, - strideh=2, stridew=2, pad=0, pad=0) - doutc1 = relu::backward(doutr1, outc1) - [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win, - Hf, Wf, stride, stride, pad, pad) - - # Compute regularization backward pass - dW1_reg = l2_reg::backward(W1, lambda) - dW2_reg = l2_reg::backward(W2, lambda) - dW3_reg = l2_reg::backward(W3, lambda) - dW4_reg = l2_reg::backward(W4, lambda) - dW1 = dW1 + dW1_reg - dW2 = dW2 + dW2_reg - dW3 = dW3 + dW3_reg - dW4 = dW4 + dW4_reg - - gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4) - -} - -aggregation = function(list[unknown] model, - list[unknown] gradients, - list[unknown] hyperparams) - return (list[unknown] modelResult) { - - W1 = as.matrix(model["W1"]) - W2 = as.matrix(model["W2"]) - W3 = as.matrix(model["W3"]) - W4 = as.matrix(model["W4"]) - b1 = as.matrix(model["b1"]) - b2 = as.matrix(model["b2"]) - b3 = as.matrix(model["b3"]) - b4 = as.matrix(model["b4"]) - dW1 = as.matrix(gradients["dW1"]) - dW2 = as.matrix(gradients["dW2"]) - dW3 = as.matrix(gradients["dW3"]) - dW4 = as.matrix(gradients["dW4"]) - db1 = as.matrix(gradients["db1"]) - db2 = as.matrix(gradients["db2"]) - db3 = as.matrix(gradients["db3"]) - db4 = as.matrix(gradients["db4"]) - vW1 = as.matrix(model["vW1"]) - vW2 = as.matrix(model["vW2"]) - vW3 = as.matrix(model["vW3"]) - vW4 = as.matrix(model["vW4"]) - vb1 = as.matrix(model["vb1"]) - vb2 = as.matrix(model["vb2"]) - vb3 = as.matrix(model["vb3"]) - vb4 = as.matrix(model["vb4"]) - lr = 0.01 - mu = 0.9 - - # Optimize with SGD w/ Nesterov momentum - [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1) - [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1) - [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2) - [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2) - [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3) - [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3) - [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4) - [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4) - - modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) - } - -predict = function(matrix[double] X, int C, int Hin, int Win, - matrix[double] W1, matrix[double] b1, - matrix[double] W2, matrix[double] b2, - matrix[double] W3, matrix[double] b3, - matrix[double] W4, matrix[double] b4) - return (matrix[double] probs) { - /* - * Computes the class probability predictions of a convolutional - * net using the "LeNet" architecture. - * - * The input matrix, X, has N examples, each represented as a 3D - * volume unrolled into a single vector. - * - * Inputs: - * - X: Input data matrix, of shape (N, C*Hin*Win). - * - C: Number of input channels (dimensionality of input depth). - * - Hin: Input height. - * - Win: Input width. - * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf). - * - b1: 1st layer biases vector, of shape (F1, 1). - * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf). - * - b2: 2nd layer biases vector, of shape (F2, 1). - * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3). - * - b3: 3rd layer biases vector, of shape (1, N3). - * - W4: 4th layer weights (parameters) matrix, of shape (N3, K). - * - b4: 4th layer biases vector, of shape (1, K). - * - * Outputs: - * - probs: Class probabilities, of shape (N, K). - */ - N = nrow(X) - - # Network: - # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax - Hf = 5 # filter height - Wf = 5 # filter width - stride = 1 - pad = 2 # For same dimensions, (Hf - stride) / 2 - - F1 = nrow(W1) # num conv filters in conv1 - F2 = nrow(W2) # num conv filters in conv2 - N3 = ncol(W3) # num nodes in affine3 - K = ncol(W4) # num nodes in affine4, equal to number of target dimensions (num classes) - - # Compute predictions over mini-batches - probs = matrix(0, rows=N, cols=K) - batch_size = 64 - iters = ceil(N / batch_size) - for(i in 1:iters) { - # Get next batch - beg = ((i-1) * batch_size) %% N + 1 - end = min(N, beg + batch_size - 1) - X_batch = X[beg:end,] - - # Compute forward pass - ## layer 1: conv1 -> relu1 -> pool1 - [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride, - pad, pad) - outr1 = relu::forward(outc1) - [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, - strideh=2, stridew=2, pad=0, pad=0) - ## layer 2: conv2 -> relu2 -> pool2 - [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf, - stride, stride, pad, pad) - outr2 = relu::forward(outc2) - [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, - strideh=2, stridew=2, pad=0, pad=0) - ## layer 3: affine3 -> relu3 - outa3 = affine::forward(outp2, W3, b3) - outr3 = relu::forward(outa3) - ## layer 4: affine4 -> softmax - outa4 = affine::forward(outr3, W4, b4) - probs_batch = softmax::forward(outa4) - - # Store predictions - probs[beg:end,] = probs_batch - } -} - -eval = function(matrix[double] probs, matrix[double] Y) - return (double loss, double accuracy) { - /* - * Evaluates a convolutional net using the "LeNet" architecture. - * - * The probs matrix contains the class probability predictions - * of K classes over N examples. The targets, Y, have K classes, - * and are one-hot encoded. - * - * Inputs: - * - probs: Class probabilities, of shape (N, K). - * - Y: Target matrix, of shape (N, K). - * - * Outputs: - * - loss: Scalar loss, of shape (1). - * - accuracy: Scalar accuracy, of shape (1). - */ - # Compute loss & accuracy - loss = cross_entropy_loss::forward(probs, Y) - correct_pred = rowIndexMax(probs) == rowIndexMax(Y) - accuracy = mean(correct_pred) -} - -generate_dummy_data = function() - return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) { - /* - * Generate a dummy dataset similar to the MNIST dataset. - * - * Outputs: - * - X: Input data matrix, of shape (N, D). - * - Y: Target matrix, of shape (N, K). - * - C: Number of input channels (dimensionality of input depth). - * - Hin: Input height. - * - Win: Input width. - */ - # Generate dummy input data - N = 1024 # num examples - C = 1 # num input channels - Hin = 28 # input height - Win = 28 # input width - K = 10 # num target classes - X = rand(rows=N, cols=C*Hin*Win, pdf="normal") - classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform")) - Y = table(seq(1, N), classes) # one-hot encoding -} - http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml index 707722e..d02e5d6 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml @@ -233,8 +233,8 @@ aggregation = function(list[unknown] model, vb2 = as.matrix(model["vb2"]) vb3 = as.matrix(model["vb3"]) vb4 = as.matrix(model["vb4"]) - lr = as.scalar(hyperparams['lr']); - mu = as.scalar(hyperparams['mu']); + lr = 0.01 + mu = 0.9 # Optimize with SGD w/ Nesterov momentum [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1) http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml new file mode 100644 index 0000000..346cc08 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 2 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "BATCH") + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml new file mode 100644 index 0000000..8d553ae --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 2 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "EPOCH") + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml deleted file mode 100644 index b50e17c..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml +++ /dev/null @@ -1,52 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml") as mnist_lenet -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss - -# Generate the training data -[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() -n = nrow(images) - -# Generate the training data -[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() - -# Split into training and validation -val_size = n * 0.1 -X = images[(val_size+1):n,] -X_val = images[1:val_size,] -Y = labels[(val_size+1):n,] -Y_val = labels[1:val_size,] - -# Arguments -epochs = 10 -workers = 2 - -# Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers) - -# Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) -loss_val = cross_entropy_loss::forward(probs_val, Y_val) -accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) - -# Output results -print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml new file mode 100644 index 0000000..7b6523b --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 2 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH") + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml new file mode 100644 index 0000000..d0a6570 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 2 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "EPOCH") + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-test.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-test.dml b/src/test/scripts/functions/paramserv/paramserv-nn-test.dml deleted file mode 100644 index 740a208..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-nn-test.dml +++ /dev/null @@ -1,52 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss - -# Generate the training data -[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() -n = nrow(images) - -# Generate the training data -[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() - -# Split into training and validation -val_size = n * 0.1 -X = images[(val_size+1):n,] -X_val = images[1:val_size,] -Y = labels[(val_size+1):n,] -Y_val = labels[1:val_size,] - -# Arguments -epochs = 10 -workers = 2 - -# Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers) - -# Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) -loss_val = cross_entropy_loss::forward(probs_val, Y_val) -accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) - -# Output results -print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml b/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml index 13a05c9..8f5f53e 100644 --- a/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml +++ b/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml @@ -26,16 +26,25 @@ Y = matrix(2, rows=2, cols=3) X_val = matrix(3, rows=2, cols=3) Y_val = matrix(4, rows=2, cols=3) -gradients = function (matrix[double] input) return (matrix[double] output) { - output = input +gradients = function(matrix[double] features, + matrix[double] wrong_labels, + list[unknown] hyperparams, + list[unknown] model) + return (list[unknown] gradients) { + gradients = model } -aggregation = function (matrix[double] input) return (matrix[double] output) { - output = input +aggregation = function(list[unknown] model, + list[unknown] gradients, + list[unknown] hyperparams) + return (list[unknown] modelResult) { + modelResult = model } e2 = "element2" params = list(e2) # Use paramserv function -modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="NSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE") \ No newline at end of file +modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="NSP", epochs=10, hyperparams=params) + +print(toString(as.matrix(modelList2[1]))) \ No newline at end of file
