mboehm7 commented on a change in pull request #1336:
URL: https://github.com/apache/systemds/pull/1336#discussion_r672207342



##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
##########
@@ -337,6 +338,7 @@ protected void weighAndPushGradients(ListObject gradients) {
 
                // Push the gradients to ps
                _ps.push(_workerID, gradients);
+               //_ps.push(_workerID, modell)

Review comment:
       what's this? please conditionally on the configuration either push the 
gradients or model

##########
File path: src/main/java/org/apache/sysds/parser/Statement.java
##########
@@ -72,6 +72,9 @@
        public static final String PS_MODE = "mode";
        public static final String PS_GRADIENTS = "gradients";
        public static final String PS_SEED = "seed";
+       public static final String PS_MODELAVG = "modelAvg";
+       public static final String PS_MODELS = "models";

Review comment:
       why do we need this besides modelAvg - remove if unnecessary.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -112,69 +120,89 @@ private void computeEpoch(long dataSize, int batchIter) {
                        catch(ExecutionException | InterruptedException ex) {
                                throw new DMLRuntimeException(ex);
                        }
-                       
+
                        accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
        }
 
+       protected ListObject createLocalModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
+               // Populate the variables table with the gradients and model
+               ec.setVariable(Statement.PS_GRADIENTS, gradients);
+               ec.setVariable(Statement.PS_MODEL, model);
+
+               // Invoke the aggregate function
+               _inst.processInstruction(ec);
+
+               // Get the new model
+               ListObject newModel = ec.getListObject(_outputName);
+
+               // Clean up the list according to the data referencing status
+               ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL, 
newModel.getStatus());
+               ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+               return newModel;
+       }
+
        private ListObject updateModel(ListObject globalParams, ListObject 
gradients, int i, int j, int batchIter) {
                Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
 
                globalParams = _ps.updateLocalModel(_ec, gradients, 
globalParams);
 
                accLocalModelUpdateTime(tUpd);
-               
+
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: local global parameter 
[size:%d kb] updated. "
-                               + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]",
-                               getWorkerName(), globalParams.getDataSize(), i 
+ 1, _epochs, j + 1, batchIter));
+                                                       + "[Epoch:%d  Total 
epoch:%d  Iteration:%d  Total iteration:%d]",
+                                       getWorkerName(), 
globalParams.getDataSize(), i + 1, _epochs, j + 1, batchIter));

Review comment:
       fix the corrupted formatting.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -112,69 +120,89 @@ private void computeEpoch(long dataSize, int batchIter) {
                        catch(ExecutionException | InterruptedException ex) {
                                throw new DMLRuntimeException(ex);
                        }
-                       
+
                        accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
        }
 
+       protected ListObject createLocalModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
+               // Populate the variables table with the gradients and model
+               ec.setVariable(Statement.PS_GRADIENTS, gradients);
+               ec.setVariable(Statement.PS_MODEL, model);
+
+               // Invoke the aggregate function
+               _inst.processInstruction(ec);
+
+               // Get the new model
+               ListObject newModel = ec.getListObject(_outputName);
+
+               // Clean up the list according to the data referencing status
+               ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL, 
newModel.getStatus());
+               ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+               return newModel;
+       }
+
        private ListObject updateModel(ListObject globalParams, ListObject 
gradients, int i, int j, int batchIter) {
                Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
 
                globalParams = _ps.updateLocalModel(_ec, gradients, 
globalParams);
 
                accLocalModelUpdateTime(tUpd);
-               
+
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: local global parameter 
[size:%d kb] updated. "
-                               + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]",
-                               getWorkerName(), globalParams.getDataSize(), i 
+ 1, _epochs, j + 1, batchIter));
+                                                       + "[Epoch:%d  Total 
epoch:%d  Iteration:%d  Total iteration:%d]",
+                                       getWorkerName(), 
globalParams.getDataSize(), i + 1, _epochs, j + 1, batchIter));
                }
                return globalParams;
        }
-
        private void computeBatch(long dataSize, int totalIter) {
                for (int i = 0; i < _epochs; i++) {
                        for (int j = 0; j < totalIter; j++) {
                                ListObject globalParams = pullModel();
 
                                ListObject gradients = 
computeGradients(globalParams, dataSize, totalIter, i, j);
-
                                // Push the gradients to ps
                                pushGradients(gradients);
                                ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
-                               
+
                                accNumBatches(1);
                        }
-                       
+
                        accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
        }
-
        private ListObject pullModel() {
                // Pull the global parameters from ps
                ListObject globalParams = _ps.pull(_workerID);
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: successfully pull the 
global parameters "
-                               + "[size:%d kb] from ps.", getWorkerName(), 
globalParams.getDataSize() / 1024));
+                                       + "[size:%d kb] from ps.", 
getWorkerName(), globalParams.getDataSize() / 1024));
                }
                return globalParams;
        }
-
        private void pushGradients(ListObject gradients) {
                // Push the gradients to ps
                _ps.push(_workerID, gradients);
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: successfully push the 
gradients "
-                               + "[size:%d kb] to ps.", getWorkerName(), 
gradients.getDataSize() / 1024));
+                                       + "[size:%d kb] to ps.", 
getWorkerName(), gradients.getDataSize() / 1024));
+               }
+       }
+       private void pushModelToServer(ListObject modell) {
+               // Push the Model to ps
+               _ps.push(_workerID, modell);

Review comment:
       please, globally correct the spelling of variables: it's model not 
modell.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
                return _model;
        }
 
-       protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
+       protected synchronized void updModel_avgModel(int workerID, ListObject 
params){
+               if (_modelAvg == true){
+                       updateAverageModel(workerID, params);}
+               else if (_modelAvg == false)
+                       updateGlobalModel(workerID, params);
+
+}
+       protected  void updateAverageModel(int workerID, ListObject models) {
                try {
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Successfully pulled 
the gradients [size:%d kb] of worker_%d.",
-                                       gradients.getDataSize() / 1024, 
workerID));
+                                               models.getDataSize() / 1024, 
workerID));
                        }
 
                        switch(_updateType) {
                                case BSP: {
                                        setFinishedState(workerID);
-
-                                       // Accumulate the intermediate gradients
-                                       if( ACCRUE_BSP_GRADIENTS )
-                                               _accGradients = 
ParamservUtils.accrueGradients(_accGradients, gradients, true);
-                                       else
-                                               updateGlobalModel(gradients);
+                                       _accModel = 
ParamservUtils.accrueModels(_accModel, models, true);
 
                                        if (allFinished()) {
-                                               // Update the global model with 
accrued gradients
-                                               if( ACCRUE_BSP_GRADIENTS ) {
-                                                       
updateGlobalModel(_accGradients);
-                                                       _accGradients = null;
-                                               }
+                                               averageGlobalModel(_accModel);
+                                               _accModel = null;
 
                                                // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
                                                // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
                                                // BSP epoch case every time
                                                if (_numBatchesPerEpoch != -1 &&
-                                                       (_freq == 
Statement.PSFrequency.EPOCH ||
-                                                       (_freq == 
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+                                                               (_freq == 
Statement.PSFrequency.EPOCH ||
+                                                                               
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch 
== 0))) {
 
                                                        if(LOG.isInfoEnabled())
                                                                LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
 
                                                        time_epoch();
-
                                                        if(_validationPossible)
                                                                validate();
-
                                                        _epochCounter++;
                                                        _syncCounter = 0;
+
                                                }
-                                               
                                                // Broadcast the updated model
                                                resetFinishedStates();
+
                                                broadcastModel(true);
                                                if (LOG.isDebugEnabled())
-                                                       LOG.debug("Global 
parameter is broadcasted successfully.");
+                                                       LOG.debug("Global 
Averaging parameter is broadcasted successfully ");

Review comment:
       what is an Averaging parameter - the model broadcast should be 
unaffected by the introduction of model averaging.

##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -327,12 +317,12 @@ private void runLocally(ExecutionContext ec, PSModeType 
mode) {
                MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) 
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
                MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? 
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
                ParamServer ps = createPS(mode, aggFunc, updateType, freq, 
workerNum, model, aggServiceEC, getValFunction(),
-                               num_batches_per_epoch, val_features, 
val_labels);
+                               num_batches_per_epoch, val_features, 
val_labels,parseBoolean(modelAvg));
 
                // Create the local workers
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
                        .mapToObj(i -> new LocalPSWorker(i, updFunc, freq,
-                               getEpochs(), getBatchSize(), workerECs.get(i), 
ps))
+                               getEpochs(), getBatchSize(), workerECs.get(i), 
ps,parseBoolean(modelAvg)))

Review comment:
       missing spaces before additional arg

##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -468,21 +458,21 @@ private int getWorkerNum(PSModeType mode) {
         * @return parameter server
         */
        private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType,
-               PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec)
+               PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec,boolean modelAvg)
        {
-               return createPS(mode, aggFunc, updateType, freq, workerNum, 
model, ec, null, -1, null, null);
+               return createPS(mode, aggFunc, updateType, freq, workerNum, 
model, ec, null, -1, null, null,modelAvg );
        }
 
        // When this creation is used the parameter server is able to validate 
after each epoch
        private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType,
                PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec, String valFunc,
-               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels)
+               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels,boolean modelAvg)
        {
-               switch (mode) {
+                       switch (mode) {

Review comment:
       wrong formatting.

##########
File path: src/test/scripts/functions/federated/paramserv/TwoNN.dml
##########
@@ -126,7 +126,7 @@ train = function(matrix[double] X, matrix[double] y,
 train_paramserv = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
                  int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing, string weighting,
-                 double eta, int seed = -1)
+                 double eta, int seed = -1,boolean modelAvg)

Review comment:
       formatting.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -191,9 +219,9 @@ private ListObject computeGradients(ListObject params, long 
dataSize, int batchI
 
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: got batch data [size:%d 
kb] of index from %d to %d [last index: %d]. "
-                               + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]", getWorkerName(),
-                               bFeatures.getDataSize() / 1024 + 
bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs,
-                               j + 1, batchIter));
+                                                       + "[Epoch:%d  Total 
epoch:%d  Iteration:%d  Total iteration:%d]", getWorkerName(),

Review comment:
       see formatting above.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
##########
@@ -360,7 +362,26 @@ protected void computeWithBatchUpdates() {
                        }
                }
        }
+       //****************************************  ATEFEH 
*********************************************************************

Review comment:
       we do not use author tags - so please remove such comments with your 
name.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -89,19 +97,19 @@ private void computeEpoch(long dataSize, int batchIter) {
                        try {
                                for (int j = 0; j < batchIter; j++) {
                                        ListObject gradients = 
computeGradients(params, dataSize, batchIter, i, j);
-       
+
                                        boolean localUpdate = j < batchIter - 1;
-                                       
-                                       // Accumulate the intermediate 
gradients (async for overlap w/ model updates 
+
+                                       // Accumulate the intermediate 
gradients (async for overlap w/ model updates
                                        // and gradient computation, sequential 
over gradient matrices to avoid deadlocks)
                                        ListObject accGradientsPrev = 
accGradients.get();
                                        accGradients = _tpool.submit(() -> 
ParamservUtils.accrueGradients(

Review comment:
       we only need to accrue gradients if we aim to exchange them - if model 
averaging is enabled this can be avoided.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
##########
@@ -33,24 +33,32 @@ public LocalParamServer() {
 
        public static LocalParamServer create(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
                Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
-               MatrixObject valFeatures, MatrixObject valLabels)
+               MatrixObject valFeatures, MatrixObject valLabels,boolean 
modelAvg)
        {
                return new LocalParamServer(model, aggFunc, updateType, freq, 
ec,
-                       workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels);
+                       workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels,modelAvg);
        }
 
        private LocalParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
                Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
-               MatrixObject valFeatures, MatrixObject valLabels)
+               MatrixObject valFeatures, MatrixObject valLabels,boolean 
modelAvg)
        {
-               super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, 
numBatchesPerEpoch, valFeatures, valLabels);
+               super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, 
numBatchesPerEpoch, valFeatures, valLabels,modelAvg);
        }
 
        @Override
        public void push(int workerID, ListObject gradients) {
-               updateGlobalModel(workerID, gradients);
+               updModel_avgModel(workerID, gradients);
        }
 
+       /*
+       public void push(int workerID, ListObject values) {

Review comment:
       remove such commented code.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -89,19 +97,19 @@ private void computeEpoch(long dataSize, int batchIter) {
                        try {
                                for (int j = 0; j < batchIter; j++) {
                                        ListObject gradients = 
computeGradients(params, dataSize, batchIter, i, j);
-       
+
                                        boolean localUpdate = j < batchIter - 1;
-                                       
-                                       // Accumulate the intermediate 
gradients (async for overlap w/ model updates 
+
+                                       // Accumulate the intermediate 
gradients (async for overlap w/ model updates
                                        // and gradient computation, sequential 
over gradient matrices to avoid deadlocks)
                                        ListObject accGradientsPrev = 
accGradients.get();
                                        accGradients = _tpool.submit(() -> 
ParamservUtils.accrueGradients(
-                                               accGradientsPrev, gradients, 
false, !localUpdate));
-       
+                                                       accGradientsPrev, 
gradients, false, !localUpdate));

Review comment:
       some as above

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -112,69 +120,89 @@ private void computeEpoch(long dataSize, int batchIter) {
                        catch(ExecutionException | InterruptedException ex) {
                                throw new DMLRuntimeException(ex);
                        }
-                       
+
                        accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
        }
 
+       protected ListObject createLocalModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
+               // Populate the variables table with the gradients and model
+               ec.setVariable(Statement.PS_GRADIENTS, gradients);
+               ec.setVariable(Statement.PS_MODEL, model);
+
+               // Invoke the aggregate function
+               _inst.processInstruction(ec);
+
+               // Get the new model
+               ListObject newModel = ec.getListObject(_outputName);
+
+               // Clean up the list according to the data referencing status
+               ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL, 
newModel.getStatus());
+               ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+               return newModel;
+       }
+
        private ListObject updateModel(ListObject globalParams, ListObject 
gradients, int i, int j, int batchIter) {
                Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
 
                globalParams = _ps.updateLocalModel(_ec, gradients, 
globalParams);
 
                accLocalModelUpdateTime(tUpd);
-               
+
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: local global parameter 
[size:%d kb] updated. "
-                               + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]",
-                               getWorkerName(), globalParams.getDataSize(), i 
+ 1, _epochs, j + 1, batchIter));
+                                                       + "[Epoch:%d  Total 
epoch:%d  Iteration:%d  Total iteration:%d]",
+                                       getWorkerName(), 
globalParams.getDataSize(), i + 1, _epochs, j + 1, batchIter));
                }
                return globalParams;
        }
-
        private void computeBatch(long dataSize, int totalIter) {
                for (int i = 0; i < _epochs; i++) {
                        for (int j = 0; j < totalIter; j++) {
                                ListObject globalParams = pullModel();
 
                                ListObject gradients = 
computeGradients(globalParams, dataSize, totalIter, i, j);
-
                                // Push the gradients to ps
                                pushGradients(gradients);
                                ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
-                               
+
                                accNumBatches(1);
                        }
-                       
+
                        accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
        }
-
        private ListObject pullModel() {
                // Pull the global parameters from ps
                ListObject globalParams = _ps.pull(_workerID);
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: successfully pull the 
global parameters "
-                               + "[size:%d kb] from ps.", getWorkerName(), 
globalParams.getDataSize() / 1024));
+                                       + "[size:%d kb] from ps.", 
getWorkerName(), globalParams.getDataSize() / 1024));

Review comment:
       see above.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -35,24 +31,30 @@
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.utils.Statistics;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+
 public class LocalPSWorker extends PSWorker implements Callable<Void> {
 
        protected static final Log LOG = 
LogFactory.getLog(LocalPSWorker.class.getName());
        private static final long serialVersionUID = 5195390748495357295L;
+       private ListObject modell;
+       private String _outputName;
 
        protected LocalPSWorker() {}
 
        public LocalPSWorker(int workerID, String updFunc, 
Statement.PSFrequency freq,
-               int epochs, long batchSize, ExecutionContext ec, ParamServer ps)
+                                                int epochs, long batchSize, 
ExecutionContext ec, ParamServer ps,boolean modelavg)

Review comment:
       please, avoid corrupting the existing formatting.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
                return _model;
        }
 
-       protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
+       protected synchronized void updModel_avgModel(int workerID, ListObject 
params){
+               if (_modelAvg == true){
+                       updateAverageModel(workerID, params);}
+               else if (_modelAvg == false)
+                       updateGlobalModel(workerID, params);
+
+}
+       protected  void updateAverageModel(int workerID, ListObject models) {
                try {
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Successfully pulled 
the gradients [size:%d kb] of worker_%d.",
-                                       gradients.getDataSize() / 1024, 
workerID));
+                                               models.getDataSize() / 1024, 
workerID));
                        }
 
                        switch(_updateType) {
                                case BSP: {
                                        setFinishedState(workerID);
-
-                                       // Accumulate the intermediate gradients
-                                       if( ACCRUE_BSP_GRADIENTS )
-                                               _accGradients = 
ParamservUtils.accrueGradients(_accGradients, gradients, true);
-                                       else
-                                               updateGlobalModel(gradients);
+                                       _accModel = 
ParamservUtils.accrueModels(_accModel, models, true);
 
                                        if (allFinished()) {
-                                               // Update the global model with 
accrued gradients
-                                               if( ACCRUE_BSP_GRADIENTS ) {
-                                                       
updateGlobalModel(_accGradients);
-                                                       _accGradients = null;
-                                               }
+                                               averageGlobalModel(_accModel);
+                                               _accModel = null;
 
                                                // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
                                                // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
                                                // BSP epoch case every time
                                                if (_numBatchesPerEpoch != -1 &&
-                                                       (_freq == 
Statement.PSFrequency.EPOCH ||
-                                                       (_freq == 
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+                                                               (_freq == 
Statement.PSFrequency.EPOCH ||
+                                                                               
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch 
== 0))) {
 
                                                        if(LOG.isInfoEnabled())
                                                                LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
 
                                                        time_epoch();
-
                                                        if(_validationPossible)
                                                                validate();
-
                                                        _epochCounter++;
                                                        _syncCounter = 0;
+

Review comment:
       do not introduce new lines before closing curly braces.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -62,7 +64,13 @@ public Void call() throws Exception {
 
                        switch (_freq) {
                                case BATCH:
-                                       computeBatch(dataSize, batchIter);
+                                       if (_modelAvg){
+                                               
computeBatch_Avg(dataSize,batchIter);
+                                       }
+
+                                       else
+                                               computeBatch(dataSize, 
batchIter);

Review comment:
       the formatting seems off again.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
##########
@@ -33,24 +33,32 @@ public LocalParamServer() {
 
        public static LocalParamServer create(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
                Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
-               MatrixObject valFeatures, MatrixObject valLabels)
+               MatrixObject valFeatures, MatrixObject valLabels,boolean 
modelAvg)

Review comment:
       all the additional parameters in these constructors are missing a space 
before the new parameter

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -41,17 +35,43 @@
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
-import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.cp.DoubleObject;
-import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.*;

Review comment:
       avoid wild-card imports.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
                return _model;
        }
 
-       protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
+       protected synchronized void updModel_avgModel(int workerID, ListObject 
params){
+               if (_modelAvg == true){
+                       updateAverageModel(workerID, params);}
+               else if (_modelAvg == false)
+                       updateGlobalModel(workerID, params);
+
+}
+       protected  void updateAverageModel(int workerID, ListObject models) {
                try {
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Successfully pulled 
the gradients [size:%d kb] of worker_%d.",
-                                       gradients.getDataSize() / 1024, 
workerID));
+                                               models.getDataSize() / 1024, 
workerID));
                        }
 
                        switch(_updateType) {
                                case BSP: {
                                        setFinishedState(workerID);
-
-                                       // Accumulate the intermediate gradients
-                                       if( ACCRUE_BSP_GRADIENTS )
-                                               _accGradients = 
ParamservUtils.accrueGradients(_accGradients, gradients, true);
-                                       else
-                                               updateGlobalModel(gradients);
+                                       _accModel = 
ParamservUtils.accrueModels(_accModel, models, true);
 
                                        if (allFinished()) {
-                                               // Update the global model with 
accrued gradients
-                                               if( ACCRUE_BSP_GRADIENTS ) {
-                                                       
updateGlobalModel(_accGradients);
-                                                       _accGradients = null;
-                                               }
+                                               averageGlobalModel(_accModel);
+                                               _accModel = null;
 
                                                // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
                                                // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
                                                // BSP epoch case every time
                                                if (_numBatchesPerEpoch != -1 &&
-                                                       (_freq == 
Statement.PSFrequency.EPOCH ||
-                                                       (_freq == 
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+                                                               (_freq == 
Statement.PSFrequency.EPOCH ||
+                                                                               
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch 
== 0))) {
 
                                                        if(LOG.isInfoEnabled())
                                                                LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
 
                                                        time_epoch();
-
                                                        if(_validationPossible)
                                                                validate();
-
                                                        _epochCounter++;
                                                        _syncCounter = 0;
+
                                                }
-                                               
                                                // Broadcast the updated model
                                                resetFinishedStates();
+
                                                broadcastModel(true);
                                                if (LOG.isDebugEnabled())
-                                                       LOG.debug("Global 
parameter is broadcasted successfully.");
+                                                       LOG.debug("Global 
Averaging parameter is broadcasted successfully ");
                                        }
                                        break;
                                }
-                               case ASP: {

Review comment:
       you can't just delete the support for ASP here!!!

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
##########
@@ -479,4 +479,51 @@ public static ListObject accrueGradients(ListObject 
accGradients, ListObject gra
                        ParamservUtils.cleanupListObject(gradients);
                return accGradients;
        }
+
+
+       /**
+                * Accumulate the given models into the accrued accrueModels
+        *
+        * @param accModels accrued models list object
+        * @param models given models list object
+        * @param cleanup clean up the given models list object
+        * @return new accrued models list object
+        */
+       public static ListObject accrueModels(ListObject accModels, ListObject 
models, boolean cleanup) {
+               return accrueModels(accModels, models, false, cleanup);
+       }
+
+       /**
+        * Accumulate the given models into the accrued models
+        *
+        * @param accModels accrued models list object
+        * @param models given models list object
+        * @param par parallel execution
+        * @param cleanup clean up the given models list object
+        * @return new accrued models list object
+        */
+       public static ListObject accrueModels(ListObject accModels, ListObject 
models, boolean par, boolean cleanup) {
+               if (accModels == null)
+                       return ParamservUtils.copyList(models, cleanup);
+               IntStream range = IntStream.range(0, accModels.getLength());
+               (par ? range.parallel() : range).forEach(i -> {
+                       MatrixBlock mb1 = ((MatrixObject) 
accModels.getData().get(i)).acquireReadAndRelease();
+                       MatrixBlock mb2 = ((MatrixObject) 
models.getData().get(i)).acquireReadAndRelease();
+                       mb1.binaryOperationsInPlace(new 
BinaryOperator(Plus.getPlusFnObject()), mb2);
+               });
+               if (cleanup)
+                       ParamservUtils.cleanupListObject(models);
+               return accModels;
+       }
+
+
+       //*******************************************   ATEFEH 
********************************************************

Review comment:
       delete

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
                return _model;
        }
 
-       protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
+       protected synchronized void updModel_avgModel(int workerID, ListObject 
params){
+               if (_modelAvg == true){
+                       updateAverageModel(workerID, params);}
+               else if (_modelAvg == false)
+                       updateGlobalModel(workerID, params);
+
+}
+       protected  void updateAverageModel(int workerID, ListObject models) {
                try {
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Successfully pulled 
the gradients [size:%d kb] of worker_%d.",
-                                       gradients.getDataSize() / 1024, 
workerID));
+                                               models.getDataSize() / 1024, 
workerID));
                        }
 
                        switch(_updateType) {
                                case BSP: {
                                        setFinishedState(workerID);
-
-                                       // Accumulate the intermediate gradients
-                                       if( ACCRUE_BSP_GRADIENTS )
-                                               _accGradients = 
ParamservUtils.accrueGradients(_accGradients, gradients, true);
-                                       else
-                                               updateGlobalModel(gradients);
+                                       _accModel = 
ParamservUtils.accrueModels(_accModel, models, true);
 
                                        if (allFinished()) {
-                                               // Update the global model with 
accrued gradients
-                                               if( ACCRUE_BSP_GRADIENTS ) {
-                                                       
updateGlobalModel(_accGradients);
-                                                       _accGradients = null;
-                                               }
+                                               averageGlobalModel(_accModel);
+                                               _accModel = null;
 
                                                // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
                                                // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
                                                // BSP epoch case every time
                                                if (_numBatchesPerEpoch != -1 &&
-                                                       (_freq == 
Statement.PSFrequency.EPOCH ||
-                                                       (_freq == 
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+                                                               (_freq == 
Statement.PSFrequency.EPOCH ||
+                                                                               
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch 
== 0))) {
 
                                                        if(LOG.isInfoEnabled())
                                                                LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
 
                                                        time_epoch();
-
                                                        if(_validationPossible)
                                                                validate();
-
                                                        _epochCounter++;
                                                        _syncCounter = 0;
+
                                                }
-                                               
                                                // Broadcast the updated model
                                                resetFinishedStates();
+
                                                broadcastModel(true);
                                                if (LOG.isDebugEnabled())
-                                                       LOG.debug("Global 
parameter is broadcasted successfully.");
+                                                       LOG.debug("Global 
Averaging parameter is broadcasted successfully ");
                                        }
                                        break;
                                }
-                               case ASP: {
-                                       updateGlobalModel(gradients);
-                                       // This works similarly to the one for 
BSP, but divides the sync counter by
-                                       // the number of workers, creating 
"Pseudo Epochs"
-                                       if (_numBatchesPerEpoch != -1 &&
-                                               ((_freq == 
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
-                                               (_freq == 
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float) 
_numBatchesPerEpoch == 0))) {
-
-                                               if(LOG.isInfoEnabled())
-                                                       LOG.info("[+] 
PARAMSERV: completed PSEUDO EPOCH (ASP) " + _epochCounter);
-
-                                               time_epoch();
-
-                                               if(_validationPossible)
-                                                       validate();
-
-                                               _epochCounter++;
-                                               _syncCounter = 0;
-                                       }
-
-                                       broadcastModel(workerID);
-                                       break;
-                               }
+                               case ASP:
+                                       throw new 
DMLRuntimeException("Unsupported update: " + _updateType.name()+"in the case of 
averaging model");
                                default:
                                        throw new 
DMLRuntimeException("Unsupported update: " + _updateType.name());
                        }
-               } 
+               }
                catch (Exception e) {
                        throw new DMLRuntimeException("Aggregation or 
validation service failed: ", e);
                }
        }
+       private void averageGlobalModel(ListObject accModel) {
+               Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
+               _model = averageModel(_ec,accModel, _model);
+
+               if (DMLScript.STATISTICS && tAgg != null)
+                       Statistics.accPSAggregationTime((long) tAgg.stop());
+       }
+       
/*********************************************************************************************************************
+        * A service method for averaging model with models
+        *
+        * @param ec execution context
+        * @param accModels list of models
+        * @param model old model
+        * @return new model (accModel)
+        */
+
+       public static  ListObject averageModel(ExecutionContext ec, ListObject 
accModels,ListObject model) {

Review comment:
       the formatting of the entire method is completely off; also we still 
need the model update via gradients (please make sure both gradient update and 
model averaging are still working).

##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
##########
@@ -118,28 +118,31 @@ public Instruction preprocessInstruction(ExecutionContext 
ec) {
        }
 
        @Override
-       public void processInstruction(ExecutionContext ec) {
+       public void

Review comment:
       ???

##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -327,12 +317,12 @@ private void runLocally(ExecutionContext ec, PSModeType 
mode) {
                MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) 
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
                MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? 
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
                ParamServer ps = createPS(mode, aggFunc, updateType, freq, 
workerNum, model, aggServiceEC, getValFunction(),
-                               num_batches_per_epoch, val_features, 
val_labels);
+                               num_batches_per_epoch, val_features, 
val_labels,parseBoolean(modelAvg));

Review comment:
       missing spaces before additional arg

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -74,23 +94,31 @@
        private int _syncCounter = 0;
        private int _epochCounter = 0 ;
        private int _numBatchesPerEpoch;
+       private boolean _modelAvg;
 
        private int _numWorkers;
+       private ListObject _accModel = null;
+       private Object sum;
+       private BinaryOperator _op2;
+       private MatrixObject AvgModel;
+
 
        protected ParamServer() {}
 
        protected ParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
-               Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc,
-               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels)
+                                                 Statement.PSFrequency freq, 
ExecutionContext ec, int workerNum, String valFunc,

Review comment:
       fix the corrupted formatting.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -127,12 +155,12 @@ protected void setupAggFunc(ExecutionContext ec, String 
aggFunc) {
                _outputName = outputs.get(0).getName();
 
                CPOperand[] boundInputs = inputs.stream()
-                       .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
-                       .toArray(CPOperand[]::new);
+                               .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))

Review comment:
       some as above.

##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -87,6 +67,9 @@
 import org.apache.sysds.runtime.util.ProgramConverter;
 import org.apache.sysds.utils.Statistics;
 
+import static java.lang.Boolean.parseBoolean;
+import static org.apache.sysds.parser.Statement.*;

Review comment:
       again, no wild-card imports.

##########
File path: 
src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
##########
@@ -19,23 +19,25 @@
 
 package org.apache.sysds.test.functions.federated.paramserv;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.List;
-
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.codegen.SpoofFusedOp;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.apache.sysds.utils.Statistics;
+import org.dmg.pmml.True;

Review comment:
       Why are you introducing this dependency?

##########
File path: src/test/scripts/functions/federated/paramserv/CNN.dml
##########
@@ -67,7 +67,7 @@ source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
  */
 train = function(matrix[double] X, matrix[double] y, matrix[double] X_val,
   matrix[double] y_val, int epochs, int batch_size, double eta, int C, int Hin,
-       int Win, int seed = -1) return (list[unknown] model)
+       int Win, int seed = -1,boolean modelAvg) return (list[unknown] model)

Review comment:
       formatting

##########
File path: 
src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
##########
@@ -197,12 +200,15 @@ private void federatedParamserv(ExecMode mode) {
                                        "channels=" + C,
                                        "hin=" + Hin,
                                        "win=" + Win,
-                                       "seed=" + _seed));
+                                       "seed=" + _seed,
+                                       "modelAvg="+ modelAvg));
 
                        programArgs = programArgsList.toArray(new String[0]);
-                       LOG.debug(runTest(null));
-                       Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
-                       
+
+                       runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+                       //      Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());

Review comment:
       please, do not disable assertions of existing tests.

##########
File path: src/test/scripts/functions/federated/paramserv/CNN.dml
##########
@@ -161,9 +161,11 @@ train = function(matrix[double] X, matrix[double] y, 
matrix[double] X_val,
 train_paramserv = function(matrix[double] X, matrix[double] y,
   matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
   string utype, string freq, int batch_size, string scheme, string 
runtime_balancing,
-  string weighting, double eta, int C, int Hin, int Win, int seed = -1)
+  string weighting, double eta, int C, int Hin, int Win, int seed = -1,boolean 
modelAvg)

Review comment:
       see above.

##########
File path: src/test/scripts/functions/federated/paramserv/CNN.dml
##########
@@ -161,9 +161,11 @@ train = function(matrix[double] X, matrix[double] y, 
matrix[double] X_val,
 train_paramserv = function(matrix[double] X, matrix[double] y,
   matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
   string utype, string freq, int batch_size, string scheme, string 
runtime_balancing,
-  string weighting, double eta, int C, int Hin, int Win, int seed = -1)
+  string weighting, double eta, int C, int Hin, int Win, int seed = -1,boolean 
modelAvg)
   return (list[unknown] model)
 {
+

Review comment:
       see above.

##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -588,4 +578,8 @@ private String getValFunction() {
        private int getSeed() {
                return (getParameterMap().containsKey(PS_SEED)) ? 
Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
        }
+       private boolean getModelAvg() {
+               return getParameterMap().containsKey(PS_MODELAVG) && 
parseBoolean(getParam(PS_MODELAVG));
+       }
+

Review comment:
       no such free lines.

##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -468,21 +458,21 @@ private int getWorkerNum(PSModeType mode) {
         * @return parameter server
         */
        private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType,
-               PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec)
+               PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec,boolean modelAvg)
        {
-               return createPS(mode, aggFunc, updateType, freq, workerNum, 
model, ec, null, -1, null, null);
+               return createPS(mode, aggFunc, updateType, freq, workerNum, 
model, ec, null, -1, null, null,modelAvg );
        }
 
        // When this creation is used the parameter server is able to validate 
after each epoch
        private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType,
                PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec, String valFunc,
-               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels)
+               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels,boolean modelAvg)

Review comment:
       same as above.

##########
File path: src/test/scripts/functions/federated/paramserv/TwoNN.dml
##########
@@ -150,13 +150,15 @@ train_paramserv = function(matrix[double] X, 
matrix[double] y,
   model = list(W1, W2, W3, b1, b2, b3)
   # Create the hyper parameter list
   hyperparams = list(learning_rate=eta)
+
+while (FALSE) {}
   # Use paramserv function
   model = paramserv(model=model, features=X, labels=y, val_features=X_val, 
val_labels=y_val,
     
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
     
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
     val="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::validate",
     k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
-    scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting, 
hyperparams=hyperparams, seed=seed)
+    scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting, 
hyperparams=hyperparams, seed=seed,modelAvg=modelAvg)

Review comment:
       formatting.

##########
File path: src/test/scripts/functions/federated/paramserv/TwoNN.dml
##########
@@ -58,7 +58,7 @@ source("nn/optim/sgd.dml") as sgd
 train = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
                  int epochs, int batch_size, double eta,
-                 int seed = -1)
+                 int seed = -1 , boolean modelAvg )

Review comment:
       formatting




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to