This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 10b493c [SYSTEMDS-2634] Reduced number of RPCs calls in federated
backend
10b493c is described below
commit 10b493cb71c72a1c6f65470166a8ab4842c239a4
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Aug 23 16:57:24 2020 +0200
[SYSTEMDS-2634] Reduced number of RPCs calls in federated backend
This patch improves the performance of the federated runtime backend by
merging the execution and cleanup RPC request batches into a single
batch of requests. Since every batch returns only a single response, we
now carefully select the right get_var, error, or other responses to
return. Overall, this reduced the number of RPC calls by almost 2x and
removed unnecessary synchronization barriers.
---
.../controlprogram/caching/CacheableData.java | 2 +-
.../federated/FederatedWorkerHandler.java | 22 ++++++++++++++++++----
.../controlprogram/federated/FederationMap.java | 9 ++++++++-
.../controlprogram/federated/FederationUtils.java | 1 -
.../fed/AggregateBinaryFEDInstruction.java | 16 ++++++++--------
.../fed/AggregateUnaryFEDInstruction.java | 6 +++---
.../fed/BinaryMatrixMatrixFEDInstruction.java | 19 ++++++++-----------
.../fed/BinaryMatrixScalarFEDInstruction.java | 13 ++++++++-----
.../instructions/fed/MMChainFEDInstruction.java | 10 ++++++----
.../instructions/fed/TsmmFEDInstruction.java | 4 ++--
10 files changed, 62 insertions(+), 40 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 720534a..4d0d5d9 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -680,7 +680,7 @@ public abstract class CacheableData<T extends CacheBlock>
extends Data
//clear federated matrix
if( _fedMapping != null )
- _fedMapping.cleanup(tid, _fedMapping.getID());
+ _fedMapping.execCleanup(tid, _fedMapping.getID());
// change object state EMPTY
setDirty(false);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index f4af303..0dcb846 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -91,10 +91,24 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
PrivacyMonitor.clearCheckedConstraints();
- response = executeCommand(request);
- conditionalAddCheckedConstraints(request, response);
- if (!response.isSuccessful()){
- log.error("Command " + request.getType() + "
failed: " + response.getErrorMessage() + "full command: \n" +
request.toString());
+ //execute command and handle privacy constraints
+ FederatedResponse tmp = executeCommand(request);
+ conditionalAddCheckedConstraints(request, tmp);
+
+ //select the response for the entire batch of requests
+ if (!tmp.isSuccessful()) {
+ log.error("Command " + request.getType() + "
failed: "
+ + tmp.getErrorMessage() + "full
command: \n" + request.toString());
+ response = (response == null ||
response.isSuccessful())
+ ? tmp : response; //return first error
+ }
+ else if( request.getType() == RequestType.GET_VAR ) {
+ if( response != null && response.isSuccessful()
)
+ log.error("Multiple GET_VAR are not
supported in single batch of requests.");
+ response = tmp; //return last get result
+ }
+ else if( response == null && i == requests.length-1 ) {
+ response = tmp; //return last
}
}
ctx.writeAndFlush(response).addListener(new CloseListener());
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index b272bf9..72d1196 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -170,7 +170,14 @@ public class FederationMap
return readResponses;
}
- public void cleanup(long tid, long... id) {
+ public FederatedRequest cleanup(long tid, long... id) {
+ FederatedRequest request = new
FederatedRequest(RequestType.EXEC_INST, -1,
+
VariableCPInstruction.prepareRemoveInstruction(id).toString());
+ request.setTID(tid);
+ return request;
+ }
+
+ public void execCleanup(long tid, long... id) {
FederatedRequest request = new
FederatedRequest(RequestType.EXEC_INST, -1,
VariableCPInstruction.prepareRemoveInstruction(id).toString());
request.setTID(tid);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index faae560..7df7c51 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -32,7 +32,6 @@ import
org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
-import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 34caec2..c28a163 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -68,10 +68,10 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
new CPOperand[]{input1, input2},
new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID()});
FederatedRequest fr2 = new
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 =
mo2.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2);
+ Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo2.getFedMapping().cleanup(getTID(), fr1.getID(),
fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else if(mo1.isFederated(FType.ROW)) { // MV + MM
@@ -81,16 +81,16 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
if( mo2.getNumColumns() == 1 ) { //MV
FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
MatrixBlock ret = FederationUtils.rbind(tmp);
- mo1.getFedMapping().cleanup(getTID(),
fr1.getID(), fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else { //MM
//execute federated operations and aggregate
- mo1.getFedMapping().execute(getTID(), true,
fr1, fr2);
- mo1.getFedMapping().cleanup(getTID(),
fr1.getID());
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(),
(int)mo1.getBlocksize());
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(),
mo2.getNumColumns()));
@@ -104,10 +104,10 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new
long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 =
mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp =
mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ Future<FederatedResponse>[] tmp =
mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(),
fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else { //other combinations
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index e87bf57..60fe40b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -55,19 +55,19 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
public void processInstruction(ExecutionContext ec) {
AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
MatrixObject in = ec.getMatrixObject(input1);
+ FederationMap map = in.getFedMapping();
//create federated commands for aggregation
FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new
long[]{in.getFedMapping().getID()});
FederatedRequest fr2 = new
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
//execute federated commands and cleanups
- FederationMap map = in.getFedMapping();
- Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1,
fr2);
+ Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1,
fr2, fr3);
if( output.isScalar() )
ec.setVariable(output.getName(),
FederationUtils.aggScalar(aop, tmp));
else
ec.setMatrixOutput(output.getName(),
FederationUtils.aggMatrix(aop, tmp, map));
- map.cleanup(getTID(), fr1.getID());
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 63c2d71..bceb6ae 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -42,39 +42,36 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
FederatedRequest fr2 = null;
if( mo2.isFederated() ) {
- if(mo1.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)){
+ if(mo1.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID()});
mo1.getFedMapping().execute(getTID(), true,
fr2);
-
- } else{
+ }
+ else {
throw new DMLRuntimeException("Matrix-matrix
binary operations "
+ " with a federated right input are
not supported yet.");
}
-
- }
+ }
else {
//matrix-matrix binary oFederatedRequest fr2 =
null;perations -> lhs fed input -> fed output
-
if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) {
//MV row vector
FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
new long[]{mo1.getFedMapping().getID(),
fr1[0].getID()});
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
//execute federated instruction and cleanup
intermediates
- mo1.getFedMapping().execute(getTID(), true,
fr1, fr2);
- mo1.getFedMapping().cleanup(getTID(),
fr1[0].getID());
+ mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
}
else { //MM or MV col vector
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
new long[]{mo1.getFedMapping().getID(),
fr1.getID()});
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
//execute federated instruction and cleanup
intermediates
- mo1.getFedMapping().execute(getTID(), true,
fr1, fr2);
- mo1.getFedMapping().cleanup(getTID(),
fr1.getID());
+ mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
}
}
-
//derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getDataCharacteristics());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index 75bfe33..b6ea1fb 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -39,17 +39,20 @@ public class BinaryMatrixScalarFEDInstruction extends
BinaryFEDInstruction
CPOperand scalar = input2.isScalar() ? input2 : input1;
MatrixObject mo = ec.getMatrixObject(matrix);
- //execute federated matrix-scalar operation and cleanups
+ //prepare federated request matrix-scalar
FederatedRequest fr1 = !scalar.isLiteral() ?
mo.getFedMapping().broadcast(ec.getScalarInput(scalar))
: null;
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{matrix, (fr1 != null)?scalar:null},
new long[]{mo.getFedMapping().getID(), (fr1 !=
null)?fr1.getID():-1});
- mo.getFedMapping().execute(getTID(), true, (fr1!=null) ?
- new FederatedRequest[]{fr1, fr2}: new
FederatedRequest[]{fr2});
- if( fr1 != null )
- mo.getFedMapping().cleanup(getTID(), fr1.getID());
+ //execute federated matrix-scalar operation and cleanups
+ if( fr1 != null ) {
+ FederatedRequest fr3 =
mo.getFedMapping().cleanup(getTID(), fr1.getID());
+ mo.getFedMapping().execute(getTID(), true, fr1, fr2,
fr3);
+ }
+ else
+ mo.getFedMapping().execute(getTID(), true, fr2);
//derive new fed mapping for output
MatrixObject out = ec.getMatrixObject(output);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 2dee64b..99a305b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -86,11 +86,12 @@ public class MMChainFEDInstruction extends
UnaryFEDInstruction {
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), fr1.getID()});
FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping()
+ .cleanup(getTID(), fr1.getID(), fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+ Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo1.getFedMapping().cleanup(getTID(), fr1.getID(),
fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else { //XtwXv | XtXvy
@@ -101,11 +102,12 @@ public class MMChainFEDInstruction extends
UnaryFEDInstruction {
new CPOperand[]{input1, input2, input3},
new long[]{mo1.getFedMapping().getID(),
fr1.getID(), fr0[0].getID()});
FederatedRequest fr3 = new
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 = mo1.getFedMapping()
+ .cleanup(getTID(), fr0[0].getID(), fr1.getID(),
fr2.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3);
+ Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3, fr4);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo1.getFedMapping().cleanup(getTID(), fr0[0].getID(),
fr1.getID(), fr2.getID());
ec.setMatrixOutput(output.getName(), ret);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 292bced..fbe88d6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -67,11 +67,11 @@ public class TsmmFEDInstruction extends
BinaryFEDInstruction {
FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new
long[]{mo1.getFedMapping().getID()});
FederatedRequest fr2 = new
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
//execute federated operations and aggregate
- Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2);
+ Future<FederatedResponse>[] tmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
MatrixBlock ret = FederationUtils.aggAdd(tmp);
- mo1.getFedMapping().cleanup(getTID(), fr1.getID());
ec.setMatrixOutput(output.getName(), ret);
}
else { //other combinations