Repository: systemml
Updated Branches:
  refs/heads/master 0e01af524 -> 7027a532d


[SYSTEMML-2413] Fix paramserv contention on DAG recompilation

Closes #789.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7027a532
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7027a532
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7027a532

Branch: refs/heads/master
Commit: 7027a532dff62304cbde8d18a3f45633c078423b
Parents: 0e01af5
Author: EdgarLGB <[email protected]>
Authored: Thu Jun 21 22:32:07 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Jun 21 23:02:51 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/FunctionProgramBlock.java    |  2 +-
 .../controlprogram/paramserv/LocalPSWorker.java |  2 +-
 .../controlprogram/paramserv/PSWorker.java      | 17 ++---
 .../controlprogram/paramserv/ParamServer.java   | 37 ++++------
 .../paramserv/ParamservUtils.java               | 76 ++++++++++++--------
 .../cp/ParamservBuiltinCPInstruction.java       | 28 ++++----
 .../paramserv/ParamservRuntimeNegativeTest.java |  4 +-
 .../paramserv/mnist_lenet_paramserv.dml         |  9 ++-
 .../mnist_lenet_paramserv_minimum_version.dml   |  5 +-
 .../paramserv/paramserv-large-parallelism.dml   |  3 +-
 .../paramserv/paramserv-minimum-version.dml     |  3 +-
 .../paramserv/paramserv-nn-asp-batch.dml        |  5 +-
 .../paramserv/paramserv-nn-asp-epoch.dml        |  5 +-
 .../paramserv/paramserv-nn-bsp-batch-dc.dml     |  5 +-
 .../paramserv/paramserv-nn-bsp-batch-dr.dml     |  5 +-
 .../paramserv/paramserv-nn-bsp-batch-drr.dml    |  5 +-
 .../paramserv/paramserv-nn-bsp-batch-or.dml     |  5 +-
 .../paramserv/paramserv-nn-bsp-epoch.dml        |  5 +-
 18 files changed, 114 insertions(+), 107 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
index 5ea84a7..25b2017 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java
@@ -110,7 +110,7 @@ public class FunctionProgramBlock extends ProgramBlock
                }
                
                // for each program block
-               try {                                           
+               try {
                        for (int i=0 ; i < this._childBlocks.size() ; i++) {
                                ec.updateDebugState(i);
                                _childBlocks.get(i).execute(ec);

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index beb5e45..4f472ee 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -107,7 +107,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        private ListObject updateModel(ListObject globalParams, ListObject 
gradients, int i, int j, int totalIter) {
                Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
 
-               globalParams = _ps.updateModel(gradients, globalParams);
+               globalParams = _ps.updateModel(_ec, gradients, globalParams);
 
                if (DMLScript.STATISTICS)
                        Statistics.accPSLocalModelUpdateTime((long) 
tUpd.stop());

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index f46370d..a76dfec 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -19,12 +19,11 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
-import static 
org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.UPDATE_FUNC_PREFIX;
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX;
 
 import java.util.ArrayList;
 import java.util.stream.Collectors;
 
-import org.apache.sysml.parser.DMLProgram;
 import org.apache.sysml.parser.DataIdentifier;
 import org.apache.sysml.parser.Expression;
 import org.apache.sysml.parser.Statement;
@@ -66,14 +65,10 @@ public abstract class PSWorker {
                _ps = ps;
 
                // Get the update function
-               String[] keys = DMLProgram.splitFunctionKey(updFunc);
-               String funcName = keys[0];
-               String funcNS = null;
-               if (keys.length == 2) {
-                       funcNS = keys[0];
-                       funcName = keys[1];
-               }
-               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(funcNS, UPDATE_FUNC_PREFIX + _workerID 
+ "_" + funcName);
+               String[] cfn = ParamservUtils.getCompleteFuncName(updFunc, 
PS_FUNC_PREFIX);
+               String ns = cfn[0];
+               String fname = cfn[1];
+               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(ns, fname);
                ArrayList<DataIdentifier> inputs = func.getInputParams();
                ArrayList<DataIdentifier> outputs = func.getOutputParams();
                CPOperand[] boundInputs = inputs.stream()
@@ -83,7 +78,7 @@ public abstract class PSWorker {
                        .collect(Collectors.toCollection(ArrayList::new));
                ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                        .collect(Collectors.toCollection(ArrayList::new));
-               _inst = new FunctionCallCPInstruction(funcNS, funcName, 
boundInputs, inputNames, outputNames, "update function");
+               _inst = new FunctionCallCPInstruction(ns, fname, boundInputs, 
inputNames, outputNames, "update function");
 
                // Check the inputs of the update function
                checkInput(false, inputs, Expression.DataType.MATRIX, 
Statement.PS_FEATURES);

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index 0ddfb40..4af72a4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -19,7 +19,7 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
-import static 
org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.AGG_FUNC_PREFIX;
+import static 
org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -40,14 +40,12 @@ import 
org.apache.commons.lang3.concurrent.BasicThreadFactory;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.parser.DMLProgram;
 import org.apache.sysml.parser.DataIdentifier;
 import org.apache.sysml.parser.Expression;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.Data;
@@ -101,11 +99,7 @@ public abstract class ParamServer {
                return _model;
        }
 
-       public ListObject updateModel(ListObject gradients, ListObject model) {
-               //note: we use a new execution context to allow for concurrent 
execution of ASP local updates; 
-               //otherwise synchronized on the aggService instance would 
serialize those
-               ExecutionContext ec = 
ExecutionContextFactory.createContext(_aggService._ec.getProgram());
-               ec.setVariable(Statement.PS_HYPER_PARAMS, 
_aggService._ec.getVariable(Statement.PS_HYPER_PARAMS));
+       public ListObject updateModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
                return _aggService.updateModel(ec, gradients, model);
        }
 
@@ -138,14 +132,10 @@ public abstract class ParamServer {
                        _finishedStates = new boolean[workerNum];
 
                        // Fetch the aggregation function
-                       String[] keys = DMLProgram.splitFunctionKey(aggFunc);
-                       String funcName = keys[0];
-                       String funcNS = null;
-                       if (keys.length == 2) {
-                               funcNS = keys[0];
-                               funcName = keys[1];
-                       }
-                       FunctionProgramBlock func = 
_ec.getProgram().getFunctionProgramBlock(funcNS, AGG_FUNC_PREFIX + funcName);
+                       String[] cfn = 
ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX);
+                       String ns = cfn[0];
+                       String fname = cfn[1];
+                       FunctionProgramBlock func = 
_ec.getProgram().getFunctionProgramBlock(ns, fname);
                        ArrayList<DataIdentifier> inputs = 
func.getInputParams();
                        ArrayList<DataIdentifier> outputs = 
func.getOutputParams();
 
@@ -165,7 +155,7 @@ public abstract class ParamServer {
                                
.collect(Collectors.toCollection(ArrayList::new));
                        ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                                
.collect(Collectors.toCollection(ArrayList::new));
-                       _inst = new FunctionCallCPInstruction(funcNS, funcName, 
boundInputs, inputNames, outputNames, "aggregate function");
+                       _inst = new FunctionCallCPInstruction(ns, fname, 
boundInputs, inputNames, outputNames, "aggregate function");
                }
 
                private boolean allFinished() {
@@ -248,16 +238,13 @@ public abstract class ParamServer {
                        return null;
                }
 
-               /**
-                * A synchronized service method for updating model with 
gradients
-                *
-                * @param gradients A list object of gradients
-                * @return A updated list object of model
-                */
-               private synchronized ListObject updateModel(ListObject 
gradients, ListObject model) {
+               private ListObject updateModel(ListObject gradients, ListObject 
model) {
                        return updateModel(_ec, gradients, model);
                }
-               
+
+               /**
+                * A service method for updating model with gradients
+                */
                private ListObject updateModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
                        // Populate the variables table with the gradients and 
model
                        ec.setVariable(Statement.PS_GRADIENTS, gradients);

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index c4a3d98..6c76aa6 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -26,6 +26,7 @@ import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
+import org.apache.commons.lang.StringUtils;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.MultiThreadedHop;
 import org.apache.sysml.hops.OptimizerUtils;
@@ -38,6 +39,7 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
 import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
 import org.apache.sysml.runtime.controlprogram.Program;
 import org.apache.sysml.runtime.controlprogram.ProgramBlock;
@@ -58,8 +60,7 @@ import org.apache.sysml.runtime.matrix.data.OutputInfo;
 
 public class ParamservUtils {
 
-       public static final String UPDATE_FUNC_PREFIX = "_worker_";
-       public static final String AGG_FUNC_PREFIX = "_agg_";
+       public static final String PS_FUNC_PREFIX = "_ps_";
 
        /**
         * Deep copy the list object
@@ -131,7 +132,17 @@ public class ParamservUtils {
                return permutation;
        }
 
-       public static ExecutionContext createExecutionContext(ExecutionContext 
ec, String updFunc, String aggFunc, int workerNum, int k) {
+       public static String[] getCompleteFuncName(String funcName, String 
prefix) {
+               String[] keys = DMLProgram.splitFunctionKey(funcName);
+               String ns = (keys.length==2) ? keys[0] : null;
+               String name = (keys.length==2) ? keys[1] : keys[0];
+               return StringUtils.isEmpty(prefix) ? 
+                       new String[]{ns, name} : new String[]{ns, name};
+       }
+
+       public static List<ExecutionContext> 
createExecutionContexts(ExecutionContext ec, LocalVariableMap varsMap,
+               String updFunc, String aggFunc, int workerNum, int k) {
+
                FunctionProgramBlock updPB = getFunctionBlock(ec, updFunc);
                FunctionProgramBlock aggPB = getFunctionBlock(ec, aggFunc);
 
@@ -142,27 +153,40 @@ public class ParamservUtils {
                // 2. Recompile the imported function blocks
                prog.getFunctionProgramBlocks().forEach((fname, fvalue) -> 
recompileProgramBlocks(k, fvalue.getChildBlocks()));
 
-               // Copy function for workers
-               IntStream.range(0, workerNum).forEach(i -> 
copyFunction(updFunc, updPB, prog, UPDATE_FUNC_PREFIX + i + "_"));
+               // 3. Copy function for workers
+               List<ExecutionContext> workerECs = IntStream.range(0, workerNum)
+                       .mapToObj(i -> {
+                               FunctionProgramBlock newUpdFunc = 
copyFunction(updFunc, updPB);
+                               FunctionProgramBlock newAggFunc = 
copyFunction(aggFunc, aggPB);
+                               Program newProg = new Program();
+                               putFunction(newProg, newUpdFunc);
+                               putFunction(newProg, newAggFunc);
+                               return 
ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg);
+                       })
+                       .collect(Collectors.toList());
 
-               // Copy function for agg service
-               copyFunction(aggFunc, aggPB, prog, AGG_FUNC_PREFIX);
+               // 4. Copy function for agg service
+               FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB);
+               Program newProg = new Program();
+               putFunction(newProg, newAggFunc);
+               ExecutionContext aggEC = 
ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg);
 
-               return ExecutionContextFactory.createContext(prog);
+               List<ExecutionContext> result = new ArrayList<>(workerECs);
+               result.add(aggEC);
+               return result;
        }
 
-       private static void copyFunction(String funcName, FunctionProgramBlock 
updPB, Program prog, String prefix) {
-               String[] keys = DMLProgram.splitFunctionKey(funcName);
-               String namespace = null;
-               String func = keys[0];
-               if (keys.length == 2) {
-                       namespace = keys[0];
-                       func = keys[1];
-               }
-               FunctionProgramBlock copiedFunc = ProgramConverter
-                       .createDeepCopyFunctionProgramBlock(updPB, new 
HashSet<>(), new HashSet<>());
-               String fnameNew = prefix + func;
-               prog.addFunctionProgramBlock(namespace, fnameNew, copiedFunc);
+       private static FunctionProgramBlock copyFunction(String funcName, 
FunctionProgramBlock fpb) {
+               FunctionProgramBlock copiedFunc = 
ProgramConverter.createDeepCopyFunctionProgramBlock(fpb, new HashSet<>(), new 
HashSet<>());
+               String[] cfn = getCompleteFuncName(funcName, 
ParamservUtils.PS_FUNC_PREFIX);
+               copiedFunc._namespace = cfn[0];
+               copiedFunc._functionName = cfn[1];
+               return copiedFunc;
+       }
+
+       private static void putFunction(Program prog, FunctionProgramBlock fpb) 
{
+               prog.addFunctionProgramBlock(fpb._namespace, fpb._functionName, 
fpb);
+               prog.addProgramBlock(fpb);
        }
 
        private static void recompileProgramBlocks(int k, 
ArrayList<ProgramBlock> pbs) {
@@ -229,13 +253,9 @@ public class ParamservUtils {
 
 
        private static FunctionProgramBlock getFunctionBlock(ExecutionContext 
ec, String funcName) {
-               String[] keys = DMLProgram.splitFunctionKey(funcName);
-               String namespace = null;
-               String func = keys[0];
-               if (keys.length == 2) {
-                       namespace = keys[0];
-                       func = keys[1];
-               }
-               return ec.getProgram().getFunctionProgramBlock(namespace, func);
+               String[] cfn = getCompleteFuncName(funcName, null);
+               String ns = cfn[0];
+               String fname = cfn[1];
+               return ec.getProgram().getFunctionProgramBlock(ns, fname);
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index d2c0edd..be80127 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -56,10 +56,8 @@ import org.apache.log4j.Logger;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
-import org.apache.sysml.runtime.controlprogram.Program;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner;
 import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDC;
 import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDR;
@@ -112,13 +110,15 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                int k = getParLevel(workerNum);
 
                // Get the compiled execution context
-               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(ec, updFunc, aggFunc, workerNum, k);
+               // Create workers' execution context
+               LocalVariableMap newVarsMap = createVarsMap(ec);
+               List<ExecutionContext> newECs = 
ParamservUtils.createExecutionContexts(ec, newVarsMap, updFunc, aggFunc, 
workerNum, k);
 
                // Create workers' execution context
-               List<ExecutionContext> workerECs = 
createExecutionContext(workerNum, ec, newEC.getProgram());
+               List<ExecutionContext> workerECs = newECs.subList(0, 
newECs.size() - 1);
 
                // Create the agg service's execution context
-               ExecutionContext aggServiceEC = createExecutionContext(1, ec, 
newEC.getProgram()).get(0);
+               ExecutionContext aggServiceEC = newECs.get(newECs.size() - 1);
 
                PSFrequency freq = getFrequency();
                PSUpdateType updateType = getUpdateType();
@@ -165,16 +165,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                }
        }
 
-       private List<ExecutionContext> createExecutionContext(int size, 
ExecutionContext ec, Program program) {
-               return IntStream.range(0, size).mapToObj(i -> {
-                       // Put the hyperparam into the variables table
-                       LocalVariableMap varsMap = new LocalVariableMap();
-                       ListObject hyperParams = getHyperParams(ec);
-                       if (hyperParams != null) {
-                               varsMap.put(PS_HYPER_PARAMS, hyperParams);
-                       }
-                       return ExecutionContextFactory.createContext(varsMap, 
program);
-               }).collect(Collectors.toList());
+       private LocalVariableMap createVarsMap(ExecutionContext ec) {
+               // Put the hyperparam into the variables table
+               LocalVariableMap varsMap = new LocalVariableMap();
+               ListObject hyperParams = getHyperParams(ec);
+               if (hyperParams != null) {
+                       varsMap.put(PS_HYPER_PARAMS, hyperParams);
+               }
+               return varsMap;
        }
 
        private PSModeType getPSMode() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java
index 7238cf9..a81fd36 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservRuntimeNegativeTest.java
@@ -44,12 +44,12 @@ public class ParamservRuntimeNegativeTest extends 
AutomatedTestBase {
 
        @Test
        public void testParamservWorkerFailed() {
-               runDMLTest(TEST_NAME1, "Invalid lookup by name in unnamed list: 
worker_err.");
+               runDMLTest(TEST_NAME1, "Invalid indexing by name in unnamed 
list: worker_err.");
        }
 
        @Test
        public void testParamservAggServiceFailed() {
-               runDMLTest(TEST_NAME2, "Invalid lookup by name in unnamed list: 
agg_service_err");
+               runDMLTest(TEST_NAME2, "Invalid indexing by name in unnamed 
list: agg_service_err");
        }
 
        @Test

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
index a10c846..acafc88 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -36,7 +36,7 @@ source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
 train = function(matrix[double] X, matrix[double] Y,
                  matrix[double] X_val, matrix[double] Y_val,
                  int C, int Hin, int Win, int epochs, int workers,
-                 string utype, string freq, string scheme)
+                 string utype, string freq, int batchsize, string scheme)
     return (matrix[double] W1, matrix[double] b1,
             matrix[double] W2, matrix[double] b2,
             matrix[double] W3, matrix[double] b3,
@@ -108,7 +108,7 @@ train = function(matrix[double] X, matrix[double] Y,
   params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
 
   # Use paramserv function
-  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation",
 mode="LOCAL", utype=utype, freq=freq, epochs=epochs, batchsize=64, k=workers, 
scheme=scheme, hyperparams=params, checkpointing="NONE")
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation",
 mode="LOCAL", utype=utype, freq=freq, epochs=epochs, batchsize=batchsize, 
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
 
   W1 = as.matrix(modelList2["W1"])
   b1 = as.matrix(modelList2["b1"])
@@ -256,7 +256,7 @@ aggregation = function(list[unknown] model,
      modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, 
b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
    }
 
-predict = function(matrix[double] X, int C, int Hin, int Win,
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,
                    matrix[double] W1, matrix[double] b1,
                    matrix[double] W2, matrix[double] b2,
                    matrix[double] W3, matrix[double] b3,
@@ -302,9 +302,8 @@ predict = function(matrix[double] X, int C, int Hin, int 
Win,
 
   # Compute predictions over mini-batches
   probs = matrix(0, rows=N, cols=K)
-  batch_size = 64
   iters = ceil(N / batch_size)
-  for(i in 1:iters) {
+  parfor(i in 1:iters, check=0) {
     # Get next batch
     beg = ((i-1) * batch_size) %% N + 1
     end = min(N, beg + batch_size - 1)

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
index d02e5d6..aeec3df 100644
--- 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
+++ 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -249,7 +249,7 @@ aggregation = function(list[unknown] model,
      modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, 
b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
    }
 
-predict = function(matrix[double] X, int C, int Hin, int Win,
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,
                    matrix[double] W1, matrix[double] b1,
                    matrix[double] W2, matrix[double] b2,
                    matrix[double] W3, matrix[double] b3,
@@ -295,9 +295,8 @@ predict = function(matrix[double] X, int C, int Hin, int 
Win,
 
   # Compute predictions over mini-batches
   probs = matrix(0, rows=N, cols=K)
-  batch_size = 64
   iters = ceil(N / batch_size)
-  for(i in 1:iters) {
+  parfor(i in 1:iters, check=0) {
     # Get next batch
     beg = ((i-1) * batch_size) %% N + 1
     end = min(N, beg + batch_size - 1)

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml 
b/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml
index 4e2d2b7..de29a04 100644
--- a/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-large-parallelism.dml
@@ -39,12 +39,13 @@ Y_val = labels[1:val_size,]
 # Arguments
 epochs = 10
 workers = 10
+batchsize = 64
 
 # Train
 [W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers)
 
 # Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
 loss_val = cross_entropy_loss::forward(probs_val, Y_val)
 accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml 
b/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml
index 4d23b8c..a537fe9 100644
--- a/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml
@@ -39,12 +39,13 @@ Y_val = labels[1:val_size,]
 # Arguments
 epochs = 10
 workers = 2
+batchsize = 64
 
 # Train
 [W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers)
 
 # Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
 loss_val = cross_entropy_loss::forward(probs_val, Y_val)
 accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
index baef6ee..2279d58 100644
--- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
@@ -39,12 +39,13 @@ Y_val = labels[1:val_size,]
 # Arguments
 epochs = 10
 workers = 2
+batchsize = 32
 
 # Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "ASP", "BATCH", "DISJOINT_CONTIGUOUS")
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "ASP", "BATCH", batchsize,"DISJOINT_CONTIGUOUS")
 
 # Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
 loss_val = cross_entropy_loss::forward(probs_val, Y_val)
 accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
index 860f53f..1824083 100644
--- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
@@ -39,12 +39,13 @@ Y_val = labels[1:val_size,]
 # Arguments
 epochs = 10
 workers = 2
+batchsize = 32
 
 # Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "ASP", "EPOCH", "DISJOINT_CONTIGUOUS")
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "ASP", "EPOCH", batchsize, "DISJOINT_CONTIGUOUS")
 
 # Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
 loss_val = cross_entropy_loss::forward(probs_val, Y_val)
 accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
index dcbd2dd..2e09de4 100644
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
@@ -39,12 +39,13 @@ Y_val = labels[1:val_size,]
 # Arguments
 epochs = 10
 workers = 2
+batchsize = 32
 
 # Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", "DISJOINT_CONTIGUOUS")
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_CONTIGUOUS")
 
 # Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
 loss_val = cross_entropy_loss::forward(probs_val, Y_val)
 accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
index 96fe734..8444952 100644
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
@@ -39,12 +39,13 @@ Y_val = labels[1:val_size,]
 # Arguments
 epochs = 10
 workers = 2
+batchsize = 32
 
 # Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", "DISJOINT_RANDOM")
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_RANDOM")
 
 # Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
 loss_val = cross_entropy_loss::forward(probs_val, Y_val)
 accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
index e97dbff..ccb7ffc 100644
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
@@ -39,12 +39,13 @@ Y_val = labels[1:val_size,]
 # Arguments
 epochs = 10
 workers = 4
+batchsize = 32
 
 # Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", "DISJOINT_ROUND_ROBIN")
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_ROUND_ROBIN")
 
 # Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
 loss_val = cross_entropy_loss::forward(probs_val, Y_val)
 accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
index a2e95d3..4afc56b 100644
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
@@ -39,12 +39,13 @@ Y_val = labels[1:val_size,]
 # Arguments
 epochs = 10
 workers = 2
+batchsize = 32
 
 # Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", "OVERLAP_RESHUFFLE")
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "OVERLAP_RESHUFFLE")
 
 # Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
 loss_val = cross_entropy_loss::forward(probs_val, Y_val)
 accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/7027a532/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml 
b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
index 25d5f48..c542286 100644
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
@@ -39,12 +39,13 @@ Y_val = labels[1:val_size,]
 # Arguments
 epochs = 10
 workers = 2
+batchsize = 32
 
 # Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "EPOCH", "DISJOINT_CONTIGUOUS")
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, 
Hin, Win, epochs, workers, "BSP", "EPOCH", batchsize,"DISJOINT_CONTIGUOUS")
 
 # Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, 
b2, W3, b3, W4, b4)
 loss_val = cross_entropy_loss::forward(probs_val, Y_val)
 accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
 

Reply via email to