[SYSTEMML-2419] Paramserv spark function shipping and worker setup

Closes #799.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/cffefca3
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/cffefca3
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/cffefca3

Branch: refs/heads/master
Commit: cffefca30e89ce249c3030d23123c1b3aba1757a
Parents: 614adec
Author: EdgarLGB <[email protected]>
Authored: Sun Jul 15 22:01:54 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Jul 15 22:02:04 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/api/DMLScript.java    |    4 +-
 .../java/org/apache/sysml/hops/DataGenOp.java   |    5 +-
 .../org/apache/sysml/hops/OptimizerUtils.java   |    3 +-
 .../apache/sysml/hops/recompile/Recompiler.java |    5 +-
 src/main/java/org/apache/sysml/lops/Lop.java    |    2 +
 .../java/org/apache/sysml/lops/compile/Dag.java |    3 +-
 .../org/apache/sysml/parser/DMLTranslator.java  |    3 +-
 .../controlprogram/LocalVariableMap.java        |    2 +-
 .../controlprogram/ParForProgramBlock.java      |    2 +-
 .../context/ExecutionContext.java               |    4 +
 .../controlprogram/paramserv/PSWorker.java      |   30 +-
 .../paramserv/ParamservUtils.java               |   35 +-
 .../paramserv/spark/SparkPSBody.java            |   46 +
 .../paramserv/spark/SparkPSWorker.java          |   43 +-
 .../controlprogram/parfor/ProgramConverter.java | 1866 ------------------
 .../controlprogram/parfor/RemoteDPParForMR.java |    1 +
 .../parfor/RemoteDPParForSparkWorker.java       |   23 +-
 .../parfor/RemoteDPParWorkerReducer.java        |    1 +
 .../controlprogram/parfor/RemoteParForMR.java   |    1 +
 .../parfor/RemoteParForSparkWorker.java         |   23 +-
 .../parfor/RemoteParForUtils.java               |   25 +
 .../parfor/RemoteParWorkerMapper.java           |    1 +
 .../parfor/opt/OptimizerRuleBased.java          |   16 +-
 .../runtime/instructions/MRJobInstruction.java  |    2 +-
 .../cp/ParamservBuiltinCPInstruction.java       |   34 +-
 .../instructions/cp/VariableCPInstruction.java  |    2 +-
 .../sysml/runtime/util/ProgramConverter.java    | 1838 +++++++++++++++++
 .../paramserv/ParamservLocalNNTest.java         |   93 +
 .../functions/paramserv/ParamservNNTest.java    |   94 -
 .../paramserv/ParamservSparkNNTest.java         |   47 +
 .../functions/paramserv/SerializationTest.java  |   80 +
 .../parfor/ParForAdversarialLiteralsTest.java   |   13 +-
 .../paramserv/mnist_lenet_paramserv.dml         |    4 +-
 .../paramserv/paramserv-nn-asp-batch.dml        |    2 +-
 .../paramserv/paramserv-nn-asp-epoch.dml        |    2 +-
 .../paramserv/paramserv-nn-bsp-batch-dc.dml     |    2 +-
 .../paramserv/paramserv-nn-bsp-batch-dr.dml     |    2 +-
 .../paramserv/paramserv-nn-bsp-batch-drr.dml    |    2 +-
 .../paramserv/paramserv-nn-bsp-batch-or.dml     |    2 +-
 .../paramserv/paramserv-nn-bsp-epoch.dml        |    2 +-
 .../paramserv-spark-nn-bsp-batch-dc.dml         |   53 +
 .../functions/paramserv/ZPackageSuite.java      |    3 +-
 42 files changed, 2332 insertions(+), 2089 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/api/DMLScript.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java 
b/src/main/java/org/apache/sysml/api/DMLScript.java
index 227c14b..9c8a3eb 100644
--- a/src/main/java/org/apache/sysml/api/DMLScript.java
+++ b/src/main/java/org/apache/sysml/api/DMLScript.java
@@ -28,6 +28,7 @@ import java.io.InputStreamReader;
 import java.net.URI;
 import java.text.DateFormat;
 import java.text.SimpleDateFormat;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Date;
 import java.util.HashMap;
@@ -76,7 +77,6 @@ import 
org.apache.sysml.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.controlprogram.parfor.util.IDHandler;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
@@ -853,7 +853,7 @@ public class DMLScript
                
                LOG.debug("SystemML security check: "
                                + "local.user.name = " + userName + ", "
-                               + "local.user.groups = " + 
ProgramConverter.serializeStringCollection(groupNames) + ", "
+                               + "local.user.groups = " + 
Arrays.toString(groupNames.toArray()) + ", "
                                + MRConfigurationNames.MR_JOBTRACKER_ADDRESS + 
" = " + job.get(MRConfigurationNames.MR_JOBTRACKER_ADDRESS) + ", "
                                + 
MRConfigurationNames.MR_TASKTRACKER_TASKCONTROLLER + " = " + taskController + 
","
                                + MRConfigurationNames.MR_TASKTRACKER_GROUP + " 
= " + ttGroupName + ", "

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/hops/DataGenOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DataGenOp.java 
b/src/main/java/org/apache/sysml/hops/DataGenOp.java
index 9e59235..b374185 100644
--- a/src/main/java/org/apache/sysml/hops/DataGenOp.java
+++ b/src/main/java/org/apache/sysml/hops/DataGenOp.java
@@ -34,7 +34,6 @@ import org.apache.sysml.parser.DataExpression;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
 /**
@@ -108,8 +107,8 @@ public class DataGenOp extends MultiThreadedHop
                
                //generate base dir
                String scratch = ConfigurationManager.getScratchSpace();
-               _baseDir = scratch + Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + 
DMLScript.getUUID() + Lop.FILE_SEPARATOR + 
-                      Lop.FILE_SEPARATOR + ProgramConverter.CP_ROOT_THREAD_ID 
+ Lop.FILE_SEPARATOR;
+               _baseDir = scratch + Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + 
DMLScript.getUUID() + Lop.FILE_SEPARATOR 
+                       + Lop.FILE_SEPARATOR + Lop.CP_ROOT_THREAD_ID + 
Lop.FILE_SEPARATOR;
                
                //compute unknown dims and nnz
                refreshSizeInformation();

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index 3f1a8cf..fb83df0 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -46,7 +46,6 @@ import 
org.apache.sysml.runtime.controlprogram.ForProgramBlock;
 import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysml.runtime.controlprogram.caching.LazyWriteBuffer;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.functionobjects.IntegerDivide;
 import org.apache.sysml.runtime.functionobjects.Modulus;
@@ -927,7 +926,7 @@ public class OptimizerUtils
        public static String getUniqueTempFileName() {
                return ConfigurationManager.getScratchSpace()
                        + Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + 
DMLScript.getUUID()
-                       + Lop.FILE_SEPARATOR + 
ProgramConverter.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR 
+                       + Lop.FILE_SEPARATOR + Lop.CP_ROOT_THREAD_ID + 
Lop.FILE_SEPARATOR 
                        + Dag.getNextUniqueFilenameSuffix();
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
index d5b0043..4175eac 100644
--- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
@@ -82,7 +82,7 @@ import 
org.apache.sysml.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
+import org.apache.sysml.runtime.util.ProgramConverter;
 import org.apache.sysml.runtime.controlprogram.parfor.opt.OptTreeConverter;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysml.runtime.instructions.Instruction;
@@ -589,8 +589,7 @@ public class Recompiler
                //update function names
                if( hop instanceof FunctionOp && 
((FunctionOp)hop).getFunctionType() != FunctionType.MULTIRETURN_BUILTIN) {
                        FunctionOp fop = (FunctionOp) hop;
-                       fop.setFunctionName( fop.getFunctionName() +
-                                                    
ProgramConverter.CP_CHILD_THREAD + pid);
+                       fop.setFunctionName( fop.getFunctionName() + 
Lop.CP_CHILD_THREAD + pid);
                }
                
                if( hop.getInput() != null )

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/lops/Lop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Lop.java 
b/src/main/java/org/apache/sysml/lops/Lop.java
index ff2d515..9e81496 100644
--- a/src/main/java/org/apache/sysml/lops/Lop.java
+++ b/src/main/java/org/apache/sysml/lops/Lop.java
@@ -77,6 +77,8 @@ public abstract class Lop
        
        public static final String FILE_SEPARATOR = "/";
        public static final String PROCESS_PREFIX = "_p";
+       public static final String CP_ROOT_THREAD_ID = "_t0";
+       public static final String CP_CHILD_THREAD = "_t";
        
        //special delimiters w/ extended ASCII characters to avoid collisions 
        public static final String INSTRUCTION_DELIMITOR = "\u2021";

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/lops/compile/Dag.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java 
b/src/main/java/org/apache/sysml/lops/compile/Dag.java
index b1d6865..452f030 100644
--- a/src/main/java/org/apache/sysml/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java
@@ -69,7 +69,6 @@ import org.apache.sysml.parser.Expression;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.StatementBlock;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
 import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysml.runtime.instructions.CPInstructionParser;
 import org.apache.sysml.runtime.instructions.Instruction;
@@ -198,7 +197,7 @@ public class Dag<N extends Lop>
                        scratchFilePath = scratch + Lop.FILE_SEPARATOR
                                + Lop.PROCESS_PREFIX + DMLScript.getUUID()
                                + Lop.FILE_SEPARATOR + Lop.FILE_SEPARATOR
-                               + ProgramConverter.CP_ROOT_THREAD_ID + 
Lop.FILE_SEPARATOR;
+                               + Lop.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR;
                }
                return scratchFilePath;
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 089edce..b9e5f9d 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -88,7 +88,6 @@ import 
org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
 import org.apache.sysml.runtime.controlprogram.Program;
 import org.apache.sysml.runtime.controlprogram.ProgramBlock;
 import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
 import org.apache.sysml.runtime.instructions.Instruction;
 
 
@@ -613,7 +612,7 @@ public class DMLTranslator
                                buff.append(Lop.PROCESS_PREFIX);
                                buff.append(DMLScript.getUUID());
                                buff.append(Lop.FILE_SEPARATOR);
-                               buff.append(ProgramConverter.CP_ROOT_THREAD_ID);
+                               buff.append(Lop.CP_ROOT_THREAD_ID);
                                buff.append(Lop.FILE_SEPARATOR);
                                buff.append("PackageSupport");
                                buff.append(Lop.FILE_SEPARATOR);

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
index 2081aae..62ae3d0 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
@@ -28,7 +28,7 @@ import java.util.StringTokenizer;
 
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
+import org.apache.sysml.runtime.util.ProgramConverter;
 import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.utils.Statistics;

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
index 83e9cde..ba490f3 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java
@@ -65,7 +65,7 @@ import 
org.apache.sysml.runtime.controlprogram.parfor.DataPartitionerRemoteSpark
 import org.apache.sysml.runtime.controlprogram.parfor.LocalParWorker;
 import org.apache.sysml.runtime.controlprogram.parfor.LocalTaskQueue;
 import org.apache.sysml.runtime.controlprogram.parfor.ParForBody;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
+import org.apache.sysml.runtime.util.ProgramConverter;
 import org.apache.sysml.runtime.controlprogram.parfor.RemoteDPParForMR;
 import org.apache.sysml.runtime.controlprogram.parfor.RemoteDPParForSpark;
 import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForJobReturn;

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index d0d0b08..2e0addf 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -101,6 +101,10 @@ public class ExecutionContext {
        public Program getProgram(){
                return _prog;
        }
+
+       public void setProgram(Program prog) {
+               _prog = prog;
+       }
        
        public LocalVariableMap getVariables() {
                return _variables;

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index a76dfec..1ab5f5e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -34,23 +34,25 @@ import 
org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
 
-@SuppressWarnings("unused")
 public abstract class PSWorker {
-
-       protected final int _workerID;
-       protected final int _epochs;
-       protected final long _batchSize;
-       protected final ExecutionContext _ec;
-       protected final ParamServer _ps;
-       protected final DataIdentifier _output;
-       protected final FunctionCallCPInstruction _inst;
+       protected int _workerID;
+       protected int _epochs;
+       protected long _batchSize;
+       protected ExecutionContext _ec;
+       protected ParamServer _ps;
+       protected DataIdentifier _output;
+       protected FunctionCallCPInstruction _inst;
        protected MatrixObject _features;
        protected MatrixObject _labels;
-       
-       private MatrixObject _valFeatures;
-       private MatrixObject _valLabels;
-       private final String _updFunc;
-       protected final Statement.PSFrequency _freq;
+
+       protected MatrixObject _valFeatures;
+       protected MatrixObject _valLabels;
+       protected String _updFunc;
+       protected Statement.PSFrequency _freq;
+
+       protected PSWorker() {
+
+       }
        
        protected PSWorker(int workerID, String updFunc, Statement.PSFrequency 
freq, int epochs, long batchSize,
                MatrixObject valFeatures, MatrixObject valLabels, 
ExecutionContext ec, ParamServer ps) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index ec6b6b2..ee15709 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -57,7 +57,6 @@ import 
org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator;
 import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper;
-import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
 import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
@@ -68,6 +67,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysml.runtime.util.ProgramConverter;
 
 import scala.Tuple2;
 
@@ -175,9 +175,8 @@ public class ParamservUtils {
                        new String[]{ns, name} : new String[]{ns, name};
        }
 
-       public static List<ExecutionContext> 
createExecutionContexts(ExecutionContext ec, LocalVariableMap varsMap,
-               String updFunc, String aggFunc, int workerNum, int k) {
-
+       public static ExecutionContext createExecutionContext(ExecutionContext 
ec, LocalVariableMap varsMap, String updFunc,
+               String aggFunc, int k) {
                FunctionProgramBlock updPB = getFunctionBlock(ec, updFunc);
                FunctionProgramBlock aggPB = getFunctionBlock(ec, aggFunc);
 
@@ -188,27 +187,21 @@ public class ParamservUtils {
                // 2. Recompile the imported function blocks
                prog.getFunctionProgramBlocks().forEach((fname, fvalue) -> 
recompileProgramBlocks(k, fvalue.getChildBlocks()));
 
-               // 3. Copy function for workers
-               List<ExecutionContext> workerECs = IntStream.range(0, workerNum)
-                       .mapToObj(i -> {
-                               FunctionProgramBlock newUpdFunc = 
copyFunction(updFunc, updPB);
-                               FunctionProgramBlock newAggFunc = 
copyFunction(aggFunc, aggPB);
-                               Program newProg = new Program();
-                               putFunction(newProg, newUpdFunc);
-                               putFunction(newProg, newAggFunc);
-                               return 
ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg);
-                       })
-                       .collect(Collectors.toList());
-
-               // 4. Copy function for agg service
+               // 3. Copy function
+               FunctionProgramBlock newUpdFunc = copyFunction(updFunc, updPB);
                FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB);
                Program newProg = new Program();
+               putFunction(newProg, newUpdFunc);
                putFunction(newProg, newAggFunc);
-               ExecutionContext aggEC = 
ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg);
+               return ExecutionContextFactory.createContext(new 
LocalVariableMap(varsMap), newProg);
+       }
 
-               List<ExecutionContext> result = new ArrayList<>(workerECs);
-               result.add(aggEC);
-               return result;
+       public static List<ExecutionContext> 
copyExecutionContext(ExecutionContext ec, int num) {
+               return IntStream.range(0, num).mapToObj(i -> {
+                       Program newProg = new Program();
+                       
ec.getProgram().getFunctionProgramBlocks().forEach((func, pb) -> 
putFunction(newProg, copyFunction(func, pb)));
+                       return ExecutionContextFactory.createContext(new 
LocalVariableMap(ec.getVariables()), newProg);
+               }).collect(Collectors.toList());
        }
 
        private static FunctionProgramBlock copyFunction(String funcName, 
FunctionProgramBlock fpb) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
new file mode 100644
index 0000000..ec10232
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark;
+
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+
+/**
+ * Wrapper class containing all needed for launching spark remote worker
+ */
+public class SparkPSBody {
+
+       private ExecutionContext _ec;
+
+       public SparkPSBody() {
+
+       }
+
+       public SparkPSBody(ExecutionContext ec) {
+               this._ec = ec;
+       }
+
+       public ExecutionContext getEc() {
+               return _ec;
+       }
+
+       public void setEc(ExecutionContext ec) {
+               this._ec = ec;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
index 69da56c..466801f 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
@@ -19,30 +19,59 @@
 
 package org.apache.sysml.runtime.controlprogram.paramserv.spark;
 
+import java.io.IOException;
 import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
 
 import org.apache.spark.api.java.function.VoidFunction;
 import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
+import org.apache.sysml.runtime.codegen.CodegenUtils;
+import org.apache.sysml.runtime.controlprogram.paramserv.PSWorker;
+import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.util.ProgramConverter;
 
 import scala.Tuple2;
 
-public class SparkPSWorker implements VoidFunction<Tuple2<Integer, 
Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {
+public class SparkPSWorker extends PSWorker implements 
VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {
 
        private static final long serialVersionUID = -8674739573419648732L;
 
-       public SparkPSWorker() {
+       private String _program;
+       private HashMap<String, byte[]> _clsMap;
+
+       protected SparkPSWorker() {
                // No-args constructor used for deserialization
        }
 
-       public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int 
epochs, long batchSize,
-                       MatrixObject valFeatures, MatrixObject valLabels, 
ExecutionContext ec, ParamServer ps) {
+       public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int 
epochs, long batchSize, String program, HashMap<String, byte[]> clsMap) {
+               _updFunc = updFunc;
+               _freq = freq;
+               _epochs = epochs;
+               _batchSize = batchSize;
+               _program = program;
+               _clsMap = clsMap;
        }
 
        @Override
        public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> 
input) throws Exception {
+               configureWorker(input);
+       }
+
+       private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, 
MatrixBlock>> input) throws IOException {
+               _workerID = input._1;
+
+               // Initialize codegen class cache (before program parsing)
+               for (Map.Entry<String, byte[]> e : _clsMap.entrySet()) {
+                       CodegenUtils.getClassSync(e.getKey(), e.getValue());
+               }
+
+               // Deserialize the body to initialize the execution context
+               SparkPSBody body = ProgramConverter.parseSparkPSBody(_program, 
_workerID);
+               _ec = body.getEc();
+
+               // Initialize the buffer pool and register it in the jvm 
shutdown hook in order to be cleanuped at the end
+               RemoteParForUtils.setupBufferPool(_workerID);
        }
 }

Reply via email to