tobiasrieger commented on a change in pull request #1075:
URL: https://github.com/apache/systemds/pull/1075#discussion_r499147919



##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -91,16 +95,87 @@ public ParamservBuiltinCPInstruction(Operator op, 
LinkedHashMap<String, String>
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-               PSModeType mode = getPSMode();
-               switch (mode) {
-                       case LOCAL:
-                               runLocally(ec, mode);
-                               break;
-                       case REMOTE_SPARK:
-                               runOnSpark((SparkExecutionContext) ec, mode);
-                               break;
-                       default:
-                               throw new 
DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+               // check if the input is federated
+               if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
+                               
ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
+                       runFederated(ec);
+               }
+               // if not federated check mode
+               else {
+                       PSModeType mode = getPSMode();
+                       switch (mode) {
+                               case LOCAL:
+                                       runLocally(ec, mode);
+                                       break;
+                               case REMOTE_SPARK:
+                                       runOnSpark((SparkExecutionContext) ec, 
mode);
+                                       break;
+                               default:
+                                       throw new 
DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+                       }
+               }
+       }
+
+       private void runFederated(ExecutionContext ec) {
+               System.out.println("PARAMETER SERVER");
+               System.out.println("[+] Running in federated mode");
+
+               // get inputs
+               PSFrequency freq = getFrequency();
+               PSUpdateType updateType = getUpdateType();
+               String updFunc = getParam(PS_UPDATE_FUN);
+               String aggFunc = getParam(PS_AGGREGATION_FUN);
+
+               // partition federated data
+               DataPartitionFederatedScheme.Result result = new 
FederatedDataPartitioner(Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER)
+                               
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), 
ec.getMatrixObject(getParam(PS_LABELS)));
+               List<MatrixObject> pFeatures = result.pFeatures;
+               List<MatrixObject> pLabels = result.pLabels;
+               int workerNum = result.workerNum;
+
+               // setup threading
+               BasicThreadFactory factory = new BasicThreadFactory.Builder()
+                               
.namingPattern("workers-pool-thread-%d").build();
+               ExecutorService es = Executors.newFixedThreadPool(workerNum, 
factory);
+
+               // Get the compiled execution context
+               LocalVariableMap newVarsMap = createVarsMap(ec);
+               // Level of par is 1 because one worker will be launched per 
task
+               // TODO: Fix recompilation
+               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, 1, 
true);
+               // Create workers' execution context
+               List<ExecutionContext> federatedWorkerECs = 
ParamservUtils.copyExecutionContext(newEC, workerNum);

Review comment:
       That is a little more complicated. Each FederatedPSControlThread needs 
its own execution context for synchronisation. The FederatedPSControlThread 
then sets up everything the federated worker needs in its execution context. 
   
   To go back to the question: The only redundant thing is, that the EC of the 
FedPSThread does not need code other than the gradients and aggregation 
function, but for now there is no filtering. It gets filtered when serialising 
and sending to the federated worker though.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to