Repository: systemml Updated Branches: refs/heads/master 52891d28e -> 69ef76c06
[SYSTEMML-2364,66-88] Extended paramserv data partitioning schemes Closes #781. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/69ef76c0 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/69ef76c0 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/69ef76c0 Branch: refs/heads/master Commit: 69ef76c06a6aea0c1e9b6750a37997a950a81794 Parents: 52891d2 Author: EdgarLGB <[email protected]> Authored: Sun Jun 10 22:11:41 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Jun 10 22:11:42 2018 -0700 ---------------------------------------------------------------------- .../paramserv/DataPartitioner.java | 47 +++++ .../paramserv/DataPartitionerDC.java | 55 ++++++ .../paramserv/DataPartitionerDR.java | 63 ++++++ .../paramserv/DataPartitionerDRR.java | 60 ++++++ .../paramserv/DataPartitionerOR.java | 60 ++++++ .../controlprogram/paramserv/LocalPSWorker.java | 10 +- .../paramserv/LocalParamServer.java | 3 +- .../controlprogram/paramserv/PSWorker.java | 25 +-- .../controlprogram/paramserv/ParamServer.java | 13 +- .../paramserv/ParamservUtils.java | 20 +- .../cp/ParamservBuiltinCPInstruction.java | 89 ++++----- .../paramserv/DataPartitionerTest.java | 192 +++++++++++++++++++ .../functions/paramserv/ParamservFuncTest.java | 156 --------------- .../functions/paramserv/ParamservNNTest.java | 94 +++++++++ .../paramserv/ParamservRecompilationTest.java | 52 +++++ .../paramserv/ParamservRuntimeNegativeTest.java | 67 +++++++ .../paramserv/ParamservSyntaxTest.java | 99 ++++++++++ .../paramserv/mnist_lenet_paramserv.dml | 5 +- .../paramserv/paramserv-nn-asp-batch.dml | 2 +- .../paramserv/paramserv-nn-asp-epoch.dml | 2 +- .../paramserv/paramserv-nn-bsp-batch-dc.dml | 52 +++++ .../paramserv/paramserv-nn-bsp-batch-dr.dml | 52 +++++ .../paramserv/paramserv-nn-bsp-batch-drr.dml | 52 +++++ .../paramserv/paramserv-nn-bsp-batch-or.dml | 52 +++++ .../paramserv/paramserv-nn-bsp-batch.dml | 52 ----- .../paramserv/paramserv-nn-bsp-epoch.dml | 2 +- .../functions/paramserv/ZPackageSuite.java | 6 +- 27 files changed, 1085 insertions(+), 297 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java new file mode 100644 index 0000000..b94d765 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; + +public abstract class DataPartitioner { + + protected static final Log LOG = LogFactory.getLog(DataPartitioner.class.getName()); + + public abstract void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels); + + protected void setPartitionedData(List<LocalPSWorker> workers, List<MatrixObject> pfs, List<MatrixObject> pls) { + if (pfs.size() < workers.size()) { + if (LOG.isWarnEnabled()) { + LOG.warn(String.format("There is only %d batches of data but has %d workers. " + + "Hence, reset the number of workers with %d.", pfs.size(), workers.size(), pfs.size())); + } + workers = workers.subList(0, pfs.size()); + } + for (int i = 0; i < workers.size(); i++) { + workers.get(i).setFeatures(pfs.get(i)); + workers.get(i).setLabels(pls.get(i)); + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java new file mode 100644 index 0000000..0426855 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; + +/** + * Disjoint_Contiguous data partitioner: + * + * for each worker, use a right indexing + * operation X[beg:end,] to obtain contiguous, + * non-overlapping partitions of rows. + */ +public class DataPartitionerDC extends DataPartitioner { + @Override + public void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels) { + int workerNum = workers.size(); + List<MatrixObject> pfs = doPartitioning(workerNum, features); + List<MatrixObject> pls = doPartitioning(workerNum, labels); + setPartitionedData(workers, pfs, pls); + } + + private List<MatrixObject> doPartitioning(int k, MatrixObject mo) { + List<MatrixObject> list = new ArrayList<>(); + long stepSize = (long) Math.ceil(mo.getNumRows() / k); + long begin = 1; + while (begin < mo.getNumRows()) { + long end = Math.min(begin - 1 + stepSize, mo.getNumRows()); + MatrixObject pmo = ParamservUtils.sliceMatrix(mo, begin, end); + list.add(pmo); + begin = end + 1; + } + return list; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java new file mode 100644 index 0000000..ce5f71d --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +/** + * Data partitioner Disjoint_Random: + * for each worker, use a permutation multiply P[beg:end,] %*% X, + * where P is constructed for example with P=table(seq(1,nrow(X)),sample(nrow(X), nrow(X))), + * i.e., sampling without replacement to ensure disjointness. + */ +public class DataPartitionerDR extends DataPartitioner { + @Override + public void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels) { + // Generate a single permutation matrix (workers use slices) + MatrixBlock permutation = ParamservUtils.generatePermutation((int)features.getNumRows()); + List<MatrixObject> pfs = doPartitioning(workers.size(), features, permutation); + List<MatrixObject> pls = doPartitioning(workers.size(), labels, permutation); + setPartitionedData(workers, pfs, pls); + } + + private List<MatrixObject> doPartitioning(int k, MatrixObject mo, MatrixBlock permutation) { + MatrixBlock data = mo.acquireRead(); + int batchSize = (int) Math.ceil(mo.getNumRows() / k); + List<MatrixObject> pMatrices = IntStream.range(0, k).mapToObj(i -> { + int begin = i * batchSize + 1; + int end = (int) Math.min((i + 1) * batchSize, mo.getNumRows()); + MatrixBlock slicedPerm = permutation.slice(begin - 1, end - 1); + MatrixBlock output = slicedPerm.aggregateBinaryOperations(slicedPerm, + data, new MatrixBlock(), InstructionUtils.getMatMultOperator(k)); + MatrixObject result = ParamservUtils.newMatrixObject(); + result.acquireModify(output); + result.release(); + return result; + }).collect(Collectors.toList()); + mo.release(); + return pMatrices; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java new file mode 100644 index 0000000..9d2f666 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.util.DataConverter; + +/** + * Disjoint_Round_Robin data partitioner: + * for each worker, use a permutation multiply + * or simpler a removeEmpty such as removeEmpty + * (target=X, margin=rows, select=(seq(1,nrow(X))%%k)==id) + */ +public class DataPartitionerDRR extends DataPartitioner { + @Override + public void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels) { + List<MatrixObject> pfs = IntStream.range(0, workers.size()) + .mapToObj(i -> removeEmpty(features, workers.size(), i)).collect(Collectors.toList()); + List<MatrixObject> pls = IntStream.range(0, workers.size()) + .mapToObj(i -> removeEmpty(labels, workers.size(), i)).collect(Collectors.toList()); + setPartitionedData(workers, pfs, pls); + } + + private MatrixObject removeEmpty(MatrixObject mo, int k, int workerId) { + MatrixObject result = ParamservUtils.newMatrixObject(); + MatrixBlock tmp = mo.acquireRead(); + double[] data = LongStream.range(0, mo.getNumRows()) + .mapToDouble(l -> l % k == workerId ? 1 : 0).toArray(); + MatrixBlock select = DataConverter.convertToMatrixBlock(data, true); + MatrixBlock resultMB = tmp.removeEmptyOperations(new MatrixBlock(), true, true, select); + mo.release(); + result.acquireModify(resultMB); + result.release(); + result.enableCleanup(false); + return result; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java new file mode 100644 index 0000000..fabbd74 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +/** + * Data partitioner Overlap_Reshuffle: + * for each worker, use a new permutation multiply P %*% X, + * where P is constructed for example with P=table(seq(1,nrow(X),sample(nrow(X), nrow(X)))) + */ +public class DataPartitionerOR extends DataPartitioner { + @Override + public void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels) { + // Generate a different permutation matrix for each worker + List<MatrixBlock> permutation = IntStream.range(0, workers.size()).mapToObj(i -> + ParamservUtils.generatePermutation((int)features.getNumRows())).collect(Collectors.toList()); + List<MatrixObject> pfs = doPartitioning(workers.size(), features, permutation); + List<MatrixObject> pls = doPartitioning(workers.size(), labels, permutation); + setPartitionedData(workers, pfs, pls); + } + + private List<MatrixObject> doPartitioning(int k, MatrixObject mo, List<MatrixBlock> lpermutation) { + MatrixBlock data = mo.acquireRead(); + List<MatrixObject> pMatrices = IntStream.range(0, k).mapToObj(i -> { + MatrixBlock permutation = lpermutation.get(i); + MatrixBlock output = permutation.aggregateBinaryOperations(permutation, + data, new MatrixBlock(), InstructionUtils.getMatMultOperator(k)); + MatrixObject result = ParamservUtils.newMatrixObject(); + result.acquireModify(output); + result.release(); + return result; + }).collect(Collectors.toList()); + mo.release(); + return pMatrices; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index 1583fbf..f9c2a00 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -34,8 +34,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName()); public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, - ExecutionContext ec, ParamServer ps) { - super(workerID, updFunc, freq, epochs, batchSize, ec, ps); + MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) { + super(workerID, updFunc, freq, epochs, batchSize, valFeatures, valLabels, ec, ps); } @Override @@ -81,7 +81,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { globalParams = _ps.updateModel(gradients, globalParams); if (LOG.isDebugEnabled()) { LOG.debug(String.format("Local worker_%d: Local global parameter [size:%d kb] updated.", - _workerID, globalParams.getDataSize())); + _workerID, globalParams.getDataSize())); } } } @@ -126,13 +126,13 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { _ps.push(_workerID, gradients); if (LOG.isDebugEnabled()) { LOG.debug(String.format("Local worker_%d: Successfully push the gradients " - + "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024)); + + "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024)); } } private ListObject computeGradients(long dataSize, int totalIter, int i, int j) { long begin = j * _batchSize + 1; - long end = Math.min(begin + _batchSize, dataSize); + long end = Math.min((j + 1) * _batchSize, dataSize); // Get batch features and labels MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end); http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java index bac507c..d20383d 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java @@ -29,8 +29,7 @@ import org.apache.sysml.runtime.instructions.cp.ListObject; public class LocalParamServer extends ParamServer { - public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, - int workerNum) { + public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { super(model, aggFunc, updateType, ec, workerNum); } http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index affa3c1..d94831e 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -51,13 +51,15 @@ public abstract class PSWorker { private final String _updFunc; protected final Statement.PSFrequency _freq; - protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, - int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { + protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, + MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) { _workerID = workerID; _updFunc = updFunc; _freq = freq; _epochs = epochs; _batchSize = batchSize; + _valFeatures = valFeatures; + _valLabels = valLabels; _ec = ec; _ps = ps; @@ -73,14 +75,13 @@ public abstract class PSWorker { ArrayList<DataIdentifier> inputs = func.getInputParams(); ArrayList<DataIdentifier> outputs = func.getOutputParams(); CPOperand[] boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); + .collect(Collectors.toCollection(ArrayList::new)); ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); - _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames, - "update function"); + .collect(Collectors.toCollection(ArrayList::new)); + _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames, "update function"); // Check the inputs of the update function checkInput(false, inputs, Expression.DataType.MATRIX, Statement.PS_FEATURES); @@ -118,11 +119,11 @@ public abstract class PSWorker { _labels = labels; } - public void setValFeatures(MatrixObject valFeatures) { - _valFeatures = valFeatures; + public MatrixObject getFeatures() { + return _features; } - public void setValLabels(MatrixObject valLabels) { - _valLabels = valLabels; + public MatrixObject getLabels() { + return _labels; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java index d7cd78d..09b760f 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -136,8 +136,7 @@ public abstract class ParamServer { // Check the output of the aggregation function if (outputs.size() != 1) { - throw new DMLRuntimeException(String.format("The output of the '%s' function should provide one list containing the updated model.", - aggFunc)); + throw new DMLRuntimeException(String.format("The output of the '%s' function should provide one list containing the updated model.", aggFunc)); } if (outputs.get(0).getDataType() != Expression.DataType.LIST) { throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", aggFunc)); @@ -145,12 +144,12 @@ public abstract class ParamServer { _output = outputs.get(0); CPOperand[] boundInputs = inputs.stream() - .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) - .toArray(CPOperand[]::new); + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); + .collect(Collectors.toCollection(ArrayList::new)); ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) - .collect(Collectors.toCollection(ArrayList::new)); + .collect(Collectors.toCollection(ArrayList::new)); _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames, "aggregate function"); } @@ -188,7 +187,7 @@ public abstract class ParamServer { } if (LOG.isDebugEnabled()) { LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.", - grad._gradients.getDataSize() / 1024, grad._workerID)); + grad._gradients.getDataSize() / 1024, grad._workerID)); } // Update and redistribute the model http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index 37b184f..2b1dca5 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.parser.Expression; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; @@ -77,6 +78,10 @@ public class ParamservUtils { cd.clearData(); } + public static MatrixObject newMatrixObject() { + return new MatrixObject(Expression.ValueType.DOUBLE, OptimizerUtils.getUniqueTempFileName(), new MetaDataFormat(new MatrixCharacteristics(-1, -1, -1, -1), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); + } + /** * Slice the matrix * @@ -86,9 +91,7 @@ public class ParamservUtils { * @return new sliced matrix */ public static MatrixObject sliceMatrix(MatrixObject mo, long rl, long rh) { - MatrixObject result = new MatrixObject(Expression.ValueType.DOUBLE, null, - new MetaDataFormat(new MatrixCharacteristics(-1, -1, -1, -1), - OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); + MatrixObject result = newMatrixObject(); MatrixBlock tmp = mo.acquireRead(); result.acquireModify(tmp.slice((int) rl - 1, (int) rh - 1)); mo.release(); @@ -96,4 +99,15 @@ public class ParamservUtils { result.enableCleanup(false); return result; } + + public static MatrixBlock generatePermutation(int numEntries) { + // Create a sequence and sample w/o replacement + MatrixBlock seq = MatrixBlock.seqOperations(1, numEntries, 1); + MatrixBlock sample = MatrixBlock.sampleOperations(numEntries, numEntries, false, -1); + + // Combine the sequence and sample as a table + MatrixBlock permutation = new MatrixBlock(numEntries, numEntries, true); + seq.ctableOperations(null, sample, 1.0, permutation); + return permutation; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 6e2b187..fd5bc56 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -72,10 +72,14 @@ import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDC; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDR; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDRR; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerOR; import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; -import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.matrix.operators.Operator; @@ -127,16 +131,21 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc ParamServer ps = createPS(mode, aggFunc, updateType, workerNum, model, aggServiceEC); // Create the local workers + MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES)); + MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS)); List<LocalPSWorker> workers = IntStream.range(0, workerNum) - .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), workerECs.get(i), ps)) + .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps)) .collect(Collectors.toList()); // Do data partition - doDataPartition(ec, workers); + PSScheme scheme = getScheme(); + doDataPartitioning(scheme, ec, workers); if (LOG.isDebugEnabled()) { - LOG.debug(String.format("\nConfiguration of paramserv func: \nmode: %s \nworkerNum: %d \nupdate frequency: %s \nstrategy: %s", - mode, workerNum, freq, updateType)); + LOG.debug(String.format("\nConfiguration of paramserv func: " + + "\nmode: %s \nworkerNum: %d \nupdate frequency: %s " + + "\nstrategy: %s \ndata partitioner: %s", + mode, workerNum, freq, updateType, scheme)); } // Launch the worker threads and wait for completion @@ -368,11 +377,26 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc return hyperparams; } - private void doDataPartition(ExecutionContext ec, List<LocalPSWorker> workers) { + private void doDataPartitioning(PSScheme scheme, ExecutionContext ec, List<LocalPSWorker> workers) { MatrixObject features = ec.getMatrixObject(getParam(PS_FEATURES)); MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS)); - MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES)); - MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS)); + switch (scheme) { + case DISJOINT_CONTIGUOUS: + doDataPartitioning(new DataPartitionerDC(), features, labels, workers); + break; + case DISJOINT_ROUND_ROBIN: + doDataPartitioning(new DataPartitionerDRR(), features, labels, workers); + break; + case DISJOINT_RANDOM: + doDataPartitioning(new DataPartitionerDR(), features, labels, workers); + break; + case OVERLAP_RESHUFFLE: + doDataPartitioning(new DataPartitionerOR(), features, labels, workers); + break; + } + } + + private PSScheme getScheme() { PSScheme scheme = DEFAULT_SCHEME; if (getParameterMap().containsKey(PS_SCHEME)) { try { @@ -381,53 +405,10 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc throw new DMLRuntimeException(String.format("Paramserv function: not support data partition scheme '%s'", getParam(PS_SCHEME))); } } - switch (scheme) { - case DISJOINT_CONTIGUOUS: - disjointContiguous(features, labels, valFeatures, valLabels, workers); - break; - case DISJOINT_RANDOM: - case OVERLAP_RESHUFFLE: - case DISJOINT_ROUND_ROBIN: - throw new DMLRuntimeException(String.format("Paramserv function: the scheme '%s' is not supported.", scheme)); - } + return scheme; } - private void disjointContiguous(MatrixObject features, MatrixObject labels, MatrixObject valFeatures, - MatrixObject valLabels, List<LocalPSWorker> workers) { - // training data - List<MatrixObject> pfs = disjointContiguous(workers.size(), features); - List<MatrixObject> pls = disjointContiguous(workers.size(), labels); - if (pfs.size() < workers.size()) { - if (LOG.isWarnEnabled()) { - LOG.warn(String.format("There is only %d batches of data but has %d workers. " - + "Hence, reset the number of workers with %d.", pfs.size(), workers.size(), pfs.size())); - } - workers = workers.subList(0, pfs.size()); - } - for (int i = 0; i < workers.size(); i++) { - workers.get(i).setFeatures(pfs.get(i)); - workers.get(i).setLabels(pls.get(i)); - } - - // validation data - List<MatrixObject> pvfs = disjointContiguous(workers.size(), valFeatures); - List<MatrixObject> pvls = disjointContiguous(workers.size(), valLabels); - for (int i = 0; i < workers.size(); i++) { - workers.get(i).setValFeatures(pvfs.get(i)); - workers.get(i).setValLabels(pvls.get(i)); - } - } - - private List<MatrixObject> disjointContiguous(int workerNum, MatrixObject mo) { - List<MatrixObject> list = new ArrayList<>(); - long stepSize = (long) Math.ceil(mo.getNumRows() / workerNum); - long begin = 1; - while (begin < mo.getNumRows()) { - long end = Math.min(begin + stepSize, mo.getNumRows()); - MatrixObject pmo = ParamservUtils.sliceMatrix(mo, begin, end); - list.add(pmo); - begin = end + 1; - } - return list; + private void doDataPartitioning(DataPartitioner dp, MatrixObject features, MatrixObject labels, List<LocalPSWorker> workers) { + dp.doPartitioning(workers, features, labels); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java new file mode 100644 index 0000000..8cc7ab7 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.paramserv; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.DataIdentifier; +import org.apache.sysml.parser.Expression; +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; +import org.apache.sysml.runtime.controlprogram.Program; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDC; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDRR; +import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.util.DataConverter; +import org.junit.Assert; +import org.junit.Test; + +//TODO test data partitioning on defined API not internal methods, +// potentially remove workers from API to make the data partitioner independent +//TODO test expected behavior not the internal implementation against itself +// (e.g., for DR check that each row is in at most one partition and all rows are distributed) + +public class DataPartitionerTest { + + @Test + public void testDataPartitionerDC() { + DataPartitioner dp = new DataPartitionerDC(); + List<LocalPSWorker> workers = IntStream.range(0, 2).mapToObj(i -> new LocalPSWorker(i, "updFunc", Statement.PSFrequency.BATCH, 1, 64, null, null, createMockExecutionContext(), null)).collect(Collectors.toList()); + double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + MatrixObject features = ParamservUtils.newMatrixObject(); + features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); + features.refreshMetaData(); + features.release(); + MatrixObject labels = ParamservUtils.newMatrixObject(); + labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); + labels.refreshMetaData(); + labels.release(); + dp.doPartitioning(workers, features, labels); + + double[] expected1 = new double[] { 1, 2, 3, 4, 5 }; + double[] realValue1 = workers.get(0).getFeatures().acquireRead().getDenseBlockValues(); + double[] realValue2 = workers.get(0).getLabels().acquireRead().getDenseBlockValues(); + Assert.assertArrayEquals(expected1, realValue1, 0); + Assert.assertArrayEquals(expected1, realValue2, 0); + + double[] expected2 = new double[] { 6, 7, 8, 9, 10 }; + double[] realValue3 = workers.get(1).getFeatures().acquireRead().getDenseBlockValues(); + double[] realValue4 = workers.get(1).getLabels().acquireRead().getDenseBlockValues(); + Assert.assertArrayEquals(expected2, realValue3, 0); + Assert.assertArrayEquals(expected2, realValue4, 0); + } + + @Test + public void testDataPartitionerDR() { +// DataPartitionerDR dp = new DataPartitionerDR(); +// double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; +// MatrixObject features = ParamservUtils.newMatrixObject(); +// features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); +// features.refreshMetaData(); +// features.release(); +// MatrixObject labels = ParamservUtils.newMatrixObject(); +// labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); +// labels.refreshMetaData(); +// labels.release(); +// +// MatrixBlock permutation = ParamservUtils.generatePermutation(df.length, df.length); +// +// List<MatrixObject> pfs = dp.doPartitioning(2, features, permutation); +// List<MatrixObject> pls = dp.doPartitioning(2, labels, permutation); +// +// double[] expected1 = IntStream.range(0, 5).mapToDouble(i -> permutation.getSparseBlock().get(i).indexes()[0] + 1).toArray(); +// double[] realValue1 = pfs.get(0).acquireRead().getDenseBlockValues(); +// double[] realValue2 = pls.get(0).acquireRead().getDenseBlockValues(); +// Assert.assertArrayEquals(expected1, realValue1, 0); +// Assert.assertArrayEquals(expected1, realValue2, 0); +// +// double[] expected2 = IntStream.range(5, 10).mapToDouble(i -> permutation.getSparseBlock().get(i).indexes()[0] + 1).toArray(); +// double[] realValue3 = pfs.get(1).acquireRead().getDenseBlockValues(); +// double[] realValue4 = pls.get(1).acquireRead().getDenseBlockValues(); +// Assert.assertArrayEquals(expected2, realValue3, 0); +// Assert.assertArrayEquals(expected2, realValue4, 0); + } + + @Test + public void testDataPartitionerDRR() { + DataPartitioner dp = new DataPartitionerDRR(); + List<LocalPSWorker> workers = IntStream.range(0, 2).mapToObj(i -> new LocalPSWorker(i, "updFunc", Statement.PSFrequency.BATCH, 1, 64, null, null, createMockExecutionContext(), null)).collect(Collectors.toList()); + double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + MatrixObject features = ParamservUtils.newMatrixObject(); + features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); + features.refreshMetaData(); + features.release(); + MatrixObject labels = ParamservUtils.newMatrixObject(); + labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); + labels.refreshMetaData(); + labels.release(); + dp.doPartitioning(workers, features, labels); + + //TODO test against four not two workers + double[] expected1 = new double[] { 1, 3, 5, 7, 9 }; + double[] realValue1 = workers.get(0).getFeatures().acquireRead().getDenseBlockValues(); + double[] realValue2 = workers.get(0).getLabels().acquireRead().getDenseBlockValues(); + Assert.assertArrayEquals(expected1, realValue1, 0); + Assert.assertArrayEquals(expected1, realValue2, 0); + + double[] expected2 = new double[] { 2, 4, 6, 8, 10 }; + double[] realValue3 = workers.get(1).getFeatures().acquireRead().getDenseBlockValues(); + double[] realValue4 = workers.get(1).getLabels().acquireRead().getDenseBlockValues(); + Assert.assertArrayEquals(expected2, realValue3, 0); + Assert.assertArrayEquals(expected2, realValue4, 0); + } + + @Test + public void testDataPartitionerOR() { +// DataPartitionerOR dp = new DataPartitionerOR(); +// double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; +// MatrixObject features = ParamservUtils.newMatrixObject(); +// features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); +// features.refreshMetaData(); +// features.release(); +// MatrixObject labels = ParamservUtils.newMatrixObject(); +// labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); +// labels.refreshMetaData(); +// labels.release(); +// +// MatrixBlock permutation = ParamservUtils.generatePermutation(df.length, df.length); +// +// List<MatrixObject> pfs = dp.doPartitioning(1, features, permutation); +// List<MatrixObject> pls = dp.doPartitioning(1, labels, permutation); +// +// double[] expected1 = IntStream.range(0, 10).mapToDouble(i -> permutation.getSparseBlock().get(i).indexes()[0] + 1).toArray(); +// double[] realValue1 = pfs.get(0).acquireRead().getDenseBlockValues(); +// double[] realValue2 = pls.get(0).acquireRead().getDenseBlockValues(); +// Assert.assertArrayEquals(expected1, realValue1, 0); +// Assert.assertArrayEquals(expected1, realValue2, 0); + } + + private ExecutionContext createMockExecutionContext() { + Program prog = new Program(); + ArrayList<DataIdentifier> inputs = new ArrayList<>(); + DataIdentifier features = new DataIdentifier("features"); + features.setDataType(Expression.DataType.MATRIX); + features.setValueType(Expression.ValueType.DOUBLE); + inputs.add(features); + DataIdentifier labels = new DataIdentifier("labels"); + labels.setDataType(Expression.DataType.MATRIX); + labels.setValueType(Expression.ValueType.DOUBLE); + inputs.add(labels); + DataIdentifier model = new DataIdentifier("model"); + model.setDataType(Expression.DataType.LIST); + model.setValueType(Expression.ValueType.UNKNOWN); + inputs.add(model); + + ArrayList<DataIdentifier> outputs = new ArrayList<>(); + DataIdentifier gradients = new DataIdentifier("gradients"); + gradients.setDataType(Expression.DataType.LIST); + gradients.setValueType(Expression.ValueType.UNKNOWN); + outputs.add(gradients); + + FunctionProgramBlock fpb = new FunctionProgramBlock(prog, inputs, outputs); + prog.addProgramBlock(fpb); + prog.addFunctionProgramBlock(DMLProgram.DEFAULT_NAMESPACE, "updFunc", fpb); + return ExecutionContextFactory.createContext(prog); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java deleted file mode 100644 index 3185621..0000000 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.test.integration.functions.paramserv; - -import org.apache.sysml.api.DMLException; -import org.apache.sysml.test.integration.AutomatedTestBase; -import org.apache.sysml.test.integration.TestConfiguration; -import org.junit.Test; - -public class ParamservFuncTest extends AutomatedTestBase { - - private static final String TEST_NAME1 = "paramserv-all-args"; - private static final String TEST_NAME2 = "paramserv-without-optional-args"; - private static final String TEST_NAME3 = "paramserv-miss-args"; - private static final String TEST_NAME4 = "paramserv-wrong-type-args"; - private static final String TEST_NAME5 = "paramserv-wrong-args"; - private static final String TEST_NAME6 = "paramserv-wrong-args2"; - private static final String TEST_NAME7 = "paramserv-nn-bsp-batch"; - private static final String TEST_NAME8 = "paramserv-minimum-version"; - private static final String TEST_NAME9 = "paramserv-worker-failed"; - private static final String TEST_NAME10 = "paramserv-agg-service-failed"; - private static final String TEST_NAME11 = "paramserv-large-parallelism"; - private static final String TEST_NAME12 = "paramserv-wrong-aggregate-func"; - private static final String TEST_NAME13 = "paramserv-nn-asp-batch"; - private static final String TEST_NAME14 = "paramserv-nn-bsp-epoch"; - private static final String TEST_NAME15 = "paramserv-nn-asp-epoch"; - - private static final String TEST_DIR = "functions/paramserv/"; - private static final String TEST_CLASS_DIR = TEST_DIR + ParamservFuncTest.class.getSimpleName() + "/"; - - private final String HOME = SCRIPT_DIR + TEST_DIR; - - @Override - public void setUp() { - addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {})); - addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {})); - addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {})); - addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {})); - addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {})); - addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {})); - addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {})); - addTestConfiguration(TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {})); - addTestConfiguration(TEST_NAME9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] {})); - addTestConfiguration(TEST_NAME10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {})); - addTestConfiguration(TEST_NAME11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME11, new String[] {})); - addTestConfiguration(TEST_NAME12, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME12, new String[] {})); - addTestConfiguration(TEST_NAME13, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME13, new String[] {})); - addTestConfiguration(TEST_NAME14, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME14, new String[] {})); - addTestConfiguration(TEST_NAME15, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME15, new String[] {})); - } - - @Test - public void testParamservWithAllArgs() { - runDMLTest(TEST_NAME1, false, null, null); - } - - @Test - public void testParamservWithoutOptionalArgs() { - runDMLTest(TEST_NAME2, false, null, null); - } - - @Test - public void testParamservMissArgs() { - final String errmsg = "Named parameter 'features' missing. Please specify the input."; - runDMLTest(TEST_NAME3, true, DMLException.class, errmsg); - } - - @Test - public void testParamservWrongTypeArgs() { - final String errmsg = "Input to PARAMSERV::model must be of type 'LIST'. It should not be of type 'MATRIX'"; - runDMLTest(TEST_NAME4, true, DMLException.class, errmsg); - } - - @Test - public void testParamservWrongArgs() { - final String errmsg = "Paramserv function: not support update type 'NSP'."; - runDMLTest(TEST_NAME5, true, DMLException.class, errmsg); - } - - @Test - public void testParamservWrongArgs2() { - final String errmsg = "Invalid parameters for PARAMSERV: [modelList, val_featur=X_val]"; - runDMLTest(TEST_NAME6, true, DMLException.class, errmsg); - } - - @Test - public void testParamservNNBspBatchTest() { - runDMLTest(TEST_NAME7, false, null, null); - } - - @Test - public void testParamservMinimumVersionTest() { - runDMLTest(TEST_NAME8, false, null, null); - } - - @Test - public void testParamservWorkerFailedTest() { - runDMLTest(TEST_NAME9, true, DMLException.class, "Invalid lookup by name in unnamed list: worker_err."); - } - - @Test - public void testParamservAggServiceFailedTest() { - runDMLTest(TEST_NAME10, true, DMLException.class, "Invalid lookup by name in unnamed list: agg_service_err"); - } - - @Test - public void testParamservLargeParallelismTest() { - runDMLTest(TEST_NAME11, false, null, null); - } - - @Test - public void testParamservWrongAggregateFuncTest() { - runDMLTest(TEST_NAME12, true, DMLException.class, - "The 'gradients' function should provide an input of 'MATRIX' type named 'labels'."); - } - - @Test - public void testParamservASPTest() { - runDMLTest(TEST_NAME13, false, null, null); - } - - @Test - public void testParamservBSPEpochTest() { - runDMLTest(TEST_NAME14, false, null, null); - } - - @Test - public void testParamservASPEpochTest() { - runDMLTest(TEST_NAME15, false, null, null); - } - - private void runDMLTest(String testname, boolean exceptionExpected, Class<?> exceptionClass, String errmsg) { - TestConfiguration config = getTestConfiguration(testname); - loadTestConfiguration(config); - programArgs = new String[] { "-explain" }; - fullDMLScriptName = HOME + testname + ".dml"; - runTest(true, exceptionExpected, exceptionClass, errmsg, -1); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservNNTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservNNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservNNTest.java new file mode 100644 index 0000000..d7afc9d --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservNNTest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.paramserv; + +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.junit.Test; + +public class ParamservNNTest extends AutomatedTestBase { + + private static final String TEST_NAME1 = "paramserv-nn-bsp-batch-dc"; + private static final String TEST_NAME2 = "paramserv-nn-asp-batch"; + private static final String TEST_NAME3 = "paramserv-nn-bsp-epoch"; + private static final String TEST_NAME4 = "paramserv-nn-asp-epoch"; + private static final String TEST_NAME5 = "paramserv-nn-bsp-batch-drr"; + private static final String TEST_NAME6 = "paramserv-nn-bsp-batch-dr"; + private static final String TEST_NAME7 = "paramserv-nn-bsp-batch-or"; + + private static final String TEST_DIR = "functions/paramserv/"; + private static final String TEST_CLASS_DIR = TEST_DIR + ParamservNNTest.class.getSimpleName() + "/"; + + private final String HOME = SCRIPT_DIR + TEST_DIR; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {})); + addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {})); + addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {})); + addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {})); + addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {})); + } + + @Test + public void testParamservBSPBatchDisjointContiguous() { + runDMLTest(TEST_NAME1); + } + + @Test + public void testParamservASPBatch() { + runDMLTest(TEST_NAME2); + } + + @Test + public void testParamservBSPEpoch() { + runDMLTest(TEST_NAME3); + } + + @Test + public void testParamservASPEpoch() { + runDMLTest(TEST_NAME4); + } + + @Test + public void testParamservBSPBatchDisjointRoundRobin() { + runDMLTest(TEST_NAME5); + } + + @Test + public void testParamservBSPBatchDisjointRandom() { + runDMLTest(TEST_NAME6); + } + + @Test + public void testParamservBSPBatchOverlapReshuffle() { + runDMLTest(TEST_NAME7); + } + + private void runDMLTest(String testname) { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + programArgs = new String[] { "-explain" }; + fullDMLScriptName = HOME + testname + ".dml"; + runTest(true, false, null, null, -1); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRecompilationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRecompilationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRecompilationTest.java new file mode 100644 index 0000000..f02ab6a --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRecompilationTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.paramserv; + +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.junit.Test; + +public class ParamservRecompilationTest extends AutomatedTestBase { + + private static final String TEST_NAME1 = "paramserv-large-parallelism"; + + private static final String TEST_DIR = "functions/paramserv/"; + private static final String TEST_CLASS_DIR = TEST_DIR + ParamservRecompilationTest.class.getSimpleName() + "/"; + + private final String HOME = SCRIPT_DIR + TEST_DIR; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {})); + } + + @Test + public void testParamservLargeParallelism() { + runDMLTest(TEST_NAME1, false, null, null); + } + + private void runDMLTest(String testname, boolean exceptionExpected, Class<?> exceptionClass, String errmsg) { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + programArgs = new String[] { "-explain" }; + fullDMLScriptName = HOME + testname + ".dml"; + runTest(true, exceptionExpected, exceptionClass, errmsg, -1); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java new file mode 100644 index 0000000..7238cf9 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.paramserv; + +import org.apache.sysml.api.DMLException; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.junit.Test; + +public class ParamservRuntimeNegativeTest extends AutomatedTestBase { + + private static final String TEST_NAME1 = "paramserv-worker-failed"; + private static final String TEST_NAME2 = "paramserv-agg-service-failed"; + private static final String TEST_NAME3 = "paramserv-wrong-aggregate-func"; + + private static final String TEST_DIR = "functions/paramserv/"; + private static final String TEST_CLASS_DIR = TEST_DIR + ParamservRuntimeNegativeTest.class.getSimpleName() + "/"; + + private final String HOME = SCRIPT_DIR + TEST_DIR; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {})); + } + + @Test + public void testParamservWorkerFailed() { + runDMLTest(TEST_NAME1, "Invalid lookup by name in unnamed list: worker_err."); + } + + @Test + public void testParamservAggServiceFailed() { + runDMLTest(TEST_NAME2, "Invalid lookup by name in unnamed list: agg_service_err"); + } + + @Test + public void testParamservWrongAggregateFunc() { + runDMLTest(TEST_NAME3, "The 'gradients' function should provide an input of 'MATRIX' type named 'labels'."); + } + + private void runDMLTest(String testname, String errmsg) { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + programArgs = new String[] { "-explain" }; + fullDMLScriptName = HOME + testname + ".dml"; + runTest(true, true, DMLException.class, errmsg, -1); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSyntaxTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSyntaxTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSyntaxTest.java new file mode 100644 index 0000000..f1bead1 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSyntaxTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.paramserv; + +import org.apache.sysml.api.DMLException; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.junit.Test; + +public class ParamservSyntaxTest extends AutomatedTestBase { + + private static final String TEST_NAME1 = "paramserv-all-args"; + private static final String TEST_NAME2 = "paramserv-without-optional-args"; + private static final String TEST_NAME3 = "paramserv-miss-args"; + private static final String TEST_NAME4 = "paramserv-wrong-type-args"; + private static final String TEST_NAME5 = "paramserv-wrong-args"; + private static final String TEST_NAME6 = "paramserv-wrong-args2"; + private static final String TEST_NAME7 = "paramserv-minimum-version"; + + private static final String TEST_DIR = "functions/paramserv/"; + private static final String TEST_CLASS_DIR = TEST_DIR + ParamservSyntaxTest.class.getSimpleName() + "/"; + + private final String HOME = SCRIPT_DIR + TEST_DIR; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {})); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {})); + addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {})); + addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {})); + addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {})); + addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {})); + } + + @Test + public void testParamservWithAllArgs() { + runDMLTest(TEST_NAME1, false, null, null); + } + + @Test + public void testParamservWithoutOptionalArgs() { + runDMLTest(TEST_NAME2, false, null, null); + } + + @Test + public void testParamservMissArgs() { + final String errmsg = "Named parameter 'features' missing. Please specify the input."; + runDMLTest(TEST_NAME3, true, DMLException.class, errmsg); + } + + @Test + public void testParamservWrongTypeArgs() { + final String errmsg = "Input to PARAMSERV::model must be of type 'LIST'. It should not be of type 'MATRIX'"; + runDMLTest(TEST_NAME4, true, DMLException.class, errmsg); + } + + @Test + public void testParamservWrongArgs() { + final String errmsg = "Paramserv function: not support update type 'NSP'."; + runDMLTest(TEST_NAME5, true, DMLException.class, errmsg); + } + + @Test + public void testParamservWrongArgs2() { + final String errmsg = "Invalid parameters for PARAMSERV: [modelList, val_featur=X_val]"; + runDMLTest(TEST_NAME6, true, DMLException.class, errmsg); + } + + @Test + public void testParamservMinimumVersion() { + runDMLTest(TEST_NAME7, false, null, null); + } + + private void runDMLTest(String testname, boolean exceptionExpected, Class<?> exceptionClass, String errmsg) { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + programArgs = new String[] { "-explain" }; + fullDMLScriptName = HOME + testname + ".dml"; + runTest(true, exceptionExpected, exceptionClass, errmsg, -1); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml index 041c2bf..a10c846 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml @@ -35,7 +35,8 @@ source("nn/optim/sgd_nesterov.dml") as sgd_nesterov train = function(matrix[double] X, matrix[double] Y, matrix[double] X_val, matrix[double] Y_val, - int C, int Hin, int Win, int epochs, int workers, string utype, string freq) + int C, int Hin, int Win, int epochs, int workers, + string utype, string freq, string scheme) return (matrix[double] W1, matrix[double] b1, matrix[double] W2, matrix[double] b2, matrix[double] W3, matrix[double] b3, @@ -107,7 +108,7 @@ train = function(matrix[double] X, matrix[double] Y, params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) # Use paramserv function - modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode="LOCAL", utype=utype, freq=freq, epochs=epochs, batchsize=64, k=workers, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE") + modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode="LOCAL", utype=utype, freq=freq, epochs=epochs, batchsize=64, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE") W1 = as.matrix(modelList2["W1"]) b1 = as.matrix(modelList2["b1"]) http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml index 346cc08..baef6ee 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml @@ -41,7 +41,7 @@ epochs = 10 workers = 2 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "BATCH") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "BATCH", "DISJOINT_CONTIGUOUS") # Compute validation loss & accuracy probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml index 8d553ae..860f53f 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml @@ -41,7 +41,7 @@ epochs = 10 workers = 2 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "EPOCH") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "EPOCH", "DISJOINT_CONTIGUOUS") # Compute validation loss & accuracy probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml new file mode 100644 index 0000000..dcbd2dd --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 2 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", "DISJOINT_CONTIGUOUS") + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml new file mode 100644 index 0000000..96fe734 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 2 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", "DISJOINT_RANDOM") + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml new file mode 100644 index 0000000..e97dbff --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 4 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", "DISJOINT_ROUND_ROBIN") + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml new file mode 100644 index 0000000..a2e95d3 --- /dev/null +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml @@ -0,0 +1,52 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss + +# Generate the training data +[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() +n = nrow(images) + +# Generate the training data +[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() + +# Split into training and validation +val_size = n * 0.1 +X = images[(val_size+1):n,] +X_val = images[1:val_size,] +Y = labels[(val_size+1):n,] +Y_val = labels[1:val_size,] + +# Arguments +epochs = 10 +workers = 2 + +# Train +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", "OVERLAP_RESHUFFLE") + +# Compute validation loss & accuracy +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +loss_val = cross_entropy_loss::forward(probs_val, Y_val) +accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) + +# Output results +print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml deleted file mode 100644 index 7b6523b..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch.dml +++ /dev/null @@ -1,52 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet -source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss - -# Generate the training data -[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data() -n = nrow(images) - -# Generate the training data -[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data() - -# Split into training and validation -val_size = n * 0.1 -X = images[(val_size+1):n,] -X_val = images[1:val_size,] -Y = labels[(val_size+1):n,] -Y_val = labels[1:val_size,] - -# Arguments -epochs = 10 -workers = 2 - -# Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH") - -# Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) -loss_val = cross_entropy_loss::forward(probs_val, Y_val) -accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) - -# Output results -print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml index d0a6570..25d5f48 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml @@ -41,7 +41,7 @@ epochs = 10 workers = 2 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "EPOCH") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "EPOCH", "DISJOINT_CONTIGUOUS") # Compute validation loss & accuracy probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) http://git-wip-us.apache.org/repos/asf/systemml/blob/69ef76c0/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java index ad3d526..89350bc 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java @@ -26,7 +26,11 @@ import org.junit.runners.Suite; * won't run two of them at once. */ @RunWith(Suite.class) @Suite.SuiteClasses({ - ParamservFuncTest.class + DataPartitionerTest.class, + ParamservSyntaxTest.class, + ParamservRecompilationTest.class, + ParamservRuntimeNegativeTest.class, + ParamservNNTest.class })
