This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 5fd67d2 [SYSTEMDS-2604] Federated compile and federated output
propagation This commit adds compilation of federated instructions when
OptimizerUtils.FEDERATED_COMPILATION is set to true. Additionally, a federated
output flag is added and propagated through HOPs, LOPs, and instructions.
Closes #1199.
5fd67d2 is described below
commit 5fd67d24177f2d78d7d704a511ee3bd9c92d0942
Author: sebwrede <[email protected]>
AuthorDate: Tue Mar 2 11:58:37 2021 +0100
[SYSTEMDS-2604] Federated compile and federated output propagation
This commit adds compilation of federated instructions when
OptimizerUtils.FEDERATED_COMPILATION is set to true.
Additionally, a federated output flag is added and propagated through HOPs,
LOPs, and instructions.
Closes #1199.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 11 ++--
.../java/org/apache/sysds/hops/AggUnaryOp.java | 9 ++-
src/main/java/org/apache/sysds/hops/BinaryOp.java | 12 ++--
src/main/java/org/apache/sysds/hops/DataGenOp.java | 2 +-
src/main/java/org/apache/sysds/hops/DataOp.java | 13 ++++-
src/main/java/org/apache/sysds/hops/DnnOp.java | 2 +-
.../java/org/apache/sysds/hops/FunctionOp.java | 2 +-
src/main/java/org/apache/sysds/hops/Hop.java | 49 +++++++++++++++++
.../java/org/apache/sysds/hops/IndexingOp.java | 2 +-
.../java/org/apache/sysds/hops/LeftIndexingOp.java | 2 +-
.../java/org/apache/sysds/hops/OptimizerUtils.java | 5 ++
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 2 +-
src/main/java/org/apache/sysds/hops/ReorgOp.java | 8 ++-
src/main/java/org/apache/sysds/hops/TernaryOp.java | 4 +-
src/main/java/org/apache/sysds/hops/UnaryOp.java | 5 +-
src/main/java/org/apache/sysds/lops/Binary.java | 27 ++++-----
src/main/java/org/apache/sysds/lops/Lop.java | 21 ++++++-
.../java/org/apache/sysds/lops/LopProperties.java | 4 ++
.../org/apache/sysds/lops/PartialAggregate.java | 2 +-
src/main/java/org/apache/sysds/lops/Transform.java | 6 +-
.../controlprogram/federated/FederationUtils.java | 8 ++-
.../runtime/instructions/FEDInstructionParser.java | 36 +++++++++++-
.../runtime/instructions/InstructionUtils.java | 63 ++++++++++++++++++++-
.../instructions/fed/BinaryFEDInstruction.java | 18 +++---
.../fed/BinaryMatrixMatrixFEDInstruction.java | 27 +++++++--
.../fed/BinaryMatrixScalarFEDInstruction.java | 6 +-
.../fed/ComputationFEDInstruction.java | 20 ++++---
.../runtime/instructions/fed/FEDInstruction.java | 6 ++
.../fed/QuantilePickFEDInstruction.java | 9 ++-
.../instructions/fed/ReorgFEDInstruction.java | 11 +++-
.../instructions/fed/TsmmFEDInstruction.java | 12 ++--
.../instructions/fed/UnaryFEDInstruction.java | 17 +++++-
.../privacy/propagation/PrivacyPropagator.java | 2 +-
.../primitives/FederatedBinaryMatrixTest.java | 14 ++++-
.../primitives/FederatedBinaryVectorTest.java | 12 +++-
.../primitives/FederatedMultiplyTest.java | 13 ++++-
.../federated/primitives/FederatedSumTest.java | 14 ++++-
.../privacy/FederatedWorkerHandlerTest.java | 2 +-
.../privacy/algorithms/FederatedL2SVMTest.java | 24 +++-----
.../FederatedMultiplyPlanningTest.java} | 64 +++++++++++-----------
.../privacy/FederatedMultiplyPlanningTest.dml | 28 ++++++++++
.../FederatedMultiplyPlanningTestReference.dml | 26 +++++++++
42 files changed, 488 insertions(+), 132 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index c279071..40fbb0e 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -102,6 +102,7 @@ public class AggBinaryOp extends MultiThreadedHop
outerOp = outOp;
getInput().add(0, in1);
getInput().add(1, in2);
+ updateETFed();
in1.getParent().add(this);
in2.getParent().add(this);
@@ -177,8 +178,8 @@ public class AggBinaryOp extends MultiThreadedHop
//matrix mult operation selection part 2 (specific
pattern)
MMTSJType mmtsj = checkTransposeSelf(); //determine
tsmm pattern
ChainType chain = checkMapMultChain(); //determine
mmchain pattern
-
- if( et == ExecType.CP || et == ExecType.GPU )
+
+ if( et == ExecType.CP || et == ExecType.GPU || et ==
ExecType.FED )
{
//matrix mult operation selection part 3 (CP
type)
_method = optFindMMultMethodCP (
input1.getDim1(), input1.getDim2(),
@@ -251,7 +252,7 @@ public class AggBinaryOp extends MultiThreadedHop
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
@@ -425,7 +426,9 @@ public class AggBinaryOp extends MultiThreadedHop
//pull binary aggregate into spark
_etype = ExecType.SPARK;
}
-
+
+ updateETFed();
+
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 4e87baa..f842502 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -123,7 +123,7 @@ public class AggUnaryOp extends MultiThreadedHop
ExecType et = optFindExecType();
Hop input = getInput().get(0);
- if ( et == ExecType.CP || et == ExecType.GPU )
+ if ( et == ExecType.CP || et == ExecType.GPU || et ==
ExecType.FED )
{
Lop agg1 = null;
if( isTernaryAggregateRewriteApplicable() ) {
@@ -209,6 +209,7 @@ public class AggUnaryOp extends MultiThreadedHop
}
}
}
+ else throw new HopsException("ExecType " + et + " not
recognized in " + this.toString() );
}
catch (Exception e) {
throw new HopsException(this.printErrorLocation() + "In
AggUnary Hop, error constructing Lops " , e);
@@ -216,7 +217,7 @@ public class AggUnaryOp extends MultiThreadedHop
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
//return created lops
return getLops();
}
@@ -381,7 +382,9 @@ public class AggUnaryOp extends MultiThreadedHop
//pull unary aggregate into spark
_etype = ExecType.SPARK;
}
-
+
+ updateETFed();
+
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 10e1c8d..36cb051 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -97,6 +97,7 @@ public class BinaryOp extends MultiThreadedHop
op = o;
getInput().add(0, inp1);
getInput().add(1, inp2);
+ updateETFed();
inp1.getParent().add(this);
inp2.getParent().add(this);
@@ -225,7 +226,7 @@ public class BinaryOp extends MultiThreadedHop
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
@@ -442,7 +443,7 @@ public class BinaryOp extends MultiThreadedHop
setLineNumbers(softmax);
setLops(softmax);
}
- else if ( et == ExecType.CP || et == ExecType.GPU )
+ else if ( et == ExecType.CP || et == ExecType.GPU || et
== ExecType.FED )
{
Lop binary = null;
@@ -462,7 +463,8 @@ public class BinaryOp extends MultiThreadedHop
binary = new
Binary(getInput(0).constructLops(), getInput(1).constructLops(),
op, getDataType(),
getValueType(), et,
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
-
+
+ setFederatedOutput(binary);
setOutputDimensions(binary);
setLineNumbers(binary);
setLops(binary);
@@ -496,7 +498,7 @@ public class BinaryOp extends MultiThreadedHop
setOutputDimensions(binary);
setLineNumbers(binary);
setLops(binary);
- }
+ } else throw new HopsException("Lop construction not
implemented for ExecType " + et);
}
}
@@ -740,6 +742,8 @@ public class BinaryOp extends MultiThreadedHop
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
+
+ updateETFed();
//spark-specific decision refinement (execute unary scalar w/
spark input and
//single parent also in spark because it's likely cheap and
reduces intermediates)
diff --git a/src/main/java/org/apache/sysds/hops/DataGenOp.java
b/src/main/java/org/apache/sysds/hops/DataGenOp.java
index 8fdf98d..32b17b9 100644
--- a/src/main/java/org/apache/sysds/hops/DataGenOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataGenOp.java
@@ -203,7 +203,7 @@ public class DataGenOp extends MultiThreadedHop
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java
b/src/main/java/org/apache/sysds/hops/DataOp.java
index e3467c5..28d458d 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -322,7 +322,7 @@ public class DataOp extends Hop {
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
@@ -487,7 +487,16 @@ public class DataOp extends Hop {
return _etype;
}
-
+
+ /**
+ * True if execution is federated, if output is federated, or if
OpOpData is federated.
+ * @return true if federated
+ */
+ @Override
+ public boolean isFederated() {
+ return super.isFederated() || getOp() == OpOpData.FEDERATED;
+ }
+
@Override
public void refreshSizeInformation() {
if( _op == OpOpData.PERSISTENTWRITE || _op ==
OpOpData.TRANSIENTWRITE ) {
diff --git a/src/main/java/org/apache/sysds/hops/DnnOp.java
b/src/main/java/org/apache/sysds/hops/DnnOp.java
index 54978f1..2a5895c 100644
--- a/src/main/java/org/apache/sysds/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysds/hops/DnnOp.java
@@ -155,7 +155,7 @@ public class DnnOp extends MultiThreadedHop {
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
diff --git a/src/main/java/org/apache/sysds/hops/FunctionOp.java
b/src/main/java/org/apache/sysds/hops/FunctionOp.java
index 1b6c2fc..be6d851 100644
--- a/src/main/java/org/apache/sysds/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysds/hops/FunctionOp.java
@@ -300,7 +300,7 @@ public class FunctionOp extends Hop
setLops(fcall);
//note: no reblock lop because outputs directly bound
-
+
return getLops();
}
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 5be1a63..dcd258e 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -81,6 +81,13 @@ public abstract class Hop implements ParseInfo {
protected ExecType _etype = null; //currently used exec type
protected ExecType _etypeForced = null; //exec type forced via platform
or external optimizer
+
+ /**
+ * Boolean defining if the output of the operation should be federated.
+ * If it is true, the output should be kept at federated sites.
+ * If it is false, the output should be retrieved by the coordinator.
+ */
+ protected boolean _federatedOutput = false;
// Estimated size for the output produced from this Hop
protected double _outputMemEstimate = OptimizerUtils.INVALID_SIZE;
@@ -734,6 +741,40 @@ public abstract class Hop implements ParseInfo {
return et;
}
+ /**
+ * Update the execution type if input is federated and federated
compilation is activated.
+ * Federated compilation is activated in OptimizerUtils.
+ */
+ protected void updateETFed(){
+ if ( inputIsFED() )
+ _etype = ExecType.FED;
+ }
+
+ /**
+ * Returns true if any input has federated ExecType and configures such
input to keep the output federated.
+ * This method can only return true if FedDecision is activated.
+ * @return true if any input has federated ExecType
+ */
+ protected boolean inputIsFED(){
+ if ( !OptimizerUtils.FEDERATED_COMPILATION ) return false;
+ boolean fedFound = false;
+ for ( Hop input : _input ){
+ if ( input.isFederated() ){
+ input._federatedOutput = true;
+ fedFound = true;
+ }
+ }
+ return fedFound;
+ }
+
+ /**
+ * Returns true if the execution is federated and/or if the output is
federated.
+ * @return true if federated
+ */
+ public boolean isFederated(){
+ return getExecType() == ExecType.FED || hasFederatedOutput();
+ }
+
public ArrayList<Hop> getParent() {
return _parent;
}
@@ -780,6 +821,10 @@ public abstract class Hop implements ParseInfo {
return _privacyConstraint;
}
+ public boolean hasFederatedOutput(){
+ return _federatedOutput;
+ }
+
public void setUpdateType(UpdateType update){
_updateType = update;
}
@@ -1413,6 +1458,10 @@ public abstract class Hop implements ParseInfo {
lop.setPrivacyConstraint(getPrivacy());
}
+ protected void setFederatedOutput(Lop lop){
+ lop.setFederatedOutput(_federatedOutput);
+ }
+
/**
* Set parse information.
*
diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java
b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index 55870fd..7de2d45 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -169,7 +169,7 @@ public class IndexingOp extends Hop
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
diff --git a/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java
b/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java
index 7aa26bd..d465cad 100644
--- a/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java
@@ -157,7 +157,7 @@ public class LeftIndexingOp extends Hop
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 55a9617..edf8dfc 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -194,6 +194,11 @@ public class OptimizerUtils
* out of while, for, and parfor loops.
*/
public static boolean ALLOW_CODE_MOTION = false;
+
+ /**
+ * Compile federated instructions based on input federation state and
privacy constraints.
+ */
+ public static boolean FEDERATED_COMPILATION = false;
/**
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index cc51375..ee54561 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -211,7 +211,7 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java
b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index fbdb5e9..badb057 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -61,6 +61,7 @@ public class ReorgOp extends MultiThreadedHop
_op = o;
getInput().add(0, inp);
inp.getParent().add(this);
+ updateETFed();
//compute unknown dims and nnz
refreshSizeInformation();
@@ -76,6 +77,8 @@ public class ReorgOp extends MultiThreadedHop
getInput().add(i, in);
in.getParent().add(this);
}
+
+ updateETFed();
//compute unknown dims and nnz
refreshSizeInformation();
@@ -159,6 +162,7 @@ public class ReorgOp extends MultiThreadedHop
else { //general case
int k =
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
Transform transform1 = new
Transform(lin, _op, getDataType(), getValueType(), et, k);
+ setFederatedOutput(transform1);
setOutputDimensions(transform1);
setLineNumbers(transform1);
setLops(transform1);
@@ -220,7 +224,7 @@ public class ReorgOp extends MultiThreadedHop
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
@@ -365,6 +369,8 @@ public class ReorgOp extends MultiThreadedHop
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
+
+ updateETFed();
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java
b/src/main/java/org/apache/sysds/hops/TernaryOp.java
index 47e42bb..6f5a55b 100644
--- a/src/main/java/org/apache/sysds/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java
@@ -196,7 +196,7 @@ public class TernaryOp extends MultiThreadedHop
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
@@ -500,6 +500,8 @@ public class TernaryOp extends MultiThreadedHop
checkAndSetInvalidCPDimsAndSize();
}
+ updateETFed();
+
//mark for recompile (forever)
// additional condition: when execType=CP and additional
dimension inputs
// are provided (and those values are unknown at initial
compile time).
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index d52d034..90e472c 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -182,7 +182,7 @@ public class UnaryOp extends MultiThreadedHop
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
-
+
return getLops();
}
@@ -512,6 +512,9 @@ public class UnaryOp extends MultiThreadedHop
|| getInput().get(0).getDataType() == DataType.LIST ||
isMetadataOperation() )
{
_etype = ExecType.CP;
+ } else {
+ updateETFed();
+ setRequiresRecompileIfNecessary();
}
return _etype;
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java
b/src/main/java/org/apache/sysds/lops/Binary.java
index 5fba53d..84bd033 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -81,20 +81,17 @@ public class Binary extends Lop
@Override
public String getInstructions(String input1, String input2, String
output) {
- if( getExecType() == ExecType.CP ) {
- return InstructionUtils.concatOperands(
- getExecType().name(), getOpcode(),
- getInputs().get(0).prepInputOperand(input1),
- getInputs().get(1).prepInputOperand(input2),
- prepOutputOperand(output),
- String.valueOf(_numThreads));
- }
- else {
- return InstructionUtils.concatOperands(
- getExecType().name(), getOpcode(),
- getInputs().get(0).prepInputOperand(input1),
- getInputs().get(1).prepInputOperand(input2),
- prepOutputOperand(output));
- }
+ String baseInstruction = InstructionUtils.concatOperands(
+ getExecType().name(), getOpcode(),
+ getInputs().get(0).prepInputOperand(input1),
+ getInputs().get(1).prepInputOperand(input2),
+ prepOutputOperand(output)
+ );
+
+ if( getExecType() == ExecType.CP || (!federatedOutput &&
getExecType() == ExecType.FED) )
+ return InstructionUtils.concatOperands(baseInstruction,
String.valueOf(_numThreads));
+ else if ( getExecType() == ExecType.FED )
+ return InstructionUtils.concatOperands(baseInstruction,
String.valueOf(_numThreads), String.valueOf(federatedOutput));
+ else return baseInstruction;
}
}
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java
b/src/main/java/org/apache/sysds/lops/Lop.java
index d81b6a8..a92609c 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -113,6 +113,13 @@ public abstract class Lop
* Privacy Constraint
*/
protected PrivacyConstraint privacyConstraint;
+
+ /**
+ * Boolean defining if the output of the operation should be federated.
+ * If it is true, the output should be kept at federated sites.
+ * If it is false, the output should be retrieved by the coordinator.
+ */
+ protected boolean federatedOutput = false;
/**
* refers to #lops whose input is equal to the output produced by this
lop.
@@ -286,6 +293,10 @@ public abstract class Lop
public PrivacyConstraint getPrivacyConstraint(){
return privacyConstraint;
}
+
+ public void setFederatedOutput(boolean federatedOutput){
+ this.federatedOutput = federatedOutput;
+ }
public void setConsumerCount(int cc) {
consumerCount = cc;
@@ -342,13 +353,21 @@ public abstract class Lop
}
/**
- * Method to get the execution type (CP, CP_FILE, MR, SPARK, GPU,
INVALID) of LOP
+ * Method to get the execution type (CP, CP_FILE, MR, SPARK, GPU, FED,
INVALID) of LOP
*
* @return execution type
*/
public ExecType getExecType() {
return lps.getExecType();
}
+
+ /**
+ * Set the execution type of LOP.
+ * @param newExecType new execution type
+ */
+ public void setExecType(ExecType newExecType){
+ lps.setExecType(newExecType);
+ }
public boolean getProducesIntermediateOutput() {
return lps.getProducesIntermediateOutput();
diff --git a/src/main/java/org/apache/sysds/lops/LopProperties.java
b/src/main/java/org/apache/sysds/lops/LopProperties.java
index efcd40e..c33a4c0 100644
--- a/src/main/java/org/apache/sysds/lops/LopProperties.java
+++ b/src/main/java/org/apache/sysds/lops/LopProperties.java
@@ -61,6 +61,10 @@ public class LopProperties
public ExecType getExecType() {
return execType;
}
+
+ public void setExecType(ExecType newExecType){
+ execType = newExecType;
+ }
public boolean getProducesIntermediateOutput() {
return producesIntermediateOutput;
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index c28a9d5..c291782 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -233,7 +233,7 @@ public class PartialAggregate extends Lop
sb.append( OPERAND_DELIMITOR );
if( getExecType() == ExecType.SPARK )
sb.append( _aggtype );
- else if( getExecType() == ExecType.CP ) {
+ else if( getExecType() == ExecType.CP || getExecType() ==
ExecType.FED ) {
sb.append(_numThreads);
//number of outputs, valid for fed instruction
diff --git a/src/main/java/org/apache/sysds/lops/Transform.java
b/src/main/java/org/apache/sysds/lops/Transform.java
index 544ea50..2c1df26 100644
--- a/src/main/java/org/apache/sysds/lops/Transform.java
+++ b/src/main/java/org/apache/sysds/lops/Transform.java
@@ -168,10 +168,14 @@ public class Transform extends Lop
sb.append( OPERAND_DELIMITOR );
sb.append( this.prepOutputOperand(output));
- if( getExecType()==ExecType.CP
+ if( (getExecType()==ExecType.CP || getExecType()==ExecType.FED)
&& (_operation == ReOrgOp.TRANS || _operation ==
ReOrgOp.SORT) ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _numThreads );
+ if ( federatedOutput ){
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( federatedOutput );
+ }
}
if( getExecType()==ExecType.SPARK && _operation ==
ReOrgOp.RESHAPE ) {
sb.append( OPERAND_DELIMITOR );
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index dc2fae5..f569364 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -65,8 +65,14 @@ public class FederationUtils {
return _idSeq.getNextID();
}
+ public static FederatedRequest callInstruction(String inst, CPOperand
varOldOut, CPOperand[] varOldIn, long[] varNewIn, boolean federatedOutput){
+ long id = getNextFedDataID();
+ String linst =
InstructionUtils.instructionStringFEDPrepare(inst, varOldOut, id, varOldIn,
varNewIn, federatedOutput);
+ return new FederatedRequest(RequestType.EXEC_INST, id, linst);
+ }
+
public static FederatedRequest callInstruction(String inst, CPOperand
varOldOut, CPOperand[] varOldIn, long[] varNewIn) {
- return callInstruction(inst, varOldOut, getNextFedDataID(),
varOldIn, varNewIn);
+ return callInstruction(inst,varOldOut, varOldIn, varNewIn,
false);
}
public static FederatedRequest[] callInstruction(String[] inst,
CPOperand varOldOut, CPOperand[] varOldIn, long[] varNewIn) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 2abeaeb..e6f430d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -21,9 +21,13 @@ package org.apache.sysds.runtime.instructions;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.ReorgFEDInstruction;
+import org.apache.sysds.runtime.instructions.fed.TsmmFEDInstruction;
import java.util.HashMap;
@@ -32,8 +36,28 @@ public class FEDInstructionParser extends InstructionParser
public static final HashMap<String, FEDType> String2FEDInstructionType;
static {
String2FEDInstructionType = new HashMap<>();
- String2FEDInstructionType.put("fedinit", FEDType.Init);
- String2FEDInstructionType.put("ba+*",
FEDType.AggregateBinary);
+ String2FEDInstructionType.put( "fedinit" , FEDType.Init );
+ String2FEDInstructionType.put( "tsmm" , FEDType.Tsmm );
+ String2FEDInstructionType.put( "ba+*" ,
FEDType.AggregateBinary );
+
+ String2FEDInstructionType.put( "uak+" ,
FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uark+" ,
FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uack+" ,
FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uasqk+" ,
FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uarsqk+" ,
FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uacsqk+" ,
FEDType.AggregateUnary );
+
+ // Arithmetic Instruction Opcodes
+ String2FEDInstructionType.put( "+" , FEDType.Binary );
+ String2FEDInstructionType.put( "-" , FEDType.Binary );
+ String2FEDInstructionType.put( "*" , FEDType.Binary );
+ String2FEDInstructionType.put( "/" , FEDType.Binary );
+
+ // Reorg Instruction Opcodes (repositioning of existing values)
+ String2FEDInstructionType.put( "r'" , FEDType.Reorg );
+ String2FEDInstructionType.put( "rdiag" , FEDType.Reorg );
+ String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
+
}
public static FEDInstruction parseSingleInstruction (String str ) {
@@ -56,6 +80,14 @@ public class FEDInstructionParser extends InstructionParser
return InitFEDInstruction.parseInstruction(str);
case AggregateBinary:
return
AggregateBinaryFEDInstruction.parseInstruction(str);
+ case AggregateUnary:
+ return
AggregateUnaryFEDInstruction.parseInstruction(str);
+ case Tsmm:
+ return TsmmFEDInstruction.parseInstruction(str);
+ case Binary:
+ return
BinaryFEDInstruction.parseInstruction(str);
+ case Reorg:
+ return
ReorgFEDInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid
FEDERATED Instruction Type: " + fedtype );
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 77b53f9..fb60d6b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -146,7 +146,10 @@ public class InstructionUtils
public static int checkNumFields( String[] parts, int... expected ) {
int numParts = parts.length;
int numFields = numParts - 1; //account for opcode
-
+ return checkMatchingNumField(numFields, expected);
+ }
+
+ private static int checkMatchingNumField(int numFields, int...
expected){
if (Arrays.stream(expected).noneMatch((i) -> numFields == i)) {
StringBuilder sb = new StringBuilder();
sb.append("checkNumFields() -- expected number (");
@@ -158,10 +161,15 @@ public class InstructionUtils
sb.append(") != is not equal to actual number
(").append(numFields).append(").");
throw new DMLRuntimeException(sb.toString());
}
-
return numFields;
}
+ public static int checkNumFields( String str, int... expected ) {
+ int numParts = str.split(Instruction.OPERAND_DELIM).length;
+ int numFields = numParts - 2; // -2 accounts for execType and
opcode
+ return checkMatchingNumField(numFields, expected);
+ }
+
public static int checkNumFields( String str, int expected1, int
expected2 ) {
//note: split required for empty tokens
int numParts = str.split(Instruction.OPERAND_DELIM).length;
@@ -1062,4 +1070,55 @@ public class InstructionUtils
parts[1] = opcode;
return InstructionUtils.concatOperands(parts[0], parts[1],
createOperand(op1), createOperand(op2), createOperand(out));
}
+
+ /**
+ * Prepare instruction string for sending in a FederatedRequest as a CP
instruction.
+ * This involves replacing the coordinator operand names with the
worker operand names,
+ * changing the execution type, and removing the federated output flag
if necessary.
+ * @param inst instruction string to prepare for federated request
+ * @param varOldOut current output operand (to be replaced)
+ * @param id new output operand (always a number)
+ * @param varOldIn current input operand (to be replaced)
+ * @param varNewIn new input operand names (always numbers)
+ * @param federatedOutput federated output flag
+ * @return instruction string prepared for federated request
+ */
+ public static String instructionStringFEDPrepare(String inst, CPOperand
varOldOut, long id, CPOperand[] varOldIn, long[] varNewIn, boolean
federatedOutput){
+ String linst = replaceExecTypeWithCP(inst);
+ linst = replaceOutputOperand(linst, varOldOut, id);
+ linst = replaceInputOperand(linst, varOldIn, varNewIn);
+ linst = removeFEDOutputFlag(linst, federatedOutput);
+ return linst;
+ }
+
+ private static String replaceExecTypeWithCP(String inst){
+ return inst.replace(Types.ExecType.SPARK.name(),
Types.ExecType.CP.name())
+ .replace(Types.ExecType.FED.name(),
Types.ExecType.CP.name());
+ }
+
+ private static String replaceOutputOperand(String linst, CPOperand
varOldOut, long id){
+ return replaceOperand(linst, varOldOut, Long.toString(id));
+ }
+
+ private static String replaceInputOperand(String linst, CPOperand[]
varOldIn, long[] varNewIn){
+ for(int i=0; i<varOldIn.length; i++)
+ if( varOldIn[i] != null ) {
+ linst = replaceOperand(linst, varOldIn[i],
Long.toString(varNewIn[i]));
+ linst =
linst.replace("="+varOldIn[i].getName(), "="+varNewIn[i]); //parameterized
+ }
+ return linst;
+ }
+
+ private static String removeFEDOutputFlag(String linst, boolean
federatedOutput){
+ if ( federatedOutput ){
+ linst = linst.substring(0,
linst.lastIndexOf(Lop.OPERAND_DELIMITOR));
+ }
+ return linst;
+ }
+
+ private static String replaceOperand(String linst, CPOperand
oldOperand, String newOperandName){
+ return linst.replace(
+
Lop.OPERAND_DELIMITOR+oldOperand.getName()+Lop.DATATYPE_PREFIX,
+
Lop.OPERAND_DELIMITOR+newOperandName+Lop.DATATYPE_PREFIX);
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 1adaf09..659281a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -31,8 +31,13 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr, boolean federatedOutput) {
+ super(type, op, in1, in2, out, opcode, istr, federatedOutput);
+ }
+
+ protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr) {
- super(type, op, in1, in2, out, opcode, istr);
+ this(type, op, in1, in2, out, opcode, istr, false);
}
public BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
@@ -47,11 +52,12 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
}
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
- InstructionUtils.checkNumFields(parts, 3, 4);
+ InstructionUtils.checkNumFields(parts, 3, 4, 5);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
+ boolean federatedOutput = parts.length > 5 &&
Boolean.parseBoolean(parts[5]);
checkOutputDataType(in1, in2, out);
Operator operator =
InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
@@ -61,13 +67,11 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
if( in1.getDataType() == DataType.SCALAR && in2.getDataType()
== DataType.SCALAR )
throw new DMLRuntimeException("Federated binary scalar
scalar operations not yet supported");
else if( in1.getDataType() == DataType.MATRIX &&
in2.getDataType() == DataType.MATRIX )
- return new BinaryMatrixMatrixFEDInstruction(operator,
in1, in2, out, opcode, str);
+ return new BinaryMatrixMatrixFEDInstruction(operator,
in1, in2, out, opcode, str, federatedOutput);
else if( in1.getDataType() == DataType.TENSOR &&
in2.getDataType() == DataType.TENSOR )
throw new DMLRuntimeException("Federated binary tensor
tensor operations not yet supported");
- else if( in1.isMatrix() && in2.isScalar() )
- return new BinaryMatrixScalarFEDInstruction(operator,
in1, in2, out, opcode, str);
- else if( in2.isMatrix() && in1.isScalar() )
- return new BinaryMatrixScalarFEDInstruction(operator,
in1, in2, out, opcode, str);
+ else if( in1.isMatrix() && in2.isScalar() || in2.isMatrix() &&
in1.isScalar() )
+ return new BinaryMatrixScalarFEDInstruction(operator,
in1, in2, out, opcode, str, federatedOutput);
else
throw new DMLRuntimeException("Federated binary
operations not yet supported:" + opcode);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index cbe9bad..77ade9a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -33,8 +33,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
{
protected BinaryMatrixMatrixFEDInstruction(Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr) {
- super(FEDType.Binary, op, in1, in2, out, opcode, istr);
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr, boolean federatedOutput) {
+ super(FEDType.Binary, op, in1, in2, out, opcode, istr,
federatedOutput);
}
@Override
@@ -55,7 +55,7 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
if( mo2.isFederated() ) {
if(mo1.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
- new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID()});
+ new long[]{mo1.getFedMapping().getID(),
mo2.getFedMapping().getID()}, _federatedOutput);
mo1.getFedMapping().execute(getTID(), true,
fr2);
}
else if ( !mo1.isFederated() ){
@@ -70,12 +70,27 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
}
}
else { // matrix-matrix binary operations -> lhs fed input ->
fed output
- if((mo1.isFederated(FType.ROW) && mo2.getNumRows() == 1
&& mo2.getNumColumns() > 1)
+ if(mo1.isFederated(FType.FULL)) {
+ // full federated (row and col)
+ if(mo1.getFedMapping().getSize() == 1) {
+ // only one partition (MM on a single
fed worker)
+ FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
+ new long[]{mo1.getFedMapping().getID(),
fr1.getID()}, _federatedOutput);
+ FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
+ //execute federated instruction and
cleanup intermediates
+ mo1.getFedMapping().execute(getTID(),
true, fr1, fr2, fr3);
+ }
+ else {
+ throw new
DMLRuntimeException("Matrix-matrix binary operations with a full partitioned
federated input with multiple partitions are not supported yet.");
+ }
+ }
+ else if((mo1.isFederated(FType.ROW) && mo2.getNumRows()
== 1 && mo2.getNumColumns() > 1)
|| (mo1.isFederated(FType.COL) &&
mo2.getNumRows() > 1 && mo2.getNumColumns() == 1)) {
// MV row partitioned row vector, MV col
partitioned col vector
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
- new long[]{mo1.getFedMapping().getID(),
fr1.getID()});
+ new long[]{mo1.getFedMapping().getID(),
fr1.getID()}, _federatedOutput);
FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1.getID());
//execute federated instruction and cleanup
intermediates
mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
@@ -85,7 +100,7 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
// row partitioned MM or col partitioned MM
FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
- new long[]{mo1.getFedMapping().getID(),
fr1[0].getID()});
+ new long[]{mo1.getFedMapping().getID(),
fr1[0].getID()}, _federatedOutput);
FederatedRequest fr3 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
//execute federated instruction and cleanup
intermediates
mo1.getFedMapping().execute(getTID(), true,
fr1, fr2, fr3);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index 895db4a..441a00b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -29,8 +29,8 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
public class BinaryMatrixScalarFEDInstruction extends BinaryFEDInstruction
{
protected BinaryMatrixScalarFEDInstruction(Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr) {
- super(FEDType.Binary, op, in1, in2, out, opcode, istr);
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr, boolean federatedOutput) {
+ super(FEDType.Binary, op, in1, in2, out, opcode, istr,
federatedOutput);
}
@Override
@@ -44,7 +44,7 @@ public class BinaryMatrixScalarFEDInstruction extends
BinaryFEDInstruction
mo.getFedMapping().broadcast(ec.getScalarInput(scalar))
: null;
FederatedRequest fr2 =
FederationUtils.callInstruction(instString, output,
new CPOperand[]{matrix, (fr1 != null)?scalar:null},
- new long[]{mo.getFedMapping().getID(), (fr1 !=
null)?fr1.getID():-1});
+ new long[]{mo.getFedMapping().getID(), (fr1 !=
null)?fr1.getID():-1}, _federatedOutput);
//execute federated matrix-scalar operation and cleanups
if( fr1 != null ) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
index ccaec24..692455c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
@@ -39,20 +39,26 @@ public abstract class ComputationFEDInstruction extends
FEDInstruction implement
protected ComputationFEDInstruction(FEDType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr) {
- super(type, op, opcode, istr);
+ this(type, op, in1, in2, null, out, opcode, istr, false);
+ }
+
+ protected ComputationFEDInstruction(FEDType type, Operator op,
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr, boolean federatedOutput) {
+ this(type, op, in1, in2, null, out, opcode, istr,
federatedOutput);
+ }
+
+ protected ComputationFEDInstruction(FEDType type, Operator op,
+ CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
String opcode, String istr, boolean federatedOutput){
+ super(type, op, opcode, istr, federatedOutput);
input1 = in1;
input2 = in2;
- input3 = null;
+ input3 = in3;
output = out;
}
protected ComputationFEDInstruction(FEDType type, Operator op,
CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
String opcode, String istr) {
- super(type, op, opcode, istr);
- input1 = in1;
- input2 = in2;
- input3 = in3;
- output = out;
+ this(type, op, in1, in2, in3, out, opcode, istr, false);
}
public String getOutputVariableName() {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 1d3c54c..c75c798 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -51,16 +51,22 @@ public abstract class FEDInstruction extends Instruction {
protected final FEDType _fedType;
protected long _tid = -1; //main
+ protected boolean _federatedOutput = false;
protected FEDInstruction(FEDType type, String opcode, String istr) {
this(type, null, opcode, istr);
}
protected FEDInstruction(FEDType type, Operator op, String opcode,
String istr) {
+ this(type, op, opcode, istr, false);
+ }
+
+ protected FEDInstruction(FEDType type, Operator op, String opcode,
String istr, boolean federatedOutput) {
super(op);
_fedType = type;
instString = istr;
instOpcode = opcode;
+ _federatedOutput = federatedOutput;
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index 1d9cbdd..04b50ac 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -51,11 +51,16 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
}
private QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand
in2, CPOperand out, OperationTypes type,
- boolean inmem, String opcode, String istr) {
- super(FEDType.QPick, op, in, in2, out, opcode, istr);
+ boolean inmem, String opcode, String istr, boolean
federatedOutput) {
+ super(FEDType.QPick, op, in, in2, out, opcode, istr,
federatedOutput);
_type = type;
}
+ private QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand
in2, CPOperand out, OperationTypes type,
+ boolean inmem, String opcode, String istr) {
+ this(op, in, in2, out, type, inmem, opcode, istr, false);
+ }
+
public static QuantilePickFEDInstruction parseInstruction ( String str
) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index 31184f7..f4999f8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -53,6 +53,10 @@ import
org.apache.sysds.runtime.matrix.operators.ReorgOperator;
public class ReorgFEDInstruction extends UnaryFEDInstruction {
+ public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out,
String opcode, String istr, boolean federatedOutput) {
+ super(FEDType.Reorg, op, in1, out, opcode, istr,
federatedOutput);
+ }
+
public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out,
String opcode, String istr) {
super(FEDType.Reorg, op, in1, out, opcode, istr);
}
@@ -64,11 +68,12 @@ public class ReorgFEDInstruction extends
UnaryFEDInstruction {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
if ( opcode.equalsIgnoreCase("r'") ) {
- InstructionUtils.checkNumFields(str, 2, 3);
+ InstructionUtils.checkNumFields(str, 2, 3, 4);
in.split(parts[1]);
out.split(parts[2]);
int k = Integer.parseInt(parts[3]);
- return new ReorgFEDInstruction(new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str);
+ boolean federatedOutput = parts.length > 4 &&
Boolean.parseBoolean(parts[4]);
+ return new ReorgFEDInstruction(new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str,
federatedOutput);
}
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
@@ -97,7 +102,7 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction
{
FederatedRequest fr1 =
FederationUtils.callInstruction(instString,
output,
new CPOperand[] {input1},
- new long[] {mo1.getFedMapping().getID()});
+ new long[] {mo1.getFedMapping().getID()},
_federatedOutput);
mo1.getFedMapping().execute(getTID(), true, fr1);
//drive output federated mapping
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 62438c0..bb14774 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -39,11 +39,15 @@ public class TsmmFEDInstruction extends
BinaryFEDInstruction {
@SuppressWarnings("unused")
private final int _numThreads;
- public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type,
int k, String opcode, String istr) {
- super(FEDType.Tsmm, null, in, null, out, opcode, istr);
+ public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type,
int k, String opcode, String istr, boolean federatedOutput) {
+ super(FEDType.Tsmm, null, in, null, out, opcode, istr,
federatedOutput);
_type = type;
_numThreads = k;
}
+
+ public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type,
int k, String opcode, String istr) {
+ this(in, out, type, k, opcode, istr, false);
+ }
public static TsmmFEDInstruction parseInstruction(String str) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
@@ -51,11 +55,11 @@ public class TsmmFEDInstruction extends
BinaryFEDInstruction {
if(!opcode.equalsIgnoreCase("tsmm"))
throw new
DMLRuntimeException("TsmmFedInstruction.parseInstruction():: Unknown opcode " +
opcode);
- InstructionUtils.checkNumFields(parts, 4);
+ InstructionUtils.checkNumFields(parts, 3, 4);
CPOperand in = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
MMTSJType type = MMTSJType.valueOf(parts[3]);
- int k = Integer.parseInt(parts[4]);
+ int k = (parts.length > 4) ? Integer.parseInt(parts[4]) : -1;
return new TsmmFEDInstruction(in, out, type, k, opcode, str);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
index edd07c6..fa0754f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
@@ -29,15 +29,30 @@ public abstract class UnaryFEDInstruction extends
ComputationFEDInstruction {
protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in,
CPOperand out, String opcode, String instr) {
this(type, op, in, null, null, out, opcode, instr);
}
+
+ protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in,
CPOperand out, String opcode, String instr,
+ boolean federatedOutput) {
+ this(type, op, in, null, null, out, opcode, instr,
federatedOutput);
+ }
protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in1,
CPOperand in2, CPOperand out, String opcode,
String instr) {
this(type, op, in1, in2, null, out, opcode, instr);
}
+
+ protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in1,
CPOperand in2, CPOperand out, String opcode,
+ String instr, boolean federatedOutput) {
+ this(type, op, in1, in2, null, out, opcode, instr,
federatedOutput);
+ }
protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out,
String opcode, String instr) {
- super(type, op, in1, in2, in3, out, opcode, instr);
+ this(type, op, in1, in2, in3, out, opcode, instr, false);
+ }
+
+ protected UnaryFEDInstruction(FEDType type, Operator op, CPOperand in1,
CPOperand in2, CPOperand in3, CPOperand out,
+ String opcode, String instr, boolean federatedOutput) {
+ super(type, op, in1, in2, in3, out, opcode, instr,
federatedOutput);
}
static String parseUnaryInstruction(String instr, CPOperand in,
CPOperand out) {
diff --git
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
index 49d17fa..71e1d46 100644
---
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
+++
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
@@ -279,7 +279,7 @@ public class PrivacyPropagator
ec.releaseMatrixInput(inst.input1.getName(),
inst.input2.getName());
}
else {
- mergedPrivacyConstraint =
mergeNary(privacyConstraints, OperatorType.NonAggregate);
+ mergedPrivacyConstraint =
mergeNary(privacyConstraints, OperatorType.Aggregate);
inst.setPrivacyConstraint(mergedPrivacyConstraint);
}
inst.output.setPrivacyConstraint(mergedPrivacyConstraint);
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
index 2517470..cb34339 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.functions.federated.primitives;
+import org.apache.sysds.hops.OptimizerUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -68,7 +69,16 @@ public class FederatedBinaryMatrixTest extends
AutomatedTestBase {
federatedMultiply(Types.ExecMode.SINGLE_NODE);
}
- public void federatedMultiply(Types.ExecMode execMode) {
+ @Test
+ public void federatedMultiplyCPCompileToFED() {
+ federatedMultiply(Types.ExecMode.SINGLE_NODE, true);
+ }
+
+ public void federatedMultiply(Types.ExecMode execMode){
+ federatedMultiply(execMode, false);
+ }
+
+ public void federatedMultiply(Types.ExecMode execMode, boolean
federatedCompilation) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
rtplatform = execMode;
@@ -108,6 +118,7 @@ public class FederatedBinaryMatrixTest extends
AutomatedTestBase {
runTest(null);
// Run actual dml script with federated matrix
+ OptimizerUtils.FEDERATED_COMPILATION = federatedCompilation;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
@@ -122,5 +133,6 @@ public class FederatedBinaryMatrixTest extends
AutomatedTestBase {
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.FEDERATED_COMPILATION = false;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
index 6ac77c4..089c23e 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.functions.federated.primitives;
+import org.apache.sysds.hops.OptimizerUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -66,10 +67,15 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
@Test
public void federatedMultiplyCP() {
- federatedMultiply(Types.ExecMode.SINGLE_NODE);
+ federatedMultiply(Types.ExecMode.SINGLE_NODE, false);
}
- public void federatedMultiply(Types.ExecMode execMode) {
+ @Test
+ public void federatedMultiplyCPCompileToFED() {
+ federatedMultiply(Types.ExecMode.SINGLE_NODE, true);
+ }
+
+ public void federatedMultiply(Types.ExecMode execMode, boolean
federatedCompilation) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
rtplatform = execMode;
@@ -109,6 +115,7 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
+ OptimizerUtils.FEDERATED_COMPILATION = federatedCompilation;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
@@ -123,5 +130,6 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.FEDERATED_COMPILATION = false;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
index 8836203..7fc192d 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.functions.federated.primitives;
+import org.apache.sysds.hops.OptimizerUtils;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -68,6 +69,10 @@ public class FederatedMultiplyTest extends AutomatedTestBase
{
federatedMultiply(Types.ExecMode.SINGLE_NODE);
}
+ @Test
+ public void federatedMultiplyCPCompileToFED() {
+ federatedMultiply(Types.ExecMode.SINGLE_NODE, true);
+ }
@Test
@Ignore
@@ -76,7 +81,11 @@ public class FederatedMultiplyTest extends AutomatedTestBase
{
federatedMultiply(Types.ExecMode.SPARK);
}
- public void federatedMultiply(Types.ExecMode execMode) {
+ private void federatedMultiply(Types.ExecMode execMode){
+ federatedMultiply(execMode,false);
+ }
+
+ private void federatedMultiply(Types.ExecMode execMode, boolean
federatedCompilation) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
rtplatform = execMode;
@@ -116,6 +125,7 @@ public class FederatedMultiplyTest extends
AutomatedTestBase {
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
+ OptimizerUtils.FEDERATED_COMPILATION = federatedCompilation;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
@@ -129,6 +139,7 @@ public class FederatedMultiplyTest extends
AutomatedTestBase {
TestUtils.shutdownThreads(t1, t2);
rtplatform = platformOld;
+ OptimizerUtils.FEDERATED_COMPILATION = false;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
index 3d03f7b..6262e03 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
@@ -24,6 +24,7 @@ import java.util.Collection;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -69,12 +70,21 @@ public class FederatedSumTest extends AutomatedTestBase {
}
@Test
+ public void federatedSumCPToFED() {
+ federatedSum(Types.ExecMode.SINGLE_NODE, true);
+ }
+
+ @Test
@Ignore
public void federatedSumSP() {
federatedSum(Types.ExecMode.SPARK);
}
- public void federatedSum(Types.ExecMode execMode) {
+ public void federatedSum(Types.ExecMode execMode){
+ federatedSum(execMode, false);
+ }
+
+ public void federatedSum(Types.ExecMode execMode, boolean
federatedCompilation) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
@@ -108,6 +118,7 @@ public class FederatedSumTest extends AutomatedTestBase {
}
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
+ OptimizerUtils.FEDERATED_COMPILATION = federatedCompilation;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-nvargs", "in=" +
TestUtils.federatedAddress(port, input("A")), "rows=" + rows,
"cols=" + cols, "out_S=" + output("S"), "out_R=" +
output("R"), "out_C=" + output("C")};
@@ -119,6 +130,7 @@ public class FederatedSumTest extends AutomatedTestBase {
TestUtils.shutdownThread(t);
rtplatform = platformOld;
+ OptimizerUtils.FEDERATED_COMPILATION = false;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
index c75e9a2..4c3d145 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedWorkerHandlerTest.java
@@ -279,7 +279,7 @@ public class FederatedWorkerHandlerTest extends
AutomatedTestBase {
@Test
public void matVecMultPrivateAggregationTest() {
- federatedMultiply(Types.ExecMode.SINGLE_NODE,
PrivacyLevel.PrivateAggregation, DMLRuntimeException.class);
+ federatedMultiply(Types.ExecMode.SINGLE_NODE,
PrivacyLevel.PrivateAggregation, null);
}
@Test
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
index 20802a1..bc4dec4 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/algorithms/FederatedL2SVMTest.java
@@ -57,16 +57,14 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
public void federatedL2SVMCPPrivateAggregationX1() throws JSONException
{
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.PrivateAggregation,
- false,null, true, DMLRuntimeException.class);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
}
@Test
public void federatedL2SVMCPPrivateAggregationX2() throws JSONException
{
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.PrivateAggregation,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
}
@Test
@@ -196,8 +194,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.PrivateAggregation,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
}
@Test
@@ -205,8 +202,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.PrivateAggregation,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
}
@Test
@@ -214,8 +210,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.PrivateAggregation,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
}
@Test
@@ -224,8 +219,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.PrivateAggregation,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.PrivateAggregation);
}
// Privacy Level Combinations
@@ -261,8 +255,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("X1", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.Private);
}
@Test
@@ -270,8 +263,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
Map<String, PrivacyConstraint> privacyConstraints = new
HashMap<>();
privacyConstraints.put("Y", new
PrivacyConstraint(PrivacyLevel.Private));
privacyConstraints.put("X2", new
PrivacyConstraint(PrivacyLevel.PrivateAggregation));
- federatedL2SVM(Types.ExecMode.SINGLE_NODE, privacyConstraints,
null, PrivacyLevel.Private,
- false, null, true, DMLRuntimeException.class);
+ federatedL2SVMNoException(Types.ExecMode.SINGLE_NODE,
privacyConstraints, null, PrivacyLevel.Private);
}
@Test
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
similarity index 70%
copy from
src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
copy to
src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 6ac77c4..e4da423 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -17,8 +17,10 @@
* under the License.
*/
-package org.apache.sysds.test.functions.federated.primitives;
+package org.apache.sysds.test.functions.privacy.fedplanning;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -34,12 +36,10 @@ import java.util.Collection;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedBinaryVectorTest extends AutomatedTestBase {
-
- private final static String TEST_DIR = "functions/federated/";
- // Using same test base as binary matrix test
- private final static String TEST_NAME = "FederatedBinaryMatrixTest";
- private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedBinaryVectorTest.class.getSimpleName() + "/";
+public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
+ private final static String TEST_DIR = "functions/privacy/";
+ private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@Parameterized.Parameter()
@@ -57,18 +57,24 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
public static Collection<Object[]> data() {
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
- // {2, 1000},
- // {10, 100},
- {100, 10},
- // {1000, 1}, {10, 2000}, {2000, 10}
+ {100, 10}
});
}
@Test
public void federatedMultiplyCP() {
+ OptimizerUtils.FEDERATED_COMPILATION = true;
federatedMultiply(Types.ExecMode.SINGLE_NODE);
}
+ private void writeStandardMatrix(String matrixName, long seed){
+ int halfRows = rows/2;
+ double[][] matrix = getRandomMatrix(halfRows, cols, 0, 1, 1,
seed);
+ writeInputMatrixWithMTD(matrixName, matrix, false,
+ new MatrixCharacteristics(halfRows, cols, blocksize,
halfRows * cols),
+ new
PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
+ }
+
public void federatedMultiply(Types.ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
@@ -80,19 +86,11 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
- // write input matrices
- int halfRows = rows / 2;
- // We have two matrices handled by a single federated worker
- double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
- double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
- // And another two matrices handled by a single federated worker
- double[][] Y1 = getRandomMatrix(halfRows, 1, 0, 1, 1, 44);
- double[][] Y2 = getRandomMatrix(halfRows, 1, 0, 1, 1, 21);
-
- writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
- writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
- writeInputMatrixWithMTD("Y1", Y1, false, new
MatrixCharacteristics(halfRows, 1, blocksize, halfRows));
- writeInputMatrixWithMTD("Y2", Y2, false, new
MatrixCharacteristics(halfRows, 1, blocksize, halfRows));
+ // Write input matrices
+ writeStandardMatrix("X1", 42);
+ writeStandardMatrix("X2", 1340);
+ writeStandardMatrix("Y1", 44);
+ writeStandardMatrix("Y2", 21);
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
@@ -102,22 +100,25 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
- // Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-nvargs", "X1=" + input("X1"),
"X2=" + input("X2"), "Y1=" + input("Y1"),
- "Y2=" + input("Y2"), "Z=" + expected("Z")};
- runTest(true, false, null, -1);
-
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] {"-explain", "-nvargs", "X1=" +
TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
"Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"Y2=" + TestUtils.federatedAddress(port2, input("Y2")),
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
runTest(true, false, null, -1);
+ OptimizerUtils.FEDERATED_COMPILATION = false;
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" + input("X1"),
"X2=" + input("X2"), "Y1=" + input("Y1"),
+ "Y2=" + input("Y2"), "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+
// compare via files
compareResults(1e-9);
+ heavyHittersContainsString("fed_*", "fed_ba+*");
TestUtils.shutdownThreads(t1, t2);
@@ -125,3 +126,4 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
+
diff --git
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest.dml
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest.dml
new file mode 100644
index 0000000..04b3804
--- /dev/null
+++ b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0),
list($r, $c)))
+Y = federated(addresses=list($Y1, $Y2),
+ ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0),
list($r, $c)))
+Z0 = X * Y
+Z = t(Z0) %*% X
+write(Z, $Z)
diff --git
a/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
new file mode 100644
index 0000000..ee595d7
--- /dev/null
+++
b/src/test/scripts/functions/privacy/FederatedMultiplyPlanningTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Y = rbind(read($Y1), read($Y2))
+Z0 = X * Y
+Z = t(Z0) %*% X
+write(Z, $Z)