Repository: systemml Updated Branches: refs/heads/master 8a5bdba43 -> dfa27ba22
[MINOR] Fix paramserv accumulator handling (uninitialized stats) Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/dfa27ba2 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/dfa27ba2 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/dfa27ba2 Branch: refs/heads/master Commit: dfa27ba22e3e64a253012f7f77eb918be5e0ef1a Parents: 8a5bdba Author: Matthias Boehm <[email protected]> Authored: Fri Jul 27 17:18:58 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Fri Jul 27 17:18:58 2018 -0700 ---------------------------------------------------------------------- .../paramserv/spark/SparkPSWorker.java | 32 ++++++++------------ .../cp/ParamservBuiltinCPInstruction.java | 14 ++++----- 2 files changed, 19 insertions(+), 27 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/dfa27ba2/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java index 59203ad..5732a4d 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java @@ -26,7 +26,6 @@ import java.util.Map; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.util.LongAccumulator; -import org.apache.sysml.api.DMLScript; import org.apache.sysml.parser.Statement; import org.apache.sysml.runtime.codegen.CodegenUtils; import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; @@ -84,7 +83,7 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2< @Override public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception { - Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null; + Timing tSetup = new Timing(true); configureWorker(input); accSetupTime(tSetup); @@ -130,43 +129,36 @@ public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2< @Override - public void incWorkerNumber() { - if (DMLScript.STATISTICS) - _aWorker.add(1); + protected void incWorkerNumber() { + _aWorker.add(1); } @Override - public void accLocalModelUpdateTime(Timing time) { - if (DMLScript.STATISTICS) - _aUpdate.add((long) time.stop()); + protected void accLocalModelUpdateTime(Timing time) { + _aUpdate.add((long) time.stop()); } @Override - public void accBatchIndexingTime(Timing time) { - if (DMLScript.STATISTICS) - _aIndex.add((long) time.stop()); + protected void accBatchIndexingTime(Timing time) { + _aIndex.add((long) time.stop()); } @Override - public void accGradientComputeTime(Timing time) { - if (DMLScript.STATISTICS) - _aGrad.add((long) time.stop()); + protected void accGradientComputeTime(Timing time) { + _aGrad.add((long) time.stop()); } @Override protected void accNumEpochs(int n) { - if (DMLScript.STATISTICS) - _nEpochs.add(n); + _nEpochs.add(n); } @Override protected void accNumBatches(int n) { - if (DMLScript.STATISTICS) - _nBatches.add(n); + _nBatches.add(n); } private void accSetupTime(Timing tSetup) { - if (DMLScript.STATISTICS) - _aSetup.add((long) tSetup.stop()); + _aSetup.add((long) tSetup.stop()); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/dfa27ba2/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index fe238bd..6220bb6 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -163,7 +163,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc LongAccumulator aRPC = sec.getSparkContext().sc().longAccumulator("rpcRequest"); LongAccumulator aBatch = sec.getSparkContext().sc().longAccumulator("numBatches"); LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs"); - + // Create remote workers SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), getFrequency(), getEpochs(), getBatchSize(), program, clsMap, sec.getSparkContext().getConf(), @@ -184,12 +184,12 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc // Accumulate the statistics for remote workers if (DMLScript.STATISTICS) { - Statistics.accPSSetupTime(aSetup.sum()); - Statistics.incWorkerNumber(aWorker.sum()); - Statistics.accPSLocalModelUpdateTime(aUpdate.sum()); - Statistics.accPSBatchIndexingTime(aIndex.sum()); - Statistics.accPSGradientComputeTime(aGrad.sum()); - Statistics.accPSRpcRequestTime(aRPC.sum()); + Statistics.accPSSetupTime(aSetup.value().longValue()); + Statistics.incWorkerNumber(aWorker.value().longValue()); + Statistics.accPSLocalModelUpdateTime(aUpdate.value().longValue()); + Statistics.accPSBatchIndexingTime(aIndex.value().longValue()); + Statistics.accPSGradientComputeTime(aGrad.value().longValue()); + Statistics.accPSRpcRequestTime(aRPC.value().longValue()); } // Fetch the final model from ps
