This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 82536c1  [SYSTEMDS-3101] Fix federated spoof instruction (federated 
output)
82536c1 is described below

commit 82536c1841b546db4f519086d2d7a6cba011603c
Author: ywcb00 <[email protected]>
AuthorDate: Sat Sep 18 22:16:43 2021 +0200

    [SYSTEMDS-3101] Fix federated spoof instruction (federated output)
    
    Closes #1380.
    
    Other cleanups:
    Closes #1336.
    Closes #1365.
---
 .../instructions/fed/SpoofFEDInstruction.java      | 493 ++++++++++++---------
 .../codegen/FederatedCodegenMultipleFedMOTest.java |   6 +-
 .../codegen/FederatedOuterProductTmplTest.java     |   8 +-
 .../codegen/FederatedRowwiseTmplTest.java          |   2 +-
 .../pipelines/BuiltinTopkEvaluateTest.java         |   1 -
 5 files changed, 289 insertions(+), 221 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index ecf310c..331ecfc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -50,6 +50,7 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.concurrent.Future;
+import java.util.stream.IntStream;
 
 public class SpoofFEDInstruction extends FEDInstruction
 {
@@ -82,37 +83,38 @@ public class SpoofFEDInstruction extends FEDInstruction
 
        @Override
        public void processInstruction(ExecutionContext ec) {
+               FederationMap fedMap = null;
+               for(CPOperand cpo : _inputs) { // searching for the first 
federated matrix to obtain the federation map
+                       Data tmpData = ec.getVariable(cpo);
+                       if(tmpData instanceof MatrixObject && 
((MatrixObject)tmpData).isFederatedExcept(FType.BROADCAST)) {
+                               fedMap = 
((MatrixObject)tmpData).getFedMapping();
+                               break;
+                       }
+               }
+
                Class<?> scla = _op.getClass().getSuperclass();
                SpoofFEDType spoofType = null;
                if(scla == SpoofCellwise.class)
-                       spoofType = new SpoofFEDCellwise(_op, _output);
+                       spoofType = new SpoofFEDCellwise(_op, _output, 
fedMap.getType());
                else if(scla == SpoofRowwise.class)
-                       spoofType = new SpoofFEDRowwise(_op, _output);
+                       spoofType = new SpoofFEDRowwise(_op, _output, 
fedMap.getType());
                else if(scla == SpoofMultiAggregate.class)
-                       spoofType = new SpoofFEDMultiAgg(_op, _output);
+                       spoofType = new SpoofFEDMultiAgg(_op, _output, 
fedMap.getType());
                else if(scla == SpoofOuterProduct.class)
-                       spoofType = new SpoofFEDOuterProduct(_op, _output);
+                       spoofType = new SpoofFEDOuterProduct(_op, _output, 
fedMap.getType(), _inputs);
                else
                        throw new DMLRuntimeException("Federated code 
generation only supported" +
                                " for cellwise, rowwise, multiaggregate, and 
outerproduct templates.");
 
+               processRequest(ec, fedMap, spoofType);
+       }
 
-               FederationMap fedMap = null;
-               long id = 0;
-               for(CPOperand cpo : _inputs) { // searching for the first 
federated matrix to obtain the federation map
-                       Data tmpData = ec.getVariable(cpo);
-                       if(tmpData instanceof MatrixObject && 
((MatrixObject)tmpData).isFederatedExcept(FType.BROADCAST)) {
-                               fedMap = 
((MatrixObject)tmpData).getFedMapping();
-                               id = ((MatrixObject)tmpData).getUniqueID();
-                               break;
-                       }
-               }
-
+       private void processRequest(ExecutionContext ec, FederationMap fedMap, 
SpoofFEDType spoofType) {
                ArrayList<FederatedRequest> frBroadcast = new ArrayList<>();
                ArrayList<FederatedRequest[]> frBroadcastSliced = new 
ArrayList<>();
                long[] frIds = new long[_inputs.length];
                int index = 0;
-               
+
                for(CPOperand cpo : _inputs) {
                        Data tmpData = ec.getVariable(cpo);
                        if(tmpData instanceof MatrixObject) {
@@ -121,7 +123,7 @@ public class SpoofFEDInstruction extends FEDInstruction
                                        frIds[index++] = 
mo.getFedMapping().getID();
                                }
                                else if(spoofType.needsBroadcastSliced(fedMap, 
mo.getNumRows(), mo.getNumColumns(), index)) {
-                                       FederatedRequest[] tmpFr = 
spoofType.broadcastSliced(mo, fedMap, id);
+                                       FederatedRequest[] tmpFr = 
spoofType.broadcastSliced(mo, fedMap);
                                        frIds[index++] = tmpFr[0].getID();
                                        frBroadcastSliced.add(tmpFr);
                                }
@@ -144,48 +146,71 @@ public class SpoofFEDInstruction extends FEDInstruction
 
                FederatedRequest frCompute = 
FederationUtils.callInstruction(instString, _output, _inputs, frIds);
 
-               // get partial results from federated workers
-               FederatedRequest frGet = new 
FederatedRequest(RequestType.GET_VAR, frCompute.getID());
+               FederatedRequest frGet = null;
+               FederatedRequest frCleanup = null;
+               if(!spoofType.isFedOutput()) {
+                       // get partial results from federated workers
+                       frGet = new FederatedRequest(RequestType.GET_VAR, 
frCompute.getID());
+                       // cleanup the federated request of callInstruction
+                       frCleanup = fedMap.cleanup(getTID(), frCompute.getID());
+               }
 
-               ArrayList<FederatedRequest> frCleanup = new ArrayList<>();
-               frCleanup.add(fedMap.cleanup(getTID(), frCompute.getID()));
-               for(FederatedRequest[] fr : frBroadcastSliced)
-                       frCleanup.add(fedMap.cleanup(getTID(), fr[0].getID()));
+               FederatedRequest[] frAll;
+               if(frGet == null) // no get request if output is kept federated
+                       frAll = ArrayUtils.addAll(
+                               frBroadcast.toArray(new FederatedRequest[0]), 
frCompute);
+               else
+                       frAll = ArrayUtils.addAll(
+                               frBroadcast.toArray(new FederatedRequest[0]), 
frCompute, frGet, frCleanup);
 
-               FederatedRequest[] frAll = ArrayUtils.addAll(ArrayUtils.addAll(
-                       frBroadcast.toArray(new FederatedRequest[0]), 
frCompute, frGet),
-                       frCleanup.toArray(new FederatedRequest[0]));
                Future<FederatedResponse>[] response = 
fedMap.executeMultipleSlices(
                        getTID(), true, frBroadcastSliced.toArray(new 
FederatedRequest[0][]), frAll);
 
                // setting the output with respect to the different aggregation 
types
                // of the different spoof templates
-               spoofType.setOutput(ec, response, fedMap);
+               spoofType.setOutput(ec, response, fedMap, frCompute.getID());
        }
 
 
+       // abstract class to differentiate between the different spoof templates
        private static abstract class SpoofFEDType {
                CPOperand _output;
+               FType _fedType;
 
-               protected SpoofFEDType(CPOperand out) {
+               protected SpoofFEDType(CPOperand out, FType fedType) {
                        _output = out;
+                       _fedType = fedType;
                }
-               
-               protected FederatedRequest[] broadcastSliced(MatrixObject mo, 
FederationMap fedMap, long id) {
+
+               /**
+                * performs the sliced broadcast of the given matrix object
+                *
+                * @param mo the matrix object to broadcast sliced
+                * @param fedMap the federated mapping
+                * @return FederatedRequest[] the resulting federated request 
array of the broadcast
+                */
+               protected FederatedRequest[] broadcastSliced(MatrixObject mo, 
FederationMap fedMap) {
                        return fedMap.broadcastSliced(mo, false);
                }
 
+               /**
+                * determine if a specific matrix object needs to be broadcast 
sliced
+                *
+                * @param fedMap the federated mapping
+                * @param rowNum the number of rows of the matrix object
+                * @param colNum the number of columns of the matrix object
+                * @param inputIndex the index of the matrix inside the 
instruction inputs
+                * @return boolean indicates if the matrix needs to be 
broadcast sliced
+                */
                protected boolean needsBroadcastSliced(FederationMap fedMap, 
long rowNum, long colNum, int inputIndex) {
-                       FType fedType = fedMap.getType();
-
                        //TODO fix check by num rows/cols
                        boolean retVal = (rowNum == 
fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1));
-                       if(fedType == FType.ROW)
-                               retVal |= (rowNum == 
fedMap.getMaxIndexInRange(0) 
-                                       && (colNum == 1 || colNum == 
fedMap.getSize() || fedMap.getMaxIndexInRange(1) == 1));
-                       else if(fedType == FType.COL)
+                       if(_fedType == FType.ROW)
+                               retVal |= (rowNum == 
fedMap.getMaxIndexInRange(0)
+                                       && (colNum == 1 || 
fedMap.getMaxIndexInRange(1) == 1));
+                       else if(_fedType == FType.COL)
                                retVal |= (colNum == 
fedMap.getMaxIndexInRange(1)
-                                       && (rowNum == 1 || rowNum == 
fedMap.getSize() || fedMap.getMaxIndexInRange(0) == 1));
+                                       && (rowNum == 1 || 
fedMap.getMaxIndexInRange(0) == 1));
                        else {
                                throw new DMLRuntimeException("Only row 
partitioned or column" +
                                        " partitioned federated input supported 
yet.");
@@ -193,236 +218,281 @@ public class SpoofFEDInstruction extends FEDInstruction
                        return retVal;
                }
 
-               protected abstract void setOutput(ExecutionContext ec,
-                       Future<FederatedResponse>[] response, FederationMap 
fedMap);
+               /**
+                * set the output by either calling setFedOutput to keep the 
output federated
+                * or calling aggResult to aggregate the partial results locally
+                */
+               protected void setOutput(ExecutionContext ec, 
Future<FederatedResponse>[] response,
+                       FederationMap fedMap, long frComputeID) {
+                       if(isFedOutput())
+                               setFedOutput(ec, fedMap, frComputeID);
+                       else
+                               aggResult(ec, response, fedMap);
+               }
+
+               // determine if the output can be kept on the federated sites
+               protected abstract boolean isFedOutput();
+               // set the output by deriving new a federated mapping
+               protected abstract void setFedOutput(ExecutionContext ec, 
FederationMap fedMap, long frComputeID);
+               // aggregate the partial results locally
+               protected abstract void aggResult(ExecutionContext ec, 
Future<FederatedResponse>[] response,
+                       FederationMap fedMap);
        }
 
+       // CELLWISE TEMPLATE
        private static class SpoofFEDCellwise extends SpoofFEDType {
                private final SpoofCellwise _op;
+               private final CellType _cellType;
 
-               SpoofFEDCellwise(SpoofOperator op, CPOperand out) {
-                       super(out);
+               SpoofFEDCellwise(SpoofOperator op, CPOperand out, FType 
fedType) {
+                       super(out, fedType);
                        _op = (SpoofCellwise)op;
+                       _cellType = _op.getCellType();
                }
 
-               protected void setOutput(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap) {
-                       FType fedType = fedMap.getType();
-                       AggOp aggOp = ((SpoofCellwise)_op).getAggOp();
-                       CellType cellType = ((SpoofCellwise)_op).getCellType();
-                       if(cellType == CellType.FULL_AGG) { // full aggregation
-                               AggregateUnaryOperator aop = null;
-                               if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
-                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-                               else if(aggOp == AggOp.MIN)
-                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
-                               else if(aggOp == AggOp.MAX)
-                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
-                               else
-                                       throw new 
DMLRuntimeException("Aggregation operation not supported yet.");
-                               ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
-                       }
-                       else if(cellType == CellType.ROW_AGG) { // row 
aggregation
-                               if(fedType == FType.ROW) {
-                                       // bind partial results from federated 
responses
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
-                               }
-                               else if(fedType == FType.COL) {
-                                       AggregateUnaryOperator aop = null;
-                                       if(aggOp == AggOp.SUM || aggOp == 
AggOp.SUM_SQ)
-                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
-                                       else if(aggOp == AggOp.MIN)
-                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
-                                       else if(aggOp == AggOp.MAX)
-                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
-                                       else
-                                               throw new 
DMLRuntimeException("Aggregation operation not supported yet.");
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
-                               }
-                               else {
-                                       throw new 
DMLRuntimeException("Aggregation type for federated spoof instructions not 
supported yet.");
-                               }
-                       }
-                       else if(cellType == CellType.COL_AGG) { // col 
aggregation
-                               if(fedType == FType.ROW) {
-                                       AggregateUnaryOperator aop = null;
-                                       if(aggOp == AggOp.SUM || aggOp == 
AggOp.SUM_SQ)
-                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
-                                       else if(aggOp == AggOp.MIN)
-                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
-                                       else if(aggOp == AggOp.MAX)
-                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
-                                       else
-                                               throw new 
DMLRuntimeException("Aggregation operation not supported yet.");
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
-                               }
-                               else if(fedType == FType.COL) {
-                                       // cbind partial results from federated 
responses
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, true));
-                               }
-                               else {
-                                       throw new 
DMLRuntimeException("Aggregation type for federated spoof instructions not 
supported yet.");
-                               }
+               protected boolean isFedOutput() {
+                       boolean retVal = false;
+                       retVal |= (_cellType == CellType.ROW_AGG && _fedType == 
FType.ROW);
+                       retVal |= (_cellType == CellType.COL_AGG && _fedType == 
FType.COL);
+                       retVal |= (_cellType == CellType.NO_AGG);
+                       return retVal;
+               }
+
+               protected void setFedOutput(ExecutionContext ec, FederationMap 
fedMap, long frComputeID) {
+                       // derive output federated mapping
+                       MatrixObject out = ec.getMatrixObject(_output);
+                       FederationMap newFedMap = 
modifyFedRanges(fedMap.copyWithNewID(frComputeID));
+                       out.setFedMapping(newFedMap);
+               }
+
+               private FederationMap modifyFedRanges(FederationMap fedMap) {
+                       if(_cellType == CellType.ROW_AGG || _cellType == 
CellType.COL_AGG) {
+                               int dim = (_cellType == CellType.COL_AGG ? 0 : 
1);
+                               // crop federation map to a vector
+                               IntStream.range(0, 
fedMap.getFederatedRanges().length).forEach(i -> {
+                                       
fedMap.getFederatedRanges()[i].setBeginDim(dim, 0);
+                                       
fedMap.getFederatedRanges()[i].setEndDim(dim, 1);
+                               });
                        }
-                       else if(cellType == CellType.NO_AGG) { // no aggregation
-                               if(fedType == FType.ROW) //rbind
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
-                               else if(fedType == FType.COL) //cbind
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, true));
-                               else
-                                       throw new DMLRuntimeException("Only row 
partitioned or column" +
-                                               " partitioned federated 
matrices supported yet.");
+                       return fedMap;
+               }
+
+               protected void aggResult(ExecutionContext ec, 
Future<FederatedResponse>[] response,
+                       FederationMap fedMap) {
+                       AggOp aggOp = _op.getAggOp();
+
+                       // build up the instruction for aggregation
+                       // 
(uak+/uamin/uamax/uark+/uarmin/uarmax/uack+/uacmin/uacmax)
+                       String aggInst = "ua";
+                       switch(_cellType) {
+                               case FULL_AGG: break;
+                               case ROW_AGG: aggInst += "r"; break;
+                               case COL_AGG: aggInst += "c"; break;
+                               case NO_AGG:
+                               default:
+                                       throw new 
DMLRuntimeException("Aggregation type not supported yet.");
                        }
-                       else {
-                               throw new DMLRuntimeException("Aggregation type 
not supported yet.");
+
+                       switch(aggOp) {
+                               case SUM:
+                               case SUM_SQ: aggInst += "k+"; break;
+                               case MIN:    aggInst += "min"; break;
+                               case MAX:    aggInst += "max"; break;
+                               default:
+                                       throw new 
DMLRuntimeException("Aggregation operation not supported yet.");
                        }
+
+                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator(aggInst);
+                       if(_cellType == CellType.FULL_AGG)
+                               ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
+                       else
+                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
                }
        }
 
+       // ROWWISE TEMPLATE
        private static class SpoofFEDRowwise extends SpoofFEDType {
                private final SpoofRowwise _op;
+               private final RowType _rowType;
 
-               SpoofFEDRowwise(SpoofOperator op, CPOperand out) {
-                       super(out);
+               SpoofFEDRowwise(SpoofOperator op, CPOperand out, FType fedType) 
{
+                       super(out, fedType);
                        _op = (SpoofRowwise)op;
+                       _rowType = _op.getRowType();
+               }
+
+               protected boolean isFedOutput() {
+                       boolean retVal = false;
+                       retVal |= (_rowType == RowType.NO_AGG);
+                       retVal |= (_rowType == RowType.NO_AGG_B1);
+                       retVal |= (_rowType == RowType.NO_AGG_CONST);
+                       retVal &= (_fedType == FType.ROW);
+                       return retVal;
                }
 
-               protected void setOutput(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap) {
-                       RowType rowType = ((SpoofRowwise)_op).getRowType();
-                       if(rowType == RowType.FULL_AGG) { // full aggregation
-                               // aggregate partial results from federated 
responses as sum
-                               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+               protected void setFedOutput(ExecutionContext ec, FederationMap 
fedMap, long frComputeID) {
+                       // derive output federated mapping
+                       MatrixObject out = ec.getMatrixObject(_output);
+                       FederationMap newFedMap = 
modifyFedRanges(fedMap.copyWithNewID(frComputeID), out.getNumColumns());
+                       out.setFedMapping(newFedMap);
+               }
+
+               private static FederationMap modifyFedRanges(FederationMap 
fedMap, long cols) {
+                       IntStream.range(0, 
fedMap.getFederatedRanges().length).forEach(i -> {
+                               fedMap.getFederatedRanges()[i].setBeginDim(1, 
0);
+                               fedMap.getFederatedRanges()[i].setEndDim(1, 
cols);
+                       });
+                       return fedMap;
+               }
+
+               protected void aggResult(ExecutionContext ec, 
Future<FederatedResponse>[] response,
+                       FederationMap fedMap) {
+                       if(_fedType != FType.ROW)
+                               throw new DMLRuntimeException("Only row 
partitioned federated matrices supported yet.");
+
+                       // build up the instruction for aggregation 
(uak+/uark+/uack+)
+                       String aggInst = "ua";
+                       if(_rowType == RowType.FULL_AGG) // full aggregation
+                               aggInst += "k+";
+                       else if(_rowType == RowType.ROW_AGG) // row aggregation
+                               aggInst += "rk+";
+                       else if(_rowType.isColumnAgg()) // col aggregation
+                               aggInst += "ck+";
+                       else
+                               throw new DMLRuntimeException("AggregationType 
not supported yet.");
+
+                       // aggregate partial results from federated responses 
as sum/rowSum/colSum
+                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator(aggInst);
+                       if(_rowType == RowType.FULL_AGG)
                                ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
-                       }
-                       else if(rowType == RowType.ROW_AGG) { // row aggregation
-                               // aggregate partial results from federated 
responses as rowSum
-                               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
-                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
-                       }
-                       else if(rowType == RowType.COL_AGG
-                               || rowType == RowType.COL_AGG_T
-                               || rowType == RowType.COL_AGG_B1
-                               || rowType == RowType.COL_AGG_B1_T
-                               || rowType == RowType.COL_AGG_B1R
-                               || rowType == RowType.COL_AGG_CONST) { // col 
aggregation
-                               // aggregate partial results from federated 
responses as colSum
-                               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+                       else
                                ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
-                       }
-                       else if(rowType == RowType.NO_AGG
-                               || rowType == RowType.NO_AGG_B1
-                               || rowType == RowType.NO_AGG_CONST) { // no 
aggregation
-                               if(fedMap.getType() == FType.ROW) {
-                                       // bind partial results from federated 
responses
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
-                               }
-                               else {
-                                       throw new DMLRuntimeException("Only row 
partitioned federated matrices supported yet.");
-                               }
-                       }
-                       else {
-                               throw new DMLRuntimeException("AggregationType 
not supported yet.");
-                       }
                }
        }
 
+       // MULTIAGGREGATE TEMPLATE
        private static class SpoofFEDMultiAgg extends SpoofFEDType {
                private final SpoofMultiAggregate _op;
 
-               SpoofFEDMultiAgg(SpoofOperator op, CPOperand out) {
-                       super(out);
+               SpoofFEDMultiAgg(SpoofOperator op, CPOperand out, FType 
fedType) {
+                       super(out, fedType);
                        _op = (SpoofMultiAggregate)op;
                }
 
-               protected void setOutput(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap) {
-                       MatrixBlock[] partRes = 
FederationUtils.getResults(response);
-                       SpoofCellwise.AggOp[] aggOps = 
((SpoofMultiAggregate)_op).getAggOps();
-                       for(int counter = 1; counter < partRes.length; 
counter++) {
-                               
SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], 
partRes[counter]);
-                       }
-                       ec.setMatrixOutput(_output.getName(), partRes[0]);
+               protected boolean isFedOutput() {
+                       return false;
                }
-       }
 
+               protected void setFedOutput(ExecutionContext ec, FederationMap 
fedMap, long frComputeID) {
+                       throw new DMLRuntimeException("SpoofFEDMultiAgg cannot 
create a federated output.");
+               }
 
+               protected void aggResult(ExecutionContext ec, 
Future<FederatedResponse>[] response,
+                       FederationMap fedMap) {
+                               MatrixBlock[] partRes = 
FederationUtils.getResults(response);
+                               SpoofCellwise.AggOp[] aggOps = _op.getAggOps();
+                               for(int counter = 1; counter < partRes.length; 
counter++) {
+                                       
SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], 
partRes[counter]);
+                               }
+                               ec.setMatrixOutput(_output.getName(), 
partRes[0]);
+                       }
+       }
+
+       // OUTER PRODUCT TEMPLATE
        private static class SpoofFEDOuterProduct extends SpoofFEDType {
                private final SpoofOuterProduct _op;
+               private final OutProdType _outProdType;
+               private CPOperand[] _inputs;
 
-               SpoofFEDOuterProduct(SpoofOperator op, CPOperand out) {
-                       super(out);
+               SpoofFEDOuterProduct(SpoofOperator op, CPOperand out, FType 
fedType, CPOperand[] inputs) {
+                       super(out, fedType);
                        _op = (SpoofOuterProduct)op;
+                       _outProdType = _op.getOuterProdType();
+                       _inputs = inputs;
+               }
+
+               protected FederatedRequest[] broadcastSliced(MatrixObject mo, 
FederationMap fedMap) {
+                       return fedMap.broadcastSliced(mo, (_fedType == 
FType.COL));
                }
 
                protected boolean needsBroadcastSliced(FederationMap fedMap, 
long rowNum, long colNum, int inputIndex) {
                        boolean retVal = false;
-                       FType fedType = fedMap.getType();
                        
                        retVal |= (rowNum == fedMap.getMaxIndexInRange(0) && 
colNum == fedMap.getMaxIndexInRange(1));
                        
-                       if(fedType == FType.ROW)
+                       if(_fedType == FType.ROW)
                                retVal |= (rowNum == 
fedMap.getMaxIndexInRange(0)) && (inputIndex != 2); // input at index 2 is V
-                       else if(fedType == FType.COL)
+                       else if(_fedType == FType.COL)
                                retVal |= (rowNum == 
fedMap.getMaxIndexInRange(1)) && (inputIndex != 1); // input at index 1 is U
                        else
                                throw new DMLRuntimeException("Only row 
partitioned or column" +
                                        " partitioned federated input supported 
yet.");
-                       
+
                        return retVal;
                }
 
-               protected void setOutput(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap) {
-                       FType fedType = fedMap.getType();
-                       OutProdType outProdType = 
((SpoofOuterProduct)_op).getOuterProdType();
-                       if(outProdType == OutProdType.LEFT_OUTER_PRODUCT) {
-                               if(fedType == FType.ROW) {
-                                       // aggregate partial results from 
federated responses as elementwise sum
-                                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
-                               }
-                               else if(fedType == FType.COL) {
-                                       // bind partial results from federated 
responses
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
-                               }
-                               else {
-                                       throw new DMLRuntimeException("Only row 
partitioned or column" +
-                                               " partitioned federated 
matrices supported yet.");
-                               }
+               protected boolean isFedOutput() {
+                       boolean retVal = false;
+                       retVal |= (_outProdType == 
OutProdType.LEFT_OUTER_PRODUCT && _fedType == FType.COL);
+                       retVal |= (_outProdType == 
OutProdType.RIGHT_OUTER_PRODUCT && _fedType == FType.ROW);
+                       retVal |= (_outProdType == 
OutProdType.CELLWISE_OUTER_PRODUCT);
+                       return retVal;
+               }
+
+               protected void setFedOutput(ExecutionContext ec, FederationMap 
fedMap, long frComputeID) {
+                       FederationMap newFedMap = 
fedMap.copyWithNewID(frComputeID);
+                       long[] outDims = new long[2];
+
+                       // find the resulting output dimensions
+                       MatrixObject X = ec.getMatrixObject(_inputs[0]);
+                       switch(_outProdType) {
+                               case LEFT_OUTER_PRODUCT: // LEFT: nrows of 
transposed X, ncols of U
+                                       newFedMap = newFedMap.transpose();
+                                       outDims[0] = X.getNumColumns();
+                                       outDims[1] = 
ec.getMatrixObject(_inputs[1]).getNumColumns();
+                                       break;
+                               case RIGHT_OUTER_PRODUCT: // RIGHT: nrows of X, 
ncols of V
+                                       outDims[0] = X.getNumRows();
+                                       outDims[1] = 
ec.getMatrixObject(_inputs[2]).getNumColumns();
+                                       break;
+                               case CELLWISE_OUTER_PRODUCT: // BASIC: preserve 
dimensions of X
+                                       outDims[0] = X.getNumRows();
+                                       outDims[1] = X.getNumColumns();
+                                       break;
+                               default:
+                                       throw new DMLRuntimeException("Outer 
Product Type " + _outProdType + " not supported yet.");
                        }
-                       else if(outProdType == OutProdType.RIGHT_OUTER_PRODUCT) 
{
-                               if(fedType == FType.ROW) {
-                                       // bind partial results from federated 
responses
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
-                               }
-                               else if(fedType == FType.COL) {
+
+                       // derive output federated mapping
+                       MatrixObject out = ec.getMatrixObject(_output);
+                       int dim = (newFedMap.getType() == FType.ROW ? 1 : 0);
+                       newFedMap = modifyFedRanges(newFedMap, dim, 
outDims[dim]);
+                       out.setFedMapping(newFedMap);
+               }
+
+               private static FederationMap modifyFedRanges(FederationMap 
fedMap, int dim, long value) {
+                       IntStream.range(0, 
fedMap.getFederatedRanges().length).forEach(i -> {
+                               fedMap.getFederatedRanges()[i].setBeginDim(dim, 
0);
+                               fedMap.getFederatedRanges()[i].setEndDim(dim, 
value);
+                       });
+                       return fedMap;
+               }
+
+               protected void aggResult(ExecutionContext ec, 
Future<FederatedResponse>[] response,
+                       FederationMap fedMap) {
+                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+                       switch(_outProdType) {
+                               case LEFT_OUTER_PRODUCT:
+                               case RIGHT_OUTER_PRODUCT:
                                        // aggregate partial results from 
federated responses as elementwise sum
-                                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                                        ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
-                               }
-                               else {
-                                       throw new DMLRuntimeException("Only row 
partitioned or column" +
-                                               " partitioned federated 
matrices supported yet.");
-                               }
-                       }
-                       else if(outProdType == 
OutProdType.CELLWISE_OUTER_PRODUCT) {
-                               if(fedType == FType.ROW) {
-                                       // rbind partial results from federated 
responses
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
-                               }
-                               else if(fedType == FType.COL) {
-                                       // cbind partial results from federated 
responses
-                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, true));
-                               }
-                               else {
-                                       throw new DMLRuntimeException("Only row 
partitioned or column" +
-                                               " partitioned federated 
matrices supported yet.");
-                               }
-                       }
-                       else if(outProdType == OutProdType.AGG_OUTER_PRODUCT) {
-                               // aggregate partial results from federated 
responses as sum
-                               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-                               ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
-                       }
-                       else {
-                               throw new DMLRuntimeException("Outer Product 
Type " + outProdType + " not supported yet.");
+                                       break;
+                               case AGG_OUTER_PRODUCT:
+                                       // aggregate partial results from 
federated responses as sum
+                                       ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
+                                       break;
+                               default:
+                                       throw new DMLRuntimeException("Outer 
Product Type " + _outProdType + " not supported yet.");
                        }
                }
        }
@@ -458,5 +528,4 @@ public class SpoofFEDInstruction extends FEDInstruction
                }
                return retVal;
        }
-
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
index 65f1728..61722db 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
@@ -104,7 +104,7 @@ public class FederatedCodegenMultipleFedMOTest extends 
AutomatedTestBase
                        // row partitioned
                        // {201, 6, 4, 6, 4, true},
                        {202, 6, 4, 6, 4, true},
-                       // {203, 20, 1, 20, 1, true},
+                       // FIXME: [SYSTEMDS-3110] {203, 20, 1, 20, 1, true},
                        // col partitioned
                        {201, 6, 4, 6, 4, false},
                        {202, 6, 4, 6, 4, false},
@@ -123,9 +123,9 @@ public class FederatedCodegenMultipleFedMOTest extends 
AutomatedTestBase
                        {308, 1000, 2000, 10, 2000, false},
                        // {310, 1000, 2000, 10, 2000, false},
                        // row and col partitioned
-                       // {311, 1000, 2000, 1000, 10, true}, // not working 
yet - ArrayIndexOutOfBoundsException in dotProduct
+                       // {311, 1000, 2000, 1000, 10, true}, // FIXME: 
ArrayIndexOutOfBoundsException in dotProduct
                        {312, 1000, 2000, 10, 2000, false},
-                       // {313, 4000, 2000, 4000, 10, true}, // not working 
yet - ArrayIndexOutOfBoundsException in dotProduct
+                       // {313, 4000, 2000, 4000, 10, true}, // FIXME: 
ArrayIndexOutOfBoundsException in dotProduct
                        {314, 4000, 2000, 10, 2000, false},
 
                        // combined tests
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
index edc9ab7..cef5fd5 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
@@ -86,14 +86,14 @@ public class FederatedOuterProductTmplTest extends 
AutomatedTestBase
                        {9, 1000, 2000, true},
 
                        // column partitioned
-                       //FIXME {1, 2000, 2000, false},
+                       {1, 2000, 2000, false},
                        // {2, 4000, 2000, false},
                        // {3, 1000, 1000, false},
-                       //FIXME {4, 4000, 2000, false},
-                       //FIXME {5, 4000, 2000, false},
+                       {4, 4000, 2000, false},
+                       {5, 4000, 2000, false},
                        // {6, 4000, 2000, false},
                        //FIXME {7, 2000, 2000, false},
-                       //FIXME {8, 1000, 2000, false},
+                       {8, 1000, 2000, false},
                        // {9, 1000, 2000, false},
                });
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
index b4bff76..89475d8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
@@ -117,7 +117,7 @@ public class FederatedRowwiseTmplTest extends 
AutomatedTestBase
        }
 
        @Test
-       public void federatedCodegenCellwiseHybrid() {
+       public void federatedCodegenRowwiseHybrid() {
                testFederatedCodegenRowwise(ExecMode.HYBRID);
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
 
b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
index 71160b7..f2e873c 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
@@ -25,7 +25,6 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
 import org.junit.Ignore;
-import org.junit.Test;
 
 public class BuiltinTopkEvaluateTest extends AutomatedTestBase {
        //      private final static String TEST_NAME1 = "prioritized";

Reply via email to