Repository: systemml
Updated Branches:
  refs/heads/master 8a5bdba43 -> dfa27ba22


[MINOR] Fix paramserv accumulator handling (uninitialized stats)

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/dfa27ba2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/dfa27ba2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/dfa27ba2

Branch: refs/heads/master
Commit: dfa27ba22e3e64a253012f7f77eb918be5e0ef1a
Parents: 8a5bdba
Author: Matthias Boehm <[email protected]>
Authored: Fri Jul 27 17:18:58 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Fri Jul 27 17:18:58 2018 -0700

----------------------------------------------------------------------
 .../paramserv/spark/SparkPSWorker.java          | 32 ++++++++------------
 .../cp/ParamservBuiltinCPInstruction.java       | 14 ++++-----
 2 files changed, 19 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/dfa27ba2/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
index 59203ad..5732a4d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
@@ -26,7 +26,6 @@ import java.util.Map;
 import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.function.VoidFunction;
 import org.apache.spark.util.LongAccumulator;
-import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.codegen.CodegenUtils;
 import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
@@ -84,7 +83,7 @@ public class SparkPSWorker extends LocalPSWorker implements 
VoidFunction<Tuple2<
        
        @Override
        public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> 
input) throws Exception {
-               Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
+               Timing tSetup = new Timing(true);
                configureWorker(input);
                accSetupTime(tSetup);
 
@@ -130,43 +129,36 @@ public class SparkPSWorker extends LocalPSWorker 
implements VoidFunction<Tuple2<
        
 
        @Override
-       public void incWorkerNumber() {
-               if (DMLScript.STATISTICS)
-                       _aWorker.add(1);
+       protected void incWorkerNumber() {
+               _aWorker.add(1);
        }
        
        @Override
-       public void accLocalModelUpdateTime(Timing time) {
-               if (DMLScript.STATISTICS)
-                       _aUpdate.add((long) time.stop());
+       protected void accLocalModelUpdateTime(Timing time) {
+               _aUpdate.add((long) time.stop());
        }
 
        @Override
-       public void accBatchIndexingTime(Timing time) {
-               if (DMLScript.STATISTICS)
-                       _aIndex.add((long) time.stop());
+       protected void accBatchIndexingTime(Timing time) {
+               _aIndex.add((long) time.stop());
        }
 
        @Override
-       public void accGradientComputeTime(Timing time) {
-               if (DMLScript.STATISTICS)
-                       _aGrad.add((long) time.stop());
+       protected void accGradientComputeTime(Timing time) {
+               _aGrad.add((long) time.stop());
        }
        
        @Override
        protected void accNumEpochs(int n) {
-               if (DMLScript.STATISTICS)
-                       _nEpochs.add(n);
+               _nEpochs.add(n);
        }
        
        @Override
        protected void accNumBatches(int n) {
-               if (DMLScript.STATISTICS)
-                       _nBatches.add(n);
+               _nBatches.add(n);
        }
        
        private void accSetupTime(Timing tSetup) {
-               if (DMLScript.STATISTICS)
-                       _aSetup.add((long) tSetup.stop());
+               _aSetup.add((long) tSetup.stop());
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/dfa27ba2/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index fe238bd..6220bb6 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -163,7 +163,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                LongAccumulator aRPC = 
sec.getSparkContext().sc().longAccumulator("rpcRequest");
                LongAccumulator aBatch = 
sec.getSparkContext().sc().longAccumulator("numBatches");
                LongAccumulator aEpoch = 
sec.getSparkContext().sc().longAccumulator("numEpochs");
-
+               
                // Create remote workers
                SparkPSWorker worker = new 
SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), 
                        getFrequency(), getEpochs(), getBatchSize(), program, 
clsMap, sec.getSparkContext().getConf(),
@@ -184,12 +184,12 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                // Accumulate the statistics for remote workers
                if (DMLScript.STATISTICS) {
-                       Statistics.accPSSetupTime(aSetup.sum());
-                       Statistics.incWorkerNumber(aWorker.sum());
-                       Statistics.accPSLocalModelUpdateTime(aUpdate.sum());
-                       Statistics.accPSBatchIndexingTime(aIndex.sum());
-                       Statistics.accPSGradientComputeTime(aGrad.sum());
-                       Statistics.accPSRpcRequestTime(aRPC.sum());
+                       Statistics.accPSSetupTime(aSetup.value().longValue());
+                       Statistics.incWorkerNumber(aWorker.value().longValue());
+                       
Statistics.accPSLocalModelUpdateTime(aUpdate.value().longValue());
+                       
Statistics.accPSBatchIndexingTime(aIndex.value().longValue());
+                       
Statistics.accPSGradientComputeTime(aGrad.value().longValue());
+                       
Statistics.accPSRpcRequestTime(aRPC.value().longValue());
                }
 
                // Fetch the final model from ps

Reply via email to