Repository: systemml
Updated Branches:
  refs/heads/master 19b310c6b -> 78e9d836e


[SYSTEMML-2299] Cleanup paramserv language API, incl defaults

Closes #817.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/78e9d836
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/78e9d836
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/78e9d836

Branch: refs/heads/master
Commit: 78e9d836ea16296fcf3bbd647b60638ce2bc24c3
Parents: 19b310c
Author: EdgarLGB <[email protected]>
Authored: Sun Aug 5 14:36:09 2018 +0200
Committer: Matthias Boehm <[email protected]>
Committed: Sun Aug 5 17:56:22 2018 -0700

----------------------------------------------------------------------
 .../ParameterizedBuiltinFunctionExpression.java | 10 ++--
 .../controlprogram/paramserv/LocalPSWorker.java |  5 +-
 .../controlprogram/paramserv/PSWorker.java      |  8 +--
 .../cp/ParamservBuiltinCPInstruction.java       | 58 ++++++++++----------
 .../paramserv/mnist_lenet_paramserv.dml         | 10 ++--
 .../mnist_lenet_paramserv_minimum_version.dml   | 12 ++--
 .../paramserv-without-optional-args.dml         | 12 ++++
 7 files changed, 59 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
index 33c5b0e..7c516e6 100644
--- 
a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
@@ -330,12 +330,12 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                checkDataType(fname, Statement.PS_MODEL, DataType.LIST, 
conditional); // check the model which is the only non-parameterized argument
                checkDataType(fname, Statement.PS_FEATURES, DataType.MATRIX, 
conditional);
                checkDataType(fname, Statement.PS_LABELS, DataType.MATRIX, 
conditional);
-               checkDataType(fname, Statement.PS_VAL_FEATURES, 
DataType.MATRIX, conditional);
-               checkDataType(fname, Statement.PS_VAL_LABELS, DataType.MATRIX, 
conditional);
+               checkDataValueType(true, fname, Statement.PS_VAL_FEATURES, 
DataType.MATRIX, ValueType.DOUBLE, conditional);
+               checkDataValueType(true, fname, Statement.PS_VAL_LABELS, 
DataType.MATRIX, ValueType.DOUBLE, conditional);
                checkDataValueType(false, fname, Statement.PS_UPDATE_FUN, 
DataType.SCALAR, ValueType.STRING, conditional);
                checkDataValueType(false, fname, Statement.PS_AGGREGATION_FUN, 
DataType.SCALAR, ValueType.STRING, conditional);
-               checkStringParam(false, fname, Statement.PS_MODE, conditional);
-               checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, 
conditional);
+               checkStringParam(true, fname, Statement.PS_MODE, conditional);
+               checkStringParam(true, fname, Statement.PS_UPDATE_TYPE, 
conditional);
                checkStringParam(true, fname, Statement.PS_FREQUENCY, 
conditional);
                checkDataValueType(false, fname, Statement.PS_EPOCHS, 
DataType.SCALAR, ValueType.INT, conditional);
                checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, 
DataType.SCALAR, ValueType.INT, conditional);
@@ -860,7 +860,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                        if (optional) {
                                return;
                        }
-                       raiseValidateError(String.format("Named parameter '%s' 
is missing. Please specify the input.", fname),
+                       raiseValidateError(String.format("Named parameter '%s' 
is missing. Please specify the input.", pname),
                                        conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
                } else if (data.getOutput().getDataType() != dt || 
data.getOutput().getValueType() != vt)
                        raiseValidateError(String.format("Input to %s::%s must 
be of type '%s', '%s'.It should not be of type '%s', '%s'.",

http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index 04050b2..5ab4e07 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -39,9 +39,8 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
 
        protected LocalPSWorker() {}
 
-       public LocalPSWorker(int workerID, String updFunc, 
Statement.PSFrequency freq, int epochs, long batchSize,
-               MatrixObject valFeatures, MatrixObject valLabels, 
ExecutionContext ec, ParamServer ps) {
-               super(workerID, updFunc, freq, epochs, batchSize, valFeatures, 
valLabels, ec, ps);
+       public LocalPSWorker(int workerID, String updFunc, 
Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, 
ParamServer ps) {
+               super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index 63600d1..7c73a71 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -50,23 +50,17 @@ public abstract class PSWorker implements Serializable
        protected FunctionCallCPInstruction _inst;
        protected MatrixObject _features;
        protected MatrixObject _labels;
-
-       protected MatrixObject _valFeatures;
-       protected MatrixObject _valLabels;
        protected String _updFunc;
        protected Statement.PSFrequency _freq;
 
        protected PSWorker() {}
 
-       protected PSWorker(int workerID, String updFunc, Statement.PSFrequency 
freq, int epochs, long batchSize,
-               MatrixObject valFeatures, MatrixObject valLabels, 
ExecutionContext ec, ParamServer ps) {
+       protected PSWorker(int workerID, String updFunc, Statement.PSFrequency 
freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) {
                _workerID = workerID;
                _updFunc = updFunc;
                _freq = freq;
                _epochs = epochs;
                _batchSize = batchSize;
-               _valFeatures = valFeatures;
-               _valLabels = valLabels;
                _ec = ec;
                _ps = ps;
                setupUpdateFunction(updFunc, ec);

http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 83ec3f7..b6bb6fb 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -36,8 +36,6 @@ import static 
org.apache.sysml.parser.Statement.PS_PARALLELISM;
 import static org.apache.sysml.parser.Statement.PS_SCHEME;
 import static org.apache.sysml.parser.Statement.PS_UPDATE_FUN;
 import static org.apache.sysml.parser.Statement.PS_UPDATE_TYPE;
-import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES;
-import static org.apache.sysml.parser.Statement.PS_VAL_LABELS;
 
 import java.util.HashMap;
 import java.util.HashSet;
@@ -65,14 +63,14 @@ import 
org.apache.sysml.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import 
org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
-import 
org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSBody;
 import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSWorker;
+import 
org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
+import 
org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
 import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
@@ -83,8 +81,10 @@ import org.apache.sysml.utils.Statistics;
 public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruction {
 
        private static final int DEFAULT_BATCH_SIZE = 64;
-       private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = 
PSFrequency.BATCH;
+       private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = 
PSFrequency.EPOCH;
        private static final PSScheme DEFAULT_SCHEME = 
PSScheme.DISJOINT_CONTIGUOUS;
+       private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL;
+       private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP;
 
        //internal local debug level
        private static final boolean LDEBUG = false;
@@ -113,13 +113,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        case REMOTE_SPARK:
                                runOnSpark((SparkExecutionContext) ec, mode);
                                break;
+                       default:
+                               throw new 
DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
                }
        }
 
        private void runOnSpark(SparkExecutionContext sec, PSModeType mode) {
                Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
 
-               PSScheme scheme = getScheme();
                int workerNum = getWorkerNum(mode);
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
@@ -129,9 +130,6 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                // Level of par is 1 in spark backend because one worker will 
be launched per task
                ExecutionContext newEC = 
ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1);
 
-               MatrixObject features = 
sec.getMatrixObject(getParam(PS_FEATURES));
-               MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
-
                // Create the agg service's execution context
                ExecutionContext aggServiceEC = 
ParamservUtils.copyExecutionContext(newEC, 1).get(0);
 
@@ -172,24 +170,25 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                if (DMLScript.STATISTICS)
                        Statistics.accPSSetupTime((long) tSetup.stop());
 
+               MatrixObject features = 
sec.getMatrixObject(getParam(PS_FEATURES));
+               MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
                try {
-                       ParamservUtils.doPartitionOnSpark(sec, features, 
labels, scheme, workerNum) // Do data partitioning
+                       ParamservUtils.doPartitionOnSpark(sec, features, 
labels, getScheme(), workerNum) // Do data partitioning
                                .foreach(worker); // Run remote workers
                } catch (Exception e) {
                        throw new DMLRuntimeException("Paramserv function 
failed: ", e);
                } finally {
-                       // Stop the netty server
-                       server.close();
+                       server.close(); // Stop the netty server
                }
 
                // Accumulate the statistics for remote workers
                if (DMLScript.STATISTICS) {
-                       Statistics.accPSSetupTime(aSetup.value().longValue());
-                       Statistics.incWorkerNumber(aWorker.value().longValue());
-                       
Statistics.accPSLocalModelUpdateTime(aUpdate.value().longValue());
-                       
Statistics.accPSBatchIndexingTime(aIndex.value().longValue());
-                       
Statistics.accPSGradientComputeTime(aGrad.value().longValue());
-                       
Statistics.accPSRpcRequestTime(aRPC.value().longValue());
+                       Statistics.accPSSetupTime(aSetup.value());
+                       Statistics.incWorkerNumber(aWorker.value());
+                       Statistics.accPSLocalModelUpdateTime(aUpdate.value());
+                       Statistics.accPSBatchIndexingTime(aIndex.value());
+                       Statistics.accPSGradientComputeTime(aGrad.value());
+                       Statistics.accPSRpcRequestTime(aRPC.value());
                }
 
                // Fetch the final model from ps
@@ -205,11 +204,9 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
 
-               int k = getParLevel(workerNum);
-
                // Get the compiled execution context
                LocalVariableMap newVarsMap = createVarsMap(ec);
-               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, k);
+               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, 
getParLevel(workerNum));
 
                // Create workers' execution context
                List<ExecutionContext> workerECs = 
ParamservUtils.copyExecutionContext(newEC, workerNum);
@@ -219,17 +216,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                PSFrequency freq = getFrequency();
                PSUpdateType updateType = getUpdateType();
-               int epochs = getEpochs();
 
                // Create the parameter server
                ListObject model = ec.getListObject(getParam(PS_MODEL));
                ParamServer ps = createPS(mode, aggFunc, updateType, workerNum, 
model, aggServiceEC);
 
                // Create the local workers
-               MatrixObject valFeatures = 
ec.getMatrixObject(getParam(PS_VAL_FEATURES));
-               MatrixObject valLabels = 
ec.getMatrixObject(getParam(PS_VAL_LABELS));
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
-                       .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, 
epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps))
+                       .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, 
getEpochs(), getBatchSize(), workerECs.get(i), ps))
                        .collect(Collectors.toList());
 
                // Do data partition
@@ -251,8 +245,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        for (Future<Void> ret : es.invokeAll(workers))
                                ret.get(); //error handling
                        // Fetch the final model from ps
-                       ListObject result = ps.getResult();
-                       ec.setVariable(output.getName(), result);
+                       ec.setVariable(output.getName(), ps.getResult());
                } catch (InterruptedException | ExecutionException e) {
                        throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
                } finally {
@@ -271,6 +264,9 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private PSModeType getPSMode() {
+               if (!getParameterMap().containsKey(PS_MODE)) {
+                       return DEFAULT_MODE;
+               }
                PSModeType mode;
                try {
                        mode = PSModeType.valueOf(getParam(PS_MODE));
@@ -294,6 +290,9 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private PSUpdateType getUpdateType() {
+               if (!getParameterMap().containsKey(PS_UPDATE_TYPE)) {
+                       return DEFAULT_TYPE;
+               }
                PSUpdateType updType;
                try {
                        updType = 
PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE));
@@ -301,7 +300,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        throw new DMLRuntimeException(String.format("Paramserv 
function: not support update type '%s'.", getParam(PS_UPDATE_TYPE)));
                }
                if (updType == PSUpdateType.SSP)
-                       throw new DMLRuntimeException("Not support update type 
SSP.");
+                       throw new DMLRuntimeException("Paramserv function: Not 
support update type SSP.");
                return updType;
        }
 
@@ -318,7 +317,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private int getRemainingCores() {
-               return InfrastructureAnalyzer.getLocalParallelism() - 1;
+               return InfrastructureAnalyzer.getLocalParallelism();
        }
 
        /**
@@ -330,7 +329,6 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        private int getWorkerNum(PSModeType mode) {
                switch (mode) {
                        case LOCAL:
-                               // default worker number: available cores - 1 
(assign one process for agg service)
                                return 
getParameterMap().containsKey(PS_PARALLELISM) ?
                                        
Integer.valueOf(getParam(PS_PARALLELISM)) : getRemainingCores();
                        case REMOTE_SPARK:

http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
index 5ccda12..028440f 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -123,10 +123,10 @@ train = function(matrix[double] X, matrix[double] Y,
 # Should always use 'features' (batch features), 'labels' (batch labels),
 # 'hyperparams', 'model' as the arguments
 # and return the gradients of type list
-gradients = function(matrix[double] features,
-                     matrix[double] labels,
+gradients = function(list[unknown] model,
                      list[unknown] hyperparams,
-                     list[unknown] model)
+                     matrix[double] features,
+                     matrix[double] labels)
           return (list[unknown] gradients) {
 
   C = as.integer(as.scalar(hyperparams["C"]))
@@ -205,8 +205,8 @@ gradients = function(matrix[double] features,
 # Should use the arguments named 'model', 'gradients', 'hyperparams'
 # and return always a model of type list
 aggregation = function(list[unknown] model,
-                       list[unknown] gradients,
-                       list[unknown] hyperparams)
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
    return (list[unknown] modelResult) {
      W1 = as.matrix(model[1])
      W2 = as.matrix(model[2])

http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
index e7056f0..25ca23e 100644
--- 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
+++ 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -107,7 +107,7 @@ train = function(matrix[double] X, matrix[double] Y,
   params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
 
   # Use paramserv function
-  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation",
 mode="LOCAL", utype="BSP", epochs=epochs, hyperparams=params)
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation",
 epochs=epochs, hyperparams=params)
 
   W1 = as.matrix(modelList2[1])
   W2 = as.matrix(modelList2[2])
@@ -120,10 +120,10 @@ train = function(matrix[double] X, matrix[double] Y,
 
 }
 
-gradients = function(matrix[double] features,
-                     matrix[double] labels,
+gradients = function(list[unknown] model,
                      list[unknown] hyperparams,
-                     list[unknown] model)
+                     matrix[double] features,
+                     matrix[double] labels)
           return (list[unknown] gradients) {
 
   C = as.integer(as.scalar(hyperparams["C"]))
@@ -200,8 +200,8 @@ gradients = function(matrix[double] features,
 }
 
 aggregation = function(list[unknown] model,
-                       list[unknown] gradients,
-                       list[unknown] hyperparams)
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
    return (list[unknown] modelResult) {
      W1 = as.matrix(model[1])
      W2 = as.matrix(model[2])

http://git-wip-us.apache.org/repos/asf/systemml/blob/78e9d836/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml 
b/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml
index 6d06ce2..425d364 100644
--- a/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml
@@ -38,6 +38,18 @@ e2 = "element2"
 params = list(e2=e2)
 
 # Use paramserv function
+# Remove the optional "val_features" and "val_labels"
+modelList2 = paramserv(model=modelList, features=X, labels=Y, upd="gradients", 
agg="aggregation", mode="REMOTE_SPARK", utype="BSP", freq="EPOCH", epochs=100, 
batchsize=64, k=7, scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH")
+
+# Remove the optional "mode"
+modelList2 = paramserv(model=modelList, features=X, labels=Y, upd="gradients", 
agg="aggregation", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, 
scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH")
+
+# Remove the optional "utype"
+modelList2 = paramserv(model=modelList, features=X, labels=Y, upd="gradients", 
agg="aggregation", epochs=100, freq="EPOCH", batchsize=64, k=7, 
scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH")
+
+# Remove the optional "freq"
+modelList2 = paramserv(model=modelList, features=X, labels=Y, upd="gradients", 
agg="aggregation", utype="BSP", epochs=100, batchsize=64, k=7, 
scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH")
+
 # Remove the optional "hyperparams"
 modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", 
mode="REMOTE_SPARK", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, 
scheme="DISJOINT_ROUND_ROBIN", checkpointing="EPOCH")
 

Reply via email to