Repository: systemml Updated Branches: refs/heads/master bfd495289 -> e0c271fe4
[SYSTEMML-2420,2457] Improved distributed paramserv comm and stats Closes #808. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e0c271fe Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e0c271fe Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e0c271fe Branch: refs/heads/master Commit: e0c271fe4cb05d3a53eb0143ab298e26831d1ed7 Parents: bfd4952 Author: EdgarLGB <guobao...@atos.net> Authored: Thu Jul 26 23:33:12 2018 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Thu Jul 26 23:33:13 2018 -0700 ---------------------------------------------------------------------- .../controlprogram/caching/CacheableData.java | 6 ++ .../controlprogram/paramserv/LocalPSWorker.java | 48 +++++++--- .../controlprogram/paramserv/PSWorker.java | 24 ++++- .../paramserv/ParamservUtils.java | 5 ++ .../paramserv/spark/SparkPSProxy.java | 35 ++++++-- .../paramserv/spark/SparkPSWorker.java | 93 ++++++++++++++++--- .../paramserv/spark/rpc/PSRpcCall.java | 92 +++++++++---------- .../paramserv/spark/rpc/PSRpcFactory.java | 24 ++--- .../paramserv/spark/rpc/PSRpcHandler.java | 32 ++++--- .../paramserv/spark/rpc/PSRpcObject.java | 85 ++++++++++++++---- .../paramserv/spark/rpc/PSRpcResponse.java | 94 ++++++++++---------- .../cp/ParamservBuiltinCPInstruction.java | 56 ++++++++---- .../sysml/runtime/io/IOUtilFunctions.java | 10 +++ .../java/org/apache/sysml/utils/Statistics.java | 4 + .../paramserv/ParamservSparkNNTest.java | 37 +++----- .../functions/paramserv/RpcObjectTest.java | 22 ++--- .../paramserv-spark-agg-service-failed.dml | 6 +- .../paramserv/paramserv-spark-worker-failed.dml | 6 +- 18 files changed, 441 insertions(+), 238 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java index f524251..0265c33 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java @@ -364,6 +364,12 @@ public abstract class CacheableData<T extends CacheBlock> extends Data // *** *** // ********************************************* + public T acquireReadAndRelease() { + T tmp = acquireRead(); + release(); + return tmp; + } + /** * Acquires a shared "read-only" lock, produces the reference to the cache block, * restores the cache block to main memory, reads from HDFS if needed. http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index c23943d..b8a416f 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -48,12 +48,10 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { public String getWorkerName() { return String.format("Local worker_%d", _workerID); } - + @Override public Void call() throws Exception { - if (DMLScript.STATISTICS) - Statistics.incWorkerNumber(); - + incWorkerNumber(); try { long dataSize = _features.getNumRows(); int totalIter = (int) Math.ceil((double) dataSize / _batchSize); @@ -94,17 +92,19 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { if( j < totalIter - 1 ) params = updateModel(params, gradients, i, j, totalIter); ParamservUtils.cleanupListObject(_ec, gradients); + + accNumBatches(1); } // Push the gradients to ps pushGradients(accGradients); ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL); + accNumEpochs(1); if (LOG.isDebugEnabled()) { LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1)); } } - } private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int totalIter) { @@ -112,8 +112,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { globalParams = _ps.updateLocalModel(_ec, gradients, globalParams); - if (DMLScript.STATISTICS) - Statistics.accPSLocalModelUpdateTime((long) tUpd.stop()); + accLocalModelUpdateTime(tUpd); if (LOG.isDebugEnabled()) { LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. " @@ -133,9 +132,12 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { // Push the gradients to ps pushGradients(gradients); - ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL); + + accNumBatches(1); } + + accNumEpochs(1); if (LOG.isDebugEnabled()) { LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1)); } @@ -169,8 +171,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { Timing tSlic = DMLScript.STATISTICS ? new Timing(true) : null; MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end); MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end); - if (DMLScript.STATISTICS) - Statistics.accPSBatchIndexingTime((long) tSlic.stop()); + accBatchIndexingTime(tSlic); _ec.setVariable(Statement.PS_FEATURES, bFeatures); _ec.setVariable(Statement.PS_LABELS, bLabels); @@ -185,8 +186,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { // Invoke the update function Timing tGrad = DMLScript.STATISTICS ? new Timing(true) : null; _inst.processInstruction(_ec); - if (DMLScript.STATISTICS) - Statistics.accPSGradientComputeTime((long) tGrad.stop()); + accGradientComputeTime(tGrad); // Get the gradients ListObject gradients = (ListObject) _ec.getVariable(_output.getName()); @@ -195,4 +195,28 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { ParamservUtils.cleanupData(_ec, bLabels); return gradients; } + + @Override + protected void incWorkerNumber() { + if (DMLScript.STATISTICS) + Statistics.incWorkerNumber(); + } + + @Override + protected void accLocalModelUpdateTime(Timing time) { + if (DMLScript.STATISTICS) + Statistics.accPSLocalModelUpdateTime((long) time.stop()); + } + + @Override + protected void accBatchIndexingTime(Timing time) { + if (DMLScript.STATISTICS) + Statistics.accPSBatchIndexingTime((long) time.stop()); + } + + @Override + protected void accGradientComputeTime(Timing time) { + if (DMLScript.STATISTICS) + Statistics.accPSGradientComputeTime((long) time.stop()); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index 4b5c5c1..5f2d552 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -32,11 +32,12 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; -public abstract class PSWorker implements Serializable { - +public abstract class PSWorker implements Serializable +{ private static final long serialVersionUID = -3510485051178200118L; protected int _workerID; @@ -133,4 +134,23 @@ public abstract class PSWorker implements Serializable { } public abstract String getWorkerName(); + + /** + * ----- The following methods are dedicated to statistics ------------- + */ + protected abstract void incWorkerNumber(); + + protected abstract void accLocalModelUpdateTime(Timing time); + + protected abstract void accBatchIndexingTime(Timing time); + + protected abstract void accGradientComputeTime(Timing time); + + protected void accNumEpochs(int n) { + //do nothing + } + + protected void accNumBatches(int n) { + //do nothing + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index cf27457..9624c55 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -156,11 +156,16 @@ public class ParamservUtils { } public static MatrixObject newMatrixObject(MatrixBlock mb) { + return newMatrixObject(mb, true); + } + + public static MatrixObject newMatrixObject(MatrixBlock mb, boolean cleanup) { MatrixObject result = new MatrixObject(Expression.ValueType.DOUBLE, OptimizerUtils.getUniqueTempFileName(), new MetaDataFormat(new MatrixCharacteristics(-1, -1, ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); result.acquireModify(mb); result.release(); + result.enableCleanup(cleanup); return result; } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java index de7b6c6..48a4883 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java @@ -22,7 +22,10 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark; import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL; import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH; +import java.io.IOException; + import org.apache.spark.network.client.TransportClient; +import org.apache.spark.util.LongAccumulator; import org.apache.sysml.api.DMLScript; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; @@ -30,25 +33,35 @@ import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall; import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.instructions.cp.ListObject; -import org.apache.sysml.utils.Statistics; public class SparkPSProxy extends ParamServer { - private TransportClient _client; + private final TransportClient _client; private final long _rpcTimeout; + private final LongAccumulator _aRPC; - public SparkPSProxy(TransportClient client, long rpcTimeout) { + public SparkPSProxy(TransportClient client, long rpcTimeout, LongAccumulator aRPC) { super(); _client = client; _rpcTimeout = rpcTimeout; + _aRPC = aRPC; + } + + private void accRpcRequestTime(Timing tRpc) { + if (DMLScript.STATISTICS) + _aRPC.add((long) tRpc.stop()); } @Override public void push(int workerID, ListObject value) { Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null; - PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout)); - if (DMLScript.STATISTICS) - Statistics.accPSRpcRequestTime((long) tRpc.stop()); + PSRpcResponse response; + try { + response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout)); + } catch (IOException e) { + throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients.", workerID), e); + } + accRpcRequestTime(tRpc); if (!response.isSuccessful()) { throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage())); } @@ -57,9 +70,13 @@ public class SparkPSProxy extends ParamServer { @Override public ListObject pull(int workerID) { Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null; - PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout)); - if (DMLScript.STATISTICS) - Statistics.accPSRpcRequestTime((long) tRpc.stop()); + PSRpcResponse response; + try { + response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout)); + } catch (IOException e) { + throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models.", workerID), e); + } + accRpcRequestTime(tRpc); if (!response.isSuccessful()) { throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage())); } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java index fa06243..59203ad 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java @@ -23,7 +23,9 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.VoidFunction; +import org.apache.spark.util.LongAccumulator; import org.apache.sysml.api.DMLScript; import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.codegen.CodegenUtils; @@ -34,7 +36,6 @@ import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.util.ProgramConverter; -import org.apache.sysml.utils.Statistics; import scala.Tuple2; @@ -42,13 +43,21 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2< private static final long serialVersionUID = -8674739573419648732L; - private String _program; - private HashMap<String, byte[]> _clsMap; - private String _host; // host ip of driver - private long _rpcTimeout; // rpc ask timeout - private String _aggFunc; - - public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, String host, long rpcTimeout) { + private final String _program; + private final HashMap<String, byte[]> _clsMap; + private final SparkConf _conf; + private final int _port; // rpc port + private final String _aggFunc; + private final LongAccumulator _aSetup; // accumulator for setup time + private final LongAccumulator _aWorker; // accumulator for worker number + private final LongAccumulator _aUpdate; // accumulator for model update + private final LongAccumulator _aIndex; // accumulator for batch indexing + private final LongAccumulator _aGrad; // accumulator for gradients computing + private final LongAccumulator _aRPC; // accumulator for rpc request + private final LongAccumulator _nBatches; //number of executed batches + private final LongAccumulator _nEpochs; //number of executed epoches + + public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs) { _updFunc = updFunc; _aggFunc = aggFunc; _freq = freq; @@ -56,21 +65,29 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2< _batchSize = batchSize; _program = program; _clsMap = clsMap; - _host = host; - _rpcTimeout = rpcTimeout; + _conf = conf; + _port = port; + _aSetup = aSetup; + _aWorker = aWorker; + _aUpdate = aUpdate; + _aIndex = aIndex; + _aGrad = aGrad; + _aRPC = aRPC; + _nBatches = aBatches; + _nEpochs = aEpochs; } @Override public String getWorkerName() { return String.format("Spark worker_%d", _workerID); } - + @Override public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception { Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null; configureWorker(input); - if (DMLScript.STATISTICS) - Statistics.accPSSetupTime((long) tSetup.stop()); + accSetupTime(tSetup); + call(); // Launch the worker } @@ -89,8 +106,14 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2< // Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end RemoteParForUtils.setupBufferPool(_workerID); + // Get some configurations + long rpcTimeout = _conf.contains("spark.rpc.askTimeout") ? + _conf.getTimeAsMs("spark.rpc.askTimeout") : + _conf.getTimeAsMs("spark.network.timeout", "120s"); + String host = _conf.get("spark.driver.host"); + // Create the ps proxy - _ps = PSRpcFactory.createSparkPSProxy(_host, _rpcTimeout); + _ps = PSRpcFactory.createSparkPSProxy(_conf, host, _port, rpcTimeout, _aRPC); // Initialize the update function setupUpdateFunction(_updFunc, _ec); @@ -104,4 +127,46 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2< _features.enableCleanup(false); _labels.enableCleanup(false); } + + + @Override + public void incWorkerNumber() { + if (DMLScript.STATISTICS) + _aWorker.add(1); + } + + @Override + public void accLocalModelUpdateTime(Timing time) { + if (DMLScript.STATISTICS) + _aUpdate.add((long) time.stop()); + } + + @Override + public void accBatchIndexingTime(Timing time) { + if (DMLScript.STATISTICS) + _aIndex.add((long) time.stop()); + } + + @Override + public void accGradientComputeTime(Timing time) { + if (DMLScript.STATISTICS) + _aGrad.add((long) time.stop()); + } + + @Override + protected void accNumEpochs(int n) { + if (DMLScript.STATISTICS) + _nEpochs.add(n); + } + + @Override + protected void accNumBatches(int n) { + if (DMLScript.STATISTICS) + _nBatches.add(n); + } + + private void accSetupTime(Timing tSetup) { + if (DMLScript.STATISTICS) + _aSetup.add((long) tSetup.stop()); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java index 999d409..b8f482c 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java @@ -19,71 +19,34 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; -import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN; -import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END; -import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM; -import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY; -import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN; -import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT; - +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.IOException; import java.nio.ByteBuffer; -import java.util.StringTokenizer; +import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.instructions.cp.ListObject; -import org.apache.sysml.runtime.util.ProgramConverter; +import org.apache.sysml.runtime.io.IOUtilFunctions; +import org.apache.sysml.runtime.util.FastBufferedDataOutputStream; public class PSRpcCall extends PSRpcObject { - private static final String PS_RPC_CALL_BEGIN = CDATA_BEGIN + "PSRPCCALL" + LEVELIN; - private static final String PS_RPC_CALL_END = LEVELOUT + CDATA_END; - - private String _method; + private int _method; private int _workerID; private ListObject _data; - public PSRpcCall(String method, int workerID, ListObject data) { + public PSRpcCall(int method, int workerID, ListObject data) { _method = method; _workerID = workerID; _data = data; } - public PSRpcCall(ByteBuffer buffer) { + public PSRpcCall(ByteBuffer buffer) throws IOException { deserialize(buffer); } - public void deserialize(ByteBuffer buffer) { - //FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks. - String input = bufferToString(buffer); - //header elimination - input = input.substring(PS_RPC_CALL_BEGIN.length(), input.length() - PS_RPC_CALL_END.length()); //remove start/end - StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM); - - _method = st.nextToken(); - _workerID = Integer.valueOf(st.nextToken()); - String dataStr = st.nextToken(); - _data = dataStr.equals(EMPTY) ? null : - (ListObject) ProgramConverter.parseDataObject(dataStr)[1]; - } - - public ByteBuffer serialize() { - //FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks. - StringBuilder sb = new StringBuilder(); - sb.append(PS_RPC_CALL_BEGIN); - sb.append(_method); - sb.append(COMPONENTS_DELIM); - sb.append(_workerID); - sb.append(COMPONENTS_DELIM); - if (_data == null) { - sb.append(EMPTY); - } else { - flushListObject(_data); - sb.append(ProgramConverter.serializeDataObject(DATA_KEY, _data)); - } - sb.append(PS_RPC_CALL_END); - return ByteBuffer.wrap(sb.toString().getBytes()); - } - - public String getMethod() { + public int getMethod() { return _method; } @@ -94,4 +57,37 @@ public class PSRpcCall extends PSRpcObject { public ListObject getData() { return _data; } + + public void deserialize(ByteBuffer buffer) throws IOException { + DataInputStream dis = new DataInputStream( + new ByteArrayInputStream(IOUtilFunctions.getBytes(buffer))); + _method = dis.readInt(); + validateMethod(_method); + _workerID = dis.readInt(); + if (dis.available() > 1) + _data = readAndDeserialize(dis); + dis.close(); + } + + public ByteBuffer serialize() throws IOException { + //TODO: Perf: use CacheDataOutput to avoid multiple copies (needs UTF handling) + ByteArrayOutputStream bos = new ByteArrayOutputStream(getApproxSerializedSize(_data)); + FastBufferedDataOutputStream dos = new FastBufferedDataOutputStream(bos); + dos.writeInt(_method); + dos.writeInt(_workerID); + if (_data != null) + serializeAndWriteListObject(_data, dos); + dos.flush(); + return ByteBuffer.wrap(bos.toByteArray()); + } + + private void validateMethod(int method) { + switch (method) { + case PUSH: + case PULL: + break; + default: + throw new DMLRuntimeException("PSRpcCall: only support rpc method 'push' or 'pull'"); + } + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java index c8b4024..2d921de 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java @@ -22,36 +22,36 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; import java.io.IOException; import java.util.Collections; +import org.apache.spark.SparkConf; import org.apache.spark.network.TransportContext; +import org.apache.spark.network.netty.SparkTransportConf; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.util.LongAccumulator; import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy; -//TODO should be able to configure the port by users public class PSRpcFactory { - private static final int PORT = 5055; private static final String MODULE_NAME = "ps"; - private static TransportContext createTransportContext(LocalParamServer ps) { - TransportConf conf = new TransportConf(MODULE_NAME, new SystemPropertyConfigProvider()); + private static TransportContext createTransportContext(SparkConf conf, LocalParamServer ps) { + TransportConf tc = SparkTransportConf.fromSparkConf(conf, MODULE_NAME, 0);; PSRpcHandler handler = new PSRpcHandler(ps); - return new TransportContext(conf, handler); + return new TransportContext(tc, handler); } /** * Create and start the server * @return server */ - public static TransportServer createServer(LocalParamServer ps, String host) { - TransportContext context = createTransportContext(ps); - return context.createServer(host, PORT, Collections.emptyList()); + public static TransportServer createServer(SparkConf conf, LocalParamServer ps, String host) { + TransportContext context = createTransportContext(conf, ps); + return context.createServer(host, 0, Collections.emptyList()); // bind rpc to an ephemeral port } - public static SparkPSProxy createSparkPSProxy(String host, long rpcTimeout) throws IOException { - TransportContext context = createTransportContext(new LocalParamServer()); - return new SparkPSProxy(context.createClientFactory().createClient(host, PORT), rpcTimeout); + public static SparkPSProxy createSparkPSProxy(SparkConf conf, String host, int port, long rpcTimeout, LongAccumulator aRPC) throws IOException { + TransportContext context = createTransportContext(conf, new LocalParamServer()); + return new SparkPSProxy(context.createClientFactory().createClient(host, port), rpcTimeout, aRPC); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java index 3d73a37..a2c311e 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java @@ -21,10 +21,8 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL; import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH; -import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.EMPTY_DATA; -import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.ERROR; -import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.SUCCESS; +import java.io.IOException; import java.nio.ByteBuffer; import org.apache.commons.lang.exception.ExceptionUtils; @@ -35,6 +33,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; +import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.Type; import org.apache.sysml.runtime.instructions.cp.ListObject; public final class PSRpcHandler extends RpcHandler { @@ -47,28 +46,41 @@ public final class PSRpcHandler extends RpcHandler { @Override public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) { - PSRpcCall call = new PSRpcCall(buffer); + PSRpcCall call; + try { + call = new PSRpcCall(buffer); + } catch (IOException e) { + throw new DMLRuntimeException("PSRpcHandler: some error occurred when deserializing the rpc call.", e); + } PSRpcResponse response = null; switch (call.getMethod()) { case PUSH: try { _server.push(call.getWorkerID(), call.getData()); - response = new PSRpcResponse(SUCCESS, EMPTY_DATA); + response = new PSRpcResponse(Type.SUCCESS_EMPTY); } catch (DMLRuntimeException exception) { - response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception)); + response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception)); } finally { - callback.onSuccess(response.serialize()); + try { + callback.onSuccess(response.serialize()); + } catch (IOException e) { + throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e); + } } break; case PULL: ListObject data; try { data = _server.pull(call.getWorkerID()); - response = new PSRpcResponse(SUCCESS, data); + response = new PSRpcResponse(Type.SUCCESS, data); } catch (DMLRuntimeException exception) { - response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception)); + response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception)); } finally { - callback.onSuccess(response.serialize()); + try { + callback.onSuccess(response.serialize()); + } catch (IOException e) { + throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e); + } } break; default: http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java index c6d7fd3..7d3353f 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java @@ -19,39 +19,86 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; -import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; public abstract class PSRpcObject { - public static final String PUSH = "push"; - public static final String PULL = "pull"; - public static final String DATA_KEY = "data"; - public static final String EMPTY_DATA = ""; + public static final int PUSH = 1; + public static final int PULL = 2; - public abstract void deserialize(ByteBuffer buffer); + public abstract void deserialize(ByteBuffer buffer) throws IOException; - public abstract ByteBuffer serialize(); + public abstract ByteBuffer serialize() throws IOException; /** - * Convert direct byte buffer to string - * @param buffer direct byte buffer - * @return string + * Deep serialize and write of a list object (currently only support list containing matrices) + * @param lo a list object containing only matrices + * @param dos output data to write to */ - protected String bufferToString(ByteBuffer buffer) { - byte[] result = new byte[buffer.limit()]; - buffer.get(result, 0, buffer.limit()); - return new String(result); + protected void serializeAndWriteListObject(ListObject lo, DataOutput dos) throws IOException { + validateListObject(lo); + dos.writeInt(lo.getLength()); //write list length + dos.writeBoolean(lo.isNamedList()); //write list named + for (int i = 0; i < lo.getLength(); i++) { + if (lo.isNamedList()) + dos.writeUTF(lo.getName(i)); //write name + ((MatrixObject) lo.getData().get(i)) + .acquireReadAndRelease().write(dos); //write matrix + } + } + + protected ListObject readAndDeserialize(DataInput dis) throws IOException { + int listLen = dis.readInt(); + List<Data> data = new ArrayList<>(); + List<String> names = dis.readBoolean() ? + new ArrayList<>() : null; + for(int i=0; i<listLen; i++) { + if( names != null ) + names.add(dis.readUTF()); + MatrixBlock mb = new MatrixBlock(); + mb.readFields(dis); + data.add(ParamservUtils.newMatrixObject(mb, false)); + } + return new ListObject(data, names); } /** - * Flush the data into HDFS - * @param data list object + * Get serialization size of a list object + * (scheme: size|name|size|matrix) + * @param lo list object + * @return serialization size */ - protected void flushListObject(ListObject data) { - data.getData().stream().filter(d -> d instanceof CacheableData) - .forEach(d -> ((CacheableData<?>) d).exportData()); + protected int getApproxSerializedSize(ListObject lo) { + if( lo == null ) return 0; + long result = 4 + 1; // list length and of named + result += lo.getLength() * (Integer.BYTES); // bytes for the size of names + if (lo.isNamedList()) + result += lo.getNames().stream().mapToLong(s -> s.length()).sum(); + result += lo.getData().stream().mapToLong(d -> + ((MatrixObject)d).acquireReadAndRelease().getExactSizeOnDisk()).sum(); + if( result > Integer.MAX_VALUE ) + throw new DMLRuntimeException("Serialized size ("+result+") larger than Integer.MAX_VALUE."); + return (int) result; + } + + private void validateListObject(ListObject lo) { + for (Data d : lo.getData()) { + if (!(d instanceof MatrixObject)) { + throw new DMLRuntimeException(String.format("Paramserv func:" + + " Unsupported deep serialize of %s, which is not matrix.", d.getDebugName())); + } + } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java index 998c523..3517491 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java @@ -19,41 +19,43 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; -import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN; -import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END; -import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM; -import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY; -import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN; -import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT; - +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.IOException; import java.nio.ByteBuffer; -import java.util.StringTokenizer; import org.apache.sysml.runtime.instructions.cp.ListObject; -import org.apache.sysml.runtime.util.ProgramConverter; +import org.apache.sysml.runtime.io.IOUtilFunctions; +import org.apache.sysml.runtime.util.FastBufferedDataOutputStream; public class PSRpcResponse extends PSRpcObject { + public enum Type { + SUCCESS, + SUCCESS_EMPTY, + ERROR, + } + + private Type _status; + private Object _data; // Could be list object or exception - public static final int SUCCESS = 1; - public static final int ERROR = 2; - - private static final String PS_RPC_RESPONSE_BEGIN = CDATA_BEGIN + "PSRPCRESPONSE" + LEVELIN; - private static final String PS_RPC_RESPONSE_END = LEVELOUT + CDATA_END; - - private int _status; - private Object _data; // Could be list object or exception - - public PSRpcResponse(ByteBuffer buffer) { + public PSRpcResponse(ByteBuffer buffer) throws IOException { deserialize(buffer); } - public PSRpcResponse(int status, Object data) { + public PSRpcResponse(Type status) { + this(status, null); + } + + public PSRpcResponse(Type status, Object data) { _status = status; _data = data; + if( _status == Type.SUCCESS && data == null ) + _status = Type.SUCCESS_EMPTY; } public boolean isSuccessful() { - return _status == SUCCESS; + return _status != Type.ERROR; } public String getErrorMessage() { @@ -65,48 +67,42 @@ public class PSRpcResponse extends PSRpcObject { } @Override - public void deserialize(ByteBuffer buffer) { - //FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks. - String input = bufferToString(buffer); - //header elimination - input = input.substring(PS_RPC_RESPONSE_BEGIN.length(), input.length() - PS_RPC_RESPONSE_END.length()); //remove start/end - StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM); - - _status = Integer.valueOf(st.nextToken()); - String data = st.nextToken(); + public void deserialize(ByteBuffer buffer) throws IOException { + DataInputStream dis = new DataInputStream( + new ByteArrayInputStream(IOUtilFunctions.getBytes(buffer))); + _status = Type.values()[dis.readInt()]; switch (_status) { case SUCCESS: - _data = data.equals(EMPTY) ? null : - ProgramConverter.parseDataObject(data)[1]; + _data = readAndDeserialize(dis); + break; + case SUCCESS_EMPTY: break; case ERROR: - _data = data; + _data = dis.readUTF(); break; } + dis.close(); } @Override - public ByteBuffer serialize() { - //FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks. - - StringBuilder sb = new StringBuilder(); - sb.append(PS_RPC_RESPONSE_BEGIN); - sb.append(_status); - sb.append(COMPONENTS_DELIM); + public ByteBuffer serialize() throws IOException { + //TODO: Perf: use CacheDataOutput to avoid multiple copies (needs UTF handling) + int len = 4 + (_status==Type.SUCCESS ? getApproxSerializedSize((ListObject)_data) : + _status==Type.SUCCESS_EMPTY ? 0 : ((String)_data).length()); + ByteArrayOutputStream bos = new ByteArrayOutputStream(len); + FastBufferedDataOutputStream dos = new FastBufferedDataOutputStream(bos); + dos.writeInt(_status.ordinal()); switch (_status) { case SUCCESS: - if (_data.equals(EMPTY_DATA)) { - sb.append(EMPTY); - } else { - flushListObject((ListObject) _data); - sb.append(ProgramConverter.serializeDataObject(DATA_KEY, (ListObject) _data)); - } + serializeAndWriteListObject((ListObject) _data, dos); + break; + case SUCCESS_EMPTY: break; case ERROR: - sb.append(_data.toString()); + dos.writeUTF(_data.toString()); break; } - sb.append(PS_RPC_RESPONSE_END); - return ByteBuffer.wrap(sb.toString().getBytes()); + dos.flush(); + return ByteBuffer.wrap(bos.toByteArray()); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 6133987..fe238bd 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -56,6 +56,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; import org.apache.log4j.Logger; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.util.LongAccumulator; import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.lops.LopProperties; @@ -125,11 +126,25 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc // Get the compiled execution context LocalVariableMap newVarsMap = createVarsMap(sec); - ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); // level of par is 1 in spark backend + // Level of par is 1 in spark backend because one worker will be launched per task + ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES)); MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS)); + // Create the agg service's execution context + ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0); + + // Create the parameter server + ListObject model = sec.getListObject(getParam(PS_MODEL)); + ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC); + + // Get driver host + String host = sec.getSparkContext().getConf().get("spark.driver.host"); + + // Create the netty server for ps + TransportServer server = PSRpcFactory.createServer(sec.getSparkContext().getConf(),(LocalParamServer) ps, host); // Start the server + // Force all the instructions to CP type Recompiler.recompileProgramBlockHierarchy2Forced( newEC.getProgram().getProgramBlocks(), 0, new HashSet<>(), LopProperties.ExecType.CP); @@ -139,29 +154,24 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc HashMap<String, byte[]> clsMap = new HashMap<>(); String program = ProgramConverter.serializeSparkPSBody(body, clsMap); - // Get some configurations - String host = sec.getSparkContext().getConf().get("spark.driver.host"); - long rpcTimeout = sec.getSparkContext().getConf().contains("spark.rpc.askTimeout") ? - sec.getSparkContext().getConf().getTimeAsMs("spark.rpc.askTimeout") : - sec.getSparkContext().getConf().getTimeAsMs("spark.network.timeout", "120s"); + // Add the accumulators for statistics + LongAccumulator aSetup = sec.getSparkContext().sc().longAccumulator("setup"); + LongAccumulator aWorker = sec.getSparkContext().sc().longAccumulator("workersNum"); + LongAccumulator aUpdate = sec.getSparkContext().sc().longAccumulator("modelUpdate"); + LongAccumulator aIndex = sec.getSparkContext().sc().longAccumulator("batchIndex"); + LongAccumulator aGrad = sec.getSparkContext().sc().longAccumulator("gradCompute"); + LongAccumulator aRPC = sec.getSparkContext().sc().longAccumulator("rpcRequest"); + LongAccumulator aBatch = sec.getSparkContext().sc().longAccumulator("numBatches"); + LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs"); // Create remote workers - SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), getFrequency(), - getEpochs(), getBatchSize(), program, clsMap, host, rpcTimeout); - - // Create the agg service's execution context - ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0); - - // Create the parameter server - ListObject model = sec.getListObject(getParam(PS_MODEL)); - ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC); + SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), + getFrequency(), getEpochs(), getBatchSize(), program, clsMap, sec.getSparkContext().getConf(), + server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch); if (DMLScript.STATISTICS) Statistics.accPSSetupTime((long) tSetup.stop()); - // Create the netty server for ps - TransportServer server = PSRpcFactory.createServer((LocalParamServer) ps, host); // Start the server - try { ParamservUtils.doPartitionOnSpark(sec, features, labels, scheme, workerNum) // Do data partitioning .foreach(worker); // Run remote workers @@ -172,6 +182,16 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc server.close(); } + // Accumulate the statistics for remote workers + if (DMLScript.STATISTICS) { + Statistics.accPSSetupTime(aSetup.sum()); + Statistics.incWorkerNumber(aWorker.sum()); + Statistics.accPSLocalModelUpdateTime(aUpdate.sum()); + Statistics.accPSBatchIndexingTime(aIndex.sum()); + Statistics.accPSGradientComputeTime(aGrad.sum()); + Statistics.accPSRpcRequestTime(aRPC.sum()); + } + // Fetch the final model from ps sec.setVariable(output.getName(), ps.getResult()); } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java b/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java index 94941a1..18f0e54 100644 --- a/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java +++ b/src/main/java/org/apache/sysml/runtime/io/IOUtilFunctions.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.StringReader; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -608,4 +609,13 @@ public class IOUtilFunctions ba[ off+6 ] = (byte)((val >>> 8) & 0xFF); ba[ off+7 ] = (byte)((val >>> 0) & 0xFF); } + + public static byte[] getBytes(ByteBuffer buff) { + int len = buff.limit(); + if( buff.hasArray() ) + return Arrays.copyOf(buff.array(), len); + byte[] ret = new byte[len]; + buff.get(ret, buff.position(), len); + return ret; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index 1dd8362..44667ba 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -541,6 +541,10 @@ public class Statistics psNumWorkers.increment(); } + public static void incWorkerNumber(long n) { + psNumWorkers.add(n); + } + public static void accPSSetupTime(long t) { psSetupTime.add(t); } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java index 30eccb3..89235d7 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java @@ -1,15 +1,8 @@ package org.apache.sysml.test.integration.functions.paramserv; -import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; -import org.apache.sysml.api.mlcontext.MLContext; -import org.apache.sysml.api.mlcontext.Script; import org.apache.sysml.parser.Statement; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.junit.Test; @@ -42,12 +35,12 @@ public class ParamservSparkNNTest extends AutomatedTestBase { @Test public void testParamservBSPEpochDisjointContiguous() { - runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS); + runDMLTest(5, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS); } @Test public void testParamservASPEpochDisjointContiguous() { - runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS); + runDMLTest(5, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS); } @Test @@ -62,9 +55,14 @@ public class ParamservSparkNNTest extends AutomatedTestBase { private void runDMLTest(String testname, boolean exceptionExpected, Class<?> expectedException, String errMessage) { programArgs = new String[] { "-explain" }; + internalRunDMLTest(testname, exceptionExpected, expectedException, errMessage); + } + + private void internalRunDMLTest(String testname, boolean exceptionExpected, Class<?> expectedException, + String errMessage) { DMLScript.RUNTIME_PLATFORM oldRtplatform = AutomatedTestBase.rtplatform; boolean oldUseLocalSparkConfig = DMLScript.USE_LOCAL_SPARK_CONFIG; - AutomatedTestBase.rtplatform = DMLScript.RUNTIME_PLATFORM.SPARK; + AutomatedTestBase.rtplatform = DMLScript.RUNTIME_PLATFORM.HYBRID_SPARK; DMLScript.USE_LOCAL_SPARK_CONFIG = true; try { @@ -80,22 +78,7 @@ public class ParamservSparkNNTest extends AutomatedTestBase { } private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) { - Script script = dmlFromFile(SCRIPT_DIR + TEST_DIR + TEST_NAME1 + ".dml").in("$mode", Statement.PSModeType.REMOTE_SPARK.toString()) - .in("$epochs", String.valueOf(epochs)) - .in("$workers", String.valueOf(workers)) - .in("$utype", utype.toString()) - .in("$freq", freq.toString()) - .in("$batchsize", String.valueOf(batchsize)) - .in("$scheme", scheme.toString()); - - SparkConf conf = SparkExecutionContext.createSystemMLSparkConf().setAppName("ParamservSparkNNTest").setMaster("local[*]") - .set("spark.driver.allowMultipleContexts", "true"); - JavaSparkContext sc = new JavaSparkContext(conf); - MLContext ml = new MLContext(sc); - ml.setStatistics(true); - ml.execute(script); - ml.resetConfig(); - sc.stop(); - ml.close(); + programArgs = new String[] { "-explain", "-nvargs", "mode=REMOTE_SPARK", "epochs=" + epochs, "workers=" + workers, "utype=" + utype, "freq=" + freq, "batchsize=" + batchsize, "scheme=" + scheme}; + internalRunDMLTest(TEST_NAME1, false, null, null); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java index 57e1106..f2df1e6 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java @@ -19,13 +19,13 @@ package org.apache.sysml.test.integration.functions.paramserv; +import java.io.IOException; import java.util.Arrays; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall; import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject; import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse; -import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.ListObject; import org.junit.Assert; import org.junit.Test; @@ -33,24 +33,26 @@ import org.junit.Test; public class RpcObjectTest { @Test - public void testPSRpcCall() { + public void testPSRpcCall() throws IOException { MatrixObject mo1 = SerializationTest.generateDummyMatrix(10); MatrixObject mo2 = SerializationTest.generateDummyMatrix(20); - IntObject io = new IntObject(30); - ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io)); + ListObject lo = new ListObject(Arrays.asList(mo1, mo2)); PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, lo); PSRpcCall actual = new PSRpcCall(expected.serialize()); - Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array())); + Assert.assertTrue(Arrays.equals( + expected.serialize().array(), + actual.serialize().array())); } @Test - public void testPSRpcResponse() { + public void testPSRpcResponse() throws IOException { MatrixObject mo1 = SerializationTest.generateDummyMatrix(10); MatrixObject mo2 = SerializationTest.generateDummyMatrix(20); - IntObject io = new IntObject(30); - ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io)); - PSRpcResponse expected = new PSRpcResponse(PSRpcResponse.SUCCESS, lo); + ListObject lo = new ListObject(Arrays.asList(mo1, mo2)); + PSRpcResponse expected = new PSRpcResponse(PSRpcResponse.Type.SUCCESS, lo); PSRpcResponse actual = new PSRpcResponse(expected.serialize()); - Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array())); + Assert.assertTrue(Arrays.equals( + expected.serialize().array(), + actual.serialize().array())); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml index 4d0f32e..d1edc29 100644 --- a/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml +++ b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml @@ -19,7 +19,7 @@ # #------------------------------------------------------------- -e1 = "element1" +e1 = matrix(1, rows=100, cols=10) modelList = list(e1) X = matrix(1, rows=200, cols=30) Y = matrix(2, rows=200, cols=1) @@ -42,11 +42,9 @@ aggregation = function(list[unknown] model, print(toString(as.matrix(gradients["agg_service_err"]))) } -e2 = "element2" +e2 = matrix(2, rows=100, cols=10) params = list(e2) -modelList = list("model") - # Use paramserv function modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1) http://git-wip-us.apache.org/repos/asf/systemml/blob/e0c271fe/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml index ad16122..bf0de68 100644 --- a/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml +++ b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml @@ -19,7 +19,7 @@ # #------------------------------------------------------------- -e1 = "element1" +e1 = matrix(1, rows=100, cols=10) modelList = list(e1) X = matrix(1, rows=200, cols=30) Y = matrix(2, rows=200, cols=1) @@ -42,11 +42,9 @@ aggregation = function(list[unknown] model, modelResult = model } -e2 = "element2" +e2 = matrix(2, rows=100, cols=10) params = list(e2) -modelList = list("model") - # Use paramserv function modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1)