This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new f487c18  [SYSTEMDS-2747] Federated PNMF test, extended quaternary 
operations
f487c18 is described below

commit f487c187af9d4202787620dfa34bddd603f93585
Author: ywcb00 <[email protected]>
AuthorDate: Sat Feb 20 17:14:57 2021 +0100

    [SYSTEMDS-2747] Federated PNMF test, extended quaternary operations
    
    Closes #1175.
---
 .../fed/QuaternaryWCeMMFEDInstruction.java         |  87 ++++++-------
 .../fed/QuaternaryWDivMMFEDInstruction.java        | 135 +++++++++++----------
 .../fed/QuaternaryWSLossFEDInstruction.java        |  81 +++++++------
 .../fed/QuaternaryWSigmoidFEDInstruction.java      |  46 +++----
 .../fed/QuaternaryWUMMFEDInstruction.java          |  42 ++++---
 .../apache/sysds/runtime/lineage/LineageItem.java  |   2 +-
 .../federated/algorithms/FederatedAlsCGTest.java   |   6 +-
 ...eratedAlsCGTest.java => FederatedPNMFTest.java} |  78 +++++-------
 .../FederatedWeightedCrossEntropyTest.java         |   6 +-
 .../FederatedWeightedDivMatrixMultTest.java        |  22 ++--
 .../primitives/FederatedWeightedSigmoidTest.java   |   6 +-
 .../FederatedWeightedSquaredLossTest.java          |   6 +-
 .../FederatedWeightedUnaryMatrixMultTest.java      |  20 +--
 .../functions/federated/FederatedPNMFTest.dml      |  32 +++++
 .../federated/FederatedPNMFTestReference.dml       |  31 +++++
 15 files changed, 331 insertions(+), 269 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
index 8566b39..68efe5d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
@@ -29,6 +29,7 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
@@ -59,58 +60,60 @@ public class QuaternaryWCeMMFEDInstruction extends 
QuaternaryFEDInstruction
                MatrixObject U = ec.getMatrixObject(input2);
                MatrixObject V = ec.getMatrixObject(input3);
                ScalarObject eps = null;
-               
+
                if(qop.hasFourInputs()) {
                        eps = (_input4.getDataType() == DataType.SCALAR) ?
                                ec.getScalarInput(_input4) :
                                new 
DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
                }
 
-               if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
-                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = ("
-                               +X.isFederated()+", "+U.isFederated()+", 
"+V.isFederated()+")");
-               
-               FederationMap fedMap = X.getFedMapping();
-               FederatedRequest[] fr1 = fedMap.broadcastSliced(U, false);
-               FederatedRequest fr2 = fedMap.broadcast(V);
-               FederatedRequest fr3 = null;
-               FederatedRequest frComp = null;
+               if(X.isFederated(FType.ROW) && !U.isFederated() && 
!V.isFederated()) {
+                       FederationMap fedMap = X.getFedMapping();
+                       FederatedRequest[] fr1 = fedMap.broadcastSliced(U, 
false);
+                       FederatedRequest fr2 = fedMap.broadcast(V);
+                       FederatedRequest fr3 = null;
+                       FederatedRequest frComp = null;
 
-               // broadcast the scalar epsilon if there are four inputs
-               if(eps != null) {
-                       fr3 = fedMap.broadcast(eps);
-                       // change the is_literal flag from true to false 
because when broadcasted it is no literal anymore
-                       instString = instString.replace("true", "false");
-                       frComp = FederationUtils.callInstruction(instString, 
output,
-                               new CPOperand[]{input1, input2, input3, 
_input4},
-                               new long[]{fedMap.getID(), fr1[0].getID(), 
fr2.getID(), fr3.getID()});
-               }
-               else {
-                       frComp = FederationUtils.callInstruction(instString, 
output,
-                       new CPOperand[]{input1, input2, input3},
-                       new long[]{fedMap.getID(), fr1[0].getID(), 
fr2.getID()});
-               }
-               
-               FederatedRequest frGet = new 
FederatedRequest(RequestType.GET_VAR, frComp.getID());
-               FederatedRequest frClean1 = fedMap.cleanup(getTID(), 
frComp.getID());
-               FederatedRequest frClean2 = fedMap.cleanup(getTID(), 
fr1[0].getID());
-               FederatedRequest frClean3 = fedMap.cleanup(getTID(), 
fr2.getID());
+                       // broadcast the scalar epsilon if there are four inputs
+                       if(eps != null) {
+                               fr3 = fedMap.broadcast(eps);
+                               // change the is_literal flag from true to 
false because when broadcasted it is no literal anymore
+                               instString = instString.replace("true", 
"false");
+                               frComp = 
FederationUtils.callInstruction(instString, output,
+                                       new CPOperand[]{input1, input2, input3, 
_input4},
+                                       new long[]{fedMap.getID(), 
fr1[0].getID(), fr2.getID(), fr3.getID()});
+                       }
+                       else {
+                               frComp = 
FederationUtils.callInstruction(instString, output,
+                               new CPOperand[]{input1, input2, input3},
+                               new long[]{fedMap.getID(), fr1[0].getID(), 
fr2.getID()});
+                       }
+
+                       FederatedRequest frGet = new 
FederatedRequest(RequestType.GET_VAR, frComp.getID());
+                       FederatedRequest frClean1 = fedMap.cleanup(getTID(), 
frComp.getID());
+                       FederatedRequest frClean2 = fedMap.cleanup(getTID(), 
fr1[0].getID());
+                       FederatedRequest frClean3 = fedMap.cleanup(getTID(), 
fr2.getID());
 
-               Future<FederatedResponse>[] response;
-               if(fr3 != null) {
-                       FederatedRequest frClean4 = fedMap.cleanup(getTID(), 
fr3.getID());
-                       // execute federated instructions
-                       response = fedMap.execute(getTID(), true, fr1, fr2, fr3,
-                               frComp, frGet, frClean1, frClean2, frClean3, 
frClean4);
+                       Future<FederatedResponse>[] response;
+                       if(fr3 != null) {
+                               FederatedRequest frClean4 = 
fedMap.cleanup(getTID(), fr3.getID());
+                               // execute federated instructions
+                               response = fedMap.execute(getTID(), true, fr1, 
fr2, fr3,
+                                       frComp, frGet, frClean1, frClean2, 
frClean3, frClean4);
+                       }
+                       else {
+                               // execute federated instructions
+                               response = fedMap.execute(getTID(), true, fr1, 
fr2,
+                                       frComp, frGet, frClean1, frClean2, 
frClean3);
+                       }
+
+                       //aggregate partial results from federated responses
+                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+                       ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, response));
                }
                else {
-                       // execute federated instructions
-                       response = fedMap.execute(getTID(), true, fr1, fr2,
-                               frComp, frGet, frClean1, frClean2, frClean3);
+                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = ("
+                               + X.isFederated() + ", " + U.isFederated() + ", 
" + V.isFederated() + ")");
                }
-               
-               //aggregate partial results from federated responses
-               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-               ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, 
response));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
index 5ba2b59..877b9c5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
@@ -86,79 +86,82 @@ public class QuaternaryWDivMMFEDInstruction extends 
QuaternaryFEDInstruction
                        }
                }
 
-               if(!(X.isFederated(FType.ROW) && !U.isFederated() && 
!V.isFederated()))
-                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = ("
-                               +X.isFederated()+", "+U.isFederated()+", 
"+V.isFederated() + ")");
-
-               FederationMap fedMap = X.getFedMapping();
-               FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
-               FederatedRequest frInit2 = fedMap.broadcast(V);
+               if(X.isFederated(FType.ROW) && !U.isFederated() && 
!V.isFederated()) {
+                       FederationMap fedMap = X.getFedMapping();
+                       FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, 
false);
+                       FederatedRequest frInit2 = fedMap.broadcast(V);
 
-               FederatedRequest frInit3 = null;
-               FederatedRequest frInit3Arr[] = null;
-               FederatedRequest frCompute1 = null;
-               // broadcast scalar epsilon if there are four inputs
-               if(eps != null) {
-                       frInit3 = fedMap.broadcast(eps);
-                       // change the is_literal flag from true to false 
because when broadcasted it is no literal anymore
-                       instString = instString.replace("true", "false");
-                       frCompute1 = 
FederationUtils.callInstruction(instString, output,
-                               new CPOperand[]{input1, input2, input3, 
_input4},
-                               new long[]{fedMap.getID(), frInit1[0].getID(), 
frInit2.getID(), frInit3.getID()});
-               }
-               else if(MX != null) {
-                       frInit3Arr = fedMap.broadcastSliced(MX, false);
-                       frCompute1 = 
FederationUtils.callInstruction(instString, output,
-                               new CPOperand[]{input1, input2, input3, 
_input4},
-                               new long[]{fedMap.getID(), frInit1[0].getID(), 
frInit2.getID(), frInit3Arr[0].getID()});
-               }
-               else {
-                       frCompute1 = 
FederationUtils.callInstruction(instString, output,
-                               new CPOperand[]{input1, input2, input3},
-                               new long[]{fedMap.getID(), frInit1[0].getID(), 
frInit2.getID()});
-               }
+                       FederatedRequest frInit3 = null;
+                       FederatedRequest frInit3Arr[] = null;
+                       FederatedRequest frCompute1 = null;
+                       // broadcast scalar epsilon if there are four inputs
+                       if(eps != null) {
+                               frInit3 = fedMap.broadcast(eps);
+                               // change the is_literal flag from true to 
false because when broadcasted it is no literal anymore
+                               instString = instString.replace("true", 
"false");
+                               frCompute1 = 
FederationUtils.callInstruction(instString, output,
+                                       new CPOperand[]{input1, input2, input3, 
_input4},
+                                       new long[]{fedMap.getID(), 
frInit1[0].getID(), frInit2.getID(), frInit3.getID()});
+                       }
+                       else if(MX != null) {
+                               frInit3Arr = fedMap.broadcastSliced(MX, false);
+                               frCompute1 = 
FederationUtils.callInstruction(instString, output,
+                                       new CPOperand[]{input1, input2, input3, 
_input4},
+                                       new long[]{fedMap.getID(), 
frInit1[0].getID(), frInit2.getID(), frInit3Arr[0].getID()});
+                       }
+                       else {
+                               frCompute1 = 
FederationUtils.callInstruction(instString, output,
+                                       new CPOperand[]{input1, input2, input3},
+                                       new long[]{fedMap.getID(), 
frInit1[0].getID(), frInit2.getID()});
+                       }
 
-               // get partial results from federated workers
-               FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+                       // get partial results from federated workers
+                       FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
 
-               FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
-               FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
-               FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
+                       FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
+                       FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
+                       FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
 
-               // execute federated instructions
-               Future<FederatedResponse>[] response;
-               if(frInit3 != null) {
-                       FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), 
frInit3.getID());
-                       response = fedMap.execute(getTID(), true,
-                               frInit1, frInit2, frInit3,
-                               frCompute1, frGet1,
-                               frCleanup1, frCleanup2, frCleanup3, frCleanup4);
-               }
-               else if(frInit3Arr != null) {
-                       FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), 
frInit3Arr[0].getID());
-                       fedMap.execute(getTID(), true, frInit1, frInit2);
-                       response = fedMap.execute(getTID(), true, frInit3Arr,
-                               frCompute1, frGet1,
-                               frCleanup1, frCleanup2, frCleanup3, frCleanup4);
-               }
-               else {
-                       response = fedMap.execute(getTID(), true,
-                               frInit1, frInit2,
-                               frCompute1, frGet1,
-                               frCleanup1, frCleanup2, frCleanup3);
-               }
+                       // execute federated instructions
+                       Future<FederatedResponse>[] response;
+                       if(frInit3 != null) {
+                               FederatedRequest frCleanup4 = 
fedMap.cleanup(getTID(), frInit3.getID());
+                               response = fedMap.execute(getTID(), true,
+                                       frInit1, frInit2, frInit3,
+                                       frCompute1, frGet1,
+                                       frCleanup1, frCleanup2, frCleanup3, 
frCleanup4);
+                       }
+                       else if(frInit3Arr != null) {
+                               FederatedRequest frCleanup4 = 
fedMap.cleanup(getTID(), frInit3Arr[0].getID());
+                               fedMap.execute(getTID(), true, frInit1, 
frInit2);
+                               response = fedMap.execute(getTID(), true, 
frInit3Arr,
+                                       frCompute1, frGet1,
+                                       frCleanup1, frCleanup2, frCleanup3, 
frCleanup4);
+                       }
+                       else {
+                               response = fedMap.execute(getTID(), true,
+                                       frInit1, frInit2,
+                                       frCompute1, frGet1,
+                                       frCleanup1, frCleanup2, frCleanup3);
+                       }
 
-               if(wdivmm_type.isLeft()) {
-                       // aggregate partial results from federated responses
-                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-                       ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
-               }
-               else if(wdivmm_type.isRight() || wdivmm_type.isBasic()) {
-                       // bind partial results from federated responses
-                       ec.setMatrixOutput(output.getName(), 
FederationUtils.bind(response, false));
+                       if(wdivmm_type.isLeft()) {
+                               // aggregate partial results from federated 
responses
+                               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+                               ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+                       }
+                       else if(wdivmm_type.isRight() || wdivmm_type.isBasic()) 
{
+                               // bind partial results from federated responses
+                               ec.setMatrixOutput(output.getName(), 
FederationUtils.bind(response, false));
+                       }
+                       else {
+                               throw new DMLRuntimeException("Federated WDivMM 
only supported for BASIC, LEFT or RIGHT variants.");
+                       }
                }
                else {
-                       throw new DMLRuntimeException("Federated WDivMM only 
supported for BASIC, LEFT or RIGHT variants.");
+                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = ("
+                               + X.isFederated() + ", " + U.isFederated() + ", 
" + V.isFederated() + ")");
                }
        }
 }
+
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
index 664fbdc..f65d4f0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSLossFEDInstruction.java
@@ -25,6 +25,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -69,51 +70,53 @@ public class QuaternaryWSLossFEDInstruction extends 
QuaternaryFEDInstruction {
                        W = ec.getMatrixObject(_input4);
                }
 
-               if(!(X.isFederated() && !U.isFederated() && !V.isFederated() && 
(W == null || !W.isFederated())))
-                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V, W) = (" + X.isFederated() + ", "
-                               + U.isFederated() + ", " + V.isFederated() + ", 
" + (W != null ? W.isFederated() : "none") + ")");
+               if(X.isFederated(FType.ROW) && !U.isFederated() && 
!V.isFederated() && (W == null || !W.isFederated())) {
+                       FederationMap fedMap = X.getFedMapping();
+                       FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, 
false);
+                       FederatedRequest frInit2 = fedMap.broadcast(V);
 
-               FederationMap fedMap = X.getFedMapping();
-               FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
-               FederatedRequest frInit2 = fedMap.broadcast(V);
+                       FederatedRequest[] frInit3 = null;
+                       FederatedRequest frCompute1 = null;
+                       if(W != null) {
+                               frInit3 = fedMap.broadcastSliced(W, false);
+                               frCompute1 = 
FederationUtils.callInstruction(instString,
+                                       output,
+                                       new CPOperand[] {input1, input2, 
input3, _input4},
+                                       new long[] {fedMap.getID(), 
frInit1[0].getID(), frInit2.getID(), frInit3[0].getID()});
+                       }
+                       else {
+                               frCompute1 = 
FederationUtils.callInstruction(instString,
+                                       output,
+                                       new CPOperand[] {input1, input2, 
input3},
+                                       new long[] {fedMap.getID(), 
frInit1[0].getID(), frInit2.getID()});
+                       }
 
-               FederatedRequest[] frInit3 = null;
-               FederatedRequest frCompute1 = null;
-               if(W != null) {
-                       frInit3 = fedMap.broadcastSliced(W, false);
-                       frCompute1 = FederationUtils.callInstruction(instString,
-                               output,
-                               new CPOperand[] {input1, input2, input3, 
_input4},
-                               new long[] {fedMap.getID(), frInit1[0].getID(), 
frInit2.getID(), frInit3[0].getID()});
-               }
-               else {
-                       frCompute1 = FederationUtils.callInstruction(instString,
-                               output,
-                               new CPOperand[] {input1, input2, input3},
-                               new long[] {fedMap.getID(), frInit1[0].getID(), 
frInit2.getID()});
-               }
+                       FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+                       FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
+                       FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
+                       FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
 
-               FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
-               FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
-               FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
-               FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
+                       Future<FederatedResponse>[] response;
+                       if(frInit3 != null) {
+                               FederatedRequest frCleanup4 = 
fedMap.cleanup(getTID(), frInit3[0].getID());
+                               // execute federated instructions
+                               fedMap.execute(getTID(), true, frInit1, 
frInit2);
+                               response = fedMap
+                                       .execute(getTID(), true, frInit3, 
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4);
+                       }
+                       else {
+                               // execute federated instructions
+                               response = fedMap
+                                       .execute(getTID(), true, frInit1, 
frInit2, frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
+                       }
 
-               Future<FederatedResponse>[] response;
-               if(frInit3 != null) {
-                       FederatedRequest frCleanup4 = fedMap.cleanup(getTID(), 
frInit3[0].getID());
-                       // execute federated instructions
-                       fedMap.execute(getTID(), true, frInit1, frInit2);
-                       response = fedMap
-                               .execute(getTID(), true, frInit3, frCompute1, 
frGet1, frCleanup1, frCleanup2, frCleanup3, frCleanup4);
+                       // aggregate partial results from federated responses
+                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+                       ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, response));
                }
                else {
-                       // execute federated instructions
-                       response = fedMap
-                               .execute(getTID(), true, frInit1, frInit2, 
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
+                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V, W) = (" + X.isFederated() + ", "
+                               + U.isFederated() + ", " + V.isFederated() + ", 
" + (W != null ? W.isFederated() : "none") + ")");
                }
-
-               // aggregate partial results from federated responses
-               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-               ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, 
response));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
index 9884e3b..95caaef 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWSigmoidFEDInstruction.java
@@ -28,6 +28,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -58,32 +59,33 @@ public class QuaternaryWSigmoidFEDInstruction extends 
QuaternaryFEDInstruction {
                MatrixObject U = ec.getMatrixObject(input2);
                MatrixObject V = ec.getMatrixObject(input3);
 
-               if(!(X.isFederated() && !U.isFederated() && !V.isFederated()))
-                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = (" + X.isFederated() + ", "
-                               + U.isFederated() + ", " + V.isFederated() + 
")");
+               if(X.isFederated(FType.ROW) && !U.isFederated() && 
!V.isFederated()) {
+                       FederationMap fedMap = X.getFedMapping();
+                       FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, 
false);
+                       FederatedRequest frInit2 = fedMap.broadcast(V);
 
-               FederationMap fedMap = X.getFedMapping();
-               FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
-               FederatedRequest frInit2 = fedMap.broadcast(V);
+                       FederatedRequest frCompute1 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {input1, input2, input3},
+                               new long[] {fedMap.getID(), frInit1[0].getID(), 
frInit2.getID()});
 
-               FederatedRequest frCompute1 = 
FederationUtils.callInstruction(instString,
-                       output,
-                       new CPOperand[] {input1, input2, input3},
-                       new long[] {fedMap.getID(), frInit1[0].getID(), 
frInit2.getID()});
+                       // get partial results from federated workers
+                       FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
 
-               // get partial results from federated workers
-               FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+                       FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
+                       FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
+                       FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
 
-               FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
-               FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
-               FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
-
-               // execute federated instructions
-               Future<FederatedResponse>[] response = fedMap
-                       .execute(getTID(), true, frInit1, frInit2, frCompute1, 
frGet1, frCleanup1, frCleanup2, frCleanup3);
-
-               // bind partial results from federated responses
-               ec.setMatrixOutput(output.getName(), 
FederationUtils.bind(response, false));
+                       // execute federated instructions
+                       Future<FederatedResponse>[] response = fedMap
+                               .execute(getTID(), true, frInit1, frInit2, 
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
 
+                       // bind partial results from federated responses
+                       ec.setMatrixOutput(output.getName(), 
FederationUtils.bind(response, false));
+               }
+               else {
+                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = (" 
+                               + X.isFederated() + ", " + U.isFederated() + ", 
" + V.isFederated() + ")");
+               }
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
index 82bc9e2..2512439 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.java
@@ -60,30 +60,32 @@ public class QuaternaryWUMMFEDInstruction extends 
QuaternaryFEDInstruction {
                MatrixObject U = ec.getMatrixObject(input2);
                MatrixObject V = ec.getMatrixObject(input3);
 
-               if(!(X.isFederated(FType.ROW) && !U.isFederated() && 
!V.isFederated()))
-                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = (" + X.isFederated() + ", "
-                               + U.isFederated() + ", " + V.isFederated() + 
")");
+               if(X.isFederated(FType.ROW) && !U.isFederated() && 
!V.isFederated()) {
+                       FederationMap fedMap = X.getFedMapping();
+                       FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, 
false);
+                       FederatedRequest frInit2 = fedMap.broadcast(V);
 
-               FederationMap fedMap = X.getFedMapping();
-               FederatedRequest[] frInit1 = fedMap.broadcastSliced(U, false);
-               FederatedRequest frInit2 = fedMap.broadcast(V);
+                       FederatedRequest frCompute1 = 
FederationUtils.callInstruction(instString,
+                               output, new CPOperand[] {input1, input2, 
input3},
+                               new long[] {fedMap.getID(), frInit1[0].getID(), 
frInit2.getID()});
 
-               FederatedRequest frCompute1 = 
FederationUtils.callInstruction(instString,
-                       output, new CPOperand[] {input1, input2, input3},
-                       new long[] {fedMap.getID(), frInit1[0].getID(), 
frInit2.getID()});
+                       // get partial results from federated workers
+                       FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
 
-               // get partial results from federated workers
-               FederatedRequest frGet1 = new 
FederatedRequest(RequestType.GET_VAR, frCompute1.getID());
+                       FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
+                       FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
+                       FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
 
-               FederatedRequest frCleanup1 = fedMap.cleanup(getTID(), 
frCompute1.getID());
-               FederatedRequest frCleanup2 = fedMap.cleanup(getTID(), 
frInit1[0].getID());
-               FederatedRequest frCleanup3 = fedMap.cleanup(getTID(), 
frInit2.getID());
+                       // execute federated instructions
+                       Future<FederatedResponse>[] response = fedMap
+                               .execute(getTID(), true, frInit1, frInit2, 
frCompute1, frGet1, frCleanup1, frCleanup2, frCleanup3);
 
-               // execute federated instructions
-               Future<FederatedResponse>[] response = fedMap
-                       .execute(getTID(), true, frInit1, frInit2, frCompute1, 
frGet1, frCleanup1, frCleanup2, frCleanup3);
-
-               // bind partial results from federated responses
-               ec.setMatrixOutput(output.getName(), 
FederationUtils.bind(response, false));
+                       // bind partial results from federated responses
+                       ec.setMatrixOutput(output.getName(), 
FederationUtils.bind(response, false));
+               }
+               else {
+                       throw new DMLRuntimeException("Unsupported federated 
inputs (X, U, V) = (" 
+                               + X.isFederated() + ", " + U.isFederated() + ", 
" + V.isFederated() + ")");
+               }
        }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
index 2210fff..cd2346d 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
@@ -354,7 +354,7 @@ public class LineageItem {
        
        // Compare a dedup patch with a sub-DAG, and map the inputs of the 
sub-dag
        // to the placeholder inputs of the dedup patch
-       private boolean equalsDedupPatch(LineageItem dli1, LineageItem dli2, 
Map<Integer, LineageItem> phMap) {
+       private static boolean equalsDedupPatch(LineageItem dli1, LineageItem 
dli2, Map<Integer, LineageItem> phMap) {
                Stack<LineageItem> s1 = new Stack<>();
                Stack<LineageItem> s2 = new Stack<>();
                s1.push(dli1);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
index 4909f7c..9263beb 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
@@ -46,7 +46,7 @@ public class FederatedAlsCGTest extends AutomatedTestBase
 
        private final static String OUTPUT_NAME = "Z";
        private final static double TOLERANCE = 0.01;
-       private final static int blocksize = 1024;
+       private final static int BLOCKSIZE = 1024;
 
        @Parameterized.Parameter()
        public int rows;
@@ -112,9 +112,9 @@ public class FederatedAlsCGTest extends AutomatedTestBase
                double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 
sparsity, 2);
 
                writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(
-                       fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+                       fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
                writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(
-                       fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+                       fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
 
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
similarity index 69%
copy from 
src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
copy to 
src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
index 4909f7c..00358c7 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedAlsCGTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPNMFTest.java
@@ -1,18 +1,18 @@
 /*
  * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.     See the NOTICE file
+ * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
- * regarding copyright ownership.      The ASF licenses this file
+ * regarding copyright ownership.  The ASF licenses this file
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
- * with the License.   You may obtain a copy of the License at
+ * with the License.  You may obtain a copy of the License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ *   http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.    See the License for the
+ * KIND, either express or implied.  See the License for the
  * specific language governing permissions and limitations
  * under the License.
  */
@@ -38,15 +38,15 @@ import java.util.HashMap;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
-public class FederatedAlsCGTest extends AutomatedTestBase
+public class FederatedPNMFTest extends AutomatedTestBase
 {
-       private final static String TEST_NAME = "FederatedAlsCGTest";
+       private final static String TEST_NAME = "FederatedPNMFTest";
        private final static String TEST_DIR = "functions/federated/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedAlsCGTest.class.getSimpleName() + "/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedPNMFTest.class.getSimpleName() + "/";
 
        private final static String OUTPUT_NAME = "Z";
-       private final static double TOLERANCE = 0.01;
-       private final static int blocksize = 1024;
+       private final static double TOLERANCE = 0.2;
+       private final static int BLOCKSIZE = 1024;
 
        @Parameterized.Parameter()
        public int rows;
@@ -55,14 +55,8 @@ public class FederatedAlsCGTest extends AutomatedTestBase
        @Parameterized.Parameter(2)
        public int rank;
        @Parameterized.Parameter(3)
-       public String regression;
-       @Parameterized.Parameter(4)
-       public double lambda;
-       @Parameterized.Parameter(5)
        public int max_iter;
-       @Parameterized.Parameter(6)
-       public double threshold;
-       @Parameterized.Parameter(7)
+       @Parameterized.Parameter(4)
        public double sparsity;
 
        @Override
@@ -74,9 +68,8 @@ public class FederatedAlsCGTest extends AutomatedTestBase
        public static Collection<Object[]> data() {
                // rows must be even
                return Arrays.asList(new Object[][] {
-                       // {rows, cols, rank, regression, lambda, max_iter, 
threshold, sparsity}
-                       {30, 15, 10, "L2", 0.0000001, 50, 0.000001, 1},
-                       {30, 15, 10, "wL2", 0.0000001, 50, 0.000001, 1}
+                       // {rows, cols, rank, max_iter, sparsity}
+                       {1000, 750, 420, 10, 1}
                });
        }
 
@@ -86,35 +79,35 @@ public class FederatedAlsCGTest extends AutomatedTestBase
        }
 
        @Test
-       public void federatedAlsCGSingleNode() {
-               federatedAlsCG(TEST_NAME, ExecMode.SINGLE_NODE);
+       public void federatedPNMFSingleNode() {
+               federatedPNMF(ExecMode.SINGLE_NODE);
        }
 
-//     @Test
-//     public void federatedAlsCGSpark() {
-//             federatedAlsCG(TEST_NAME, ExecMode.SPARK);
-//     }
+       @Test
+       public void federatedPNMFSpark() {
+               federatedPNMF(ExecMode.SPARK);
+       }
 
 // 
-----------------------------------------------------------------------------
 
-       public void federatedAlsCG(String testname, ExecMode execMode)
+       public void federatedPNMF(ExecMode execMode)
        {
                // store the previous platform config to restore it after the 
test
                ExecMode platform_old = setExecMode(execMode);
 
-               getAndLoadTestConfiguration(testname);
+               getAndLoadTestConfiguration(TEST_NAME);
                String HOME = SCRIPT_DIR + TEST_DIR;
 
                int fed_rows = rows / 2;
                int fed_cols = cols;
 
+               // generate dataset
+               // matrix handled by two federated workers
                double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 
sparsity, 13);
                double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 1, 2, 
sparsity, 2);
 
-               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(
-                       fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
-               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(
-                       fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
 
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
@@ -123,27 +116,22 @@ public class FederatedAlsCGTest extends AutomatedTestBase
                Thread thread1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
                Thread thread2 = startLocalFedWorkerThread(port2);
 
-               getAndLoadTestConfiguration(testname);
+               getAndLoadTestConfiguration(TEST_NAME);
 
                // Run reference dml script with normal matrix
-               fullDMLScriptName = HOME + testname + "Reference.dml";
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
                programArgs = new String[] {"-stats", "-nvargs",
-                       "in_X1=" + input("X1"), "in_X2=" + input("X2"), 
"in_rank=" + Integer.toString(rank),
-                       "in_reg=" + regression, "in_lambda=" + 
Double.toString(lambda),
-                       "in_maxi=" + Integer.toString(max_iter), "in_thr=" + 
Double.toString(threshold),
+                       "in_X1=" + input("X1"), "in_X2=" + input("X2"), 
"in_rank=" + Integer.toString(rank), "in_max_iter=" + 
Integer.toString(max_iter),
                        "out_Z=" + expected(OUTPUT_NAME)};
                runTest(true, false, null, -1);
 
                // Run actual dml script with federated matrix
-               fullDMLScriptName = HOME + testname + ".dml";
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-stats", "-nvargs",
                        "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
                        "in_rank=" + Integer.toString(rank),
-                       "in_reg=" + regression,
-                       "in_lambda=" + Double.toString(lambda),
-                       "in_maxi=" + Integer.toString(max_iter),
-                       "in_thr=" + Double.toString(threshold),
+                       "in_max_iter=" + Integer.toString(max_iter),
                        "rows=" + fed_rows, "cols=" + fed_cols,
                        "out_Z=" + output(OUTPUT_NAME)};
                runTest(true, false, null, -1);
@@ -156,15 +144,13 @@ public class FederatedAlsCGTest extends AutomatedTestBase
                TestUtils.shutdownThreads(thread1, thread2);
 
                // check for federated operations
-               Assert.assertTrue(heavyHittersContainsString("fed_!="));
-               Assert.assertTrue(heavyHittersContainsString("fed_fedinit"));
+               Assert.assertTrue(heavyHittersContainsString("fed_wcemm"));
                Assert.assertTrue(heavyHittersContainsString("fed_wdivmm"));
-               Assert.assertTrue(heavyHittersContainsString("fed_wsloss"));
+               Assert.assertTrue(heavyHittersContainsString("fed_fedinit"));
 
                // check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
-
                resetExecMode(platform_old);
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
index bf676a3..655124d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedCrossEntropyTest.java
@@ -47,7 +47,7 @@ public class FederatedWeightedCrossEntropyTest extends 
AutomatedTestBase
 
        private final static String OUTPUT_NAME = "Z";
        private final static double TOLERANCE = 1e-9;
-       private final static int blocksize = 1024;
+       private final static int BLOCKSIZE = 1024;
 
        @Parameterized.Parameter()
        public int rows;
@@ -124,8 +124,8 @@ public class FederatedWeightedCrossEntropyTest extends 
AutomatedTestBase
                double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
                double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
 
-               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
-               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
 
                writeInputMatrixWithMTD("U", U, true);
                writeInputMatrixWithMTD("V", V, true);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java
index 39a79bb..15c192b 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedDivMatrixMultTest.java
@@ -60,7 +60,7 @@ public class FederatedWeightedDivMatrixMultTest extends 
AutomatedTestBase
 
        private final static double TOLERANCE = 1e-9;
 
-       private final static int blocksize = 1024;
+       private final static int BLOCKSIZE = 1024;
 
        @Parameterized.Parameter()
        public int rows;
@@ -256,11 +256,11 @@ public class FederatedWeightedDivMatrixMultTest extends 
AutomatedTestBase
                double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
                double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
 
-               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
-               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
 
-               writeInputMatrixWithMTD("U", U, true, new 
MatrixCharacteristics(rows, rank, blocksize, rows * rank));
-               writeInputMatrixWithMTD("V", V, true, new 
MatrixCharacteristics(cols, rank, blocksize, rows * rank));
+               writeInputMatrixWithMTD("U", U, true, new 
MatrixCharacteristics(rows, rank, BLOCKSIZE, rows * rank));
+               writeInputMatrixWithMTD("V", V, true, new 
MatrixCharacteristics(cols, rank, BLOCKSIZE, rows * rank));
 
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
@@ -270,7 +270,7 @@ public class FederatedWeightedDivMatrixMultTest extends 
AutomatedTestBase
                Thread thread2 = startLocalFedWorkerThread(port2);
 
                getAndLoadTestConfiguration(test_name);
-               
+
                try {
                        // Run reference dml script with normal matrix
                        fullDMLScriptName = HOME + test_name + "Reference.dml";
@@ -278,7 +278,7 @@ public class FederatedWeightedDivMatrixMultTest extends 
AutomatedTestBase
                                "in_U=" + input("U"), "in_V=" + input("V"), 
"in_W=" + Double.toString(epsilon),
                                "out_Z=" + expected(OUTPUT_NAME)};
                        runTest(true, false, null, -1);
-       
+
                        // Run actual dml script with federated matrix
                        fullDMLScriptName = HOME + test_name + ".dml";
                        programArgs = new String[] {"-stats", "-nvargs",
@@ -289,22 +289,22 @@ public class FederatedWeightedDivMatrixMultTest extends 
AutomatedTestBase
                                "in_W=" + Double.toString(epsilon),
                                "rows=" + fed_rows, "cols=" + fed_cols, 
"out_Z=" + output(OUTPUT_NAME)};
                        runTest(true, false, null, -1);
-       
+
                        // compare the results via files
                        HashMap<CellIndex, Double> refResults   = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
                        HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
                        TestUtils.compareMatrices(fedResults, refResults, 
TOLERANCE, "Fed", "Ref");
-                       
+
                        // check for federated operations
                        
Assert.assertTrue(heavyHittersContainsString("fed_wdivmm"));
-       
+
                        // check that federated input files are still existing
                        
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
                        
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
                }
                finally {
                        TestUtils.shutdownThreads(thread1, thread2);
-                       
+
                        resetExecMode(platform_old);
                }
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
index e73ce82..ec800b0 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSigmoidTest.java
@@ -50,7 +50,7 @@ public class FederatedWeightedSigmoidTest extends 
AutomatedTestBase {
 
        private final static double TOLERANCE = 0;
 
-       private final static int blocksize = 1024;
+       private final static int BLOCKSIZE = 1024;
 
        @Parameterized.Parameter()
        public int rows;
@@ -151,11 +151,11 @@ public class FederatedWeightedSigmoidTest extends 
AutomatedTestBase {
                writeInputMatrixWithMTD("X1",
                        X1,
                        false,
-                       new MatrixCharacteristics(fed_rows, fed_cols, 
blocksize, fed_rows * fed_cols));
+                       new MatrixCharacteristics(fed_rows, fed_cols, 
BLOCKSIZE, fed_rows * fed_cols));
                writeInputMatrixWithMTD("X2",
                        X2,
                        false,
-                       new MatrixCharacteristics(fed_rows, fed_cols, 
blocksize, fed_rows * fed_cols));
+                       new MatrixCharacteristics(fed_rows, fed_cols, 
BLOCKSIZE, fed_rows * fed_cols));
 
                writeInputMatrixWithMTD("U", U, true);
                writeInputMatrixWithMTD("V", V, true);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
index 9b0f7a7..782891c 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedSquaredLossTest.java
@@ -50,7 +50,7 @@ public class FederatedWeightedSquaredLossTest extends 
AutomatedTestBase {
 
        private final static double TOLERANCE = 1e-8;
 
-       private final static int blocksize = 1024;
+       private final static int BLOCKSIZE = 1024;
 
        @Parameterized.Parameter()
        public int rows;
@@ -138,11 +138,11 @@ public class FederatedWeightedSquaredLossTest extends 
AutomatedTestBase {
                writeInputMatrixWithMTD("X1",
                        X1,
                        false,
-                       new MatrixCharacteristics(fed_rows, fed_cols, 
blocksize, fed_rows * fed_cols));
+                       new MatrixCharacteristics(fed_rows, fed_cols, 
BLOCKSIZE, fed_rows * fed_cols));
                writeInputMatrixWithMTD("X2",
                        X2,
                        false,
-                       new MatrixCharacteristics(fed_rows, fed_cols, 
blocksize, fed_rows * fed_cols));
+                       new MatrixCharacteristics(fed_rows, fed_cols, 
BLOCKSIZE, fed_rows * fed_cols));
 
                writeInputMatrixWithMTD("U", U, true);
                writeInputMatrixWithMTD("V", V, true);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
index 581d27d..8cc582a 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedWeightedUnaryMatrixMultTest.java
@@ -51,7 +51,7 @@ public class FederatedWeightedUnaryMatrixMultTest extends 
AutomatedTestBase
 
        private final static double TOLERANCE = 0;
 
-       private final static int blocksize = 1024;
+       private final static int BLOCKSIZE = 1024;
 
        @Parameterized.Parameter()
        public int rows;
@@ -147,11 +147,11 @@ public class FederatedWeightedUnaryMatrixMultTest extends 
AutomatedTestBase
                double[][] U = getRandomMatrix(rows, rank, 0, 1, 1, 512);
                double[][] V = getRandomMatrix(cols, rank, 0, 1, 1, 5040);
 
-               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
-               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, blocksize, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
 
-               writeInputMatrixWithMTD("U", U, false, new 
MatrixCharacteristics(rows, rank, blocksize, rows * rank));
-               writeInputMatrixWithMTD("V", V, false, new 
MatrixCharacteristics(cols, rank, blocksize, rows * rank));
+               writeInputMatrixWithMTD("U", U, false, new 
MatrixCharacteristics(rows, rank, BLOCKSIZE, rows * rank));
+               writeInputMatrixWithMTD("V", V, false, new 
MatrixCharacteristics(cols, rank, BLOCKSIZE, rows * rank));
 
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
@@ -169,7 +169,7 @@ public class FederatedWeightedUnaryMatrixMultTest extends 
AutomatedTestBase
                                "in_U=" + input("U"), "in_V=" + input("V"),
                                "out_Z=" + expected(OUTPUT_NAME)};
                        runTest(true, false, null, -1);
-       
+
                        // Run actual dml script with federated matrix
                        fullDMLScriptName = HOME + test_name + ".dml";
                        programArgs = new String[] {"-stats", "-nvargs",
@@ -179,22 +179,22 @@ public class FederatedWeightedUnaryMatrixMultTest extends 
AutomatedTestBase
                                "in_V=" + input("V"),
                                "rows=" + fed_rows, "cols=" + fed_cols, 
"out_Z=" + output(OUTPUT_NAME)};
                        runTest(true, false, null, -1);
-       
+
                        // compare the results via files
                        HashMap<CellIndex, Double> refResults   = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
                        HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
                        TestUtils.compareMatrices(fedResults, refResults, 
TOLERANCE, "Fed", "Ref");
-                       
+
                        // check for federated operations
                        
Assert.assertTrue(heavyHittersContainsString("fed_wumm"));
-       
+
                        // check that federated input files are still existing
                        
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
                        
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
                }
                finally {
                        TestUtils.shutdownThreads(thread1, thread2);
-                       
+
                        resetExecMode(platform_old);
                }
        }
diff --git a/src/test/scripts/functions/federated/FederatedPNMFTest.dml 
b/src/test/scripts/functions/federated/FederatedPNMFTest.dml
new file mode 100644
index 0000000..e8b01c9
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPNMFTest.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+  ranges=list(list(0, 0), list($rows, $cols), list($rows, 0), list($rows * 2, 
$cols)));
+
+rank = $in_rank;
+max_iter = $in_max_iter;
+
+[W, H] = pnmf(X = X, rnk = rank, maxi = max_iter);
+
+Z = W %*% H;
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/FederatedPNMFTestReference.dml 
b/src/test/scripts/functions/federated/FederatedPNMFTestReference.dml
new file mode 100644
index 0000000..b501cf9
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPNMFTestReference.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($in_X1), read($in_X2));
+
+rank = $in_rank;
+max_iter = $in_max_iter;
+
+[W, H] = pnmf(X = X, rnk = rank, maxi = max_iter);
+
+Z = W %*% H;
+
+write(Z, $out_Z);

Reply via email to