tobiasrieger commented on a change in pull request #1075:
URL: https://github.com/apache/systemds/pull/1075#discussion_r499147919
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -91,16 +95,87 @@ public ParamservBuiltinCPInstruction(Operator op,
LinkedHashMap<String, String>
@Override
public void processInstruction(ExecutionContext ec) {
- PSModeType mode = getPSMode();
- switch (mode) {
- case LOCAL:
- runLocally(ec, mode);
- break;
- case REMOTE_SPARK:
- runOnSpark((SparkExecutionContext) ec, mode);
- break;
- default:
- throw new
DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+ // check if the input is federated
+ if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
+
ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
+ runFederated(ec);
+ }
+ // if not federated check mode
+ else {
+ PSModeType mode = getPSMode();
+ switch (mode) {
+ case LOCAL:
+ runLocally(ec, mode);
+ break;
+ case REMOTE_SPARK:
+ runOnSpark((SparkExecutionContext) ec,
mode);
+ break;
+ default:
+ throw new
DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+ }
+ }
+ }
+
+ private void runFederated(ExecutionContext ec) {
+ System.out.println("PARAMETER SERVER");
+ System.out.println("[+] Running in federated mode");
+
+ // get inputs
+ PSFrequency freq = getFrequency();
+ PSUpdateType updateType = getUpdateType();
+ String updFunc = getParam(PS_UPDATE_FUN);
+ String aggFunc = getParam(PS_AGGREGATION_FUN);
+
+ // partition federated data
+ DataPartitionFederatedScheme.Result result = new
FederatedDataPartitioner(Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER)
+
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)),
ec.getMatrixObject(getParam(PS_LABELS)));
+ List<MatrixObject> pFeatures = result.pFeatures;
+ List<MatrixObject> pLabels = result.pLabels;
+ int workerNum = result.workerNum;
+
+ // setup threading
+ BasicThreadFactory factory = new BasicThreadFactory.Builder()
+
.namingPattern("workers-pool-thread-%d").build();
+ ExecutorService es = Executors.newFixedThreadPool(workerNum,
factory);
+
+ // Get the compiled execution context
+ LocalVariableMap newVarsMap = createVarsMap(ec);
+ // Level of par is 1 because one worker will be launched per
task
+ // TODO: Fix recompilation
+ ExecutionContext newEC =
ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, 1,
true);
+ // Create workers' execution context
+ List<ExecutionContext> federatedWorkerECs =
ParamservUtils.copyExecutionContext(newEC, workerNum);
Review comment:
That is a little more complicated. Each FederatedPSControlThread needs
its own execution context for synchronisation. The FederatedPSControlThread
then sets up everything the federated worker needs in its execution context.
To go back to the question: The only redundant thing is, that the EC of the
FedPSThread does not need code other than the gradients and aggregation
function, but for now there is no filtering. It gets filtered when serialising
and sending to the federated worker though.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]