This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new fd178aa  [SYSTEMDS-2738] Federated rdiag, rev and uppertri instructions
fd178aa is described below

commit fd178aa36f02daea0372e4aab6b8f77711bc7a98
Author: Olga <ovcharenko.fo...@gmail.com>
AuthorDate: Fri Nov 20 14:55:34 2020 +0100

    [SYSTEMDS-2738] Federated rdiag, rev and uppertri instructions
    
    Federated support for diagonal, reverse and upper triangle.
    
    There are some TODO added in this commit as well:
    
    - Add a sorting method for federated map to sort the federatedInstances.
    - Reverse use slice and allocate many matrices, optimize this by
    leveraging underlying Dense and Sparse Blocks.
    
    Closes #1112
---
 .../controlprogram/federated/FederationMap.java    |  17 ++
 .../instructions/fed/FEDInstructionUtils.java      |   7 +-
 .../fed/ParameterizedBuiltinFEDInstruction.java    | 108 +++++++-
 .../instructions/fed/ReorgFEDInstruction.java      | 272 +++++++++++++++++++--
 .../federated/primitives/FederatedRdiagTest.java   | 145 +++++++++++
 .../federated/primitives/FederatedRevTest.java     | 160 ++++++++++++
 .../federated/primitives/FederatedTriTest.java     | 149 +++++++++++
 .../federated/FederatedLmPipelineReference.dml     |   2 +-
 .../functions/federated/FederatedRdiagTest.dml     |  27 ++
 .../federated/FederatedRdiagTestReference.dml      |  25 ++
 .../functions/federated/FederatedRevTest.dml       |  32 +++
 .../federated/FederatedRevTestReference.dml        |  26 ++
 .../functions/federated/FederatedTriTest.dml       |  33 +++
 .../federated/FederatedTriTestReference.dml        |  26 ++
 14 files changed, 1006 insertions(+), 23 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 482ade7..e933979 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -391,6 +391,23 @@ public class FederationMap {
                                Arrays.stream(frset).forEach(fr -> 
fr.setTID(tid));
        }
 
+       public void reverseFedMap() {
+               // TODO: add a check if the map is sorted based on indexes 
before reversing.
+               // TODO: add a setup such that on construction the federated 
map is already sorted.
+               FederatedRange[] fedRanges = this.getFederatedRanges();
+
+               for(int i = 0; i < Math.floor(fedRanges.length / 2.0); i++) {
+                       FederatedData data1 = _fedMap.get(fedRanges[i]);
+                       FederatedData data2 = 
_fedMap.get(fedRanges[fedRanges.length-1-i]);
+
+                       _fedMap.remove(fedRanges[i]);
+                       _fedMap.remove(fedRanges[fedRanges.length-1-i]);
+
+                       _fedMap.put(fedRanges[i], data2);
+                       _fedMap.put(fedRanges[fedRanges.length-1-i], data1);
+               }
+       }
+
        private static class MappingTask implements Callable<Void> {
                private final FederatedRange _range;
                private final FederatedData _data;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index af17637..e6a64cb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -91,7 +91,8 @@ public class FEDInstructionUtils {
                }
                else if (inst instanceof UnaryCPInstruction && ! (inst 
instanceof IndexingCPInstruction)) {
                        UnaryCPInstruction instruction = (UnaryCPInstruction) 
inst;
-                       if(inst instanceof ReorgCPInstruction && 
inst.getOpcode().equals("r'")) {
+                       if(inst instanceof ReorgCPInstruction && 
(inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
+                               || inst.getOpcode().equals("rev"))) {
                                ReorgCPInstruction rinst = (ReorgCPInstruction) 
inst;
                                CacheableData<?> mo = 
ec.getCacheableData(rinst.input1);
 
@@ -129,7 +130,9 @@ public class FEDInstructionUtils {
                }
                else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
                        ParameterizedBuiltinCPInstruction pinst = 
(ParameterizedBuiltinCPInstruction) inst;
-                       if((pinst.getOpcode().equals("replace") || 
pinst.getOpcode().equals("rmempty")) && pinst.getTarget(ec).isFederated()) {
+                       if((pinst.getOpcode().equals("replace") || 
pinst.getOpcode().equals("rmempty")
+                               || pinst.getOpcode().equals("lowertri") || 
pinst.getOpcode().equals("uppertri"))
+                               && pinst.getTarget(ec).isFederated()) {
                                fedinst = 
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
                        }
                        else if((pinst.getOpcode().equals("transformdecode") || 
pinst.getOpcode().equals("transformapply")) &&
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 6588909..a2b63e9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -104,7 +104,8 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                LinkedHashMap<String, String> paramsMap = 
constructParameterMap(parts);
 
                // determine the appropriate value function
-               if(opcode.equalsIgnoreCase("replace") || 
opcode.equalsIgnoreCase("rmempty")) {
+               if(opcode.equalsIgnoreCase("replace") || 
opcode.equalsIgnoreCase("rmempty")
+                       || opcode.equalsIgnoreCase("lowertri") || 
opcode.equalsIgnoreCase("uppertri")) {
                        ValueFunction func = 
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
                        return new ParameterizedBuiltinFEDInstruction(new 
SimpleOperator(func), paramsMap, out, opcode, str);
                }
@@ -137,6 +138,8 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                }
                else if(opcode.equals("rmempty"))
                        rmempty(ec);
+               else if(opcode.equals("lowertri") || opcode.equals("uppertri"))
+                       triangle(ec, opcode);
                else if(opcode.equalsIgnoreCase("transformdecode"))
                        transformDecode(ec);
                else if(opcode.equalsIgnoreCase("transformapply"))
@@ -145,22 +148,117 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        throw new DMLRuntimeException("Unknown opcode : " + 
opcode);
                }
        }
+       
+       private void triangle(ExecutionContext ec, String opcode) {
+               boolean lower = opcode.equals("lowertri");
+               boolean diag = Boolean.parseBoolean(params.get("diag"));
+               boolean values = Boolean.parseBoolean(params.get("values"));
+
+               MatrixObject mo = (MatrixObject) getTarget(ec);
+
+               FederationMap fedMap = mo.getFedMapping();
+               boolean rowFed = mo.isFederated(FederationMap.FType.ROW);
+
+               long varID = FederationUtils.getNextFedDataID();
+               FederationMap diagFedMap;
+
+               diagFedMap = fedMap.mapParallel(varID, (range, data) -> {
+                       try {
+                               FederatedResponse response = 
data.executeFederatedOperation(new FederatedRequest(
+                                       FederatedRequest.RequestType.EXEC_UDF, 
-1,
+                                       new 
ParameterizedBuiltinFEDInstruction.Tri(data.getVarID(), varID,
+                                               rowFed ? (new int[] 
{range.getBeginDimsInt()[0], range.getEndDimsInt()[0]}) :
+                                                       new int[] 
{range.getBeginDimsInt()[1], range.getEndDimsInt()[1]},
+                                               rowFed, lower, diag, 
values))).get();
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               return null;
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+               });
+               MatrixObject out = ec.getMatrixObject(output);
+               out.setFedMapping(diagFedMap);
+       }
+
+       private static class Tri extends FederatedUDF {
+               private static final long serialVersionUID = 
6254009025304038215L;
+
+               private final long _outputID;
+               private final int[] _slice;
+               private final boolean _rowFed;
+               private final boolean _lower;
+               private final boolean _diag;
+               private final boolean _values;
+
+               private Tri(long input, long outputID, int[] slice, boolean 
rowFed, boolean lower, boolean diag, boolean values) {
+                       super(new long[] {input});
+                       _outputID = outputID;
+                       _slice = slice;
+                       _rowFed = rowFed;
+                       _lower = lower;
+                       _diag = diag;
+                       _values = values;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       MatrixBlock soresBlock, addBlock;
+                       MatrixBlock ret;
+
+                       //slice
+                       soresBlock = _rowFed ?
+                               mb.slice(0, mb.getNumRows()-1, _slice[0], 
_slice[1]-1, new MatrixBlock()) :
+                               mb.slice(_slice[0], _slice[1]-1);
+
+                       //triangle
+                       MatrixBlock tri = soresBlock.extractTriangular(new 
MatrixBlock(), _lower, _diag, _values);
+                       // todo: optimize to not allocate and slice all these 
matrix blocks, but leveraging underlying dense or sparse blocks.
+                       if(_rowFed) {
+                               ret = new MatrixBlock(mb.getNumRows(), 
mb.getNumColumns(), 0.0);
+                               ret.copy(0, ret.getNumRows()-1, _slice[0], 
_slice[1]-1, tri, false);
+                               if(_slice[1] <= mb.getNumColumns()-1 && 
!_lower) {
+                                       addBlock = mb.slice(0, 
mb.getNumRows()-1, _slice[1], mb.getNumColumns()-1, new MatrixBlock());
+                                       ret.copy(0, ret.getNumRows()-1, 
_slice[1], ret.getNumColumns() - 1, addBlock, false);
+                               } else if(_slice[0] > 0 && _lower) {
+                                       addBlock = mb.slice(0, 
mb.getNumRows()-1, 0, _slice[0]-1, new MatrixBlock());
+                                       ret.copy(0, ret.getNumRows()-1, 0,  
_slice[0]-1, addBlock, false);
+                               }
+                       } else {
+                               ret = new MatrixBlock(mb.getNumRows(), 
mb.getNumColumns(), 0.0);
+                               ret.copy(_slice[0], _slice[1]-1, 0, 
mb.getNumColumns() - 1, tri, false);
+                               if(_slice[0] > 0 && !_lower) {
+                                       addBlock = mb.slice(0, _slice[0]-1,0, 
mb.getNumColumns()-1, new MatrixBlock());
+                                       ret.copy(0, ret.getNumRows() - 1, 
_slice[1], ret.getNumColumns() - 1, addBlock, false);
+                               } else if(_slice[1] <= mb.getNumRows() 
&&_lower) {
+                                       addBlock = mb.slice(_slice[1], 
ret.getNumRows()-1,0, mb.getNumColumns()-1, new MatrixBlock());
+                                       ret.copy(_slice[1], ret.getNumRows() - 
1, 0, mb.getNumColumns()-1, addBlock, false);
+                               }
+                       }
+                       MatrixObject mout = 
ExecutionContext.createMatrixObject(ret);
+                       ec.setVariable(String.valueOf(_outputID), mout);
+
+                       return new 
FederatedResponse(ResponseType.SUCCESS_EMPTY);
+               }
+       }
 
        private void rmempty(ExecutionContext ec) {
                String margin = params.get("margin");
                if( !(margin.equals("rows") || margin.equals("cols")) )
-                       throw new DMLRuntimeException("Unspupported margin 
identifier '"+margin+"'.");
+                       throw new DMLRuntimeException("Unsupported margin 
identifier '"+margin+"'.");
 
                MatrixObject mo = (MatrixObject) getTarget(ec);
                MatrixObject select = params.containsKey("select") ? 
ec.getMatrixObject(params.get("select")) : null;
                MatrixObject out = ec.getMatrixObject(output);
 
                boolean marginRow = params.get("margin").equals("rows");
-               boolean k = ((marginRow && 
mo.getFedMapping().getType().isColPartitioned()) ||
+               boolean isNotAligned = ((marginRow && 
mo.getFedMapping().getType().isColPartitioned()) ||
                        (!marginRow && 
mo.getFedMapping().getType().isRowPartitioned()));
 
                MatrixBlock s = new MatrixBlock();
-               if(select == null && k) {
+               if(select == null && isNotAligned) {
                        List<MatrixBlock> colSums = new ArrayList<>();
                        mo.getFedMapping().forEachParallel((range, data) -> {
                                try {
@@ -209,7 +307,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        mo.getFedMapping().execute(getTID(), true, fr1);
                        
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
                }
-               else if (!k) {
+               else if (!isNotAligned) {
                        //construct commands: broadcast , fed rmempty, clean 
broadcast
                        FederatedRequest[] fr1 = 
mo.getFedMapping().broadcastSliced(select, !marginRow);
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString,
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index d9e9d97..b2f3a53 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -19,28 +19,60 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
+import java.util.AbstractMap;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.functionobjects.DiagIndex;
+import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.SortIndex;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.runtime.util.IndexRange;
 
 public class ReorgFEDInstruction extends UnaryFEDInstruction {
        
-       public ReorgFEDInstruction(CPOperand in1, CPOperand out, String opcode, 
String istr) {
-               super(FEDType.Reorg, null, in1, out, opcode, istr);
+       public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, 
String opcode, String istr) {
+               super(FEDType.Reorg, op, in1, out, opcode, istr);
        }
 
        public static ReorgFEDInstruction parseInstruction ( String str ) {
+               CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+               CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, 
Types.DataType.UNKNOWN);
+
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
                if ( opcode.equalsIgnoreCase("r'") ) {
                        InstructionUtils.checkNumFields(str, 2, 3);
-                       CPOperand in = new CPOperand(parts[1]);
-                       CPOperand out = new CPOperand(parts[2]);
-                       return new ReorgFEDInstruction(in, out, opcode, str);
+                       in.split(parts[1]);
+                       out.split(parts[2]);
+                       int k = Integer.parseInt(parts[3]);
+                       return new ReorgFEDInstruction(new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str);
+               }
+               else if ( opcode.equalsIgnoreCase("rdiag") ) {
+                       parseUnaryInstruction(str, in, out); //max 2 operands
+                       return new ReorgFEDInstruction(new 
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
+               }
+               else if ( opcode.equalsIgnoreCase("rev") ) {
+                       parseUnaryInstruction(str, in, out); //max 2 operands
+                       return new ReorgFEDInstruction(new 
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
                }
                else {
                        throw new DMLRuntimeException("ReorgFEDInstruction: 
unsupported opcode: "+opcode);
@@ -50,20 +82,230 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
        @Override
        public void processInstruction(ExecutionContext ec) {
                MatrixObject mo1 = ec.getMatrixObject(input1);
-               
+               ReorgOperator r_op = (ReorgOperator) _optr;
+
                if( !mo1.isFederated() )
                        throw new DMLRuntimeException("Federated Reorg: "
                                + "Federated input expected, but invoked w/ 
"+mo1.isFederated());
 
-               //execute transpose at federated site
-               FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
-                       new CPOperand[]{input1}, new 
long[]{mo1.getFedMapping().getID()});
-               mo1.getFedMapping().execute(getTID(), true, fr1);
+               if(instOpcode.equals("r'")) {
+                       //execute transpose at federated site
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {input1},
+                               new long[] {mo1.getFedMapping().getID()});
+                       mo1.getFedMapping().execute(getTID(), true, fr1);
+
+                       //drive output federated mapping
+                       MatrixObject out = ec.getMatrixObject(output);
+                       out.getDataCharacteristics().set(mo1.getNumColumns(), 
mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
+                       
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+               }
+               else if(instOpcode.equalsIgnoreCase("rev")) {
+                       //execute transpose at federated site
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {input1},
+                               new long[] {mo1.getFedMapping().getID()});
+                       mo1.getFedMapping().execute(getTID(), true, fr1);
+
+                       if(mo1.isFederated(FederationMap.FType.ROW))
+                               mo1.getFedMapping().reverseFedMap();
+
+                       //derive output federated mapping
+                       MatrixObject out = ec.getMatrixObject(output);
+                       out.getDataCharacteristics().set(mo1.getNumRows(), 
mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
+                       
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+               }
+               else if (instOpcode.equals("rdiag")) {
+                       RdiagResult result;
+                       // diag(diag(X))
+                       if (mo1.getNumColumns() == 1 && mo1.getNumRows() != 1) {
+                               result = rdiagV2M(mo1, r_op);
+                       } else {
+                               result = rdiagM2V(mo1, r_op);
+                       }
+
+                       FederationMap diagFedMap = result.getFedMap();
+                       Map<FederatedRange, int[]> dcs = result.getDcs();
+
+                       //update fed ranges
+                       for(int i = 0; i < 
diagFedMap.getFederatedRanges().length; i++) {
+                               int[] newRange = 
dcs.get(diagFedMap.getFederatedRanges()[i]);
+
+                               
diagFedMap.getFederatedRanges()[i].setBeginDim(0,
+                                       
(diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+                                               i == 0) ? 0 : 
diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
+                               diagFedMap.getFederatedRanges()[i].setEndDim(0,
+                                       
diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+                               
diagFedMap.getFederatedRanges()[i].setBeginDim(1,
+                                       
(diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+                                               i == 0) ? 0 : 
diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
+                               diagFedMap.getFederatedRanges()[i].setEndDim(1,
+                                       
diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+                       }
+
+                       //update output mapping and data characteristics
+                       MatrixObject rdiag = ec.getMatrixObject(output);
+                       rdiag.getDataCharacteristics()
+                               .set(diagFedMap.getMaxIndexInRange(0), 
diagFedMap.getMaxIndexInRange(1),
+                                       (int) mo1.getBlocksize());
+                       rdiag.setFedMapping(diagFedMap);
+               }
+       }
+
+       private class RdiagResult {
+               FederationMap fedMap;
+               Map<FederatedRange, int[]> dcs;
+
+               public RdiagResult(FederationMap fedMap, Map<FederatedRange, 
int[]> dcs) {
+                       this.fedMap = fedMap;
+                       this.dcs = dcs;
+               }
+
+               public FederationMap getFedMap() {
+                       return fedMap;
+               }
+
+               public Map<FederatedRange, int[]> getDcs() {
+                       return dcs;
+               }
+       }
+
+       private RdiagResult rdiagV2M (MatrixObject mo1, ReorgOperator r_op) {
+               FederationMap fedMap = mo1.getFedMapping();
+               boolean rowFed = mo1.isFederated(FederationMap.FType.ROW);
 
-               //drive output federated mapping
-               MatrixObject out = ec.getMatrixObject(output);
-               out.getDataCharacteristics().set(mo1.getNumColumns(),
-                       mo1.getNumRows(), (int)mo1.getBlocksize(), 
mo1.getNnz());
-               
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()).transpose());
+               long varID = FederationUtils.getNextFedDataID();
+               Map<FederatedRange, int[]> dcs = new HashMap<>();
+               FederationMap diagFedMap;
+
+               diagFedMap = fedMap.mapParallel(varID, (range, data) -> {
+                       try {
+                               FederatedResponse response = 
data.executeFederatedOperation(new FederatedRequest(
+                                       FederatedRequest.RequestType.EXEC_UDF, 
-1,
+                                       new 
ReorgFEDInstruction.DiagMatrix(data.getVarID(),
+                                               varID, r_op,
+                                               rowFed ? (new int[] 
{range.getBeginDimsInt()[0], range.getEndDimsInt()[0]}) :
+                                                       new int[] 
{range.getBeginDimsInt()[1], range.getEndDimsInt()[1]},
+                                               rowFed, (int) 
mo1.getNumRows()))).get();
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               int[] subRangeCharacteristics = (int[]) 
response.getData()[0];
+                               synchronized(dcs) {
+                                       dcs.put(range, subRangeCharacteristics);
+                               }
+                               return null;
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+               });
+               return new RdiagResult(diagFedMap, dcs);
+       }
+
+       private RdiagResult rdiagM2V (MatrixObject mo1, ReorgOperator r_op) {
+               FederationMap fedMap = mo1.getFedMapping();
+               boolean rowFed = mo1.isFederated(FederationMap.FType.ROW);
+
+               long varID = FederationUtils.getNextFedDataID();
+               Map<FederatedRange, int[]> dcs = new HashMap<>();
+               FederationMap diagFedMap;
+
+               diagFedMap = fedMap.mapParallel(varID, (range, data) -> {
+                       try {
+                               FederatedResponse response = 
data.executeFederatedOperation(new FederatedRequest(
+                                       FederatedRequest.RequestType.EXEC_UDF, 
-1,
+                                       new 
ReorgFEDInstruction.Rdiag(data.getVarID(), varID, r_op,
+                                               rowFed ? (new int[] 
{range.getBeginDimsInt()[0], range.getEndDimsInt()[0]}) :
+                                                       new int[] 
{range.getBeginDimsInt()[1], range.getEndDimsInt()[1]},
+                                               rowFed))).get();
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               int[] subRangeCharacteristics = (int[]) 
response.getData()[0];
+                               synchronized(dcs) {
+                                       dcs.put(range, subRangeCharacteristics);
+                               }
+                               return null;
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+               });
+               return new RdiagResult(diagFedMap, dcs);
+       }
+
+       private static class Rdiag extends FederatedUDF {
+
+               private static final long serialVersionUID = 
-3466926635958851402L;
+               private final long _outputID;
+               private final ReorgOperator _r_op;
+               private final int[] _slice;
+               private final boolean _rowFed;
+
+               private Rdiag(long input, long outputID, ReorgOperator r_op, 
int[] slice, boolean rowFed) {
+                       super(new long[] {input});
+                       _outputID = outputID;
+                       _r_op = r_op;
+                       _slice = slice;
+                       _rowFed = rowFed;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       MatrixBlock soresBlock;
+                       MatrixBlock res;
+
+                       soresBlock = _rowFed ?
+                               mb.slice(0, mb.getNumRows() - 1, _slice[0], 
_slice[1] - 1, new MatrixBlock()) :
+                               mb.slice(_slice[0], _slice[1] - 1);
+                       res = soresBlock.reorgOperations(_r_op, new 
MatrixBlock(), 0, 0, 0);
+
+                       MatrixObject mout = 
ExecutionContext.createMatrixObject(res);
+                       mout.setDiag(true);
+                       ec.setVariable(String.valueOf(_outputID), mout);
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new 
int[]{res.getNumRows(), res.getNumColumns()});
+               }
+       }
+
+       private static class DiagMatrix extends FederatedUDF {
+
+               private static final long serialVersionUID = 
-3466926635958851402L;
+               private final long _outputID;
+               private final ReorgOperator _r_op;
+               private final int _len;
+               private final int[] _slice;
+               private final boolean _rowFed;
+
+               private DiagMatrix(long input, long outputID, ReorgOperator 
r_op, int[] slice, boolean rowFed, int len) {
+                       super(new long[] {input});
+                       _outputID = outputID;
+                       _r_op = r_op;
+                       _len = len;
+                       _rowFed = rowFed;
+                       _slice = slice;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       MatrixBlock res;
+
+                       MatrixBlock tmp = mb.reorgOperations(_r_op, new 
MatrixBlock(), 0, 0, 0);
+                       if(_rowFed) {
+                               res = new MatrixBlock(mb.getNumRows(), _len, 
0.0);
+                               res.copy(0, res.getNumRows()-1, _slice[0], 
_slice[1]-1, tmp, false);
+                       } else {
+                               res = new MatrixBlock(_len, _slice[1], 0.0);
+                               res.copy(_slice[0], _slice[1]-1, 0, 
mb.getNumColumns() - 1, tmp, false);;
+                       }
+                       MatrixObject mout = 
ExecutionContext.createMatrixObject(res);
+                       mout.setDiag(true);
+                       ec.setVariable(String.valueOf(_outputID), mout);
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new 
int[]{res.getNumRows(), res.getNumColumns()});
+               }
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
new file mode 100644
index 0000000..eedbbbb
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedRdiagTest extends AutomatedTestBase {
+
+       private final static String TEST_DIR = "functions/federated/";
+       private final static String TEST_NAME = "FederatedRdiagTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedRdiagTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {1000, 1000},
+                       {1000,1}
+               });
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+       }
+
+       @Test
+       public void federatedRdiagCP() { 
federatedRdiag(Types.ExecMode.SINGLE_NODE); }
+
+       @Test
+       @Ignore
+       public void federatedRdiagSP() { federatedRdiag(Types.ExecMode.SPARK); }
+
+       public void federatedRdiag(Types.ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               int r = rows / 4;
+
+               double[][] X1 = getRandomMatrix(r, cols, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, cols, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, cols, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, cols, 1, 5, 1, 9);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, cols, 
blocksize, r*cols);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               // reference file should not be written to hdfs, so we set 
platform here
+               rtplatform = execMode;
+               if(rtplatform == Types.ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               // Run reference dml script with normal matrix for Row/Col
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
+                       input("X1"), input("X2"), input("X3"), input("X4"), 
expected("S")};
+               runTest(null);
+
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
+                       "rows=" + rows,
+                       "cols=" + cols,
+                       "out_S=" + output("S")};
+               runTest(null);
+
+               // compare all sums via files
+               compareResults(0.01);
+
+               Assert.assertTrue(heavyHittersContainsString("fed_rdiag"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
new file mode 100644
index 0000000..36996db
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedRevTest extends AutomatedTestBase {
+       // private static final Log LOG = 
LogFactory.getLog(FederatedRightIndexTest.class.getName());
+
+       private final static String TEST_NAME = "FederatedRevTest";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedRevTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {100, 12, true},
+                       {100, 12, false}
+               });
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+       }
+
+       @Test
+       public void testRevCP() {
+               runRevTest(ExecMode.SINGLE_NODE);
+       }
+
+       private void runRevTest(ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+               for(int k : new int[] {1, 2, 3}) {
+                       Arrays.fill(X3[k], 0);
+               }
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                       Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
+
+               runTest(null);
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
+
+               runTest(null);
+
+               // compare via files
+               compareResults(0.01);
+
+               Assert.assertTrue(heavyHittersContainsString("fed_rev"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java
new file mode 100644
index 0000000..e7ae2d4
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTriTest.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedTriTest extends AutomatedTestBase {
+
+       private final static String TEST_NAME = "FederatedTriTest";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedTriTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {20, 20, true}
+               });
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+       }
+
+       @Test
+       public void testTriCP() { runTriTest(ExecMode.SINGLE_NODE); }
+
+       private void runTriTest(ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               ExecMode platformOld = rtplatform;
+
+               if(rtplatform == ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int r = rows;
+               int c = cols / 4;
+               if(rowPartitioned) {
+                       r = rows / 4;
+                       c = cols;
+               }
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                       Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
+
+               runTest(null);
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
+
+               runTest(null);
+
+               // compare via files
+               compareResults(1e-9);
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml 
b/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml
index 1934aae..ffdca07 100644
--- a/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml
+++ b/src/test/scripts/functions/federated/FederatedLmPipelineReference.dml
@@ -19,8 +19,8 @@
 #
 #-------------------------------------------------------------
 
-Fin = rbind(read($1), read($2))
 
+Fin = rbind(read($1), read($2))
 y = read($5)
 
 # one hot encoding categorical, other passthrough
diff --git a/src/test/scripts/functions/federated/FederatedRdiagTest.dml 
b/src/test/scripts/functions/federated/FederatedRdiagTest.dml
new file mode 100644
index 0000000..a03ce4e
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRdiagTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+    list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+
+s = diag(A);
+write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedRdiagTestReference.dml 
b/src/test/scripts/functions/federated/FederatedRdiagTestReference.dml
new file mode 100644
index 0000000..92bdd8f
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRdiagTestReference.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = rbind(read($1), read($2), read($3), read($4));
+
+s = diag(A);
+write(s, $5);
diff --git a/src/test/scripts/functions/federated/FederatedRevTest.dml 
b/src/test/scripts/functions/federated/FederatedRevTest.dml
new file mode 100644
index 0000000..d43edd1
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRevTest.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+ #
+ # Licensed to the Apache Software Foundation (ASF) under one
+ # or more contributor license agreements.  See the NOTICE file
+ # distributed with this work for additional information
+ # regarding copyright ownership.  The ASF licenses this file
+ # to you under the Apache License, Version 2.0 (the
+ # "License"); you may not use this file except in compliance
+ # with the License.  You may obtain a copy of the License at
+ #
+ #   http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing,
+ # software distributed under the License is distributed on an
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ # KIND, either express or implied.  See the License for the
+ # specific language governing permissions and limitations
+ # under the License.
+ #
+ #-------------------------------------------------------------
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+s = rev(A);
+write(s, $out_S);
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/FederatedRevTestReference.dml 
b/src/test/scripts/functions/federated/FederatedRevTestReference.dml
new file mode 100644
index 0000000..e07591d
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedRevTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+ #
+ # Licensed to the Apache Software Foundation (ASF) under one
+ # or more contributor license agreements.  See the NOTICE file
+ # distributed with this work for additional information
+ # regarding copyright ownership.  The ASF licenses this file
+ # to you under the Apache License, Version 2.0 (the
+ # "License"); you may not use this file except in compliance
+ # with the License.  You may obtain a copy of the License at
+ #
+ #   http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing,
+ # software distributed under the License is distributed on an
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ # KIND, either express or implied.  See the License for the
+ # specific language governing permissions and limitations
+ # under the License.
+ #
+ #-------------------------------------------------------------
+
+ if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+ else { A = cbind(read($1), read($2), read($3), read($4)); }
+
+ s = rev(A);
+ write(s, $6);
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/FederatedTriTest.dml 
b/src/test/scripts/functions/federated/FederatedTriTest.dml
new file mode 100644
index 0000000..a661a4a
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedTriTest.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+  } else {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+            ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+    }
+
+s = lower.tri(target=A, diag=FALSE, values=TRUE);
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/FederatedTriTestReference.dml 
b/src/test/scripts/functions/federated/FederatedTriTestReference.dml
new file mode 100644
index 0000000..bb5d98d
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedTriTestReference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+if($5) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4));}
+
+s = lower.tri(target=A, diag=FALSE, values=TRUE);
+write(s, $6);

Reply via email to