[MINOR] Various paramserv refactorings and code cleanups Closes #814.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/382f847d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/382f847d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/382f847d Branch: refs/heads/master Commit: 382f847de6e33cdb5386b5eb5912eb5da0dff8d6 Parents: e11ae6a Author: EdgarLGB <[email protected]> Authored: Fri Aug 3 19:17:15 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Aug 3 19:17:15 2018 -0700 ---------------------------------------------------------------------- .../controlprogram/paramserv/DCScheme.java | 61 ------- .../controlprogram/paramserv/DRRScheme.java | 57 ------- .../controlprogram/paramserv/DRScheme.java | 62 ------- .../paramserv/DataPartitionScheme.java | 40 ----- .../paramserv/DataPartitioner.java | 49 ------ .../controlprogram/paramserv/LocalPSWorker.java | 40 ++--- .../paramserv/LocalParamServer.java | 6 +- .../controlprogram/paramserv/ORScheme.java | 61 ------- .../controlprogram/paramserv/PSWorker.java | 1 + .../controlprogram/paramserv/ParamServer.java | 40 ++--- .../paramserv/ParamservUtils.java | 66 ++++++-- .../controlprogram/paramserv/SparkPSBody.java | 44 +++++ .../controlprogram/paramserv/SparkPSProxy.java | 84 ++++++++++ .../controlprogram/paramserv/SparkPSWorker.java | 158 +++++++++++++++++ .../paramserv/dp/DCLocalScheme.java | 62 +++++++ .../paramserv/dp/DCSparkScheme.java | 47 ++++++ .../paramserv/dp/DRLocalScheme.java | 63 +++++++ .../paramserv/dp/DRRLocalScheme.java | 58 +++++++ .../paramserv/dp/DRRSparkScheme.java | 45 +++++ .../paramserv/dp/DRSparkScheme.java | 69 ++++++++ .../paramserv/dp/DataPartitionLocalScheme.java | 40 +++++ .../paramserv/dp/DataPartitionSparkScheme.java | 76 +++++++++ .../dp/DataPartitionerSparkAggregator.java | 66 ++++++++ .../dp/DataPartitionerSparkMapper.java | 70 ++++++++ .../paramserv/dp/LocalDataPartitioner.java | 52 ++++++ .../paramserv/dp/ORLocalScheme.java | 62 +++++++ .../paramserv/dp/ORSparkScheme.java | 60 +++++++ .../paramserv/dp/SparkDataPartitioner.java | 106 ++++++++++++ .../controlprogram/paramserv/rpc/PSRpcCall.java | 86 ++++++++++ .../paramserv/rpc/PSRpcFactory.java | 61 +++++++ .../paramserv/rpc/PSRpcHandler.java | 95 +++++++++++ .../paramserv/rpc/PSRpcObject.java | 107 ++++++++++++ .../paramserv/rpc/PSRpcResponse.java | 101 +++++++++++ .../paramserv/spark/DCSparkScheme.java | 47 ------ .../paramserv/spark/DRRSparkScheme.java | 45 ----- .../paramserv/spark/DRSparkScheme.java | 69 -------- .../spark/DataPartitionSparkScheme.java | 76 --------- .../spark/DataPartitionerSparkAggregator.java | 66 -------- .../spark/DataPartitionerSparkMapper.java | 70 -------- .../paramserv/spark/ORSparkScheme.java | 60 ------- .../paramserv/spark/SparkDataPartitioner.java | 106 ------------ .../paramserv/spark/SparkPSBody.java | 44 ----- .../paramserv/spark/SparkPSProxy.java | 85 ---------- .../paramserv/spark/SparkPSWorker.java | 168 ------------------- .../paramserv/spark/rpc/PSRpcCall.java | 86 ---------- .../paramserv/spark/rpc/PSRpcFactory.java | 57 ------- .../paramserv/spark/rpc/PSRpcHandler.java | 95 ----------- .../paramserv/spark/rpc/PSRpcObject.java | 107 ------------ .../paramserv/spark/rpc/PSRpcResponse.java | 101 ----------- .../cp/ParamservBuiltinCPInstruction.java | 16 +- .../sysml/runtime/util/ProgramConverter.java | 2 +- .../paramserv/BaseDataPartitionerTest.java | 20 +-- .../paramserv/LocalDataPartitionerTest.java | 14 +- .../functions/paramserv/RpcObjectTest.java | 6 +- .../paramserv/SparkDataPartitionerTest.java | 12 +- .../paramserv/mnist_lenet_paramserv.dml | 118 +++++++------ .../mnist_lenet_paramserv_minimum_version.dml | 114 +++++++------ 57 files changed, 1851 insertions(+), 1828 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DCScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DCScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DCScheme.java deleted file mode 100644 index 00aaa21..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DCScheme.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv; - -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; - -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -/** - * Disjoint_Contiguous data partitioner: - * - * for each worker, use a right indexing - * operation X[beg:end,] to obtain contiguous, - * non-overlapping partitions of rows. - */ -public class DCScheme extends DataPartitionScheme { - - public static List<MatrixBlock> partition(int k, MatrixBlock mb) { - List<MatrixBlock> list = new ArrayList<>(); - long stepSize = (long) Math.ceil((double) mb.getNumRows() / k); - long begin = 1; - while (begin < mb.getNumRows()) { - long end = Math.min(begin - 1 + stepSize, mb.getNumRows()); - MatrixBlock pmo = ParamservUtils.sliceMatrixBlock(mb, begin, end); - list.add(pmo); - begin = end + 1; - } - return list; - } - - private List<MatrixObject> doPartitioning(int k, MatrixBlock mb) { - return partition(k, mb).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList()); - } - - @Override - public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { - List<MatrixObject> pfs = doPartitioning(workersNum, features); - List<MatrixObject> pls = doPartitioning(workersNum, labels); - return new Result(pfs, pls); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRRScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRRScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRRScheme.java deleted file mode 100644 index 90c62d6..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRRScheme.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv; - -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import java.util.stream.LongStream; - -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.util.DataConverter; - -/** - * Disjoint_Round_Robin data partitioner: - * for each worker, use a permutation multiply - * or simpler a removeEmpty such as removeEmpty - * (target=X, margin=rows, select=(seq(1,nrow(X))%%k)==id) - */ -public class DRRScheme extends DataPartitionScheme { - - public static MatrixBlock removeEmpty(MatrixBlock mb, int k, int workerId) { - double[] data = LongStream.range(0, mb.getNumRows()).mapToDouble(l -> l % k == workerId ? 1 : 0).toArray(); - MatrixBlock select = DataConverter.convertToMatrixBlock(data, true); - return mb.removeEmptyOperations(new MatrixBlock(), true, true, select); - } - - private MatrixObject internalRemoveEmpty(MatrixBlock mb, int k, int workerId) { - MatrixObject result = ParamservUtils.newMatrixObject(removeEmpty(mb, k, workerId)); - result.enableCleanup(false); - return result; - } - - @Override - public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { - List<MatrixObject> pfs = IntStream.range(0, workersNum).mapToObj(i -> internalRemoveEmpty(features, workersNum, i)).collect(Collectors.toList()); - List<MatrixObject> pls = IntStream.range(0, workersNum).mapToObj(i -> internalRemoveEmpty(labels, workersNum, i)).collect(Collectors.toList()); - return new Result(pfs, pls); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRScheme.java deleted file mode 100644 index 062a7ab..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRScheme.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv; - -import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED; - -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -/** - * Data partitioner Disjoint_Random: - * for each worker, use a permutation multiply P[beg:end,] %*% X, - * where P is constructed for example with P=table(seq(1,nrow(X)),sample(nrow(X), nrow(X))), - * i.e., sampling without replacement to ensure disjointness. - */ -public class DRScheme extends DataPartitionScheme { - - private List<MatrixBlock> partition(int k, MatrixBlock mb, MatrixBlock permutation) { - int batchSize = (int) Math.ceil((double) mb.getNumRows() / k); - return IntStream.range(0, k).mapToObj(i -> { - int begin = i * batchSize; - int end = Math.min((i + 1) * batchSize, mb.getNumRows()); - MatrixBlock slicedPerm = permutation.slice(begin, end - 1); - return slicedPerm.aggregateBinaryOperations(slicedPerm, mb, new MatrixBlock(), InstructionUtils.getMatMultOperator(k)); - }).collect(Collectors.toList()); - } - - private List<MatrixObject> internalDoPartitioning(int k, MatrixBlock mb, MatrixBlock permutation) { - return partition(k, mb, permutation).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList()); - } - - @Override - public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { - // Generate a single permutation matrix (workers use slices) - MatrixBlock permutation = ParamservUtils.generatePermutation(features.getNumRows(), SEED); - List<MatrixObject> pfs = internalDoPartitioning(workersNum, features, permutation); - List<MatrixObject> pls = internalDoPartitioning(workersNum, labels, permutation); - return new Result(pfs, pls); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionScheme.java deleted file mode 100644 index f2ea0aa..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionScheme.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv; - -import java.util.List; - -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -public abstract class DataPartitionScheme { - - public final class Result { - public final List<MatrixObject> pFeatures; - public final List<MatrixObject> pLabels; - - public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels) { - this.pFeatures = pFeatures; - this.pLabels = pLabels; - } - } - - public abstract Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels); -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java deleted file mode 100644 index 3f28cd1..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv; - -import org.apache.sysml.parser.Statement; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -public class DataPartitioner { - - private DataPartitionScheme _scheme; - - public DataPartitioner(Statement.PSScheme scheme) { - switch (scheme) { - case DISJOINT_CONTIGUOUS: - _scheme = new DCScheme(); - break; - case DISJOINT_ROUND_ROBIN: - _scheme = new DRRScheme(); - break; - case DISJOINT_RANDOM: - _scheme = new DRScheme(); - break; - case OVERLAP_RESHUFFLE: - _scheme = new ORScheme(); - break; - } - } - - public DataPartitionScheme.Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { - return _scheme.doPartitioning(workersNum, features, labels); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index f76fddb..04050b2 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -54,15 +54,17 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { incWorkerNumber(); try { long dataSize = _features.getNumRows(); - int totalIter = (int) Math.ceil((double) dataSize / _batchSize); + int batchIter = (int) Math.ceil((double) dataSize / _batchSize); switch (_freq) { case BATCH: - computeBatch(dataSize, totalIter); + computeBatch(dataSize, batchIter); break; case EPOCH: - computeEpoch(dataSize, totalIter); + computeEpoch(dataSize, batchIter); break; + default: + throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), _freq)); } if (LOG.isDebugEnabled()) { @@ -74,25 +76,23 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { return null; } - private void computeEpoch(long dataSize, int totalIter) { + private void computeEpoch(long dataSize, int batchIter) { for (int i = 0; i < _epochs; i++) { // Pull the global parameters from ps ListObject params = pullModel(); ListObject accGradients = null; - for (int j = 0; j < totalIter; j++) { - _ec.setVariable(Statement.PS_MODEL, params); - - ListObject gradients = computeGradients(dataSize, totalIter, i, j); + for (int j = 0; j < batchIter; j++) { + ListObject gradients = computeGradients(params, dataSize, batchIter, i, j); + boolean localUpdate = j < batchIter - 1; // Accumulate the intermediate gradients - accGradients = ParamservUtils.accrueGradients(accGradients, gradients); + accGradients = ParamservUtils.accrueGradients(accGradients, gradients, !localUpdate); // Update the local model with gradients - if( j < totalIter - 1 ) - params = updateModel(params, gradients, i, j, totalIter); - ParamservUtils.cleanupListObject(_ec, gradients); - + if(localUpdate) + params = updateModel(params, gradients, i, j, batchIter); + accNumBatches(1); } @@ -107,7 +107,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { } } - private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int totalIter) { + private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int batchIter) { Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null; globalParams = _ps.updateLocalModel(_ec, gradients, globalParams); @@ -117,7 +117,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { if (LOG.isDebugEnabled()) { LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. " + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", - getWorkerName(), globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter)); + getWorkerName(), globalParams.getDataSize(), i + 1, _epochs, j + 1, batchIter)); } return globalParams; } @@ -127,8 +127,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { for (int j = 0; j < totalIter; j++) { ListObject globalParams = pullModel(); - _ec.setVariable(Statement.PS_MODEL, globalParams); - ListObject gradients = computeGradients(dataSize, totalIter, i, j); + ListObject gradients = computeGradients(globalParams, dataSize, totalIter, i, j); // Push the gradients to ps pushGradients(gradients); @@ -163,7 +162,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { } } - private ListObject computeGradients(long dataSize, int totalIter, int i, int j) { + private ListObject computeGradients(ListObject params, long dataSize, int batchIter, int i, int j) { + _ec.setVariable(Statement.PS_MODEL, params); long begin = j * _batchSize + 1; long end = Math.min((j + 1) * _batchSize, dataSize); @@ -180,7 +180,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { LOG.debug(String.format("%s: got batch data [size:%d kb] of index from %d to %d [last index: %d]. " + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", getWorkerName(), bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs, - j + 1, totalIter)); + j + 1, batchIter)); } // Invoke the update function @@ -189,7 +189,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { accGradientComputeTime(tGrad); // Get the gradients - ListObject gradients = (ListObject) _ec.getVariable(_output.getName()); + ListObject gradients = _ec.getListObject(_output.getName()); ParamservUtils.cleanupData(_ec, Statement.PS_FEATURES); ParamservUtils.cleanupData(_ec, Statement.PS_LABELS); http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java index 0c73acb..a2904fe 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java @@ -30,7 +30,11 @@ public class LocalParamServer extends ParamServer { super(); } - public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { + public static LocalParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { + return new LocalParamServer(model, aggFunc, updateType, ec, workerNum); + } + + private LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { super(model, aggFunc, updateType, ec, workerNum); } http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ORScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ORScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ORScheme.java deleted file mode 100644 index 2692efa..0000000 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ORScheme.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.controlprogram.paramserv; - -import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED; - -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; - -/** - * Data partitioner Overlap_Reshuffle: - * for each worker, use a new permutation multiply P %*% X, - * where P is constructed for example with P=table(seq(1,nrow(X),sample(nrow(X), nrow(X)))) - */ -public class ORScheme extends DataPartitionScheme { - - public static List<MatrixBlock> partition(int k, MatrixBlock mb, List<MatrixBlock> permutations) { - return IntStream.range(0, k).mapToObj(i -> { - MatrixBlock permutation = permutations.get(i); - return permutation.aggregateBinaryOperations(permutation, mb, new MatrixBlock(), - InstructionUtils.getMatMultOperator(k)); - }).collect(Collectors.toList()); - } - - private List<MatrixObject> doPartitioning(int k, MatrixBlock mb, List<MatrixBlock> permutations) { - return partition(k, mb, permutations).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList()); - } - - @Override - public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { - // Generate a different permutation matrix for each worker - List<MatrixBlock> permutations = IntStream.range(0, workersNum) - .mapToObj(i -> ParamservUtils.generatePermutation(features.getNumRows(), SEED+i)) - .collect(Collectors.toList()); - List<MatrixObject> pfs = doPartitioning(workersNum, features, permutations); - List<MatrixObject> pls = doPartitioning(workersNum, labels, permutations); - return new Result(pfs, pls); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index 5f2d552..63600d1 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -36,6 +36,7 @@ import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; +// TODO use the validate features and labels to calculate the model precision when training public abstract class PSWorker implements Serializable { private static final long serialVersionUID = -3510485051178200118L; http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java index 0f5f70d..2b2249e 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -81,15 +81,10 @@ public abstract class ParamServer setupAggFunc(_ec, aggFunc); // broadcast initial model - try { - broadcastModel(); - } - catch (InterruptedException e) { - throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e); - } + broadcastModel(true); } - public void setupAggFunc(ExecutionContext ec, String aggFunc) { + protected void setupAggFunc(ExecutionContext ec, String aggFunc) { String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX); String ns = cfn[0]; String fname = cfn[1]; @@ -140,11 +135,9 @@ public abstract class ParamServer // Accumulate the intermediate gradients if( ACCRUE_BSP_GRADIENTS ) - _accGradients = ParamservUtils.accrueGradients( - _accGradients, gradients, true); + _accGradients = ParamservUtils.accrueGradients(_accGradients, gradients, true); else updateGlobalModel(gradients); - ParamservUtils.cleanupListObject(_ec, gradients); if (allFinished()) { // Update the global model with accrued gradients @@ -155,7 +148,7 @@ public abstract class ParamServer // Broadcast the updated model resetFinishedStates(); - broadcastModel(); + broadcastModel(true); if (LOG.isDebugEnabled()) LOG.debug("Global parameter is broadcasted successfully."); } @@ -199,7 +192,7 @@ public abstract class ParamServer _inst.processInstruction(ec); // Get the new model - ListObject newModel = (ListObject) ec.getVariable(_outputName); + ListObject newModel = ec.getListObject(_outputName); // Clean up the list according to the data referencing status ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL, newModel.getStatus()); @@ -218,23 +211,26 @@ public abstract class ParamServer private void setFinishedState(int workerID) { _finishedStates[workerID] = true; } - - private void broadcastModel() throws InterruptedException { - Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null; - - //broadcast copy of the model to all workers, cleaned up by workers - for (BlockingQueue<ListObject> q : _modelMap.values()) - q.put(ParamservUtils.copyList(_model)); - if (DMLScript.STATISTICS) - Statistics.accPSModelBroadcastTime((long) tBroad.stop()); + /** + * Broadcast the model for all workers + */ + private void broadcastModel(boolean par) { + IntStream stream = IntStream.range(0, _modelMap.size()); + (par ? stream.parallel() : stream).forEach(workerID -> { + try { + broadcastModel(workerID); + } catch (InterruptedException e) { + throw new DMLRuntimeException("Paramserv func: some error occurred when broadcasting model", e); + } + }); } private void broadcastModel(int workerID) throws InterruptedException { Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null; //broadcast copy of model to specific worker, cleaned up by worker - _modelMap.get(workerID).put(ParamservUtils.copyList(_model)); + _modelMap.get(workerID).put(ParamservUtils.copyList(_model, false)); if (DMLScript.STATISTICS) Statistics.accPSModelBroadcastTime((long) tBroad.stop()); http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index e9292d1..f8b5dda 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -58,8 +58,8 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator; -import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper; +import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionerSparkAggregator; +import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionerSparkMapper; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.cp.Data; @@ -86,12 +86,10 @@ public class ParamservUtils { * Deep copy the list object * * @param lo list object + * @param cleanup clean up the given list object * @return a new copied list object */ - public static ListObject copyList(ListObject lo) { - if (lo.getLength() == 0) { - return lo; - } + public static ListObject copyList(ListObject lo, boolean cleanup) { List<Data> newData = IntStream.range(0, lo.getLength()).mapToObj(i -> { Data oldData = lo.slice(i); if (oldData instanceof MatrixObject) @@ -101,7 +99,10 @@ public class ParamservUtils { else return oldData; }).collect(Collectors.toList()); - return new ListObject(newData, lo.getNames()); + ListObject result = new ListObject(newData, lo.getNames()); + if (cleanup) + ParamservUtils.cleanupListObject(lo); + return result; } /** @@ -197,6 +198,12 @@ public class ParamservUtils { return mb.slice((int) rl - 1, (int) rh - 1); } + /** + * Generate the permutation + * @param numEntries permutation size + * @param seed seed used to generate random number + * @return permutation matrix + */ public static MatrixBlock generatePermutation(int numEntries, long seed) { // Create a sequence and sample w/o replacement // (no need to materialize the sequence because ctable only uses its meta data) @@ -208,6 +215,12 @@ public class ParamservUtils { new MatrixBlock(numEntries, numEntries, true)); } + /** + * Get the namespace and function name of a given physical func name + * @param funcName physical func name (e.g., "ns:func") + * @param prefix prefix + * @return an string array of size 2 where array[0] is namespace and array[1] is name + */ public static String[] getCompleteFuncName(String funcName, String prefix) { String[] keys = DMLProgram.splitFunctionKey(funcName); String ns = (keys.length==2) ? keys[0] : null; @@ -373,9 +386,9 @@ public class ParamservUtils { Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null; // Get input RDD JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = (JavaPairRDD<MatrixIndexes, MatrixBlock>) - sec.getRDDHandleForMatrixObject(features, InputInfo.BinaryBlockInputInfo); + sec.getRDDHandleForMatrixObject(features, InputInfo.BinaryBlockInputInfo); JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD = (JavaPairRDD<MatrixIndexes, MatrixBlock>) - sec.getRDDHandleForMatrixObject(labels, InputInfo.BinaryBlockInputInfo); + sec.getRDDHandleForMatrixObject(labels, InputInfo.BinaryBlockInputInfo); DataPartitionerSparkMapper mapper = new DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows()); JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = ParamservUtils @@ -408,21 +421,38 @@ public class ParamservUtils { return result; } - public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) { - return accrueGradients(accGradients, gradients, false); + /** + * Accumulate the given gradients into the accrued gradients + * + * @param accGradients accrued gradients list object + * @param gradients given gradients list object + * @param cleanup clean up the given gradients list object + * @return new accrued gradients list object + */ + public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean cleanup) { + return accrueGradients(accGradients, gradients, false, cleanup); } - - public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean par) { + + /** + * Accumulate the given gradients into the accrued gradients + * + * @param accGradients accrued gradients list object + * @param gradients given gradients list object + * @param par parallel execution + * @param cleanup clean up the given gradients list object + * @return new accrued gradients list object + */ + public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean par, boolean cleanup) { if (accGradients == null) - return ParamservUtils.copyList(gradients); + return ParamservUtils.copyList(gradients, cleanup); IntStream range = IntStream.range(0, accGradients.getLength()); (par ? range.parallel() : range).forEach(i -> { - MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead(); - MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead(); + MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireReadAndRelease(); + MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireReadAndRelease(); mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2); - ((MatrixObject) accGradients.getData().get(i)).release(); - ((MatrixObject) gradients.getData().get(i)).release(); }); + if (cleanup) + ParamservUtils.cleanupListObject(gradients); return accGradients; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSBody.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSBody.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSBody.java new file mode 100644 index 0000000..58690a6 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSBody.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; + +/** + * Wrapper class containing all needed for launching spark remote worker + */ +public class SparkPSBody { + + private ExecutionContext _ec; + + public SparkPSBody() {} + + public SparkPSBody(ExecutionContext ec) { + _ec = ec; + } + + public ExecutionContext getEc() { + return _ec; + } + + public void setEc(ExecutionContext ec) { + this._ec = ec; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSProxy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSProxy.java new file mode 100644 index 0000000..fd88b83 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSProxy.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcObject.PULL; +import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcObject.PUSH; + +import java.io.IOException; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.util.LongAccumulator; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall; +import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcResponse; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysml.runtime.instructions.cp.ListObject; + +public class SparkPSProxy extends ParamServer { + + private final TransportClient _client; + private final long _rpcTimeout; + private final LongAccumulator _aRPC; + + public SparkPSProxy(TransportClient client, long rpcTimeout, LongAccumulator aRPC) { + super(); + _client = client; + _rpcTimeout = rpcTimeout; + _aRPC = aRPC; + } + + private void accRpcRequestTime(Timing tRpc) { + if (DMLScript.STATISTICS) + _aRPC.add((long) tRpc.stop()); + } + + @Override + public void push(int workerID, ListObject value) { + Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null; + PSRpcResponse response; + try { + response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout)); + } catch (IOException e) { + throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients.", workerID), e); + } + accRpcRequestTime(tRpc); + if (!response.isSuccessful()) { + throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage())); + } + } + + @Override + public ListObject pull(int workerID) { + Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null; + PSRpcResponse response; + try { + response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout)); + } catch (IOException e) { + throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models.", workerID), e); + } + accRpcRequestTime(tRpc); + if (!response.isSuccessful()) { + throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage())); + } + return response.getResultModel(); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSWorker.java new file mode 100644 index 0000000..bc8fc9e --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSWorker.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.VoidFunction; +import org.apache.spark.util.LongAccumulator; +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.codegen.CodegenUtils; +import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcFactory; +import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.util.ProgramConverter; + +import scala.Tuple2; + +public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> { + + private static final long serialVersionUID = -8674739573419648732L; + + private final String _program; + private final HashMap<String, byte[]> _clsMap; + private final SparkConf _conf; + private final int _port; // rpc port + private final String _aggFunc; + private final LongAccumulator _aSetup; // accumulator for setup time + private final LongAccumulator _aWorker; // accumulator for worker number + private final LongAccumulator _aUpdate; // accumulator for model update + private final LongAccumulator _aIndex; // accumulator for batch indexing + private final LongAccumulator _aGrad; // accumulator for gradients computing + private final LongAccumulator _aRPC; // accumulator for rpc request + private final LongAccumulator _nBatches; //number of executed batches + private final LongAccumulator _nEpochs; //number of executed epoches + + public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs) { + _updFunc = updFunc; + _aggFunc = aggFunc; + _freq = freq; + _epochs = epochs; + _batchSize = batchSize; + _program = program; + _clsMap = clsMap; + _conf = conf; + _port = port; + _aSetup = aSetup; + _aWorker = aWorker; + _aUpdate = aUpdate; + _aIndex = aIndex; + _aGrad = aGrad; + _aRPC = aRPC; + _nBatches = aBatches; + _nEpochs = aEpochs; + } + + @Override + public String getWorkerName() { + return String.format("Spark worker_%d", _workerID); + } + + @Override + public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception { + Timing tSetup = new Timing(true); + configureWorker(input); + accSetupTime(tSetup); + + call(); // Launch the worker + } + + private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException { + _workerID = input._1; + + // Initialize codegen class cache (before program parsing) + for (Map.Entry<String, byte[]> e : _clsMap.entrySet()) { + CodegenUtils.getClassSync(e.getKey(), e.getValue()); + } + + // Deserialize the body to initialize the execution context + SparkPSBody body = ProgramConverter.parseSparkPSBody(_program, _workerID); + _ec = body.getEc(); + + // Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end + RemoteParForUtils.setupBufferPool(_workerID); + + // Create the ps proxy + _ps = PSRpcFactory.createSparkPSProxy(_conf, _port, _aRPC); + + // Initialize the update function + setupUpdateFunction(_updFunc, _ec); + + // Initialize the agg function + _ps.setupAggFunc(_ec, _aggFunc); + + // Lazy initialize the matrix of features and labels + setFeatures(ParamservUtils.newMatrixObject(input._2._1, false)); + setLabels(ParamservUtils.newMatrixObject(input._2._2, false)); + } + + + @Override + protected void incWorkerNumber() { + _aWorker.add(1); + } + + @Override + protected void accLocalModelUpdateTime(Timing time) { + if( time != null ) + _aUpdate.add((long) time.stop()); + } + + @Override + protected void accBatchIndexingTime(Timing time) { + if( time != null ) + _aIndex.add((long) time.stop()); + } + + @Override + protected void accGradientComputeTime(Timing time) { + if( time != null ) + _aGrad.add((long) time.stop()); + } + + @Override + protected void accNumEpochs(int n) { + _nEpochs.add(n); + } + + @Override + protected void accNumBatches(int n) { + _nBatches.add(n); + } + + private void accSetupTime(Timing time) { + if( time != null ) + _aSetup.add((long) time.stop()); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCLocalScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCLocalScheme.java new file mode 100644 index 0000000..9a3e502 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCLocalScheme.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +/** + * Disjoint_Contiguous data partitioner: + * + * for each worker, use a right indexing + * operation X[beg:end,] to obtain contiguous, + * non-overlapping partitions of rows. + */ +public class DCLocalScheme extends DataPartitionLocalScheme { + + public static List<MatrixBlock> partition(int k, MatrixBlock mb) { + List<MatrixBlock> list = new ArrayList<>(); + long stepSize = (long) Math.ceil((double) mb.getNumRows() / k); + long begin = 1; + while (begin < mb.getNumRows()) { + long end = Math.min(begin - 1 + stepSize, mb.getNumRows()); + MatrixBlock pmo = ParamservUtils.sliceMatrixBlock(mb, begin, end); + list.add(pmo); + begin = end + 1; + } + return list; + } + + private List<MatrixObject> doPartitioning(int k, MatrixBlock mb) { + return partition(k, mb).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList()); + } + + @Override + public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { + List<MatrixObject> pfs = doPartitioning(workersNum, features); + List<MatrixObject> pls = doPartitioning(workersNum, labels); + return new Result(pfs, pls); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCSparkScheme.java new file mode 100644 index 0000000..f42e0b6 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCSparkScheme.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.util.List; + +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +import scala.Tuple2; + +/** + * Spark Disjoint_Contiguous data partitioner: + * <p> + * For each row, find out the shifted place according to the workerID indicator + */ +public class DCSparkScheme extends DataPartitionSparkScheme { + + private static final long serialVersionUID = -2786906947020788787L; + + protected DCSparkScheme() { + // No-args constructor used for deserialization + } + + @Override + public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) { + List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features); + List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels); + return new Result(pfs, pls); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRLocalScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRLocalScheme.java new file mode 100644 index 0000000..464be99 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRLocalScheme.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +/** + * Data partitioner Disjoint_Random: + * for each worker, use a permutation multiply P[beg:end,] %*% X, + * where P is constructed for example with P=table(seq(1,nrow(X)),sample(nrow(X), nrow(X))), + * i.e., sampling without replacement to ensure disjointness. + */ +public class DRLocalScheme extends DataPartitionLocalScheme { + + private List<MatrixBlock> partition(int k, MatrixBlock mb, MatrixBlock permutation) { + int batchSize = (int) Math.ceil((double) mb.getNumRows() / k); + return IntStream.range(0, k).mapToObj(i -> { + int begin = i * batchSize; + int end = Math.min((i + 1) * batchSize, mb.getNumRows()); + MatrixBlock slicedPerm = permutation.slice(begin, end - 1); + return slicedPerm.aggregateBinaryOperations(slicedPerm, mb, new MatrixBlock(), InstructionUtils.getMatMultOperator(k)); + }).collect(Collectors.toList()); + } + + private List<MatrixObject> internalDoPartitioning(int k, MatrixBlock mb, MatrixBlock permutation) { + return partition(k, mb, permutation).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList()); + } + + @Override + public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { + // Generate a single permutation matrix (workers use slices) + MatrixBlock permutation = ParamservUtils.generatePermutation(features.getNumRows(), SEED); + List<MatrixObject> pfs = internalDoPartitioning(workersNum, features, permutation); + List<MatrixObject> pls = internalDoPartitioning(workersNum, labels, permutation); + return new Result(pfs, pls); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRLocalScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRLocalScheme.java new file mode 100644 index 0000000..2061903 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRLocalScheme.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.util.DataConverter; + +/** + * Disjoint_Round_Robin data partitioner: + * for each worker, use a permutation multiply + * or simpler a removeEmpty such as removeEmpty + * (target=X, margin=rows, select=(seq(1,nrow(X))%%k)==id) + */ +public class DRRLocalScheme extends DataPartitionLocalScheme { + + public static MatrixBlock removeEmpty(MatrixBlock mb, int k, int workerId) { + double[] data = LongStream.range(0, mb.getNumRows()).mapToDouble(l -> l % k == workerId ? 1 : 0).toArray(); + MatrixBlock select = DataConverter.convertToMatrixBlock(data, true); + return mb.removeEmptyOperations(new MatrixBlock(), true, true, select); + } + + private MatrixObject internalRemoveEmpty(MatrixBlock mb, int k, int workerId) { + MatrixObject result = ParamservUtils.newMatrixObject(removeEmpty(mb, k, workerId)); + result.enableCleanup(false); + return result; + } + + @Override + public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { + List<MatrixObject> pfs = IntStream.range(0, workersNum).mapToObj(i -> internalRemoveEmpty(features, workersNum, i)).collect(Collectors.toList()); + List<MatrixObject> pls = IntStream.range(0, workersNum).mapToObj(i -> internalRemoveEmpty(labels, workersNum, i)).collect(Collectors.toList()); + return new Result(pfs, pls); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRSparkScheme.java new file mode 100644 index 0000000..025f774 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRSparkScheme.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.util.List; + +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +import scala.Tuple2; + +/** + * Spark Disjoint_Round_Robin data partitioner: + */ +public class DRRSparkScheme extends DataPartitionSparkScheme { + + private static final long serialVersionUID = -3130831851505549672L; + + protected DRRSparkScheme() { + // No-args constructor used for deserialization + } + + @Override + public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) { + List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features); + List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels); + return new Result(pfs, pls); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRSparkScheme.java new file mode 100644 index 0000000..df61af9 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRSparkScheme.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +import scala.Tuple2; + +/** + * Spark data partitioner Disjoint_Random: + * + * For the current row block, find all the shifted place for each row (WorkerID => (row block ID, matrix) + */ +public class DRSparkScheme extends DataPartitionSparkScheme { + + private static final long serialVersionUID = -7655310624144544544L; + + protected DRSparkScheme() { + // No-args constructor used for deserialization + } + + @Override + public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) { + List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(rblkID, features); + List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(rblkID, labels); + return new Result(pfs, pls); + } + + private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int rblkID, MatrixBlock mb) { + MatrixBlock partialPerm = _globalPerms.get(0).getBlock(rblkID, 1); + + // For each row, find out the shifted place + return IntStream.range(0, mb.getNumRows()).mapToObj(r -> { + MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1); + long shiftedPosition = (long) partialPerm.getValue(r, 0); + + // Get the shifted block and position + int shiftedBlkID = (int) (shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE + 1); + + MatrixBlock indicator = _workerIndicator.getBlock(shiftedBlkID, 1); + int workerID = (int) indicator.getValue((int) shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE, 0); + return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB)); + }).collect(Collectors.toList()); + } + +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionLocalScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionLocalScheme.java new file mode 100644 index 0000000..8d03345 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionLocalScheme.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.util.List; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +public abstract class DataPartitionLocalScheme { + + public final class Result { + public final List<MatrixObject> pFeatures; + public final List<MatrixObject> pLabels; + + public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels) { + this.pFeatures = pFeatures; + this.pLabels = pLabels; + } + } + + public abstract Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels); +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionSparkScheme.java new file mode 100644 index 0000000..7992ac8 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionSparkScheme.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.io.Serializable; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.LongStream; + +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +import scala.Tuple2; + +public abstract class DataPartitionSparkScheme implements Serializable { + + protected final class Result { + protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures; // WorkerID => (rowID, matrix) + protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels; + + protected Result(List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures, List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels) { + this.pFeatures = pFeatures; + this.pLabels = pLabels; + } + } + + private static final long serialVersionUID = -3462829818083371171L; + + protected List<PartitionedBroadcast<MatrixBlock>> _globalPerms; // a list of global permutations + protected PartitionedBroadcast<MatrixBlock> _workerIndicator; // a matrix indicating to which worker the given row belongs + + protected void setGlobalPermutation(List<PartitionedBroadcast<MatrixBlock>> gps) { + _globalPerms = gps; + } + + protected void setWorkerIndicator(PartitionedBroadcast<MatrixBlock> wi) { + _workerIndicator = wi; + } + + /** + * Do non-reshuffled data partitioning according to worker indicator + * @param rblkID row block ID + * @param mb Matrix + * @return list of tuple (workerID, (row block ID, matrix row)) + */ + protected List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> nonShuffledPartition(int rblkID, MatrixBlock mb) { + MatrixBlock indicator = _workerIndicator.getBlock(rblkID, 1); + return LongStream.range(0, mb.getNumRows()).mapToObj(r -> { + int workerID = (int) indicator.getValue((int) r, 0); + MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1); + long shiftedPosition = r + (rblkID - 1) * OptimizerUtils.DEFAULT_BLOCKSIZE; + return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB)); + }).collect(Collectors.toList()); + } + + public abstract Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels); +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkAggregator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkAggregator.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkAggregator.java new file mode 100644 index 0000000..0314ccf --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkAggregator.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.io.Serializable; +import java.util.LinkedList; + +import org.apache.spark.api.java.function.PairFunction; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +import scala.Tuple2; + +public class DataPartitionerSparkAggregator implements PairFunction<Tuple2<Integer,LinkedList<Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>>, Integer, Tuple2<MatrixBlock, MatrixBlock>>, Serializable { + + private static final long serialVersionUID = -1245300852709085117L; + private long _fcol; + private long _lcol; + + public DataPartitionerSparkAggregator() { + + } + + public DataPartitionerSparkAggregator(long fcol, long lcol) { + _fcol = fcol; + _lcol = lcol; + } + + /** + * Row-wise combine the matrix + * @param input workerID => ordered list [(rowBlockID, (features, labels))] + * @return workerID => [(features, labels)] + * @throws Exception Some exception + */ + @Override + public Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> call(Tuple2<Integer, LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> input) throws Exception { + MatrixBlock fmb = new MatrixBlock(input._2.size(), (int) _fcol, false); + MatrixBlock lmb = new MatrixBlock(input._2.size(), (int) _lcol, false); + + for (int i = 0; i < input._2.size(); i++) { + MatrixBlock tmpFMB = input._2.get(i)._2._1; + MatrixBlock tmpLMB = input._2.get(i)._2._2; + // Row-wise aggregation + fmb = fmb.leftIndexingOperations(tmpFMB, i, i, 0, (int) _fcol - 1, fmb, MatrixObject.UpdateType.INPLACE_PINNED); + lmb = lmb.leftIndexingOperations(tmpLMB, i, i, 0, (int) _lcol - 1, lmb, MatrixObject.UpdateType.INPLACE_PINNED); + } + return new Tuple2<>(input._1, new Tuple2<>(fmb, lmb)); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkMapper.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkMapper.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkMapper.java new file mode 100644 index 0000000..bd30121 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkMapper.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.io.Serializable; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +import scala.Tuple2; + +public class DataPartitionerSparkMapper implements PairFlatMapFunction<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>, Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable { + + private static final long serialVersionUID = 1710721606050403296L; + private int _workersNum; + + private SparkDataPartitioner _dp; + + protected DataPartitionerSparkMapper() { + // No-args constructor used for deserialization + } + + public DataPartitionerSparkMapper(Statement.PSScheme scheme, int workersNum, SparkExecutionContext sec, int numEntries) { + _workersNum = workersNum; + _dp = new SparkDataPartitioner(scheme, sec, numEntries, workersNum); + } + + /** + * Do data partitioning + * @param input RowBlockID => (features, labels) + * @return WorkerID => (rowBlockID, (single row features, single row labels)) + * @throws Exception Some exception + */ + @Override + public Iterator<Tuple2<Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> call(Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>> input) + throws Exception { + List<Tuple2<Integer, Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>> partitions = new LinkedList<>(); + MatrixBlock features = input._2._1; + MatrixBlock labels = input._2._2; + DataPartitionSparkScheme.Result result = _dp.doPartitioning(_workersNum, features, labels, input._1); + for (int i = 0; i < result.pFeatures.size(); i++) { + Tuple2<Integer, Tuple2<Long, MatrixBlock>> ft = result.pFeatures.get(i); + Tuple2<Integer, Tuple2<Long, MatrixBlock>> lt = result.pLabels.get(i); + partitions.add(new Tuple2<>(ft._1, new Tuple2<>(ft._2._1, new Tuple2<>(ft._2._2, lt._2._2)))); + } + return partitions.iterator(); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/LocalDataPartitioner.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/LocalDataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/LocalDataPartitioner.java new file mode 100644 index 0000000..68cf9b6 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/LocalDataPartitioner.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +public class LocalDataPartitioner { + + private DataPartitionLocalScheme _scheme; + + public LocalDataPartitioner(Statement.PSScheme scheme) { + switch (scheme) { + case DISJOINT_CONTIGUOUS: + _scheme = new DCLocalScheme(); + break; + case DISJOINT_ROUND_ROBIN: + _scheme = new DRRLocalScheme(); + break; + case DISJOINT_RANDOM: + _scheme = new DRLocalScheme(); + break; + case OVERLAP_RESHUFFLE: + _scheme = new ORLocalScheme(); + break; + default: + throw new DMLRuntimeException(String.format("LocalDataPartitioner: not support data partition scheme '%s'", scheme)); + } + } + + public DataPartitionLocalScheme.Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { + return _scheme.doPartitioning(workersNum, features, labels); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORLocalScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORLocalScheme.java new file mode 100644 index 0000000..b7d8b97 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORLocalScheme.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +/** + * Data partitioner Overlap_Reshuffle: + * for each worker, use a new permutation multiply P %*% X, + * where P is constructed for example with P=table(seq(1,nrow(X),sample(nrow(X), nrow(X)))) + */ +public class ORLocalScheme extends DataPartitionLocalScheme { + + public static List<MatrixBlock> partition(int k, MatrixBlock mb, List<MatrixBlock> permutations) { + return IntStream.range(0, k).mapToObj(i -> { + MatrixBlock permutation = permutations.get(i); + return permutation.aggregateBinaryOperations(permutation, mb, new MatrixBlock(), + InstructionUtils.getMatMultOperator(k)); + }).collect(Collectors.toList()); + } + + private List<MatrixObject> doPartitioning(int k, MatrixBlock mb, List<MatrixBlock> permutations) { + return partition(k, mb, permutations).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList()); + } + + @Override + public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) { + // Generate a different permutation matrix for each worker + List<MatrixBlock> permutations = IntStream.range(0, workersNum) + .mapToObj(i -> ParamservUtils.generatePermutation(features.getNumRows(), SEED+i)) + .collect(Collectors.toList()); + List<MatrixObject> pfs = doPartitioning(workersNum, features, permutations); + List<MatrixObject> pls = doPartitioning(workersNum, labels, permutations); + return new Result(pfs, pls); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORSparkScheme.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORSparkScheme.java new file mode 100644 index 0000000..08b49b0 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORSparkScheme.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.dp; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +import scala.Tuple2; + +/** + * Spark data partitioner Overlap_Reshuffle: + * + */ +public class ORSparkScheme extends DataPartitionSparkScheme { + + private static final long serialVersionUID = 6867567406403580311L; + + protected ORSparkScheme() { + // No-args constructor used for deserialization + } + + @Override + public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) { + List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(numWorkers, rblkID, features); + List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(numWorkers, rblkID, labels); + return new Result(pfs, pls); + } + + private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int numWorkers, int rblkID, MatrixBlock mb) { + return IntStream.range(0, numWorkers).boxed().flatMap(workerID -> { + MatrixBlock partialPerm = _globalPerms.get(workerID).getBlock(rblkID, 1); + return IntStream.range(0, mb.getNumRows()).mapToObj(r -> { + MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1); + long shiftedPosition = (long) partialPerm.getValue(r, 0); + return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB)); + }); + }).collect(Collectors.toList()); + } +}
