[SYSTEMML-2420,2422] New distributed paramserv spark workers and rpc

Closes #805.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/15ecb723
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/15ecb723
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/15ecb723

Branch: refs/heads/master
Commit: 15ecb723e39e3154412ca8f8824c4554ee64ca35
Parents: 54dbe9b
Author: EdgarLGB <[email protected]>
Authored: Sat Jul 21 22:31:36 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jul 21 22:31:36 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/paramserv/LocalPSWorker.java |  34 +++---
 .../paramserv/LocalParamServer.java             |   7 +-
 .../controlprogram/paramserv/PSWorker.java      |  15 ++-
 .../controlprogram/paramserv/ParamServer.java   |  39 ++++---
 .../paramserv/ParamservUtils.java               |  65 ++++++-----
 .../paramserv/spark/SparkPSBody.java            |   6 +-
 .../paramserv/spark/SparkPSProxy.java           |  68 +++++++++++
 .../paramserv/spark/SparkPSWorker.java          |  46 ++++++--
 .../paramserv/spark/rpc/PSRpcCall.java          |  97 ++++++++++++++++
 .../paramserv/spark/rpc/PSRpcFactory.java       |  57 ++++++++++
 .../paramserv/spark/rpc/PSRpcHandler.java       |  83 ++++++++++++++
 .../paramserv/spark/rpc/PSRpcObject.java        |  57 ++++++++++
 .../paramserv/spark/rpc/PSRpcResponse.java      | 112 +++++++++++++++++++
 .../cp/ParamservBuiltinCPInstruction.java       |  52 +++++++--
 .../sysml/runtime/util/ProgramConverter.java    |  11 +-
 .../java/org/apache/sysml/utils/Statistics.java |   6 +
 .../paramserv/ParamservLocalNNTest.java         |  41 +++----
 .../paramserv/ParamservSparkNNTest.java         |  68 +++++++++--
 .../functions/paramserv/RpcObjectTest.java      |  56 ++++++++++
 .../functions/paramserv/SerializationTest.java  |   2 +-
 .../paramserv/paramserv-nn-asp-batch.dml        |  53 ---------
 .../paramserv/paramserv-nn-asp-epoch.dml        |  53 ---------
 .../paramserv/paramserv-nn-bsp-batch-dc.dml     |  53 ---------
 .../paramserv/paramserv-nn-bsp-batch-dr.dml     |  53 ---------
 .../paramserv/paramserv-nn-bsp-batch-drr.dml    |  53 ---------
 .../paramserv/paramserv-nn-bsp-batch-or.dml     |  53 ---------
 .../paramserv/paramserv-nn-bsp-epoch.dml        |  53 ---------
 .../paramserv-spark-agg-service-failed.dml      |  53 +++++++++
 .../paramserv-spark-nn-bsp-batch-dc.dml         |  53 ---------
 .../paramserv/paramserv-spark-worker-failed.dml |  53 +++++++++
 .../functions/paramserv/paramserv-test.dml      |  48 ++++++++
 .../functions/paramserv/ZPackageSuite.java      |   4 +-
 32 files changed, 961 insertions(+), 543 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index bbf2dbe..c23943d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -35,6 +35,9 @@ import org.apache.sysml.utils.Statistics;
 public class LocalPSWorker extends PSWorker implements Callable<Void> {
 
        protected static final Log LOG = 
LogFactory.getLog(LocalPSWorker.class.getName());
+       private static final long serialVersionUID = 5195390748495357295L;
+
+       protected LocalPSWorker() {}
 
        public LocalPSWorker(int workerID, String updFunc, 
Statement.PSFrequency freq, int epochs, long batchSize,
                MatrixObject valFeatures, MatrixObject valLabels, 
ExecutionContext ec, ParamServer ps) {
@@ -42,6 +45,11 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        }
 
        @Override
+       public String getWorkerName() {
+               return String.format("Local worker_%d", _workerID);
+       }
+
+       @Override
        public Void call() throws Exception {
                if (DMLScript.STATISTICS)
                        Statistics.incWorkerNumber();
@@ -60,10 +68,10 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                        }
 
                        if (LOG.isDebugEnabled()) {
-                               LOG.debug(String.format("Local worker_%d: Job 
finished.", _workerID));
+                               LOG.debug(String.format("%s: job finished.", 
getWorkerName()));
                        }
                } catch (Exception e) {
-                       throw new DMLRuntimeException(String.format("Local 
worker_%d failed", _workerID), e);
+                       throw new DMLRuntimeException(String.format("%s 
failed", getWorkerName()), e);
                }
                return null;
        }
@@ -93,7 +101,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                        ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
 
                        if (LOG.isDebugEnabled()) {
-                               LOG.debug(String.format("Local worker_%d: 
Finished %d epoch.", _workerID, i + 1));
+                               LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
 
@@ -108,9 +116,9 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                        Statistics.accPSLocalModelUpdateTime((long) 
tUpd.stop());
                
                if (LOG.isDebugEnabled()) {
-                       LOG.debug(String.format("Local worker_%d: Local global 
parameter [size:%d kb] updated. "
+                       LOG.debug(String.format("%s: local global parameter 
[size:%d kb] updated. "
                                + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]",
-                               _workerID, globalParams.getDataSize(), i + 1, 
_epochs, j + 1, totalIter));
+                               getWorkerName(), globalParams.getDataSize(), i 
+ 1, _epochs, j + 1, totalIter));
                }
                return globalParams;
        }
@@ -129,17 +137,17 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                                ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
                        }
                        if (LOG.isDebugEnabled()) {
-                               LOG.debug(String.format("Local worker_%d: 
Finished %d epoch.", _workerID, i + 1));
+                               LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
        }
 
        private ListObject pullModel() {
                // Pull the global parameters from ps
-               ListObject globalParams = (ListObject)_ps.pull(_workerID);
+               ListObject globalParams = _ps.pull(_workerID);
                if (LOG.isDebugEnabled()) {
-                       LOG.debug(String.format("Local worker_%d: Successfully 
pull the global parameters "
-                               + "[size:%d kb] from ps.", _workerID, 
globalParams.getDataSize() / 1024));
+                       LOG.debug(String.format("%s: successfully pull the 
global parameters "
+                               + "[size:%d kb] from ps.", getWorkerName(), 
globalParams.getDataSize() / 1024));
                }
                return globalParams;
        }
@@ -148,8 +156,8 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                // Push the gradients to ps
                _ps.push(_workerID, gradients);
                if (LOG.isDebugEnabled()) {
-                       LOG.debug(String.format("Local worker_%d: Successfully 
push the gradients "
-                               + "[size:%d kb] to ps.", _workerID, 
gradients.getDataSize() / 1024));
+                       LOG.debug(String.format("%s: successfully push the 
gradients "
+                               + "[size:%d kb] to ps.", getWorkerName(), 
gradients.getDataSize() / 1024));
                }
        }
 
@@ -168,8 +176,8 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                _ec.setVariable(Statement.PS_LABELS, bLabels);
 
                if (LOG.isDebugEnabled()) {
-                       LOG.debug(String.format("Local worker_%d: Got batch 
data [size:%d kb] of index from %d to %d [last index: %d]. "
-                               + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]", _workerID,
+                       LOG.debug(String.format("%s: got batch data [size:%d 
kb] of index from %d to %d [last index: %d]. "
+                               + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]", getWorkerName(),
                                bFeatures.getDataSize() / 1024 + 
bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs,
                                j + 1, totalIter));
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
index 52372c9..0c73acb 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -22,11 +22,14 @@ package org.apache.sysml.runtime.controlprogram.paramserv;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 
 public class LocalParamServer extends ParamServer {
 
+       public LocalParamServer() {
+               super();
+       }
+
        public LocalParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
                super(model, aggFunc, updateType, ec, workerNum);
        }
@@ -37,7 +40,7 @@ public class LocalParamServer extends ParamServer {
        }
 
        @Override
-       public Data pull(int workerID) {
+       public ListObject pull(int workerID) {
                ListObject model;
                try {
                        model = _modelMap.get(workerID).take();

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index 1ab5f5e..464db9b 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -21,6 +21,7 @@ package org.apache.sysml.runtime.controlprogram.paramserv;
 
 import static 
org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.stream.Collectors;
 
@@ -34,7 +35,10 @@ import 
org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 
-public abstract class PSWorker {
+public abstract class PSWorker implements Serializable {
+
+       private static final long serialVersionUID = -3510485051178200118L;
+
        protected int _workerID;
        protected int _epochs;
        protected long _batchSize;
@@ -50,10 +54,8 @@ public abstract class PSWorker {
        protected String _updFunc;
        protected Statement.PSFrequency _freq;
 
-       protected PSWorker() {
+       protected PSWorker() {}
 
-       }
-       
        protected PSWorker(int workerID, String updFunc, Statement.PSFrequency 
freq, int epochs, long batchSize,
                MatrixObject valFeatures, MatrixObject valLabels, 
ExecutionContext ec, ParamServer ps) {
                _workerID = workerID;
@@ -65,7 +67,10 @@ public abstract class PSWorker {
                _valLabels = valLabels;
                _ec = ec;
                _ps = ps;
+               setupUpdateFunction(updFunc, ec);
+       }
 
+       protected void setupUpdateFunction(String updFunc, ExecutionContext ec) 
{
                // Get the update function
                String[] cfn = ParamservUtils.getCompleteFuncName(updFunc, 
PS_FUNC_PREFIX);
                String ns = cfn[0];
@@ -125,4 +130,6 @@ public abstract class PSWorker {
        public MatrixObject getLabels() {
                return _labels;
        }
+
+       public abstract String getWorkerName();
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index bd8ee36..2607036 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -42,7 +42,6 @@ import 
org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
-import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.apache.sysml.utils.Statistics;
@@ -53,17 +52,19 @@ public abstract class ParamServer
        protected static final boolean ACCRUE_BSP_GRADIENTS = true;
        
        // worker input queues and global model
-       protected final Map<Integer, BlockingQueue<ListObject>> _modelMap;
+       protected Map<Integer, BlockingQueue<ListObject>> _modelMap;
        private ListObject _model;
 
        //aggregation service
-       protected final ExecutionContext _ec;
-       private final Statement.PSUpdateType _updateType;
-       private final FunctionCallCPInstruction _inst;
-       private final String _outputName;
-       private final boolean[] _finishedStates;  // Workers' finished states
+       protected ExecutionContext _ec;
+       private Statement.PSUpdateType _updateType;
+       private FunctionCallCPInstruction _inst;
+       private String _outputName;
+       private boolean[] _finishedStates;  // Workers' finished states
        private ListObject _accGradients = null;
 
+       protected ParamServer() {}
+
        protected ParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
                // init worker queues and global model
                _modelMap = new HashMap<>(workerNum);
@@ -77,10 +78,22 @@ public abstract class ParamServer
                _ec = ec;
                _updateType = updateType;
                _finishedStates = new boolean[workerNum];
+               setupAggFunc(_ec, aggFunc);
+               
+               // broadcast initial model
+               try {
+                       broadcastModel();
+               }
+               catch (InterruptedException e) {
+                       throw new DMLRuntimeException("Param server: failed to 
broadcast the initial model.", e);
+               }
+       }
+
+       public void setupAggFunc(ExecutionContext ec, String aggFunc) {
                String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, 
PS_FUNC_PREFIX);
                String ns = cfn[0];
                String fname = cfn[1];
-               FunctionProgramBlock func = 
_ec.getProgram().getFunctionProgramBlock(ns, fname);
+               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(ns, fname);
                ArrayList<DataIdentifier> inputs = func.getInputParams();
                ArrayList<DataIdentifier> outputs = func.getOutputParams();
 
@@ -101,19 +114,11 @@ public abstract class ParamServer
                ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                        .collect(Collectors.toCollection(ArrayList::new));
                _inst = new FunctionCallCPInstruction(ns, fname, boundInputs, 
inputNames, outputNames, "aggregate function");
-               
-               // broadcast initial model
-               try {
-                       broadcastModel();
-               }
-               catch (InterruptedException e) {
-                       throw new DMLRuntimeException("Param server: failed to 
broadcast the initial model.", e);
-               }
        }
 
        public abstract void push(int workerID, ListObject value);
 
-       public abstract Data pull(int workerID);
+       public abstract ListObject pull(int workerID);
 
        public ListObject getResult() {
                // All the model updating work has terminated,

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index b9fd7a8..cf27457 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -28,8 +28,11 @@ import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.MultiThreadedHop;
@@ -57,6 +60,7 @@ import 
org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator;
 import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
@@ -68,13 +72,14 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.utils.Statistics;
 
 import scala.Tuple2;
 
 public class ParamservUtils {
 
+       protected static final Log LOG = 
LogFactory.getLog(ParamservUtils.class.getName());
        public static final String PS_FUNC_PREFIX = "_ps_";
-
        public static long SEED = -1; // Used for generating permutation
 
        /**
@@ -140,6 +145,14 @@ public class ParamservUtils {
                CacheableData<?> cd = (CacheableData<?>) data;
                cd.enableCleanup(true);
                ec.cleanupCacheableData(cd);
+               if (LOG.isDebugEnabled()) {
+                       LOG.debug(String.format("%s has been deleted.", 
cd.getFileName()));
+               }
+       }
+
+       public static void cleanupMatrixObject(ExecutionContext ec, 
MatrixObject mo) {
+               mo.enableCleanup(true);
+               ec.cleanupCacheableData(mo);
        }
 
        public static MatrixObject newMatrixObject(MatrixBlock mb) {
@@ -365,6 +378,7 @@ public class ParamservUtils {
 
        @SuppressWarnings("unchecked")
        public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> 
doPartitionOnSpark(SparkExecutionContext sec, MatrixObject features, 
MatrixObject labels, Statement.PSScheme scheme, int workerNum) {
+               Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
                // Get input RDD
                JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = 
(JavaPairRDD<MatrixIndexes, MatrixBlock>)
                                sec.getRDDHandleForMatrixObject(features, 
InputInfo.BinaryBlockInputInfo);
@@ -372,33 +386,34 @@ public class ParamservUtils {
                                sec.getRDDHandleForMatrixObject(labels, 
InputInfo.BinaryBlockInputInfo);
 
                DataPartitionerSparkMapper mapper = new 
DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows());
-               return 
ParamservUtils.assembleTrainingData(features.getNumRows(), featuresRDD, 
labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, 
labels))
+               JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = 
ParamservUtils
+                       .assembleTrainingData(features.getNumRows(), 
featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID 
=> (features, labels))
                        .flatMapToPair(mapper) // Do the data partitioning on 
spark (workerID => (rowBlockID, (single row features, single row labels))
                        // Aggregate the partitioned matrix according to rowID 
for each worker
                        // i.e. (workerID => ordered list[(rowBlockID, (single 
row features, single row labels)]
-                       .aggregateByKey(new LinkedList<Tuple2<Long, 
Tuple2<MatrixBlock, MatrixBlock>>>(),
-                               new Partitioner() {
-                                       private static final long 
serialVersionUID = -7937781374718031224L;
-                                       @Override
-                                       public int getPartition(Object 
workerID) {
-                                               return (int) workerID;
-                                       }
-                                       @Override
-                                       public int numPartitions() {
-                                               return workerNum;
-                                       }
-                               }, 
-                               (list, input) -> {
-                                       list.add(input);
-                                       return list;
-                               },
-                               (l1, l2) -> {
-                                       l1.addAll(l2);
-                                       l1.sort((o1, o2) -> 
o1._1.compareTo(o2._1));
-                                       return l1;
-                               })
-                       .mapToPair(new DataPartitionerSparkAggregator(
-                               features.getNumColumns(), 
labels.getNumColumns()));
+                       .aggregateByKey(new LinkedList<Tuple2<Long, 
Tuple2<MatrixBlock, MatrixBlock>>>(), new Partitioner() {
+                               private static final long serialVersionUID = 
-7937781374718031224L;
+                               @Override
+                               public int getPartition(Object workerID) {
+                                       return (int) workerID;
+                               }
+                               @Override
+                               public int numPartitions() {
+                                       return workerNum;
+                               }
+                       }, (list, input) -> {
+                               list.add(input);
+                               return list;
+                       }, (l1, l2) -> {
+                               l1.addAll(l2);
+                               l1.sort((o1, o2) -> o1._1.compareTo(o2._1));
+                               return l1;
+                       })
+                       .mapToPair(new 
DataPartitionerSparkAggregator(features.getNumColumns(), 
labels.getNumColumns()));
+
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSSetupTime((long) tSetup.stop());
+               return result;
        }
 
        public static ListObject accrueGradients(ListObject accGradients, 
ListObject gradients) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
index ec10232..9354025 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
@@ -28,12 +28,10 @@ public class SparkPSBody {
 
        private ExecutionContext _ec;
 
-       public SparkPSBody() {
-
-       }
+       public SparkPSBody() {}
 
        public SparkPSBody(ExecutionContext ec) {
-               this._ec = ec;
+               _ec = ec;
        }
 
        public ExecutionContext getEc() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
new file mode 100644
index 0000000..de7b6c6
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark;
+
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL;
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
+import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.utils.Statistics;
+
+public class SparkPSProxy extends ParamServer {
+
+       private TransportClient _client;
+       private final long _rpcTimeout;
+
+       public SparkPSProxy(TransportClient client, long rpcTimeout) {
+               super();
+               _client = client;
+               _rpcTimeout = rpcTimeout;
+       }
+
+       @Override
+       public void push(int workerID, ListObject value) {
+               Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
+               PSRpcResponse response = new 
PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, 
value).serialize(), _rpcTimeout));
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSRpcRequestTime((long) tRpc.stop());
+               if (!response.isSuccessful()) {
+                       throw new 
DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push 
gradients. \n%s", workerID, response.getErrorMessage()));
+               }
+       }
+
+       @Override
+       public ListObject pull(int workerID) {
+               Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
+               PSRpcResponse response = new 
PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, 
null).serialize(), _rpcTimeout));
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSRpcRequestTime((long) tRpc.stop());
+               if (!response.isSuccessful()) {
+                       throw new 
DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull 
models. \n%s", workerID, response.getErrorMessage()));
+               }
+               return response.getResultModel();
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
index 466801f..fa06243 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
@@ -20,43 +20,58 @@
 package org.apache.sysml.runtime.controlprogram.paramserv.spark;
 
 import java.io.IOException;
-import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 
 import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.codegen.CodegenUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.PSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
 import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.utils.Statistics;
 
 import scala.Tuple2;
 
-public class SparkPSWorker extends PSWorker implements 
VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {
+public class SparkPSWorker extends LocalPSWorker implements 
VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> {
 
        private static final long serialVersionUID = -8674739573419648732L;
 
        private String _program;
        private HashMap<String, byte[]> _clsMap;
+       private String _host; // host ip of driver
+       private long _rpcTimeout; // rpc ask timeout
+       private String _aggFunc;
 
-       protected SparkPSWorker() {
-               // No-args constructor used for deserialization
-       }
-
-       public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int 
epochs, long batchSize, String program, HashMap<String, byte[]> clsMap) {
+       public SparkPSWorker(String updFunc, String aggFunc, 
Statement.PSFrequency freq, int epochs, long batchSize, String program, 
HashMap<String, byte[]> clsMap, String host, long rpcTimeout) {
                _updFunc = updFunc;
+               _aggFunc = aggFunc;
                _freq = freq;
                _epochs = epochs;
                _batchSize = batchSize;
                _program = program;
                _clsMap = clsMap;
+               _host = host;
+               _rpcTimeout = rpcTimeout;
+       }
+
+       @Override
+       public String getWorkerName() {
+               return String.format("Spark worker_%d", _workerID);
        }
 
        @Override
        public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> 
input) throws Exception {
+               Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
                configureWorker(input);
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSSetupTime((long) tSetup.stop());
+               call(); // Launch the worker
        }
 
        private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, 
MatrixBlock>> input) throws IOException {
@@ -73,5 +88,20 @@ public class SparkPSWorker extends PSWorker implements 
VoidFunction<Tuple2<Integ
 
                // Initialize the buffer pool and register it in the jvm 
shutdown hook in order to be cleanuped at the end
                RemoteParForUtils.setupBufferPool(_workerID);
+
+               // Create the ps proxy
+               _ps = PSRpcFactory.createSparkPSProxy(_host, _rpcTimeout);
+
+               // Initialize the update function
+               setupUpdateFunction(_updFunc, _ec);
+
+               // Initialize the agg function
+               _ps.setupAggFunc(_ec, _aggFunc);
+
+               // Lazy initialize the matrix of features and labels
+               setFeatures(ParamservUtils.newMatrixObject(input._2._1));
+               setLabels(ParamservUtils.newMatrixObject(input._2._2));
+               _features.enableCleanup(false);
+               _labels.enableCleanup(false);
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
new file mode 100644
index 0000000..999d409
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
+import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
+import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
+
+import java.nio.ByteBuffer;
+import java.util.StringTokenizer;
+
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.util.ProgramConverter;
+
+public class PSRpcCall extends PSRpcObject {
+
+       private static final String PS_RPC_CALL_BEGIN = CDATA_BEGIN + 
"PSRPCCALL" + LEVELIN;
+       private static final String PS_RPC_CALL_END = LEVELOUT + CDATA_END;
+
+       private String _method;
+       private int _workerID;
+       private ListObject _data;
+
+       public PSRpcCall(String method, int workerID, ListObject data) {
+               _method = method;
+               _workerID = workerID;
+               _data = data;
+       }
+
+       public PSRpcCall(ByteBuffer buffer) {
+               deserialize(buffer);
+       }
+
+       public void deserialize(ByteBuffer buffer) {
+               //FIXME: instead of shallow deserialize + read, we should do a 
deep deserialize of the matrix blocks.
+               String input = bufferToString(buffer);
+               //header elimination
+               input = input.substring(PS_RPC_CALL_BEGIN.length(), 
input.length() - PS_RPC_CALL_END.length()); //remove start/end
+               StringTokenizer st = new StringTokenizer(input, 
COMPONENTS_DELIM);
+
+               _method = st.nextToken();
+               _workerID = Integer.valueOf(st.nextToken());
+               String dataStr = st.nextToken();
+               _data = dataStr.equals(EMPTY) ? null :
+                       (ListObject) 
ProgramConverter.parseDataObject(dataStr)[1];
+       }
+
+       public ByteBuffer serialize() {
+               //FIXME: instead of export+shallow serialize, we should do a 
deep serialize of the matrix blocks.
+               StringBuilder sb = new StringBuilder();
+               sb.append(PS_RPC_CALL_BEGIN);
+               sb.append(_method);
+               sb.append(COMPONENTS_DELIM);
+               sb.append(_workerID);
+               sb.append(COMPONENTS_DELIM);
+               if (_data == null) {
+                       sb.append(EMPTY);
+               } else {
+                       flushListObject(_data);
+                       
sb.append(ProgramConverter.serializeDataObject(DATA_KEY, _data));
+               }
+               sb.append(PS_RPC_CALL_END);
+               return ByteBuffer.wrap(sb.toString().getBytes());
+       }
+
+       public String getMethod() {
+               return _method;
+       }
+
+       public int getWorkerID() {
+               return _workerID;
+       }
+
+       public ListObject getData() {
+               return _data;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
new file mode 100644
index 0000000..c8b4024
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import java.io.IOException;
+import java.util.Collections;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy;
+
+//TODO should be able to configure the port by users
+public class PSRpcFactory {
+
+       private static final int PORT = 5055;
+       private static final String MODULE_NAME = "ps";
+
+       private static TransportContext createTransportContext(LocalParamServer 
ps) {
+               TransportConf conf = new TransportConf(MODULE_NAME, new 
SystemPropertyConfigProvider());
+               PSRpcHandler handler = new PSRpcHandler(ps);
+               return new TransportContext(conf, handler);
+       }
+
+       /**
+        * Create and start the server
+        * @return server
+        */
+       public static TransportServer createServer(LocalParamServer ps, String 
host) {
+               TransportContext context = createTransportContext(ps);
+               return context.createServer(host, PORT, 
Collections.emptyList());
+       }
+
+       public static SparkPSProxy createSparkPSProxy(String host, long 
rpcTimeout) throws IOException {
+               TransportContext context = createTransportContext(new 
LocalParamServer());
+               return new 
SparkPSProxy(context.createClientFactory().createClient(host, PORT), 
rpcTimeout);
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
new file mode 100644
index 0000000..3d73a37
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL;
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH;
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.EMPTY_DATA;
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.ERROR;
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.SUCCESS;
+
+import java.nio.ByteBuffer;
+
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public final class PSRpcHandler extends RpcHandler {
+
+       private LocalParamServer _server;
+
+       protected PSRpcHandler(LocalParamServer server) {
+               _server = server;
+       }
+
+       @Override
+       public void receive(TransportClient client, ByteBuffer buffer, 
RpcResponseCallback callback) {
+               PSRpcCall call = new PSRpcCall(buffer);
+               PSRpcResponse response = null;
+               switch (call.getMethod()) {
+                       case PUSH:
+                               try {
+                                       _server.push(call.getWorkerID(), 
call.getData());
+                                       response = new PSRpcResponse(SUCCESS, 
EMPTY_DATA);
+                               } catch (DMLRuntimeException exception) {
+                                       response = new PSRpcResponse(ERROR, 
ExceptionUtils.getFullStackTrace(exception));
+                               } finally {
+                                       
callback.onSuccess(response.serialize());
+                               }
+                               break;
+                       case PULL:
+                               ListObject data;
+                               try {
+                                       data = _server.pull(call.getWorkerID());
+                                       response = new PSRpcResponse(SUCCESS, 
data);
+                               } catch (DMLRuntimeException exception) {
+                                       response = new PSRpcResponse(ERROR, 
ExceptionUtils.getFullStackTrace(exception));
+                               } finally {
+                                       
callback.onSuccess(response.serialize());
+                               }
+                               break;
+                       default:
+                               throw new 
DMLRuntimeException(String.format("Does not support the rpc call for method 
%s", call.getMethod()));
+               }
+       }
+
+       @Override
+       public StreamManager getStreamManager() {
+               return new OneForOneStreamManager();
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
new file mode 100644
index 0000000..c6d7fd3
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import java.nio.ByteBuffer;
+
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public abstract class PSRpcObject {
+
+       public static final String PUSH = "push";
+       public static final String PULL = "pull";
+       public static final String DATA_KEY = "data";
+       public static final String EMPTY_DATA = "";
+
+       public abstract void deserialize(ByteBuffer buffer);
+
+       public abstract ByteBuffer serialize();
+
+       /**
+        * Convert direct byte buffer to string
+        * @param buffer direct byte buffer
+        * @return string
+        */
+       protected String bufferToString(ByteBuffer buffer) {
+               byte[] result = new byte[buffer.limit()];
+               buffer.get(result, 0, buffer.limit());
+               return new String(result);
+       }
+
+       /**
+        * Flush the data into HDFS
+        * @param data list object
+        */
+       protected void flushListObject(ListObject data) {
+               data.getData().stream().filter(d -> d instanceof CacheableData)
+                       .forEach(d -> ((CacheableData<?>) d).exportData());
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
new file mode 100644
index 0000000..998c523
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
+import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
+import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
+
+import java.nio.ByteBuffer;
+import java.util.StringTokenizer;
+
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.util.ProgramConverter;
+
+public class PSRpcResponse extends PSRpcObject {
+
+       public static final int SUCCESS = 1;
+       public static final int ERROR = 2;
+
+       private static final String PS_RPC_RESPONSE_BEGIN = CDATA_BEGIN + 
"PSRPCRESPONSE" + LEVELIN;
+       private static final String PS_RPC_RESPONSE_END = LEVELOUT + CDATA_END;
+
+       private int _status;
+       private Object _data;   // Could be list object or exception
+
+       public PSRpcResponse(ByteBuffer buffer) {
+               deserialize(buffer);
+       }
+
+       public PSRpcResponse(int status, Object data) {
+               _status = status;
+               _data = data;
+       }
+
+       public boolean isSuccessful() {
+               return _status == SUCCESS;
+       }
+
+       public String getErrorMessage() {
+               return (String) _data;
+       }
+
+       public ListObject getResultModel() {
+               return (ListObject) _data;
+       }
+
+       @Override
+       public void deserialize(ByteBuffer buffer) {
+               //FIXME: instead of shallow deserialize + read, we should do a 
deep deserialize of the matrix blocks.
+               String input = bufferToString(buffer);
+               //header elimination
+               input = input.substring(PS_RPC_RESPONSE_BEGIN.length(), 
input.length() - PS_RPC_RESPONSE_END.length()); //remove start/end
+               StringTokenizer st = new StringTokenizer(input, 
COMPONENTS_DELIM);
+
+               _status = Integer.valueOf(st.nextToken());
+               String data = st.nextToken();
+               switch (_status) {
+                       case SUCCESS:
+                               _data = data.equals(EMPTY) ? null :
+                                       
ProgramConverter.parseDataObject(data)[1];
+                               break;
+                       case ERROR:
+                               _data = data;
+                               break;
+               }
+       }
+
+       @Override
+       public ByteBuffer serialize() {
+               //FIXME: instead of export+shallow serialize, we should do a 
deep serialize of the matrix blocks.
+               
+               StringBuilder sb = new StringBuilder();
+               sb.append(PS_RPC_RESPONSE_BEGIN);
+               sb.append(_status);
+               sb.append(COMPONENTS_DELIM);
+               switch (_status) {
+                       case SUCCESS:
+                               if (_data.equals(EMPTY_DATA)) {
+                                       sb.append(EMPTY);
+                               } else {
+                                       flushListObject((ListObject) _data);
+                                       
sb.append(ProgramConverter.serializeDataObject(DATA_KEY, (ListObject) _data));
+                               }
+                               break;
+                       case ERROR:
+                               sb.append(_data.toString());
+                               break;
+               }
+               sb.append(PS_RPC_RESPONSE_END);
+               return ByteBuffer.wrap(sb.toString().getBytes());
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 4e7a718..6133987 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -55,6 +55,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
+import org.apache.spark.network.server.TransportServer;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.recompile.Recompiler;
 import org.apache.sysml.lops.LopProperties;
@@ -71,6 +72,7 @@ import 
org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSWorker;
+import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.matrix.operators.Operator;
@@ -114,16 +116,16 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private void runOnSpark(SparkExecutionContext sec, PSModeType mode) {
+               Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
+
                PSScheme scheme = getScheme();
                int workerNum = getWorkerNum(mode);
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
 
-               int k = getParLevel(workerNum);
-
                // Get the compiled execution context
                LocalVariableMap newVarsMap = createVarsMap(sec);
-               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, k);
+               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); // 
level of par is 1 in spark backend
 
                MatrixObject features = 
sec.getMatrixObject(getParam(PS_FEATURES));
                MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
@@ -131,16 +133,47 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                // Force all the instructions to CP type
                Recompiler.recompileProgramBlockHierarchy2Forced(
                        newEC.getProgram().getProgramBlocks(), 0, new 
HashSet<>(), LopProperties.ExecType.CP);
-               
+
                // Serialize all the needed params for remote workers
                SparkPSBody body = new SparkPSBody(newEC);
                HashMap<String, byte[]> clsMap = new HashMap<>();
                String program = ProgramConverter.serializeSparkPSBody(body, 
clsMap);
 
-               SparkPSWorker worker = new 
SparkPSWorker(getParam(PS_UPDATE_FUN), getFrequency(), getEpochs(), 
getBatchSize(), program, clsMap);
-               ParamservUtils.doPartitionOnSpark(sec, features, labels, 
scheme, workerNum) // Do data partitioning
-                       .foreach(worker);   // Run remote workers
+               // Get some configurations
+               String host = 
sec.getSparkContext().getConf().get("spark.driver.host");
+               long rpcTimeout = 
sec.getSparkContext().getConf().contains("spark.rpc.askTimeout") ? 
+                       
sec.getSparkContext().getConf().getTimeAsMs("spark.rpc.askTimeout") :
+                       
sec.getSparkContext().getConf().getTimeAsMs("spark.network.timeout", "120s");
+
+               // Create remote workers
+               SparkPSWorker worker = new 
SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), 
getFrequency(),
+                       getEpochs(), getBatchSize(), program, clsMap, host, 
rpcTimeout);
+
+               // Create the agg service's execution context
+               ExecutionContext aggServiceEC = 
ParamservUtils.copyExecutionContext(newEC, 1).get(0);
+
+               // Create the parameter server
+               ListObject model = sec.getListObject(getParam(PS_MODEL));
+               ParamServer ps = createPS(mode, aggFunc, getUpdateType(), 
workerNum, model, aggServiceEC);
+
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSSetupTime((long) tSetup.stop());
+
+               // Create the netty server for ps
+               TransportServer server = 
PSRpcFactory.createServer((LocalParamServer) ps, host); // Start the server
 
+               try {
+                       ParamservUtils.doPartitionOnSpark(sec, features, 
labels, scheme, workerNum) // Do data partitioning
+                               .foreach(worker); // Run remote workers
+               } catch (Exception e) {
+                       throw new DMLRuntimeException("Paramserv function 
failed: ", e);
+               } finally {
+                       // Stop the netty server
+                       server.close();
+               }
+
+               // Fetch the final model from ps
+               sec.setVariable(output.getName(), ps.getResult());
        }
 
        private void runLocally(ExecutionContext ec, PSModeType mode) {
@@ -176,8 +209,8 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                MatrixObject valFeatures = 
ec.getMatrixObject(getParam(PS_VAL_FEATURES));
                MatrixObject valLabels = 
ec.getMatrixObject(getParam(PS_VAL_LABELS));
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
-                  .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, 
getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps))
-                  .collect(Collectors.toList());
+                       .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, 
epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps))
+                       .collect(Collectors.toList());
 
                // Do data partition
                PSScheme scheme = getScheme();
@@ -296,6 +329,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        private ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
                switch (mode) {
                        case LOCAL:
+                       case REMOTE_SPARK:
                                return new LocalParamServer(model, aggFunc, 
updateType, ec, workerNum);
                        default:
                                throw new DMLRuntimeException("Unsupported 
parameter server: "+mode.name());

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java 
b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
index 1d2115e..fc9d9b4 100644
--- a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
@@ -143,7 +143,7 @@ public class ProgramConverter
        public static final String PB_IF = " IF" + LEVELIN;
        public static final String PB_FC = " FC" + LEVELIN;
        public static final String PB_EFC = " EFC" + LEVELIN;
-       
+
        public static final String CONF_STATS = "stats";
 
        // Used for parfor
@@ -716,9 +716,10 @@ public class ProgramConverter
                
builder.append(rSerializeProgramBlocks(ec.getProgram().getProgramBlocks(), 
clsMap));
                builder.append(PBS_END);
                builder.append(NEWLINE);
+               builder.append(COMPONENTS_DELIM);
+               builder.append(NEWLINE);
 
                builder.append(PSBODY_END);
-
                return builder.toString();
        }
 
@@ -868,7 +869,7 @@ public class ProgramConverter
                                value = mo.getFileName();
                                PartitionFormat partFormat = 
(mo.getPartitionFormat()!=null) ? new PartitionFormat(
                                                
mo.getPartitionFormat(),mo.getPartitionSize()) : PartitionFormat.NONE;
-                               metaData = new String[9];
+                               metaData = new String[11];
                                metaData[0] = String.valueOf( mc.getRows() );
                                metaData[1] = String.valueOf( mc.getCols() );
                                metaData[2] = String.valueOf( 
mc.getRowsPerBlock() );
@@ -878,6 +879,8 @@ public class ProgramConverter
                                metaData[6] = OutputInfo.outputInfoToString( 
md.getOutputInfo() );
                                metaData[7] = String.valueOf( partFormat );
                                metaData[8] = String.valueOf( 
mo.getUpdateType() );
+                               metaData[9] = 
String.valueOf(mo.isHDFSFileExists());
+                               metaData[10] = 
String.valueOf(mo.isCleanupEnabled());
                                break;
                        case LIST:
                                // SCHEMA: 
<name>|<datatype>|<valuetype>|value|<metadata>|<tab>element1<tab>element2<tab>element3
 (this is the list)
@@ -1683,6 +1686,8 @@ public class ProgramConverter
                                if( partFormat._dpf != 
PDataPartitionFormat.NONE )
                                        mo.setPartitioned( partFormat._dpf, 
partFormat._N );
                                mo.setUpdateType(inplace);
+                               
mo.setHDFSFileExists(Boolean.valueOf(st.nextToken()));
+                               
mo.enableCleanup(Boolean.valueOf(st.nextToken()));
                                dat = mo;
                                break;
                        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java 
b/src/main/java/org/apache/sysml/utils/Statistics.java
index 8f0d853..1dd8362 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -125,6 +125,7 @@ public class Statistics
        private static final LongAdder psLocalModelUpdateTime = new LongAdder();
        private static final LongAdder psModelBroadcastTime = new LongAdder();
        private static final LongAdder psBatchIndexTime = new LongAdder();
+       private static final LongAdder psRpcRequestTime = new LongAdder();
 
        //PARFOR optimization stats (low frequency updates)
        private static long parforOptTime = 0; //in milli sec
@@ -564,6 +565,10 @@ public class Statistics
                psBatchIndexTime.add(t);
        }
 
+       public static void accPSRpcRequestTime(long t) {
+               psRpcRequestTime.add(t);
+       }
+
        public static String getCPHeavyHitterCode( Instruction inst )
        {
                String opcode = null;
@@ -1003,6 +1008,7 @@ public class Statistics
                                                
psLocalModelUpdateTime.doubleValue() / 1000, psAggregationTime.doubleValue() / 
1000));
                                sb.append(String.format("Paramserv model 
broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 1000));
                                sb.append(String.format("Paramserv batch slice 
time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000));
+                               sb.append(String.format("Paramserv RPC request 
time:\t%.3f secs.\n", psRpcRequestTime.doubleValue() / 1000));
                        }
                        if( parforOptCount>0 ){
                                sb.append("ParFor loops optimized:\t\t" + 
getParforOptCount() + ".\n");

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
index d5fd509..905bfd1 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
@@ -19,75 +19,66 @@
 
 package org.apache.sysml.test.integration.functions.paramserv;
 
+import org.apache.sysml.parser.Statement;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.junit.Test;
 
 public class ParamservLocalNNTest extends AutomatedTestBase {
 
-       private static final String TEST_NAME1 = "paramserv-nn-bsp-batch-dc";
-       private static final String TEST_NAME2 = "paramserv-nn-asp-batch";
-       private static final String TEST_NAME3 = "paramserv-nn-bsp-epoch";
-       private static final String TEST_NAME4 = "paramserv-nn-asp-epoch";
-       private static final String TEST_NAME5 = "paramserv-nn-bsp-batch-drr";
-       private static final String TEST_NAME6 = "paramserv-nn-bsp-batch-dr";
-       private static final String TEST_NAME7 = "paramserv-nn-bsp-batch-or";
+       private static final String TEST_NAME = "paramserv-test";
 
        private static final String TEST_DIR = "functions/paramserv/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
ParamservLocalNNTest.class.getSimpleName() + "/";
 
        @Override
        public void setUp() {
-               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
-               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {}));
-               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {}));
-               addTestConfiguration(TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {}));
-               addTestConfiguration(TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {}));
-               addTestConfiguration(TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {}));
-               addTestConfiguration(TEST_NAME7, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {}));
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {}));
        }
 
        @Test
        public void testParamservBSPBatchDisjointContiguous() {
-               runDMLTest(TEST_NAME1);
+               runDMLTest(10, 3, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
        }
 
        @Test
        public void testParamservASPBatch() {
-               runDMLTest(TEST_NAME2);
+               runDMLTest(10, 3, Statement.PSUpdateType.ASP, 
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
        }
 
        @Test
        public void testParamservBSPEpoch() {
-               runDMLTest(TEST_NAME3);
+               runDMLTest(10, 3, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
        }
 
        @Test
        public void testParamservASPEpoch() {
-               runDMLTest(TEST_NAME4);
+               runDMLTest(10, 3, Statement.PSUpdateType.ASP, 
Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
        }
 
        @Test
        public void testParamservBSPBatchDisjointRoundRobin() {
-               runDMLTest(TEST_NAME5);
+               runDMLTest(10, 3, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN);
        }
 
        @Test
        public void testParamservBSPBatchDisjointRandom() {
-               runDMLTest(TEST_NAME6);
+               runDMLTest(10, 3, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM);
        }
 
        @Test
        public void testParamservBSPBatchOverlapReshuffle() {
-               runDMLTest(TEST_NAME7);
+               runDMLTest(10, 3, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE);
        }
 
-       private void runDMLTest(String testname) {
-               TestConfiguration config = getTestConfiguration(testname);
+       private void runDMLTest(int epochs, int workers, Statement.PSUpdateType 
utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
+               TestConfiguration config = 
getTestConfiguration(ParamservLocalNNTest.TEST_NAME);
                loadTestConfiguration(config);
-               programArgs = new String[] { "-explain" };
+               programArgs = new String[] { "-explain", "-nvargs", 
"mode=LOCAL", "epochs=" + epochs,
+                       "workers=" + workers, "utype=" + utype, "freq=" + freq, 
"batchsize=" + batchsize,
+                       "scheme=" + scheme };
                String HOME = SCRIPT_DIR + TEST_DIR;
-               fullDMLScriptName = HOME + testname + ".dml";
+               fullDMLScriptName = HOME + ParamservLocalNNTest.TEST_NAME + 
".dml";
                runTest(true, false, null, null, -1);
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
index 2441116..30eccb3 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
@@ -1,14 +1,24 @@
 package org.apache.sysml.test.integration.functions.paramserv;
 
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.Script;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.junit.Test;
 
 public class ParamservSparkNNTest extends AutomatedTestBase {
 
-       private static final String TEST_NAME1 = 
"paramserv-spark-nn-bsp-batch-dc";
+       private static final String TEST_NAME1 = "paramserv-test";
+       private static final String TEST_NAME2 = 
"paramserv-spark-worker-failed";
+       private static final String TEST_NAME3 = 
"paramserv-spark-agg-service-failed";
 
        private static final String TEST_DIR = "functions/paramserv/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
ParamservSparkNNTest.class.getSimpleName() + "/";
@@ -16,14 +26,42 @@ public class ParamservSparkNNTest extends AutomatedTestBase 
{
        @Override
        public void setUp() {
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {}));
        }
 
        @Test
        public void testParamservBSPBatchDisjointContiguous() {
-               runDMLTest(TEST_NAME1);
+               runDMLTest(2, 3, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+       }
+
+       @Test
+       public void testParamservASPBatchDisjointContiguous() {
+               runDMLTest(2, 3, Statement.PSUpdateType.ASP, 
Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+       }
+
+       @Test
+       public void testParamservBSPEpochDisjointContiguous() {
+               runDMLTest(10, 3, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+       }
+
+       @Test
+       public void testParamservASPEpochDisjointContiguous() {
+               runDMLTest(10, 3, Statement.PSUpdateType.ASP, 
Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
        }
 
-       private void runDMLTest(String testname) {
+       @Test
+       public void testParamservWorkerFailed() {
+               runDMLTest(TEST_NAME2, true, DMLException.class, "Invalid 
indexing by name in unnamed list: worker_err.");
+       }
+
+       @Test
+       public void testParamservAggServiceFailed() {
+               runDMLTest(TEST_NAME3, true, DMLException.class, "Invalid 
indexing by name in unnamed list: agg_service_err.");
+       }
+
+       private void runDMLTest(String testname, boolean exceptionExpected, 
Class<?> expectedException, String errMessage) {
+               programArgs = new String[] { "-explain" };
                DMLScript.RUNTIME_PLATFORM oldRtplatform = 
AutomatedTestBase.rtplatform;
                boolean oldUseLocalSparkConfig = 
DMLScript.USE_LOCAL_SPARK_CONFIG;
                AutomatedTestBase.rtplatform = DMLScript.RUNTIME_PLATFORM.SPARK;
@@ -32,16 +70,32 @@ public class ParamservSparkNNTest extends AutomatedTestBase 
{
                try {
                        TestConfiguration config = 
getTestConfiguration(testname);
                        loadTestConfiguration(config);
-                       programArgs = new String[] { "-explain" };
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       // The test is not already finished, so it is normal to 
have the NPE
-                       runTest(true, true, DMLException.class, null, -1);
+                       runTest(true, exceptionExpected, expectedException, 
errMessage, -1);
                } finally {
                        AutomatedTestBase.rtplatform = oldRtplatform;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = 
oldUseLocalSparkConfig;
                }
-
        }
 
+       private void runDMLTest(int epochs, int workers, Statement.PSUpdateType 
utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
+               Script script = dmlFromFile(SCRIPT_DIR + TEST_DIR + TEST_NAME1 
+ ".dml").in("$mode", Statement.PSModeType.REMOTE_SPARK.toString())
+                       .in("$epochs", String.valueOf(epochs))
+                       .in("$workers", String.valueOf(workers))
+                       .in("$utype", utype.toString())
+                       .in("$freq", freq.toString())
+                       .in("$batchsize", String.valueOf(batchsize))
+                       .in("$scheme", scheme.toString());
+
+               SparkConf conf = 
SparkExecutionContext.createSystemMLSparkConf().setAppName("ParamservSparkNNTest").setMaster("local[*]")
+                       .set("spark.driver.allowMultipleContexts", "true");
+               JavaSparkContext sc = new JavaSparkContext(conf);
+               MLContext ml = new MLContext(sc);
+               ml.setStatistics(true);
+               ml.execute(script);
+               ml.resetConfig();
+               sc.stop();
+               ml.close();
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
new file mode 100644
index 0000000..57e1106
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.paramserv;
+
+import java.util.Arrays;
+
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject;
+import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
+import org.apache.sysml.runtime.instructions.cp.IntObject;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RpcObjectTest {
+
+       @Test
+       public void testPSRpcCall() {
+               MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
+               MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
+               IntObject io = new IntObject(30);
+               ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
+               PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, lo);
+               PSRpcCall actual = new PSRpcCall(expected.serialize());
+               Assert.assertEquals(new String(expected.serialize().array()), 
new String(actual.serialize().array()));
+       }
+
+       @Test
+       public void testPSRpcResponse() {
+               MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
+               MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
+               IntObject io = new IntObject(30);
+               ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
+               PSRpcResponse expected = new 
PSRpcResponse(PSRpcResponse.SUCCESS, lo);
+               PSRpcResponse actual = new PSRpcResponse(expected.serialize());
+               Assert.assertEquals(new String(expected.serialize().array()), 
new String(actual.serialize().array()));
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
index 2a08ca6..64d6492 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
@@ -68,7 +68,7 @@ public class SerializationTest {
                Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
        }
 
-       private MatrixObject generateDummyMatrix(int size) {
+       public static MatrixObject generateDummyMatrix(int size) {
                double[] dl = new double[size];
                for (int i = 0; i < size; i++) {
                        dl[i] = i;

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
deleted file mode 100644
index ba22942..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "ASP", "BATCH", batchsize,"DISJOINT_CONTIGUOUS", 
"LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
deleted file mode 100644
index c8c6a2f..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "ASP", "EPOCH", batchsize, "DISJOINT_CONTIGUOUS", 
"LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
deleted file mode 100644
index 78fc1c4..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_CONTIGUOUS", 
"LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
deleted file mode 100644
index 9191b5a..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_RANDOM", 
"LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
deleted file mode 100644
index ec18cb4..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 4
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_ROUND_ROBIN", 
"LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
deleted file mode 100644
index 928dde2..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "OVERLAP_RESHUFFLE", 
"LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file

Reply via email to