Repository: systemml
Updated Branches:
  refs/heads/master bfd495289 -> e0c271fe4


[SYSTEMML-2420,2457] Improved distributed paramserv comm and stats

Closes #808.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e0c271fe
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e0c271fe
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e0c271fe

Branch: refs/heads/master
Commit: e0c271fe4cb05d3a53eb0143ab298e26831d1ed7
Parents: bfd4952
Author: EdgarLGB <guobao...@atos.net>
Authored: Thu Jul 26 23:33:12 2018 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu Jul 26 23:33:13 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/caching/CacheableData.java   |  6 ++
 .../controlprogram/paramserv/LocalPSWorker.java | 48 +++++++---
 .../controlprogram/paramserv/PSWorker.java      | 24 ++++-
 .../paramserv/ParamservUtils.java               |  5 ++
 .../paramserv/spark/SparkPSProxy.java           | 35 ++++++--
 .../paramserv/spark/SparkPSWorker.java          | 93 ++++++++++++++++---
 .../paramserv/spark/rpc/PSRpcCall.java          | 92 +++++++++----------
 .../paramserv/spark/rpc/PSRpcFactory.java       | 24 ++---
 .../paramserv/spark/rpc/PSRpcHandler.java       | 32 ++++---
 .../paramserv/spark/rpc/PSRpcObject.java        | 85 ++++++++++++++----
 .../paramserv/spark/rpc/PSRpcResponse.java      | 94 ++++++++++----------
 .../cp/ParamservBuiltinCPInstruction.java       | 56 ++++++++----
 .../sysml/runtime/io/IOUtilFunctions.java       | 10 +++
 .../java/org/apache/sysml/utils/Statistics.java |  4 +
 .../paramserv/ParamservSparkNNTest.java         | 37 +++-----
 .../functions/paramserv/RpcObjectTest.java      | 22 ++---
 .../paramserv-spark-agg-service-failed.dml      |  6 +-
 .../paramserv/paramserv-spark-worker-failed.dml |  6 +-
 18 files changed, 441 insertions(+), 238 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index f524251..0265c33 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -364,6 +364,12 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
        // ***                                       ***
        // *********************************************
 
+       public T acquireReadAndRelease() {
+               T tmp = acquireRead();
+               release();
+               return tmp;
+       }
+       
        /**
         * Acquires a shared "read-only" lock, produces the reference to the 
cache block,
         * restores the cache block to main memory, reads from HDFS if needed.

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index c23943d..b8a416f 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -48,12 +48,10 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        public String getWorkerName() {
                return String.format("Local worker_%d", _workerID);
        }
-
+       
        @Override
        public Void call() throws Exception {
-               if (DMLScript.STATISTICS)
-                       Statistics.incWorkerNumber();
-               
+               incWorkerNumber();
                try {
                        long dataSize = _features.getNumRows();
                        int totalIter = (int) Math.ceil((double) dataSize / 
_batchSize);
@@ -94,17 +92,19 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                                if( j < totalIter - 1 )
                                        params = updateModel(params, gradients, 
i, j, totalIter);
                                ParamservUtils.cleanupListObject(_ec, 
gradients);
+                               
+                               accNumBatches(1);
                        }
 
                        // Push the gradients to ps
                        pushGradients(accGradients);
                        ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
 
+                       accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
-
        }
 
        private ListObject updateModel(ListObject globalParams, ListObject 
gradients, int i, int j, int totalIter) {
@@ -112,8 +112,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
 
                globalParams = _ps.updateLocalModel(_ec, gradients, 
globalParams);
 
-               if (DMLScript.STATISTICS)
-                       Statistics.accPSLocalModelUpdateTime((long) 
tUpd.stop());
+               accLocalModelUpdateTime(tUpd);
                
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: local global parameter 
[size:%d kb] updated. "
@@ -133,9 +132,12 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
 
                                // Push the gradients to ps
                                pushGradients(gradients);
-
                                ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
+                               
+                               accNumBatches(1);
                        }
+                       
+                       accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
@@ -169,8 +171,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                Timing tSlic = DMLScript.STATISTICS ? new Timing(true) : null;
                MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, 
begin, end);
                MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, 
begin, end);
-               if (DMLScript.STATISTICS)
-                       Statistics.accPSBatchIndexingTime((long) tSlic.stop());
+               accBatchIndexingTime(tSlic);
 
                _ec.setVariable(Statement.PS_FEATURES, bFeatures);
                _ec.setVariable(Statement.PS_LABELS, bLabels);
@@ -185,8 +186,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                // Invoke the update function
                Timing tGrad = DMLScript.STATISTICS ? new Timing(true) : null;
                _inst.processInstruction(_ec);
-               if (DMLScript.STATISTICS)
-                       Statistics.accPSGradientComputeTime((long) 
tGrad.stop());
+               accGradientComputeTime(tGrad);
 
                // Get the gradients
                ListObject gradients = (ListObject) 
_ec.getVariable(_output.getName());
@@ -195,4 +195,28 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                ParamservUtils.cleanupData(_ec, bLabels);
                return gradients;
        }
+       
+       @Override
+       protected void incWorkerNumber() {
+               if (DMLScript.STATISTICS)
+                       Statistics.incWorkerNumber();
+       }
+
+       @Override
+       protected void accLocalModelUpdateTime(Timing time) {
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSLocalModelUpdateTime((long) 
time.stop());
+       }
+
+       @Override
+       protected void accBatchIndexingTime(Timing time) {
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSBatchIndexingTime((long) time.stop());
+       }
+
+       @Override
+       protected void accGradientComputeTime(Timing time) {
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSGradientComputeTime((long) time.stop());
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index 4b5c5c1..5f2d552 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -32,11 +32,12 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 
-public abstract class PSWorker implements Serializable {
-
+public abstract class PSWorker implements Serializable 
+{
        private static final long serialVersionUID = -3510485051178200118L;
 
        protected int _workerID;
@@ -133,4 +134,23 @@ public abstract class PSWorker implements Serializable {
        }
 
        public abstract String getWorkerName();
+
+       /**
+        * ----- The following methods are dedicated to statistics -------------
+        */
+       protected abstract void incWorkerNumber();
+
+       protected abstract void accLocalModelUpdateTime(Timing time);
+
+       protected abstract void accBatchIndexingTime(Timing time);
+
+       protected abstract void accGradientComputeTime(Timing time);
+
+       protected void accNumEpochs(int n) {
+               //do nothing
+       }
+       
+       protected void accNumBatches(int n) {
+               //do nothing
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index cf27457..9624c55 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -156,11 +156,16 @@ public class ParamservUtils {
        }
 
        public static MatrixObject newMatrixObject(MatrixBlock mb) {
+               return newMatrixObject(mb, true);
+       }
+       
+       public static MatrixObject newMatrixObject(MatrixBlock mb, boolean 
cleanup) {
                MatrixObject result = new 
MatrixObject(Expression.ValueType.DOUBLE, 
OptimizerUtils.getUniqueTempFileName(),
                        new MetaDataFormat(new MatrixCharacteristics(-1, -1, 
ConfigurationManager.getBlocksize(),
                        ConfigurationManager.getBlocksize()), 
OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
                result.acquireModify(mb);
                result.release();
+               result.enableCleanup(cleanup);
                return result;
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
index de7b6c6..48a4883 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
@@ -22,7 +22,10 @@ package 
org.apache.sysml.runtime.controlprogram.paramserv.spark;
 import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL;
 import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH;
 
+import java.io.IOException;
+
 import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
@@ -30,25 +33,35 @@ import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
 import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.utils.Statistics;
 
 public class SparkPSProxy extends ParamServer {
 
-       private TransportClient _client;
+       private final TransportClient _client;
        private final long _rpcTimeout;
+       private final LongAccumulator _aRPC;
 
-       public SparkPSProxy(TransportClient client, long rpcTimeout) {
+       public SparkPSProxy(TransportClient client, long rpcTimeout, 
LongAccumulator aRPC) {
                super();
                _client = client;
                _rpcTimeout = rpcTimeout;
+               _aRPC = aRPC;
+       }
+
+       private void accRpcRequestTime(Timing tRpc) {
+               if (DMLScript.STATISTICS)
+                       _aRPC.add((long) tRpc.stop());
        }
 
        @Override
        public void push(int workerID, ListObject value) {
                Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
-               PSRpcResponse response = new 
PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, 
value).serialize(), _rpcTimeout));
-               if (DMLScript.STATISTICS)
-                       Statistics.accPSRpcRequestTime((long) tRpc.stop());
+               PSRpcResponse response;
+               try {
+                       response = new PSRpcResponse(_client.sendRpcSync(new 
PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout));
+               } catch (IOException e) {
+                       throw new 
DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push 
gradients.", workerID), e);
+               }
+               accRpcRequestTime(tRpc);
                if (!response.isSuccessful()) {
                        throw new 
DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push 
gradients. \n%s", workerID, response.getErrorMessage()));
                }
@@ -57,9 +70,13 @@ public class SparkPSProxy extends ParamServer {
        @Override
        public ListObject pull(int workerID) {
                Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
-               PSRpcResponse response = new 
PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, 
null).serialize(), _rpcTimeout));
-               if (DMLScript.STATISTICS)
-                       Statistics.accPSRpcRequestTime((long) tRpc.stop());
+               PSRpcResponse response;
+               try {
+                       response = new PSRpcResponse(_client.sendRpcSync(new 
PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout));
+               } catch (IOException e) {
+                       throw new 
DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull 
models.", workerID), e);
+               }
+               accRpcRequestTime(tRpc);
                if (!response.isSuccessful()) {
                        throw new 
DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull 
models. \n%s", workerID, response.getErrorMessage()));
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
index fa06243..59203ad 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
@@ -23,7 +23,9 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
+import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.codegen.CodegenUtils;
@@ -34,7 +36,6 @@ import 
org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.util.ProgramConverter;
-import org.apache.sysml.utils.Statistics;
 
 import scala.Tuple2;
 
@@ -42,13 +43,21 @@ public class SparkPSWorker extends LocalPSWorker implements 
VoidFunction<Tuple2<
 
        private static final long serialVersionUID = -8674739573419648732L;
 
-       private String _program;
-       private HashMap<String, byte[]> _clsMap;
-       private String _host; // host ip of driver
-       private long _rpcTimeout; // rpc ask timeout
-       private String _aggFunc;
-
-       public SparkPSWorker(String updFunc, String aggFunc, 
Statement.PSFrequency freq, int epochs, long batchSize, String program, 
HashMap<String, byte[]> clsMap, String host, long rpcTimeout) {
+       private final String _program;
+       private final HashMap<String, byte[]> _clsMap;
+       private final SparkConf _conf;
+       private final int _port; // rpc port
+       private final String _aggFunc;
+       private final LongAccumulator _aSetup; // accumulator for setup time
+       private final LongAccumulator _aWorker; // accumulator for worker number
+       private final LongAccumulator _aUpdate; // accumulator for model update
+       private final LongAccumulator _aIndex; // accumulator for batch indexing
+       private final LongAccumulator _aGrad; // accumulator for gradients 
computing
+       private final LongAccumulator _aRPC; // accumulator for rpc request
+       private final LongAccumulator _nBatches; //number of executed batches
+       private final LongAccumulator _nEpochs; //number of executed epoches
+       
+       public SparkPSWorker(String updFunc, String aggFunc, 
Statement.PSFrequency freq, int epochs, long batchSize, String program, 
HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator 
aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator 
aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, 
LongAccumulator aEpochs) {
                _updFunc = updFunc;
                _aggFunc = aggFunc;
                _freq = freq;
@@ -56,21 +65,29 @@ public class SparkPSWorker extends LocalPSWorker implements 
VoidFunction<Tuple2<
                _batchSize = batchSize;
                _program = program;
                _clsMap = clsMap;
-               _host = host;
-               _rpcTimeout = rpcTimeout;
+               _conf = conf;
+               _port = port;
+               _aSetup = aSetup;
+               _aWorker = aWorker;
+               _aUpdate = aUpdate;
+               _aIndex = aIndex;
+               _aGrad = aGrad;
+               _aRPC = aRPC;
+               _nBatches = aBatches;
+               _nEpochs = aEpochs;
        }
 
        @Override
        public String getWorkerName() {
                return String.format("Spark worker_%d", _workerID);
        }
-
+       
        @Override
        public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> 
input) throws Exception {
                Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
                configureWorker(input);
-               if (DMLScript.STATISTICS)
-                       Statistics.accPSSetupTime((long) tSetup.stop());
+               accSetupTime(tSetup);
+
                call(); // Launch the worker
        }
 
@@ -89,8 +106,14 @@ public class SparkPSWorker extends LocalPSWorker implements 
VoidFunction<Tuple2<
                // Initialize the buffer pool and register it in the jvm 
shutdown hook in order to be cleanuped at the end
                RemoteParForUtils.setupBufferPool(_workerID);
 
+               // Get some configurations
+               long rpcTimeout = _conf.contains("spark.rpc.askTimeout") ?
+                       _conf.getTimeAsMs("spark.rpc.askTimeout") :
+                       _conf.getTimeAsMs("spark.network.timeout", "120s");
+               String host = _conf.get("spark.driver.host");
+
                // Create the ps proxy
-               _ps = PSRpcFactory.createSparkPSProxy(_host, _rpcTimeout);
+               _ps = PSRpcFactory.createSparkPSProxy(_conf, host, _port, 
rpcTimeout, _aRPC);
 
                // Initialize the update function
                setupUpdateFunction(_updFunc, _ec);
@@ -104,4 +127,46 @@ public class SparkPSWorker extends LocalPSWorker 
implements VoidFunction<Tuple2<
                _features.enableCleanup(false);
                _labels.enableCleanup(false);
        }
+       
+
+       @Override
+       public void incWorkerNumber() {
+               if (DMLScript.STATISTICS)
+                       _aWorker.add(1);
+       }
+       
+       @Override
+       public void accLocalModelUpdateTime(Timing time) {
+               if (DMLScript.STATISTICS)
+                       _aUpdate.add((long) time.stop());
+       }
+
+       @Override
+       public void accBatchIndexingTime(Timing time) {
+               if (DMLScript.STATISTICS)
+                       _aIndex.add((long) time.stop());
+       }
+
+       @Override
+       public void accGradientComputeTime(Timing time) {
+               if (DMLScript.STATISTICS)
+                       _aGrad.add((long) time.stop());
+       }
+       
+       @Override
+       protected void accNumEpochs(int n) {
+               if (DMLScript.STATISTICS)
+                       _nEpochs.add(n);
+       }
+       
+       @Override
+       protected void accNumBatches(int n) {
+               if (DMLScript.STATISTICS)
+                       _nBatches.add(n);
+       }
+       
+       private void accSetupTime(Timing tSetup) {
+               if (DMLScript.STATISTICS)
+                       _aSetup.add((long) tSetup.stop());
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
index 999d409..b8f482c 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
@@ -19,71 +19,34 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 
-import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
-import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
-import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
-import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
-import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
-import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
-
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.util.StringTokenizer;
 
+import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.runtime.io.IOUtilFunctions;
+import org.apache.sysml.runtime.util.FastBufferedDataOutputStream;
 
 public class PSRpcCall extends PSRpcObject {
 
-       private static final String PS_RPC_CALL_BEGIN = CDATA_BEGIN + 
"PSRPCCALL" + LEVELIN;
-       private static final String PS_RPC_CALL_END = LEVELOUT + CDATA_END;
-
-       private String _method;
+       private int _method;
        private int _workerID;
        private ListObject _data;
 
-       public PSRpcCall(String method, int workerID, ListObject data) {
+       public PSRpcCall(int method, int workerID, ListObject data) {
                _method = method;
                _workerID = workerID;
                _data = data;
        }
 
-       public PSRpcCall(ByteBuffer buffer) {
+       public PSRpcCall(ByteBuffer buffer) throws IOException {
                deserialize(buffer);
        }
 
-       public void deserialize(ByteBuffer buffer) {
-               //FIXME: instead of shallow deserialize + read, we should do a 
deep deserialize of the matrix blocks.
-               String input = bufferToString(buffer);
-               //header elimination
-               input = input.substring(PS_RPC_CALL_BEGIN.length(), 
input.length() - PS_RPC_CALL_END.length()); //remove start/end
-               StringTokenizer st = new StringTokenizer(input, 
COMPONENTS_DELIM);
-
-               _method = st.nextToken();
-               _workerID = Integer.valueOf(st.nextToken());
-               String dataStr = st.nextToken();
-               _data = dataStr.equals(EMPTY) ? null :
-                       (ListObject) 
ProgramConverter.parseDataObject(dataStr)[1];
-       }
-
-       public ByteBuffer serialize() {
-               //FIXME: instead of export+shallow serialize, we should do a 
deep serialize of the matrix blocks.
-               StringBuilder sb = new StringBuilder();
-               sb.append(PS_RPC_CALL_BEGIN);
-               sb.append(_method);
-               sb.append(COMPONENTS_DELIM);
-               sb.append(_workerID);
-               sb.append(COMPONENTS_DELIM);
-               if (_data == null) {
-                       sb.append(EMPTY);
-               } else {
-                       flushListObject(_data);
-                       
sb.append(ProgramConverter.serializeDataObject(DATA_KEY, _data));
-               }
-               sb.append(PS_RPC_CALL_END);
-               return ByteBuffer.wrap(sb.toString().getBytes());
-       }
-
-       public String getMethod() {
+       public int getMethod() {
                return _method;
        }
 
@@ -94,4 +57,37 @@ public class PSRpcCall extends PSRpcObject {
        public ListObject getData() {
                return _data;
        }
+       
+       public void deserialize(ByteBuffer buffer) throws IOException {
+               DataInputStream dis = new DataInputStream(
+                       new 
ByteArrayInputStream(IOUtilFunctions.getBytes(buffer)));
+               _method = dis.readInt();
+               validateMethod(_method);
+               _workerID = dis.readInt();
+               if (dis.available() > 1)
+                       _data = readAndDeserialize(dis);
+               dis.close();
+       }
+
+       public ByteBuffer serialize() throws IOException {
+               //TODO: Perf: use CacheDataOutput to avoid multiple copies 
(needs UTF handling)
+               ByteArrayOutputStream bos = new 
ByteArrayOutputStream(getApproxSerializedSize(_data));
+               FastBufferedDataOutputStream dos = new 
FastBufferedDataOutputStream(bos);
+               dos.writeInt(_method);
+               dos.writeInt(_workerID);
+               if (_data != null)
+                       serializeAndWriteListObject(_data, dos);
+               dos.flush();
+               return ByteBuffer.wrap(bos.toByteArray());
+       }
+       
+       private void validateMethod(int method) {
+               switch (method) {
+                       case PUSH:
+                       case PULL:
+                               break;
+                       default:
+                               throw new DMLRuntimeException("PSRpcCall: only 
support rpc method 'push' or 'pull'");
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
index c8b4024..2d921de 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
@@ -22,36 +22,36 @@ package 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 import java.io.IOException;
 import java.util.Collections;
 
+import org.apache.spark.SparkConf;
 import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.netty.SparkTransportConf;
 import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.SystemPropertyConfigProvider;
 import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy;
 
-//TODO should be able to configure the port by users
 public class PSRpcFactory {
 
-       private static final int PORT = 5055;
        private static final String MODULE_NAME = "ps";
 
-       private static TransportContext createTransportContext(LocalParamServer 
ps) {
-               TransportConf conf = new TransportConf(MODULE_NAME, new 
SystemPropertyConfigProvider());
+       private static TransportContext createTransportContext(SparkConf conf, 
LocalParamServer ps) {
+               TransportConf tc = SparkTransportConf.fromSparkConf(conf, 
MODULE_NAME, 0);;
                PSRpcHandler handler = new PSRpcHandler(ps);
-               return new TransportContext(conf, handler);
+               return new TransportContext(tc, handler);
        }
 
        /**
         * Create and start the server
         * @return server
         */
-       public static TransportServer createServer(LocalParamServer ps, String 
host) {
-               TransportContext context = createTransportContext(ps);
-               return context.createServer(host, PORT, 
Collections.emptyList());
+       public static TransportServer createServer(SparkConf conf, 
LocalParamServer ps, String host) {
+               TransportContext context = createTransportContext(conf, ps);
+               return context.createServer(host, 0, Collections.emptyList());  
// bind rpc to an ephemeral port
        }
 
-       public static SparkPSProxy createSparkPSProxy(String host, long 
rpcTimeout) throws IOException {
-               TransportContext context = createTransportContext(new 
LocalParamServer());
-               return new 
SparkPSProxy(context.createClientFactory().createClient(host, PORT), 
rpcTimeout);
+       public static SparkPSProxy createSparkPSProxy(SparkConf conf, String 
host, int port, long rpcTimeout, LongAccumulator aRPC) throws IOException {
+               TransportContext context = createTransportContext(conf, new 
LocalParamServer());
+               return new 
SparkPSProxy(context.createClientFactory().createClient(host, port), 
rpcTimeout, aRPC);
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
index 3d73a37..a2c311e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
@@ -21,10 +21,8 @@ package 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 
 import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL;
 import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH;
-import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.EMPTY_DATA;
-import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.ERROR;
-import static 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.SUCCESS;
 
+import java.io.IOException;
 import java.nio.ByteBuffer;
 
 import org.apache.commons.lang.exception.ExceptionUtils;
@@ -35,6 +33,7 @@ import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.Type;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 
 public final class PSRpcHandler extends RpcHandler {
@@ -47,28 +46,41 @@ public final class PSRpcHandler extends RpcHandler {
 
        @Override
        public void receive(TransportClient client, ByteBuffer buffer, 
RpcResponseCallback callback) {
-               PSRpcCall call = new PSRpcCall(buffer);
+               PSRpcCall call;
+               try {
+                       call = new PSRpcCall(buffer);
+               } catch (IOException e) {
+                       throw new DMLRuntimeException("PSRpcHandler: some error 
occurred when deserializing the rpc call.", e);
+               }
                PSRpcResponse response = null;
                switch (call.getMethod()) {
                        case PUSH:
                                try {
                                        _server.push(call.getWorkerID(), 
call.getData());
-                                       response = new PSRpcResponse(SUCCESS, 
EMPTY_DATA);
+                                       response = new 
PSRpcResponse(Type.SUCCESS_EMPTY);
                                } catch (DMLRuntimeException exception) {
-                                       response = new PSRpcResponse(ERROR, 
ExceptionUtils.getFullStackTrace(exception));
+                                       response = new 
PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
                                } finally {
-                                       
callback.onSuccess(response.serialize());
+                                       try {
+                                               
callback.onSuccess(response.serialize());
+                                       } catch (IOException e) {
+                                               throw new 
DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc 
response.", e);
+                                       }
                                }
                                break;
                        case PULL:
                                ListObject data;
                                try {
                                        data = _server.pull(call.getWorkerID());
-                                       response = new PSRpcResponse(SUCCESS, 
data);
+                                       response = new 
PSRpcResponse(Type.SUCCESS, data);
                                } catch (DMLRuntimeException exception) {
-                                       response = new PSRpcResponse(ERROR, 
ExceptionUtils.getFullStackTrace(exception));
+                                       response = new 
PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
                                } finally {
-                                       
callback.onSuccess(response.serialize());
+                                       try {
+                                               
callback.onSuccess(response.serialize());
+                                       } catch (IOException e) {
+                                               throw new 
DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc 
response.", e);
+                                       }
                                }
                                break;
                        default:

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
index c6d7fd3..7d3353f 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
@@ -19,39 +19,86 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
 
-import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 
 public abstract class PSRpcObject {
 
-       public static final String PUSH = "push";
-       public static final String PULL = "pull";
-       public static final String DATA_KEY = "data";
-       public static final String EMPTY_DATA = "";
+       public static final int PUSH = 1;
+       public static final int PULL = 2;
 
-       public abstract void deserialize(ByteBuffer buffer);
+       public abstract void deserialize(ByteBuffer buffer) throws IOException;
 
-       public abstract ByteBuffer serialize();
+       public abstract ByteBuffer serialize() throws IOException;
 
        /**
-        * Convert direct byte buffer to string
-        * @param buffer direct byte buffer
-        * @return string
+        * Deep serialize and write of a list object (currently only support 
list containing matrices)
+        * @param lo a list object containing only matrices
+        * @param dos output data to write to
         */
-       protected String bufferToString(ByteBuffer buffer) {
-               byte[] result = new byte[buffer.limit()];
-               buffer.get(result, 0, buffer.limit());
-               return new String(result);
+       protected void serializeAndWriteListObject(ListObject lo, DataOutput 
dos) throws IOException {
+               validateListObject(lo);
+               dos.writeInt(lo.getLength()); //write list length
+               dos.writeBoolean(lo.isNamedList()); //write list named
+               for (int i = 0; i < lo.getLength(); i++) {
+                       if (lo.isNamedList())
+                               dos.writeUTF(lo.getName(i)); //write name
+                       ((MatrixObject) lo.getData().get(i))
+                               .acquireReadAndRelease().write(dos); //write 
matrix
+               }
+       }
+       
+       protected ListObject readAndDeserialize(DataInput dis) throws 
IOException {
+               int listLen = dis.readInt();
+               List<Data> data = new ArrayList<>();
+               List<String> names = dis.readBoolean() ?
+                       new ArrayList<>() : null;
+               for(int i=0; i<listLen; i++) {
+                       if( names != null )
+                               names.add(dis.readUTF());
+                       MatrixBlock mb = new MatrixBlock();
+                       mb.readFields(dis);
+                       data.add(ParamservUtils.newMatrixObject(mb, false));
+               }
+               return new ListObject(data, names);
        }
 
        /**
-        * Flush the data into HDFS
-        * @param data list object
+        * Get serialization size of a list object
+        * (scheme: size|name|size|matrix)
+        * @param lo list object
+        * @return serialization size
         */
-       protected void flushListObject(ListObject data) {
-               data.getData().stream().filter(d -> d instanceof CacheableData)
-                       .forEach(d -> ((CacheableData<?>) d).exportData());
+       protected int getApproxSerializedSize(ListObject lo) {
+               if( lo == null ) return 0;
+               long result = 4 + 1; // list length and of named
+               result += lo.getLength() * (Integer.BYTES); // bytes for the 
size of names
+               if (lo.isNamedList())
+                       result += lo.getNames().stream().mapToLong(s -> 
s.length()).sum();
+               result += lo.getData().stream().mapToLong(d ->
+                       
((MatrixObject)d).acquireReadAndRelease().getExactSizeOnDisk()).sum();
+               if( result > Integer.MAX_VALUE )
+                       throw new DMLRuntimeException("Serialized size 
("+result+") larger than Integer.MAX_VALUE.");
+               return (int) result;
+       }
+
+       private void validateListObject(ListObject lo) {
+               for (Data d : lo.getData()) {
+                       if (!(d instanceof MatrixObject)) {
+                               throw new 
DMLRuntimeException(String.format("Paramserv func:"
+                                       + " Unsupported deep serialize of %s, 
which is not matrix.", d.getDebugName()));
+                       }
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
index 998c523..3517491 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
@@ -19,41 +19,43 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
 
-import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
-import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
-import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
-import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
-import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
-import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
-
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.util.StringTokenizer;
 
 import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.runtime.io.IOUtilFunctions;
+import org.apache.sysml.runtime.util.FastBufferedDataOutputStream;
 
 public class PSRpcResponse extends PSRpcObject {
+       public enum Type  {
+               SUCCESS,
+               SUCCESS_EMPTY,
+               ERROR,
+       }
+       
+       private Type _status;
+       private Object _data; // Could be list object or exception
 
-       public static final int SUCCESS = 1;
-       public static final int ERROR = 2;
-
-       private static final String PS_RPC_RESPONSE_BEGIN = CDATA_BEGIN + 
"PSRPCRESPONSE" + LEVELIN;
-       private static final String PS_RPC_RESPONSE_END = LEVELOUT + CDATA_END;
-
-       private int _status;
-       private Object _data;   // Could be list object or exception
-
-       public PSRpcResponse(ByteBuffer buffer) {
+       public PSRpcResponse(ByteBuffer buffer) throws IOException {
                deserialize(buffer);
        }
 
-       public PSRpcResponse(int status, Object data) {
+       public PSRpcResponse(Type status) {
+               this(status, null);
+       }
+       
+       public PSRpcResponse(Type status, Object data) {
                _status = status;
                _data = data;
+               if( _status == Type.SUCCESS && data == null )
+                       _status = Type.SUCCESS_EMPTY;
        }
 
        public boolean isSuccessful() {
-               return _status == SUCCESS;
+               return _status != Type.ERROR;
        }
 
        public String getErrorMessage() {
@@ -65,48 +67,42 @@ public class PSRpcResponse extends PSRpcObject {
        }
 
        @Override
-       public void deserialize(ByteBuffer buffer) {
-               //FIXME: instead of shallow deserialize + read, we should do a 
deep deserialize of the matrix blocks.
-               String input = bufferToString(buffer);
-               //header elimination
-               input = input.substring(PS_RPC_RESPONSE_BEGIN.length(), 
input.length() - PS_RPC_RESPONSE_END.length()); //remove start/end
-               StringTokenizer st = new StringTokenizer(input, 
COMPONENTS_DELIM);
-
-               _status = Integer.valueOf(st.nextToken());
-               String data = st.nextToken();
+       public void deserialize(ByteBuffer buffer) throws IOException {
+               DataInputStream dis = new DataInputStream(
+                       new 
ByteArrayInputStream(IOUtilFunctions.getBytes(buffer)));
+               _status = Type.values()[dis.readInt()];
                switch (_status) {
                        case SUCCESS:
-                               _data = data.equals(EMPTY) ? null :
-                                       
ProgramConverter.parseDataObject(data)[1];
+                               _data = readAndDeserialize(dis);
+                               break;
+                       case SUCCESS_EMPTY:
                                break;
                        case ERROR:
-                               _data = data;
+                               _data = dis.readUTF();
                                break;
                }
+               dis.close();
        }
 
        @Override
-       public ByteBuffer serialize() {
-               //FIXME: instead of export+shallow serialize, we should do a 
deep serialize of the matrix blocks.
-               
-               StringBuilder sb = new StringBuilder();
-               sb.append(PS_RPC_RESPONSE_BEGIN);
-               sb.append(_status);
-               sb.append(COMPONENTS_DELIM);
+       public ByteBuffer serialize() throws IOException {
+               //TODO: Perf: use CacheDataOutput to avoid multiple copies 
(needs UTF handling)
+               int len = 4 + (_status==Type.SUCCESS ? 
getApproxSerializedSize((ListObject)_data) :
+                       _status==Type.SUCCESS_EMPTY ? 0 : 
((String)_data).length());
+               ByteArrayOutputStream bos = new ByteArrayOutputStream(len);
+               FastBufferedDataOutputStream dos = new 
FastBufferedDataOutputStream(bos);
+               dos.writeInt(_status.ordinal());
                switch (_status) {
                        case SUCCESS:
-                               if (_data.equals(EMPTY_DATA)) {
-                                       sb.append(EMPTY);
-                               } else {
-                                       flushListObject((ListObject) _data);
-                                       
sb.append(ProgramConverter.serializeDataObject(DATA_KEY, (ListObject) _data));
-                               }
+                               serializeAndWriteListObject((ListObject) _data, 
dos);
+                               break;
+                       case SUCCESS_EMPTY:
                                break;
                        case ERROR:
-                               sb.append(_data.toString());
+                               dos.writeUTF(_data.toString());
                                break;
                }
-               sb.append(PS_RPC_RESPONSE_END);
-               return ByteBuffer.wrap(sb.toString().getBytes());
+               dos.flush();
+               return ByteBuffer.wrap(bos.toByteArray());
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 6133987..fe238bd 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -56,6 +56,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
 import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.recompile.Recompiler;
 import org.apache.sysml.lops.LopProperties;
@@ -125,11 +126,25 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                // Get the compiled execution context
                LocalVariableMap newVarsMap = createVarsMap(sec);
-               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); // 
level of par is 1 in spark backend
+               // Level of par is 1 in spark backend because one worker will 
be launched per task
+               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1);
 
                MatrixObject features = 
sec.getMatrixObject(getParam(PS_FEATURES));
                MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
 
+               // Create the agg service's execution context
+               ExecutionContext aggServiceEC = 
ParamservUtils.copyExecutionContext(newEC, 1).get(0);
+
+               // Create the parameter server
+               ListObject model = sec.getListObject(getParam(PS_MODEL));
+               ParamServer ps = createPS(mode, aggFunc, getUpdateType(), 
workerNum, model, aggServiceEC);
+
+               // Get driver host
+               String host = 
sec.getSparkContext().getConf().get("spark.driver.host");
+
+               // Create the netty server for ps
+               TransportServer server = 
PSRpcFactory.createServer(sec.getSparkContext().getConf(),(LocalParamServer) 
ps, host); // Start the server
+
                // Force all the instructions to CP type
                Recompiler.recompileProgramBlockHierarchy2Forced(
                        newEC.getProgram().getProgramBlocks(), 0, new 
HashSet<>(), LopProperties.ExecType.CP);
@@ -139,29 +154,24 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                HashMap<String, byte[]> clsMap = new HashMap<>();
                String program = ProgramConverter.serializeSparkPSBody(body, 
clsMap);
 
-               // Get some configurations
-               String host = 
sec.getSparkContext().getConf().get("spark.driver.host");
-               long rpcTimeout = 
sec.getSparkContext().getConf().contains("spark.rpc.askTimeout") ? 
-                       
sec.getSparkContext().getConf().getTimeAsMs("spark.rpc.askTimeout") :
-                       
sec.getSparkContext().getConf().getTimeAsMs("spark.network.timeout", "120s");
+               // Add the accumulators for statistics
+               LongAccumulator aSetup = 
sec.getSparkContext().sc().longAccumulator("setup");
+               LongAccumulator aWorker = 
sec.getSparkContext().sc().longAccumulator("workersNum");
+               LongAccumulator aUpdate = 
sec.getSparkContext().sc().longAccumulator("modelUpdate");
+               LongAccumulator aIndex = 
sec.getSparkContext().sc().longAccumulator("batchIndex");
+               LongAccumulator aGrad = 
sec.getSparkContext().sc().longAccumulator("gradCompute");
+               LongAccumulator aRPC = 
sec.getSparkContext().sc().longAccumulator("rpcRequest");
+               LongAccumulator aBatch = 
sec.getSparkContext().sc().longAccumulator("numBatches");
+               LongAccumulator aEpoch = 
sec.getSparkContext().sc().longAccumulator("numEpochs");
 
                // Create remote workers
-               SparkPSWorker worker = new 
SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), 
getFrequency(),
-                       getEpochs(), getBatchSize(), program, clsMap, host, 
rpcTimeout);
-
-               // Create the agg service's execution context
-               ExecutionContext aggServiceEC = 
ParamservUtils.copyExecutionContext(newEC, 1).get(0);
-
-               // Create the parameter server
-               ListObject model = sec.getListObject(getParam(PS_MODEL));
-               ParamServer ps = createPS(mode, aggFunc, getUpdateType(), 
workerNum, model, aggServiceEC);
+               SparkPSWorker worker = new 
SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), 
+                       getFrequency(), getEpochs(), getBatchSize(), program, 
clsMap, sec.getSparkContext().getConf(),
+                       server.getPort(), aSetup, aWorker, aUpdate, aIndex, 
aGrad, aRPC, aBatch, aEpoch);
 
                if (DMLScript.STATISTICS)
                        Statistics.accPSSetupTime((long) tSetup.stop());
 
-               // Create the netty server for ps
-               TransportServer server = 
PSRpcFactory.createServer((LocalParamServer) ps, host); // Start the server
-
                try {
                        ParamservUtils.doPartitionOnSpark(sec, features, 
labels, scheme, workerNum) // Do data partitioning
                                .foreach(worker); // Run remote workers
@@ -172,6 +182,16 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        server.close();
                }
 
+               // Accumulate the statistics for remote workers
+               if (DMLScript.STATISTICS) {
+                       Statistics.accPSSetupTime(aSetup.sum());
+                       Statistics.incWorkerNumber(aWorker.sum());
+                       Statistics.accPSLocalModelUpdateTime(aUpdate.sum());
+                       Statistics.accPSBatchIndexingTime(aIndex.sum());
+                       Statistics.accPSGradientComputeTime(aGrad.sum());
+                       Statistics.accPSRpcRequestTime(aRPC.sum());
+               }
+
                // Fetch the final model from ps
                sec.setVariable(output.getName(), ps.getResult());
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java 
b/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java
index 94941a1..18f0e54 100644
--- a/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java
+++ b/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java
@@ -26,6 +26,7 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.io.StringReader;
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
@@ -608,4 +609,13 @@ public class IOUtilFunctions
                ba[ off+6 ] = (byte)((val >>>  8) & 0xFF);
                ba[ off+7 ] = (byte)((val >>>  0) & 0xFF);
        }
+       
+       public static byte[] getBytes(ByteBuffer buff) {
+               int len = buff.limit();
+               if( buff.hasArray() )
+                       return Arrays.copyOf(buff.array(), len);
+               byte[] ret = new byte[len];
+               buff.get(ret, buff.position(), len);
+               return ret;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java 
b/src/main/java/org/apache/sysml/utils/Statistics.java
index 1dd8362..44667ba 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -541,6 +541,10 @@ public class Statistics
                psNumWorkers.increment();
        }
 
+       public static void incWorkerNumber(long n) {
+               psNumWorkers.add(n);
+       }
+
        public static void accPSSetupTime(long t) {
                psSetupTime.add(t);
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
index 30eccb3..89235d7 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
@@ -1,15 +1,8 @@
 package org.apache.sysml.test.integration.functions.paramserv;
 
-import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.mlcontext.MLContext;
-import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.junit.Test;
@@ -42,12 +35,12 @@ public class ParamservSparkNNTest extends AutomatedTestBase 
{
 
        @Test
        public void testParamservBSPEpochDisjointContiguous() {
-               runDMLTest(10, 3, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+               runDMLTest(5, 3, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
        }
 
        @Test
        public void testParamservASPEpochDisjointContiguous() {
-               runDMLTest(10, 3, Statement.PSUpdateType.ASP, 
Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+               runDMLTest(5, 3, Statement.PSUpdateType.ASP, 
Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
        }
 
        @Test
@@ -62,9 +55,14 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
 
        private void runDMLTest(String testname, boolean exceptionExpected, 
Class<?> expectedException, String errMessage) {
                programArgs = new String[] { "-explain" };
+               internalRunDMLTest(testname, exceptionExpected, 
expectedException, errMessage);
+       }
+
+       private void internalRunDMLTest(String testname, boolean 
exceptionExpected, Class<?> expectedException,
+                       String errMessage) {
                DMLScript.RUNTIME_PLATFORM oldRtplatform = 
AutomatedTestBase.rtplatform;
                boolean oldUseLocalSparkConfig = 
DMLScript.USE_LOCAL_SPARK_CONFIG;
-               AutomatedTestBase.rtplatform = DMLScript.RUNTIME_PLATFORM.SPARK;
+               AutomatedTestBase.rtplatform = 
DMLScript.RUNTIME_PLATFORM.HYBRID_SPARK;
                DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 
                try {
@@ -80,22 +78,7 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
        }
 
        private void runDMLTest(int epochs, int workers, Statement.PSUpdateType 
utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
-               Script script = dmlFromFile(SCRIPT_DIR + TEST_DIR + TEST_NAME1 
+ ".dml").in("$mode", Statement.PSModeType.REMOTE_SPARK.toString())
-                       .in("$epochs", String.valueOf(epochs))
-                       .in("$workers", String.valueOf(workers))
-                       .in("$utype", utype.toString())
-                       .in("$freq", freq.toString())
-                       .in("$batchsize", String.valueOf(batchsize))
-                       .in("$scheme", scheme.toString());
-
-               SparkConf conf = 
SparkExecutionContext.createSystemMLSparkConf().setAppName("ParamservSparkNNTest").setMaster("local[*]")
-                       .set("spark.driver.allowMultipleContexts", "true");
-               JavaSparkContext sc = new JavaSparkContext(conf);
-               MLContext ml = new MLContext(sc);
-               ml.setStatistics(true);
-               ml.execute(script);
-               ml.resetConfig();
-               sc.stop();
-               ml.close();
+               programArgs = new String[] { "-explain", "-nvargs", 
"mode=REMOTE_SPARK", "epochs=" + epochs, "workers=" + workers, "utype=" + 
utype, "freq=" + freq, "batchsize=" + batchsize, "scheme=" + scheme};
+               internalRunDMLTest(TEST_NAME1, false, null, null);
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
index 57e1106..f2df1e6 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
@@ -19,13 +19,13 @@
 
 package org.apache.sysml.test.integration.functions.paramserv;
 
+import java.io.IOException;
 import java.util.Arrays;
 
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
 import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject;
 import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
-import org.apache.sysml.runtime.instructions.cp.IntObject;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.junit.Assert;
 import org.junit.Test;
@@ -33,24 +33,26 @@ import org.junit.Test;
 public class RpcObjectTest {
 
        @Test
-       public void testPSRpcCall() {
+       public void testPSRpcCall() throws IOException {
                MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
                MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
-               IntObject io = new IntObject(30);
-               ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
+               ListObject lo = new ListObject(Arrays.asList(mo1, mo2));
                PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, lo);
                PSRpcCall actual = new PSRpcCall(expected.serialize());
-               Assert.assertEquals(new String(expected.serialize().array()), 
new String(actual.serialize().array()));
+               Assert.assertTrue(Arrays.equals(
+                       expected.serialize().array(),
+                       actual.serialize().array()));
        }
 
        @Test
-       public void testPSRpcResponse() {
+       public void testPSRpcResponse() throws IOException {
                MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
                MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
-               IntObject io = new IntObject(30);
-               ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
-               PSRpcResponse expected = new 
PSRpcResponse(PSRpcResponse.SUCCESS, lo);
+               ListObject lo = new ListObject(Arrays.asList(mo1, mo2));
+               PSRpcResponse expected = new 
PSRpcResponse(PSRpcResponse.Type.SUCCESS, lo);
                PSRpcResponse actual = new PSRpcResponse(expected.serialize());
-               Assert.assertEquals(new String(expected.serialize().array()), 
new String(actual.serialize().array()));
+               Assert.assertTrue(Arrays.equals(
+                       expected.serialize().array(),
+                       actual.serialize().array()));
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml 
b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
index 4d0f32e..d1edc29 100644
--- 
a/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
+++ 
b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-e1 = "element1"
+e1 = matrix(1, rows=100, cols=10)
 modelList = list(e1)
 X = matrix(1, rows=200, cols=30)
 Y = matrix(2, rows=200, cols=1)
@@ -42,11 +42,9 @@ aggregation = function(list[unknown] model,
   print(toString(as.matrix(gradients["agg_service_err"])))
 }
 
-e2 = "element2"
+e2 = matrix(2, rows=100, cols=10)
 params = list(e2)
 
-modelList = list("model")
-
 # Use paramserv function
 modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1)
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml 
b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
index ad16122..bf0de68 100644
--- a/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-e1 = "element1"
+e1 = matrix(1, rows=100, cols=10)
 modelList = list(e1)
 X = matrix(1, rows=200, cols=30)
 Y = matrix(2, rows=200, cols=1)
@@ -42,11 +42,9 @@ aggregation = function(list[unknown] model,
   modelResult = model
 }
 
-e2 = "element2"
+e2 = matrix(2, rows=100, cols=10)
 params = list(e2)
 
-modelList = list("model")
-
 # Use paramserv function
 modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1)
 

Reply via email to