Repository: systemml Updated Branches: refs/heads/master eb179b151 -> 63a1e2ac5
[SYSTEMML-2403] Fix accuracy issue paramserv BSP batch updates Closes #791. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/63a1e2ac Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/63a1e2ac Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/63a1e2ac Branch: refs/heads/master Commit: 63a1e2ac59f3201ab99a6e5e71636133eec96b1b Parents: eb179b1 Author: EdgarLGB <[email protected]> Authored: Sat Jul 7 18:40:25 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jul 7 18:40:26 2018 -0700 ---------------------------------------------------------------------- .../controlprogram/paramserv/LocalPSWorker.java | 19 ++--------- .../controlprogram/paramserv/ParamServer.java | 34 +++++++++++++++----- .../paramserv/ParamservUtils.java | 24 ++++++++++++++ 3 files changed, 52 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/63a1e2ac/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index 0ed7c81..366284c 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -20,7 +20,6 @@ package org.apache.sysml.runtime.controlprogram.paramserv; import java.util.concurrent.Callable; -import java.util.stream.IntStream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,10 +29,7 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing; -import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.cp.ListObject; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.operators.BinaryOperator; import org.apache.sysml.utils.Statistics; public class LocalPSWorker extends PSWorker implements Callable<Void> { @@ -84,13 +80,12 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { ListObject gradients = computeGradients(dataSize, totalIter, i, j); // Accumulate the intermediate gradients - accGradients = (accGradients==null) ? - ParamservUtils.copyList(gradients) : - accrueGradients(accGradients, gradients); + accGradients = ParamservUtils.accrueGradients(accGradients, gradients); // Update the local model with gradients if( j < totalIter - 1 ) params = updateModel(params, gradients, i, j, totalIter); + ParamservUtils.cleanupListObject(gradients); } // Push the gradients to ps @@ -193,14 +188,4 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { return gradients; } - private ListObject accrueGradients(ListObject accGradients, ListObject gradients) { - IntStream.range(0, accGradients.getLength()).forEach(i -> { - MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead(); - MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead(); - mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2); - ((MatrixObject) accGradients.getData().get(i)).release(); - ((MatrixObject) gradients.getData().get(i)).release(); - }); - return accGradients; - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/63a1e2ac/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java index abec267..432d4fc 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java @@ -49,7 +49,8 @@ import org.apache.sysml.utils.Statistics; public abstract class ParamServer { - protected final Log LOG = LogFactory.getLog(ParamServer.class.getName()); + protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName()); + protected static final boolean ACCRUE_BSP_GRADIENTS = true; // worker input queues and global model protected final Map<Integer, BlockingQueue<ListObject>> _modelMap; @@ -61,6 +62,7 @@ public abstract class ParamServer private final FunctionCallCPInstruction _inst; private final String _outputName; private final boolean[] _finishedStates; // Workers' finished states + private ListObject _accGradients = null; protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) { // init worker queues and global model @@ -126,17 +128,25 @@ public abstract class ParamServer gradients.getDataSize() / 1024, workerID)); } - // Update and redistribute the model - Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null; - _model = updateLocalModel(_ec, gradients, _model); - if (DMLScript.STATISTICS) - Statistics.accPSAggregationTime((long) tAgg.stop()); - - // Redistribute model according to update type switch(_updateType) { case BSP: { setFinishedState(workerID); + + // Accumulate the intermediate gradients + if( ACCRUE_BSP_GRADIENTS ) + _accGradients = ParamservUtils.accrueGradients( + _accGradients, gradients, true); + else + updateGlobalModel(gradients); + ParamservUtils.cleanupListObject(gradients); + if (allFinished()) { + // Update the global model with accrued gradients + if( ACCRUE_BSP_GRADIENTS ) { + updateGlobalModel(_accGradients); + _accGradients = null; + } + // Broadcast the updated model resetFinishedStates(); broadcastModel(); @@ -146,6 +156,7 @@ public abstract class ParamServer break; } case ASP: { + updateGlobalModel(gradients); broadcastModel(workerID); break; } @@ -158,6 +169,13 @@ public abstract class ParamServer } } + private void updateGlobalModel(ListObject gradients) { + Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null; + _model = updateLocalModel(_ec, gradients, _model); + if (DMLScript.STATISTICS) + Statistics.accPSAggregationTime((long) tAgg.stop()); + } + /** * A service method for updating model with gradients * http://git-wip-us.apache.org/repos/asf/systemml/blob/63a1e2ac/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index ecfac66..3aee170 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -50,6 +50,7 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; +import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.ListObject; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -57,6 +58,7 @@ import org.apache.sysml.runtime.matrix.MetaDataFormat; import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.OutputInfo; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; public class ParamservUtils { @@ -88,6 +90,10 @@ public class ParamservUtils { public static void cleanupListObject(ExecutionContext ec, String lName) { ListObject lo = (ListObject) ec.removeVariable(lName); + cleanupListObject(lo); + } + + public static void cleanupListObject(ListObject lo) { lo.getData().forEach(ParamservUtils::cleanupData); } @@ -258,4 +264,22 @@ public class ParamservUtils { String fname = cfn[1]; return ec.getProgram().getFunctionProgramBlock(ns, fname); } + + public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) { + return accrueGradients(accGradients, gradients, false); + } + + public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean par) { + if (accGradients == null) + return ParamservUtils.copyList(gradients); + IntStream range = IntStream.range(0, accGradients.getLength()); + (par ? range.parallel() : range).forEach(i -> { + MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead(); + MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead(); + mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2); + ((MatrixObject) accGradients.getData().get(i)).release(); + ((MatrixObject) gradients.getData().get(i)).release(); + }); + return accGradients; + } }
