Repository: systemml Updated Branches: refs/heads/master 25be6a686 -> b06f390ec
[SYSTEMML-2392/8,2401/2/6] Paramserv statistics and various fixes Closes #787. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b06f390e Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b06f390e Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b06f390e Branch: refs/heads/master Commit: b06f390ecf75dcf24e8143aafdba533440326861 Parents: 25be6a6 Author: EdgarLGB <[email protected]> Authored: Sun Jun 17 18:21:48 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Jun 17 18:21:48 2018 -0700 ---------------------------------------------------------------------- .../controlprogram/paramserv/LocalPSWorker.java | 78 ++++++-- .../controlprogram/paramserv/PSWorker.java | 4 +- .../controlprogram/paramserv/ParamServer.java | 62 +++++-- .../paramserv/ParamservUtils.java | 132 +++++++++++++- .../cp/ParamservBuiltinCPInstruction.java | 181 +++++-------------- .../java/org/apache/sysml/utils/Statistics.java | 47 ++++- 6 files changed, 327 insertions(+), 177 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index a64df72..beb5e45 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -20,14 +20,21 @@ package org.apache.sysml.runtime.controlprogram.paramserv; import java.util.concurrent.Callable; +import java.util.stream.IntStream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysml.api.DMLScript; import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; +import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; +import org.apache.sysml.utils.Statistics; public class LocalPSWorker extends PSWorker implements Callable<Void> { @@ -40,6 +47,9 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { @Override public Void call() throws Exception { + if (DMLScript.STATISTICS) + Statistics.incWorkerNumber(); + try { long dataSize = _features.getNumRows(); int totalIter = (int) Math.ceil((double) dataSize / _batchSize); @@ -65,26 +75,28 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { private void computeEpoch(long dataSize, int totalIter) { for (int i = 0; i < _epochs; i++) { // Pull the global parameters from ps - ListObject globalParams = pullModel(); - + ListObject params = pullModel(); + ListObject accGradients = null; + for (int j = 0; j < totalIter; j++) { - _ec.setVariable(Statement.PS_MODEL, globalParams); + _ec.setVariable(Statement.PS_MODEL, params); ListObject gradients = computeGradients(dataSize, totalIter, i, j); - if (j == totalIter - 1) { - // Push the gradients to ps - pushGradients(gradients); - ParamservUtils.cleanupListObject(_ec, globalParams); - } else { - // Update the local model with gradients - globalParams = _ps.updateModel(gradients, globalParams); - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("Local worker_%d: Local global parameter [size:%d kb] updated.", - _workerID, globalParams.getDataSize())); - } - } + // Accumulate the intermediate gradients + accGradients = (accGradients==null) ? + ParamservUtils.copyList(gradients) : + accrueGradients(accGradients, gradients); + + // Update the local model with gradients + if( j < totalIter - 1 ) + params = updateModel(params, gradients, i, j, totalIter); } + + // Push the gradients to ps + pushGradients(accGradients); + ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL); + if (LOG.isDebugEnabled()) { LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); } @@ -92,6 +104,22 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { } + private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int totalIter) { + Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null; + + globalParams = _ps.updateModel(gradients, globalParams); + + if (DMLScript.STATISTICS) + Statistics.accPSLocalModelUpdateTime((long) tUpd.stop()); + + if (LOG.isDebugEnabled()) { + LOG.debug(String.format("Local worker_%d: Local global parameter [size:%d kb] updated. " + + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", + _workerID, globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter)); + } + return globalParams; + } + private void computeBatch(long dataSize, int totalIter) { for (int i = 0; i < _epochs; i++) { for (int j = 0; j < totalIter; j++) { @@ -103,7 +131,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { // Push the gradients to ps pushGradients(gradients); - ParamservUtils.cleanupListObject(_ec, globalParams); + ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL); } if (LOG.isDebugEnabled()) { LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1)); @@ -135,8 +163,12 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { long end = Math.min((j + 1) * _batchSize, dataSize); // Get batch features and labels + Timing tSlic = DMLScript.STATISTICS ? new Timing(true) : null; MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end); MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end); + if (DMLScript.STATISTICS) + Statistics.accPSBatchIndexingTime((long) tSlic.stop()); + _ec.setVariable(Statement.PS_FEATURES, bFeatures); _ec.setVariable(Statement.PS_LABELS, bLabels); @@ -148,7 +180,10 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { } // Invoke the update function + Timing tGrad = DMLScript.STATISTICS ? new Timing(true) : null; _inst.processInstruction(_ec); + if (DMLScript.STATISTICS) + Statistics.accPSGradientComputeTime((long) tGrad.stop()); // Get the gradients ListObject gradients = (ListObject) _ec.getVariable(_output.getName()); @@ -157,4 +192,15 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { ParamservUtils.cleanupData(bLabels); return gradients; } + + private ListObject accrueGradients(ListObject accGradients, ListObject gradients) { + IntStream.range(0, accGradients.getLength()).forEach(i -> { + MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead(); + MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead(); + mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2); + ((MatrixObject) accGradients.getData().get(i)).release(); + ((MatrixObject) gradients.getData().get(i)).release(); + }); + return accGradients; + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java index d94831e..f46370d 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java @@ -19,6 +19,8 @@ package org.apache.sysml.runtime.controlprogram.paramserv; +import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.UPDATE_FUNC_PREFIX; + import java.util.ArrayList; import java.util.stream.Collectors; @@ -71,7 +73,7 @@ public abstract class PSWorker { funcNS = keys[0]; funcName = keys[1]; } - FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(funcNS, funcName); + FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(funcNS, UPDATE_FUNC_PREFIX + _workerID + "_" + funcName); ArrayList<DataIdentifier> inputs = func.getInputParams(); ArrayList<DataIdentifier> outputs = func.getOutputParams(); CPOperand[] boundInputs = inputs.stream() http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java index 7a39bec..0ddfb40 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -19,6 +19,8 @@ package org.apache.sysml.runtime.controlprogram.paramserv; +import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.AGG_FUNC_PREFIX; + import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -37,6 +39,7 @@ import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.concurrent.BasicThreadFactory; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysml.api.DMLScript; import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.Expression; @@ -44,13 +47,16 @@ import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; import org.apache.sysml.runtime.instructions.cp.ListObject; +import org.apache.sysml.utils.Statistics; public abstract class ParamServer { - + final BlockingQueue<Gradient> _gradientsQueue; final Map<Integer, BlockingQueue<ListObject>> _modelMap; private final AggregationService _aggService; @@ -73,8 +79,7 @@ public abstract class ParamServer { throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e); } BasicThreadFactory factory = new BasicThreadFactory.Builder() - .namingPattern("agg-service-pool-thread-%d") - .build(); + .namingPattern("agg-service-pool-thread-%d").build(); _es = Executors.newSingleThreadExecutor(factory); } @@ -91,11 +96,17 @@ public abstract class ParamServer { } public ListObject getResult() { + // All the model updating work has terminated, + // so we could return directly the result model return _model; } public ListObject updateModel(ListObject gradients, ListObject model) { - return _aggService.updateModel(gradients, model); + //note: we use a new execution context to allow for concurrent execution of ASP local updates; + //otherwise synchronized on the aggService instance would serialize those + ExecutionContext ec = ExecutionContextFactory.createContext(_aggService._ec.getProgram()); + ec.setVariable(Statement.PS_HYPER_PARAMS, _aggService._ec.getVariable(Statement.PS_HYPER_PARAMS)); + return _aggService.updateModel(ec, gradients, model); } public static class Gradient { @@ -115,11 +126,11 @@ public abstract class ParamServer { protected final Log LOG = LogFactory.getLog(AggregationService.class.getName()); - protected ExecutionContext _ec; - private Statement.PSUpdateType _updateType; - private FunctionCallCPInstruction _inst; - private DataIdentifier _output; - private boolean[] _finishedStates; // Workers' finished states + protected final ExecutionContext _ec; + private final Statement.PSUpdateType _updateType; + private final FunctionCallCPInstruction _inst; + private final DataIdentifier _output; + private final boolean[] _finishedStates; // Workers' finished states AggregationService(String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { _ec = ec; @@ -134,7 +145,7 @@ public abstract class ParamServer { funcNS = keys[0]; funcName = keys[1]; } - FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(funcNS, funcName); + FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(funcNS, AGG_FUNC_PREFIX + funcName); ArrayList<DataIdentifier> inputs = func.getInputParams(); ArrayList<DataIdentifier> outputs = func.getOutputParams(); @@ -170,14 +181,24 @@ public abstract class ParamServer { } private void broadcastModel() throws InterruptedException { + Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null; + //broadcast copy of the model to all workers, cleaned up by workers for (BlockingQueue<ListObject> q : _modelMap.values()) q.put(ParamservUtils.copyList(_model)); + + if (DMLScript.STATISTICS) + Statistics.accPSModelBroadcastTime((long) tBroad.stop()); } - + private void broadcastModel(int workerID) throws InterruptedException { + Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null; + //broadcast copy of model to specific worker, cleaned up by worker _modelMap.get(workerID).put(ParamservUtils.copyList(_model)); + + if (DMLScript.STATISTICS) + Statistics.accPSModelBroadcastTime((long) tBroad.stop()); } @Override @@ -195,7 +216,10 @@ public abstract class ParamServer { } // Update and redistribute the model + Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null; _model = updateModel(grad._gradients, _model); + if (DMLScript.STATISTICS) + Statistics.accPSAggregationTime((long) tAgg.stop()); // Redistribute model according to update type switch(_updateType) { @@ -231,19 +255,23 @@ public abstract class ParamServer { * @return A updated list object of model */ private synchronized ListObject updateModel(ListObject gradients, ListObject model) { + return updateModel(_ec, gradients, model); + } + + private ListObject updateModel(ExecutionContext ec, ListObject gradients, ListObject model) { // Populate the variables table with the gradients and model - _ec.setVariable(Statement.PS_GRADIENTS, gradients); - _ec.setVariable(Statement.PS_MODEL, model); + ec.setVariable(Statement.PS_GRADIENTS, gradients); + ec.setVariable(Statement.PS_MODEL, model); // Invoke the aggregate function - _inst.processInstruction(_ec); + _inst.processInstruction(ec); // Get the output - ListObject newModel = (ListObject) _ec.getVariable(_output.getName()); + ListObject newModel = (ListObject) ec.getVariable(_output.getName()); // Update the model with the new output - ParamservUtils.cleanupListObject(_ec, model); - ParamservUtils.cleanupListObject(_ec, gradients); + ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL); + ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS); return newModel; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index 2b1dca5..c4a3d98 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -19,18 +19,35 @@ package org.apache.sysml.runtime.controlprogram.paramserv; +import java.io.IOException; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.MultiThreadedHop; import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.DMLTranslator; import org.apache.sysml.parser.Expression; +import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.ForProgramBlock; +import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; +import org.apache.sysml.runtime.controlprogram.IfProgramBlock; +import org.apache.sysml.runtime.controlprogram.ParForProgramBlock; +import org.apache.sysml.runtime.controlprogram.Program; +import org.apache.sysml.runtime.controlprogram.ProgramBlock; +import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.ListObject; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -41,6 +58,9 @@ import org.apache.sysml.runtime.matrix.data.OutputInfo; public class ParamservUtils { + public static final String UPDATE_FUNC_PREFIX = "_worker_"; + public static final String AGG_FUNC_PREFIX = "_agg_"; + /** * Deep copy the list object * @@ -65,8 +85,8 @@ public class ParamservUtils { return new ListObject(newData, lo.getNames()); } - public static void cleanupListObject(ExecutionContext ec, ListObject lo) { - ec.getVariables().removeAllIn(new HashSet<>(lo.getNames())); + public static void cleanupListObject(ExecutionContext ec, String lName) { + ListObject lo = (ListObject) ec.removeVariable(lName); lo.getData().forEach(ParamservUtils::cleanupData); } @@ -110,4 +130,112 @@ public class ParamservUtils { seq.ctableOperations(null, sample, 1.0, permutation); return permutation; } + + public static ExecutionContext createExecutionContext(ExecutionContext ec, String updFunc, String aggFunc, int workerNum, int k) { + FunctionProgramBlock updPB = getFunctionBlock(ec, updFunc); + FunctionProgramBlock aggPB = getFunctionBlock(ec, aggFunc); + + Program prog = ec.getProgram(); + + // 1. Recompile the internal program blocks + recompileProgramBlocks(k, prog.getProgramBlocks()); + // 2. Recompile the imported function blocks + prog.getFunctionProgramBlocks().forEach((fname, fvalue) -> recompileProgramBlocks(k, fvalue.getChildBlocks())); + + // Copy function for workers + IntStream.range(0, workerNum).forEach(i -> copyFunction(updFunc, updPB, prog, UPDATE_FUNC_PREFIX + i + "_")); + + // Copy function for agg service + copyFunction(aggFunc, aggPB, prog, AGG_FUNC_PREFIX); + + return ExecutionContextFactory.createContext(prog); + } + + private static void copyFunction(String funcName, FunctionProgramBlock updPB, Program prog, String prefix) { + String[] keys = DMLProgram.splitFunctionKey(funcName); + String namespace = null; + String func = keys[0]; + if (keys.length == 2) { + namespace = keys[0]; + func = keys[1]; + } + FunctionProgramBlock copiedFunc = ProgramConverter + .createDeepCopyFunctionProgramBlock(updPB, new HashSet<>(), new HashSet<>()); + String fnameNew = prefix + func; + prog.addFunctionProgramBlock(namespace, fnameNew, copiedFunc); + } + + private static void recompileProgramBlocks(int k, ArrayList<ProgramBlock> pbs) { + // Reset the visit status from root + for (ProgramBlock pb : pbs) + DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock()); + + // Should recursively assign the level of parallelism + // and recompile the program block + try { + rAssignParallelism(pbs, k, false); + } catch (IOException e) { + throw new DMLRuntimeException(e); + } + } + + private static boolean rAssignParallelism(ArrayList<ProgramBlock> pbs, int k, boolean recompiled) throws IOException { + for (ProgramBlock pb : pbs) { + if (pb instanceof ParForProgramBlock) { + ParForProgramBlock pfpb = (ParForProgramBlock) pb; + pfpb.setDegreeOfParallelism(k); + recompiled |= rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled); + } else if (pb instanceof ForProgramBlock) { + recompiled |= rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled); + } else if (pb instanceof WhileProgramBlock) { + recompiled |= rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled); + } else if (pb instanceof FunctionProgramBlock) { + recompiled |= rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled); + } else if (pb instanceof IfProgramBlock) { + IfProgramBlock ipb = (IfProgramBlock) pb; + recompiled |= rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled); + if (ipb.getChildBlocksElseBody() != null) + recompiled |= rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled); + } else { + StatementBlock sb = pb.getStatementBlock(); + for (Hop hop : sb.getHops()) + recompiled |= rAssignParallelism(hop, k, recompiled); + } + // Recompile the program block + if (recompiled) { + Recompiler.recompileProgramBlockInstructions(pb); + } + } + return recompiled; + } + + private static boolean rAssignParallelism(Hop hop, int k, boolean recompiled) { + if (hop.isVisited()) { + return recompiled; + } + if (hop instanceof MultiThreadedHop) { + // Reassign the level of parallelism + MultiThreadedHop mhop = (MultiThreadedHop) hop; + mhop.setMaxNumThreads(k); + recompiled = true; + } + ArrayList<Hop> inputs = hop.getInput(); + for (Hop h : inputs) { + recompiled |= rAssignParallelism(h, k, recompiled); + } + hop.setVisited(); + return recompiled; + } + + + private static FunctionProgramBlock getFunctionBlock(ExecutionContext ec, String funcName) { + String[] keys = DMLProgram.splitFunctionKey(funcName); + String namespace = null; + String func = keys[0]; + if (keys.length == 2) { + namespace = keys[0]; + func = keys[1]; + } + return ec.getProgram().getFunctionProgramBlock(namespace, func); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 65a1930..d2c0edd 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -39,9 +39,6 @@ import static org.apache.sysml.parser.Statement.PS_UPDATE_TYPE; import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES; import static org.apache.sysml.parser.Statement.PS_VAL_LABELS; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.concurrent.ExecutionException; @@ -56,21 +53,10 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.log4j.Level; import org.apache.log4j.Logger; -import org.apache.sysml.hops.Hop; -import org.apache.sysml.hops.MultiThreadedHop; -import org.apache.sysml.hops.recompile.Recompiler; -import org.apache.sysml.parser.DMLProgram; -import org.apache.sysml.parser.DMLTranslator; -import org.apache.sysml.parser.StatementBlock; +import org.apache.sysml.api.DMLScript; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.ForProgramBlock; -import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; -import org.apache.sysml.runtime.controlprogram.IfProgramBlock; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; -import org.apache.sysml.runtime.controlprogram.ParForProgramBlock; import org.apache.sysml.runtime.controlprogram.Program; -import org.apache.sysml.runtime.controlprogram.ProgramBlock; -import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; @@ -82,9 +68,11 @@ import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerOR; import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer; import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer; -import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; +import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysml.runtime.matrix.operators.Operator; +import org.apache.sysml.utils.Statistics; public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction { @@ -111,21 +99,26 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc @Override public void processInstruction(ExecutionContext ec) { + Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null; + PSModeType mode = getPSMode(); int workerNum = getWorkerNum(mode); BasicThreadFactory factory = new BasicThreadFactory.Builder() - .namingPattern("workers-pool-thread-%d") - .build(); + .namingPattern("workers-pool-thread-%d").build(); ExecutorService es = Executors.newFixedThreadPool(workerNum, factory); String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); - // Create the workers' execution context int k = getParLevel(workerNum); - List<ExecutionContext> workerECs = createExecutionContext(ec, updFunc, workerNum, k); + + // Get the compiled execution context + ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, updFunc, aggFunc, workerNum, k); + + // Create workers' execution context + List<ExecutionContext> workerECs = createExecutionContext(workerNum, ec, newEC.getProgram()); // Create the agg service's execution context - ExecutionContext aggServiceEC = createExecutionContext(ec, aggFunc, 1, 1).get(0); + ExecutionContext aggServiceEC = createExecutionContext(1, ec, newEC.getProgram()).get(0); PSFrequency freq = getFrequency(); PSUpdateType updateType = getUpdateType(); @@ -145,6 +138,9 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc // Do data partition PSScheme scheme = getScheme(); doDataPartitioning(scheme, ec, workers); + + if (DMLScript.STATISTICS) + Statistics.accPSSetupTime((long) tSetup.stop()); if (LOG.isDebugEnabled()) { LOG.debug(String.format("\nConfiguration of paramserv func: " @@ -169,6 +165,18 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc } } + private List<ExecutionContext> createExecutionContext(int size, ExecutionContext ec, Program program) { + return IntStream.range(0, size).mapToObj(i -> { + // Put the hyperparam into the variables table + LocalVariableMap varsMap = new LocalVariableMap(); + ListObject hyperParams = getHyperParams(ec); + if (hyperParams != null) { + varsMap.put(PS_HYPER_PARAMS, hyperParams); + } + return ExecutionContextFactory.createContext(varsMap, program); + }).collect(Collectors.toList()); + } + private PSModeType getPSMode() { PSModeType mode; try { @@ -194,106 +202,6 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc return Math.max((int)Math.ceil((double)getRemainingCores()/workerNum), 1); } - private List<ExecutionContext> createExecutionContext(ExecutionContext ec, String funcName, int workerNum, int k) { - // Fetch the target function - String[] keys = DMLProgram.splitFunctionKey(funcName); - String namespace = null; - String func = keys[0]; - if (keys.length == 2) { - namespace = keys[0]; - func = keys[1]; - } - return createExecutionContext(ec, namespace, func, workerNum, k); - } - - private List<ExecutionContext> createExecutionContext(ExecutionContext ec, String namespace, String func, - int workerNum, int k) { - FunctionProgramBlock targetFunc = ec.getProgram().getFunctionProgramBlock(namespace, func); - return IntStream.range(0, workerNum).mapToObj(i -> { - // Put the hyperparam into the variables table - LocalVariableMap varsMap = new LocalVariableMap(); - ListObject hyperParams = getHyperParams(ec); - if (hyperParams != null) { - varsMap.put(PS_HYPER_PARAMS, hyperParams); - } - - // Deep copy the target func - FunctionProgramBlock copiedFunc = ProgramConverter - .createDeepCopyFunctionProgramBlock(targetFunc, new HashSet<>(), new HashSet<>()); - - // Reset the visit status from root - for( ProgramBlock pb : copiedFunc.getChildBlocks() ) - DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock()); - - // Should recursively assign the level of parallelism - // and recompile the program block - try { - rAssignParallelism(copiedFunc.getChildBlocks(), k, false); - } catch (IOException e) { - throw new DMLRuntimeException(e); - } - - Program prog = new Program(); - prog.addProgramBlock(copiedFunc); - prog.addFunctionProgramBlock(namespace, func, copiedFunc); - return ExecutionContextFactory.createContext(varsMap, prog); - - }).collect(Collectors.toList()); - } - - private boolean rAssignParallelism(ArrayList<ProgramBlock> pbs, int k, boolean recompiled) throws IOException { - for (ProgramBlock pb : pbs) { - if (pb instanceof ParForProgramBlock) { - ParForProgramBlock pfpb = (ParForProgramBlock) pb; - pfpb.setDegreeOfParallelism(k); - recompiled |= rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled); - } - else if (pb instanceof ForProgramBlock) { - recompiled |= rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled); - } - else if (pb instanceof WhileProgramBlock) { - recompiled |= rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled); - } - else if (pb instanceof FunctionProgramBlock) { - recompiled |= rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled); - } - else if (pb instanceof IfProgramBlock) { - IfProgramBlock ipb = (IfProgramBlock) pb; - recompiled |= rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled); - if (ipb.getChildBlocksElseBody() != null) - recompiled |= rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled); - } - else { - StatementBlock sb = pb.getStatementBlock(); - for (Hop hop : sb.getHops()) - recompiled |= rAssignParallelism(hop, k, recompiled); - } - // Recompile the program block - if (recompiled) { - Recompiler.recompileProgramBlockInstructions(pb); - } - } - return recompiled; - } - - private boolean rAssignParallelism(Hop hop, int k, boolean recompiled) { - if (hop.isVisited()) { - return recompiled; - } - if (hop instanceof MultiThreadedHop) { - // Reassign the level of parallelism - MultiThreadedHop mhop = (MultiThreadedHop) hop; - mhop.setMaxNumThreads(k); - recompiled = true; - } - ArrayList<Hop> inputs = hop.getInput(); - for (Hop h : inputs) { - recompiled |= rAssignParallelism(h, k, recompiled); - } - hop.setVisited(); - return recompiled; - } - private PSUpdateType getUpdateType() { PSUpdateType updType; try { @@ -310,13 +218,12 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc if (!getParameterMap().containsKey(PS_FREQUENCY)) { return DEFAULT_UPDATE_FREQUENCY; } - PSFrequency freq; try { - freq = PSFrequency.valueOf(getParam(PS_FREQUENCY)); + return PSFrequency.valueOf(getParam(PS_FREQUENCY)); } catch (IllegalArgumentException e) { - throw new DMLRuntimeException(String.format("Paramserv function: not support '%s' update frequency.", getParam(PS_FREQUENCY))); + throw new DMLRuntimeException(String.format("Paramserv function: " + + "not support '%s' update frequency.", getParam(PS_FREQUENCY))); } - return freq; } private int getRemainingCores() { @@ -330,19 +237,16 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc * @return worker numbers */ private int getWorkerNum(PSModeType mode) { - int workerNum = -1; switch (mode) { case LOCAL: // default worker number: available cores - 1 (assign one process for agg service) - workerNum = getRemainingCores(); - if (getParameterMap().containsKey(PS_PARALLELISM)) { - workerNum = Math.min(workerNum, Integer.valueOf(getParam(PS_PARALLELISM))); - } - break; - case REMOTE_SPARK: - throw new DMLRuntimeException("Do not support remote spark."); + int workerNum = getRemainingCores(); + if (getParameterMap().containsKey(PS_PARALLELISM)) + workerNum = Integer.valueOf(getParam(PS_PARALLELISM)); + return workerNum; + default: + throw new DMLRuntimeException("Unsupported parameter server: "+mode.name()); } - return workerNum; } /** @@ -351,15 +255,12 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc * @return parameter server */ private ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) { - ParamServer ps = null; switch (mode) { case LOCAL: - ps = new LocalParamServer(model, aggFunc, updateType, ec, workerNum); - break; - case REMOTE_SPARK: - throw new DMLRuntimeException("Do not support remote spark."); + return new LocalParamServer(model, aggFunc, updateType, ec, workerNum); + default: + throw new DMLRuntimeException("Unsupported parameter server: "+mode.name()); } - return ps; } private long getBatchSize() { http://git-wip-us.apache.org/repos/asf/systemml/blob/b06f390e/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index 8cf3f02..0618b38 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -107,6 +107,15 @@ public class Statistics private static final LongAdder sparkBroadcast = new LongAdder(); private static final LongAdder sparkBroadcastCount = new LongAdder(); + // Paramserv function stats (time is in milli sec) + private static final LongAdder psNumWorkers = new LongAdder(); + private static final LongAdder psSetupTime = new LongAdder(); + private static final LongAdder psGradientComputeTime = new LongAdder(); + private static final LongAdder psAggregationTime = new LongAdder(); + private static final LongAdder psLocalModelUpdateTime = new LongAdder(); + private static final LongAdder psModelBroadcastTime = new LongAdder(); + private static final LongAdder psBatchIndexTime = new LongAdder(); + //PARFOR optimization stats (low frequency updates) private static long parforOptTime = 0; //in milli sec private static long parforOptCount = 0; //count @@ -516,7 +525,34 @@ public class Statistics public static void incSparkBroadcastCount(long c) { sparkBroadcastCount.add(c); } - + + public static void incWorkerNumber() { + psNumWorkers.increment(); + } + + public static void accPSSetupTime(long t) { + psSetupTime.add(t); + } + + public static void accPSGradientComputeTime(long t) { + psGradientComputeTime.add(t); + } + + public static void accPSAggregationTime(long t) { + psAggregationTime.add(t); + } + + public static void accPSLocalModelUpdateTime(long t) { + psLocalModelUpdateTime.add(t); + } + + public static void accPSModelBroadcastTime(long t) { + psModelBroadcastTime.add(t); + } + + public static void accPSBatchIndexingTime(long t) { + psBatchIndexTime.add(t); + } public static String getCPHeavyHitterCode( Instruction inst ) { @@ -850,6 +886,15 @@ public class Statistics ((double)sparkBroadcast.longValue())*1e-9, ((double)sparkCollect.longValue())*1e-9)); } + if (psNumWorkers.longValue() > 0) { + sb.append(String.format("Paramserv total num workers:\t%d.\n", psNumWorkers.longValue())); + sb.append(String.format("Paramserv setup time:\t\t%.3f secs.\n", psSetupTime.doubleValue() / 1000)); + sb.append(String.format("Paramserv grad compute time:\t%.3f secs.\n", psGradientComputeTime.doubleValue() / 1000)); + sb.append(String.format("Paramserv model update time:\t%.3f/%.3f secs.\n", + psLocalModelUpdateTime.doubleValue() / 1000, psAggregationTime.doubleValue() / 1000)); + sb.append(String.format("Paramserv model broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 1000)); + sb.append(String.format("Paramserv batch slice time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000)); + } if( parforOptCount>0 ){ sb.append("ParFor loops optimized:\t\t" + getParforOptCount() + ".\n"); sb.append("ParFor optimize time:\t\t" + String.format("%.3f", ((double)getParforOptTime())/1000) + " sec.\n");
