[SYSTEMML-2085] Initial version of local backend for paramserv builtin

Closes #771.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/97018d4e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/97018d4e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/97018d4e

Branch: refs/heads/master
Commit: 97018d4e688ba7eeaaa4567ca1e174a3c5525468
Parents: c7a9e01
Author: EdgarLGB <[email protected]>
Authored: Mon May 28 23:17:18 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon May 28 23:17:20 2018 -0700

----------------------------------------------------------------------
 .../ParameterizedBuiltinFunctionExpression.java |   6 +-
 .../java/org/apache/sysml/parser/Statement.java |   1 +
 .../context/ExecutionContext.java               |  13 +-
 .../controlprogram/paramserv/LocalPSWorker.java |  97 +++++
 .../paramserv/LocalParamServer.java             |  59 +++
 .../controlprogram/paramserv/PSWorker.java      | 131 +++++++
 .../controlprogram/paramserv/ParamServer.java   | 232 +++++++++++
 .../paramserv/ParamservUtils.java               |  97 +++++
 .../runtime/instructions/cp/CPOperand.java      |   2 +-
 .../runtime/instructions/cp/ListObject.java     |  14 +
 .../cp/MatrixIndexingCPInstruction.java         |   4 +-
 .../cp/ParamservBuiltinCPInstruction.java       | 257 ++++++++++++-
 .../test/integration/AutomatedTestBase.java     |  18 +-
 .../functions/paramserv/ParamservFuncTest.java  |  29 +-
 .../paramserv/mnist_lenet_paramserv.dml         | 383 +++++++++++++++++++
 .../mnist_lenet_paramserv_minimum_version.dml   | 377 ++++++++++++++++++
 .../functions/paramserv/paramserv-all-args.dml  |   4 +-
 .../functions/paramserv/paramserv-ipa-test.dml  |  47 ---
 .../paramserv/paramserv-minimum-version.dml     |  52 +++
 .../functions/paramserv/paramserv-miss-args.dml |   4 +-
 .../functions/paramserv/paramserv-nn-test.dml   |  52 +++
 .../paramserv-without-optional-args.dml         |   4 +-
 22 files changed, 1805 insertions(+), 78 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
index 3d74f8d..99aec78 100644
--- 
a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
@@ -341,12 +341,12 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                                .collect(Collectors.toSet());
                checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, 
utypes, conditional);
                Set<String> frequencies = 
Arrays.stream(Statement.PSFrequency.values()).map(Enum::name).collect(Collectors.toSet());
-               checkStringParam(false, fname, Statement.PS_FREQUENCY, 
frequencies, conditional);
+               checkStringParam(true, fname, Statement.PS_FREQUENCY, 
frequencies, conditional);
                checkDataValueType(false, fname, Statement.PS_EPOCHS, 
DataType.SCALAR, ValueType.INT, conditional);
                checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, 
DataType.SCALAR, ValueType.INT, conditional);
-               checkDataValueType(false, fname, Statement.PS_PARALLELISM, 
DataType.SCALAR, ValueType.INT, conditional);
+               checkDataValueType(true, fname, Statement.PS_PARALLELISM, 
DataType.SCALAR, ValueType.INT, conditional);
                Set<String> schemes = 
Arrays.stream(Statement.PSScheme.values()).map(Enum::name).collect(Collectors.toSet());
-               checkStringParam(false, fname, Statement.PS_SCHEME, schemes, 
conditional);
+               checkStringParam(true, fname, Statement.PS_SCHEME, schemes, 
conditional);
                checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, 
DataType.LIST, ValueType.UNKNOWN, conditional);
                Set<String> checkpointings = 
Arrays.stream(Statement.PSCheckpointing.values()).map(Enum::name).collect(Collectors.toSet());
                checkStringParam(true, fname, Statement.PS_CHECKPOINTING, 
checkpointings, conditional);

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/parser/Statement.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Statement.java 
b/src/main/java/org/apache/sysml/parser/Statement.java
index 4853a47..1987d31 100644
--- a/src/main/java/org/apache/sysml/parser/Statement.java
+++ b/src/main/java/org/apache/sysml/parser/Statement.java
@@ -71,6 +71,7 @@ public abstract class Statement implements ParseInfo
        public static final String PS_UPDATE_FUN = "upd";
        public static final String PS_AGGREGATION_FUN = "agg";
        public static final String PS_MODE = "mode";
+       public static final String PS_GRADIENTS = "gradients";
        public enum PSModeType {
                LOCAL, REMOTE_SPARK
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index 67b2a83..6807848 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -42,6 +42,7 @@ import org.apache.sysml.runtime.instructions.cp.CPInstruction;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
@@ -443,7 +444,17 @@ public class ExecutionContext {
        public void setScalarOutput(String varName, ScalarObject so) {
                setVariable(varName, so);
        }
-       
+
+       public ListObject getListObject(String name) {
+               Data dat = getVariable(name);
+               //error handling if non existing or no list
+               if (dat == null)
+                       throw new DMLRuntimeException("Variable '" + name + "' 
does not exist in the symbol table.");
+               if (!(dat instanceof ListObject))
+                       throw new DMLRuntimeException("Variable '" + name + "' 
is not a list.");
+               return (ListObject) dat;
+       }
+
        public void releaseMatrixOutputForGPUInstruction(String varName) {
                MatrixObject mo = getMatrixObject(varName);
                if(mo.getGPUObject(getGPUContext(0)) == null || 
!mo.getGPUObject(getGPUContext(0)).isAllocated()) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
new file mode 100644
index 0000000..181b866
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public class LocalPSWorker extends PSWorker implements Runnable {
+
+       protected static final Log LOG = 
LogFactory.getLog(LocalPSWorker.class.getName());
+
+       public LocalPSWorker(long workerID, String updFunc, 
Statement.PSFrequency freq, int epochs, long batchSize,
+                       ListObject hyperParams, ExecutionContext ec, 
ParamServer ps) {
+               super(workerID, updFunc, freq, epochs, batchSize, hyperParams, 
ec, ps);
+       }
+
+       @Override
+       public void run() {
+
+               long dataSize = _features.getNumRows();
+
+               for (int i = 0; i < _epochs; i++) {
+                       int totalIter = (int) Math.ceil(dataSize / _batchSize);
+                       for (int j = 0; j < totalIter; j++) {
+                               // Pull the global parameters from ps
+                               // Need to copy the global parameter
+                               ListObject globalParams = 
ParamservUtils.copyList((ListObject) _ps.pull(_workerID));
+                               if (LOG.isDebugEnabled()) {
+                                       LOG.debug(String.format(
+                                                       "Local worker_%d: 
Successfully pull the global parameters [size:%d kb] from ps.", _workerID,
+                                                       
globalParams.getDataSize() / 1024));
+                               }
+                               _ec.setVariable(Statement.PS_MODEL, 
globalParams);
+
+                               long begin = j * _batchSize + 1;
+                               long end = Math.min(begin + _batchSize, 
dataSize);
+
+                               // Get batch features and labels
+                               MatrixObject bFeatures = 
ParamservUtils.sliceMatrix(_features, begin, end);
+                               MatrixObject bLabels = 
ParamservUtils.sliceMatrix(_labels, begin, end);
+                               _ec.setVariable(Statement.PS_FEATURES, 
bFeatures);
+                               _ec.setVariable(Statement.PS_LABELS, bLabels);
+
+                               if (LOG.isDebugEnabled()) {
+                                       LOG.debug(String.format(
+                                                       "Local worker_%d: Got 
batch data [size:%d kb] of index from %d to %d. [Epoch:%d  Total epoch:%d  
Iteration:%d  Total iteration:%d]",
+                                                       _workerID, 
bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 
1,
+                                                       _epochs, j + 1, 
totalIter));
+                               }
+
+                               // Invoke the update function
+                               _inst.processInstruction(_ec);
+
+                               // Get the gradients
+                               ListObject gradients = (ListObject) 
_ec.getVariable(_outputs.get(0).getName());
+
+                               // Push the gradients to ps
+                               _ps.push(_workerID, gradients);
+                               if (LOG.isDebugEnabled()) {
+                                       LOG.debug(String.format("Local 
worker_%d: Successfully push the gradients [size:%d kb] to ps.",
+                                                       _workerID, 
gradients.getDataSize() / 1024));
+                               }
+
+                               ParamservUtils.cleanupListObject(_ec, 
globalParams);
+                               ParamservUtils.cleanupData(bFeatures);
+                               ParamservUtils.cleanupData(bLabels);
+                       }
+                       if (LOG.isDebugEnabled()) {
+                               LOG.debug(String.format("Local worker_%d: 
Finished %d epoch.", _workerID, i + 1));
+                       }
+               }
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug(String.format("Local worker_%d: Job 
finished.", _workerID));
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
new file mode 100644
index 0000000..d060a91
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public class LocalParamServer extends ParamServer {
+
+       public LocalParamServer(ListObject model, String aggFunc, 
Statement.PSFrequency freq,
+                       Statement.PSUpdateType updateType, ExecutionContext ec, 
int workerNum,
+                       ListObject hyperParams) {
+               super(model, aggFunc, freq, updateType, ec, workerNum, 
hyperParams);
+       }
+
+       @Override
+       public void push(long workerID, ListObject gradients) {
+               synchronized (_lock) {
+                       _queue.add(new Gradient(workerID, gradients));
+                       _lock.notifyAll();
+               }
+       }
+
+       @Override
+       public Data pull(long workerID) {
+               synchronized (_lock) {
+                       while (getPulledState((int) workerID)) {
+                               try {
+                                       _lock.wait();
+                               } catch (InterruptedException e) {
+                                       throw new DMLRuntimeException(
+                                                       String.format("Local 
worker_%d: failed to pull the global parameters.", workerID), e);
+                               }
+                       }
+                       setPulledState((int) workerID, true);
+               }
+               return getResult();
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
new file mode 100644
index 0000000..9ace823
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.util.ArrayList;
+import java.util.stream.Collectors;
+
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DataIdentifier;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+@SuppressWarnings("unused")
+public abstract class PSWorker {
+
+       long _workerID = -1;
+       int _epochs;
+       long _batchSize;
+       MatrixObject _features;
+       MatrixObject _labels;
+       ExecutionContext _ec;
+       ParamServer _ps;
+       private String _updFunc;
+       private Statement.PSFrequency _freq;
+       private MatrixObject _valFeatures;
+       private MatrixObject _valLabels;
+
+       ArrayList<DataIdentifier> _outputs;
+       FunctionCallCPInstruction _inst;
+
+       public PSWorker(long workerID, String updFunc, Statement.PSFrequency 
freq, int epochs, long batchSize,
+                       ListObject hyperParams, ExecutionContext ec, 
ParamServer ps) {
+               this._workerID = workerID;
+               this._updFunc = updFunc;
+               this._freq = freq;
+               this._epochs = epochs;
+               this._batchSize = batchSize;
+               this._ec = 
ExecutionContextFactory.createContext(ec.getProgram());
+               if (hyperParams != null) {
+                       this._ec.setVariable(Statement.PS_HYPER_PARAMS, 
hyperParams);
+               }
+               this._ps = ps;
+
+               // Get the update function
+               String[] keys = DMLProgram.splitFunctionKey(updFunc);
+               String _funcName = keys[0];
+               String _funcNS = null;
+               if (keys.length == 2) {
+                       _funcNS = keys[0];
+                       _funcName = keys[1];
+               }
+               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(_funcNS, _funcName);
+               ArrayList<DataIdentifier> _inputs = func.getInputParams();
+               _outputs = func.getOutputParams();
+               CPOperand[] _boundInputs = _inputs.stream()
+                               .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
+                               .toArray(CPOperand[]::new);
+               ArrayList<String> _inputNames = 
_inputs.stream().map(DataIdentifier::getName)
+                               
.collect(Collectors.toCollection(ArrayList::new));
+               ArrayList<String> _outputNames = 
_outputs.stream().map(DataIdentifier::getName)
+                               
.collect(Collectors.toCollection(ArrayList::new));
+               _inst = new FunctionCallCPInstruction(_funcNS, _funcName, 
_boundInputs, _inputNames, _outputNames,
+                               "update function");
+
+               // Check the inputs of the update function
+               checkInput(_inputs, Expression.DataType.MATRIX, 
Statement.PS_FEATURES);
+               checkInput(_inputs, Expression.DataType.MATRIX, 
Statement.PS_LABELS);
+               checkInput(_inputs, Expression.DataType.LIST, 
Statement.PS_MODEL);
+               if (hyperParams != null) {
+                       checkInput(_inputs, Expression.DataType.LIST, 
Statement.PS_HYPER_PARAMS);
+               }
+
+               // Check the output of the update function
+               if (_outputs.size() != 1) {
+                       throw new DMLRuntimeException(
+                               String.format("The output of the '%s' function 
should provide one list containing the gradients.", updFunc));
+               }
+               if (_outputs.get(0).getDataType() != Expression.DataType.LIST) {
+                       throw new DMLRuntimeException(
+                                       String.format("The output of the '%s' 
function should be type of list.", updFunc));
+               }
+       }
+
+       private void checkInput(ArrayList<DataIdentifier> _inputs, 
Expression.DataType dt, String pname) {
+               if (_inputs.stream().filter(input -> input.getDataType() == dt 
&& pname.equals(input.getName())).count() != 1) {
+                       throw new DMLRuntimeException(
+                               String.format("The '%s' function should provide 
an input of '%s' type named '%s'.", _updFunc, dt, pname));
+               }
+       }
+
+       public void setFeatures(MatrixObject features) {
+               this._features = features;
+       }
+
+       public void setLabels(MatrixObject labels) {
+               this._labels = labels;
+       }
+
+       public void setValFeatures(MatrixObject valFeatures) {
+               this._valFeatures = valFeatures;
+       }
+
+       public void setValLabels(MatrixObject valLabels) {
+               this._valLabels = valLabels;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
new file mode 100644
index 0000000..6e1cd13
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.stream.Collectors;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DataIdentifier;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public abstract class ParamServer {
+
+       public class Gradient {
+               final long _workerID;
+               final ListObject _gradients;
+
+               public Gradient(long workerID, ListObject gradients) {
+                       this._workerID = workerID;
+                       this._gradients = gradients;
+               }
+       }
+
+       Queue<Gradient> _queue;
+       final Object _lock = new Object();
+       private ListObject _model;
+       private AggregationService _aggService;
+       private Thread _aggThread;
+       private boolean[] _pulledStates;
+
+       protected ParamServer(ListObject model, String aggFunc, 
Statement.PSFrequency freq,
+                       Statement.PSUpdateType updateType, ExecutionContext ec, 
int workerNum, ListObject hyperParams) {
+               this._queue = new ConcurrentLinkedQueue<>();
+               this._model = model;
+               this._aggService = new AggregationService(aggFunc, freq, 
updateType, ec, workerNum, hyperParams);
+               this._pulledStates = new boolean[workerNum];
+               this._aggThread = new Thread(_aggService);
+       }
+
+       public abstract void push(long workerID, ListObject value);
+
+       public abstract Data pull(long workerID);
+
+       public void start() {
+               _aggService._alive = true;
+               _aggThread.start();
+       }
+
+       public void stop() {
+               _aggService._alive = false;
+               try {
+                       _aggThread.join();
+               } catch (InterruptedException e) {
+                       throw new DMLRuntimeException("Parameter server: failed 
when stopping the server.", e);
+               }
+       }
+
+       public ListObject getResult() {
+               return _model;
+       }
+
+       public boolean getPulledState(int workerID) {
+               return _pulledStates[workerID];
+       }
+
+       public void setPulledState(int workerID, boolean state) {
+               _pulledStates[workerID] = state;
+       }
+
+       private void resetPulledStates() {
+               _pulledStates = new boolean[_pulledStates.length];
+       }
+
+       /**
+        * Inner aggregation service which is for updating the model
+        */
+       @SuppressWarnings("unused")
+       private class AggregationService implements Runnable {
+
+               protected final Log LOG = 
LogFactory.getLog(AggregationService.class.getName());
+
+               protected ExecutionContext _ec;
+               private Statement.PSFrequency _freq;
+               private Statement.PSUpdateType _updateType;
+               private FunctionCallCPInstruction _inst;
+               private DataIdentifier _output;
+               private boolean _alive;
+               private boolean[] _finishedStates;  // Workers' finished states
+
+               AggregationService(String aggFunc, Statement.PSFrequency freq, 
Statement.PSUpdateType updateType,
+                               ExecutionContext ec, int workerNum, ListObject 
hyperParams) {
+                       _ec = 
ExecutionContextFactory.createContext(ec.getProgram());
+                       _freq = freq;
+                       _updateType = updateType;
+                       if (hyperParams != null) {
+                               _ec.setVariable(Statement.PS_HYPER_PARAMS, 
hyperParams);
+                       }
+                       _finishedStates = new boolean[workerNum];
+
+                       // Fetch the aggregation function
+                       String[] keys = DMLProgram.splitFunctionKey(aggFunc);
+                       String funcName = keys[0];
+                       String funcNS = null;
+                       if (keys.length == 2) {
+                               funcNS = keys[0];
+                               funcName = keys[1];
+                       }
+                       FunctionProgramBlock func = 
_ec.getProgram().getFunctionProgramBlock(funcNS, funcName);
+                       ArrayList<DataIdentifier> inputs = 
func.getInputParams();
+                       ArrayList<DataIdentifier> outputs = 
func.getOutputParams();
+
+                       // Check the output of the aggregation function
+                       if (outputs.size() != 1) {
+                               throw new DMLRuntimeException(String.format(
+                                               "The output of the '%s' 
function should provide one list containing the updated model.",
+                                               aggFunc));
+                       }
+                       if (outputs.get(0).getDataType() != 
Expression.DataType.LIST) {
+                               throw new DMLRuntimeException(
+                                               String.format("The output of 
the '%s' function should be type of list.", aggFunc));
+                       }
+                       _output = outputs.get(0);
+
+                       CPOperand[] boundInputs = inputs.stream()
+                                       .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+                                       .toArray(CPOperand[]::new);
+                       ArrayList<String> inputNames = 
inputs.stream().map(DataIdentifier::getName)
+                                       
.collect(Collectors.toCollection(ArrayList::new));
+                       ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
+                                       
.collect(Collectors.toCollection(ArrayList::new));
+                       _inst = new FunctionCallCPInstruction(funcNS, funcName, 
boundInputs, inputNames, outputNames,
+                                       "aggregate function");
+               }
+
+               boolean isAlive() {
+                       return _alive;
+               }
+
+               private boolean allFinished() {
+                       return !ArrayUtils.contains(_finishedStates, false);
+               }
+
+               private void resetFinishedStates() {
+                       Arrays.fill(_finishedStates, false);
+               }
+
+               private void setFinishedState(int workerID) {
+                       _finishedStates[workerID] = true;
+               }
+
+               @Override
+               public void run() {
+                       synchronized (_lock) {
+                               while (isAlive()) {
+                                       do {
+                                               while (_queue.isEmpty()) {
+                                                       try {
+                                                               _lock.wait();
+                                                       } catch 
(InterruptedException e) {
+                                                               throw new 
DMLRuntimeException(
+                                                                               
"Aggregation service: error when waiting for the coming gradients.", e);
+                                                       }
+                                               }
+                                               Gradient p = _queue.remove();
+                                               if (LOG.isDebugEnabled()) {
+                                                       
LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of 
worker_%d.",
+                                                                       
p._gradients.getDataSize() / 1024, p._workerID));
+                                               }
+
+                                               setFinishedState((int) 
p._workerID);
+
+                                               // Populate the variables table 
with the gradients and model
+                                               
_ec.setVariable(Statement.PS_GRADIENTS, p._gradients);
+                                               
_ec.setVariable(Statement.PS_MODEL, _model);
+
+                                               // Invoke the aggregate function
+                                               _inst.processInstruction(_ec);
+
+                                               // Get the output
+                                               ListObject newModel = 
(ListObject) _ec.getVariable(_output.getName());
+
+                                               // Update the model with the 
new output
+                                               
ParamservUtils.cleanupListObject(_ec, _model);
+                                               
ParamservUtils.cleanupListObject(_ec, p._gradients);
+                                               _model = newModel;
+
+                                       } while (!allFinished());
+
+                                       // notify all the workers to get the 
updated model
+                                       resetPulledStates();
+                                       resetFinishedStates();
+                                       _lock.notifyAll();
+                                       if (LOG.isDebugEnabled()) {
+                                               LOG.debug("Global parameter is 
broadcasted successfully.");
+                                       }
+                               }
+                       }
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
new file mode 100644
index 0000000..54c5d6c
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.MetaDataFormat;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+
+public class ParamservUtils {
+
+       /**
+        * Deep copy the list object
+        *
+        * @param lo list object
+        * @return a new copied list object
+        */
+       public static ListObject copyList(ListObject lo) {
+               if (lo.getLength() == 0) {
+                       return lo;
+               }
+               List<Data> newData = lo.getNames().stream().map(name -> {
+                       Data oldData = lo.slice(name);
+                       if (oldData instanceof MatrixObject) {
+                               MatrixObject mo = (MatrixObject) oldData;
+                               return sliceMatrix(mo, 1, mo.getNumRows());
+                       } else if (oldData instanceof ListObject || oldData 
instanceof FrameObject) {
+                               throw new DMLRuntimeException("Copy list: does 
not support list or frame.");
+                       } else {
+                               return oldData;
+                       }
+               }).collect(Collectors.toList());
+               return new ListObject(newData, lo.getNames());
+       }
+
+       public static void cleanupListObject(ExecutionContext ec, ListObject 
lo) {
+               ec.getVariables().removeAllIn(new HashSet<>(lo.getNames()));
+               lo.getData().forEach(ParamservUtils::cleanupData);
+       }
+
+       public static void cleanupData(Data data) {
+               if( !(data instanceof CacheableData) )
+                       return;
+               CacheableData<?> cd = (CacheableData<?>) data;
+               cd.enableCleanup(true);
+               cd.clearData();
+       }
+
+       /**
+        * Slice the matrix
+        * @param mo input matrix
+        * @param rl low boundary
+        * @param rh high boundary
+        * @return new sliced matrix
+        */
+       public static MatrixObject sliceMatrix(MatrixObject mo, long rl, long 
rh) {
+               MatrixObject result = new 
MatrixObject(Expression.ValueType.DOUBLE, null,
+                       new MetaDataFormat(new MatrixCharacteristics(-1, -1, 
-1, -1),
+                               OutputInfo.BinaryBlockOutputInfo, 
InputInfo.BinaryBlockInputInfo));
+               MatrixBlock tmp = mo.acquireRead();
+               result.acquireModify(tmp.slice((int)rl-1, (int)rh-1, 0,
+                       tmp.getNumColumns()-1, new MatrixBlock()));
+               mo.release();
+               result.release();
+               return result;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
index 1ca8eab..22b79b0 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
@@ -46,7 +46,7 @@ public class CPOperand
                this(name, vt, dt, false);
        }
 
-       private CPOperand(String name, ValueType vt, DataType dt, boolean 
literal) {
+       public CPOperand(String name, ValueType vt, DataType dt, boolean 
literal) {
                _name = name;
                _valueType = vt;
                _dataType = dt;

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
index 95f03b5..670190c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
@@ -25,6 +25,7 @@ import java.util.List;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
 
 public class ListObject extends Data {
        private static final long serialVersionUID = 3652422061598967358L;
@@ -107,6 +108,19 @@ public class ListObject extends Data {
                return (_names == null) ? null : _names.get(ix);
        }
 
+       public boolean isNamedList() {
+               return _names != null;
+       }
+
+       public List<Data> getData() {
+               return _data;
+       }
+
+       public long getDataSize() {
+               return _data.stream().filter(data -> data instanceof 
CacheableData)
+                               .map(data -> ((CacheableData) 
data).getDataSize()).reduce((l1, l2) -> l1 + l2).get();
+       }
+
        @Override
        public String getDebugName() {
                return toString();

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
index 51cc4c1..4e5d4c0 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
@@ -34,8 +34,8 @@ import org.apache.sysml.utils.Statistics;
 
 public final class MatrixIndexingCPInstruction extends IndexingCPInstruction {
 
-       protected MatrixIndexingCPInstruction(CPOperand in, CPOperand rl, 
CPOperand ru, CPOperand cl,
-                       CPOperand cu, CPOperand out, String opcode, String 
istr) {
+       public MatrixIndexingCPInstruction(CPOperand in, CPOperand rl, 
CPOperand ru, CPOperand cl, CPOperand cu,
+                       CPOperand out, String opcode, String istr) {
                super(in, rl, ru, cl, cu, out, opcode, istr);
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index ddc56ae..3ab0fc8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,14 +19,62 @@
 
 package org.apache.sysml.runtime.instructions.cp;
 
+import static org.apache.sysml.parser.Statement.PSFrequency;
+import static org.apache.sysml.parser.Statement.PSModeType;
+import static org.apache.sysml.parser.Statement.PSScheme;
+import static org.apache.sysml.parser.Statement.PSUpdateType;
+import static org.apache.sysml.parser.Statement.PS_AGGREGATION_FUN;
+import static org.apache.sysml.parser.Statement.PS_BATCH_SIZE;
+import static org.apache.sysml.parser.Statement.PS_EPOCHS;
+import static org.apache.sysml.parser.Statement.PS_FEATURES;
+import static org.apache.sysml.parser.Statement.PS_FREQUENCY;
+import static org.apache.sysml.parser.Statement.PS_HYPER_PARAMS;
+import static org.apache.sysml.parser.Statement.PS_LABELS;
+import static org.apache.sysml.parser.Statement.PS_MODE;
+import static org.apache.sysml.parser.Statement.PS_MODEL;
+import static org.apache.sysml.parser.Statement.PS_PARALLELISM;
+import static org.apache.sysml.parser.Statement.PS_SCHEME;
+import static org.apache.sysml.parser.Statement.PS_UPDATE_FUN;
+import static org.apache.sysml.parser.Statement.PS_UPDATE_TYPE;
+import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES;
+import static org.apache.sysml.parser.Statement.PS_VAL_LABELS;
+
+import java.util.ArrayList;
 import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
-import org.apache.sysml.parser.Statement;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.utils.NativeHelper;
 
 public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruction {
 
+       private static final int DEFAULT_BATCH_SIZE = 64;
+       private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = 
PSFrequency.BATCH;
+       private static final int DEFAULT_LEVEL_PARALLELISM = 
InfrastructureAnalyzer.getLocalParallelism();
+       private static final PSScheme DEFAULT_SCHEME = 
PSScheme.DISJOINT_CONTIGUOUS;
+
+       //internal local debug level
+       private static final boolean LDEBUG = false;
+
+       static {
+               // for internal debugging only
+               if (LDEBUG) {
+                       
Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel((Level)
 Level.DEBUG);
+               }
+       }
+
        protected ParamservBuiltinCPInstruction(Operator op, 
LinkedHashMap<String, String> paramsMap, CPOperand out,
                        String opcode, String istr) {
                super(op, paramsMap, out, opcode, istr);
@@ -34,8 +82,209 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-               ListObject model = (ListObject) 
ec.getVariable(getParam(Statement.PS_MODEL));
-               ListObject outList = model.slice(0, model.getLength() - 1);
-               ec.setVariable(output.getName(), outList);
+
+               PSModeType mode = PSModeType.valueOf(getParam(PS_MODE));
+               int workerNum = getWorkerNum(mode);
+               String updFunc = getParam(PS_UPDATE_FUN);
+               String aggFunc = getParam(PS_AGGREGATION_FUN);
+               PSFrequency freq = getFrequency();
+               PSUpdateType updateType = getUpdateType();
+               int epochs = Integer.valueOf(getParam(PS_EPOCHS));
+               if (epochs <= 0) {
+                       throw new DMLRuntimeException(
+                                       String.format("Paramserv function: The 
argument '%s' could not be less than or equal to 0.",
+                                                       PS_EPOCHS));
+               }
+               long batchSize = getBatchSize();
+
+               // Create the parameter server
+               ListObject model = ec.getListObject(getParam(PS_MODEL));
+               ListObject hyperParams = getHyperParams(ec);
+               ParamServer ps = createPS(mode, aggFunc, freq, updateType, 
workerNum, model, ec, hyperParams);
+
+               // Create the local workers
+               List<LocalPSWorker> workers = IntStream.range(0, workerNum)
+                               .mapToObj(i -> new LocalPSWorker((long) i, 
updFunc, freq, epochs, batchSize, hyperParams, ec, ps))
+                               .collect(Collectors.toList());
+
+               // Do data partition
+               doDataPartition(ec, workers);
+
+               // Create the worker threads
+               List<Thread> threads = 
workers.stream().map(Thread::new).collect(Collectors.toList());
+
+               // Start the ps
+               ps.start();
+
+               // Start the workers
+               threads.forEach(Thread::start);
+
+               // Wait for the workers stopping
+               threads.forEach(thread -> {
+                       try {
+                               thread.join();
+                       } catch (InterruptedException e) {
+                               throw new DMLRuntimeException("Paramserv 
function: Failed to join the worker threads.", e);
+                       }
+               });
+
+               ps.stop();
+
+               // Create the output
+               ListObject result = ps.getResult();
+               ec.setVariable(output.getName(), result);
+       }
+
+       private PSUpdateType getUpdateType() {
+               PSUpdateType updType = 
PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE));
+               switch (updType) {
+               case ASP:
+               case SSP:
+                       throw new DMLRuntimeException(String.format("Not 
support update type '%s'.", updType));
+               case BSP:
+                       break;
+               }
+               return updType;
+       }
+
+       private PSFrequency getFrequency() {
+               if (!getParameterMap().containsKey(PS_FREQUENCY)) {
+                       return DEFAULT_UPDATE_FREQUENCY;
+               }
+               PSFrequency freq = PSFrequency.valueOf(getParam(PS_FREQUENCY));
+               switch (freq) {
+               case EPOCH:
+                       throw new DMLRuntimeException("Not support epoch update 
frequency.");
+               case BATCH:
+                       break;
+               }
+               return freq;
+       }
+
+       /**
+        * Get the worker numbers according to the vcores
+        *
+        * @param mode execution mode
+        * @return worker numbers
+        */
+       private int getWorkerNum(PSModeType mode) {
+               int workerNum = DEFAULT_LEVEL_PARALLELISM;
+               if (getParameterMap().containsKey(PS_PARALLELISM)) {
+                       workerNum = Integer.valueOf(getParam(PS_PARALLELISM));
+               }
+               switch (mode) {
+               case LOCAL:
+                       //FIXME: this is a workaround for a maximum number of 
buffers in openblas
+                       //However, the root cause is a missing function 
preparation for each worker
+                       //(i.e., deep copy with unique file names, and reduced 
degree of parallelism)
+                       int vcores = 
InfrastructureAnalyzer.getLocalParallelism();
+                       if ("openblas".equals(NativeHelper.getCurrentBLAS())) {
+                               workerNum = Math.min(workerNum, vcores / 2);
+                       } else {
+                               workerNum = Math.min(workerNum, vcores);
+                       }
+                       break;
+               case REMOTE_SPARK:
+                       throw new DMLRuntimeException("Do not support remote 
spark.");
+               }
+               return workerNum;
+       }
+
+       /**
+        * Create a server which serves the local or remote workers
+        *
+        * @return parameter server
+        */
+       private ParamServer createPS(PSModeType mode, String aggFunc, 
PSFrequency freq, PSUpdateType updateType,
+                       int workerNum, ListObject model, ExecutionContext ec, 
ListObject hyperParams) {
+               ParamServer ps = null;
+               switch (mode) {
+               case LOCAL:
+                       ps = new LocalParamServer(model, aggFunc, freq, 
updateType, ec, workerNum, hyperParams);
+                       break;
+               case REMOTE_SPARK:
+                       throw new DMLRuntimeException("Do not support remote 
spark.");
+               }
+               return ps;
+       }
+
+       private long getBatchSize() {
+               if (!getParameterMap().containsKey(PS_BATCH_SIZE)) {
+                       return DEFAULT_BATCH_SIZE;
+               }
+               long batchSize = Integer.valueOf(getParam(PS_BATCH_SIZE));
+               if (batchSize <= 0) {
+                       throw new DMLRuntimeException(String.format(
+                                       "Paramserv function: the number of 
argument '%s' could not be less than or equal to 0.",
+                                       PS_BATCH_SIZE));
+               }
+               return batchSize;
+       }
+
+       private ListObject getHyperParams(ExecutionContext ec) {
+               ListObject hyperparams = null;
+               if (getParameterMap().containsKey(PS_HYPER_PARAMS)) {
+                       hyperparams = 
ec.getListObject(getParam(PS_HYPER_PARAMS));
+               }
+               return hyperparams;
+       }
+
+       private void doDataPartition(ExecutionContext ec, List<LocalPSWorker> 
workers) {
+               MatrixObject features = 
ec.getMatrixObject(getParam(PS_FEATURES));
+               MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS));
+               MatrixObject valFeatures = 
ec.getMatrixObject(getParam(PS_VAL_FEATURES));
+               MatrixObject valLabels = 
ec.getMatrixObject(getParam(PS_VAL_LABELS));
+               PSScheme scheme = DEFAULT_SCHEME;
+               if (getParameterMap().containsKey(PS_SCHEME)) {
+                       scheme = PSScheme.valueOf(getParam(PS_SCHEME));
+               }
+               switch (scheme) {
+               case DISJOINT_CONTIGUOUS:
+                       disjointContiguous(features, labels, valFeatures, 
valLabels, workers);
+                       break;
+               case DISJOINT_RANDOM:
+               case OVERLAP_RESHUFFLE:
+               case DISJOINT_ROUND_ROBIN:
+                       throw new DMLRuntimeException(
+                                       String.format("Paramserv function: the 
scheme '%s' is not supported.", scheme));
+               }
+       }
+
+       private void disjointContiguous(MatrixObject features, MatrixObject 
labels, MatrixObject valFeatures,
+                       MatrixObject valLabels, List<LocalPSWorker> workers) {
+               // training data
+               List<MatrixObject> pfs = disjointContiguous(workers.size(), 
features);
+               List<MatrixObject> pls = disjointContiguous(workers.size(), 
labels);
+               if (pfs.size() < workers.size()) {
+                       LOG.warn(String.format(
+                                       "There is only %d batches of data but 
has %d workers. Hence, reset the number of workers with %d.",
+                                       pfs.size(), workers.size(), 
pfs.size()));
+                       workers = workers.subList(0, pfs.size());
+               }
+               for (int i = 0; i < workers.size(); i++) {
+                       workers.get(i).setFeatures(pfs.get(i));
+                       workers.get(i).setLabels(pls.get(i));
+               }
+
+               // validation data
+               List<MatrixObject> pvfs = disjointContiguous(workers.size(), 
valFeatures);
+               List<MatrixObject> pvls = disjointContiguous(workers.size(), 
valLabels);
+               for (int i = 0; i < workers.size(); i++) {
+                       workers.get(i).setValFeatures(pvfs.get(i));
+                       workers.get(i).setValLabels(pvls.get(i));
+               }
+       }
+
+       private List<MatrixObject> disjointContiguous(int workerNum, 
MatrixObject mo) {
+               List<MatrixObject> list = new ArrayList<>();
+               long stepSize = (long) Math.ceil(mo.getNumRows() / workerNum);
+               long begin = 1;
+               while (begin < mo.getNumRows()) {
+                       long end = Math.min(begin + stepSize, mo.getNumRows());
+                       MatrixObject pmo = ParamservUtils.sliceMatrix(mo, 
begin, end);
+                       list.add(pmo);
+                       begin = end + 1;
+               }
+               return list;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java 
b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index 47ea66e..43f5229 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -1250,9 +1250,11 @@ public abstract class AutomatedTestBase
                        if (exceptionExpected)
                                fail("expected exception which has not been 
raised: " + expectedException);
                } catch (Exception e) {
-                       if (exceptionExpected && 
e.getClass().equals(expectedException) && errMessage != null
-                                       && 
!e.getMessage().contains(errMessage)) {
-                               fail("expected exception message has not been 
raised: " + errMessage);
+                       if (errMessage != null && !errMessage.equals("")) {
+                               boolean result = 
rCompareException(exceptionExpected, errMessage, e, false);
+                               if (exceptionExpected && !result) {
+                                       fail(String.format("expected exception 
message '%s' has not been raised.", errMessage));
+                               }
                        }
                        if (!exceptionExpected || (expectedException != null && 
!(e.getClass().equals(expectedException)))) {
                                e.printStackTrace();
@@ -1269,6 +1271,16 @@ public abstract class AutomatedTestBase
                }
        }
 
+       private boolean rCompareException(boolean exceptionExpected, String 
errMessage, Throwable e, boolean result) {
+               if (e.getCause() != null) {
+                       result |= rCompareException(exceptionExpected, 
errMessage, e.getCause(), result);
+               }
+               if (exceptionExpected && errMessage != null && 
e.getMessage().contains(errMessage)) {
+                       result = true;
+               }
+               return result;
+       }
+
        public void cleanupScratchSpace()
        {
                try

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
index 1b227f1..6370099 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
@@ -32,7 +32,8 @@ public class ParamservFuncTest extends AutomatedTestBase {
        private static final String TEST_NAME4 = "paramserv-wrong-type-args";
        private static final String TEST_NAME5 = "paramserv-wrong-args";
        private static final String TEST_NAME6 = "paramserv-wrong-args2";
-       private static final String TEST_NAME7 = "paramserv-ipa-test";
+       private static final String TEST_NAME7 = "paramserv-nn-test";
+       private static final String TEST_NAME8 = "paramserv-minimum-version";
 
        private static final String TEST_DIR = "functions/paramserv/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
ParamservFuncTest.class.getSimpleName() + "/";
@@ -48,53 +49,59 @@ public class ParamservFuncTest extends AutomatedTestBase {
                addTestConfiguration(TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {}));
                addTestConfiguration(TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {}));
                addTestConfiguration(TEST_NAME7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {}));
+               addTestConfiguration(TEST_NAME8, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {}));
        }
 
        @Test
        public void testParamservWithAllArgs() {
-               runDMLTest(TEST_NAME1, true, false, null, null);
+               runDMLTest(TEST_NAME1, false, null, null);
        }
 
        @Test
        public void testParamservWithoutOptionalArgs() {
-               runDMLTest(TEST_NAME2, true, false, null, null);
+               runDMLTest(TEST_NAME2, false, null, null);
        }
 
        @Test
        public void testParamservMissArgs() {
                final String errmsg = "Named parameter 'features' missing. 
Please specify the input.";
-               runDMLTest(TEST_NAME3, true, true, DMLException.class, errmsg);
+               runDMLTest(TEST_NAME3, true, DMLException.class, errmsg);
        }
 
        @Test
        public void testParamservWrongTypeArgs() {
                final String errmsg = "Input to PARAMSERV::model must be of 
type 'LIST'. It should not be of type 'MATRIX'";
-               runDMLTest(TEST_NAME4, true, true, DMLException.class, errmsg);
+               runDMLTest(TEST_NAME4, true, DMLException.class, errmsg);
        }
 
        @Test
        public void testParamservWrongArgs() {
                final String errmsg = "Function PARAMSERV does not support 
value 'NSP' as the 'utype' parameter.";
-               runDMLTest(TEST_NAME5, true, true, DMLException.class, errmsg);
+               runDMLTest(TEST_NAME5, true, DMLException.class, errmsg);
        }
 
        @Test
        public void testParamservWrongArgs2() {
                final String errmsg = "Invalid parameters for PARAMSERV: 
[modelList, val_featur=X_val]";
-               runDMLTest(TEST_NAME6, true, true, DMLException.class, errmsg);
+               runDMLTest(TEST_NAME6, true, DMLException.class, errmsg);
        }
 
        @Test
-       public void testParamservIpaTest() {
-               runDMLTest(TEST_NAME7, true, false, null, "1");
+       public void testParamservNNTest() {
+               runDMLTest(TEST_NAME7, false, null, null);
        }
 
-       private void runDMLTest(String testname, boolean newWay, boolean 
exceptionExpected, Class<?> exceptionClass,
+       @Test
+       public void testParamservMinimumVersionTest() {
+               runDMLTest(TEST_NAME8, false, null, null);
+       }
+
+       private void runDMLTest(String testname, boolean exceptionExpected, 
Class<?> exceptionClass,
                        String errmsg) {
                TestConfiguration config = getTestConfiguration(testname);
                loadTestConfiguration(config);
                programArgs = new String[] { "-explain" };
                fullDMLScriptName = HOME + testname + ".dml";
-               runTest(newWay, exceptionExpected, exceptionClass, errmsg, -1);
+               runTest(true, exceptionExpected, exceptionClass, errmsg, -1);
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
new file mode 100644
index 0000000..2a3bbe2
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -0,0 +1,383 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/conv2d_builtin.dml") as conv2d
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/l2_reg.dml") as l2_reg
+source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+                 matrix[double] X_val, matrix[double] Y_val,
+                 int C, int Hin, int Win, int epochs, int workers)
+    return (matrix[double] W1, matrix[double] b1,
+            matrix[double] W2, matrix[double] b2,
+            matrix[double] W3, matrix[double] b3,
+            matrix[double] W4, matrix[double] b4) {
+  /*
+   * Trains a convolutional net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.  The targets, Y, have K
+   * classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - Y_val: Target validation matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   */
+  N = nrow(X)
+  K = ncol(Y)
+
+  # Create network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.01  # learning rate
+  mu = 0.9  #0.5  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the model object
+  modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, 
vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+
+  # Create the hyper parameter list
+  params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+  # Use paramserv function
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation",
 mode="LOCAL", utype="BSP", freq="BATCH", epochs=epochs, batchsize=64, 
k=workers, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, 
checkpointing="NONE")
+
+  W1 = as.matrix(modelList2["W1"])
+  b1 = as.matrix(modelList2["b1"])
+  W2 = as.matrix(modelList2["W2"])
+  b2 = as.matrix(modelList2["b2"])
+  W3 = as.matrix(modelList2["W3"])
+  b3 = as.matrix(modelList2["b3"])
+  W4 = as.matrix(modelList2["W4"])
+  b4 = as.matrix(modelList2["b4"])
+
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(matrix[double] features,
+                     matrix[double] labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+          return (list[unknown] gradients) {
+
+# PB: not be able to get scalar from list
+
+  C = 1
+  Hin = 28
+  Win = 28
+  Hf = 5
+  Wf = 5
+  stride = 1
+  pad = 2
+  lambda = 5e-04
+  F1 = 32
+  F2 = 64
+  N3 = 512
+  W1 = as.matrix(model["W1"])
+  b1 = as.matrix(model["b1"])
+  W2 = as.matrix(model["W2"])
+  b2 = as.matrix(model["b2"])
+  W3 = as.matrix(model["W3"])
+  b3 = as.matrix(model["b3"])
+  W4 = as.matrix(model["W4"])
+  b4 = as.matrix(model["b4"])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, 
Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, 
pad=0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, 
pad=0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute data backward pass
+  ## loss:
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, 
F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, 
W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, 
db3=db3, db4=db4)
+}
+
+# PB: how to handle the velocity? (put into the model)
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+   return (list[unknown] modelResult) {
+
+     W1 = as.matrix(model["W1"])
+     W2 = as.matrix(model["W2"])
+     W3 = as.matrix(model["W3"])
+     W4 = as.matrix(model["W4"])
+     b1 = as.matrix(model["b1"])
+     b2 = as.matrix(model["b2"])
+     b3 = as.matrix(model["b3"])
+     b4 = as.matrix(model["b4"])
+     dW1 = as.matrix(gradients["dW1"])
+     dW2 = as.matrix(gradients["dW2"])
+     dW3 = as.matrix(gradients["dW3"])
+     dW4 = as.matrix(gradients["dW4"])
+     db1 = as.matrix(gradients["db1"])
+     db2 = as.matrix(gradients["db2"])
+     db3 = as.matrix(gradients["db3"])
+     db4 = as.matrix(gradients["db4"])
+     vW1 = as.matrix(model["vW1"])
+     vW2 = as.matrix(model["vW2"])
+     vW3 = as.matrix(model["vW3"])
+     vW4 = as.matrix(model["vW4"])
+     vb1 = as.matrix(model["vb1"])
+     vb2 = as.matrix(model["vb2"])
+     vb3 = as.matrix(model["vb3"])
+     vb4 = as.matrix(model["vb4"])
+     lr = 0.01
+     mu = 0.9
+
+     # Optimize with SGD w/ Nesterov momentum
+     [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+     [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+     [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+     [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+     [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+     [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+     [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+     [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+     modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, 
b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+   }
+
+predict = function(matrix[double] X, int C, int Hin, int Win,
+                   matrix[double] W1, matrix[double] b1,
+                   matrix[double] W2, matrix[double] b2,
+                   matrix[double] W3, matrix[double] b3,
+                   matrix[double] W4, matrix[double] b4)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a convolutional
+   * net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  N = nrow(X)
+
+  # Network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions 
(num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  batch_size = 64
+  iters = ceil(N / batch_size)
+  for(i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a convolutional net using the "LeNet" architecture.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, Y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - Y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, Y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+  accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+  /*
+   * Generate a dummy dataset similar to the MNIST dataset.
+   *
+   * Outputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   */
+  # Generate dummy input data
+  N = 1024  # num examples
+  C = 1  # num input channels
+  Hin = 28  # input height
+  Win = 28  # input width
+  K = 10  # num target classes
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+  Y = table(seq(1, N), classes)  # one-hot encoding
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
new file mode 100644
index 0000000..2ef7411
--- /dev/null
+++ 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -0,0 +1,377 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/conv2d_builtin.dml") as conv2d
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/l2_reg.dml") as l2_reg
+source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+                 matrix[double] X_val, matrix[double] Y_val,
+                 int C, int Hin, int Win, int epochs, int workers)
+    return (matrix[double] W1, matrix[double] b1,
+            matrix[double] W2, matrix[double] b2,
+            matrix[double] W3, matrix[double] b3,
+            matrix[double] W4, matrix[double] b4) {
+  /*
+   * Trains a convolutional net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.  The targets, Y, have K
+   * classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - Y_val: Target validation matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   */
+  N = nrow(X)
+  K = ncol(Y)
+
+  # Create network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.01  # learning rate
+  mu = 0.9  #0.5  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the model object
+  modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, 
vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+
+  # Create the hyper parameter list
+  params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+  # Use paramserv function
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation",
 mode="LOCAL", utype="BSP", epochs=epochs, hyperparams=params)
+
+  W1 = as.matrix(modelList2["W1"])
+  b1 = as.matrix(modelList2["b1"])
+  W2 = as.matrix(modelList2["W2"])
+  b2 = as.matrix(modelList2["b2"])
+  W3 = as.matrix(modelList2["W3"])
+  b3 = as.matrix(modelList2["b3"])
+  W4 = as.matrix(modelList2["W4"])
+  b4 = as.matrix(modelList2["b4"])
+
+}
+
+gradients = function(matrix[double] features,
+                     matrix[double] labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+          return (list[unknown] gradients) {
+
+  C = 1
+  Hin = 28
+  Win = 28
+  Hf = 5
+  Wf = 5
+  stride = 1
+  pad = 2
+  lambda = 5e-04
+  F1 = 32
+  F2 = 64
+  N3 = 512
+  W1 = as.matrix(model["W1"])
+  b1 = as.matrix(model["b1"])
+  W2 = as.matrix(model["W2"])
+  b2 = as.matrix(model["b2"])
+  W3 = as.matrix(model["W3"])
+  b3 = as.matrix(model["b3"])
+  W4 = as.matrix(model["W4"])
+  b4 = as.matrix(model["b4"])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, 
Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, 
pad=0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, 
pad=0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute data backward pass
+  ## loss:
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, 
F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, 
W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, 
db3=db3, db4=db4)
+
+}
+
+# how to handle the velocity?
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+   return (list[unknown] modelResult) {
+
+     W1 = as.matrix(model["W1"])
+     W2 = as.matrix(model["W2"])
+     W3 = as.matrix(model["W3"])
+     W4 = as.matrix(model["W4"])
+     b1 = as.matrix(model["b1"])
+     b2 = as.matrix(model["b2"])
+     b3 = as.matrix(model["b3"])
+     b4 = as.matrix(model["b4"])
+     dW1 = as.matrix(gradients["dW1"])
+     dW2 = as.matrix(gradients["dW2"])
+     dW3 = as.matrix(gradients["dW3"])
+     dW4 = as.matrix(gradients["dW4"])
+     db1 = as.matrix(gradients["db1"])
+     db2 = as.matrix(gradients["db2"])
+     db3 = as.matrix(gradients["db3"])
+     db4 = as.matrix(gradients["db4"])
+     vW1 = as.matrix(model["vW1"])
+     vW2 = as.matrix(model["vW2"])
+     vW3 = as.matrix(model["vW3"])
+     vW4 = as.matrix(model["vW4"])
+     vb1 = as.matrix(model["vb1"])
+     vb2 = as.matrix(model["vb2"])
+     vb3 = as.matrix(model["vb3"])
+     vb4 = as.matrix(model["vb4"])
+     lr = 0.01
+     mu = 0.9
+
+     # Optimize with SGD w/ Nesterov momentum
+     [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+     [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+     [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+     [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+     [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+     [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+     [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+     [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+     modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, 
b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+   }
+
+predict = function(matrix[double] X, int C, int Hin, int Win,
+                   matrix[double] W1, matrix[double] b1,
+                   matrix[double] W2, matrix[double] b2,
+                   matrix[double] W3, matrix[double] b3,
+                   matrix[double] W4, matrix[double] b4)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a convolutional
+   * net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  N = nrow(X)
+
+  # Network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions 
(num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  batch_size = 64
+  iters = ceil(N / batch_size)
+  for(i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a convolutional net using the "LeNet" architecture.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, Y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - Y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, Y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+  accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+  /*
+   * Generate a dummy dataset similar to the MNIST dataset.
+   *
+   * Outputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   */
+  # Generate dummy input data
+  N = 1024  # num examples
+  C = 1  # num input channels
+  Hin = 28  # input height
+  Win = 28  # input width
+  K = 10  # num target classes
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+  Y = table(seq(1, N), classes)  # one-hot encoding
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-all-args.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-all-args.dml 
b/src/test/scripts/functions/paramserv/paramserv-all-args.dml
index bcb3ac3..ec6e087 100644
--- a/src/test/scripts/functions/paramserv/paramserv-all-args.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-all-args.dml
@@ -20,7 +20,7 @@
 #-------------------------------------------------------------
 
 e1 = "element1"
-paramsList = list(e1)
+paramsList = list(e1=e1)
 X = matrix(1, rows=2, cols=3)
 Y = matrix(2, rows=2, cols=3)
 X_val = matrix(3, rows=2, cols=3)
@@ -35,7 +35,7 @@ aggregation = function (matrix[double] input) return 
(matrix[double] output) {
 }
 
 e2 = "element2"
-hps = list(e2)
+hps = list(e2=e2)
 
 # Use paramserv function
 paramsList2 = paramserv(model=paramsList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, 
scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE")

http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml 
b/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml
deleted file mode 100644
index 5aed767..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml
+++ /dev/null
@@ -1,47 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-e1 = "element1"
-paramsList = list(e1)
-X = matrix(1, rows=2, cols=3)
-Y = matrix(2, rows=2, cols=3)
-X_val = matrix(3, rows=2, cols=3)
-Y_val = matrix(4, rows=2, cols=3)
-
-gradients = function (matrix[double] input) return (matrix[double] output) {
-  output = input
-}
-
-aggregation = function (matrix[double] input) return (matrix[double] output) {
-  output = input
-}
-
-e2 = "element2"
-hps = list(e2)
-
-# Use paramserv function
-paramsList2 = list(1, 2, 3)
-
-if (length(paramsList2) == 3) {
-  paramsList2 = paramserv(model=paramsList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, 
scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE")
-}
-
-print(length(paramsList2))
\ No newline at end of file

Reply via email to