atefeh-asayesh commented on a change in pull request #1336:
URL: https://github.com/apache/systemds/pull/1336#discussion_r676040171



##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
##########
@@ -33,24 +33,32 @@ public LocalParamServer() {
 
        public static LocalParamServer create(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
                Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
-               MatrixObject valFeatures, MatrixObject valLabels)
+               MatrixObject valFeatures, MatrixObject valLabels,boolean 
modelAvg)
        {
                return new LocalParamServer(model, aggFunc, updateType, freq, 
ec,
-                       workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels);
+                       workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels,modelAvg);
        }
 
        private LocalParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
                Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
-               MatrixObject valFeatures, MatrixObject valLabels)
+               MatrixObject valFeatures, MatrixObject valLabels,boolean 
modelAvg)
        {
-               super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, 
numBatchesPerEpoch, valFeatures, valLabels);
+               super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, 
numBatchesPerEpoch, valFeatures, valLabels,modelAvg);
        }
 
        @Override
        public void push(int workerID, ListObject gradients) {
-               updateGlobalModel(workerID, gradients);
+               updModel_avgModel(workerID, gradients);
        }
 
+       /*
+       public void push(int workerID, ListObject values) {

Review comment:
       Commented code has been removed.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -41,17 +35,43 @@
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
-import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.cp.DoubleObject;
-import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.*;

Review comment:
       Thanks for the comment.  wild-card imports are removed.
   
   
   

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -127,12 +155,12 @@ protected void setupAggFunc(ExecutionContext ec, String 
aggFunc) {
                _outputName = outputs.get(0).getName();
 
                CPOperand[] boundInputs = inputs.stream()
-                       .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
-                       .toArray(CPOperand[]::new);
+                               .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))

Review comment:
       Fixed.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
                return _model;
        }
 
-       protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
+       protected synchronized void updModel_avgModel(int workerID, ListObject 
params){
+               if (_modelAvg == true){
+                       updateAverageModel(workerID, params);}
+               else if (_modelAvg == false)
+                       updateGlobalModel(workerID, params);
+
+}
+       protected  void updateAverageModel(int workerID, ListObject models) {
                try {
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Successfully pulled 
the gradients [size:%d kb] of worker_%d.",
-                                       gradients.getDataSize() / 1024, 
workerID));
+                                               models.getDataSize() / 1024, 
workerID));
                        }
 
                        switch(_updateType) {
                                case BSP: {
                                        setFinishedState(workerID);
-
-                                       // Accumulate the intermediate gradients
-                                       if( ACCRUE_BSP_GRADIENTS )
-                                               _accGradients = 
ParamservUtils.accrueGradients(_accGradients, gradients, true);
-                                       else
-                                               updateGlobalModel(gradients);
+                                       _accModel = 
ParamservUtils.accrueModels(_accModel, models, true);
 
                                        if (allFinished()) {
-                                               // Update the global model with 
accrued gradients
-                                               if( ACCRUE_BSP_GRADIENTS ) {
-                                                       
updateGlobalModel(_accGradients);
-                                                       _accGradients = null;
-                                               }
+                                               averageGlobalModel(_accModel);
+                                               _accModel = null;
 
                                                // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
                                                // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
                                                // BSP epoch case every time
                                                if (_numBatchesPerEpoch != -1 &&
-                                                       (_freq == 
Statement.PSFrequency.EPOCH ||
-                                                       (_freq == 
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+                                                               (_freq == 
Statement.PSFrequency.EPOCH ||
+                                                                               
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch 
== 0))) {
 
                                                        if(LOG.isInfoEnabled())
                                                                LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
 
                                                        time_epoch();
-
                                                        if(_validationPossible)
                                                                validate();
-
                                                        _epochCounter++;
                                                        _syncCounter = 0;
+

Review comment:
       Fixed.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
                return _model;
        }
 
-       protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
+       protected synchronized void updModel_avgModel(int workerID, ListObject 
params){
+               if (_modelAvg == true){
+                       updateAverageModel(workerID, params);}
+               else if (_modelAvg == false)
+                       updateGlobalModel(workerID, params);
+
+}
+       protected  void updateAverageModel(int workerID, ListObject models) {
                try {
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Successfully pulled 
the gradients [size:%d kb] of worker_%d.",
-                                       gradients.getDataSize() / 1024, 
workerID));
+                                               models.getDataSize() / 1024, 
workerID));
                        }
 
                        switch(_updateType) {
                                case BSP: {
                                        setFinishedState(workerID);
-
-                                       // Accumulate the intermediate gradients
-                                       if( ACCRUE_BSP_GRADIENTS )
-                                               _accGradients = 
ParamservUtils.accrueGradients(_accGradients, gradients, true);
-                                       else
-                                               updateGlobalModel(gradients);
+                                       _accModel = 
ParamservUtils.accrueModels(_accModel, models, true);
 
                                        if (allFinished()) {
-                                               // Update the global model with 
accrued gradients
-                                               if( ACCRUE_BSP_GRADIENTS ) {
-                                                       
updateGlobalModel(_accGradients);
-                                                       _accGradients = null;
-                                               }
+                                               averageGlobalModel(_accModel);
+                                               _accModel = null;
 
                                                // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
                                                // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
                                                // BSP epoch case every time
                                                if (_numBatchesPerEpoch != -1 &&
-                                                       (_freq == 
Statement.PSFrequency.EPOCH ||
-                                                       (_freq == 
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+                                                               (_freq == 
Statement.PSFrequency.EPOCH ||
+                                                                               
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch 
== 0))) {
 
                                                        if(LOG.isInfoEnabled())
                                                                LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
 
                                                        time_epoch();
-
                                                        if(_validationPossible)
                                                                validate();
-
                                                        _epochCounter++;
                                                        _syncCounter = 0;
+
                                                }
-                                               
                                                // Broadcast the updated model
                                                resetFinishedStates();
+
                                                broadcastModel(true);
                                                if (LOG.isDebugEnabled())
-                                                       LOG.debug("Global 
parameter is broadcasted successfully.");
+                                                       LOG.debug("Global 
Averaging parameter is broadcasted successfully ");

Review comment:
       Thanks for the comment. The log information is considered as the 
previous one and the new log has been removed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to