[SYSTEMML-2419] Paramserv spark function shipping and worker setup Closes #799.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/cffefca3 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/cffefca3 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/cffefca3 Branch: refs/heads/master Commit: cffefca30e89ce249c3030d23123c1b3aba1757a Parents: 614adec Author: EdgarLGB <[email protected]> Authored: Sun Jul 15 22:01:54 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Jul 15 22:02:04 2018 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/api/DMLScript.java | 4 +- .../java/org/apache/sysml/hops/DataGenOp.java | 5 +- .../org/apache/sysml/hops/OptimizerUtils.java | 3 +- .../apache/sysml/hops/recompile/Recompiler.java | 5 +- src/main/java/org/apache/sysml/lops/Lop.java | 2 + .../java/org/apache/sysml/lops/compile/Dag.java | 3 +- .../org/apache/sysml/parser/DMLTranslator.java | 3 +- .../controlprogram/LocalVariableMap.java | 2 +- .../controlprogram/ParForProgramBlock.java | 2 +- .../context/ExecutionContext.java | 4 + .../controlprogram/paramserv/PSWorker.java | 30 +- .../paramserv/ParamservUtils.java | 35 +- .../paramserv/spark/SparkPSBody.java | 46 + .../paramserv/spark/SparkPSWorker.java | 43 +- .../controlprogram/parfor/ProgramConverter.java | 1866 ------------------ .../controlprogram/parfor/RemoteDPParForMR.java | 1 + .../parfor/RemoteDPParForSparkWorker.java | 23 +- .../parfor/RemoteDPParWorkerReducer.java | 1 + .../controlprogram/parfor/RemoteParForMR.java | 1 + .../parfor/RemoteParForSparkWorker.java | 23 +- .../parfor/RemoteParForUtils.java | 25 + .../parfor/RemoteParWorkerMapper.java | 1 + .../parfor/opt/OptimizerRuleBased.java | 16 +- .../runtime/instructions/MRJobInstruction.java | 2 +- .../cp/ParamservBuiltinCPInstruction.java | 34 +- .../instructions/cp/VariableCPInstruction.java | 2 +- .../sysml/runtime/util/ProgramConverter.java | 1838 +++++++++++++++++ .../paramserv/ParamservLocalNNTest.java | 93 + .../functions/paramserv/ParamservNNTest.java | 94 - .../paramserv/ParamservSparkNNTest.java | 47 + .../functions/paramserv/SerializationTest.java | 80 + .../parfor/ParForAdversarialLiteralsTest.java | 13 +- .../paramserv/mnist_lenet_paramserv.dml | 4 +- .../paramserv/paramserv-nn-asp-batch.dml | 2 +- .../paramserv/paramserv-nn-asp-epoch.dml | 2 +- .../paramserv/paramserv-nn-bsp-batch-dc.dml | 2 +- .../paramserv/paramserv-nn-bsp-batch-dr.dml | 2 +- .../paramserv/paramserv-nn-bsp-batch-drr.dml | 2 +- .../paramserv/paramserv-nn-bsp-batch-or.dml | 2 +- .../paramserv/paramserv-nn-bsp-epoch.dml | 2 +- .../paramserv-spark-nn-bsp-batch-dc.dml | 53 + .../functions/paramserv/ZPackageSuite.java | 3 +- 42 files changed, 2332 insertions(+), 2089 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/api/DMLScript.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java index 227c14b..9c8a3eb 100644 --- a/src/main/java/org/apache/sysml/api/DMLScript.java +++ b/src/main/java/org/apache/sysml/api/DMLScript.java @@ -28,6 +28,7 @@ import java.io.InputStreamReader; import java.net.URI; import java.text.DateFormat; import java.text.SimpleDateFormat; +import java.util.Arrays; import java.util.Collections; import java.util.Date; import java.util.HashMap; @@ -76,7 +77,6 @@ import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.controlprogram.parfor.util.IDHandler; import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; @@ -853,7 +853,7 @@ public class DMLScript LOG.debug("SystemML security check: " + "local.user.name = " + userName + ", " - + "local.user.groups = " + ProgramConverter.serializeStringCollection(groupNames) + ", " + + "local.user.groups = " + Arrays.toString(groupNames.toArray()) + ", " + MRConfigurationNames.MR_JOBTRACKER_ADDRESS + " = " + job.get(MRConfigurationNames.MR_JOBTRACKER_ADDRESS) + ", " + MRConfigurationNames.MR_TASKTRACKER_TASKCONTROLLER + " = " + taskController + "," + MRConfigurationNames.MR_TASKTRACKER_GROUP + " = " + ttGroupName + ", " http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/hops/DataGenOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DataGenOp.java b/src/main/java/org/apache/sysml/hops/DataGenOp.java index 9e59235..b374185 100644 --- a/src/main/java/org/apache/sysml/hops/DataGenOp.java +++ b/src/main/java/org/apache/sysml/hops/DataGenOp.java @@ -34,7 +34,6 @@ import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.parser.Statement; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.util.UtilFunctions; /** @@ -108,8 +107,8 @@ public class DataGenOp extends MultiThreadedHop //generate base dir String scratch = ConfigurationManager.getScratchSpace(); - _baseDir = scratch + Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + DMLScript.getUUID() + Lop.FILE_SEPARATOR + - Lop.FILE_SEPARATOR + ProgramConverter.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR; + _baseDir = scratch + Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + DMLScript.getUUID() + Lop.FILE_SEPARATOR + + Lop.FILE_SEPARATOR + Lop.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR; //compute unknown dims and nnz refreshSizeInformation(); http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/hops/OptimizerUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java index 3f1a8cf..fb83df0 100644 --- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java @@ -46,7 +46,6 @@ import org.apache.sysml.runtime.controlprogram.ForProgramBlock; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.caching.LazyWriteBuffer; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.functionobjects.IntegerDivide; import org.apache.sysml.runtime.functionobjects.Modulus; @@ -927,7 +926,7 @@ public class OptimizerUtils public static String getUniqueTempFileName() { return ConfigurationManager.getScratchSpace() + Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + DMLScript.getUUID() - + Lop.FILE_SEPARATOR + ProgramConverter.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR + + Lop.FILE_SEPARATOR + Lop.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR + Dag.getNextUniqueFilenameSuffix(); } http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java index d5b0043..4175eac 100644 --- a/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java +++ b/src/main/java/org/apache/sysml/hops/recompile/Recompiler.java @@ -82,7 +82,7 @@ import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; +import org.apache.sysml.runtime.util.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.opt.OptTreeConverter; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysml.runtime.instructions.Instruction; @@ -589,8 +589,7 @@ public class Recompiler //update function names if( hop instanceof FunctionOp && ((FunctionOp)hop).getFunctionType() != FunctionType.MULTIRETURN_BUILTIN) { FunctionOp fop = (FunctionOp) hop; - fop.setFunctionName( fop.getFunctionName() + - ProgramConverter.CP_CHILD_THREAD + pid); + fop.setFunctionName( fop.getFunctionName() + Lop.CP_CHILD_THREAD + pid); } if( hop.getInput() != null ) http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/lops/Lop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java index ff2d515..9e81496 100644 --- a/src/main/java/org/apache/sysml/lops/Lop.java +++ b/src/main/java/org/apache/sysml/lops/Lop.java @@ -77,6 +77,8 @@ public abstract class Lop public static final String FILE_SEPARATOR = "/"; public static final String PROCESS_PREFIX = "_p"; + public static final String CP_ROOT_THREAD_ID = "_t0"; + public static final String CP_CHILD_THREAD = "_t"; //special delimiters w/ extended ASCII characters to avoid collisions public static final String INSTRUCTION_DELIMITOR = "\u2021"; http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/lops/compile/Dag.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java b/src/main/java/org/apache/sysml/lops/compile/Dag.java index b1d6865..452f030 100644 --- a/src/main/java/org/apache/sysml/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java @@ -69,7 +69,6 @@ import org.apache.sysml.parser.Expression; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysml.runtime.instructions.CPInstructionParser; import org.apache.sysml.runtime.instructions.Instruction; @@ -198,7 +197,7 @@ public class Dag<N extends Lop> scratchFilePath = scratch + Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + DMLScript.getUUID() + Lop.FILE_SEPARATOR + Lop.FILE_SEPARATOR - + ProgramConverter.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR; + + Lop.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR; } return scratchFilePath; } http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 089edce..b9e5f9d 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -88,7 +88,6 @@ import org.apache.sysml.runtime.controlprogram.ParForProgramBlock; import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.ProgramBlock; import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.instructions.Instruction; @@ -613,7 +612,7 @@ public class DMLTranslator buff.append(Lop.PROCESS_PREFIX); buff.append(DMLScript.getUUID()); buff.append(Lop.FILE_SEPARATOR); - buff.append(ProgramConverter.CP_ROOT_THREAD_ID); + buff.append(Lop.CP_ROOT_THREAD_ID); buff.append(Lop.FILE_SEPARATOR); buff.append("PackageSupport"); buff.append(Lop.FILE_SEPARATOR); http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java index 2081aae..62ae3d0 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java @@ -28,7 +28,7 @@ import java.util.StringTokenizer; import org.apache.sysml.api.DMLScript; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; +import org.apache.sysml.runtime.util.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.utils.Statistics; http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java index 83e9cde..ba490f3 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java @@ -65,7 +65,7 @@ import org.apache.sysml.runtime.controlprogram.parfor.DataPartitionerRemoteSpark import org.apache.sysml.runtime.controlprogram.parfor.LocalParWorker; import org.apache.sysml.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysml.runtime.controlprogram.parfor.ParForBody; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; +import org.apache.sysml.runtime.util.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.RemoteDPParForMR; import org.apache.sysml.runtime.controlprogram.parfor.RemoteDPParForSpark; import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForJobReturn; http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index d0d0b08..2e0addf 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -101,6 +101,10 @@ public class ExecutionContext { public Program getProgram(){ return _prog; } + + public void setProgram(Program prog) { + _prog = prog; + } public LocalVariableMap getVariables() { return _variables; http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index a76dfec..1ab5f5e 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -34,23 +34,25 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; -@SuppressWarnings("unused") public abstract class PSWorker { - - protected final int _workerID; - protected final int _epochs; - protected final long _batchSize; - protected final ExecutionContext _ec; - protected final ParamServer _ps; - protected final DataIdentifier _output; - protected final FunctionCallCPInstruction _inst; + protected int _workerID; + protected int _epochs; + protected long _batchSize; + protected ExecutionContext _ec; + protected ParamServer _ps; + protected DataIdentifier _output; + protected FunctionCallCPInstruction _inst; protected MatrixObject _features; protected MatrixObject _labels; - - private MatrixObject _valFeatures; - private MatrixObject _valLabels; - private final String _updFunc; - protected final Statement.PSFrequency _freq; + + protected MatrixObject _valFeatures; + protected MatrixObject _valLabels; + protected String _updFunc; + protected Statement.PSFrequency _freq; + + protected PSWorker() { + + } protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) { http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index ec6b6b2..ee15709 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -57,7 +57,6 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator; import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.ListObject; @@ -68,6 +67,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; +import org.apache.sysml.runtime.util.ProgramConverter; import scala.Tuple2; @@ -175,9 +175,8 @@ public class ParamservUtils { new String[]{ns, name} : new String[]{ns, name}; } - public static List<ExecutionContext> createExecutionContexts(ExecutionContext ec, LocalVariableMap varsMap, - String updFunc, String aggFunc, int workerNum, int k) { - + public static ExecutionContext createExecutionContext(ExecutionContext ec, LocalVariableMap varsMap, String updFunc, + String aggFunc, int k) { FunctionProgramBlock updPB = getFunctionBlock(ec, updFunc); FunctionProgramBlock aggPB = getFunctionBlock(ec, aggFunc); @@ -188,27 +187,21 @@ public class ParamservUtils { // 2. Recompile the imported function blocks prog.getFunctionProgramBlocks().forEach((fname, fvalue) -> recompileProgramBlocks(k, fvalue.getChildBlocks())); - // 3. Copy function for workers - List<ExecutionContext> workerECs = IntStream.range(0, workerNum) - .mapToObj(i -> { - FunctionProgramBlock newUpdFunc = copyFunction(updFunc, updPB); - FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB); - Program newProg = new Program(); - putFunction(newProg, newUpdFunc); - putFunction(newProg, newAggFunc); - return ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg); - }) - .collect(Collectors.toList()); - - // 4. Copy function for agg service + // 3. Copy function + FunctionProgramBlock newUpdFunc = copyFunction(updFunc, updPB); FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB); Program newProg = new Program(); + putFunction(newProg, newUpdFunc); putFunction(newProg, newAggFunc); - ExecutionContext aggEC = ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg); + return ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg); + } - List<ExecutionContext> result = new ArrayList<>(workerECs); - result.add(aggEC); - return result; + public static List<ExecutionContext> copyExecutionContext(ExecutionContext ec, int num) { + return IntStream.range(0, num).mapToObj(i -> { + Program newProg = new Program(); + ec.getProgram().getFunctionProgramBlocks().forEach((func, pb) -> putFunction(newProg, copyFunction(func, pb))); + return ExecutionContextFactory.createContext(new LocalVariableMap(ec.getVariables()), newProg); + }).collect(Collectors.toList()); } private static FunctionProgramBlock copyFunction(String funcName, FunctionProgramBlock fpb) { http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java new file mode 100644 index 0000000..ec10232 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.controlprogram.paramserv.spark; + +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; + +/** + * Wrapper class containing all needed for launching spark remote worker + */ +public class SparkPSBody { + + private ExecutionContext _ec; + + public SparkPSBody() { + + } + + public SparkPSBody(ExecutionContext ec) { + this._ec = ec; + } + + public ExecutionContext getEc() { + return _ec; + } + + public void setEc(ExecutionContext ec) { + this._ec = ec; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/cffefca3/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java index 69da56c..466801f 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java @@ -19,30 +19,59 @@ package org.apache.sysml.runtime.controlprogram.paramserv.spark; +import java.io.IOException; import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; import org.apache.spark.api.java.function.VoidFunction; import org.apache.sysml.parser.Statement; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; +import org.apache.sysml.runtime.codegen.CodegenUtils; +import org.apache.sysml.runtime.controlprogram.paramserv.PSWorker; +import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils; import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.util.ProgramConverter; import scala.Tuple2; -public class SparkPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable { +public class SparkPSWorker extends PSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable { private static final long serialVersionUID = -8674739573419648732L; - public SparkPSWorker() { + private String _program; + private HashMap<String, byte[]> _clsMap; + + protected SparkPSWorker() { // No-args constructor used for deserialization } - public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, - MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) { + public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap) { + _updFunc = updFunc; + _freq = freq; + _epochs = epochs; + _batchSize = batchSize; + _program = program; + _clsMap = clsMap; } @Override public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception { + configureWorker(input); + } + + private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException { + _workerID = input._1; + + // Initialize codegen class cache (before program parsing) + for (Map.Entry<String, byte[]> e : _clsMap.entrySet()) { + CodegenUtils.getClassSync(e.getKey(), e.getValue()); + } + + // Deserialize the body to initialize the execution context + SparkPSBody body = ProgramConverter.parseSparkPSBody(_program, _workerID); + _ec = body.getEc(); + + // Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end + RemoteParForUtils.setupBufferPool(_workerID); } }
