This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new f487c18 [SYSTEMDS-2747] Federated PNMF test, extended quaternary
operations
f487c18 is described below
commit f487c187af9d4202787620dfa34bddd603f93585
Author: ywcb00 <[email protected]>
AuthorDate: Sat Feb 20 17:14:57 2021 +0100
[SYSTEMDS-2747] Federated PNMF test, extended quaternary operations
Closes #1175.
---
.../fed/QuaternaryWCeMMFEDInstruction.java | 87 ++++++-------
.../fed/QuaternaryWDivMMFEDInstruction.java | 135 +++++++++++----------
.../fed/QuaternaryWSLossFEDInstruction.java | 81 +++++++------
.../fed/QuaternaryWSigmoidFEDInstruction.java | 46 +++----
.../fed/QuaternaryWUMMFEDInstruction.java | 42 ++++---
.../apache/sysds/runtime/lineage/LineageItem.java | 2 +-
.../federated/algorithms/FederatedAlsCGTest.java | 6 +-
...eratedAlsCGTest.java => FederatedPNMFTest.java} | 78 +++++-------
.../FederatedWeightedCrossEntropyTest.java | 6 +-
.../FederatedWeightedDivMatrixMultTest.java | 22 ++--
.../primitives/FederatedWeightedSigmoidTest.java | 6 +-
.../FederatedWeightedSquaredLossTest.java | 6 +-
.../FederatedWeightedUnaryMatrixMultTest.java | 20 +--
.../functions/federated/FederatedPNMFTest.dml | 32 +++++
.../federated/FederatedPNMFTestReference.dml | 31 +++++
15 files changed, 331 insertions(+), 269 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
index 8566b39..68efe5d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
@@ -29,6 +29,7 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
@@ -59,58 +60,60 @@ public class QuaternaryWCeMMFEDInstruction extends
QuaternaryFEDInstruction
MatrixObject U = ec.getMatrixObject(input2);
MatrixObject V = ec.getMatrixObject(input3);
ScalarObject eps = null;
-
+
if(qop.hasFourInputs()) {
eps = (_input4.getDataType() == DataType.SCALAR) ?
ec.getScalarInput(_input4) :
new
DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
}
- if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
- throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V) = ("
- +X.isFederated()+", "+U.isFederated()+",
"+V.isFederated()+")");
-
- FederationMap fedMap = X.getFedMapping();
- FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false);
- FederatedRequest fr2 = fedMap.broadcast(V);
- FederatedRequest fr3 = null;
- FederatedRequest frComp = null;
+ if(X.isFederated(FType.ROW) && !U.isFederated() &&
!V.isFederated()) {
+ FederationMap fedMap = X.getFedMapping();
+ FederatedRequest[] fr1 = fedMap.broadcastSliced(U,
false);
+ FederatedRequest fr2 = fedMap.broadcast(V);
+ FederatedRequest fr3 = null;
+ FederatedRequest frComp = null;
- // broadcast the scalar epsilon if there are four inputs
- if(eps != null) {
- fr3 = fedMap.broadcast(eps);
- // change the is_literal flag from true to false
because when broadcasted it is no literal anymore
- instString = instString.replace("true", "false");
- frComp = FederationUtils.callInstruction(instString,
output,
- new CPOperand[]{input1, input2, input3,
_input4},
- new long[]{fedMap.getID(), fr1[0].getID(),
fr2.getID(), fr3.getID()});
- }
- else {
- frComp = FederationUtils.callInstruction(instString,
output,
- new CPOperand[]{input1, input2, input3},
- new long[]{fedMap.getID(), fr1[0].getID(),
fr2.getID()});
- }
-
- FederatedRequest frGet = new
FederatedRequest(RequestType.GET_VAR, frComp.getID());
- FederatedRequest frClean1 = fedMap.cleanup(getTID(),
frComp.getID());
- FederatedRequest frClean2 = fedMap.cleanup(getTID(),
fr1[0].getID());
- FederatedRequest frClean3 = fedMap.cleanup(getTID(),
fr2.getID());
+ // broadcast the scalar epsilon if there are four inputs
+ if(eps != null) {
+ fr3 = fedMap.broadcast(eps);
+ // change the is_literal flag from true to
false because when broadcasted it is no literal anymore
+ instString = instString.replace("true",
"false");
+ frComp =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2, input3,
_input4},
+ new long[]{fedMap.getID(),
fr1[0].getID(), fr2.getID(), fr3.getID()});
+ }
+ else {
+ frComp =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2, input3},
+ new long[]{fedMap.getID(), fr1[0].getID(),
fr2.getID()});
+ }
+
+ FederatedRequest frGet = new
FederatedRequest(RequestType.GET_VAR, frComp.getID());
+ FederatedRequest frClean1 = fedMap.cleanup(getTID(),
frComp.getID());
+ FederatedRequest frClean2 = fedMap.cleanup(getTID(),
fr1[0].getID());
+ FederatedRequest frClean3 = fedMap.cleanup(getTID(),
fr2.getID());
- Future<FederatedResponse>[] response;
- if(fr3 != null) {
- FederatedRequest frClean4 = fedMap.cleanup(getTID(),
fr3.getID());
- // execute federated instructions
- response = fedMap.execute(getTID(), true, fr1, fr2, fr3,
- frComp, frGet, frClean1, frClean2, frClean3,
frClean4);
+ Future<FederatedResponse>[] response;
+ if(fr3 != null) {
+ FederatedRequest frClean4 =
fedMap.cleanup(getTID(), fr3.getID());
+ // execute federated instructions
+ response = fedMap.execute(getTID(), true, fr1,
fr2, fr3,
+ frComp, frGet, frClean1, frClean2,
frClean3, frClean4);
+ }
+ else {
+ // execute federated instructions
+ response = fedMap.execute(getTID(), true, fr1,
fr2,
+ frComp, frGet, frClean1, frClean2,
frClean3);
+ }
+
+ //aggregate partial results from federated responses
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setVariable(output.getName(),
FederationUtils.aggScalar(aop, response));
}
else {
- // execute federated instructions
- response = fedMap.execute(getTID(), true, fr1, fr2,
- frComp, frGet, frClean1, frClean2, frClean3);
+ throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V) = ("
+ + X.isFederated() + ", " + U.isFederated() + ",
" + V.isFederated() + ")");
}
-
- //aggregate partial results from federated responses
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- ec.setVariable(output.getName(), FederationUtils.aggScalar(aop,
response));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
index 5ba2b59..877b9c5 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
@@ -86,79 +86,82 @@ public class QuaternaryWDivMMFEDInstruction extends
QuaternaryFEDInstruction
}
}
- if(!(X.isFederated(FType.ROW) && !U.isFederated() &&
!V.isFederated()))
- throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V) = ("
- +X.isFederated()+", "+U.isFederated()+",
"+V.isFederated() + ")");
-
- FederationMap fedMap = X.getFedMapping();
- FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
- FederatedRequest frInit2 = fedMap.broadcast(V);
+ if(X.isFederated(FType.ROW) && !U.isFederated() &&
!V.isFederated()) {
+ FederationMap fedMap = X.getFedMapping();
+ FederatedRequest[] frInit1 = fedMap.broadcastSliced(U,
false);
+ FederatedRequest frInit2 = fedMap.broadcast(V);
- FederatedRequest frInit3 = null;
- FederatedRequest frInit3Arr[] = null;
- FederatedRequest frCompute1 = null;
- // broadcast scalar epsilon if there are four inputs
- if(eps != null) {
- frInit3 = fedMap.broadcast(eps);
- // change the is_literal flag from true to false
because when broadcasted it is no literal anymore
- instString = instString.replace("true", "false");
- frCompute1 =
FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2, input3,
_input4},
- new long[]{fedMap.getID(), frInit1[0].getID(),
frInit2.getID(), frInit3.getID()});
- }
- else if(MX != null) {
- frInit3Arr = fedMap.broadcastSliced(MX, false);
- frCompute1 =
FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2, input3,
_input4},
- new long[]{fedMap.getID(), frInit1[0].getID(),
frInit2.getID(), frInit3Arr[0].getID()});
- }
- else {
- frCompute1 =
FederationUtils.callInstruction(instString, output,
- new CPOperand[]{input1, input2, input3},
- new long[]{fedMap.getID(), frInit1[0].getID(),
frInit2.getID()});
- }
+ FederatedRequest frInit3 = null;
+ FederatedRequest frInit3Arr[] = null;
+ FederatedRequest frCompute1 = null;
+ // broadcast scalar epsilon if there are four inputs
+ if(eps != null) {
+ frInit3 = fedMap.broadcast(eps);
+ // change the is_literal flag from true to
false because when broadcasted it is no literal anymore
+ instString = instString.replace("true",
"false");
+ frCompute1 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2, input3,
_input4},
+ new long[]{fedMap.getID(),
frInit1[0].getID(), frInit2.getID(), frInit3.getID()});
+ }
+ else if(MX != null) {
+ frInit3Arr = fedMap.broadcastSliced(MX, false);
+ frCompute1 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2, input3,
_input4},
+ new long[]{fedMap.getID(),
frInit1[0].getID(), frInit2.getID(), frInit3Arr[0].getID()});
+ }
+ else {
+ frCompute1 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2, input3},
+ new long[]{fedMap.getID(),
frInit1[0].getID(), frInit2.getID()});
+ }
- // get partial results from federated workers
- FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+ // get partial results from federated workers
+ FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
- FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
- FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
- FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
+ FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
+ FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
+ FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
- // execute federated instructions
- Future<FederatedResponse>[] response;
- if(frInit3 != null) {
- FederatedRequest frCleanup4 = fedMap.cleanup(getTID(),
frInit3.getID());
- response = fedMap.execute(getTID(), true,
- frInit1, frInit2, frInit3,
- frCompute1, frGet1,
- frCleanup1, frCleanup2, frCleanup3, frCleanup4);
- }
- else if(frInit3Arr != null) {
- FederatedRequest frCleanup4 = fedMap.cleanup(getTID(),
frInit3Arr[0].getID());
- fedMap.execute(getTID(), true, frInit1, frInit2);
- response = fedMap.execute(getTID(), true, frInit3Arr,
- frCompute1, frGet1,
- frCleanup1, frCleanup2, frCleanup3, frCleanup4);
- }
- else {
- response = fedMap.execute(getTID(), true,
- frInit1, frInit2,
- frCompute1, frGet1,
- frCleanup1, frCleanup2, frCleanup3);
- }
+ // execute federated instructions
+ Future<FederatedResponse>[] response;
+ if(frInit3 != null) {
+ FederatedRequest frCleanup4 =
fedMap.cleanup(getTID(), frInit3.getID());
+ response = fedMap.execute(getTID(), true,
+ frInit1, frInit2, frInit3,
+ frCompute1, frGet1,
+ frCleanup1, frCleanup2, frCleanup3,
frCleanup4);
+ }
+ else if(frInit3Arr != null) {
+ FederatedRequest frCleanup4 =
fedMap.cleanup(getTID(), frInit3Arr[0].getID());
+ fedMap.execute(getTID(), true, frInit1,
frInit2);
+ response = fedMap.execute(getTID(), true,
frInit3Arr,
+ frCompute1, frGet1,
+ frCleanup1, frCleanup2, frCleanup3,
frCleanup4);
+ }
+ else {
+ response = fedMap.execute(getTID(), true,
+ frInit1, frInit2,
+ frCompute1, frGet1,
+ frCleanup1, frCleanup2, frCleanup3);
+ }
- if(wdivmm_type.isLeft()) {
- // aggregate partial results from federated responses
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- ec.setMatrixOutput(output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else if(wdivmm_type.isRight() || wdivmm_type.isBasic()) {
- // bind partial results from federated responses
- ec.setMatrixOutput(output.getName(),
FederationUtils.bind(response, false));
+ if(wdivmm_type.isLeft()) {
+ // aggregate partial results from federated
responses
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setMatrixOutput(output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(wdivmm_type.isRight() || wdivmm_type.isBasic())
{
+ // bind partial results from federated responses
+ ec.setMatrixOutput(output.getName(),
FederationUtils.bind(response, false));
+ }
+ else {
+ throw new DMLRuntimeException("Federated WDivMM
only supported for BASIC, LEFT or RIGHT variants.");
+ }
}
else {
- throw new DMLRuntimeException("Federated WDivMM only
supported for BASIC, LEFT or RIGHT variants.");
+ throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V) = ("
+ + X.isFederated() + ", " + U.isFederated() + ",
" + V.isFederated() + ")");
}
}
}
+
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
index 664fbdc..f65d4f0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
@@ -25,6 +25,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -69,51 +70,53 @@ public class QuaternaryWSLossFEDInstruction extends
QuaternaryFEDInstruction {
W = ec.getMatrixObject(_input4);
}
- if(!(X.isFederated() && !U.isFederated() && !V.isFederated() &&
(W == null || !W.isFederated())))
- throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V, W) = (" + X.isFederated() + ", "
- + U.isFederated() + ", " + V.isFederated() + ",
" + (W != null ? W.isFederated() : "none") + ")");
+ if(X.isFederated(FType.ROW) && !U.isFederated() &&
!V.isFederated() && (W == null || !W.isFederated())) {
+ FederationMap fedMap = X.getFedMapping();
+ FederatedRequest[] frInit1 = fedMap.broadcastSliced(U,
false);
+ FederatedRequest frInit2 = fedMap.broadcast(V);
- FederationMap fedMap = X.getFedMapping();
- FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
- FederatedRequest frInit2 = fedMap.broadcast(V);
+ FederatedRequest[] frInit3 = null;
+ FederatedRequest frCompute1 = null;
+ if(W != null) {
+ frInit3 = fedMap.broadcastSliced(W, false);
+ frCompute1 =
FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {input1, input2,
input3, _input4},
+ new long[] {fedMap.getID(),
frInit1[0].getID(), frInit2.getID(), frInit3[0].getID()});
+ }
+ else {
+ frCompute1 =
FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {input1, input2,
input3},
+ new long[] {fedMap.getID(),
frInit1[0].getID(), frInit2.getID()});
+ }
- FederatedRequest[] frInit3 = null;
- FederatedRequest frCompute1 = null;
- if(W != null) {
- frInit3 = fedMap.broadcastSliced(W, false);
- frCompute1 = FederationUtils.callInstruction(instString,
- output,
- new CPOperand[] {input1, input2, input3,
_input4},
- new long[] {fedMap.getID(), frInit1[0].getID(),
frInit2.getID(), frInit3[0].getID()});
- }
- else {
- frCompute1 = FederationUtils.callInstruction(instString,
- output,
- new CPOperand[] {input1, input2, input3},
- new long[] {fedMap.getID(), frInit1[0].getID(),
frInit2.getID()});
- }
+ FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+ FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
+ FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
+ FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
- FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
- FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
- FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
- FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
+ Future<FederatedResponse>[] response;
+ if(frInit3 != null) {
+ FederatedRequest frCleanup4 =
fedMap.cleanup(getTID(), frInit3[0].getID());
+ // execute federated instructions
+ fedMap.execute(getTID(), true, frInit1,
frInit2);
+ response = fedMap
+ .execute(getTID(), true, frInit3,
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4);
+ }
+ else {
+ // execute federated instructions
+ response = fedMap
+ .execute(getTID(), true, frInit1,
frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
+ }
- Future<FederatedResponse>[] response;
- if(frInit3 != null) {
- FederatedRequest frCleanup4 = fedMap.cleanup(getTID(),
frInit3[0].getID());
- // execute federated instructions
- fedMap.execute(getTID(), true, frInit1, frInit2);
- response = fedMap
- .execute(getTID(), true, frInit3, frCompute1,
frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4);
+ // aggregate partial results from federated responses
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setVariable(output.getName(),
FederationUtils.aggScalar(aop, response));
}
else {
- // execute federated instructions
- response = fedMap
- .execute(getTID(), true, frInit1, frInit2,
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
+ throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V, W) = (" + X.isFederated() + ", "
+ + U.isFederated() + ", " + V.isFederated() + ",
" + (W != null ? W.isFederated() : "none") + ")");
}
-
- // aggregate partial results from federated responses
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- ec.setVariable(output.getName(), FederationUtils.aggScalar(aop,
response));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
index 9884e3b..95caaef 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
@@ -28,6 +28,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -58,32 +59,33 @@ public class QuaternaryWSigmoidFEDInstruction extends
QuaternaryFEDInstruction {
MatrixObject U = ec.getMatrixObject(input2);
MatrixObject V = ec.getMatrixObject(input3);
- if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
- throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V) = (" + X.isFederated() + ", "
- + U.isFederated() + ", " + V.isFederated() +
")");
+ if(X.isFederated(FType.ROW) && !U.isFederated() &&
!V.isFederated()) {
+ FederationMap fedMap = X.getFedMapping();
+ FederatedRequest[] frInit1 = fedMap.broadcastSliced(U,
false);
+ FederatedRequest frInit2 = fedMap.broadcast(V);
- FederationMap fedMap = X.getFedMapping();
- FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
- FederatedRequest frInit2 = fedMap.broadcast(V);
+ FederatedRequest frCompute1 =
FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {input1, input2, input3},
+ new long[] {fedMap.getID(), frInit1[0].getID(),
frInit2.getID()});
- FederatedRequest frCompute1 =
FederationUtils.callInstruction(instString,
- output,
- new CPOperand[] {input1, input2, input3},
- new long[] {fedMap.getID(), frInit1[0].getID(),
frInit2.getID()});
+ // get partial results from federated workers
+ FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
- // get partial results from federated workers
- FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+ FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
+ FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
+ FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
- FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
- FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
- FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
-
- // execute federated instructions
- Future<FederatedResponse>[] response = fedMap
- .execute(getTID(), true, frInit1, frInit2, frCompute1,
frGet1, frCleanup1, frCleanup2, frCleanup3);
-
- // bind partial results from federated responses
- ec.setMatrixOutput(output.getName(),
FederationUtils.bind(response, false));
+ // execute federated instructions
+ Future<FederatedResponse>[] response = fedMap
+ .execute(getTID(), true, frInit1, frInit2,
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
+ // bind partial results from federated responses
+ ec.setMatrixOutput(output.getName(),
FederationUtils.bind(response, false));
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V) = ("
+ + X.isFederated() + ", " + U.isFederated() + ",
" + V.isFederated() + ")");
+ }
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
index 82bc9e2..2512439 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
@@ -60,30 +60,32 @@ public class QuaternaryWUMMFEDInstruction extends
QuaternaryFEDInstruction {
MatrixObject U = ec.getMatrixObject(input2);
MatrixObject V = ec.getMatrixObject(input3);
- if(!(X.isFederated(FType.ROW) && !U.isFederated() &&
!V.isFederated()))
- throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V) = (" + X.isFederated() + ", "
- + U.isFederated() + ", " + V.isFederated() +
")");
+ if(X.isFederated(FType.ROW) && !U.isFederated() &&
!V.isFederated()) {
+ FederationMap fedMap = X.getFedMapping();
+ FederatedRequest[] frInit1 = fedMap.broadcastSliced(U,
false);
+ FederatedRequest frInit2 = fedMap.broadcast(V);
- FederationMap fedMap = X.getFedMapping();
- FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
- FederatedRequest frInit2 = fedMap.broadcast(V);
+ FederatedRequest frCompute1 =
FederationUtils.callInstruction(instString,
+ output, new CPOperand[] {input1, input2,
input3},
+ new long[] {fedMap.getID(), frInit1[0].getID(),
frInit2.getID()});
- FederatedRequest frCompute1 =
FederationUtils.callInstruction(instString,
- output, new CPOperand[] {input1, input2, input3},
- new long[] {fedMap.getID(), frInit1[0].getID(),
frInit2.getID()});
+ // get partial results from federated workers
+ FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
- // get partial results from federated workers
- FederatedRequest frGet1 = new
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+ FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
+ FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
+ FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
- FederatedRequest frCleanup1 = fedMap.cleanup(getTID(),
frCompute1.getID());
- FederatedRequest frCleanup2 = fedMap.cleanup(getTID(),
frInit1[0].getID());
- FederatedRequest frCleanup3 = fedMap.cleanup(getTID(),
frInit2.getID());
+ // execute federated instructions
+ Future<FederatedResponse>[] response = fedMap
+ .execute(getTID(), true, frInit1, frInit2,
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
- // execute federated instructions
- Future<FederatedResponse>[] response = fedMap
- .execute(getTID(), true, frInit1, frInit2, frCompute1,
frGet1, frCleanup1, frCleanup2, frCleanup3);
-
- // bind partial results from federated responses
- ec.setMatrixOutput(output.getName(),
FederationUtils.bind(response, false));
+ // bind partial results from federated responses
+ ec.setMatrixOutput(output.getName(),
FederationUtils.bind(response, false));
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported federated
inputs (X, U, V) = ("
+ + X.isFederated() + ", " + U.isFederated() + ",
" + V.isFederated() + ")");
+ }
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
index 2210fff..cd2346d 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
@@ -354,7 +354,7 @@ public class LineageItem {
// Compare a dedup patch with a sub-DAG, and map the inputs of the
sub-dag
// to the placeholder inputs of the dedup patch
- private boolean equalsDedupPatch(LineageItem dli1, LineageItem dli2,
Map<Integer, LineageItem> phMap) {
+ private static boolean equalsDedupPatch(LineageItem dli1, LineageItem
dli2, Map<Integer, LineageItem> phMap) {
Stack<LineageItem> s1 = new Stack<>();
Stack<LineageItem> s2 = new Stack<>();
s1.push(dli1);
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
index 4909f7c..9263beb 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
@@ -46,7 +46,7 @@ public class FederatedAlsCGTest extends AutomatedTestBase
private final static String OUTPUT_NAME = "Z";
private final static double TOLERANCE = 0.01;
- private final static int blocksize = 1024;
+ private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
public int rows;
@@ -112,9 +112,9 @@ public class FederatedAlsCGTest extends AutomatedTestBase
double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2,
sparsity, 2);
writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(
- fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(
- fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
// empty script name because we don't execute any script, just
start the worker
fullDMLScriptName = "";
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
similarity index 69%
copy from
src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
copy to
src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
index 4909f7c..00358c7 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
@@ -1,18 +1,18 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
+ * or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
+ * regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
+ * KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
@@ -38,15 +38,15 @@ import java.util.HashMap;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedAlsCGTest extends AutomatedTestBase
+public class FederatedPNMFTest extends AutomatedTestBase
{
- private final static String TEST_NAME = "FederatedAlsCGTest";
+ private final static String TEST_NAME = "FederatedPNMFTest";
private final static String TEST_DIR = "functions/federated/";
- private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedAlsCGTest.class.getSimpleName() + "/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedPNMFTest.class.getSimpleName() + "/";
private final static String OUTPUT_NAME = "Z";
- private final static double TOLERANCE = 0.01;
- private final static int blocksize = 1024;
+ private final static double TOLERANCE = 0.2;
+ private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
public int rows;
@@ -55,14 +55,8 @@ public class FederatedAlsCGTest extends AutomatedTestBase
@Parameterized.Parameter(2)
public int rank;
@Parameterized.Parameter(3)
- public String regression;
- @Parameterized.Parameter(4)
- public double lambda;
- @Parameterized.Parameter(5)
public int max_iter;
- @Parameterized.Parameter(6)
- public double threshold;
- @Parameterized.Parameter(7)
+ @Parameterized.Parameter(4)
public double sparsity;
@Override
@@ -74,9 +68,8 @@ public class FederatedAlsCGTest extends AutomatedTestBase
public static Collection<Object[]> data() {
// rows must be even
return Arrays.asList(new Object[][] {
- // {rows, cols, rank, regression, lambda, max_iter,
threshold, sparsity}
- {30, 15, 10, "L2", 0.0000001, 50, 0.000001, 1},
- {30, 15, 10, "wL2", 0.0000001, 50, 0.000001, 1}
+ // {rows, cols, rank, max_iter, sparsity}
+ {1000, 750, 420, 10, 1}
});
}
@@ -86,35 +79,35 @@ public class FederatedAlsCGTest extends AutomatedTestBase
}
@Test
- public void federatedAlsCGSingleNode() {
- federatedAlsCG(TEST_NAME, ExecMode.SINGLE_NODE);
+ public void federatedPNMFSingleNode() {
+ federatedPNMF(ExecMode.SINGLE_NODE);
}
-// @Test
-// public void federatedAlsCGSpark() {
-// federatedAlsCG(TEST_NAME, ExecMode.SPARK);
-// }
+ @Test
+ public void federatedPNMFSpark() {
+ federatedPNMF(ExecMode.SPARK);
+ }
//
-----------------------------------------------------------------------------
- public void federatedAlsCG(String testname, ExecMode execMode)
+ public void federatedPNMF(ExecMode execMode)
{
// store the previous platform config to restore it after the
test
ExecMode platform_old = setExecMode(execMode);
- getAndLoadTestConfiguration(testname);
+ getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
int fed_rows = rows / 2;
int fed_cols = cols;
+ // generate dataset
+ // matrix handled by two federated workers
double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2,
sparsity, 13);
double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2,
sparsity, 2);
- writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(
- fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
- writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(
- fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
// empty script name because we don't execute any script, just
start the worker
fullDMLScriptName = "";
@@ -123,27 +116,22 @@ public class FederatedAlsCGTest extends AutomatedTestBase
Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
Thread thread2 = startLocalFedWorkerThread(port2);
- getAndLoadTestConfiguration(testname);
+ getAndLoadTestConfiguration(TEST_NAME);
// Run reference dml script with normal matrix
- fullDMLScriptName = HOME + testname + "Reference.dml";
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-stats", "-nvargs",
- "in_X1=" + input("X1"), "in_X2=" + input("X2"),
"in_rank=" + Integer.toString(rank),
- "in_reg=" + regression, "in_lambda=" +
Double.toString(lambda),
- "in_maxi=" + Integer.toString(max_iter), "in_thr=" +
Double.toString(threshold),
+ "in_X1=" + input("X1"), "in_X2=" + input("X2"),
"in_rank=" + Integer.toString(rank), "in_max_iter=" +
Integer.toString(max_iter),
"out_Z=" + expected(OUTPUT_NAME)};
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
- fullDMLScriptName = HOME + testname + ".dml";
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
"in_rank=" + Integer.toString(rank),
- "in_reg=" + regression,
- "in_lambda=" + Double.toString(lambda),
- "in_maxi=" + Integer.toString(max_iter),
- "in_thr=" + Double.toString(threshold),
+ "in_max_iter=" + Integer.toString(max_iter),
"rows=" + fed_rows, "cols=" + fed_cols,
"out_Z=" + output(OUTPUT_NAME)};
runTest(true, false, null, -1);
@@ -156,15 +144,13 @@ public class FederatedAlsCGTest extends AutomatedTestBase
TestUtils.shutdownThreads(thread1, thread2);
// check for federated operations
- Assert.assertTrue(heavyHittersContainsString("fed_!="));
- Assert.assertTrue(heavyHittersContainsString("fed_fedinit"));
+ Assert.assertTrue(heavyHittersContainsString("fed_wcemm"));
Assert.assertTrue(heavyHittersContainsString("fed_wdivmm"));
- Assert.assertTrue(heavyHittersContainsString("fed_wsloss"));
+ Assert.assertTrue(heavyHittersContainsString("fed_fedinit"));
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
-
resetExecMode(platform_old);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
index bf676a3..655124d 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
@@ -47,7 +47,7 @@ public class FederatedWeightedCrossEntropyTest extends
AutomatedTestBase
private final static String OUTPUT_NAME = "Z";
private final static double TOLERANCE = 1e-9;
- private final static int blocksize = 1024;
+ private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
public int rows;
@@ -124,8 +124,8 @@ public class FederatedWeightedCrossEntropyTest extends
AutomatedTestBase
double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
- writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
- writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
writeInputMatrixWithMTD("U", U, true);
writeInputMatrixWithMTD("V", V, true);
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java
index 39a79bb..15c192b 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java
@@ -60,7 +60,7 @@ public class FederatedWeightedDivMatrixMultTest extends
AutomatedTestBase
private final static double TOLERANCE = 1e-9;
- private final static int blocksize = 1024;
+ private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
public int rows;
@@ -256,11 +256,11 @@ public class FederatedWeightedDivMatrixMultTest extends
AutomatedTestBase
double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
- writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
- writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
- writeInputMatrixWithMTD("U", U, true, new
MatrixCharacteristics(rows, rank, blocksize, rows * rank));
- writeInputMatrixWithMTD("V", V, true, new
MatrixCharacteristics(cols, rank, blocksize, rows * rank));
+ writeInputMatrixWithMTD("U", U, true, new
MatrixCharacteristics(rows, rank, BLOCKSIZE, rows * rank));
+ writeInputMatrixWithMTD("V", V, true, new
MatrixCharacteristics(cols, rank, BLOCKSIZE, rows * rank));
// empty script name because we don't execute any script, just
start the worker
fullDMLScriptName = "";
@@ -270,7 +270,7 @@ public class FederatedWeightedDivMatrixMultTest extends
AutomatedTestBase
Thread thread2 = startLocalFedWorkerThread(port2);
getAndLoadTestConfiguration(test_name);
-
+
try {
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + test_name + "Reference.dml";
@@ -278,7 +278,7 @@ public class FederatedWeightedDivMatrixMultTest extends
AutomatedTestBase
"in_U=" + input("U"), "in_V=" + input("V"),
"in_W=" + Double.toString(epsilon),
"out_Z=" + expected(OUTPUT_NAME)};
runTest(true, false, null, -1);
-
+
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + test_name + ".dml";
programArgs = new String[] {"-stats", "-nvargs",
@@ -289,22 +289,22 @@ public class FederatedWeightedDivMatrixMultTest extends
AutomatedTestBase
"in_W=" + Double.toString(epsilon),
"rows=" + fed_rows, "cols=" + fed_cols,
"out_Z=" + output(OUTPUT_NAME)};
runTest(true, false, null, -1);
-
+
// compare the results via files
HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults,
TOLERANCE, "Fed", "Ref");
-
+
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_wdivmm"));
-
+
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
}
finally {
TestUtils.shutdownThreads(thread1, thread2);
-
+
resetExecMode(platform_old);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
index e73ce82..ec800b0 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
@@ -50,7 +50,7 @@ public class FederatedWeightedSigmoidTest extends
AutomatedTestBase {
private final static double TOLERANCE = 0;
- private final static int blocksize = 1024;
+ private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
public int rows;
@@ -151,11 +151,11 @@ public class FederatedWeightedSigmoidTest extends
AutomatedTestBase {
writeInputMatrixWithMTD("X1",
X1,
false,
- new MatrixCharacteristics(fed_rows, fed_cols,
blocksize, fed_rows * fed_cols));
+ new MatrixCharacteristics(fed_rows, fed_cols,
BLOCKSIZE, fed_rows * fed_cols));
writeInputMatrixWithMTD("X2",
X2,
false,
- new MatrixCharacteristics(fed_rows, fed_cols,
blocksize, fed_rows * fed_cols));
+ new MatrixCharacteristics(fed_rows, fed_cols,
BLOCKSIZE, fed_rows * fed_cols));
writeInputMatrixWithMTD("U", U, true);
writeInputMatrixWithMTD("V", V, true);
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
index 9b0f7a7..782891c 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
@@ -50,7 +50,7 @@ public class FederatedWeightedSquaredLossTest extends
AutomatedTestBase {
private final static double TOLERANCE = 1e-8;
- private final static int blocksize = 1024;
+ private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
public int rows;
@@ -138,11 +138,11 @@ public class FederatedWeightedSquaredLossTest extends
AutomatedTestBase {
writeInputMatrixWithMTD("X1",
X1,
false,
- new MatrixCharacteristics(fed_rows, fed_cols,
blocksize, fed_rows * fed_cols));
+ new MatrixCharacteristics(fed_rows, fed_cols,
BLOCKSIZE, fed_rows * fed_cols));
writeInputMatrixWithMTD("X2",
X2,
false,
- new MatrixCharacteristics(fed_rows, fed_cols,
blocksize, fed_rows * fed_cols));
+ new MatrixCharacteristics(fed_rows, fed_cols,
BLOCKSIZE, fed_rows * fed_cols));
writeInputMatrixWithMTD("U", U, true);
writeInputMatrixWithMTD("V", V, true);
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
index 581d27d..8cc582a 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
@@ -51,7 +51,7 @@ public class FederatedWeightedUnaryMatrixMultTest extends
AutomatedTestBase
private final static double TOLERANCE = 0;
- private final static int blocksize = 1024;
+ private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
public int rows;
@@ -147,11 +147,11 @@ public class FederatedWeightedUnaryMatrixMultTest extends
AutomatedTestBase
double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
- writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
- writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
- writeInputMatrixWithMTD("U", U, false, new
MatrixCharacteristics(rows, rank, blocksize, rows * rank));
- writeInputMatrixWithMTD("V", V, false, new
MatrixCharacteristics(cols, rank, blocksize, rows * rank));
+ writeInputMatrixWithMTD("U", U, false, new
MatrixCharacteristics(rows, rank, BLOCKSIZE, rows * rank));
+ writeInputMatrixWithMTD("V", V, false, new
MatrixCharacteristics(cols, rank, BLOCKSIZE, rows * rank));
// empty script name because we don't execute any script, just
start the worker
fullDMLScriptName = "";
@@ -169,7 +169,7 @@ public class FederatedWeightedUnaryMatrixMultTest extends
AutomatedTestBase
"in_U=" + input("U"), "in_V=" + input("V"),
"out_Z=" + expected(OUTPUT_NAME)};
runTest(true, false, null, -1);
-
+
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + test_name + ".dml";
programArgs = new String[] {"-stats", "-nvargs",
@@ -179,22 +179,22 @@ public class FederatedWeightedUnaryMatrixMultTest extends
AutomatedTestBase
"in_V=" + input("V"),
"rows=" + fed_rows, "cols=" + fed_cols,
"out_Z=" + output(OUTPUT_NAME)};
runTest(true, false, null, -1);
-
+
// compare the results via files
HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
TestUtils.compareMatrices(fedResults, refResults,
TOLERANCE, "Fed", "Ref");
-
+
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_wumm"));
-
+
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
}
finally {
TestUtils.shutdownThreads(thread1, thread2);
-
+
resetExecMode(platform_old);
}
}
diff --git a/src/test/scripts/functions/federated/FederatedPNMFTest.dml
b/src/test/scripts/functions/federated/FederatedPNMFTest.dml
new file mode 100644
index 0000000..e8b01c9
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPNMFTest.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2,
$cols)));
+
+rank = $in_rank;
+max_iter = $in_max_iter;
+
+[W, H] = pnmf(X = X, rnk = rank, maxi = max_iter);
+
+Z = W %*% H;
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/FederatedPNMFTestReference.dml
b/src/test/scripts/functions/federated/FederatedPNMFTestReference.dml
new file mode 100644
index 0000000..b501cf9
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPNMFTestReference.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+
+rank = $in_rank;
+max_iter = $in_max_iter;
+
+[W, H] = pnmf(X = X, rnk = rank, maxi = max_iter);
+
+Z = W %*% H;
+
+write(Z, $out_Z);