[SYSTEMML-2420,2422] New distributed paramserv spark workers and rpc Closes #805.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/15ecb723 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/15ecb723 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/15ecb723 Branch: refs/heads/master Commit: 15ecb723e39e3154412ca8f8824c4554ee64ca35 Parents: 54dbe9b Author: EdgarLGB <[email protected]> Authored: Sat Jul 21 22:31:36 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jul 21 22:31:36 2018 -0700 ---------------------------------------------------------------------- .../controlprogram/paramserv/LocalPSWorker.java | 34 +++--- .../paramserv/LocalParamServer.java | 7 +- .../controlprogram/paramserv/PSWorker.java | 15 ++- .../controlprogram/paramserv/ParamServer.java | 39 ++++--- .../paramserv/ParamservUtils.java | 65 ++++++----- .../paramserv/spark/SparkPSBody.java | 6 +- .../paramserv/spark/SparkPSProxy.java | 68 +++++++++++ .../paramserv/spark/SparkPSWorker.java | 46 ++++++-- .../paramserv/spark/rpc/PSRpcCall.java | 97 ++++++++++++++++ .../paramserv/spark/rpc/PSRpcFactory.java | 57 ++++++++++ .../paramserv/spark/rpc/PSRpcHandler.java | 83 ++++++++++++++ .../paramserv/spark/rpc/PSRpcObject.java | 57 ++++++++++ .../paramserv/spark/rpc/PSRpcResponse.java | 112 +++++++++++++++++++ .../cp/ParamservBuiltinCPInstruction.java | 52 +++++++-- .../sysml/runtime/util/ProgramConverter.java | 11 +- .../java/org/apache/sysml/utils/Statistics.java | 6 + .../paramserv/ParamservLocalNNTest.java | 41 +++---- .../paramserv/ParamservSparkNNTest.java | 68 +++++++++-- .../functions/paramserv/RpcObjectTest.java | 56 ++++++++++ .../functions/paramserv/SerializationTest.java | 2 +- .../paramserv/paramserv-nn-asp-batch.dml | 53 --------- .../paramserv/paramserv-nn-asp-epoch.dml | 53 --------- .../paramserv/paramserv-nn-bsp-batch-dc.dml | 53 --------- .../paramserv/paramserv-nn-bsp-batch-dr.dml | 53 --------- .../paramserv/paramserv-nn-bsp-batch-drr.dml | 53 --------- .../paramserv/paramserv-nn-bsp-batch-or.dml | 53 --------- .../paramserv/paramserv-nn-bsp-epoch.dml | 53 --------- .../paramserv-spark-agg-service-failed.dml | 53 +++++++++ .../paramserv-spark-nn-bsp-batch-dc.dml | 53 --------- .../paramserv/paramserv-spark-worker-failed.dml | 53 +++++++++ .../functions/paramserv/paramserv-test.dml | 48 ++++++++ .../functions/paramserv/ZPackageSuite.java | 4 +- 32 files changed, 961 insertions(+), 543 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index bbf2dbe..c23943d 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -35,6 +35,9 @@ import org.apache.sysml.utils.Statistics; public class LocalPSWorker extends PSWorker implements Callable<Void> { protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName()); + private static final long serialVersionUID = 5195390748495357295L; + + protected LocalPSWorker() {} public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) { @@ -42,6 +45,11 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { } @Override + public String getWorkerName() { + return String.format("Local worker_%d", _workerID); + } + + @Override public Void call() throws Exception { if (DMLScript.STATISTICS) Statistics.incWorkerNumber(); @@ -60,10 +68,10 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { } if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Job finished.", _workerID)); + LOG.debug(String.format("%s: job finished.", getWorkerName())); } } catch (Exception e) { - throw new DMLRuntimeException(String.format("Local worker_%d failed", _workerID), e); + throw new DMLRuntimeException(String.format("%s failed", getWorkerName()), e); } return null; } @@ -93,7 +101,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL); if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); + LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1)); } } @@ -108,9 +116,9 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { Statistics.accPSLocalModelUpdateTime((long) tUpd.stop()); if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Local global parameter [size:%d kb] updated. " + LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. " + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", - _workerID, globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter)); + getWorkerName(), globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter)); } return globalParams; } @@ -129,17 +137,17 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL); } if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); + LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1)); } } } private ListObject pullModel() { // Pull the global parameters from ps - ListObject globalParams = (ListObject)_ps.pull(_workerID); + ListObject globalParams = _ps.pull(_workerID); if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Successfully pull the global parameters " - + "[size:%d kb] from ps.", _workerID, globalParams.getDataSize() / 1024)); + LOG.debug(String.format("%s: successfully pull the global parameters " + + "[size:%d kb] from ps.", getWorkerName(), globalParams.getDataSize() / 1024)); } return globalParams; } @@ -148,8 +156,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { // Push the gradients to ps _ps.push(_workerID, gradients); if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Successfully push the gradients " - + "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024)); + LOG.debug(String.format("%s: successfully push the gradients " + + "[size:%d kb] to ps.", getWorkerName(), gradients.getDataSize() / 1024)); } } @@ -168,8 +176,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { _ec.setVariable(Statement.PS_LABELS, bLabels); if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Got batch data [size:%d kb] of index from %d to %d [last index: %d]. " - + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", _workerID, + LOG.debug(String.format("%s: got batch data [size:%d kb] of index from %d to %d [last index: %d]. " + + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", getWorkerName(), bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs, j + 1, totalIter)); } http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java index 52372c9..0c73acb 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java @@ -22,11 +22,14 @@ package org.apache.sysml.runtime.controlprogram.paramserv; import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.ListObject; public class LocalParamServer extends ParamServer { + public LocalParamServer() { + super(); + } + public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { super(model, aggFunc, updateType, ec, workerNum); } @@ -37,7 +40,7 @@ public class LocalParamServer extends ParamServer { } @Override - public Data pull(int workerID) { + public ListObject pull(int workerID) { ListObject model; try { model = _modelMap.get(workerID).take(); http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index 1ab5f5e..464db9b 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -21,6 +21,7 @@ package org.apache.sysml.runtime.controlprogram.paramserv; import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX; +import java.io.Serializable; import java.util.ArrayList; import java.util.stream.Collectors; @@ -34,7 +35,10 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; -public abstract class PSWorker { +public abstract class PSWorker implements Serializable { + + private static final long serialVersionUID = -3510485051178200118L; + protected int _workerID; protected int _epochs; protected long _batchSize; @@ -50,10 +54,8 @@ public abstract class PSWorker { protected String _updFunc; protected Statement.PSFrequency _freq; - protected PSWorker() { + protected PSWorker() {} - } - protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) { _workerID = workerID; @@ -65,7 +67,10 @@ public abstract class PSWorker { _valLabels = valLabels; _ec = ec; _ps = ps; + setupUpdateFunction(updFunc, ec); + } + protected void setupUpdateFunction(String updFunc, ExecutionContext ec) { // Get the update function String[] cfn = ParamservUtils.getCompleteFuncName(updFunc, PS_FUNC_PREFIX); String ns = cfn[0]; @@ -125,4 +130,6 @@ public abstract class PSWorker { public MatrixObject getLabels() { return _labels; } + + public abstract String getWorkerName(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java index bd8ee36..2607036 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -42,7 +42,6 @@ import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.instructions.cp.CPOperand; -import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; import org.apache.sysml.runtime.instructions.cp.ListObject; import org.apache.sysml.utils.Statistics; @@ -53,17 +52,19 @@ public abstract class ParamServer protected static final boolean ACCRUE_BSP_GRADIENTS = true; // worker input queues and global model - protected final Map<Integer, BlockingQueue<ListObject>> _modelMap; + protected Map<Integer, BlockingQueue<ListObject>> _modelMap; private ListObject _model; //aggregation service - protected final ExecutionContext _ec; - private final Statement.PSUpdateType _updateType; - private final FunctionCallCPInstruction _inst; - private final String _outputName; - private final boolean[] _finishedStates; // Workers' finished states + protected ExecutionContext _ec; + private Statement.PSUpdateType _updateType; + private FunctionCallCPInstruction _inst; + private String _outputName; + private boolean[] _finishedStates; // Workers' finished states private ListObject _accGradients = null; + protected ParamServer() {} + protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { // init worker queues and global model _modelMap = new HashMap<>(workerNum); @@ -77,10 +78,22 @@ public abstract class ParamServer _ec = ec; _updateType = updateType; _finishedStates = new boolean[workerNum]; + setupAggFunc(_ec, aggFunc); + + // broadcast initial model + try { + broadcastModel(); + } + catch (InterruptedException e) { + throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e); + } + } + + public void setupAggFunc(ExecutionContext ec, String aggFunc) { String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX); String ns = cfn[0]; String fname = cfn[1]; - FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(ns, fname); + FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname); ArrayList<DataIdentifier> inputs = func.getInputParams(); ArrayList<DataIdentifier> outputs = func.getOutputParams(); @@ -101,19 +114,11 @@ public abstract class ParamServer ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) .collect(Collectors.toCollection(ArrayList::new)); _inst = new FunctionCallCPInstruction(ns, fname, boundInputs, inputNames, outputNames, "aggregate function"); - - // broadcast initial model - try { - broadcastModel(); - } - catch (InterruptedException e) { - throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e); - } } public abstract void push(int workerID, ListObject value); - public abstract Data pull(int workerID); + public abstract ListObject pull(int workerID); public ListObject getResult() { // All the model updating work has terminated, http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index b9fd7a8..cf27457 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -28,8 +28,11 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.commons.lang.StringUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.sysml.api.DMLScript; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.MultiThreadedHop; @@ -57,6 +60,7 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator; import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.ListObject; @@ -68,13 +72,14 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; import org.apache.sysml.runtime.util.ProgramConverter; +import org.apache.sysml.utils.Statistics; import scala.Tuple2; public class ParamservUtils { + protected static final Log LOG = LogFactory.getLog(ParamservUtils.class.getName()); public static final String PS_FUNC_PREFIX = "_ps_"; - public static long SEED = -1; // Used for generating permutation /** @@ -140,6 +145,14 @@ public class ParamservUtils { CacheableData<?> cd = (CacheableData<?>) data; cd.enableCleanup(true); ec.cleanupCacheableData(cd); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("%s has been deleted.", cd.getFileName())); + } + } + + public static void cleanupMatrixObject(ExecutionContext ec, MatrixObject mo) { + mo.enableCleanup(true); + ec.cleanupCacheableData(mo); } public static MatrixObject newMatrixObject(MatrixBlock mb) { @@ -365,6 +378,7 @@ public class ParamservUtils { @SuppressWarnings("unchecked") public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> doPartitionOnSpark(SparkExecutionContext sec, MatrixObject features, MatrixObject labels, Statement.PSScheme scheme, int workerNum) { + Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null; // Get input RDD JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = (JavaPairRDD<MatrixIndexes, MatrixBlock>) sec.getRDDHandleForMatrixObject(features, InputInfo.BinaryBlockInputInfo); @@ -372,33 +386,34 @@ public class ParamservUtils { sec.getRDDHandleForMatrixObject(labels, InputInfo.BinaryBlockInputInfo); DataPartitionerSparkMapper mapper = new DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows()); - return ParamservUtils.assembleTrainingData(features.getNumRows(), featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels)) + JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = ParamservUtils + .assembleTrainingData(features.getNumRows(), featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels)) .flatMapToPair(mapper) // Do the data partitioning on spark (workerID => (rowBlockID, (single row features, single row labels)) // Aggregate the partitioned matrix according to rowID for each worker // i.e. (workerID => ordered list[(rowBlockID, (single row features, single row labels)] - .aggregateByKey(new LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>(), - new Partitioner() { - private static final long serialVersionUID = -7937781374718031224L; - @Override - public int getPartition(Object workerID) { - return (int) workerID; - } - @Override - public int numPartitions() { - return workerNum; - } - }, - (list, input) -> { - list.add(input); - return list; - }, - (l1, l2) -> { - l1.addAll(l2); - l1.sort((o1, o2) -> o1._1.compareTo(o2._1)); - return l1; - }) - .mapToPair(new DataPartitionerSparkAggregator( - features.getNumColumns(), labels.getNumColumns())); + .aggregateByKey(new LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>(), new Partitioner() { + private static final long serialVersionUID = -7937781374718031224L; + @Override + public int getPartition(Object workerID) { + return (int) workerID; + } + @Override + public int numPartitions() { + return workerNum; + } + }, (list, input) -> { + list.add(input); + return list; + }, (l1, l2) -> { + l1.addAll(l2); + l1.sort((o1, o2) -> o1._1.compareTo(o2._1)); + return l1; + }) + .mapToPair(new DataPartitionerSparkAggregator(features.getNumColumns(), labels.getNumColumns())); + + if (DMLScript.STATISTICS) + Statistics.accPSSetupTime((long) tSetup.stop()); + return result; } public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) { http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java index ec10232..9354025 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java @@ -28,12 +28,10 @@ public class SparkPSBody { private ExecutionContext _ec; - public SparkPSBody() { - - } + public SparkPSBody() {} public SparkPSBody(ExecutionContext ec) { - this._ec = ec; + _ec = ec; } public ExecutionContext getEc() { http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java new file mode 100644 index 0000000..de7b6c6 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.spark; + +import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL; +import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH; + +import org.apache.spark.network.client.TransportClient; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; +import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall; +import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.utils.Statistics; + +public class SparkPSProxy extends ParamServer { + + private TransportClient _client; + private final long _rpcTimeout; + + public SparkPSProxy(TransportClient client, long rpcTimeout) { + super(); + _client = client; + _rpcTimeout = rpcTimeout; + } + + @Override + public void push(int workerID, ListObject value) { + Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null; + PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout)); + if (DMLScript.STATISTICS) + Statistics.accPSRpcRequestTime((long) tRpc.stop()); + if (!response.isSuccessful()) { + throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage())); + } + } + + @Override + public ListObject pull(int workerID) { + Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null; + PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout)); + if (DMLScript.STATISTICS) + Statistics.accPSRpcRequestTime((long) tRpc.stop()); + if (!response.isSuccessful()) { + throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage())); + } + return response.getResultModel(); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java index 466801f..fa06243 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java @@ -20,43 +20,58 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark; import java.io.IOException; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; import org.apache.spark.api.java.function.VoidFunction; +import org.apache.sysml.api.DMLScript; import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.codegen.CodegenUtils; -import org.apache.sysml.runtime.controlprogram.paramserv.PSWorker; +import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory; import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.util.ProgramConverter; +import org.apache.sysml.utils.Statistics; import scala.Tuple2; -public class SparkPSWorker extends PSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable { +public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> { private static final long serialVersionUID = -8674739573419648732L; private String _program; private HashMap<String, byte[]> _clsMap; + private String _host; // host ip of driver + private long _rpcTimeout; // rpc ask timeout + private String _aggFunc; - protected SparkPSWorker() { - // No-args constructor used for deserialization - } - - public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap) { + public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, String host, long rpcTimeout) { _updFunc = updFunc; + _aggFunc = aggFunc; _freq = freq; _epochs = epochs; _batchSize = batchSize; _program = program; _clsMap = clsMap; + _host = host; + _rpcTimeout = rpcTimeout; + } + + @Override + public String getWorkerName() { + return String.format("Spark worker_%d", _workerID); } @Override public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception { + Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null; configureWorker(input); + if (DMLScript.STATISTICS) + Statistics.accPSSetupTime((long) tSetup.stop()); + call(); // Launch the worker } private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException { @@ -73,5 +88,20 @@ public class SparkPSWorker extends PSWorker implements VoidFunction<Tuple2<Integ // Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end RemoteParForUtils.setupBufferPool(_workerID); + + // Create the ps proxy + _ps = PSRpcFactory.createSparkPSProxy(_host, _rpcTimeout); + + // Initialize the update function + setupUpdateFunction(_updFunc, _ec); + + // Initialize the agg function + _ps.setupAggFunc(_ec, _aggFunc); + + // Lazy initialize the matrix of features and labels + setFeatures(ParamservUtils.newMatrixObject(input._2._1)); + setLabels(ParamservUtils.newMatrixObject(input._2._2)); + _features.enableCleanup(false); + _labels.enableCleanup(false); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java new file mode 100644 index 0000000..999d409 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; + +import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN; +import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END; +import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM; +import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY; +import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN; +import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT; + +import java.nio.ByteBuffer; +import java.util.StringTokenizer; + +import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.runtime.util.ProgramConverter; + +public class PSRpcCall extends PSRpcObject { + + private static final String PS_RPC_CALL_BEGIN = CDATA_BEGIN + "PSRPCCALL" + LEVELIN; + private static final String PS_RPC_CALL_END = LEVELOUT + CDATA_END; + + private String _method; + private int _workerID; + private ListObject _data; + + public PSRpcCall(String method, int workerID, ListObject data) { + _method = method; + _workerID = workerID; + _data = data; + } + + public PSRpcCall(ByteBuffer buffer) { + deserialize(buffer); + } + + public void deserialize(ByteBuffer buffer) { + //FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks. + String input = bufferToString(buffer); + //header elimination + input = input.substring(PS_RPC_CALL_BEGIN.length(), input.length() - PS_RPC_CALL_END.length()); //remove start/end + StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM); + + _method = st.nextToken(); + _workerID = Integer.valueOf(st.nextToken()); + String dataStr = st.nextToken(); + _data = dataStr.equals(EMPTY) ? null : + (ListObject) ProgramConverter.parseDataObject(dataStr)[1]; + } + + public ByteBuffer serialize() { + //FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks. + StringBuilder sb = new StringBuilder(); + sb.append(PS_RPC_CALL_BEGIN); + sb.append(_method); + sb.append(COMPONENTS_DELIM); + sb.append(_workerID); + sb.append(COMPONENTS_DELIM); + if (_data == null) { + sb.append(EMPTY); + } else { + flushListObject(_data); + sb.append(ProgramConverter.serializeDataObject(DATA_KEY, _data)); + } + sb.append(PS_RPC_CALL_END); + return ByteBuffer.wrap(sb.toString().getBytes()); + } + + public String getMethod() { + return _method; + } + + public int getWorkerID() { + return _workerID; + } + + public ListObject getData() { + return _data; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java new file mode 100644 index 0000000..c8b4024 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; + +import java.io.IOException; +import java.util.Collections; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; +import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy; + +//TODO should be able to configure the port by users +public class PSRpcFactory { + + private static final int PORT = 5055; + private static final String MODULE_NAME = "ps"; + + private static TransportContext createTransportContext(LocalParamServer ps) { + TransportConf conf = new TransportConf(MODULE_NAME, new SystemPropertyConfigProvider()); + PSRpcHandler handler = new PSRpcHandler(ps); + return new TransportContext(conf, handler); + } + + /** + * Create and start the server + * @return server + */ + public static TransportServer createServer(LocalParamServer ps, String host) { + TransportContext context = createTransportContext(ps); + return context.createServer(host, PORT, Collections.emptyList()); + } + + public static SparkPSProxy createSparkPSProxy(String host, long rpcTimeout) throws IOException { + TransportContext context = createTransportContext(new LocalParamServer()); + return new SparkPSProxy(context.createClientFactory().createClient(host, PORT), rpcTimeout); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java new file mode 100644 index 0000000..3d73a37 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; + +import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL; +import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH; +import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.EMPTY_DATA; +import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.ERROR; +import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.SUCCESS; + +import java.nio.ByteBuffer; + +import org.apache.commons.lang.exception.ExceptionUtils; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; +import org.apache.sysml.runtime.instructions.cp.ListObject; + +public final class PSRpcHandler extends RpcHandler { + + private LocalParamServer _server; + + protected PSRpcHandler(LocalParamServer server) { + _server = server; + } + + @Override + public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) { + PSRpcCall call = new PSRpcCall(buffer); + PSRpcResponse response = null; + switch (call.getMethod()) { + case PUSH: + try { + _server.push(call.getWorkerID(), call.getData()); + response = new PSRpcResponse(SUCCESS, EMPTY_DATA); + } catch (DMLRuntimeException exception) { + response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception)); + } finally { + callback.onSuccess(response.serialize()); + } + break; + case PULL: + ListObject data; + try { + data = _server.pull(call.getWorkerID()); + response = new PSRpcResponse(SUCCESS, data); + } catch (DMLRuntimeException exception) { + response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception)); + } finally { + callback.onSuccess(response.serialize()); + } + break; + default: + throw new DMLRuntimeException(String.format("Does not support the rpc call for method %s", call.getMethod())); + } + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java new file mode 100644 index 0000000..c6d7fd3 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; + +import java.nio.ByteBuffer; + +import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.instructions.cp.ListObject; + +public abstract class PSRpcObject { + + public static final String PUSH = "push"; + public static final String PULL = "pull"; + public static final String DATA_KEY = "data"; + public static final String EMPTY_DATA = ""; + + public abstract void deserialize(ByteBuffer buffer); + + public abstract ByteBuffer serialize(); + + /** + * Convert direct byte buffer to string + * @param buffer direct byte buffer + * @return string + */ + protected String bufferToString(ByteBuffer buffer) { + byte[] result = new byte[buffer.limit()]; + buffer.get(result, 0, buffer.limit()); + return new String(result); + } + + /** + * Flush the data into HDFS + * @param data list object + */ + protected void flushListObject(ListObject data) { + data.getData().stream().filter(d -> d instanceof CacheableData) + .forEach(d -> ((CacheableData<?>) d).exportData()); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java new file mode 100644 index 0000000..998c523 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc; + +import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN; +import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END; +import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM; +import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY; +import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN; +import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT; + +import java.nio.ByteBuffer; +import java.util.StringTokenizer; + +import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.runtime.util.ProgramConverter; + +public class PSRpcResponse extends PSRpcObject { + + public static final int SUCCESS = 1; + public static final int ERROR = 2; + + private static final String PS_RPC_RESPONSE_BEGIN = CDATA_BEGIN + "PSRPCRESPONSE" + LEVELIN; + private static final String PS_RPC_RESPONSE_END = LEVELOUT + CDATA_END; + + private int _status; + private Object _data; // Could be list object or exception + + public PSRpcResponse(ByteBuffer buffer) { + deserialize(buffer); + } + + public PSRpcResponse(int status, Object data) { + _status = status; + _data = data; + } + + public boolean isSuccessful() { + return _status == SUCCESS; + } + + public String getErrorMessage() { + return (String) _data; + } + + public ListObject getResultModel() { + return (ListObject) _data; + } + + @Override + public void deserialize(ByteBuffer buffer) { + //FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks. + String input = bufferToString(buffer); + //header elimination + input = input.substring(PS_RPC_RESPONSE_BEGIN.length(), input.length() - PS_RPC_RESPONSE_END.length()); //remove start/end + StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM); + + _status = Integer.valueOf(st.nextToken()); + String data = st.nextToken(); + switch (_status) { + case SUCCESS: + _data = data.equals(EMPTY) ? null : + ProgramConverter.parseDataObject(data)[1]; + break; + case ERROR: + _data = data; + break; + } + } + + @Override + public ByteBuffer serialize() { + //FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks. + + StringBuilder sb = new StringBuilder(); + sb.append(PS_RPC_RESPONSE_BEGIN); + sb.append(_status); + sb.append(COMPONENTS_DELIM); + switch (_status) { + case SUCCESS: + if (_data.equals(EMPTY_DATA)) { + sb.append(EMPTY); + } else { + flushListObject((ListObject) _data); + sb.append(ProgramConverter.serializeDataObject(DATA_KEY, (ListObject) _data)); + } + break; + case ERROR: + sb.append(_data.toString()); + break; + } + sb.append(PS_RPC_RESPONSE_END); + return ByteBuffer.wrap(sb.toString().getBytes()); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 4e7a718..6133987 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -55,6 +55,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; import org.apache.log4j.Logger; +import org.apache.spark.network.server.TransportServer; import org.apache.sysml.api.DMLScript; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.lops.LopProperties; @@ -71,6 +72,7 @@ import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody; import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSWorker; +import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.matrix.operators.Operator; @@ -114,16 +116,16 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc } private void runOnSpark(SparkExecutionContext sec, PSModeType mode) { + Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null; + PSScheme scheme = getScheme(); int workerNum = getWorkerNum(mode); String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); - int k = getParLevel(workerNum); - // Get the compiled execution context LocalVariableMap newVarsMap = createVarsMap(sec); - ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, k); + ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); // level of par is 1 in spark backend MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES)); MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS)); @@ -131,16 +133,47 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc // Force all the instructions to CP type Recompiler.recompileProgramBlockHierarchy2Forced( newEC.getProgram().getProgramBlocks(), 0, new HashSet<>(), LopProperties.ExecType.CP); - + // Serialize all the needed params for remote workers SparkPSBody body = new SparkPSBody(newEC); HashMap<String, byte[]> clsMap = new HashMap<>(); String program = ProgramConverter.serializeSparkPSBody(body, clsMap); - SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getFrequency(), getEpochs(), getBatchSize(), program, clsMap); - ParamservUtils.doPartitionOnSpark(sec, features, labels, scheme, workerNum) // Do data partitioning - .foreach(worker); // Run remote workers + // Get some configurations + String host = sec.getSparkContext().getConf().get("spark.driver.host"); + long rpcTimeout = sec.getSparkContext().getConf().contains("spark.rpc.askTimeout") ? + sec.getSparkContext().getConf().getTimeAsMs("spark.rpc.askTimeout") : + sec.getSparkContext().getConf().getTimeAsMs("spark.network.timeout", "120s"); + + // Create remote workers + SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), getFrequency(), + getEpochs(), getBatchSize(), program, clsMap, host, rpcTimeout); + + // Create the agg service's execution context + ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0); + + // Create the parameter server + ListObject model = sec.getListObject(getParam(PS_MODEL)); + ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC); + + if (DMLScript.STATISTICS) + Statistics.accPSSetupTime((long) tSetup.stop()); + + // Create the netty server for ps + TransportServer server = PSRpcFactory.createServer((LocalParamServer) ps, host); // Start the server + try { + ParamservUtils.doPartitionOnSpark(sec, features, labels, scheme, workerNum) // Do data partitioning + .foreach(worker); // Run remote workers + } catch (Exception e) { + throw new DMLRuntimeException("Paramserv function failed: ", e); + } finally { + // Stop the netty server + server.close(); + } + + // Fetch the final model from ps + sec.setVariable(output.getName(), ps.getResult()); } private void runLocally(ExecutionContext ec, PSModeType mode) { @@ -176,8 +209,8 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES)); MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS)); List<LocalPSWorker> workers = IntStream.range(0, workerNum) - .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps)) - .collect(Collectors.toList()); + .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps)) + .collect(Collectors.toList()); // Do data partition PSScheme scheme = getScheme(); @@ -296,6 +329,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc private ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) { switch (mode) { case LOCAL: + case REMOTE_SPARK: return new LocalParamServer(model, aggFunc, updateType, ec, workerNum); default: throw new DMLRuntimeException("Unsupported parameter server: "+mode.name()); http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java index 1d2115e..fc9d9b4 100644 --- a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java +++ b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java @@ -143,7 +143,7 @@ public class ProgramConverter public static final String PB_IF = " IF" + LEVELIN; public static final String PB_FC = " FC" + LEVELIN; public static final String PB_EFC = " EFC" + LEVELIN; - + public static final String CONF_STATS = "stats"; // Used for parfor @@ -716,9 +716,10 @@ public class ProgramConverter builder.append(rSerializeProgramBlocks(ec.getProgram().getProgramBlocks(), clsMap)); builder.append(PBS_END); builder.append(NEWLINE); + builder.append(COMPONENTS_DELIM); + builder.append(NEWLINE); builder.append(PSBODY_END); - return builder.toString(); } @@ -868,7 +869,7 @@ public class ProgramConverter value = mo.getFileName(); PartitionFormat partFormat = (mo.getPartitionFormat()!=null) ? new PartitionFormat( mo.getPartitionFormat(),mo.getPartitionSize()) : PartitionFormat.NONE; - metaData = new String[9]; + metaData = new String[11]; metaData[0] = String.valueOf( mc.getRows() ); metaData[1] = String.valueOf( mc.getCols() ); metaData[2] = String.valueOf( mc.getRowsPerBlock() ); @@ -878,6 +879,8 @@ public class ProgramConverter metaData[6] = OutputInfo.outputInfoToString( md.getOutputInfo() ); metaData[7] = String.valueOf( partFormat ); metaData[8] = String.valueOf( mo.getUpdateType() ); + metaData[9] = String.valueOf(mo.isHDFSFileExists()); + metaData[10] = String.valueOf(mo.isCleanupEnabled()); break; case LIST: // SCHEMA: <name>|<datatype>|<valuetype>|value|<metadata>|<tab>element1<tab>element2<tab>element3 (this is the list) @@ -1683,6 +1686,8 @@ public class ProgramConverter if( partFormat._dpf != PDataPartitionFormat.NONE ) mo.setPartitioned( partFormat._dpf, partFormat._N ); mo.setUpdateType(inplace); + mo.setHDFSFileExists(Boolean.valueOf(st.nextToken())); + mo.enableCleanup(Boolean.valueOf(st.nextToken())); dat = mo; break; } http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index 8f0d853..1dd8362 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -125,6 +125,7 @@ public class Statistics private static final LongAdder psLocalModelUpdateTime = new LongAdder(); private static final LongAdder psModelBroadcastTime = new LongAdder(); private static final LongAdder psBatchIndexTime = new LongAdder(); + private static final LongAdder psRpcRequestTime = new LongAdder(); //PARFOR optimization stats (low frequency updates) private static long parforOptTime = 0; //in milli sec @@ -564,6 +565,10 @@ public class Statistics psBatchIndexTime.add(t); } + public static void accPSRpcRequestTime(long t) { + psRpcRequestTime.add(t); + } + public static String getCPHeavyHitterCode( Instruction inst ) { String opcode = null; @@ -1003,6 +1008,7 @@ public class Statistics psLocalModelUpdateTime.doubleValue() / 1000, psAggregationTime.doubleValue() / 1000)); sb.append(String.format("Paramserv model broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 1000)); sb.append(String.format("Paramserv batch slice time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000)); + sb.append(String.format("Paramserv RPC request time:\t%.3f secs.\n", psRpcRequestTime.doubleValue() / 1000)); } if( parforOptCount>0 ){ sb.append("ParFor loops optimized:\t\t" + getParforOptCount() + ".\n"); http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java index d5fd509..905bfd1 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java @@ -19,75 +19,66 @@ package org.apache.sysml.test.integration.functions.paramserv; +import org.apache.sysml.parser.Statement; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.junit.Test; public class ParamservLocalNNTest extends AutomatedTestBase { - private static final String TEST_NAME1 = "paramserv-nn-bsp-batch-dc"; - private static final String TEST_NAME2 = "paramserv-nn-asp-batch"; - private static final String TEST_NAME3 = "paramserv-nn-bsp-epoch"; - private static final String TEST_NAME4 = "paramserv-nn-asp-epoch"; - private static final String TEST_NAME5 = "paramserv-nn-bsp-batch-drr"; - private static final String TEST_NAME6 = "paramserv-nn-bsp-batch-dr"; - private static final String TEST_NAME7 = "paramserv-nn-bsp-batch-or"; + private static final String TEST_NAME = "paramserv-test"; private static final String TEST_DIR = "functions/paramserv/"; private static final String TEST_CLASS_DIR = TEST_DIR + ParamservLocalNNTest.class.getSimpleName() + "/"; @Override public void setUp() { - addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {})); - addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {})); - addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {})); - addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {})); - addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {})); - addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {})); - addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {})); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {})); } @Test public void testParamservBSPBatchDisjointContiguous() { - runDMLTest(TEST_NAME1); + runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS); } @Test public void testParamservASPBatch() { - runDMLTest(TEST_NAME2); + runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS); } @Test public void testParamservBSPEpoch() { - runDMLTest(TEST_NAME3); + runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS); } @Test public void testParamservASPEpoch() { - runDMLTest(TEST_NAME4); + runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS); } @Test public void testParamservBSPBatchDisjointRoundRobin() { - runDMLTest(TEST_NAME5); + runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN); } @Test public void testParamservBSPBatchDisjointRandom() { - runDMLTest(TEST_NAME6); + runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM); } @Test public void testParamservBSPBatchOverlapReshuffle() { - runDMLTest(TEST_NAME7); + runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE); } - private void runDMLTest(String testname) { - TestConfiguration config = getTestConfiguration(testname); + private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) { + TestConfiguration config = getTestConfiguration(ParamservLocalNNTest.TEST_NAME); loadTestConfiguration(config); - programArgs = new String[] { "-explain" }; + programArgs = new String[] { "-explain", "-nvargs", "mode=LOCAL", "epochs=" + epochs, + "workers=" + workers, "utype=" + utype, "freq=" + freq, "batchsize=" + batchsize, + "scheme=" + scheme }; String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + testname + ".dml"; + fullDMLScriptName = HOME + ParamservLocalNNTest.TEST_NAME + ".dml"; runTest(true, false, null, null, -1); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java index 2441116..30eccb3 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java @@ -1,14 +1,24 @@ package org.apache.sysml.test.integration.functions.paramserv; +import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.Script; +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.junit.Test; public class ParamservSparkNNTest extends AutomatedTestBase { - private static final String TEST_NAME1 = "paramserv-spark-nn-bsp-batch-dc"; + private static final String TEST_NAME1 = "paramserv-test"; + private static final String TEST_NAME2 = "paramserv-spark-worker-failed"; + private static final String TEST_NAME3 = "paramserv-spark-agg-service-failed"; private static final String TEST_DIR = "functions/paramserv/"; private static final String TEST_CLASS_DIR = TEST_DIR + ParamservSparkNNTest.class.getSimpleName() + "/"; @@ -16,14 +26,42 @@ public class ParamservSparkNNTest extends AutomatedTestBase { @Override public void setUp() { addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {})); } @Test public void testParamservBSPBatchDisjointContiguous() { - runDMLTest(TEST_NAME1); + runDMLTest(2, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS); + } + + @Test + public void testParamservASPBatchDisjointContiguous() { + runDMLTest(2, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS); + } + + @Test + public void testParamservBSPEpochDisjointContiguous() { + runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS); + } + + @Test + public void testParamservASPEpochDisjointContiguous() { + runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS); } - private void runDMLTest(String testname) { + @Test + public void testParamservWorkerFailed() { + runDMLTest(TEST_NAME2, true, DMLException.class, "Invalid indexing by name in unnamed list: worker_err."); + } + + @Test + public void testParamservAggServiceFailed() { + runDMLTest(TEST_NAME3, true, DMLException.class, "Invalid indexing by name in unnamed list: agg_service_err."); + } + + private void runDMLTest(String testname, boolean exceptionExpected, Class<?> expectedException, String errMessage) { + programArgs = new String[] { "-explain" }; DMLScript.RUNTIME_PLATFORM oldRtplatform = AutomatedTestBase.rtplatform; boolean oldUseLocalSparkConfig = DMLScript.USE_LOCAL_SPARK_CONFIG; AutomatedTestBase.rtplatform = DMLScript.RUNTIME_PLATFORM.SPARK; @@ -32,16 +70,32 @@ public class ParamservSparkNNTest extends AutomatedTestBase { try { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); - programArgs = new String[] { "-explain" }; String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; - // The test is not already finished, so it is normal to have the NPE - runTest(true, true, DMLException.class, null, -1); + runTest(true, exceptionExpected, expectedException, errMessage, -1); } finally { AutomatedTestBase.rtplatform = oldRtplatform; DMLScript.USE_LOCAL_SPARK_CONFIG = oldUseLocalSparkConfig; } - } + private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) { + Script script = dmlFromFile(SCRIPT_DIR + TEST_DIR + TEST_NAME1 + ".dml").in("$mode", Statement.PSModeType.REMOTE_SPARK.toString()) + .in("$epochs", String.valueOf(epochs)) + .in("$workers", String.valueOf(workers)) + .in("$utype", utype.toString()) + .in("$freq", freq.toString()) + .in("$batchsize", String.valueOf(batchsize)) + .in("$scheme", scheme.toString()); + + SparkConf conf = SparkExecutionContext.createSystemMLSparkConf().setAppName("ParamservSparkNNTest").setMaster("local[*]") + .set("spark.driver.allowMultipleContexts", "true"); + JavaSparkContext sc = new JavaSparkContext(conf); + MLContext ml = new MLContext(sc); + ml.setStatistics(true); + ml.execute(script); + ml.resetConfig(); + sc.stop(); + ml.close(); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java new file mode 100644 index 0000000..57e1106 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.paramserv; + +import java.util.Arrays; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall; +import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject; +import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse; +import org.apache.sysml.runtime.instructions.cp.IntObject; +import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.junit.Assert; +import org.junit.Test; + +public class RpcObjectTest { + + @Test + public void testPSRpcCall() { + MatrixObject mo1 = SerializationTest.generateDummyMatrix(10); + MatrixObject mo2 = SerializationTest.generateDummyMatrix(20); + IntObject io = new IntObject(30); + ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io)); + PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, lo); + PSRpcCall actual = new PSRpcCall(expected.serialize()); + Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array())); + } + + @Test + public void testPSRpcResponse() { + MatrixObject mo1 = SerializationTest.generateDummyMatrix(10); + MatrixObject mo2 = SerializationTest.generateDummyMatrix(20); + IntObject io = new IntObject(30); + ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io)); + PSRpcResponse expected = new PSRpcResponse(PSRpcResponse.SUCCESS, lo); + PSRpcResponse actual = new PSRpcResponse(expected.serialize()); + Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array())); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java index 2a08ca6..64d6492 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java @@ -68,7 +68,7 @@ public class SerializationTest { Assert.assertEquals(io.getLongValue(), actualIO.getLongValue()); } - private MatrixObject generateDummyMatrix(int size) { + public static MatrixObject generateDummyMatrix(int size) { double[] dl = new double[size]; for (int i = 0; i < size; i++) { dl[i] = i; http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml deleted file mode 100644 index ba22942..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml +++ /dev/null @@ -1,53 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss - -# Generate the training data -[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() -n = nrow(images) - -# Generate the training data -[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() - -# Split into training and validation -val_size = n * 0.1 -X = images[(val_size+1):n,] -X_val = images[1:val_size,] -Y = labels[(val_size+1):n,] -Y_val = labels[1:val_size,] - -# Arguments -epochs = 10 -workers = 2 -batchsize = 32 - -# Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "BATCH", batchsize,"DISJOINT_CONTIGUOUS", "LOCAL") - -# Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) -loss_val = cross_entropy_loss::forward(probs_val, Y_val) -accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) - -# Output results -print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml deleted file mode 100644 index c8c6a2f..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml +++ /dev/null @@ -1,53 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss - -# Generate the training data -[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() -n = nrow(images) - -# Generate the training data -[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() - -# Split into training and validation -val_size = n * 0.1 -X = images[(val_size+1):n,] -X_val = images[1:val_size,] -Y = labels[(val_size+1):n,] -Y_val = labels[1:val_size,] - -# Arguments -epochs = 10 -workers = 2 -batchsize = 32 - -# Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "EPOCH", batchsize, "DISJOINT_CONTIGUOUS", "LOCAL") - -# Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) -loss_val = cross_entropy_loss::forward(probs_val, Y_val) -accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) - -# Output results -print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml deleted file mode 100644 index 78fc1c4..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml +++ /dev/null @@ -1,53 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss - -# Generate the training data -[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() -n = nrow(images) - -# Generate the training data -[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() - -# Split into training and validation -val_size = n * 0.1 -X = images[(val_size+1):n,] -X_val = images[1:val_size,] -Y = labels[(val_size+1):n,] -Y_val = labels[1:val_size,] - -# Arguments -epochs = 10 -workers = 2 -batchsize = 32 - -# Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_CONTIGUOUS", "LOCAL") - -# Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) -loss_val = cross_entropy_loss::forward(probs_val, Y_val) -accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) - -# Output results -print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml deleted file mode 100644 index 9191b5a..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml +++ /dev/null @@ -1,53 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss - -# Generate the training data -[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() -n = nrow(images) - -# Generate the training data -[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() - -# Split into training and validation -val_size = n * 0.1 -X = images[(val_size+1):n,] -X_val = images[1:val_size,] -Y = labels[(val_size+1):n,] -Y_val = labels[1:val_size,] - -# Arguments -epochs = 10 -workers = 2 -batchsize = 32 - -# Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_RANDOM", "LOCAL") - -# Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) -loss_val = cross_entropy_loss::forward(probs_val, Y_val) -accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) - -# Output results -print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml deleted file mode 100644 index ec18cb4..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml +++ /dev/null @@ -1,53 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss - -# Generate the training data -[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() -n = nrow(images) - -# Generate the training data -[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() - -# Split into training and validation -val_size = n * 0.1 -X = images[(val_size+1):n,] -X_val = images[1:val_size,] -Y = labels[(val_size+1):n,] -Y_val = labels[1:val_size,] - -# Arguments -epochs = 10 -workers = 4 -batchsize = 32 - -# Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_ROUND_ROBIN", "LOCAL") - -# Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) -loss_val = cross_entropy_loss::forward(probs_val, Y_val) -accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) - -# Output results -print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml deleted file mode 100644 index 928dde2..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml +++ /dev/null @@ -1,53 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss - -# Generate the training data -[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() -n = nrow(images) - -# Generate the training data -[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() - -# Split into training and validation -val_size = n * 0.1 -X = images[(val_size+1):n,] -X_val = images[1:val_size,] -Y = labels[(val_size+1):n,] -Y_val = labels[1:val_size,] - -# Arguments -epochs = 10 -workers = 2 -batchsize = 32 - -# Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "OVERLAP_RESHUFFLE", "LOCAL") - -# Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) -loss_val = cross_entropy_loss::forward(probs_val, Y_val) -accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) - -# Output results -print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file
