Repository: systemml Updated Branches: refs/heads/master 0e01af524 -> 7027a532d
[SYSTEMML-2413] Fix paramserv contention on DAG recompilation Closes #789. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7027a532 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7027a532 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7027a532 Branch: refs/heads/master Commit: 7027a532dff62304cbde8d18a3f45633c078423b Parents: 0e01af5 Author: EdgarLGB <[email protected]> Authored: Thu Jun 21 22:32:07 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Jun 21 23:02:51 2018 -0700 ---------------------------------------------------------------------- .../controlprogram/FunctionProgramBlock.java | 2 +- .../controlprogram/paramserv/LocalPSWorker.java | 2 +- .../controlprogram/paramserv/PSWorker.java | 17 ++--- .../controlprogram/paramserv/ParamServer.java | 37 ++++------ .../paramserv/ParamservUtils.java | 76 ++++++++++++-------- .../cp/ParamservBuiltinCPInstruction.java | 28 ++++---- .../paramserv/ParamservRuntimeNegativeTest.java | 4 +- .../paramserv/mnist_lenet_paramserv.dml | 9 ++- .../mnist_lenet_paramserv_minimum_version.dml | 5 +- .../paramserv/paramserv-large-parallelism.dml | 3 +- .../paramserv/paramserv-minimum-version.dml | 3 +- .../paramserv/paramserv-nn-asp-batch.dml | 5 +- .../paramserv/paramserv-nn-asp-epoch.dml | 5 +- .../paramserv/paramserv-nn-bsp-batch-dc.dml | 5 +- .../paramserv/paramserv-nn-bsp-batch-dr.dml | 5 +- .../paramserv/paramserv-nn-bsp-batch-drr.dml | 5 +- .../paramserv/paramserv-nn-bsp-batch-or.dml | 5 +- .../paramserv/paramserv-nn-bsp-epoch.dml | 5 +- 18 files changed, 114 insertions(+), 107 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java index 5ea84a7..25b2017 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java @@ -110,7 +110,7 @@ public class FunctionProgramBlock extends ProgramBlock } // for each program block - try { + try { for (int i=0 ; i < this._childBlocks.size() ; i++) { ec.updateDebugState(i); _childBlocks.get(i).execute(ec); http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index beb5e45..4f472ee 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -107,7 +107,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int totalIter) { Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null; - globalParams = _ps.updateModel(gradients, globalParams); + globalParams = _ps.updateModel(_ec, gradients, globalParams); if (DMLScript.STATISTICS) Statistics.accPSLocalModelUpdateTime((long) tUpd.stop()); http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index f46370d..a76dfec 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -19,12 +19,11 @@ package org.apache.sysml.runtime.controlprogram.paramserv; -import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.UPDATE_FUNC_PREFIX; +import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX; import java.util.ArrayList; import java.util.stream.Collectors; -import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.Expression; import org.apache.sysml.parser.Statement; @@ -66,14 +65,10 @@ public abstract class PSWorker { _ps = ps; // Get the update function - String[] keys = DMLProgram.splitFunctionKey(updFunc); - String funcName = keys[0]; - String funcNS = null; - if (keys.length == 2) { - funcNS = keys[0]; - funcName = keys[1]; - } - FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(funcNS, UPDATE_FUNC_PREFIX + _workerID + "_" + funcName); + String[] cfn = ParamservUtils.getCompleteFuncName(updFunc, PS_FUNC_PREFIX); + String ns = cfn[0]; + String fname = cfn[1]; + FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname); ArrayList<DataIdentifier> inputs = func.getInputParams(); ArrayList<DataIdentifier> outputs = func.getOutputParams(); CPOperand[] boundInputs = inputs.stream() @@ -83,7 +78,7 @@ public abstract class PSWorker { .collect(Collectors.toCollection(ArrayList::new)); ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) .collect(Collectors.toCollection(ArrayList::new)); - _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames, "update function"); + _inst = new FunctionCallCPInstruction(ns, fname, boundInputs, inputNames, outputNames, "update function"); // Check the inputs of the update function checkInput(false, inputs, Expression.DataType.MATRIX, Statement.PS_FEATURES); http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java index 0ddfb40..4af72a4 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -19,7 +19,7 @@ package org.apache.sysml.runtime.controlprogram.paramserv; -import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.AGG_FUNC_PREFIX; +import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX; import java.util.ArrayList; import java.util.Arrays; @@ -40,14 +40,12 @@ import org.apache.commons.lang3.concurrent.BasicThreadFactory; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; -import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.Expression; import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.Data; @@ -101,11 +99,7 @@ public abstract class ParamServer { return _model; } - public ListObject updateModel(ListObject gradients, ListObject model) { - //note: we use a new execution context to allow for concurrent execution of ASP local updates; - //otherwise synchronized on the aggService instance would serialize those - ExecutionContext ec = ExecutionContextFactory.createContext(_aggService._ec.getProgram()); - ec.setVariable(Statement.PS_HYPER_PARAMS, _aggService._ec.getVariable(Statement.PS_HYPER_PARAMS)); + public ListObject updateModel(ExecutionContext ec, ListObject gradients, ListObject model) { return _aggService.updateModel(ec, gradients, model); } @@ -138,14 +132,10 @@ public abstract class ParamServer { _finishedStates = new boolean[workerNum]; // Fetch the aggregation function - String[] keys = DMLProgram.splitFunctionKey(aggFunc); - String funcName = keys[0]; - String funcNS = null; - if (keys.length == 2) { - funcNS = keys[0]; - funcName = keys[1]; - } - FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(funcNS, AGG_FUNC_PREFIX + funcName); + String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX); + String ns = cfn[0]; + String fname = cfn[1]; + FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(ns, fname); ArrayList<DataIdentifier> inputs = func.getInputParams(); ArrayList<DataIdentifier> outputs = func.getOutputParams(); @@ -165,7 +155,7 @@ public abstract class ParamServer { .collect(Collectors.toCollection(ArrayList::new)); ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName) .collect(Collectors.toCollection(ArrayList::new)); - _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames, "aggregate function"); + _inst = new FunctionCallCPInstruction(ns, fname, boundInputs, inputNames, outputNames, "aggregate function"); } private boolean allFinished() { @@ -248,16 +238,13 @@ public abstract class ParamServer { return null; } - /** - * A synchronized service method for updating model with gradients - * - * @param gradients A list object of gradients - * @return A updated list object of model - */ - private synchronized ListObject updateModel(ListObject gradients, ListObject model) { + private ListObject updateModel(ListObject gradients, ListObject model) { return updateModel(_ec, gradients, model); } - + + /** + * A service method for updating model with gradients + */ private ListObject updateModel(ExecutionContext ec, ListObject gradients, ListObject model) { // Populate the variables table with the gradients and model ec.setVariable(Statement.PS_GRADIENTS, gradients); http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index c4a3d98..6c76aa6 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.commons.lang.StringUtils; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.MultiThreadedHop; import org.apache.sysml.hops.OptimizerUtils; @@ -38,6 +39,7 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.ForProgramBlock; import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysml.runtime.controlprogram.IfProgramBlock; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.ParForProgramBlock; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.ProgramBlock; @@ -58,8 +60,7 @@ import org.apache.sysml.runtime.matrix.data.OutputInfo; public class ParamservUtils { - public static final String UPDATE_FUNC_PREFIX = "_worker_"; - public static final String AGG_FUNC_PREFIX = "_agg_"; + public static final String PS_FUNC_PREFIX = "_ps_"; /** * Deep copy the list object @@ -131,7 +132,17 @@ public class ParamservUtils { return permutation; } - public static ExecutionContext createExecutionContext(ExecutionContext ec, String updFunc, String aggFunc, int workerNum, int k) { + public static String[] getCompleteFuncName(String funcName, String prefix) { + String[] keys = DMLProgram.splitFunctionKey(funcName); + String ns = (keys.length==2) ? keys[0] : null; + String name = (keys.length==2) ? keys[1] : keys[0]; + return StringUtils.isEmpty(prefix) ? + new String[]{ns, name} : new String[]{ns, name}; + } + + public static List<ExecutionContext> createExecutionContexts(ExecutionContext ec, LocalVariableMap varsMap, + String updFunc, String aggFunc, int workerNum, int k) { + FunctionProgramBlock updPB = getFunctionBlock(ec, updFunc); FunctionProgramBlock aggPB = getFunctionBlock(ec, aggFunc); @@ -142,27 +153,40 @@ public class ParamservUtils { // 2. Recompile the imported function blocks prog.getFunctionProgramBlocks().forEach((fname, fvalue) -> recompileProgramBlocks(k, fvalue.getChildBlocks())); - // Copy function for workers - IntStream.range(0, workerNum).forEach(i -> copyFunction(updFunc, updPB, prog, UPDATE_FUNC_PREFIX + i + "_")); + // 3. Copy function for workers + List<ExecutionContext> workerECs = IntStream.range(0, workerNum) + .mapToObj(i -> { + FunctionProgramBlock newUpdFunc = copyFunction(updFunc, updPB); + FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB); + Program newProg = new Program(); + putFunction(newProg, newUpdFunc); + putFunction(newProg, newAggFunc); + return ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg); + }) + .collect(Collectors.toList()); - // Copy function for agg service - copyFunction(aggFunc, aggPB, prog, AGG_FUNC_PREFIX); + // 4. Copy function for agg service + FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB); + Program newProg = new Program(); + putFunction(newProg, newAggFunc); + ExecutionContext aggEC = ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg); - return ExecutionContextFactory.createContext(prog); + List<ExecutionContext> result = new ArrayList<>(workerECs); + result.add(aggEC); + return result; } - private static void copyFunction(String funcName, FunctionProgramBlock updPB, Program prog, String prefix) { - String[] keys = DMLProgram.splitFunctionKey(funcName); - String namespace = null; - String func = keys[0]; - if (keys.length == 2) { - namespace = keys[0]; - func = keys[1]; - } - FunctionProgramBlock copiedFunc = ProgramConverter - .createDeepCopyFunctionProgramBlock(updPB, new HashSet<>(), new HashSet<>()); - String fnameNew = prefix + func; - prog.addFunctionProgramBlock(namespace, fnameNew, copiedFunc); + private static FunctionProgramBlock copyFunction(String funcName, FunctionProgramBlock fpb) { + FunctionProgramBlock copiedFunc = ProgramConverter.createDeepCopyFunctionProgramBlock(fpb, new HashSet<>(), new HashSet<>()); + String[] cfn = getCompleteFuncName(funcName, ParamservUtils.PS_FUNC_PREFIX); + copiedFunc._namespace = cfn[0]; + copiedFunc._functionName = cfn[1]; + return copiedFunc; + } + + private static void putFunction(Program prog, FunctionProgramBlock fpb) { + prog.addFunctionProgramBlock(fpb._namespace, fpb._functionName, fpb); + prog.addProgramBlock(fpb); } private static void recompileProgramBlocks(int k, ArrayList<ProgramBlock> pbs) { @@ -229,13 +253,9 @@ public class ParamservUtils { private static FunctionProgramBlock getFunctionBlock(ExecutionContext ec, String funcName) { - String[] keys = DMLProgram.splitFunctionKey(funcName); - String namespace = null; - String func = keys[0]; - if (keys.length == 2) { - namespace = keys[0]; - func = keys[1]; - } - return ec.getProgram().getFunctionProgramBlock(namespace, func); + String[] cfn = getCompleteFuncName(funcName, null); + String ns = cfn[0]; + String fname = cfn[1]; + return ec.getProgram().getFunctionProgramBlock(ns, fname); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index d2c0edd..be80127 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -56,10 +56,8 @@ import org.apache.log4j.Logger; import org.apache.sysml.api.DMLScript; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; -import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner; import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDC; import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDR; @@ -112,13 +110,15 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc int k = getParLevel(workerNum); // Get the compiled execution context - ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, updFunc, aggFunc, workerNum, k); + // Create workers' execution context + LocalVariableMap newVarsMap = createVarsMap(ec); + List<ExecutionContext> newECs = ParamservUtils.createExecutionContexts(ec, newVarsMap, updFunc, aggFunc, workerNum, k); // Create workers' execution context - List<ExecutionContext> workerECs = createExecutionContext(workerNum, ec, newEC.getProgram()); + List<ExecutionContext> workerECs = newECs.subList(0, newECs.size() - 1); // Create the agg service's execution context - ExecutionContext aggServiceEC = createExecutionContext(1, ec, newEC.getProgram()).get(0); + ExecutionContext aggServiceEC = newECs.get(newECs.size() - 1); PSFrequency freq = getFrequency(); PSUpdateType updateType = getUpdateType(); @@ -165,16 +165,14 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc } } - private List<ExecutionContext> createExecutionContext(int size, ExecutionContext ec, Program program) { - return IntStream.range(0, size).mapToObj(i -> { - // Put the hyperparam into the variables table - LocalVariableMap varsMap = new LocalVariableMap(); - ListObject hyperParams = getHyperParams(ec); - if (hyperParams != null) { - varsMap.put(PS_HYPER_PARAMS, hyperParams); - } - return ExecutionContextFactory.createContext(varsMap, program); - }).collect(Collectors.toList()); + private LocalVariableMap createVarsMap(ExecutionContext ec) { + // Put the hyperparam into the variables table + LocalVariableMap varsMap = new LocalVariableMap(); + ListObject hyperParams = getHyperParams(ec); + if (hyperParams != null) { + varsMap.put(PS_HYPER_PARAMS, hyperParams); + } + return varsMap; } private PSModeType getPSMode() { http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java index 7238cf9..a81fd36 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java @@ -44,12 +44,12 @@ public class ParamservRuntimeNegativeTest extends AutomatedTestBase { @Test public void testParamservWorkerFailed() { - runDMLTest(TEST_NAME1, "Invalid lookup by name in unnamed list: worker_err."); + runDMLTest(TEST_NAME1, "Invalid indexing by name in unnamed list: worker_err."); } @Test public void testParamservAggServiceFailed() { - runDMLTest(TEST_NAME2, "Invalid lookup by name in unnamed list: agg_service_err"); + runDMLTest(TEST_NAME2, "Invalid indexing by name in unnamed list: agg_service_err"); } @Test http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml index a10c846..acafc88 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml @@ -36,7 +36,7 @@ source("nn/optim/sgd_nesterov.dml") as sgd_nesterov train = function(matrix[double] X, matrix[double] Y, matrix[double] X_val, matrix[double] Y_val, int C, int Hin, int Win, int epochs, int workers, - string utype, string freq, string scheme) + string utype, string freq, int batchsize, string scheme) return (matrix[double] W1, matrix[double] b1, matrix[double] W2, matrix[double] b2, matrix[double] W3, matrix[double] b3, @@ -108,7 +108,7 @@ train = function(matrix[double] X, matrix[double] Y, params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) # Use paramserv function - modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode="LOCAL", utype=utype, freq=freq, epochs=epochs, batchsize=64, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE") + modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode="LOCAL", utype=utype, freq=freq, epochs=epochs, batchsize=batchsize, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE") W1 = as.matrix(modelList2["W1"]) b1 = as.matrix(modelList2["b1"]) @@ -256,7 +256,7 @@ aggregation = function(list[unknown] model, modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) } -predict = function(matrix[double] X, int C, int Hin, int Win, +predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, matrix[double] W1, matrix[double] b1, matrix[double] W2, matrix[double] b2, matrix[double] W3, matrix[double] b3, @@ -302,9 +302,8 @@ predict = function(matrix[double] X, int C, int Hin, int Win, # Compute predictions over mini-batches probs = matrix(0, rows=N, cols=K) - batch_size = 64 iters = ceil(N / batch_size) - for(i in 1:iters) { + parfor(i in 1:iters, check=0) { # Get next batch beg = ((i-1) * batch_size) %% N + 1 end = min(N, beg + batch_size - 1) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml index d02e5d6..aeec3df 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml @@ -249,7 +249,7 @@ aggregation = function(list[unknown] model, modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4) } -predict = function(matrix[double] X, int C, int Hin, int Win, +predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, matrix[double] W1, matrix[double] b1, matrix[double] W2, matrix[double] b2, matrix[double] W3, matrix[double] b3, @@ -295,9 +295,8 @@ predict = function(matrix[double] X, int C, int Hin, int Win, # Compute predictions over mini-batches probs = matrix(0, rows=N, cols=K) - batch_size = 64 iters = ceil(N / batch_size) - for(i in 1:iters) { + parfor(i in 1:iters, check=0) { # Get next batch beg = ((i-1) * batch_size) %% N + 1 end = min(N, beg + batch_size - 1) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml b/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml index 4e2d2b7..de29a04 100644 --- a/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml +++ b/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml @@ -39,12 +39,13 @@ Y_val = labels[1:val_size,] # Arguments epochs = 10 workers = 10 +batchsize = 64 # Train [W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers) # Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) loss_val = cross_entropy_loss::forward(probs_val, Y_val) accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml b/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml index 4d23b8c..a537fe9 100644 --- a/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml +++ b/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml @@ -39,12 +39,13 @@ Y_val = labels[1:val_size,] # Arguments epochs = 10 workers = 2 +batchsize = 64 # Train [W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers) # Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) loss_val = cross_entropy_loss::forward(probs_val, Y_val) accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml index baef6ee..2279d58 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml @@ -39,12 +39,13 @@ Y_val = labels[1:val_size,] # Arguments epochs = 10 workers = 2 +batchsize = 32 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "BATCH", "DISJOINT_CONTIGUOUS") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "BATCH", batchsize,"DISJOINT_CONTIGUOUS") # Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) loss_val = cross_entropy_loss::forward(probs_val, Y_val) accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml index 860f53f..1824083 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml @@ -39,12 +39,13 @@ Y_val = labels[1:val_size,] # Arguments epochs = 10 workers = 2 +batchsize = 32 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "EPOCH", "DISJOINT_CONTIGUOUS") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "EPOCH", batchsize, "DISJOINT_CONTIGUOUS") # Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) loss_val = cross_entropy_loss::forward(probs_val, Y_val) accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml index dcbd2dd..2e09de4 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml @@ -39,12 +39,13 @@ Y_val = labels[1:val_size,] # Arguments epochs = 10 workers = 2 +batchsize = 32 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", "DISJOINT_CONTIGUOUS") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_CONTIGUOUS") # Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) loss_val = cross_entropy_loss::forward(probs_val, Y_val) accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml index 96fe734..8444952 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml @@ -39,12 +39,13 @@ Y_val = labels[1:val_size,] # Arguments epochs = 10 workers = 2 +batchsize = 32 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", "DISJOINT_RANDOM") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_RANDOM") # Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) loss_val = cross_entropy_loss::forward(probs_val, Y_val) accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml index e97dbff..ccb7ffc 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml @@ -39,12 +39,13 @@ Y_val = labels[1:val_size,] # Arguments epochs = 10 workers = 4 +batchsize = 32 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", "DISJOINT_ROUND_ROBIN") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_ROUND_ROBIN") # Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) loss_val = cross_entropy_loss::forward(probs_val, Y_val) accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml index a2e95d3..4afc56b 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml @@ -39,12 +39,13 @@ Y_val = labels[1:val_size,] # Arguments epochs = 10 workers = 2 +batchsize = 32 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", "OVERLAP_RESHUFFLE") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "OVERLAP_RESHUFFLE") # Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) loss_val = cross_entropy_loss::forward(probs_val, Y_val) accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val)) http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml index 25d5f48..c542286 100644 --- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml +++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml @@ -39,12 +39,13 @@ Y_val = labels[1:val_size,] # Arguments epochs = 10 workers = 2 +batchsize = 32 # Train -[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "EPOCH", "DISJOINT_CONTIGUOUS") +[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "EPOCH", batchsize,"DISJOINT_CONTIGUOUS") # Compute validation loss & accuracy -probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4) +probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4) loss_val = cross_entropy_loss::forward(probs_val, Y_val) accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
