atefeh-asayesh commented on a change in pull request #1336:
URL: https://github.com/apache/systemds/pull/1336#discussion_r676039848



##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -112,69 +120,89 @@ private void computeEpoch(long dataSize, int batchIter) {
                        catch(ExecutionException | InterruptedException ex) {
                                throw new DMLRuntimeException(ex);
                        }
-                       
+
                        accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
        }
 
+       protected ListObject createLocalModel(ExecutionContext ec, ListObject 
gradients, ListObject model) {
+               // Populate the variables table with the gradients and model
+               ec.setVariable(Statement.PS_GRADIENTS, gradients);
+               ec.setVariable(Statement.PS_MODEL, model);
+
+               // Invoke the aggregate function
+               _inst.processInstruction(ec);
+
+               // Get the new model
+               ListObject newModel = ec.getListObject(_outputName);
+
+               // Clean up the list according to the data referencing status
+               ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL, 
newModel.getStatus());
+               ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+               return newModel;
+       }
+
        private ListObject updateModel(ListObject globalParams, ListObject 
gradients, int i, int j, int batchIter) {
                Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
 
                globalParams = _ps.updateLocalModel(_ec, gradients, 
globalParams);
 
                accLocalModelUpdateTime(tUpd);
-               
+
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: local global parameter 
[size:%d kb] updated. "
-                               + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]",
-                               getWorkerName(), globalParams.getDataSize(), i 
+ 1, _epochs, j + 1, batchIter));
+                                                       + "[Epoch:%d  Total 
epoch:%d  Iteration:%d  Total iteration:%d]",
+                                       getWorkerName(), 
globalParams.getDataSize(), i + 1, _epochs, j + 1, batchIter));
                }
                return globalParams;
        }
-
        private void computeBatch(long dataSize, int totalIter) {
                for (int i = 0; i < _epochs; i++) {
                        for (int j = 0; j < totalIter; j++) {
                                ListObject globalParams = pullModel();
 
                                ListObject gradients = 
computeGradients(globalParams, dataSize, totalIter, i, j);
-
                                // Push the gradients to ps
                                pushGradients(gradients);
                                ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
-                               
+
                                accNumBatches(1);
                        }
-                       
+
                        accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
        }
-
        private ListObject pullModel() {
                // Pull the global parameters from ps
                ListObject globalParams = _ps.pull(_workerID);
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: successfully pull the 
global parameters "
-                               + "[size:%d kb] from ps.", getWorkerName(), 
globalParams.getDataSize() / 1024));
+                                       + "[size:%d kb] from ps.", 
getWorkerName(), globalParams.getDataSize() / 1024));
                }
                return globalParams;
        }
-
        private void pushGradients(ListObject gradients) {
                // Push the gradients to ps
                _ps.push(_workerID, gradients);
                if (LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: successfully push the 
gradients "
-                               + "[size:%d kb] to ps.", getWorkerName(), 
gradients.getDataSize() / 1024));
+                                       + "[size:%d kb] to ps.", 
getWorkerName(), gradients.getDataSize() / 1024));
+               }
+       }
+       private void pushModelToServer(ListObject modell) {
+               // Push the Model to ps
+               _ps.push(_workerID, modell);

Review comment:
       Thanks for the comment .this variable has been removed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to