[SYSTEMML-2085] Initial version of local backend for paramserv builtin Closes #771.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/97018d4e Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/97018d4e Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/97018d4e Branch: refs/heads/master Commit: 97018d4e688ba7eeaaa4567ca1e174a3c5525468 Parents: c7a9e01 Author: EdgarLGB <[email protected]> Authored: Mon May 28 23:17:18 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon May 28 23:17:20 2018 -0700 ---------------------------------------------------------------------- .../ParameterizedBuiltinFunctionExpression.java | 6 +- .../java/org/apache/sysml/parser/Statement.java | 1 + .../context/ExecutionContext.java | 13 +- .../controlprogram/paramserv/LocalPSWorker.java | 97 +++++ .../paramserv/LocalParamServer.java | 59 +++ .../controlprogram/paramserv/PSWorker.java | 131 +++++++ .../controlprogram/paramserv/ParamServer.java | 232 +++++++++++ .../paramserv/ParamservUtils.java | 97 +++++ .../runtime/instructions/cp/CPOperand.java | 2 +- .../runtime/instructions/cp/ListObject.java | 14 + .../cp/MatrixIndexingCPInstruction.java | 4 +- .../cp/ParamservBuiltinCPInstruction.java | 257 ++++++++++++- .../test/integration/AutomatedTestBase.java | 18 +- .../functions/paramserv/ParamservFuncTest.java | 29 +- .../paramserv/mnist_lenet_paramserv.dml | 383 +++++++++++++++++++ .../mnist_lenet_paramserv_minimum_version.dml | 377 ++++++++++++++++++ .../functions/paramserv/paramserv-all-args.dml | 4 +- .../functions/paramserv/paramserv-ipa-test.dml | 47 --- .../paramserv/paramserv-minimum-version.dml | 52 +++ .../functions/paramserv/paramserv-miss-args.dml | 4 +- .../functions/paramserv/paramserv-nn-test.dml | 52 +++ .../paramserv-without-optional-args.dml | 4 +- 22 files changed, 1805 insertions(+), 78 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java index 3d74f8d..99aec78 100644 --- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java @@ -341,12 +341,12 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier .collect(Collectors.toSet()); checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, utypes, conditional); Set<String> frequencies = Arrays.stream(Statement.PSFrequency.values()).map(Enum::name).collect(Collectors.toSet()); - checkStringParam(false, fname, Statement.PS_FREQUENCY, frequencies, conditional); + checkStringParam(true, fname, Statement.PS_FREQUENCY, frequencies, conditional); checkDataValueType(false, fname, Statement.PS_EPOCHS, DataType.SCALAR, ValueType.INT, conditional); checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT, conditional); - checkDataValueType(false, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT, conditional); + checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT, conditional); Set<String> schemes = Arrays.stream(Statement.PSScheme.values()).map(Enum::name).collect(Collectors.toSet()); - checkStringParam(false, fname, Statement.PS_SCHEME, schemes, conditional); + checkStringParam(true, fname, Statement.PS_SCHEME, schemes, conditional); checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional); Set<String> checkpointings = Arrays.stream(Statement.PSCheckpointing.values()).map(Enum::name).collect(Collectors.toSet()); checkStringParam(true, fname, Statement.PS_CHECKPOINTING, checkpointings, conditional); http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/parser/Statement.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/Statement.java b/src/main/java/org/apache/sysml/parser/Statement.java index 4853a47..1987d31 100644 --- a/src/main/java/org/apache/sysml/parser/Statement.java +++ b/src/main/java/org/apache/sysml/parser/Statement.java @@ -71,6 +71,7 @@ public abstract class Statement implements ParseInfo public static final String PS_UPDATE_FUN = "upd"; public static final String PS_AGGREGATION_FUN = "agg"; public static final String PS_MODE = "mode"; + public static final String PS_GRADIENTS = "gradients"; public enum PSModeType { LOCAL, REMOTE_SPARK } http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index 67b2a83..6807848 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -42,6 +42,7 @@ import org.apache.sysml.runtime.instructions.cp.CPInstruction; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; +import org.apache.sysml.runtime.instructions.cp.ListObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory; import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; @@ -443,7 +444,17 @@ public class ExecutionContext { public void setScalarOutput(String varName, ScalarObject so) { setVariable(varName, so); } - + + public ListObject getListObject(String name) { + Data dat = getVariable(name); + //error handling if non existing or no list + if (dat == null) + throw new DMLRuntimeException("Variable '" + name + "' does not exist in the symbol table."); + if (!(dat instanceof ListObject)) + throw new DMLRuntimeException("Variable '" + name + "' is not a list."); + return (ListObject) dat; + } + public void releaseMatrixOutputForGPUInstruction(String varName) { MatrixObject mo = getMatrixObject(varName); if(mo.getGPUObject(getGPUContext(0)) == null || !mo.getGPUObject(getGPUContext(0)).isAllocated()) { http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java new file mode 100644 index 0000000..181b866 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.cp.ListObject; + +public class LocalPSWorker extends PSWorker implements Runnable { + + protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName()); + + public LocalPSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, + ListObject hyperParams, ExecutionContext ec, ParamServer ps) { + super(workerID, updFunc, freq, epochs, batchSize, hyperParams, ec, ps); + } + + @Override + public void run() { + + long dataSize = _features.getNumRows(); + + for (int i = 0; i < _epochs; i++) { + int totalIter = (int) Math.ceil(dataSize / _batchSize); + for (int j = 0; j < totalIter; j++) { + // Pull the global parameters from ps + // Need to copy the global parameter + ListObject globalParams = ParamservUtils.copyList((ListObject) _ps.pull(_workerID)); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format( + "Local worker_%d: Successfully pull the global parameters [size:%d kb] from ps.", _workerID, + globalParams.getDataSize() / 1024)); + } + _ec.setVariable(Statement.PS_MODEL, globalParams); + + long begin = j * _batchSize + 1; + long end = Math.min(begin + _batchSize, dataSize); + + // Get batch features and labels + MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end); + MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end); + _ec.setVariable(Statement.PS_FEATURES, bFeatures); + _ec.setVariable(Statement.PS_LABELS, bLabels); + + if (LOG.isDebugEnabled()) { + LOG.debug(String.format( + "Local worker_%d: Got batch data [size:%d kb] of index from %d to %d. [Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", + _workerID, bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 1, + _epochs, j + 1, totalIter)); + } + + // Invoke the update function + _inst.processInstruction(_ec); + + // Get the gradients + ListObject gradients = (ListObject) _ec.getVariable(_outputs.get(0).getName()); + + // Push the gradients to ps + _ps.push(_workerID, gradients); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Successfully push the gradients [size:%d kb] to ps.", + _workerID, gradients.getDataSize() / 1024)); + } + + ParamservUtils.cleanupListObject(_ec, globalParams); + ParamservUtils.cleanupData(bFeatures); + ParamservUtils.cleanupData(bLabels); + } + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); + } + } + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Job finished.", _workerID)); + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java new file mode 100644 index 0000000..d060a91 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.cp.Data; +import org.apache.sysml.runtime.instructions.cp.ListObject; + +public class LocalParamServer extends ParamServer { + + public LocalParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq, + Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum, + ListObject hyperParams) { + super(model, aggFunc, freq, updateType, ec, workerNum, hyperParams); + } + + @Override + public void push(long workerID, ListObject gradients) { + synchronized (_lock) { + _queue.add(new Gradient(workerID, gradients)); + _lock.notifyAll(); + } + } + + @Override + public Data pull(long workerID) { + synchronized (_lock) { + while (getPulledState((int) workerID)) { + try { + _lock.wait(); + } catch (InterruptedException e) { + throw new DMLRuntimeException( + String.format("Local worker_%d: failed to pull the global parameters.", workerID), e); + } + } + setPulledState((int) workerID, true); + } + return getResult(); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java new file mode 100644 index 0000000..9ace823 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import java.util.ArrayList; +import java.util.stream.Collectors; + +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.DataIdentifier; +import org.apache.sysml.parser.Expression; +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; +import org.apache.sysml.runtime.instructions.cp.ListObject; + +@SuppressWarnings("unused") +public abstract class PSWorker { + + long _workerID = -1; + int _epochs; + long _batchSize; + MatrixObject _features; + MatrixObject _labels; + ExecutionContext _ec; + ParamServer _ps; + private String _updFunc; + private Statement.PSFrequency _freq; + private MatrixObject _valFeatures; + private MatrixObject _valLabels; + + ArrayList<DataIdentifier> _outputs; + FunctionCallCPInstruction _inst; + + public PSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, + ListObject hyperParams, ExecutionContext ec, ParamServer ps) { + this._workerID = workerID; + this._updFunc = updFunc; + this._freq = freq; + this._epochs = epochs; + this._batchSize = batchSize; + this._ec = ExecutionContextFactory.createContext(ec.getProgram()); + if (hyperParams != null) { + this._ec.setVariable(Statement.PS_HYPER_PARAMS, hyperParams); + } + this._ps = ps; + + // Get the update function + String[] keys = DMLProgram.splitFunctionKey(updFunc); + String _funcName = keys[0]; + String _funcNS = null; + if (keys.length == 2) { + _funcNS = keys[0]; + _funcName = keys[1]; + } + FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(_funcNS, _funcName); + ArrayList<DataIdentifier> _inputs = func.getInputParams(); + _outputs = func.getOutputParams(); + CPOperand[] _boundInputs = _inputs.stream() + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); + ArrayList<String> _inputNames = _inputs.stream().map(DataIdentifier::getName) + .collect(Collectors.toCollection(ArrayList::new)); + ArrayList<String> _outputNames = _outputs.stream().map(DataIdentifier::getName) + .collect(Collectors.toCollection(ArrayList::new)); + _inst = new FunctionCallCPInstruction(_funcNS, _funcName, _boundInputs, _inputNames, _outputNames, + "update function"); + + // Check the inputs of the update function + checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_FEATURES); + checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_LABELS); + checkInput(_inputs, Expression.DataType.LIST, Statement.PS_MODEL); + if (hyperParams != null) { + checkInput(_inputs, Expression.DataType.LIST, Statement.PS_HYPER_PARAMS); + } + + // Check the output of the update function + if (_outputs.size() != 1) { + throw new DMLRuntimeException( + String.format("The output of the '%s' function should provide one list containing the gradients.", updFunc)); + } + if (_outputs.get(0).getDataType() != Expression.DataType.LIST) { + throw new DMLRuntimeException( + String.format("The output of the '%s' function should be type of list.", updFunc)); + } + } + + private void checkInput(ArrayList<DataIdentifier> _inputs, Expression.DataType dt, String pname) { + if (_inputs.stream().filter(input -> input.getDataType() == dt && pname.equals(input.getName())).count() != 1) { + throw new DMLRuntimeException( + String.format("The '%s' function should provide an input of '%s' type named '%s'.", _updFunc, dt, pname)); + } + } + + public void setFeatures(MatrixObject features) { + this._features = features; + } + + public void setLabels(MatrixObject labels) { + this._labels = labels; + } + + public void setValFeatures(MatrixObject valFeatures) { + this._valFeatures = valFeatures; + } + + public void setValLabels(MatrixObject valLabels) { + this._valLabels = valLabels; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java new file mode 100644 index 0000000..6e1cd13 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.DataIdentifier; +import org.apache.sysml.parser.Expression; +import org.apache.sysml.parser.Statement; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.cp.Data; +import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; +import org.apache.sysml.runtime.instructions.cp.ListObject; + +public abstract class ParamServer { + + public class Gradient { + final long _workerID; + final ListObject _gradients; + + public Gradient(long workerID, ListObject gradients) { + this._workerID = workerID; + this._gradients = gradients; + } + } + + Queue<Gradient> _queue; + final Object _lock = new Object(); + private ListObject _model; + private AggregationService _aggService; + private Thread _aggThread; + private boolean[] _pulledStates; + + protected ParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq, + Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum, ListObject hyperParams) { + this._queue = new ConcurrentLinkedQueue<>(); + this._model = model; + this._aggService = new AggregationService(aggFunc, freq, updateType, ec, workerNum, hyperParams); + this._pulledStates = new boolean[workerNum]; + this._aggThread = new Thread(_aggService); + } + + public abstract void push(long workerID, ListObject value); + + public abstract Data pull(long workerID); + + public void start() { + _aggService._alive = true; + _aggThread.start(); + } + + public void stop() { + _aggService._alive = false; + try { + _aggThread.join(); + } catch (InterruptedException e) { + throw new DMLRuntimeException("Parameter server: failed when stopping the server.", e); + } + } + + public ListObject getResult() { + return _model; + } + + public boolean getPulledState(int workerID) { + return _pulledStates[workerID]; + } + + public void setPulledState(int workerID, boolean state) { + _pulledStates[workerID] = state; + } + + private void resetPulledStates() { + _pulledStates = new boolean[_pulledStates.length]; + } + + /** + * Inner aggregation service which is for updating the model + */ + @SuppressWarnings("unused") + private class AggregationService implements Runnable { + + protected final Log LOG = LogFactory.getLog(AggregationService.class.getName()); + + protected ExecutionContext _ec; + private Statement.PSFrequency _freq; + private Statement.PSUpdateType _updateType; + private FunctionCallCPInstruction _inst; + private DataIdentifier _output; + private boolean _alive; + private boolean[] _finishedStates; // Workers' finished states + + AggregationService(String aggFunc, Statement.PSFrequency freq, Statement.PSUpdateType updateType, + ExecutionContext ec, int workerNum, ListObject hyperParams) { + _ec = ExecutionContextFactory.createContext(ec.getProgram()); + _freq = freq; + _updateType = updateType; + if (hyperParams != null) { + _ec.setVariable(Statement.PS_HYPER_PARAMS, hyperParams); + } + _finishedStates = new boolean[workerNum]; + + // Fetch the aggregation function + String[] keys = DMLProgram.splitFunctionKey(aggFunc); + String funcName = keys[0]; + String funcNS = null; + if (keys.length == 2) { + funcNS = keys[0]; + funcName = keys[1]; + } + FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(funcNS, funcName); + ArrayList<DataIdentifier> inputs = func.getInputParams(); + ArrayList<DataIdentifier> outputs = func.getOutputParams(); + + // Check the output of the aggregation function + if (outputs.size() != 1) { + throw new DMLRuntimeException(String.format( + "The output of the '%s' function should provide one list containing the updated model.", + aggFunc)); + } + if (outputs.get(0).getDataType() != Expression.DataType.LIST) { + throw new DMLRuntimeException( + String.format("The output of the '%s' function should be type of list.", aggFunc)); + } + _output = outputs.get(0); + + CPOperand[] boundInputs = inputs.stream() + .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())) + .toArray(CPOperand[]::new); + ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName) + .collect(Collectors.toCollection(ArrayList::new)); + ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) + .collect(Collectors.toCollection(ArrayList::new)); + _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames, + "aggregate function"); + } + + boolean isAlive() { + return _alive; + } + + private boolean allFinished() { + return !ArrayUtils.contains(_finishedStates, false); + } + + private void resetFinishedStates() { + Arrays.fill(_finishedStates, false); + } + + private void setFinishedState(int workerID) { + _finishedStates[workerID] = true; + } + + @Override + public void run() { + synchronized (_lock) { + while (isAlive()) { + do { + while (_queue.isEmpty()) { + try { + _lock.wait(); + } catch (InterruptedException e) { + throw new DMLRuntimeException( + "Aggregation service: error when waiting for the coming gradients.", e); + } + } + Gradient p = _queue.remove(); + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.", + p._gradients.getDataSize() / 1024, p._workerID)); + } + + setFinishedState((int) p._workerID); + + // Populate the variables table with the gradients and model + _ec.setVariable(Statement.PS_GRADIENTS, p._gradients); + _ec.setVariable(Statement.PS_MODEL, _model); + + // Invoke the aggregate function + _inst.processInstruction(_ec); + + // Get the output + ListObject newModel = (ListObject) _ec.getVariable(_output.getName()); + + // Update the model with the new output + ParamservUtils.cleanupListObject(_ec, _model); + ParamservUtils.cleanupListObject(_ec, p._gradients); + _model = newModel; + + } while (!allFinished()); + + // notify all the workers to get the updated model + resetPulledStates(); + resetFinishedStates(); + _lock.notifyAll(); + if (LOG.isDebugEnabled()) { + LOG.debug("Global parameter is broadcasted successfully."); + } + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java new file mode 100644 index 0000000..54c5d6c --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv; + +import java.util.HashSet; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.sysml.parser.Expression; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.controlprogram.caching.FrameObject; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.cp.Data; +import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.MetaDataFormat; +import org.apache.sysml.runtime.matrix.data.InputInfo; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.OutputInfo; + +public class ParamservUtils { + + /** + * Deep copy the list object + * + * @param lo list object + * @return a new copied list object + */ + public static ListObject copyList(ListObject lo) { + if (lo.getLength() == 0) { + return lo; + } + List<Data> newData = lo.getNames().stream().map(name -> { + Data oldData = lo.slice(name); + if (oldData instanceof MatrixObject) { + MatrixObject mo = (MatrixObject) oldData; + return sliceMatrix(mo, 1, mo.getNumRows()); + } else if (oldData instanceof ListObject || oldData instanceof FrameObject) { + throw new DMLRuntimeException("Copy list: does not support list or frame."); + } else { + return oldData; + } + }).collect(Collectors.toList()); + return new ListObject(newData, lo.getNames()); + } + + public static void cleanupListObject(ExecutionContext ec, ListObject lo) { + ec.getVariables().removeAllIn(new HashSet<>(lo.getNames())); + lo.getData().forEach(ParamservUtils::cleanupData); + } + + public static void cleanupData(Data data) { + if( !(data instanceof CacheableData) ) + return; + CacheableData<?> cd = (CacheableData<?>) data; + cd.enableCleanup(true); + cd.clearData(); + } + + /** + * Slice the matrix + * @param mo input matrix + * @param rl low boundary + * @param rh high boundary + * @return new sliced matrix + */ + public static MatrixObject sliceMatrix(MatrixObject mo, long rl, long rh) { + MatrixObject result = new MatrixObject(Expression.ValueType.DOUBLE, null, + new MetaDataFormat(new MatrixCharacteristics(-1, -1, -1, -1), + OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo)); + MatrixBlock tmp = mo.acquireRead(); + result.acquireModify(tmp.slice((int)rl-1, (int)rh-1, 0, + tmp.getNumColumns()-1, new MatrixBlock())); + mo.release(); + result.release(); + return result; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java index 1ca8eab..22b79b0 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java @@ -46,7 +46,7 @@ public class CPOperand this(name, vt, dt, false); } - private CPOperand(String name, ValueType vt, DataType dt, boolean literal) { + public CPOperand(String name, ValueType vt, DataType dt, boolean literal) { _name = name; _valueType = vt; _dataType = dt; http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java index 95f03b5..670190c 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java @@ -25,6 +25,7 @@ import java.util.List; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.CacheableData; public class ListObject extends Data { private static final long serialVersionUID = 3652422061598967358L; @@ -107,6 +108,19 @@ public class ListObject extends Data { return (_names == null) ? null : _names.get(ix); } + public boolean isNamedList() { + return _names != null; + } + + public List<Data> getData() { + return _data; + } + + public long getDataSize() { + return _data.stream().filter(data -> data instanceof CacheableData) + .map(data -> ((CacheableData) data).getDataSize()).reduce((l1, l2) -> l1 + l2).get(); + } + @Override public String getDebugName() { return toString(); http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java index 51cc4c1..4e5d4c0 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java @@ -34,8 +34,8 @@ import org.apache.sysml.utils.Statistics; public final class MatrixIndexingCPInstruction extends IndexingCPInstruction { - protected MatrixIndexingCPInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, - CPOperand cu, CPOperand out, String opcode, String istr) { + public MatrixIndexingCPInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, + CPOperand out, String opcode, String istr) { super(in, rl, ru, cl, cu, out, opcode, istr); } http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index ddc56ae..3ab0fc8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -19,14 +19,62 @@ package org.apache.sysml.runtime.instructions.cp; +import static org.apache.sysml.parser.Statement.PSFrequency; +import static org.apache.sysml.parser.Statement.PSModeType; +import static org.apache.sysml.parser.Statement.PSScheme; +import static org.apache.sysml.parser.Statement.PSUpdateType; +import static org.apache.sysml.parser.Statement.PS_AGGREGATION_FUN; +import static org.apache.sysml.parser.Statement.PS_BATCH_SIZE; +import static org.apache.sysml.parser.Statement.PS_EPOCHS; +import static org.apache.sysml.parser.Statement.PS_FEATURES; +import static org.apache.sysml.parser.Statement.PS_FREQUENCY; +import static org.apache.sysml.parser.Statement.PS_HYPER_PARAMS; +import static org.apache.sysml.parser.Statement.PS_LABELS; +import static org.apache.sysml.parser.Statement.PS_MODE; +import static org.apache.sysml.parser.Statement.PS_MODEL; +import static org.apache.sysml.parser.Statement.PS_PARALLELISM; +import static org.apache.sysml.parser.Statement.PS_SCHEME; +import static org.apache.sysml.parser.Statement.PS_UPDATE_FUN; +import static org.apache.sysml.parser.Statement.PS_UPDATE_TYPE; +import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES; +import static org.apache.sysml.parser.Statement.PS_VAL_LABELS; + +import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; -import org.apache.sysml.parser.Statement; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; +import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; +import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.matrix.operators.Operator; +import org.apache.sysml.utils.NativeHelper; public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction { + private static final int DEFAULT_BATCH_SIZE = 64; + private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.BATCH; + private static final int DEFAULT_LEVEL_PARALLELISM = InfrastructureAnalyzer.getLocalParallelism(); + private static final PSScheme DEFAULT_SCHEME = PSScheme.DISJOINT_CONTIGUOUS; + + //internal local debug level + private static final boolean LDEBUG = false; + + static { + // for internal debugging only + if (LDEBUG) { + Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel((Level) Level.DEBUG); + } + } + protected ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) { super(op, paramsMap, out, opcode, istr); @@ -34,8 +82,209 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc @Override public void processInstruction(ExecutionContext ec) { - ListObject model = (ListObject) ec.getVariable(getParam(Statement.PS_MODEL)); - ListObject outList = model.slice(0, model.getLength() - 1); - ec.setVariable(output.getName(), outList); + + PSModeType mode = PSModeType.valueOf(getParam(PS_MODE)); + int workerNum = getWorkerNum(mode); + String updFunc = getParam(PS_UPDATE_FUN); + String aggFunc = getParam(PS_AGGREGATION_FUN); + PSFrequency freq = getFrequency(); + PSUpdateType updateType = getUpdateType(); + int epochs = Integer.valueOf(getParam(PS_EPOCHS)); + if (epochs <= 0) { + throw new DMLRuntimeException( + String.format("Paramserv function: The argument '%s' could not be less than or equal to 0.", + PS_EPOCHS)); + } + long batchSize = getBatchSize(); + + // Create the parameter server + ListObject model = ec.getListObject(getParam(PS_MODEL)); + ListObject hyperParams = getHyperParams(ec); + ParamServer ps = createPS(mode, aggFunc, freq, updateType, workerNum, model, ec, hyperParams); + + // Create the local workers + List<LocalPSWorker> workers = IntStream.range(0, workerNum) + .mapToObj(i -> new LocalPSWorker((long) i, updFunc, freq, epochs, batchSize, hyperParams, ec, ps)) + .collect(Collectors.toList()); + + // Do data partition + doDataPartition(ec, workers); + + // Create the worker threads + List<Thread> threads = workers.stream().map(Thread::new).collect(Collectors.toList()); + + // Start the ps + ps.start(); + + // Start the workers + threads.forEach(Thread::start); + + // Wait for the workers stopping + threads.forEach(thread -> { + try { + thread.join(); + } catch (InterruptedException e) { + throw new DMLRuntimeException("Paramserv function: Failed to join the worker threads.", e); + } + }); + + ps.stop(); + + // Create the output + ListObject result = ps.getResult(); + ec.setVariable(output.getName(), result); + } + + private PSUpdateType getUpdateType() { + PSUpdateType updType = PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE)); + switch (updType) { + case ASP: + case SSP: + throw new DMLRuntimeException(String.format("Not support update type '%s'.", updType)); + case BSP: + break; + } + return updType; + } + + private PSFrequency getFrequency() { + if (!getParameterMap().containsKey(PS_FREQUENCY)) { + return DEFAULT_UPDATE_FREQUENCY; + } + PSFrequency freq = PSFrequency.valueOf(getParam(PS_FREQUENCY)); + switch (freq) { + case EPOCH: + throw new DMLRuntimeException("Not support epoch update frequency."); + case BATCH: + break; + } + return freq; + } + + /** + * Get the worker numbers according to the vcores + * + * @param mode execution mode + * @return worker numbers + */ + private int getWorkerNum(PSModeType mode) { + int workerNum = DEFAULT_LEVEL_PARALLELISM; + if (getParameterMap().containsKey(PS_PARALLELISM)) { + workerNum = Integer.valueOf(getParam(PS_PARALLELISM)); + } + switch (mode) { + case LOCAL: + //FIXME: this is a workaround for a maximum number of buffers in openblas + //However, the root cause is a missing function preparation for each worker + //(i.e., deep copy with unique file names, and reduced degree of parallelism) + int vcores = InfrastructureAnalyzer.getLocalParallelism(); + if ("openblas".equals(NativeHelper.getCurrentBLAS())) { + workerNum = Math.min(workerNum, vcores / 2); + } else { + workerNum = Math.min(workerNum, vcores); + } + break; + case REMOTE_SPARK: + throw new DMLRuntimeException("Do not support remote spark."); + } + return workerNum; + } + + /** + * Create a server which serves the local or remote workers + * + * @return parameter server + */ + private ParamServer createPS(PSModeType mode, String aggFunc, PSFrequency freq, PSUpdateType updateType, + int workerNum, ListObject model, ExecutionContext ec, ListObject hyperParams) { + ParamServer ps = null; + switch (mode) { + case LOCAL: + ps = new LocalParamServer(model, aggFunc, freq, updateType, ec, workerNum, hyperParams); + break; + case REMOTE_SPARK: + throw new DMLRuntimeException("Do not support remote spark."); + } + return ps; + } + + private long getBatchSize() { + if (!getParameterMap().containsKey(PS_BATCH_SIZE)) { + return DEFAULT_BATCH_SIZE; + } + long batchSize = Integer.valueOf(getParam(PS_BATCH_SIZE)); + if (batchSize <= 0) { + throw new DMLRuntimeException(String.format( + "Paramserv function: the number of argument '%s' could not be less than or equal to 0.", + PS_BATCH_SIZE)); + } + return batchSize; + } + + private ListObject getHyperParams(ExecutionContext ec) { + ListObject hyperparams = null; + if (getParameterMap().containsKey(PS_HYPER_PARAMS)) { + hyperparams = ec.getListObject(getParam(PS_HYPER_PARAMS)); + } + return hyperparams; + } + + private void doDataPartition(ExecutionContext ec, List<LocalPSWorker> workers) { + MatrixObject features = ec.getMatrixObject(getParam(PS_FEATURES)); + MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS)); + MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES)); + MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS)); + PSScheme scheme = DEFAULT_SCHEME; + if (getParameterMap().containsKey(PS_SCHEME)) { + scheme = PSScheme.valueOf(getParam(PS_SCHEME)); + } + switch (scheme) { + case DISJOINT_CONTIGUOUS: + disjointContiguous(features, labels, valFeatures, valLabels, workers); + break; + case DISJOINT_RANDOM: + case OVERLAP_RESHUFFLE: + case DISJOINT_ROUND_ROBIN: + throw new DMLRuntimeException( + String.format("Paramserv function: the scheme '%s' is not supported.", scheme)); + } + } + + private void disjointContiguous(MatrixObject features, MatrixObject labels, MatrixObject valFeatures, + MatrixObject valLabels, List<LocalPSWorker> workers) { + // training data + List<MatrixObject> pfs = disjointContiguous(workers.size(), features); + List<MatrixObject> pls = disjointContiguous(workers.size(), labels); + if (pfs.size() < workers.size()) { + LOG.warn(String.format( + "There is only %d batches of data but has %d workers. Hence, reset the number of workers with %d.", + pfs.size(), workers.size(), pfs.size())); + workers = workers.subList(0, pfs.size()); + } + for (int i = 0; i < workers.size(); i++) { + workers.get(i).setFeatures(pfs.get(i)); + workers.get(i).setLabels(pls.get(i)); + } + + // validation data + List<MatrixObject> pvfs = disjointContiguous(workers.size(), valFeatures); + List<MatrixObject> pvls = disjointContiguous(workers.size(), valLabels); + for (int i = 0; i < workers.size(); i++) { + workers.get(i).setValFeatures(pvfs.get(i)); + workers.get(i).setValLabels(pvls.get(i)); + } + } + + private List<MatrixObject> disjointContiguous(int workerNum, MatrixObject mo) { + List<MatrixObject> list = new ArrayList<>(); + long stepSize = (long) Math.ceil(mo.getNumRows() / workerNum); + long begin = 1; + while (begin < mo.getNumRows()) { + long end = Math.min(begin + stepSize, mo.getNumRows()); + MatrixObject pmo = ParamservUtils.sliceMatrix(mo, begin, end); + list.add(pmo); + begin = end + 1; + } + return list; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java index 47ea66e..43f5229 100644 --- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java @@ -1250,9 +1250,11 @@ public abstract class AutomatedTestBase if (exceptionExpected) fail("expected exception which has not been raised: " + expectedException); } catch (Exception e) { - if (exceptionExpected && e.getClass().equals(expectedException) && errMessage != null - && !e.getMessage().contains(errMessage)) { - fail("expected exception message has not been raised: " + errMessage); + if (errMessage != null && !errMessage.equals("")) { + boolean result = rCompareException(exceptionExpected, errMessage, e, false); + if (exceptionExpected && !result) { + fail(String.format("expected exception message '%s' has not been raised.", errMessage)); + } } if (!exceptionExpected || (expectedException != null && !(e.getClass().equals(expectedException)))) { e.printStackTrace(); @@ -1269,6 +1271,16 @@ public abstract class AutomatedTestBase } } + private boolean rCompareException(boolean exceptionExpected, String errMessage, Throwable e, boolean result) { + if (e.getCause() != null) { + result |= rCompareException(exceptionExpected, errMessage, e.getCause(), result); + } + if (exceptionExpected && errMessage != null && e.getMessage().contains(errMessage)) { + result = true; + } + return result; + } + public void cleanupScratchSpace() { try http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java index 1b227f1..6370099 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java @@ -32,7 +32,8 @@ public class ParamservFuncTest extends AutomatedTestBase { private static final String TEST_NAME4 = "paramserv-wrong-type-args"; private static final String TEST_NAME5 = "paramserv-wrong-args"; private static final String TEST_NAME6 = "paramserv-wrong-args2"; - private static final String TEST_NAME7 = "paramserv-ipa-test"; + private static final String TEST_NAME7 = "paramserv-nn-test"; + private static final String TEST_NAME8 = "paramserv-minimum-version"; private static final String TEST_DIR = "functions/paramserv/"; private static final String TEST_CLASS_DIR = TEST_DIR + ParamservFuncTest.class.getSimpleName() + "/"; @@ -48,53 +49,59 @@ public class ParamservFuncTest extends AutomatedTestBase { addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {})); addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {})); addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {})); + addTestConfiguration(TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {})); } @Test public void testParamservWithAllArgs() { - runDMLTest(TEST_NAME1, true, false, null, null); + runDMLTest(TEST_NAME1, false, null, null); } @Test public void testParamservWithoutOptionalArgs() { - runDMLTest(TEST_NAME2, true, false, null, null); + runDMLTest(TEST_NAME2, false, null, null); } @Test public void testParamservMissArgs() { final String errmsg = "Named parameter 'features' missing. Please specify the input."; - runDMLTest(TEST_NAME3, true, true, DMLException.class, errmsg); + runDMLTest(TEST_NAME3, true, DMLException.class, errmsg); } @Test public void testParamservWrongTypeArgs() { final String errmsg = "Input to PARAMSERV::model must be of type 'LIST'. It should not be of type 'MATRIX'"; - runDMLTest(TEST_NAME4, true, true, DMLException.class, errmsg); + runDMLTest(TEST_NAME4, true, DMLException.class, errmsg); } @Test public void testParamservWrongArgs() { final String errmsg = "Function PARAMSERV does not support value 'NSP' as the 'utype' parameter."; - runDMLTest(TEST_NAME5, true, true, DMLException.class, errmsg); + runDMLTest(TEST_NAME5, true, DMLException.class, errmsg); } @Test public void testParamservWrongArgs2() { final String errmsg = "Invalid parameters for PARAMSERV: [modelList, val_featur=X_val]"; - runDMLTest(TEST_NAME6, true, true, DMLException.class, errmsg); + runDMLTest(TEST_NAME6, true, DMLException.class, errmsg); } @Test - public void testParamservIpaTest() { - runDMLTest(TEST_NAME7, true, false, null, "1"); + public void testParamservNNTest() { + runDMLTest(TEST_NAME7, false, null, null); } - private void runDMLTest(String testname, boolean newWay, boolean exceptionExpected, Class<?> exceptionClass, + @Test + public void testParamservMinimumVersionTest() { + runDMLTest(TEST_NAME8, false, null, null); + } + + private void runDMLTest(String testname, boolean exceptionExpected, Class<?> exceptionClass, String errmsg) { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); programArgs = new String[] { "-explain" }; fullDMLScriptName = HOME + testname + ".dml"; - runTest(newWay, exceptionExpected, exceptionClass, errmsg, -1); + runTest(true, exceptionExpected, exceptionClass, errmsg, -1); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml new file mode 100644 index 0000000..2a3bbe2 --- /dev/null +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml @@ -0,0 +1,383 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +/* + * MNIST LeNet Example + */ +# Imports +source("nn/layers/affine.dml") as affine +source("nn/layers/conv2d_builtin.dml") as conv2d +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss +source("nn/layers/dropout.dml") as dropout +source("nn/layers/l2_reg.dml") as l2_reg +source("nn/layers/max_pool2d_builtin.dml") as max_pool2d +source("nn/layers/relu.dml") as relu +source("nn/layers/softmax.dml") as softmax +source("nn/optim/sgd_nesterov.dml") as sgd_nesterov + +train = function(matrix[double] X, matrix[double] Y, + matrix[double] X_val, matrix[double] Y_val, + int C, int Hin, int Win, int epochs, int workers) + return (matrix[double] W1, matrix[double] b1, + matrix[double] W2, matrix[double] b2, + matrix[double] W3, matrix[double] b3, + matrix[double] W4, matrix[double] b4) { + /* + * Trains a convolutional net using the "LeNet" architecture. + * + * The input matrix, X, has N examples, each represented as a 3D + * volume unrolled into a single vector. The targets, Y, have K + * classes, and are one-hot encoded. + * + * Inputs: + * - X: Input data matrix, of shape (N, C*Hin*Win). + * - Y: Target matrix, of shape (N, K). + * - X_val: Input validation data matrix, of shape (N, C*Hin*Win). + * - Y_val: Target validation matrix, of shape (N, K). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + * - epochs: Total number of full training loops over the full data set. + * + * Outputs: + * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf). + * - b1: 1st layer biases vector, of shape (F1, 1). + * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf). + * - b2: 2nd layer biases vector, of shape (F2, 1). + * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3). + * - b3: 3rd layer biases vector, of shape (1, N3). + * - W4: 4th layer weights (parameters) matrix, of shape (N3, K). + * - b4: 4th layer biases vector, of shape (1, K). + */ + N = nrow(X) + K = ncol(Y) + + # Create network: + # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax + Hf = 5 # filter height + Wf = 5 # filter width + stride = 1 + pad = 2 # For same dimensions, (Hf - stride) / 2 + + F1 = 32 # num conv filters in conv1 + F2 = 64 # num conv filters in conv2 + N3 = 512 # num nodes in affine3 + # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes) + + [W1, b1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win) + [W2, b2] = conv2d::init(F2, F1, Hf, Wf) # inputs: (N, F1*(Hin/2)*(Win/2)) + [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3) # inputs: (N, F2*(Hin/2/2)*(Win/2/2)) + [W4, b4] = affine::init(N3, K) # inputs: (N, N3) + W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu + + # Initialize SGD w/ Nesterov momentum optimizer + lr = 0.01 # learning rate + mu = 0.9 #0.5 # momentum + decay = 0.95 # learning rate decay constant + vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1) + vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2) + vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3) + vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4) + + # Regularization + lambda = 5e-04 + + # Create the model object + modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) + + # Create the hyper parameter list + params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) + + # Use paramserv function + modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode="LOCAL", utype="BSP", freq="BATCH", epochs=epochs, batchsize=64, k=workers, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE") + + W1 = as.matrix(modelList2["W1"]) + b1 = as.matrix(modelList2["b1"]) + W2 = as.matrix(modelList2["W2"]) + b2 = as.matrix(modelList2["b2"]) + W3 = as.matrix(modelList2["W3"]) + b3 = as.matrix(modelList2["b3"]) + W4 = as.matrix(modelList2["W4"]) + b4 = as.matrix(modelList2["b4"]) + +} + +# Should always use 'features' (batch features), 'labels' (batch labels), +# 'hyperparams', 'model' as the arguments +# and return the gradients of type list +gradients = function(matrix[double] features, + matrix[double] labels, + list[unknown] hyperparams, + list[unknown] model) + return (list[unknown] gradients) { + +# PB: not be able to get scalar from list + + C = 1 + Hin = 28 + Win = 28 + Hf = 5 + Wf = 5 + stride = 1 + pad = 2 + lambda = 5e-04 + F1 = 32 + F2 = 64 + N3 = 512 + W1 = as.matrix(model["W1"]) + b1 = as.matrix(model["b1"]) + W2 = as.matrix(model["W2"]) + b2 = as.matrix(model["b2"]) + W3 = as.matrix(model["W3"]) + b3 = as.matrix(model["b3"]) + W4 = as.matrix(model["W4"]) + b4 = as.matrix(model["b4"]) + + # Compute forward pass + ## layer 1: conv1 -> relu1 -> pool1 + [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf, + stride, stride, pad, pad) + outr1 = relu::forward(outc1) + [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 2: conv2 -> relu2 -> pool2 + [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf, + stride, stride, pad, pad) + outr2 = relu::forward(outc2) + [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 3: affine3 -> relu3 -> dropout + outa3 = affine::forward(outp2, W3, b3) + outr3 = relu::forward(outa3) + [outd3, maskd3] = dropout::forward(outr3, 0.5, -1) + ## layer 4: affine4 -> softmax + outa4 = affine::forward(outd3, W4, b4) + probs = softmax::forward(outa4) + + # Compute data backward pass + ## loss: + dprobs = cross_entropy_loss::backward(probs, labels) + ## layer 4: affine4 -> softmax + douta4 = softmax::backward(dprobs, outa4) + [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4) + ## layer 3: affine3 -> relu3 -> dropout + doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3) + douta3 = relu::backward(doutr3, outa3) + [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3) + ## layer 2: conv2 -> relu2 -> pool2 + doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + doutc2 = relu::backward(doutr2, outc2) + [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1, + Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad) + ## layer 1: conv1 -> relu1 -> pool1 + doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + doutc1 = relu::backward(doutr1, outc1) + [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win, + Hf, Wf, stride, stride, pad, pad) + + # Compute regularization backward pass + dW1_reg = l2_reg::backward(W1, lambda) + dW2_reg = l2_reg::backward(W2, lambda) + dW3_reg = l2_reg::backward(W3, lambda) + dW4_reg = l2_reg::backward(W4, lambda) + dW1 = dW1 + dW1_reg + dW2 = dW2 + dW2_reg + dW3 = dW3 + dW3_reg + dW4 = dW4 + dW4_reg + + gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4) +} + +# PB: how to handle the velocity? (put into the model) +# Should use the arguments named 'model', 'gradients', 'hyperparams' +# and return always a model of type list +aggregation = function(list[unknown] model, + list[unknown] gradients, + list[unknown] hyperparams) + return (list[unknown] modelResult) { + + W1 = as.matrix(model["W1"]) + W2 = as.matrix(model["W2"]) + W3 = as.matrix(model["W3"]) + W4 = as.matrix(model["W4"]) + b1 = as.matrix(model["b1"]) + b2 = as.matrix(model["b2"]) + b3 = as.matrix(model["b3"]) + b4 = as.matrix(model["b4"]) + dW1 = as.matrix(gradients["dW1"]) + dW2 = as.matrix(gradients["dW2"]) + dW3 = as.matrix(gradients["dW3"]) + dW4 = as.matrix(gradients["dW4"]) + db1 = as.matrix(gradients["db1"]) + db2 = as.matrix(gradients["db2"]) + db3 = as.matrix(gradients["db3"]) + db4 = as.matrix(gradients["db4"]) + vW1 = as.matrix(model["vW1"]) + vW2 = as.matrix(model["vW2"]) + vW3 = as.matrix(model["vW3"]) + vW4 = as.matrix(model["vW4"]) + vb1 = as.matrix(model["vb1"]) + vb2 = as.matrix(model["vb2"]) + vb3 = as.matrix(model["vb3"]) + vb4 = as.matrix(model["vb4"]) + lr = 0.01 + mu = 0.9 + + # Optimize with SGD w/ Nesterov momentum + [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1) + [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1) + [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2) + [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2) + [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3) + [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3) + [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4) + [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4) + + modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) + } + +predict = function(matrix[double] X, int C, int Hin, int Win, + matrix[double] W1, matrix[double] b1, + matrix[double] W2, matrix[double] b2, + matrix[double] W3, matrix[double] b3, + matrix[double] W4, matrix[double] b4) + return (matrix[double] probs) { + /* + * Computes the class probability predictions of a convolutional + * net using the "LeNet" architecture. + * + * The input matrix, X, has N examples, each represented as a 3D + * volume unrolled into a single vector. + * + * Inputs: + * - X: Input data matrix, of shape (N, C*Hin*Win). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf). + * - b1: 1st layer biases vector, of shape (F1, 1). + * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf). + * - b2: 2nd layer biases vector, of shape (F2, 1). + * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3). + * - b3: 3rd layer biases vector, of shape (1, N3). + * - W4: 4th layer weights (parameters) matrix, of shape (N3, K). + * - b4: 4th layer biases vector, of shape (1, K). + * + * Outputs: + * - probs: Class probabilities, of shape (N, K). + */ + N = nrow(X) + + # Network: + # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax + Hf = 5 # filter height + Wf = 5 # filter width + stride = 1 + pad = 2 # For same dimensions, (Hf - stride) / 2 + + F1 = nrow(W1) # num conv filters in conv1 + F2 = nrow(W2) # num conv filters in conv2 + N3 = ncol(W3) # num nodes in affine3 + K = ncol(W4) # num nodes in affine4, equal to number of target dimensions (num classes) + + # Compute predictions over mini-batches + probs = matrix(0, rows=N, cols=K) + batch_size = 64 + iters = ceil(N / batch_size) + for(i in 1:iters) { + # Get next batch + beg = ((i-1) * batch_size) %% N + 1 + end = min(N, beg + batch_size - 1) + X_batch = X[beg:end,] + + # Compute forward pass + ## layer 1: conv1 -> relu1 -> pool1 + [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad) + outr1 = relu::forward(outc1) + [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 2: conv2 -> relu2 -> pool2 + [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf, + stride, stride, pad, pad) + outr2 = relu::forward(outc2) + [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 3: affine3 -> relu3 + outa3 = affine::forward(outp2, W3, b3) + outr3 = relu::forward(outa3) + ## layer 4: affine4 -> softmax + outa4 = affine::forward(outr3, W4, b4) + probs_batch = softmax::forward(outa4) + + # Store predictions + probs[beg:end,] = probs_batch + } +} + +eval = function(matrix[double] probs, matrix[double] Y) + return (double loss, double accuracy) { + /* + * Evaluates a convolutional net using the "LeNet" architecture. + * + * The probs matrix contains the class probability predictions + * of K classes over N examples. The targets, Y, have K classes, + * and are one-hot encoded. + * + * Inputs: + * - probs: Class probabilities, of shape (N, K). + * - Y: Target matrix, of shape (N, K). + * + * Outputs: + * - loss: Scalar loss, of shape (1). + * - accuracy: Scalar accuracy, of shape (1). + */ + # Compute loss & accuracy + loss = cross_entropy_loss::forward(probs, Y) + correct_pred = rowIndexMax(probs) == rowIndexMax(Y) + accuracy = mean(correct_pred) +} + +generate_dummy_data = function() + return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) { + /* + * Generate a dummy dataset similar to the MNIST dataset. + * + * Outputs: + * - X: Input data matrix, of shape (N, D). + * - Y: Target matrix, of shape (N, K). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + */ + # Generate dummy input data + N = 1024 # num examples + C = 1 # num input channels + Hin = 28 # input height + Win = 28 # input width + K = 10 # num target classes + X = rand(rows=N, cols=C*Hin*Win, pdf="normal") + classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform")) + Y = table(seq(1, N), classes) # one-hot encoding +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml new file mode 100644 index 0000000..2ef7411 --- /dev/null +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml @@ -0,0 +1,377 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +/* + * MNIST LeNet Example + */ +# Imports +source("nn/layers/affine.dml") as affine +source("nn/layers/conv2d_builtin.dml") as conv2d +source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss +source("nn/layers/dropout.dml") as dropout +source("nn/layers/l2_reg.dml") as l2_reg +source("nn/layers/max_pool2d_builtin.dml") as max_pool2d +source("nn/layers/relu.dml") as relu +source("nn/layers/softmax.dml") as softmax +source("nn/optim/sgd_nesterov.dml") as sgd_nesterov + +train = function(matrix[double] X, matrix[double] Y, + matrix[double] X_val, matrix[double] Y_val, + int C, int Hin, int Win, int epochs, int workers) + return (matrix[double] W1, matrix[double] b1, + matrix[double] W2, matrix[double] b2, + matrix[double] W3, matrix[double] b3, + matrix[double] W4, matrix[double] b4) { + /* + * Trains a convolutional net using the "LeNet" architecture. + * + * The input matrix, X, has N examples, each represented as a 3D + * volume unrolled into a single vector. The targets, Y, have K + * classes, and are one-hot encoded. + * + * Inputs: + * - X: Input data matrix, of shape (N, C*Hin*Win). + * - Y: Target matrix, of shape (N, K). + * - X_val: Input validation data matrix, of shape (N, C*Hin*Win). + * - Y_val: Target validation matrix, of shape (N, K). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + * - epochs: Total number of full training loops over the full data set. + * + * Outputs: + * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf). + * - b1: 1st layer biases vector, of shape (F1, 1). + * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf). + * - b2: 2nd layer biases vector, of shape (F2, 1). + * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3). + * - b3: 3rd layer biases vector, of shape (1, N3). + * - W4: 4th layer weights (parameters) matrix, of shape (N3, K). + * - b4: 4th layer biases vector, of shape (1, K). + */ + N = nrow(X) + K = ncol(Y) + + # Create network: + # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax + Hf = 5 # filter height + Wf = 5 # filter width + stride = 1 + pad = 2 # For same dimensions, (Hf - stride) / 2 + + F1 = 32 # num conv filters in conv1 + F2 = 64 # num conv filters in conv2 + N3 = 512 # num nodes in affine3 + # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes) + + [W1, b1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win) + [W2, b2] = conv2d::init(F2, F1, Hf, Wf) # inputs: (N, F1*(Hin/2)*(Win/2)) + [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3) # inputs: (N, F2*(Hin/2/2)*(Win/2/2)) + [W4, b4] = affine::init(N3, K) # inputs: (N, N3) + W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu + + # Initialize SGD w/ Nesterov momentum optimizer + lr = 0.01 # learning rate + mu = 0.9 #0.5 # momentum + decay = 0.95 # learning rate decay constant + vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1) + vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2) + vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3) + vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4) + + # Regularization + lambda = 5e-04 + + # Create the model object + modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) + + # Create the hyper parameter list + params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) + + # Use paramserv function + modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation", mode="LOCAL", utype="BSP", epochs=epochs, hyperparams=params) + + W1 = as.matrix(modelList2["W1"]) + b1 = as.matrix(modelList2["b1"]) + W2 = as.matrix(modelList2["W2"]) + b2 = as.matrix(modelList2["b2"]) + W3 = as.matrix(modelList2["W3"]) + b3 = as.matrix(modelList2["b3"]) + W4 = as.matrix(modelList2["W4"]) + b4 = as.matrix(modelList2["b4"]) + +} + +gradients = function(matrix[double] features, + matrix[double] labels, + list[unknown] hyperparams, + list[unknown] model) + return (list[unknown] gradients) { + + C = 1 + Hin = 28 + Win = 28 + Hf = 5 + Wf = 5 + stride = 1 + pad = 2 + lambda = 5e-04 + F1 = 32 + F2 = 64 + N3 = 512 + W1 = as.matrix(model["W1"]) + b1 = as.matrix(model["b1"]) + W2 = as.matrix(model["W2"]) + b2 = as.matrix(model["b2"]) + W3 = as.matrix(model["W3"]) + b3 = as.matrix(model["b3"]) + W4 = as.matrix(model["W4"]) + b4 = as.matrix(model["b4"]) + + # Compute forward pass + ## layer 1: conv1 -> relu1 -> pool1 + [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf, + stride, stride, pad, pad) + outr1 = relu::forward(outc1) + [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 2: conv2 -> relu2 -> pool2 + [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf, + stride, stride, pad, pad) + outr2 = relu::forward(outc2) + [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 3: affine3 -> relu3 -> dropout + outa3 = affine::forward(outp2, W3, b3) + outr3 = relu::forward(outa3) + [outd3, maskd3] = dropout::forward(outr3, 0.5, -1) + ## layer 4: affine4 -> softmax + outa4 = affine::forward(outd3, W4, b4) + probs = softmax::forward(outa4) + + # Compute data backward pass + ## loss: + dprobs = cross_entropy_loss::backward(probs, labels) + ## layer 4: affine4 -> softmax + douta4 = softmax::backward(dprobs, outa4) + [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4) + ## layer 3: affine3 -> relu3 -> dropout + doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3) + douta3 = relu::backward(doutr3, outa3) + [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3) + ## layer 2: conv2 -> relu2 -> pool2 + doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + doutc2 = relu::backward(doutr2, outc2) + [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1, + Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad) + ## layer 1: conv1 -> relu1 -> pool1 + doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + doutc1 = relu::backward(doutr1, outc1) + [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win, + Hf, Wf, stride, stride, pad, pad) + + # Compute regularization backward pass + dW1_reg = l2_reg::backward(W1, lambda) + dW2_reg = l2_reg::backward(W2, lambda) + dW3_reg = l2_reg::backward(W3, lambda) + dW4_reg = l2_reg::backward(W4, lambda) + dW1 = dW1 + dW1_reg + dW2 = dW2 + dW2_reg + dW3 = dW3 + dW3_reg + dW4 = dW4 + dW4_reg + + gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4) + +} + +# how to handle the velocity? +aggregation = function(list[unknown] model, + list[unknown] gradients, + list[unknown] hyperparams) + return (list[unknown] modelResult) { + + W1 = as.matrix(model["W1"]) + W2 = as.matrix(model["W2"]) + W3 = as.matrix(model["W3"]) + W4 = as.matrix(model["W4"]) + b1 = as.matrix(model["b1"]) + b2 = as.matrix(model["b2"]) + b3 = as.matrix(model["b3"]) + b4 = as.matrix(model["b4"]) + dW1 = as.matrix(gradients["dW1"]) + dW2 = as.matrix(gradients["dW2"]) + dW3 = as.matrix(gradients["dW3"]) + dW4 = as.matrix(gradients["dW4"]) + db1 = as.matrix(gradients["db1"]) + db2 = as.matrix(gradients["db2"]) + db3 = as.matrix(gradients["db3"]) + db4 = as.matrix(gradients["db4"]) + vW1 = as.matrix(model["vW1"]) + vW2 = as.matrix(model["vW2"]) + vW3 = as.matrix(model["vW3"]) + vW4 = as.matrix(model["vW4"]) + vb1 = as.matrix(model["vb1"]) + vb2 = as.matrix(model["vb2"]) + vb3 = as.matrix(model["vb3"]) + vb4 = as.matrix(model["vb4"]) + lr = 0.01 + mu = 0.9 + + # Optimize with SGD w/ Nesterov momentum + [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1) + [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1) + [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2) + [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2) + [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3) + [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3) + [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4) + [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4) + + modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) + } + +predict = function(matrix[double] X, int C, int Hin, int Win, + matrix[double] W1, matrix[double] b1, + matrix[double] W2, matrix[double] b2, + matrix[double] W3, matrix[double] b3, + matrix[double] W4, matrix[double] b4) + return (matrix[double] probs) { + /* + * Computes the class probability predictions of a convolutional + * net using the "LeNet" architecture. + * + * The input matrix, X, has N examples, each represented as a 3D + * volume unrolled into a single vector. + * + * Inputs: + * - X: Input data matrix, of shape (N, C*Hin*Win). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf). + * - b1: 1st layer biases vector, of shape (F1, 1). + * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf). + * - b2: 2nd layer biases vector, of shape (F2, 1). + * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3). + * - b3: 3rd layer biases vector, of shape (1, N3). + * - W4: 4th layer weights (parameters) matrix, of shape (N3, K). + * - b4: 4th layer biases vector, of shape (1, K). + * + * Outputs: + * - probs: Class probabilities, of shape (N, K). + */ + N = nrow(X) + + # Network: + # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax + Hf = 5 # filter height + Wf = 5 # filter width + stride = 1 + pad = 2 # For same dimensions, (Hf - stride) / 2 + + F1 = nrow(W1) # num conv filters in conv1 + F2 = nrow(W2) # num conv filters in conv2 + N3 = ncol(W3) # num nodes in affine3 + K = ncol(W4) # num nodes in affine4, equal to number of target dimensions (num classes) + + # Compute predictions over mini-batches + probs = matrix(0, rows=N, cols=K) + batch_size = 64 + iters = ceil(N / batch_size) + for(i in 1:iters) { + # Get next batch + beg = ((i-1) * batch_size) %% N + 1 + end = min(N, beg + batch_size - 1) + X_batch = X[beg:end,] + + # Compute forward pass + ## layer 1: conv1 -> relu1 -> pool1 + [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride, + pad, pad) + outr1 = relu::forward(outc1) + [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 2: conv2 -> relu2 -> pool2 + [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf, + stride, stride, pad, pad) + outr2 = relu::forward(outc2) + [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2, + strideh=2, stridew=2, pad=0, pad=0) + ## layer 3: affine3 -> relu3 + outa3 = affine::forward(outp2, W3, b3) + outr3 = relu::forward(outa3) + ## layer 4: affine4 -> softmax + outa4 = affine::forward(outr3, W4, b4) + probs_batch = softmax::forward(outa4) + + # Store predictions + probs[beg:end,] = probs_batch + } +} + +eval = function(matrix[double] probs, matrix[double] Y) + return (double loss, double accuracy) { + /* + * Evaluates a convolutional net using the "LeNet" architecture. + * + * The probs matrix contains the class probability predictions + * of K classes over N examples. The targets, Y, have K classes, + * and are one-hot encoded. + * + * Inputs: + * - probs: Class probabilities, of shape (N, K). + * - Y: Target matrix, of shape (N, K). + * + * Outputs: + * - loss: Scalar loss, of shape (1). + * - accuracy: Scalar accuracy, of shape (1). + */ + # Compute loss & accuracy + loss = cross_entropy_loss::forward(probs, Y) + correct_pred = rowIndexMax(probs) == rowIndexMax(Y) + accuracy = mean(correct_pred) +} + +generate_dummy_data = function() + return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) { + /* + * Generate a dummy dataset similar to the MNIST dataset. + * + * Outputs: + * - X: Input data matrix, of shape (N, D). + * - Y: Target matrix, of shape (N, K). + * - C: Number of input channels (dimensionality of input depth). + * - Hin: Input height. + * - Win: Input width. + */ + # Generate dummy input data + N = 1024 # num examples + C = 1 # num input channels + Hin = 28 # input height + Win = 28 # input width + K = 10 # num target classes + X = rand(rows=N, cols=C*Hin*Win, pdf="normal") + classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform")) + Y = table(seq(1, N), classes) # one-hot encoding +} + http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-all-args.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-all-args.dml b/src/test/scripts/functions/paramserv/paramserv-all-args.dml index bcb3ac3..ec6e087 100644 --- a/src/test/scripts/functions/paramserv/paramserv-all-args.dml +++ b/src/test/scripts/functions/paramserv/paramserv-all-args.dml @@ -20,7 +20,7 @@ #------------------------------------------------------------- e1 = "element1" -paramsList = list(e1) +paramsList = list(e1=e1) X = matrix(1, rows=2, cols=3) Y = matrix(2, rows=2, cols=3) X_val = matrix(3, rows=2, cols=3) @@ -35,7 +35,7 @@ aggregation = function (matrix[double] input) return (matrix[double] output) { } e2 = "element2" -hps = list(e2) +hps = list(e2=e2) # Use paramserv function paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE") http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml b/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml deleted file mode 100644 index 5aed767..0000000 --- a/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml +++ /dev/null @@ -1,47 +0,0 @@ -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -e1 = "element1" -paramsList = list(e1) -X = matrix(1, rows=2, cols=3) -Y = matrix(2, rows=2, cols=3) -X_val = matrix(3, rows=2, cols=3) -Y_val = matrix(4, rows=2, cols=3) - -gradients = function (matrix[double] input) return (matrix[double] output) { - output = input -} - -aggregation = function (matrix[double] input) return (matrix[double] output) { - output = input -} - -e2 = "element2" -hps = list(e2) - -# Use paramserv function -paramsList2 = list(1, 2, 3) - -if (length(paramsList2) == 3) { - paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE") -} - -print(length(paramsList2)) \ No newline at end of file
