Repository: systemml
Updated Branches:
  refs/heads/master 2b86a4d92 -> d44b3280f


[SYSTEMML-2344,48,49,52] Various improvements local paramserv backend

Closes #777.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d44b3280
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d44b3280
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d44b3280

Branch: refs/heads/master
Commit: d44b3280f2132deb303e955ff5b9a17daac4c31e
Parents: 2b86a4d
Author: EdgarLGB <[email protected]>
Authored: Sun Jun 3 23:07:02 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Jun 3 23:07:03 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/parser/Statement.java |   8 +-
 .../controlprogram/paramserv/LocalPSWorker.java |  99 ++---
 .../paramserv/LocalParamServer.java             |  40 +-
 .../controlprogram/paramserv/PSWorker.java      | 113 +++---
 .../controlprogram/paramserv/ParamServer.java   | 227 +++++------
 .../paramserv/ParamservUtils.java               |  12 +-
 .../cp/ParamservBuiltinCPInstruction.java       | 260 +++++++++----
 .../functions/paramserv/ParamservFuncTest.java  |  39 +-
 .../paramserv/mnist_lenet_paramserv.dml         |   1 -
 .../paramserv/mnist_lenet_paramserv_asp.dml     | 376 +++++++++++++++++++
 .../mnist_lenet_paramserv_minimum_version.dml   |   1 -
 .../paramserv/paramserv-agg-service-failed.dml  |  53 +++
 .../paramserv/paramserv-large-parallelism.dml   |  52 +++
 .../functions/paramserv/paramserv-nn-asp.dml    |  52 +++
 .../paramserv/paramserv-worker-failed.dml       |  53 +++
 .../paramserv-wrong-aggregate-func.dml          |  50 +++
 16 files changed, 1117 insertions(+), 319 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/parser/Statement.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Statement.java 
b/src/main/java/org/apache/sysml/parser/Statement.java
index 1987d31..d49eada 100644
--- a/src/main/java/org/apache/sysml/parser/Statement.java
+++ b/src/main/java/org/apache/sysml/parser/Statement.java
@@ -77,7 +77,13 @@ public abstract class Statement implements ParseInfo
        }
        public static final String PS_UPDATE_TYPE = "utype";
        public enum PSUpdateType {
-               BSP, ASP, SSP
+               BSP, ASP, SSP;
+               public boolean isBSP() {
+                       return this == BSP;
+               }
+               public boolean isASP() {
+                       return this == ASP;
+               }
        }
        public static final String PS_FREQUENCY = "freq";
        public enum PSFrequency {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index 181b866..e902aea 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -19,79 +19,82 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
+import java.util.concurrent.Callable;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 
-public class LocalPSWorker extends PSWorker implements Runnable {
+public class LocalPSWorker extends PSWorker implements Callable<Void> {
 
        protected static final Log LOG = 
LogFactory.getLog(LocalPSWorker.class.getName());
 
-       public LocalPSWorker(long workerID, String updFunc, 
Statement.PSFrequency freq, int epochs, long batchSize,
-                       ListObject hyperParams, ExecutionContext ec, 
ParamServer ps) {
-               super(workerID, updFunc, freq, epochs, batchSize, hyperParams, 
ec, ps);
+       public LocalPSWorker(int workerID, String updFunc, 
Statement.PSFrequency freq, int epochs, long batchSize,
+                       ExecutionContext ec, ParamServer ps) {
+               super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
        }
 
        @Override
-       public void run() {
+       public Void call() throws Exception {
+               try {
+                       long dataSize = _features.getNumRows();
+                       for (int i = 0; i < _epochs; i++) {
+                               int totalIter = (int) Math.ceil(dataSize / 
_batchSize);
+                               for (int j = 0; j < totalIter; j++) {
+                                       // Pull the global parameters from ps
+                                       ListObject globalParams = 
(ListObject)_ps.pull(_workerID);
+                                       if (LOG.isDebugEnabled()) {
+                                               LOG.debug(String.format("Local 
worker_%d: Successfully pull the global parameters "
+                                                       + "[size:%d kb] from 
ps.", _workerID, globalParams.getDataSize() / 1024));
+                                       }
+                                       _ec.setVariable(Statement.PS_MODEL, 
globalParams);
 
-               long dataSize = _features.getNumRows();
+                                       long begin = j * _batchSize + 1;
+                                       long end = Math.min(begin + _batchSize, 
dataSize);
 
-               for (int i = 0; i < _epochs; i++) {
-                       int totalIter = (int) Math.ceil(dataSize / _batchSize);
-                       for (int j = 0; j < totalIter; j++) {
-                               // Pull the global parameters from ps
-                               // Need to copy the global parameter
-                               ListObject globalParams = 
ParamservUtils.copyList((ListObject) _ps.pull(_workerID));
-                               if (LOG.isDebugEnabled()) {
-                                       LOG.debug(String.format(
-                                                       "Local worker_%d: 
Successfully pull the global parameters [size:%d kb] from ps.", _workerID,
-                                                       
globalParams.getDataSize() / 1024));
-                               }
-                               _ec.setVariable(Statement.PS_MODEL, 
globalParams);
+                                       // Get batch features and labels
+                                       MatrixObject bFeatures = 
ParamservUtils.sliceMatrix(_features, begin, end);
+                                       MatrixObject bLabels = 
ParamservUtils.sliceMatrix(_labels, begin, end);
+                                       _ec.setVariable(Statement.PS_FEATURES, 
bFeatures);
+                                       _ec.setVariable(Statement.PS_LABELS, 
bLabels);
 
-                               long begin = j * _batchSize + 1;
-                               long end = Math.min(begin + _batchSize, 
dataSize);
+                                       if (LOG.isDebugEnabled()) {
+                                               LOG.debug(String.format("Local 
worker_%d: Got batch data [size:%d kb] of index from %d to %d. "
+                                                       + "[Epoch:%d  Total 
epoch:%d  Iteration:%d  Total iteration:%d]", _workerID, bFeatures.getDataSize()
+                                                       / 1024 + 
bLabels.getDataSize() / 1024, begin, end, i + 1, _epochs, j + 1, totalIter));
+                                       }
 
-                               // Get batch features and labels
-                               MatrixObject bFeatures = 
ParamservUtils.sliceMatrix(_features, begin, end);
-                               MatrixObject bLabels = 
ParamservUtils.sliceMatrix(_labels, begin, end);
-                               _ec.setVariable(Statement.PS_FEATURES, 
bFeatures);
-                               _ec.setVariable(Statement.PS_LABELS, bLabels);
+                                       // Invoke the update function
+                                       _inst.processInstruction(_ec);
 
-                               if (LOG.isDebugEnabled()) {
-                                       LOG.debug(String.format(
-                                                       "Local worker_%d: Got 
batch data [size:%d kb] of index from %d to %d. [Epoch:%d  Total epoch:%d  
Iteration:%d  Total iteration:%d]",
-                                                       _workerID, 
bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 
1,
-                                                       _epochs, j + 1, 
totalIter));
-                               }
+                                       // Get the gradients
+                                       ListObject gradients = (ListObject) 
_ec.getVariable(_output.getName());
 
-                               // Invoke the update function
-                               _inst.processInstruction(_ec);
+                                       // Push the gradients to ps
+                                       _ps.push(_workerID, gradients);
+                                       if (LOG.isDebugEnabled()) {
+                                               LOG.debug(String.format("Local 
worker_%d: Successfully push the gradients "
+                                                       + "[size:%d kb] to 
ps.", _workerID, gradients.getDataSize() / 1024));
+                                       }
 
-                               // Get the gradients
-                               ListObject gradients = (ListObject) 
_ec.getVariable(_outputs.get(0).getName());
-
-                               // Push the gradients to ps
-                               _ps.push(_workerID, gradients);
+                                       ParamservUtils.cleanupListObject(_ec, 
globalParams);
+                                       ParamservUtils.cleanupData(bFeatures);
+                                       ParamservUtils.cleanupData(bLabels);
+                               }
                                if (LOG.isDebugEnabled()) {
-                                       LOG.debug(String.format("Local 
worker_%d: Successfully push the gradients [size:%d kb] to ps.",
-                                                       _workerID, 
gradients.getDataSize() / 1024));
+                                       LOG.debug(String.format("Local 
worker_%d: Finished %d epoch.", _workerID, i + 1));
                                }
-
-                               ParamservUtils.cleanupListObject(_ec, 
globalParams);
-                               ParamservUtils.cleanupData(bFeatures);
-                               ParamservUtils.cleanupData(bLabels);
                        }
                        if (LOG.isDebugEnabled()) {
-                               LOG.debug(String.format("Local worker_%d: 
Finished %d epoch.", _workerID, i + 1));
+                               LOG.debug(String.format("Local worker_%d: Job 
finished.", _workerID));
                        }
+               } catch (Exception e) {
+                       throw new DMLRuntimeException(String.format("Local 
worker_%d failed", _workerID), e);
                }
-               if (LOG.isDebugEnabled()) {
-                       LOG.debug(String.format("Local worker_%d: Job 
finished.", _workerID));
-               }
+               return null;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
index d060a91..665395e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
+import java.util.concurrent.ExecutionException;
+
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
@@ -28,32 +30,32 @@ import org.apache.sysml.runtime.instructions.cp.ListObject;
 public class LocalParamServer extends ParamServer {
 
        public LocalParamServer(ListObject model, String aggFunc, 
Statement.PSFrequency freq,
-                       Statement.PSUpdateType updateType, ExecutionContext ec, 
int workerNum,
-                       ListObject hyperParams) {
-               super(model, aggFunc, freq, updateType, ec, workerNum, 
hyperParams);
+                       Statement.PSUpdateType updateType, ExecutionContext ec, 
int workerNum) {
+               super(model, aggFunc, freq, updateType, ec, workerNum);
        }
 
        @Override
-       public void push(long workerID, ListObject gradients) {
-               synchronized (_lock) {
-                       _queue.add(new Gradient(workerID, gradients));
-                       _lock.notifyAll();
+       public void push(int workerID, ListObject gradients) {
+               try {
+                       _gradientsQueue.put(new Gradient(workerID, gradients));
+               } catch (InterruptedException e) {
+                       throw new DMLRuntimeException(e);
+               }
+               try {
+                       launchService();
+               } catch (ExecutionException | InterruptedException e) {
+                       throw new DMLRuntimeException("Aggregate service: some 
error occurred: ", e);
                }
        }
 
        @Override
-       public Data pull(long workerID) {
-               synchronized (_lock) {
-                       while (getPulledState((int) workerID)) {
-                               try {
-                                       _lock.wait();
-                               } catch (InterruptedException e) {
-                                       throw new DMLRuntimeException(
-                                                       String.format("Local 
worker_%d: failed to pull the global parameters.", workerID), e);
-                               }
-                       }
-                       setPulledState((int) workerID, true);
+       public Data pull(int workerID) {
+               ListObject model;
+               try {
+                       model = _modelMap.get((int) workerID).take();
+               } catch (InterruptedException e) {
+                       throw new DMLRuntimeException(e);
                }
-               return getResult();
+               return model;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index 9ace823..56fda22 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -30,102 +30,99 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
 
 @SuppressWarnings("unused")
 public abstract class PSWorker {
 
-       long _workerID = -1;
-       int _epochs;
-       long _batchSize;
-       MatrixObject _features;
-       MatrixObject _labels;
-       ExecutionContext _ec;
-       ParamServer _ps;
-       private String _updFunc;
-       private Statement.PSFrequency _freq;
+       protected final int _workerID;
+       protected final int _epochs;
+       protected final long _batchSize;
+       protected final ExecutionContext _ec;
+       protected final ParamServer _ps;
+       protected final DataIdentifier _output;
+       protected final FunctionCallCPInstruction _inst;
+       protected MatrixObject _features;
+       protected MatrixObject _labels;
+       
        private MatrixObject _valFeatures;
        private MatrixObject _valLabels;
-
-       ArrayList<DataIdentifier> _outputs;
-       FunctionCallCPInstruction _inst;
-
-       public PSWorker(long workerID, String updFunc, Statement.PSFrequency 
freq, int epochs, long batchSize,
-                       ListObject hyperParams, ExecutionContext ec, 
ParamServer ps) {
-               this._workerID = workerID;
-               this._updFunc = updFunc;
-               this._freq = freq;
-               this._epochs = epochs;
-               this._batchSize = batchSize;
-               this._ec = 
ExecutionContextFactory.createContext(ec.getProgram());
-               if (hyperParams != null) {
-                       this._ec.setVariable(Statement.PS_HYPER_PARAMS, 
hyperParams);
-               }
-               this._ps = ps;
+       private final String _updFunc;
+       private final Statement.PSFrequency _freq;
+       
+       protected PSWorker(int workerID, String updFunc, Statement.PSFrequency 
freq,
+               int epochs, long batchSize, ExecutionContext ec, ParamServer 
ps) {
+               _workerID = workerID;
+               _updFunc = updFunc;
+               _freq = freq;
+               _epochs = epochs;
+               _batchSize = batchSize;
+               _ec = ec;
+               _ps = ps;
 
                // Get the update function
                String[] keys = DMLProgram.splitFunctionKey(updFunc);
-               String _funcName = keys[0];
-               String _funcNS = null;
+               String funcName = keys[0];
+               String funcNS = null;
                if (keys.length == 2) {
-                       _funcNS = keys[0];
-                       _funcName = keys[1];
+                       funcNS = keys[0];
+                       funcName = keys[1];
                }
-               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(_funcNS, _funcName);
-               ArrayList<DataIdentifier> _inputs = func.getInputParams();
-               _outputs = func.getOutputParams();
-               CPOperand[] _boundInputs = _inputs.stream()
+               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(funcNS, funcName);
+               ArrayList<DataIdentifier> inputs = func.getInputParams();
+               ArrayList<DataIdentifier> outputs = func.getOutputParams();
+               CPOperand[] boundInputs = inputs.stream()
                                .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
                                .toArray(CPOperand[]::new);
-               ArrayList<String> _inputNames = 
_inputs.stream().map(DataIdentifier::getName)
+               ArrayList<String> inputNames = 
inputs.stream().map(DataIdentifier::getName)
                                
.collect(Collectors.toCollection(ArrayList::new));
-               ArrayList<String> _outputNames = 
_outputs.stream().map(DataIdentifier::getName)
+               ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                                
.collect(Collectors.toCollection(ArrayList::new));
-               _inst = new FunctionCallCPInstruction(_funcNS, _funcName, 
_boundInputs, _inputNames, _outputNames,
+               _inst = new FunctionCallCPInstruction(funcNS, funcName, 
boundInputs, inputNames, outputNames,
                                "update function");
 
                // Check the inputs of the update function
-               checkInput(_inputs, Expression.DataType.MATRIX, 
Statement.PS_FEATURES);
-               checkInput(_inputs, Expression.DataType.MATRIX, 
Statement.PS_LABELS);
-               checkInput(_inputs, Expression.DataType.LIST, 
Statement.PS_MODEL);
-               if (hyperParams != null) {
-                       checkInput(_inputs, Expression.DataType.LIST, 
Statement.PS_HYPER_PARAMS);
-               }
+               checkInput(false, inputs, Expression.DataType.MATRIX, 
Statement.PS_FEATURES);
+               checkInput(false, inputs, Expression.DataType.MATRIX, 
Statement.PS_LABELS);
+               checkInput(false, inputs, Expression.DataType.LIST, 
Statement.PS_MODEL);
+               checkInput(true, inputs, Expression.DataType.LIST, 
Statement.PS_HYPER_PARAMS);
 
                // Check the output of the update function
-               if (_outputs.size() != 1) {
-                       throw new DMLRuntimeException(
-                               String.format("The output of the '%s' function 
should provide one list containing the gradients.", updFunc));
+               if (outputs.size() != 1) {
+                       throw new DMLRuntimeException(String.format("The output 
of the '%s' function "
+                               + "should provide one list containing the 
gradients.", updFunc));
                }
-               if (_outputs.get(0).getDataType() != Expression.DataType.LIST) {
-                       throw new DMLRuntimeException(
-                                       String.format("The output of the '%s' 
function should be type of list.", updFunc));
+               if (outputs.get(0).getDataType() != Expression.DataType.LIST) {
+                       throw new DMLRuntimeException(String.format("The output 
of the '%s' function should be type of list.", updFunc));
                }
+               _output = outputs.get(0);
        }
 
-       private void checkInput(ArrayList<DataIdentifier> _inputs, 
Expression.DataType dt, String pname) {
-               if (_inputs.stream().filter(input -> input.getDataType() == dt 
&& pname.equals(input.getName())).count() != 1) {
-                       throw new DMLRuntimeException(
-                               String.format("The '%s' function should provide 
an input of '%s' type named '%s'.", _updFunc, dt, pname));
+       private void checkInput(boolean optional, ArrayList<DataIdentifier> 
inputs, Expression.DataType dt, String pname) {
+               if (optional && inputs.stream().noneMatch(input -> 
pname.equals(input.getName()))) {
+                       // We do not need to check if the input is optional and 
is not provided
+                       return;
+               }
+               if (inputs.stream().filter(input -> input.getDataType() == dt 
&& pname.equals(input.getName())).count() != 1) {
+                       throw new DMLRuntimeException(String.format("The '%s' 
function should provide "
+                               + "an input of '%s' type named '%s'.", 
_updFunc, dt, pname));
                }
        }
 
        public void setFeatures(MatrixObject features) {
-               this._features = features;
+               _features = features;
        }
 
        public void setLabels(MatrixObject labels) {
-               this._labels = labels;
+               _labels = labels;
        }
 
        public void setValFeatures(MatrixObject valFeatures) {
-               this._valFeatures = valFeatures;
+               _valFeatures = valFeatures;
        }
 
        public void setValLabels(MatrixObject valLabels) {
-               this._valLabels = valLabels;
+               _valLabels = valLabels;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index 6e1cd13..1052390 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -21,9 +21,17 @@ package org.apache.sysml.runtime.controlprogram.paramserv;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Queue;
-import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingDeque;
 import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.logging.Log;
@@ -35,98 +43,83 @@ import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 
 public abstract class ParamServer {
+       
+       final BlockingQueue<Gradient> _gradientsQueue;
+       final Map<Integer, BlockingQueue<ListObject>> _modelMap;
+       private final AggregationService _aggService;
+       private final ExecutorService _es;
+       private ListObject _model;
 
-       public class Gradient {
-               final long _workerID;
-               final ListObject _gradients;
-
-               public Gradient(long workerID, ListObject gradients) {
-                       this._workerID = workerID;
-                       this._gradients = gradients;
+       ParamServer(ListObject model, String aggFunc, Statement.PSFrequency 
freq, Statement.PSUpdateType updateType,
+                       ExecutionContext ec, int workerNum) {
+               _gradientsQueue = new LinkedBlockingDeque<>();
+               _modelMap = new HashMap<>(workerNum);
+               IntStream.range(0, workerNum).forEach(i -> {
+                       // Create a single element blocking queue for workers 
to receive the broadcasted model
+                       _modelMap.put(i, new ArrayBlockingQueue<>(1));
+               });
+               _model = model;
+               _aggService = new AggregationService(aggFunc, freq, updateType, 
ec, workerNum);
+               try {
+                       _aggService.broadcastModel();
                }
+               catch (InterruptedException e) {
+                       throw new DMLRuntimeException("Param server: failed to 
broadcast the initial model.", e);
+               }
+               _es = Executors.newSingleThreadExecutor();
        }
 
-       Queue<Gradient> _queue;
-       final Object _lock = new Object();
-       private ListObject _model;
-       private AggregationService _aggService;
-       private Thread _aggThread;
-       private boolean[] _pulledStates;
-
-       protected ParamServer(ListObject model, String aggFunc, 
Statement.PSFrequency freq,
-                       Statement.PSUpdateType updateType, ExecutionContext ec, 
int workerNum, ListObject hyperParams) {
-               this._queue = new ConcurrentLinkedQueue<>();
-               this._model = model;
-               this._aggService = new AggregationService(aggFunc, freq, 
updateType, ec, workerNum, hyperParams);
-               this._pulledStates = new boolean[workerNum];
-               this._aggThread = new Thread(_aggService);
-       }
-
-       public abstract void push(long workerID, ListObject value);
+       public abstract void push(int workerID, ListObject value);
 
-       public abstract Data pull(long workerID);
+       public abstract Data pull(int workerID);
 
-       public void start() {
-               _aggService._alive = true;
-               _aggThread.start();
+       void launchService() throws ExecutionException, InterruptedException {
+               _es.submit(_aggService).get();
        }
 
-       public void stop() {
-               _aggService._alive = false;
-               try {
-                       _aggThread.join();
-               } catch (InterruptedException e) {
-                       throw new DMLRuntimeException("Parameter server: failed 
when stopping the server.", e);
-               }
+       public void shutdown() {
+               _es.shutdownNow();
        }
 
        public ListObject getResult() {
                return _model;
        }
 
-       public boolean getPulledState(int workerID) {
-               return _pulledStates[workerID];
-       }
-
-       public void setPulledState(int workerID, boolean state) {
-               _pulledStates[workerID] = state;
-       }
+       public static class Gradient {
+               final int _workerID;
+               final ListObject _gradients;
 
-       private void resetPulledStates() {
-               _pulledStates = new boolean[_pulledStates.length];
+               public Gradient(int workerID, ListObject gradients) {
+                       _workerID = workerID;
+                       _gradients = gradients;
+               }
        }
-
+       
        /**
         * Inner aggregation service which is for updating the model
         */
-       @SuppressWarnings("unused")
-       private class AggregationService implements Runnable {
+       private class AggregationService implements Callable<Void> {
 
                protected final Log LOG = 
LogFactory.getLog(AggregationService.class.getName());
 
                protected ExecutionContext _ec;
-               private Statement.PSFrequency _freq;
+               //private Statement.PSFrequency _freq;
                private Statement.PSUpdateType _updateType;
                private FunctionCallCPInstruction _inst;
                private DataIdentifier _output;
-               private boolean _alive;
                private boolean[] _finishedStates;  // Workers' finished states
 
                AggregationService(String aggFunc, Statement.PSFrequency freq, 
Statement.PSUpdateType updateType,
-                               ExecutionContext ec, int workerNum, ListObject 
hyperParams) {
-                       _ec = 
ExecutionContextFactory.createContext(ec.getProgram());
-                       _freq = freq;
+                               ExecutionContext ec, int workerNum) {
+                       _ec = ec;
+                       //_freq = freq;
                        _updateType = updateType;
-                       if (hyperParams != null) {
-                               _ec.setVariable(Statement.PS_HYPER_PARAMS, 
hyperParams);
-                       }
                        _finishedStates = new boolean[workerNum];
 
                        // Fetch the aggregation function
@@ -143,13 +136,11 @@ public abstract class ParamServer {
 
                        // Check the output of the aggregation function
                        if (outputs.size() != 1) {
-                               throw new DMLRuntimeException(String.format(
-                                               "The output of the '%s' 
function should provide one list containing the updated model.",
+                               throw new 
DMLRuntimeException(String.format("The output of the '%s' function should 
provide one list containing the updated model.",
                                                aggFunc));
                        }
                        if (outputs.get(0).getDataType() != 
Expression.DataType.LIST) {
-                               throw new DMLRuntimeException(
-                                               String.format("The output of 
the '%s' function should be type of list.", aggFunc));
+                               throw new 
DMLRuntimeException(String.format("The output of the '%s' function should be 
type of list.", aggFunc));
                        }
                        _output = outputs.get(0);
 
@@ -160,12 +151,7 @@ public abstract class ParamServer {
                                        
.collect(Collectors.toCollection(ArrayList::new));
                        ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                                        
.collect(Collectors.toCollection(ArrayList::new));
-                       _inst = new FunctionCallCPInstruction(funcNS, funcName, 
boundInputs, inputNames, outputNames,
-                                       "aggregate function");
-               }
-
-               boolean isAlive() {
-                       return _alive;
+                       _inst = new FunctionCallCPInstruction(funcNS, funcName, 
boundInputs, inputNames, outputNames, "aggregate function");
                }
 
                private boolean allFinished() {
@@ -180,53 +166,76 @@ public abstract class ParamServer {
                        _finishedStates[workerID] = true;
                }
 
-               @Override
-               public void run() {
-                       synchronized (_lock) {
-                               while (isAlive()) {
-                                       do {
-                                               while (_queue.isEmpty()) {
-                                                       try {
-                                                               _lock.wait();
-                                                       } catch 
(InterruptedException e) {
-                                                               throw new 
DMLRuntimeException(
-                                                                               
"Aggregation service: error when waiting for the coming gradients.", e);
-                                                       }
-                                               }
-                                               Gradient p = _queue.remove();
-                                               if (LOG.isDebugEnabled()) {
-                                                       
LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of 
worker_%d.",
-                                                                       
p._gradients.getDataSize() / 1024, p._workerID));
-                                               }
-
-                                               setFinishedState((int) 
p._workerID);
+               private void broadcastModel() throws InterruptedException {
+                       //broadcast copy of the model to all workers, cleaned 
up by workers
+                       for (BlockingQueue<ListObject> q : _modelMap.values())
+                               q.put(ParamservUtils.copyList(_model));
+               }
+               
+               private void broadcastModel(int workerID) throws 
InterruptedException {
+                       //broadcast copy of model to specific worker, cleaned 
up by worker
+                       
_modelMap.get(workerID).put(ParamservUtils.copyList(_model));
+               }
 
-                                               // Populate the variables table 
with the gradients and model
-                                               
_ec.setVariable(Statement.PS_GRADIENTS, p._gradients);
-                                               
_ec.setVariable(Statement.PS_MODEL, _model);
+               @Override
+               public Void call() throws Exception {
+                       try {
+                               Gradient grad;
+                               try {
+                                       grad = _gradientsQueue.take();
+                               } catch (InterruptedException e) {
+                                       throw new 
DMLRuntimeException("Aggregation service: error when waiting for the coming 
gradients.", e);
+                               }
+                               if (LOG.isDebugEnabled()) {
+                                       LOG.debug(String.format("Successfully 
pulled the gradients [size:%d kb] of worker_%d.",
+                                                       
grad._gradients.getDataSize() / 1024, grad._workerID));
+                               }
 
-                                               // Invoke the aggregate function
-                                               _inst.processInstruction(_ec);
+                               // Update and redistribute the model
+                               updateModel(grad);
+                               
+                               // Redistribute model according to update type
+                               switch( _updateType ) {
+                                       case BSP: {
+                                               
setFinishedState(grad._workerID);
+                                               if( allFinished() ) {
+                                                       // Broadcast the 
updated model
+                                                       resetFinishedStates();
+                                                       broadcastModel();
+                                                       if 
(LOG.isDebugEnabled())
+                                                               
LOG.debug("Global parameter is broadcasted successfully.");
+                                               }
+                                               break;
+                                       }
+                                       case ASP: {
+                                               broadcastModel(grad._workerID);
+                                               break;
+                                       }
+                                       default:
+                                               throw new 
DMLRuntimeException("Unsupported update: "+_updateType.name());
+                               }
+                       } 
+                       catch (Exception e) {
+                               throw new DMLRuntimeException("Aggregation 
service failed: ", e);
+                       }
+                       return null;
+               }
 
-                                               // Get the output
-                                               ListObject newModel = 
(ListObject) _ec.getVariable(_output.getName());
+               private void updateModel(Gradient grad) throws 
InterruptedException {
+                       // Populate the variables table with the gradients and 
model
+                       _ec.setVariable(Statement.PS_GRADIENTS, 
grad._gradients);
+                       _ec.setVariable(Statement.PS_MODEL, _model);
 
-                                               // Update the model with the 
new output
-                                               
ParamservUtils.cleanupListObject(_ec, _model);
-                                               
ParamservUtils.cleanupListObject(_ec, p._gradients);
-                                               _model = newModel;
+                       // Invoke the aggregate function
+                       _inst.processInstruction(_ec);
 
-                                       } while (!allFinished());
+                       // Get the output
+                       ListObject newModel = (ListObject) 
_ec.getVariable(_output.getName());
 
-                                       // notify all the workers to get the 
updated model
-                                       resetPulledStates();
-                                       resetFinishedStates();
-                                       _lock.notifyAll();
-                                       if (LOG.isDebugEnabled()) {
-                                               LOG.debug("Global parameter is 
broadcasted successfully.");
-                                       }
-                               }
-                       }
+                       // Update the model with the new output
+                       ParamservUtils.cleanupListObject(_ec, _model);
+                       ParamservUtils.cleanupListObject(_ec, grad._gradients);
+                       _model = newModel;
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index 54c5d6c..426a7fe 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -22,6 +22,7 @@ package org.apache.sysml.runtime.controlprogram.paramserv;
 import java.util.HashSet;
 import java.util.List;
 import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import org.apache.sysml.parser.Expression;
 import org.apache.sysml.runtime.DMLRuntimeException;
@@ -49,8 +50,8 @@ public class ParamservUtils {
                if (lo.getLength() == 0) {
                        return lo;
                }
-               List<Data> newData = lo.getNames().stream().map(name -> {
-                       Data oldData = lo.slice(name);
+               List<Data> newData = IntStream.range(0, 
lo.getLength()).mapToObj(i -> {
+                       Data oldData = lo.slice(i);
                        if (oldData instanceof MatrixObject) {
                                MatrixObject mo = (MatrixObject) oldData;
                                return sliceMatrix(mo, 1, mo.getNumRows());
@@ -69,7 +70,7 @@ public class ParamservUtils {
        }
 
        public static void cleanupData(Data data) {
-               if( !(data instanceof CacheableData) )
+               if (!(data instanceof CacheableData))
                        return;
                CacheableData<?> cd = (CacheableData<?>) data;
                cd.enableCleanup(true);
@@ -78,6 +79,7 @@ public class ParamservUtils {
 
        /**
         * Slice the matrix
+        *
         * @param mo input matrix
         * @param rl low boundary
         * @param rh high boundary
@@ -88,10 +90,10 @@ public class ParamservUtils {
                        new MetaDataFormat(new MatrixCharacteristics(-1, -1, 
-1, -1),
                                OutputInfo.BinaryBlockOutputInfo, 
InputInfo.BinaryBlockInputInfo));
                MatrixBlock tmp = mo.acquireRead();
-               result.acquireModify(tmp.slice((int)rl-1, (int)rh-1, 0,
-                       tmp.getNumColumns()-1, new MatrixBlock()));
+               result.acquireModify(tmp.slice((int) rl - 1, (int) rh - 1, 0, 
tmp.getNumColumns() - 1, new MatrixBlock()));
                mo.release();
                result.release();
+               result.enableCleanup(false);
                return result;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 3ab0fc8..79d1ff3 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -39,30 +39,49 @@ import static 
org.apache.sysml.parser.Statement.PS_UPDATE_TYPE;
 import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES;
 import static org.apache.sysml.parser.Statement.PS_VAL_LABELS;
 
+import java.io.IOException;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DMLTranslator;
+import org.apache.sysml.parser.StatementBlock;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
+import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
+import org.apache.sysml.runtime.controlprogram.Program;
+import org.apache.sysml.runtime.controlprogram.ProgramBlock;
+import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.matrix.operators.Operator;
-import org.apache.sysml.utils.NativeHelper;
 
 public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruction {
 
        private static final int DEFAULT_BATCH_SIZE = 64;
        private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = 
PSFrequency.BATCH;
-       private static final int DEFAULT_LEVEL_PARALLELISM = 
InfrastructureAnalyzer.getLocalParallelism();
        private static final PSScheme DEFAULT_SCHEME = 
PSScheme.DISJOINT_CONTIGUOUS;
 
        //internal local debug level
@@ -71,79 +90,178 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        static {
                // for internal debugging only
                if (LDEBUG) {
-                       
Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel((Level)
 Level.DEBUG);
+                       
Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel(Level.DEBUG);
                }
        }
 
-       protected ParamservBuiltinCPInstruction(Operator op, 
LinkedHashMap<String, String> paramsMap, CPOperand out,
-                       String opcode, String istr) {
+       public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, 
String> paramsMap, CPOperand out, String opcode, String istr) {
                super(op, paramsMap, out, opcode, istr);
        }
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-
                PSModeType mode = PSModeType.valueOf(getParam(PS_MODE));
                int workerNum = getWorkerNum(mode);
+               ExecutorService es = Executors.newFixedThreadPool(workerNum);
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
+
+               // Create the workers' execution context
+               int k = getParLevel(workerNum);
+               List<ExecutionContext> workerECs = createExecutionContext(ec, 
updFunc, workerNum, k);
+
+               // Create the agg service's execution context
+               ExecutionContext aggServiceEC = createExecutionContext(ec, 
aggFunc, 1, 1).get(0);
+
                PSFrequency freq = getFrequency();
                PSUpdateType updateType = getUpdateType();
-               int epochs = Integer.valueOf(getParam(PS_EPOCHS));
-               if (epochs <= 0) {
-                       throw new DMLRuntimeException(
-                                       String.format("Paramserv function: The 
argument '%s' could not be less than or equal to 0.",
-                                                       PS_EPOCHS));
-               }
-               long batchSize = getBatchSize();
+               int epochs = getEpochs();
 
                // Create the parameter server
                ListObject model = ec.getListObject(getParam(PS_MODEL));
-               ListObject hyperParams = getHyperParams(ec);
-               ParamServer ps = createPS(mode, aggFunc, freq, updateType, 
workerNum, model, ec, hyperParams);
+               ParamServer ps = createPS(mode, aggFunc, freq, updateType, 
workerNum, model, aggServiceEC);
 
                // Create the local workers
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
-                               .mapToObj(i -> new LocalPSWorker((long) i, 
updFunc, freq, epochs, batchSize, hyperParams, ec, ps))
-                               .collect(Collectors.toList());
+                       .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, 
epochs, getBatchSize(), workerECs.get(i), ps))
+                       .collect(Collectors.toList());
 
                // Do data partition
                doDataPartition(ec, workers);
 
-               // Create the worker threads
-               List<Thread> threads = 
workers.stream().map(Thread::new).collect(Collectors.toList());
+               // Launch the worker threads and wait for completion
+               try {
+                       for( Future<Void> ret : es.invokeAll(workers) )
+                               ret.get(); //error handling
+               } catch (InterruptedException | ExecutionException e) {
+                       throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
+               } finally {
+                       es.shutdownNow();
+               }
 
-               // Start the ps
-               ps.start();
+               // Fetch the final model from ps
+               ListObject result;
+               result = ps.getResult();
+               ec.setVariable(output.getName(), result);
+       }
 
-               // Start the workers
-               threads.forEach(Thread::start);
+       private int getEpochs() {
+               int epochs = Integer.valueOf(getParam(PS_EPOCHS));
+               if (epochs <= 0) {
+                       throw new DMLRuntimeException(String.format("Paramserv 
function: "
+                               + "The argument '%s' could not be less than or 
equal to 0.", PS_EPOCHS));
+               }
+               return epochs;
+       }
+
+       private int getParLevel(int workerNum) {
+               return 
Math.max((int)Math.ceil((double)getRemainingCores()/workerNum), 1);
+       }
+
+       private List<ExecutionContext> createExecutionContext(ExecutionContext 
ec, String funcName, int workerNum, int k) {
+               // Fetch the target function
+               String[] keys = DMLProgram.splitFunctionKey(funcName);
+               String namespace = null;
+               String func = keys[0];
+               if (keys.length == 2) {
+                       namespace = keys[0];
+                       func = keys[1];
+               }
+               return createExecutionContext(ec, namespace, func, workerNum, 
k);
+       }
 
-               // Wait for the workers stopping
-               threads.forEach(thread -> {
+       private List<ExecutionContext> createExecutionContext(ExecutionContext 
ec, String namespace, String func,
+                       int workerNum, int k) {
+               FunctionProgramBlock targetFunc = 
ec.getProgram().getFunctionProgramBlock(namespace, func);
+               return IntStream.range(0, workerNum).mapToObj(i -> {
+                       // Put the hyperparam into the variables table
+                       LocalVariableMap varsMap = new LocalVariableMap();
+                       ListObject hyperParams = getHyperParams(ec);
+                       if (hyperParams != null) {
+                               varsMap.put(PS_HYPER_PARAMS, hyperParams);
+                       }
+
+                       // Deep copy the target func
+                       FunctionProgramBlock copiedFunc = ProgramConverter
+                               .createDeepCopyFunctionProgramBlock(targetFunc, 
new HashSet<>(), new HashSet<>());
+
+                       // Reset the visit status from root
+                       for( ProgramBlock pb : copiedFunc.getChildBlocks() )
+                               
DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock());
+
+                       // Should recursively assign the level of parallelism
+                       // and recompile the program block
                        try {
-                               thread.join();
-                       } catch (InterruptedException e) {
-                               throw new DMLRuntimeException("Paramserv 
function: Failed to join the worker threads.", e);
+                               rAssignParallelism(copiedFunc.getChildBlocks(), 
k, false);
+                       } catch (IOException e) {
+                               throw new DMLRuntimeException(e);
                        }
-               });
 
-               ps.stop();
+                       Program prog = new Program();
+                       prog.addProgramBlock(copiedFunc);
+                       prog.addFunctionProgramBlock(namespace, func, 
copiedFunc);
+                       return ExecutionContextFactory.createContext(varsMap, 
prog);
 
-               // Create the output
-               ListObject result = ps.getResult();
-               ec.setVariable(output.getName(), result);
+               }).collect(Collectors.toList());
+       }
+
+       private boolean rAssignParallelism(ArrayList<ProgramBlock> pbs, int k, 
boolean recompiled) throws IOException {
+               for (ProgramBlock pb : pbs) {
+                       if (pb instanceof ParForProgramBlock) {
+                               ParForProgramBlock pfpb = (ParForProgramBlock) 
pb;
+                               pfpb.setDegreeOfParallelism(k);
+                               recompiled |= 
rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled);
+                       }
+                       else if (pb instanceof ForProgramBlock) {
+                               recompiled |= 
rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled);
+                       }
+                       else if (pb instanceof WhileProgramBlock) {
+                               recompiled |= 
rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled);
+                       }
+                       else if (pb instanceof FunctionProgramBlock) {
+                               recompiled |= 
rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled);
+                       }
+                       else if (pb instanceof IfProgramBlock) {
+                               IfProgramBlock ipb = (IfProgramBlock) pb;
+                               recompiled |= 
rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled);
+                               if( ipb.getChildBlocksElseBody() != null )
+                                       recompiled |= 
rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled);
+                       }
+                       else {
+                               StatementBlock sb = pb.getStatementBlock();
+                               for (Hop hop : sb.getHops())
+                                       recompiled |= rAssignParallelism(hop, 
k, recompiled);
+                       }
+                       // Recompile the program block
+                       if (recompiled) {
+                               
Recompiler.recompileProgramBlockInstructions(pb);
+                       }
+               }
+               return recompiled;
+       }
+
+       private boolean rAssignParallelism(Hop hop, int k, boolean recompiled) {
+               if (hop.isVisited()) {
+                       return recompiled;
+               }
+               if (hop instanceof Hop.MultiThreadedHop) {
+                       // Reassign the level of parallelism
+                       Hop.MultiThreadedHop mhop = (Hop.MultiThreadedHop) hop;
+                       mhop.setMaxNumThreads(k);
+                       recompiled = true;
+               }
+               ArrayList<Hop> inputs = hop.getInput();
+               for (Hop h : inputs) {
+                       recompiled |= rAssignParallelism(h, k, recompiled);
+               }
+               hop.setVisited();
+               return recompiled;
        }
 
        private PSUpdateType getUpdateType() {
                PSUpdateType updType = 
PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE));
-               switch (updType) {
-               case ASP:
-               case SSP:
+               if( updType == PSUpdateType.SSP )
                        throw new DMLRuntimeException(String.format("Not 
support update type '%s'.", updType));
-               case BSP:
-                       break;
-               }
                return updType;
        }
 
@@ -152,15 +270,15 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        return DEFAULT_UPDATE_FREQUENCY;
                }
                PSFrequency freq = PSFrequency.valueOf(getParam(PS_FREQUENCY));
-               switch (freq) {
-               case EPOCH:
+               if( freq == PSFrequency.EPOCH )
                        throw new DMLRuntimeException("Not support epoch update 
frequency.");
-               case BATCH:
-                       break;
-               }
                return freq;
        }
 
+       private int getRemainingCores() {
+               return InfrastructureAnalyzer.getLocalParallelism() - 1;
+       }
+
        /**
         * Get the worker numbers according to the vcores
         *
@@ -168,24 +286,17 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         * @return worker numbers
         */
        private int getWorkerNum(PSModeType mode) {
-               int workerNum = DEFAULT_LEVEL_PARALLELISM;
-               if (getParameterMap().containsKey(PS_PARALLELISM)) {
-                       workerNum = Integer.valueOf(getParam(PS_PARALLELISM));
-               }
+               int workerNum = -1;
                switch (mode) {
-               case LOCAL:
-                       //FIXME: this is a workaround for a maximum number of 
buffers in openblas
-                       //However, the root cause is a missing function 
preparation for each worker
-                       //(i.e., deep copy with unique file names, and reduced 
degree of parallelism)
-                       int vcores = 
InfrastructureAnalyzer.getLocalParallelism();
-                       if ("openblas".equals(NativeHelper.getCurrentBLAS())) {
-                               workerNum = Math.min(workerNum, vcores / 2);
-                       } else {
-                               workerNum = Math.min(workerNum, vcores);
-                       }
-                       break;
-               case REMOTE_SPARK:
-                       throw new DMLRuntimeException("Do not support remote 
spark.");
+                       case LOCAL:
+                               // default worker number: available cores - 1 
(assign one process for agg service)
+                               workerNum = getRemainingCores();
+                               if 
(getParameterMap().containsKey(PS_PARALLELISM)) {
+                                       workerNum = Math.min(workerNum, 
Integer.valueOf(getParam(PS_PARALLELISM)));
+                               }
+                               break;
+                       case REMOTE_SPARK:
+                               throw new DMLRuntimeException("Do not support 
remote spark.");
                }
                return workerNum;
        }
@@ -196,14 +307,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         * @return parameter server
         */
        private ParamServer createPS(PSModeType mode, String aggFunc, 
PSFrequency freq, PSUpdateType updateType,
-                       int workerNum, ListObject model, ExecutionContext ec, 
ListObject hyperParams) {
+                       int workerNum, ListObject model, ExecutionContext ec) {
                ParamServer ps = null;
                switch (mode) {
-               case LOCAL:
-                       ps = new LocalParamServer(model, aggFunc, freq, 
updateType, ec, workerNum, hyperParams);
-                       break;
-               case REMOTE_SPARK:
-                       throw new DMLRuntimeException("Do not support remote 
spark.");
+                       case LOCAL:
+                               ps = new LocalParamServer(model, aggFunc, freq, 
updateType, ec, workerNum);
+                               break;
+                       case REMOTE_SPARK:
+                               throw new DMLRuntimeException("Do not support 
remote spark.");
                }
                return ps;
        }
@@ -214,9 +325,8 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                }
                long batchSize = Integer.valueOf(getParam(PS_BATCH_SIZE));
                if (batchSize <= 0) {
-                       throw new DMLRuntimeException(String.format(
-                                       "Paramserv function: the number of 
argument '%s' could not be less than or equal to 0.",
-                                       PS_BATCH_SIZE));
+                       throw new DMLRuntimeException(String.format("Paramserv 
function: the number "
+                               + "of argument '%s' could not be less than or 
equal to 0.", PS_BATCH_SIZE));
                }
                return batchSize;
        }
@@ -245,8 +355,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                case DISJOINT_RANDOM:
                case OVERLAP_RESHUFFLE:
                case DISJOINT_ROUND_ROBIN:
-                       throw new DMLRuntimeException(
-                                       String.format("Paramserv function: the 
scheme '%s' is not supported.", scheme));
+                       throw new DMLRuntimeException(String.format("Paramserv 
function: the scheme '%s' is not supported.", scheme));
                }
        }
 
@@ -256,9 +365,10 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                List<MatrixObject> pfs = disjointContiguous(workers.size(), 
features);
                List<MatrixObject> pls = disjointContiguous(workers.size(), 
labels);
                if (pfs.size() < workers.size()) {
-                       LOG.warn(String.format(
-                                       "There is only %d batches of data but 
has %d workers. Hence, reset the number of workers with %d.",
-                                       pfs.size(), workers.size(), 
pfs.size()));
+                       if (LOG.isWarnEnabled()) {
+                               LOG.warn(String.format("There is only %d 
batches of data but has %d workers. "
+                                       + "Hence, reset the number of workers 
with %d.", pfs.size(), workers.size(), pfs.size()));
+                       }
                        workers = workers.subList(0, pfs.size());
                }
                for (int i = 0; i < workers.size(); i++) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
index 6370099..28b525a 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
@@ -34,6 +34,11 @@ public class ParamservFuncTest extends AutomatedTestBase {
        private static final String TEST_NAME6 = "paramserv-wrong-args2";
        private static final String TEST_NAME7 = "paramserv-nn-test";
        private static final String TEST_NAME8 = "paramserv-minimum-version";
+       private static final String TEST_NAME9 = "paramserv-worker-failed";
+       private static final String TEST_NAME10 = 
"paramserv-agg-service-failed";
+       private static final String TEST_NAME11 = "paramserv-large-parallelism";
+       private static final String TEST_NAME12 = 
"paramserv-wrong-aggregate-func";
+       private static final String TEST_NAME13 = "paramserv-nn-asp";
 
        private static final String TEST_DIR = "functions/paramserv/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
ParamservFuncTest.class.getSimpleName() + "/";
@@ -50,6 +55,11 @@ public class ParamservFuncTest extends AutomatedTestBase {
                addTestConfiguration(TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {}));
                addTestConfiguration(TEST_NAME7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {}));
                addTestConfiguration(TEST_NAME8, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {}));
+               addTestConfiguration(TEST_NAME9, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] {}));
+               addTestConfiguration(TEST_NAME10, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {}));
+               addTestConfiguration(TEST_NAME11, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME11, new String[] {}));
+               addTestConfiguration(TEST_NAME12, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME12, new String[] {}));
+               addTestConfiguration(TEST_NAME13, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME13, new String[] {}));
        }
 
        @Test
@@ -96,8 +106,33 @@ public class ParamservFuncTest extends AutomatedTestBase {
                runDMLTest(TEST_NAME8, false, null, null);
        }
 
-       private void runDMLTest(String testname, boolean exceptionExpected, 
Class<?> exceptionClass,
-                       String errmsg) {
+       @Test
+       public void testParamservWorkerFailedTest() {
+               runDMLTest(TEST_NAME9, true, DMLException.class, "Invalid 
lookup by name in unnamed list: worker_err.");
+       }
+
+       @Test
+       public void testParamservAggServiceFailedTest() {
+               runDMLTest(TEST_NAME10, true, DMLException.class, "Invalid 
lookup by name in unnamed list: agg_service_err");
+       }
+
+       @Test
+       public void testParamservLargeParallelismTest() {
+               runDMLTest(TEST_NAME11, false, null, null);
+       }
+
+       @Test
+       public void testParamservWrongAggregateFuncTest() {
+               runDMLTest(TEST_NAME12, true, DMLException.class,
+                               "The 'gradients' function should provide an 
input of 'MATRIX' type named 'labels'.");
+       }
+
+       @Test
+       public void testParamservASPTest() {
+               runDMLTest(TEST_NAME13, false, null, null);
+       }
+
+       private void runDMLTest(String testname, boolean exceptionExpected, 
Class<?> exceptionClass, String errmsg) {
                TestConfiguration config = getTestConfiguration(testname);
                loadTestConfiguration(config);
                programArgs = new String[] { "-explain" };

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
index 2a3bbe2..4ea6e5f 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -208,7 +208,6 @@ gradients = function(matrix[double] features,
   gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, 
db3=db3, db4=db4)
 }
 
-# PB: how to handle the velocity? (put into the model)
 # Should use the arguments named 'model', 'gradients', 'hyperparams'
 # and return always a model of type list
 aggregation = function(list[unknown] model,

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml
new file mode 100644
index 0000000..b2e155e
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml
@@ -0,0 +1,376 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/conv2d_builtin.dml") as conv2d
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/l2_reg.dml") as l2_reg
+source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+                 matrix[double] X_val, matrix[double] Y_val,
+                 int C, int Hin, int Win, int epochs, int workers)
+    return (matrix[double] W1, matrix[double] b1,
+            matrix[double] W2, matrix[double] b2,
+            matrix[double] W3, matrix[double] b3,
+            matrix[double] W4, matrix[double] b4) {
+  /*
+   * Trains a convolutional net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.  The targets, Y, have K
+   * classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - Y_val: Target validation matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   */
+  N = nrow(X)
+  K = ncol(Y)
+
+  # Create network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.01  # learning rate
+  mu = 0.9  #0.5  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the model object
+  modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, 
vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+
+  # Create the hyper parameter list
+  params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+  # Use paramserv function
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml::aggregation",
 mode="LOCAL", utype="ASP", epochs=epochs, hyperparams=params)
+
+  W1 = as.matrix(modelList2["W1"])
+  b1 = as.matrix(modelList2["b1"])
+  W2 = as.matrix(modelList2["W2"])
+  b2 = as.matrix(modelList2["b2"])
+  W3 = as.matrix(modelList2["W3"])
+  b3 = as.matrix(modelList2["b3"])
+  W4 = as.matrix(modelList2["W4"])
+  b4 = as.matrix(modelList2["b4"])
+
+}
+
+gradients = function(matrix[double] features,
+                     matrix[double] labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+          return (list[unknown] gradients) {
+
+  C = 1
+  Hin = 28
+  Win = 28
+  Hf = 5
+  Wf = 5
+  stride = 1
+  pad = 2
+  lambda = 5e-04
+  F1 = 32
+  F2 = 64
+  N3 = 512
+  W1 = as.matrix(model["W1"])
+  b1 = as.matrix(model["b1"])
+  W2 = as.matrix(model["W2"])
+  b2 = as.matrix(model["b2"])
+  W3 = as.matrix(model["W3"])
+  b3 = as.matrix(model["b3"])
+  W4 = as.matrix(model["W4"])
+  b4 = as.matrix(model["b4"])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, 
Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, 
pad=0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                strideh=2, stridew=2, pad=0, 
pad=0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute data backward pass
+  ## loss:
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, 
F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, Hf=2, Wf=2,
+                                strideh=2, stridew=2, pad=0, pad=0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, 
W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, 
db3=db3, db4=db4)
+
+}
+
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+   return (list[unknown] modelResult) {
+
+     W1 = as.matrix(model["W1"])
+     W2 = as.matrix(model["W2"])
+     W3 = as.matrix(model["W3"])
+     W4 = as.matrix(model["W4"])
+     b1 = as.matrix(model["b1"])
+     b2 = as.matrix(model["b2"])
+     b3 = as.matrix(model["b3"])
+     b4 = as.matrix(model["b4"])
+     dW1 = as.matrix(gradients["dW1"])
+     dW2 = as.matrix(gradients["dW2"])
+     dW3 = as.matrix(gradients["dW3"])
+     dW4 = as.matrix(gradients["dW4"])
+     db1 = as.matrix(gradients["db1"])
+     db2 = as.matrix(gradients["db2"])
+     db3 = as.matrix(gradients["db3"])
+     db4 = as.matrix(gradients["db4"])
+     vW1 = as.matrix(model["vW1"])
+     vW2 = as.matrix(model["vW2"])
+     vW3 = as.matrix(model["vW3"])
+     vW4 = as.matrix(model["vW4"])
+     vb1 = as.matrix(model["vb1"])
+     vb2 = as.matrix(model["vb2"])
+     vb3 = as.matrix(model["vb3"])
+     vb4 = as.matrix(model["vb4"])
+     lr = 0.01
+     mu = 0.9
+
+     # Optimize with SGD w/ Nesterov momentum
+     [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+     [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+     [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+     [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+     [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+     [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+     [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+     [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+     modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, 
b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+   }
+
+predict = function(matrix[double] X, int C, int Hin, int Win,
+                   matrix[double] W1, matrix[double] b1,
+                   matrix[double] W2, matrix[double] b2,
+                   matrix[double] W3, matrix[double] b3,
+                   matrix[double] W4, matrix[double] b4)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a convolutional
+   * net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  N = nrow(X)
+
+  # Network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions 
(num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  batch_size = 64
+  iters = ceil(N / batch_size)
+  for(i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a convolutional net using the "LeNet" architecture.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, Y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - Y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, Y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+  accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+  /*
+   * Generate a dummy dataset similar to the MNIST dataset.
+   *
+   * Outputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   */
+  # Generate dummy input data
+  N = 1024  # num examples
+  C = 1  # num input channels
+  Hin = 28  # input height
+  Win = 28  # input width
+  K = 10  # num target classes
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+  Y = table(seq(1, N), classes)  # one-hot encoding
+}
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
index 8811c36..707722e 100644
--- 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
+++ 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -204,7 +204,6 @@ gradients = function(matrix[double] features,
 
 }
 
-# how to handle the velocity?
 aggregation = function(list[unknown] model,
                        list[unknown] gradients,
                        list[unknown] hyperparams)

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-agg-service-failed.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-agg-service-failed.dml 
b/src/test/scripts/functions/paramserv/paramserv-agg-service-failed.dml
new file mode 100644
index 0000000..1edd237
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-agg-service-failed.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+e1 = "element1"
+modelList = list(e1)
+X = matrix(1, rows=200, cols=30)
+Y = matrix(2, rows=200, cols=1)
+X_val = matrix(3, rows=200, cols=30)
+Y_val = matrix(4, rows=200, cols=1)
+
+gradients = function(matrix[double] features,
+                     matrix[double] labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+            return (list[unknown] gradients) {
+  gradients = model
+}
+
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+              return (list[unknown] modelResult) {
+  modelResult = model
+  print(toString(as.matrix(gradients["agg_service_err"])))
+}
+
+e2 = "element2"
+params = list(e2)
+
+modelList = list("model")
+
+# Use paramserv function
+modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="LOCAL", utype="BSP", epochs=10, hyperparams=params, k=1)
+
+print(toString(as.matrix(modelList2[1])))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml 
b/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml
new file mode 100644
index 0000000..4e2d2b7
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml")
 as mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Arguments
+epochs = 10
+workers = 10
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers)
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml
new file mode 100644
index 0000000..b50e17c
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_asp.dml") 
as mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Arguments
+epochs = 10
+workers = 2
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers)
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-worker-failed.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-worker-failed.dml 
b/src/test/scripts/functions/paramserv/paramserv-worker-failed.dml
new file mode 100644
index 0000000..7ccde60
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-worker-failed.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+e1 = "element1"
+modelList = list(e1)
+X = matrix(1, rows=200, cols=30)
+Y = matrix(2, rows=200, cols=1)
+X_val = matrix(3, rows=200, cols=30)
+Y_val = matrix(4, rows=200, cols=1)
+
+gradients = function(matrix[double] features,
+                     matrix[double] labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+            return (list[unknown] gradients) {
+  gradients = model
+  print(toString(as.matrix(gradients["worker_err"])))
+}
+
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+              return (list[unknown] modelResult) {
+  modelResult = model
+}
+
+e2 = "element2"
+params = list(e2)
+
+modelList = list("model")
+
+# Use paramserv function
+modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="LOCAL", utype="BSP", epochs=10, hyperparams=params, k=1)
+
+print(toString(as.matrix(modelList2[1])))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/d44b3280/src/test/scripts/functions/paramserv/paramserv-wrong-aggregate-func.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-wrong-aggregate-func.dml 
b/src/test/scripts/functions/paramserv/paramserv-wrong-aggregate-func.dml
new file mode 100644
index 0000000..ca76850
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-wrong-aggregate-func.dml
@@ -0,0 +1,50 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+e1 = "element1"
+modelList = list(e1)
+X = matrix(1, rows=2, cols=3)
+Y = matrix(2, rows=2, cols=3)
+X_val = matrix(3, rows=2, cols=3)
+Y_val = matrix(4, rows=2, cols=3)
+
+gradients = function(matrix[double] features,
+                     matrix[double] wrong_labels,
+                     list[unknown] hyperparams,
+                     list[unknown] model)
+            return (list[unknown] gradients) {
+  gradients = model
+}
+
+aggregation = function(list[unknown] model,
+                       list[unknown] gradients,
+                       list[unknown] hyperparams)
+              return (list[unknown] modelResult) {
+  modelResult = model
+}
+
+e2 = "element2"
+params = list(e2)
+
+# Use paramserv function
+modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="LOCAL", utype="BSP", epochs=10, hyperparams=params)
+
+print(toString(as.matrix(modelList2[1])))
\ No newline at end of file

Reply via email to