Repository: systemml
Updated Branches:
  refs/heads/master 095781868 -> 51057e471


[SYSTEMML-2359] Additional paramserv update frequency: per-epoch

Closes #780.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/51057e47
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/51057e47
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/51057e47

Branch: refs/heads/master
Commit: 51057e4712d6ab9a190a9c1f8e9f36d48a8a1fd5
Parents: 0957818
Author: EdgarLGB <[email protected]>
Authored: Mon Jun 4 21:25:53 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Jun 4 22:02:23 2018 -0700

----------------------------------------------------------------------
 .../ParameterizedBuiltinFunctionExpression.java |  24 +-
 .../controlprogram/paramserv/LocalPSWorker.java | 139 +++++--
 .../paramserv/LocalParamServer.java             |   8 +-
 .../controlprogram/paramserv/PSWorker.java      |   2 +-
 .../controlprogram/paramserv/ParamServer.java   |  42 ++-
 .../cp/ParamservBuiltinCPInstruction.java       |  61 ++-
 .../functions/paramserv/ParamservFuncTest.java  |  22 +-
 .../paramserv/mnist_lenet_paramserv.dml         |   4 +-
 .../paramserv/mnist_lenet_paramserv_asp.dml     | 376 -------------------
 .../mnist_lenet_paramserv_minimum_version.dml   |   4 +-
 .../paramserv/paramserv-nn-asp-batch.dml        |  52 +++
 .../paramserv/paramserv-nn-asp-epoch.dml        |  52 +++
 .../functions/paramserv/paramserv-nn-asp.dml    |  52 ---
 .../paramserv/paramserv-nn-bsp-batch.dml        |  52 +++
 .../paramserv/paramserv-nn-bsp-epoch.dml        |  52 +++
 .../functions/paramserv/paramserv-nn-test.dml   |  52 ---
 .../paramserv/paramserv-wrong-args.dml          |  19 +-
 17 files changed, 425 insertions(+), 588 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
index 99aec78..33c5b0e 100644
--- 
a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
@@ -334,22 +334,15 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                checkDataType(fname, Statement.PS_VAL_LABELS, DataType.MATRIX, 
conditional);
                checkDataValueType(false, fname, Statement.PS_UPDATE_FUN, 
DataType.SCALAR, ValueType.STRING, conditional);
                checkDataValueType(false, fname, Statement.PS_AGGREGATION_FUN, 
DataType.SCALAR, ValueType.STRING, conditional);
-               Set<String> modes = 
Arrays.stream(Statement.PSModeType.values()).map(Enum::name)
-                               .collect(Collectors.toSet());
-               checkStringParam(false, fname, Statement.PS_MODE, modes, 
conditional);
-               Set<String> utypes = 
Arrays.stream(Statement.PSUpdateType.values()).map(Enum::name)
-                               .collect(Collectors.toSet());
-               checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, 
utypes, conditional);
-               Set<String> frequencies = 
Arrays.stream(Statement.PSFrequency.values()).map(Enum::name).collect(Collectors.toSet());
-               checkStringParam(true, fname, Statement.PS_FREQUENCY, 
frequencies, conditional);
+               checkStringParam(false, fname, Statement.PS_MODE, conditional);
+               checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, 
conditional);
+               checkStringParam(true, fname, Statement.PS_FREQUENCY, 
conditional);
                checkDataValueType(false, fname, Statement.PS_EPOCHS, 
DataType.SCALAR, ValueType.INT, conditional);
                checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, 
DataType.SCALAR, ValueType.INT, conditional);
                checkDataValueType(true, fname, Statement.PS_PARALLELISM, 
DataType.SCALAR, ValueType.INT, conditional);
-               Set<String> schemes = 
Arrays.stream(Statement.PSScheme.values()).map(Enum::name).collect(Collectors.toSet());
-               checkStringParam(true, fname, Statement.PS_SCHEME, schemes, 
conditional);
+               checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
                checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, 
DataType.LIST, ValueType.UNKNOWN, conditional);
-               Set<String> checkpointings = 
Arrays.stream(Statement.PSCheckpointing.values()).map(Enum::name).collect(Collectors.toSet());
-               checkStringParam(true, fname, Statement.PS_CHECKPOINTING, 
checkpointings, conditional);
+               checkStringParam(true, fname, Statement.PS_CHECKPOINTING, 
conditional);
 
                // set output characteristics
                output.setDataType(DataType.LIST);
@@ -358,7 +351,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                output.setBlockDimensions(-1, -1);
        }
 
-       private void checkStringParam(boolean optional, String fname, String 
pname, Set<String> validOptions, boolean conditional) {
+       private void checkStringParam(boolean optional, String fname, String 
pname, boolean conditional) {
                Expression param = getVarParam(pname);
                if (param == null) {
                        if (optional) {
@@ -371,11 +364,6 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                                        String.format("Function %s should 
provide a string value for %s parameter.", fname, pname),
                                        conditional);
                }
-               StringIdentifier si = (StringIdentifier) param;
-               if (!validOptions.contains(si.getValue())) {
-                       raiseValidateError(String.format("Function %s does not 
support value '%s' as the '%s' parameter.", fname,
-                                       si.getValue(), pname), conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
-               }
        }
 
        // example: A = transformapply(target=X, meta=M, spec=s)

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index e902aea..1583fbf 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -42,59 +42,118 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        public Void call() throws Exception {
                try {
                        long dataSize = _features.getNumRows();
-                       for (int i = 0; i < _epochs; i++) {
-                               int totalIter = (int) Math.ceil(dataSize / 
_batchSize);
-                               for (int j = 0; j < totalIter; j++) {
-                                       // Pull the global parameters from ps
-                                       ListObject globalParams = 
(ListObject)_ps.pull(_workerID);
-                                       if (LOG.isDebugEnabled()) {
-                                               LOG.debug(String.format("Local 
worker_%d: Successfully pull the global parameters "
-                                                       + "[size:%d kb] from 
ps.", _workerID, globalParams.getDataSize() / 1024));
-                                       }
-                                       _ec.setVariable(Statement.PS_MODEL, 
globalParams);
+                       int totalIter = (int) Math.ceil(dataSize / _batchSize);
 
-                                       long begin = j * _batchSize + 1;
-                                       long end = Math.min(begin + _batchSize, 
dataSize);
+                       switch (_freq) {
+                               case BATCH:
+                                       computeBatch(dataSize, totalIter);
+                                       break;
+                               case EPOCH:
+                                       computeEpoch(dataSize, totalIter);
+                                       break;
+                       }
 
-                                       // Get batch features and labels
-                                       MatrixObject bFeatures = 
ParamservUtils.sliceMatrix(_features, begin, end);
-                                       MatrixObject bLabels = 
ParamservUtils.sliceMatrix(_labels, begin, end);
-                                       _ec.setVariable(Statement.PS_FEATURES, 
bFeatures);
-                                       _ec.setVariable(Statement.PS_LABELS, 
bLabels);
+                       if (LOG.isDebugEnabled()) {
+                               LOG.debug(String.format("Local worker_%d: Job 
finished.", _workerID));
+                       }
+               } catch (Exception e) {
+                       throw new DMLRuntimeException(String.format("Local 
worker_%d failed", _workerID), e);
+               }
+               return null;
+       }
 
-                                       if (LOG.isDebugEnabled()) {
-                                               LOG.debug(String.format("Local 
worker_%d: Got batch data [size:%d kb] of index from %d to %d. "
-                                                       + "[Epoch:%d  Total 
epoch:%d  Iteration:%d  Total iteration:%d]", _workerID, bFeatures.getDataSize()
-                                                       / 1024 + 
bLabels.getDataSize() / 1024, begin, end, i + 1, _epochs, j + 1, totalIter));
-                                       }
+       private void computeEpoch(long dataSize, int totalIter) {
+               for (int i = 0; i < _epochs; i++) {
+                       // Pull the global parameters from ps
+                       ListObject globalParams = pullModel();
 
-                                       // Invoke the update function
-                                       _inst.processInstruction(_ec);
+                       for (int j = 0; j < totalIter; j++) {
+                               _ec.setVariable(Statement.PS_MODEL, 
globalParams);
 
-                                       // Get the gradients
-                                       ListObject gradients = (ListObject) 
_ec.getVariable(_output.getName());
+                               ListObject gradients = 
computeGradients(dataSize, totalIter, i, j);
 
+                               if (j == totalIter - 1) {
                                        // Push the gradients to ps
-                                       _ps.push(_workerID, gradients);
+                                       pushGradients(gradients);
+                                       ParamservUtils.cleanupListObject(_ec, 
globalParams);
+                               } else {
+                                       // Update the local model with gradients
+                                       globalParams = 
_ps.updateModel(gradients, globalParams);
                                        if (LOG.isDebugEnabled()) {
-                                               LOG.debug(String.format("Local 
worker_%d: Successfully push the gradients "
-                                                       + "[size:%d kb] to 
ps.", _workerID, gradients.getDataSize() / 1024));
+                                               LOG.debug(String.format("Local 
worker_%d: Local global parameter [size:%d kb] updated.",
+                                                               _workerID, 
globalParams.getDataSize()));
                                        }
-
-                                       ParamservUtils.cleanupListObject(_ec, 
globalParams);
-                                       ParamservUtils.cleanupData(bFeatures);
-                                       ParamservUtils.cleanupData(bLabels);
-                               }
-                               if (LOG.isDebugEnabled()) {
-                                       LOG.debug(String.format("Local 
worker_%d: Finished %d epoch.", _workerID, i + 1));
                                }
                        }
                        if (LOG.isDebugEnabled()) {
-                               LOG.debug(String.format("Local worker_%d: Job 
finished.", _workerID));
+                               LOG.debug(String.format("Local worker_%d: 
Finished %d epoch.", _workerID, i + 1));
                        }
-               } catch (Exception e) {
-                       throw new DMLRuntimeException(String.format("Local 
worker_%d failed", _workerID), e);
                }
-               return null;
+
+       }
+
+       private void computeBatch(long dataSize, int totalIter) {
+               for (int i = 0; i < _epochs; i++) {
+                       for (int j = 0; j < totalIter; j++) {
+                               ListObject globalParams = pullModel();
+
+                               _ec.setVariable(Statement.PS_MODEL, 
globalParams);
+                               ListObject gradients = 
computeGradients(dataSize, totalIter, i, j);
+
+                               // Push the gradients to ps
+                               pushGradients(gradients);
+
+                               ParamservUtils.cleanupListObject(_ec, 
globalParams);
+                       }
+                       if (LOG.isDebugEnabled()) {
+                               LOG.debug(String.format("Local worker_%d: 
Finished %d epoch.", _workerID, i + 1));
+                       }
+               }
+       }
+
+       private ListObject pullModel() {
+               // Pull the global parameters from ps
+               ListObject globalParams = (ListObject)_ps.pull(_workerID);
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug(String.format("Local worker_%d: Successfully 
pull the global parameters "
+                               + "[size:%d kb] from ps.", _workerID, 
globalParams.getDataSize() / 1024));
+               }
+               return globalParams;
+       }
+
+       private void pushGradients(ListObject gradients) {
+               // Push the gradients to ps
+               _ps.push(_workerID, gradients);
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug(String.format("Local worker_%d: Successfully 
push the gradients "
+                                       + "[size:%d kb] to ps.", _workerID, 
gradients.getDataSize() / 1024));
+               }
+       }
+
+       private ListObject computeGradients(long dataSize, int totalIter, int 
i, int j) {
+               long begin = j * _batchSize + 1;
+               long end = Math.min(begin + _batchSize, dataSize);
+
+               // Get batch features and labels
+               MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, 
begin, end);
+               MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, 
begin, end);
+               _ec.setVariable(Statement.PS_FEATURES, bFeatures);
+               _ec.setVariable(Statement.PS_LABELS, bLabels);
+
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug(String.format("Local worker_%d: Got batch 
data [size:%d kb] of index from %d to %d. "
+                               + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]", _workerID, bFeatures.getDataSize()
+                               / 1024 + bLabels.getDataSize() / 1024, begin, 
end, i + 1, _epochs, j + 1, totalIter));
+               }
+
+               // Invoke the update function
+               _inst.processInstruction(_ec);
+
+               // Get the gradients
+               ListObject gradients = (ListObject) 
_ec.getVariable(_output.getName());
+
+               ParamservUtils.cleanupData(bFeatures);
+               ParamservUtils.cleanupData(bLabels);
+               return gradients;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
index 665395e..bac507c 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -29,9 +29,9 @@ import org.apache.sysml.runtime.instructions.cp.ListObject;
 
 public class LocalParamServer extends ParamServer {
 
-       public LocalParamServer(ListObject model, String aggFunc, 
Statement.PSFrequency freq,
-                       Statement.PSUpdateType updateType, ExecutionContext ec, 
int workerNum) {
-               super(model, aggFunc, freq, updateType, ec, workerNum);
+       public LocalParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType, ExecutionContext ec,
+                       int workerNum) {
+               super(model, aggFunc, updateType, ec, workerNum);
        }
 
        @Override
@@ -52,7 +52,7 @@ public class LocalParamServer extends ParamServer {
        public Data pull(int workerID) {
                ListObject model;
                try {
-                       model = _modelMap.get((int) workerID).take();
+                       model = _modelMap.get(workerID).take();
                } catch (InterruptedException e) {
                        throw new DMLRuntimeException(e);
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index 56fda22..affa3c1 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -49,7 +49,7 @@ public abstract class PSWorker {
        private MatrixObject _valFeatures;
        private MatrixObject _valLabels;
        private final String _updFunc;
-       private final Statement.PSFrequency _freq;
+       protected final Statement.PSFrequency _freq;
        
        protected PSWorker(int workerID, String updFunc, Statement.PSFrequency 
freq,
                int epochs, long batchSize, ExecutionContext ec, ParamServer 
ps) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index 1052390..d7cd78d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -56,8 +56,7 @@ public abstract class ParamServer {
        private final ExecutorService _es;
        private ListObject _model;
 
-       ParamServer(ListObject model, String aggFunc, Statement.PSFrequency 
freq, Statement.PSUpdateType updateType,
-                       ExecutionContext ec, int workerNum) {
+       ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType 
updateType, ExecutionContext ec, int workerNum) {
                _gradientsQueue = new LinkedBlockingDeque<>();
                _modelMap = new HashMap<>(workerNum);
                IntStream.range(0, workerNum).forEach(i -> {
@@ -65,7 +64,7 @@ public abstract class ParamServer {
                        _modelMap.put(i, new ArrayBlockingQueue<>(1));
                });
                _model = model;
-               _aggService = new AggregationService(aggFunc, freq, updateType, 
ec, workerNum);
+               _aggService = new AggregationService(aggFunc, updateType, ec, 
workerNum);
                try {
                        _aggService.broadcastModel();
                }
@@ -91,6 +90,10 @@ public abstract class ParamServer {
                return _model;
        }
 
+       public ListObject updateModel(ListObject gradients, ListObject model) {
+               return _aggService.updateModel(gradients, model);
+       }
+
        public static class Gradient {
                final int _workerID;
                final ListObject _gradients;
@@ -109,16 +112,13 @@ public abstract class ParamServer {
                protected final Log LOG = 
LogFactory.getLog(AggregationService.class.getName());
 
                protected ExecutionContext _ec;
-               //private Statement.PSFrequency _freq;
                private Statement.PSUpdateType _updateType;
                private FunctionCallCPInstruction _inst;
                private DataIdentifier _output;
                private boolean[] _finishedStates;  // Workers' finished states
 
-               AggregationService(String aggFunc, Statement.PSFrequency freq, 
Statement.PSUpdateType updateType,
-                               ExecutionContext ec, int workerNum) {
+               AggregationService(String aggFunc, Statement.PSUpdateType 
updateType, ExecutionContext ec, int workerNum) {
                        _ec = ec;
-                       //_freq = freq;
                        _updateType = updateType;
                        _finishedStates = new boolean[workerNum];
 
@@ -192,13 +192,13 @@ public abstract class ParamServer {
                                }
 
                                // Update and redistribute the model
-                               updateModel(grad);
-                               
+                               _model = updateModel(grad._gradients, _model);
+
                                // Redistribute model according to update type
-                               switch( _updateType ) {
+                               switch(_updateType) {
                                        case BSP: {
                                                
setFinishedState(grad._workerID);
-                                               if( allFinished() ) {
+                                               if (allFinished()) {
                                                        // Broadcast the 
updated model
                                                        resetFinishedStates();
                                                        broadcastModel();
@@ -212,7 +212,7 @@ public abstract class ParamServer {
                                                break;
                                        }
                                        default:
-                                               throw new 
DMLRuntimeException("Unsupported update: "+_updateType.name());
+                                               throw new 
DMLRuntimeException("Unsupported update: " + _updateType.name());
                                }
                        } 
                        catch (Exception e) {
@@ -221,10 +221,16 @@ public abstract class ParamServer {
                        return null;
                }
 
-               private void updateModel(Gradient grad) throws 
InterruptedException {
+               /**
+                * A synchronized service method for updating model with 
gradients
+                *
+                * @param gradients A list object of gradients
+                * @return A updated list object of model
+                */
+               private synchronized ListObject updateModel(ListObject 
gradients, ListObject model) {
                        // Populate the variables table with the gradients and 
model
-                       _ec.setVariable(Statement.PS_GRADIENTS, 
grad._gradients);
-                       _ec.setVariable(Statement.PS_MODEL, _model);
+                       _ec.setVariable(Statement.PS_GRADIENTS, gradients);
+                       _ec.setVariable(Statement.PS_MODEL, model);
 
                        // Invoke the aggregate function
                        _inst.processInstruction(_ec);
@@ -233,9 +239,9 @@ public abstract class ParamServer {
                        ListObject newModel = (ListObject) 
_ec.getVariable(_output.getName());
 
                        // Update the model with the new output
-                       ParamservUtils.cleanupListObject(_ec, _model);
-                       ParamservUtils.cleanupListObject(_ec, grad._gradients);
-                       _model = newModel;
+                       ParamservUtils.cleanupListObject(_ec, model);
+                       ParamservUtils.cleanupListObject(_ec, gradients);
+                       return newModel;
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 79d1ff3..6e2b187 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -51,6 +51,8 @@ import java.util.concurrent.Future;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
 import org.apache.sysml.hops.Hop;
@@ -86,11 +88,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
        //internal local debug level
        private static final boolean LDEBUG = false;
+       protected static final Log LOG = 
LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName());
+
 
        static {
                // for internal debugging only
                if (LDEBUG) {
                        
Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel(Level.DEBUG);
+                       
Logger.getLogger(ParamservBuiltinCPInstruction.class.getName()).setLevel(Level.DEBUG);
                }
        }
 
@@ -100,7 +105,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-               PSModeType mode = PSModeType.valueOf(getParam(PS_MODE));
+               PSModeType mode = getPSMode();
                int workerNum = getWorkerNum(mode);
                ExecutorService es = Executors.newFixedThreadPool(workerNum);
                String updFunc = getParam(PS_UPDATE_FUN);
@@ -119,7 +124,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                // Create the parameter server
                ListObject model = ec.getListObject(getParam(PS_MODEL));
-               ParamServer ps = createPS(mode, aggFunc, freq, updateType, 
workerNum, model, aggServiceEC);
+               ParamServer ps = createPS(mode, aggFunc, updateType, workerNum, 
model, aggServiceEC);
 
                // Create the local workers
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
@@ -129,9 +134,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                // Do data partition
                doDataPartition(ec, workers);
 
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug(String.format("\nConfiguration of paramserv 
func: \nmode: %s \nworkerNum: %d \nupdate frequency: %s \nstrategy: %s",
+                                       mode, workerNum, freq, updateType));
+               }
+
                // Launch the worker threads and wait for completion
                try {
-                       for( Future<Void> ret : es.invokeAll(workers) )
+                       for (Future<Void> ret : es.invokeAll(workers))
                                ret.get(); //error handling
                } catch (InterruptedException | ExecutionException e) {
                        throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
@@ -145,6 +155,18 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                ec.setVariable(output.getName(), result);
        }
 
+       private PSModeType getPSMode() {
+               PSModeType mode;
+               try {
+                       mode = PSModeType.valueOf(getParam(PS_MODE));
+               } catch (IllegalArgumentException e) {
+                       throw new DMLRuntimeException(String.format("Paramserv 
function: not support ps execution mode '%s'", getParam(PS_MODE)));
+               }
+               if( mode == PSModeType.REMOTE_SPARK )
+                       throw new DMLRuntimeException("Do not support remote 
spark.");
+               return mode;
+       }
+
        private int getEpochs() {
                int epochs = Integer.valueOf(getParam(PS_EPOCHS));
                if (epochs <= 0) {
@@ -224,7 +246,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        else if (pb instanceof IfProgramBlock) {
                                IfProgramBlock ipb = (IfProgramBlock) pb;
                                recompiled |= 
rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled);
-                               if( ipb.getChildBlocksElseBody() != null )
+                               if (ipb.getChildBlocksElseBody() != null)
                                        recompiled |= 
rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled);
                        }
                        else {
@@ -259,9 +281,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private PSUpdateType getUpdateType() {
-               PSUpdateType updType = 
PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE));
-               if( updType == PSUpdateType.SSP )
-                       throw new DMLRuntimeException(String.format("Not 
support update type '%s'.", updType));
+               PSUpdateType updType;
+               try {
+                       updType = 
PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE));
+               } catch (IllegalArgumentException e) {
+                       throw new DMLRuntimeException(String.format("Paramserv 
function: not support update type '%s'.", getParam(PS_UPDATE_TYPE)));
+               }
+               if (updType == PSUpdateType.SSP)
+                       throw new DMLRuntimeException("Not support update type 
SSP.");
                return updType;
        }
 
@@ -269,9 +296,12 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                if (!getParameterMap().containsKey(PS_FREQUENCY)) {
                        return DEFAULT_UPDATE_FREQUENCY;
                }
-               PSFrequency freq = PSFrequency.valueOf(getParam(PS_FREQUENCY));
-               if( freq == PSFrequency.EPOCH )
-                       throw new DMLRuntimeException("Not support epoch update 
frequency.");
+               PSFrequency freq;
+               try {
+                       freq = PSFrequency.valueOf(getParam(PS_FREQUENCY));
+               } catch (IllegalArgumentException e) {
+                       throw new DMLRuntimeException(String.format("Paramserv 
function: not support '%s' update frequency.", getParam(PS_FREQUENCY)));
+               }
                return freq;
        }
 
@@ -306,12 +336,11 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         *
         * @return parameter server
         */
-       private ParamServer createPS(PSModeType mode, String aggFunc, 
PSFrequency freq, PSUpdateType updateType,
-                       int workerNum, ListObject model, ExecutionContext ec) {
+       private ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
                ParamServer ps = null;
                switch (mode) {
                        case LOCAL:
-                               ps = new LocalParamServer(model, aggFunc, freq, 
updateType, ec, workerNum);
+                               ps = new LocalParamServer(model, aggFunc, 
updateType, ec, workerNum);
                                break;
                        case REMOTE_SPARK:
                                throw new DMLRuntimeException("Do not support 
remote spark.");
@@ -346,7 +375,11 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                MatrixObject valLabels = 
ec.getMatrixObject(getParam(PS_VAL_LABELS));
                PSScheme scheme = DEFAULT_SCHEME;
                if (getParameterMap().containsKey(PS_SCHEME)) {
-                       scheme = PSScheme.valueOf(getParam(PS_SCHEME));
+                       try {
+                               scheme = PSScheme.valueOf(getParam(PS_SCHEME));
+                       } catch (IllegalArgumentException e) {
+                               throw new 
DMLRuntimeException(String.format("Paramserv function: not support data 
partition scheme '%s'", getParam(PS_SCHEME)));
+                       }
                }
                switch (scheme) {
                case DISJOINT_CONTIGUOUS:

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
index 28b525a..3185621 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
@@ -32,13 +32,15 @@ public class ParamservFuncTest extends AutomatedTestBase {
        private static final String TEST_NAME4 = "paramserv-wrong-type-args";
        private static final String TEST_NAME5 = "paramserv-wrong-args";
        private static final String TEST_NAME6 = "paramserv-wrong-args2";
-       private static final String TEST_NAME7 = "paramserv-nn-test";
+       private static final String TEST_NAME7 = "paramserv-nn-bsp-batch";
        private static final String TEST_NAME8 = "paramserv-minimum-version";
        private static final String TEST_NAME9 = "paramserv-worker-failed";
        private static final String TEST_NAME10 = 
"paramserv-agg-service-failed";
        private static final String TEST_NAME11 = "paramserv-large-parallelism";
        private static final String TEST_NAME12 = 
"paramserv-wrong-aggregate-func";
-       private static final String TEST_NAME13 = "paramserv-nn-asp";
+       private static final String TEST_NAME13 = "paramserv-nn-asp-batch";
+       private static final String TEST_NAME14 = "paramserv-nn-bsp-epoch";
+       private static final String TEST_NAME15 = "paramserv-nn-asp-epoch";
 
        private static final String TEST_DIR = "functions/paramserv/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
ParamservFuncTest.class.getSimpleName() + "/";
@@ -60,6 +62,8 @@ public class ParamservFuncTest extends AutomatedTestBase {
                addTestConfiguration(TEST_NAME11, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME11, new String[] {}));
                addTestConfiguration(TEST_NAME12, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME12, new String[] {}));
                addTestConfiguration(TEST_NAME13, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME13, new String[] {}));
+               addTestConfiguration(TEST_NAME14, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME14, new String[] {}));
+               addTestConfiguration(TEST_NAME15, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME15, new String[] {}));
        }
 
        @Test
@@ -86,7 +90,7 @@ public class ParamservFuncTest extends AutomatedTestBase {
 
        @Test
        public void testParamservWrongArgs() {
-               final String errmsg = "Function PARAMSERV does not support 
value 'NSP' as the 'utype' parameter.";
+               final String errmsg = "Paramserv function: not support update 
type 'NSP'.";
                runDMLTest(TEST_NAME5, true, DMLException.class, errmsg);
        }
 
@@ -97,7 +101,7 @@ public class ParamservFuncTest extends AutomatedTestBase {
        }
 
        @Test
-       public void testParamservNNTest() {
+       public void testParamservNNBspBatchTest() {
                runDMLTest(TEST_NAME7, false, null, null);
        }
 
@@ -132,6 +136,16 @@ public class ParamservFuncTest extends AutomatedTestBase {
                runDMLTest(TEST_NAME13, false, null, null);
        }
 
+       @Test
+       public void testParamservBSPEpochTest() {
+               runDMLTest(TEST_NAME14, false, null, null);
+       }
+
+       @Test
+       public void testParamservASPEpochTest() {
+               runDMLTest(TEST_NAME15, false, null, null);
+       }
+
        private void runDMLTest(String testname, boolean exceptionExpected, 
Class<?> exceptionClass, String errmsg) {
                TestConfiguration config = getTestConfiguration(testname);
                loadTestConfiguration(config);

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
index 4ea6e5f..041c2bf 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -35,7 +35,7 @@ source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
 
 train = function(matrix[double] X, matrix[double] Y,
                  matrix[double] X_val, matrix[double] Y_val,
-                 int C, int Hin, int Win, int epochs, int workers)
+                 int C, int Hin, int Win, int epochs, int workers, string 
utype, string freq)
     return (matrix[double] W1, matrix[double] b1,
             matrix[double] W2, matrix[double] b2,
             matrix[double] W3, matrix[double] b3,
@@ -107,7 +107,7 @@ train = function(matrix[double] X, matrix[double] Y,
   params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
 
   # Use paramserv function
-  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation",
 mode="LOCAL", utype="BSP", freq="BATCH", epochs=epochs, batchsize=64, 
k=workers, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, 
checkpointing="NONE")
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation",
 mode="LOCAL", utype=utype, freq=freq, epochs=epochs, batchsize=64, k=workers, 
scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE")
 
   W1 = as.matrix(modelList2["W1"])
   b1 = as.matrix(modelList2["b1"])

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml
deleted file mode 100644
index b2e155e..0000000
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml
+++ /dev/null
@@ -1,376 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-/*
- * MNIST LeNet Example
- */
-# Imports
-source("nn/layers/affine.dml") as affine
-source("nn/layers/conv2d_builtin.dml") as conv2d
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-source("nn/layers/dropout.dml") as dropout
-source("nn/layers/l2_reg.dml") as l2_reg
-source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
-source("nn/layers/relu.dml") as relu
-source("nn/layers/softmax.dml") as softmax
-source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
-
-train = function(matrix[double] X, matrix[double] Y,
-                 matrix[double] X_val, matrix[double] Y_val,
-                 int C, int Hin, int Win, int epochs, int workers)
-    return (matrix[double] W1, matrix[double] b1,
-            matrix[double] W2, matrix[double] b2,
-            matrix[double] W3, matrix[double] b3,
-            matrix[double] W4, matrix[double] b4) {
-  /*
-   * Trains a convolutional net using the "LeNet" architecture.
-   *
-   * The input matrix, X, has N examples, each represented as a 3D
-   * volume unrolled into a single vector.  The targets, Y, have K
-   * classes, and are one-hot encoded.
-   *
-   * Inputs:
-   *  - X: Input data matrix, of shape (N, C*Hin*Win).
-   *  - Y: Target matrix, of shape (N, K).
-   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
-   *  - Y_val: Target validation matrix, of shape (N, K).
-   *  - C: Number of input channels (dimensionality of input depth).
-   *  - Hin: Input height.
-   *  - Win: Input width.
-   *  - epochs: Total number of full training loops over the full data set.
-   *
-   * Outputs:
-   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
-   *  - b1: 1st layer biases vector, of shape (F1, 1).
-   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
-   *  - b2: 2nd layer biases vector, of shape (F2, 1).
-   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
-   *  - b3: 3rd layer biases vector, of shape (1, N3).
-   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
-   *  - b4: 4th layer biases vector, of shape (1, K).
-   */
-  N = nrow(X)
-  K = ncol(Y)
-
-  # Create network:
-  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
-  Hf = 5  # filter height
-  Wf = 5  # filter width
-  stride = 1
-  pad = 2  # For same dimensions, (Hf - stride) / 2
-
-  F1 = 32  # num conv filters in conv1
-  F2 = 64  # num conv filters in conv2
-  N3 = 512  # num nodes in affine3
-  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
-
-  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
-  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
-  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
-  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
-  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
-
-  # Initialize SGD w/ Nesterov momentum optimizer
-  lr = 0.01  # learning rate
-  mu = 0.9  #0.5  # momentum
-  decay = 0.95  # learning rate decay constant
-  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
-  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
-  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
-  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
-
-  # Regularization
-  lambda = 5e-04
-
-  # Create the model object
-  modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, 
vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
-
-  # Create the hyper parameter list
-  params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
-
-  # Use paramserv function
-  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml::aggregation",
 mode="LOCAL", utype="ASP", epochs=epochs, hyperparams=params)
-
-  W1 = as.matrix(modelList2["W1"])
-  b1 = as.matrix(modelList2["b1"])
-  W2 = as.matrix(modelList2["W2"])
-  b2 = as.matrix(modelList2["b2"])
-  W3 = as.matrix(modelList2["W3"])
-  b3 = as.matrix(modelList2["b3"])
-  W4 = as.matrix(modelList2["W4"])
-  b4 = as.matrix(modelList2["b4"])
-
-}
-
-gradients = function(matrix[double] features,
-                     matrix[double] labels,
-                     list[unknown] hyperparams,
-                     list[unknown] model)
-          return (list[unknown] gradients) {
-
-  C = 1
-  Hin = 28
-  Win = 28
-  Hf = 5
-  Wf = 5
-  stride = 1
-  pad = 2
-  lambda = 5e-04
-  F1 = 32
-  F2 = 64
-  N3 = 512
-  W1 = as.matrix(model["W1"])
-  b1 = as.matrix(model["b1"])
-  W2 = as.matrix(model["W2"])
-  b2 = as.matrix(model["b2"])
-  W3 = as.matrix(model["W3"])
-  b3 = as.matrix(model["b3"])
-  W4 = as.matrix(model["W4"])
-  b4 = as.matrix(model["b4"])
-
-  # Compute forward pass
-  ## layer 1: conv1 -> relu1 -> pool1
-  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, 
Wf,
-                                              stride, stride, pad, pad)
-  outr1 = relu::forward(outc1)
-  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
-                                                strideh=2, stridew=2, pad=0, 
pad=0)
-  ## layer 2: conv2 -> relu2 -> pool2
-  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf,
-                                            stride, stride, pad, pad)
-  outr2 = relu::forward(outc2)
-  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
-                                                strideh=2, stridew=2, pad=0, 
pad=0)
-  ## layer 3:  affine3 -> relu3 -> dropout
-  outa3 = affine::forward(outp2, W3, b3)
-  outr3 = relu::forward(outa3)
-  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
-  ## layer 4:  affine4 -> softmax
-  outa4 = affine::forward(outd3, W4, b4)
-  probs = softmax::forward(outa4)
-
-  # Compute data backward pass
-  ## loss:
-  dprobs = cross_entropy_loss::backward(probs, labels)
-  ## layer 4:  affine4 -> softmax
-  douta4 = softmax::backward(dprobs, outa4)
-  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
-  ## layer 3:  affine3 -> relu3 -> dropout
-  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
-  douta3 = relu::backward(doutr3, outa3)
-  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
-  ## layer 2: conv2 -> relu2 -> pool2
-  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, Hf=2, Wf=2,
-                                strideh=2, stridew=2, pad=0, pad=0)
-  doutc2 = relu::backward(doutr2, outc2)
-  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, 
F1,
-                                        Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
-  ## layer 1: conv1 -> relu1 -> pool1
-  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, Hf=2, Wf=2,
-                                strideh=2, stridew=2, pad=0, pad=0)
-  doutc1 = relu::backward(doutr1, outc1)
-  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, 
W1, b1, C, Hin, Win,
-                                          Hf, Wf, stride, stride, pad, pad)
-
-  # Compute regularization backward pass
-  dW1_reg = l2_reg::backward(W1, lambda)
-  dW2_reg = l2_reg::backward(W2, lambda)
-  dW3_reg = l2_reg::backward(W3, lambda)
-  dW4_reg = l2_reg::backward(W4, lambda)
-  dW1 = dW1 + dW1_reg
-  dW2 = dW2 + dW2_reg
-  dW3 = dW3 + dW3_reg
-  dW4 = dW4 + dW4_reg
-
-  gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, 
db3=db3, db4=db4)
-
-}
-
-aggregation = function(list[unknown] model,
-                       list[unknown] gradients,
-                       list[unknown] hyperparams)
-   return (list[unknown] modelResult) {
-
-     W1 = as.matrix(model["W1"])
-     W2 = as.matrix(model["W2"])
-     W3 = as.matrix(model["W3"])
-     W4 = as.matrix(model["W4"])
-     b1 = as.matrix(model["b1"])
-     b2 = as.matrix(model["b2"])
-     b3 = as.matrix(model["b3"])
-     b4 = as.matrix(model["b4"])
-     dW1 = as.matrix(gradients["dW1"])
-     dW2 = as.matrix(gradients["dW2"])
-     dW3 = as.matrix(gradients["dW3"])
-     dW4 = as.matrix(gradients["dW4"])
-     db1 = as.matrix(gradients["db1"])
-     db2 = as.matrix(gradients["db2"])
-     db3 = as.matrix(gradients["db3"])
-     db4 = as.matrix(gradients["db4"])
-     vW1 = as.matrix(model["vW1"])
-     vW2 = as.matrix(model["vW2"])
-     vW3 = as.matrix(model["vW3"])
-     vW4 = as.matrix(model["vW4"])
-     vb1 = as.matrix(model["vb1"])
-     vb2 = as.matrix(model["vb2"])
-     vb3 = as.matrix(model["vb3"])
-     vb4 = as.matrix(model["vb4"])
-     lr = 0.01
-     mu = 0.9
-
-     # Optimize with SGD w/ Nesterov momentum
-     [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
-     [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
-     [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
-     [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
-     [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
-     [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
-     [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
-     [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
-
-     modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, 
b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
-   }
-
-predict = function(matrix[double] X, int C, int Hin, int Win,
-                   matrix[double] W1, matrix[double] b1,
-                   matrix[double] W2, matrix[double] b2,
-                   matrix[double] W3, matrix[double] b3,
-                   matrix[double] W4, matrix[double] b4)
-    return (matrix[double] probs) {
-  /*
-   * Computes the class probability predictions of a convolutional
-   * net using the "LeNet" architecture.
-   *
-   * The input matrix, X, has N examples, each represented as a 3D
-   * volume unrolled into a single vector.
-   *
-   * Inputs:
-   *  - X: Input data matrix, of shape (N, C*Hin*Win).
-   *  - C: Number of input channels (dimensionality of input depth).
-   *  - Hin: Input height.
-   *  - Win: Input width.
-   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
-   *  - b1: 1st layer biases vector, of shape (F1, 1).
-   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
-   *  - b2: 2nd layer biases vector, of shape (F2, 1).
-   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
-   *  - b3: 3rd layer biases vector, of shape (1, N3).
-   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
-   *  - b4: 4th layer biases vector, of shape (1, K).
-   *
-   * Outputs:
-   *  - probs: Class probabilities, of shape (N, K).
-   */
-  N = nrow(X)
-
-  # Network:
-  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
-  Hf = 5  # filter height
-  Wf = 5  # filter width
-  stride = 1
-  pad = 2  # For same dimensions, (Hf - stride) / 2
-
-  F1 = nrow(W1)  # num conv filters in conv1
-  F2 = nrow(W2)  # num conv filters in conv2
-  N3 = ncol(W3)  # num nodes in affine3
-  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions 
(num classes)
-
-  # Compute predictions over mini-batches
-  probs = matrix(0, rows=N, cols=K)
-  batch_size = 64
-  iters = ceil(N / batch_size)
-  for(i in 1:iters) {
-    # Get next batch
-    beg = ((i-1) * batch_size) %% N + 1
-    end = min(N, beg + batch_size - 1)
-    X_batch = X[beg:end,]
-
-    # Compute forward pass
-    ## layer 1: conv1 -> relu1 -> pool1
-    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
-                                              pad, pad)
-    outr1 = relu::forward(outc1)
-    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
-                                                  strideh=2, stridew=2, pad=0, 
pad=0)
-    ## layer 2: conv2 -> relu2 -> pool2
-    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
-                                              stride, stride, pad, pad)
-    outr2 = relu::forward(outc2)
-    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
-                                                  strideh=2, stridew=2, pad=0, 
pad=0)
-    ## layer 3:  affine3 -> relu3
-    outa3 = affine::forward(outp2, W3, b3)
-    outr3 = relu::forward(outa3)
-    ## layer 4:  affine4 -> softmax
-    outa4 = affine::forward(outr3, W4, b4)
-    probs_batch = softmax::forward(outa4)
-
-    # Store predictions
-    probs[beg:end,] = probs_batch
-  }
-}
-
-eval = function(matrix[double] probs, matrix[double] Y)
-    return (double loss, double accuracy) {
-  /*
-   * Evaluates a convolutional net using the "LeNet" architecture.
-   *
-   * The probs matrix contains the class probability predictions
-   * of K classes over N examples.  The targets, Y, have K classes,
-   * and are one-hot encoded.
-   *
-   * Inputs:
-   *  - probs: Class probabilities, of shape (N, K).
-   *  - Y: Target matrix, of shape (N, K).
-   *
-   * Outputs:
-   *  - loss: Scalar loss, of shape (1).
-   *  - accuracy: Scalar accuracy, of shape (1).
-   */
-  # Compute loss & accuracy
-  loss = cross_entropy_loss::forward(probs, Y)
-  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
-  accuracy = mean(correct_pred)
-}
-
-generate_dummy_data = function()
-    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
-  /*
-   * Generate a dummy dataset similar to the MNIST dataset.
-   *
-   * Outputs:
-   *  - X: Input data matrix, of shape (N, D).
-   *  - Y: Target matrix, of shape (N, K).
-   *  - C: Number of input channels (dimensionality of input depth).
-   *  - Hin: Input height.
-   *  - Win: Input width.
-   */
-  # Generate dummy input data
-  N = 1024  # num examples
-  C = 1  # num input channels
-  Hin = 28  # input height
-  Win = 28  # input width
-  K = 10  # num target classes
-  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
-  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
-  Y = table(seq(1, N), classes)  # one-hot encoding
-}
-

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
index 707722e..d02e5d6 100644
--- 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
+++ 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -233,8 +233,8 @@ aggregation = function(list[unknown] model,
      vb2 = as.matrix(model["vb2"])
      vb3 = as.matrix(model["vb3"])
      vb4 = as.matrix(model["vb4"])
-     lr = as.scalar(hyperparams['lr']);
-     mu = as.scalar(hyperparams['mu']);
+     lr = 0.01
+     mu = 0.9
 
      # Optimize with SGD w/ Nesterov momentum
      [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
new file mode 100644
index 0000000..346cc08
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Arguments
+epochs = 10
+workers = 2
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "ASP", "BATCH")
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
new file mode 100644
index 0000000..8d553ae
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Arguments
+epochs = 10
+workers = 2
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "ASP", "EPOCH")
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml
deleted file mode 100644
index b50e17c..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml
+++ /dev/null
@@ -1,52 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml") 
as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers)
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml
new file mode 100644
index 0000000..7b6523b
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Arguments
+epochs = 10
+workers = 2
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH")
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
new file mode 100644
index 0000000..d0a6570
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Arguments
+epochs = 10
+workers = 2
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "EPOCH")
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-nn-test.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-test.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-test.dml
deleted file mode 100644
index 740a208..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-test.dml
+++ /dev/null
@@ -1,52 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers)
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/51057e47/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml 
b/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml
index 13a05c9..8f5f53e 100644
--- a/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-wrong-args.dml
@@ -26,16 +26,25 @@ Y = matrix(2, rows=2, cols=3)
 X_val = matrix(3, rows=2, cols=3)
 Y_val = matrix(4, rows=2, cols=3)
 
-gradients = function (matrix[double] input) return (matrix[double] output) {
-  output = input
+gradients = function(matrix[double] features,
+                     matrix[double] wrong_labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+            return (list[unknown] gradients) {
+  gradients = model
 }
 
-aggregation = function (matrix[double] input) return (matrix[double] output) {
-  output = input
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+              return (list[unknown] modelResult) {
+  modelResult = model
 }
 
 e2 = "element2"
 params = list(e2)
 
 # Use paramserv function
-modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="LOCAL", utype="NSP", freq="EPOCH", epochs=100, batchsize=64, k=7, 
scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE")
\ No newline at end of file
+modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="LOCAL", utype="NSP", epochs=10, hyperparams=params)
+
+print(toString(as.matrix(modelList2[1])))
\ No newline at end of file

Reply via email to