Repository: systemml Updated Branches: refs/heads/master 19b310c6b -> 78e9d836e
[SYSTEMML-2299] Cleanup paramserv language API, incl defaults Closes #817. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/78e9d836 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/78e9d836 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/78e9d836 Branch: refs/heads/master Commit: 78e9d836ea16296fcf3bbd647b60638ce2bc24c3 Parents: 19b310c Author: EdgarLGB <[email protected]> Authored: Sun Aug 5 14:36:09 2018 +0200 Committer: Matthias Boehm <[email protected]> Committed: Sun Aug 5 17:56:22 2018 -0700 ---------------------------------------------------------------------- .../ParameterizedBuiltinFunctionExpression.java | 10 ++-- .../controlprogram/paramserv/LocalPSWorker.java | 5 +- .../controlprogram/paramserv/PSWorker.java | 8 +-- .../cp/ParamservBuiltinCPInstruction.java | 58 ++++++++++---------- .../paramserv/mnist_lenet_paramserv.dml | 10 ++-- .../mnist_lenet_paramserv_minimum_version.dml | 12 ++-- .../paramserv-without-optional-args.dml | 12 ++++ 7 files changed, 59 insertions(+), 56 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java index 33c5b0e..7c516e6 100644 --- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java @@ -330,12 +330,12 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier checkDataType(fname, Statement.PS_MODEL, DataType.LIST, conditional); // check the model which is the only non-parameterized argument checkDataType(fname, Statement.PS_FEATURES, DataType.MATRIX, conditional); checkDataType(fname, Statement.PS_LABELS, DataType.MATRIX, conditional); - checkDataType(fname, Statement.PS_VAL_FEATURES, DataType.MATRIX, conditional); - checkDataType(fname, Statement.PS_VAL_LABELS, DataType.MATRIX, conditional); + checkDataValueType(true, fname, Statement.PS_VAL_FEATURES, DataType.MATRIX, ValueType.DOUBLE, conditional); + checkDataValueType(true, fname, Statement.PS_VAL_LABELS, DataType.MATRIX, ValueType.DOUBLE, conditional); checkDataValueType(false, fname, Statement.PS_UPDATE_FUN, DataType.SCALAR, ValueType.STRING, conditional); checkDataValueType(false, fname, Statement.PS_AGGREGATION_FUN, DataType.SCALAR, ValueType.STRING, conditional); - checkStringParam(false, fname, Statement.PS_MODE, conditional); - checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, conditional); + checkStringParam(true, fname, Statement.PS_MODE, conditional); + checkStringParam(true, fname, Statement.PS_UPDATE_TYPE, conditional); checkStringParam(true, fname, Statement.PS_FREQUENCY, conditional); checkDataValueType(false, fname, Statement.PS_EPOCHS, DataType.SCALAR, ValueType.INT, conditional); checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT, conditional); @@ -860,7 +860,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier if (optional) { return; } - raiseValidateError(String.format("Named parameter '%s' is missing. Please specify the input.", fname), + raiseValidateError(String.format("Named parameter '%s' is missing. Please specify the input.", pname), conditional, LanguageErrorCodes.INVALID_PARAMETERS); } else if (data.getOutput().getDataType() != dt || data.getOutput().getValueType() != vt) raiseValidateError(String.format("Input to %s::%s must be of type '%s', '%s'.It should not be of type '%s', '%s'.", http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index 04050b2..5ab4e07 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -39,9 +39,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { protected LocalPSWorker() {} - public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, - MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) { - super(workerID, updFunc, freq, epochs, batchSize, valFeatures, valLabels, ec, ps); + public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { + super(workerID, updFunc, freq, epochs, batchSize, ec, ps); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index 63600d1..7c73a71 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -50,23 +50,17 @@ public abstract class PSWorker implements Serializable protected FunctionCallCPInstruction _inst; protected MatrixObject _features; protected MatrixObject _labels; - - protected MatrixObject _valFeatures; - protected MatrixObject _valLabels; protected String _updFunc; protected Statement.PSFrequency _freq; protected PSWorker() {} - protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, - MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) { + protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) { _workerID = workerID; _updFunc = updFunc; _freq = freq; _epochs = epochs; _batchSize = batchSize; - _valFeatures = valFeatures; - _valLabels = valLabels; _ec = ec; _ps = ps; setupUpdateFunction(updFunc, ec); http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 83ec3f7..b6bb6fb 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -36,8 +36,6 @@ import static org.apache.sysml.parser.Statement.PS_PARALLELISM; import static org.apache.sysml.parser.Statement.PS_SCHEME; import static org.apache.sysml.parser.Statement.PS_UPDATE_FUN; import static org.apache.sysml.parser.Statement.PS_UPDATE_TYPE; -import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES; -import static org.apache.sysml.parser.Statement.PS_VAL_LABELS; import java.util.HashMap; import java.util.HashSet; @@ -65,14 +63,14 @@ import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme; -import org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner; import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSBody; import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSWorker; +import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme; +import org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner; import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcFactory; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; @@ -83,8 +81,10 @@ import org.apache.sysml.utils.Statistics; public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction { private static final int DEFAULT_BATCH_SIZE = 64; - private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.BATCH; + private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.EPOCH; private static final PSScheme DEFAULT_SCHEME = PSScheme.DISJOINT_CONTIGUOUS; + private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL; + private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP; //internal local debug level private static final boolean LDEBUG = false; @@ -113,13 +113,14 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc case REMOTE_SPARK: runOnSpark((SparkExecutionContext) ec, mode); break; + default: + throw new DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode)); } } private void runOnSpark(SparkExecutionContext sec, PSModeType mode) { Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null; - PSScheme scheme = getScheme(); int workerNum = getWorkerNum(mode); String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); @@ -129,9 +130,6 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc // Level of par is 1 in spark backend because one worker will be launched per task ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); - MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES)); - MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS)); - // Create the agg service's execution context ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0); @@ -172,24 +170,25 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc if (DMLScript.STATISTICS) Statistics.accPSSetupTime((long) tSetup.stop()); + MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES)); + MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS)); try { - ParamservUtils.doPartitionOnSpark(sec, features, labels, scheme, workerNum) // Do data partitioning + ParamservUtils.doPartitionOnSpark(sec, features, labels, getScheme(), workerNum) // Do data partitioning .foreach(worker); // Run remote workers } catch (Exception e) { throw new DMLRuntimeException("Paramserv function failed: ", e); } finally { - // Stop the netty server - server.close(); + server.close(); // Stop the netty server } // Accumulate the statistics for remote workers if (DMLScript.STATISTICS) { - Statistics.accPSSetupTime(aSetup.value().longValue()); - Statistics.incWorkerNumber(aWorker.value().longValue()); - Statistics.accPSLocalModelUpdateTime(aUpdate.value().longValue()); - Statistics.accPSBatchIndexingTime(aIndex.value().longValue()); - Statistics.accPSGradientComputeTime(aGrad.value().longValue()); - Statistics.accPSRpcRequestTime(aRPC.value().longValue()); + Statistics.accPSSetupTime(aSetup.value()); + Statistics.incWorkerNumber(aWorker.value()); + Statistics.accPSLocalModelUpdateTime(aUpdate.value()); + Statistics.accPSBatchIndexingTime(aIndex.value()); + Statistics.accPSGradientComputeTime(aGrad.value()); + Statistics.accPSRpcRequestTime(aRPC.value()); } // Fetch the final model from ps @@ -205,11 +204,9 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); - int k = getParLevel(workerNum); - // Get the compiled execution context LocalVariableMap newVarsMap = createVarsMap(ec); - ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, k); + ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, getParLevel(workerNum)); // Create workers' execution context List<ExecutionContext> workerECs = ParamservUtils.copyExecutionContext(newEC, workerNum); @@ -219,17 +216,14 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc PSFrequency freq = getFrequency(); PSUpdateType updateType = getUpdateType(); - int epochs = getEpochs(); // Create the parameter server ListObject model = ec.getListObject(getParam(PS_MODEL)); ParamServer ps = createPS(mode, aggFunc, updateType, workerNum, model, aggServiceEC); // Create the local workers - MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES)); - MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS)); List<LocalPSWorker> workers = IntStream.range(0, workerNum) - .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps)) + .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, getEpochs(), getBatchSize(), workerECs.get(i), ps)) .collect(Collectors.toList()); // Do data partition @@ -251,8 +245,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc for (Future<Void> ret : es.invokeAll(workers)) ret.get(); //error handling // Fetch the final model from ps - ListObject result = ps.getResult(); - ec.setVariable(output.getName(), result); + ec.setVariable(output.getName(), ps.getResult()); } catch (InterruptedException | ExecutionException e) { throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e); } finally { @@ -271,6 +264,9 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc } private PSModeType getPSMode() { + if (!getParameterMap().containsKey(PS_MODE)) { + return DEFAULT_MODE; + } PSModeType mode; try { mode = PSModeType.valueOf(getParam(PS_MODE)); @@ -294,6 +290,9 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc } private PSUpdateType getUpdateType() { + if (!getParameterMap().containsKey(PS_UPDATE_TYPE)) { + return DEFAULT_TYPE; + } PSUpdateType updType; try { updType = PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE)); @@ -301,7 +300,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc throw new DMLRuntimeException(String.format("Paramserv function: not support update type '%s'.", getParam(PS_UPDATE_TYPE))); } if (updType == PSUpdateType.SSP) - throw new DMLRuntimeException("Not support update type SSP."); + throw new DMLRuntimeException("Paramserv function: Not support update type SSP."); return updType; } @@ -318,7 +317,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc } private int getRemainingCores() { - return InfrastructureAnalyzer.getLocalParallelism() - 1; + return InfrastructureAnalyzer.getLocalParallelism(); } /** @@ -330,7 +329,6 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc private int getWorkerNum(PSModeType mode) { switch (mode) { case LOCAL: - // default worker number: available cores - 1 (assign one process for agg service) return getParameterMap().containsKey(PS_PARALLELISM) ? Integer.valueOf(getParam(PS_PARALLELISM)) : getRemainingCores(); case REMOTE_SPARK: http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml index 5ccda12..028440f 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml @@ -123,10 +123,10 @@ train = function(matrix[double] X, matrix[double] Y, # Should always use 'features' (batch features), 'labels' (batch labels), # 'hyperparams', 'model' as the arguments # and return the gradients of type list -gradients = function(matrix[double] features, - matrix[double] labels, +gradients = function(list[unknown] model, list[unknown] hyperparams, - list[unknown] model) + matrix[double] features, + matrix[double] labels) return (list[unknown] gradients) { C = as.integer(as.scalar(hyperparams["C"])) @@ -205,8 +205,8 @@ gradients = function(matrix[double] features, # Should use the arguments named 'model', 'gradients', 'hyperparams' # and return always a model of type list aggregation = function(list[unknown] model, - list[unknown] gradients, - list[unknown] hyperparams) + list[unknown] hyperparams, + list[unknown] gradients) return (list[unknown] modelResult) { W1 = as.matrix(model[1]) W2 = as.matrix(model[2]) http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml index e7056f0..25ca23e 100644 --- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml +++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml @@ -107,7 +107,7 @@ train = function(matrix[double] X, matrix[double] Y, params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3) # Use paramserv function - modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation", mode="LOCAL", utype="BSP", epochs=epochs, hyperparams=params) + modelList2 = paramserv(model=modelList, features=X, labels=Y, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation", epochs=epochs, hyperparams=params) W1 = as.matrix(modelList2[1]) W2 = as.matrix(modelList2[2]) @@ -120,10 +120,10 @@ train = function(matrix[double] X, matrix[double] Y, } -gradients = function(matrix[double] features, - matrix[double] labels, +gradients = function(list[unknown] model, list[unknown] hyperparams, - list[unknown] model) + matrix[double] features, + matrix[double] labels) return (list[unknown] gradients) { C = as.integer(as.scalar(hyperparams["C"])) @@ -200,8 +200,8 @@ gradients = function(matrix[double] features, } aggregation = function(list[unknown] model, - list[unknown] gradients, - list[unknown] hyperparams) + list[unknown] hyperparams, + list[unknown] gradients) return (list[unknown] modelResult) { W1 = as.matrix(model[1]) W2 = as.matrix(model[2]) http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml b/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml index 6d06ce2..425d364 100644 --- a/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml +++ b/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml @@ -38,6 +38,18 @@ e2 = "element2" params = list(e2=e2) # Use paramserv function +# Remove the optional "val_features" and "val_labels" +modelList2 = paramserv(model=modelList, features=X, labels=Y, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH") + +# Remove the optional "mode" +modelList2 = paramserv(model=modelList, features=X, labels=Y, upd="gradients", agg="aggregation", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH") + +# Remove the optional "utype" +modelList2 = paramserv(model=modelList, features=X, labels=Y, upd="gradients", agg="aggregation", epochs=100, freq="EPOCH", batchsize=64, k=7, scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH") + +# Remove the optional "freq" +modelList2 = paramserv(model=modelList, features=X, labels=Y, upd="gradients", agg="aggregation", utype="BSP", epochs=100, batchsize=64, k=7, scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH") + # Remove the optional "hyperparams" modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH")
