This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 11965c8  [SYSTEMDS-2922] Federated codegen outer-product operations
11965c8 is described below

commit 11965c802b9f563056a2b4c2bafa7896bce09235
Author: ywcb00 <[email protected]>
AuthorDate: Sat May 22 21:34:34 2021 +0200

    [SYSTEMDS-2922] Federated codegen outer-product operations
    
    Closes #1283.
---
 .../instructions/fed/FEDInstructionUtils.java      |   9 +-
 .../instructions/fed/SpoofFEDInstruction.java      | 388 ++++++++++++++-------
 .../codegen/FederatedOuterProductTmplTest.java     | 200 +++++++++++
 .../codegen/FederatedOuterProductTmplTest.dml      | 108 ++++++
 .../FederatedOuterProductTmplTestReference.dml     | 106 ++++++
 5 files changed, 674 insertions(+), 137 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 1ff7fbf..721143f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.fed;
 
 import org.apache.sysds.runtime.codegen.SpoofCellwise;
 import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
+import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
 import org.apache.sysds.runtime.codegen.SpoofRowwise;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
@@ -229,8 +230,8 @@ public class FEDInstructionUtils {
                else if(inst instanceof SpoofCPInstruction) {
                        SpoofCPInstruction instruction = (SpoofCPInstruction) 
inst;
                        Class<?> scla = 
instruction.getOperatorClass().getSuperclass();
-                       if(((scla == SpoofCellwise.class || scla == 
SpoofMultiAggregate.class)
-                                       && instruction.isFederated(ec))
+                       if(((scla == SpoofCellwise.class || scla == 
SpoofMultiAggregate.class
+                                               || scla == 
SpoofOuterProduct.class) && instruction.isFederated(ec))
                                || (scla == SpoofRowwise.class && 
instruction.isFederated(ec, FType.ROW))) {
                                fedinst = 
SpoofFEDInstruction.parseInstruction(instruction.getInstructionString());
                        }
@@ -339,8 +340,8 @@ public class FEDInstructionUtils {
                else if(inst instanceof SpoofSPInstruction) {
                        SpoofSPInstruction instruction = (SpoofSPInstruction) 
inst;
                        Class<?> scla = 
instruction.getOperatorClass().getSuperclass();
-                       if(((scla == SpoofCellwise.class || scla == 
SpoofMultiAggregate.class)
-                                       && instruction.isFederated(ec))
+                       if(((scla == SpoofCellwise.class || scla == 
SpoofMultiAggregate.class
+                                               || scla == 
SpoofOuterProduct.class) && instruction.isFederated(ec))
                                || (scla == SpoofRowwise.class && 
instruction.isFederated(ec, FType.ROW))) {
                                fedinst = 
SpoofFEDInstruction.parseInstruction(inst.getInstructionString());
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index 01d6051..13b1785 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -24,10 +24,12 @@ import org.apache.sysds.runtime.codegen.CodegenUtils;
 import org.apache.sysds.runtime.codegen.SpoofCellwise;
 import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;
 import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;
-import org.apache.sysds.runtime.codegen.SpoofRowwise;
-import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
 import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
 import org.apache.sysds.runtime.codegen.SpoofOperator;
+import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
+import org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;
+import org.apache.sysds.runtime.codegen.SpoofRowwise;
+import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
@@ -54,16 +56,14 @@ public class SpoofFEDInstruction extends FEDInstruction
        private final CPOperand _output;
 
        private SpoofFEDInstruction(SpoofOperator op, CPOperand[] in,
-               CPOperand out, String opcode, String instStr)
-       {
+               CPOperand out, String opcode, String instStr) {
                super(FEDInstruction.FEDType.SpoofFused, opcode, instStr);
                _op = op;
                _inputs = in;
                _output = out;
        }
 
-       public static SpoofFEDInstruction parseInstruction(String str)
-       {
+       public static SpoofFEDInstruction parseInstruction(String str) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
 
                CPOperand[] inputCpo = new CPOperand[parts.length - 3 - 2];
@@ -79,8 +79,21 @@ public class SpoofFEDInstruction extends FEDInstruction
        }
 
        @Override
-       public void processInstruction(ExecutionContext ec)
-       {
+       public void processInstruction(ExecutionContext ec) {
+               Class<?> scla = _op.getClass().getSuperclass();
+               SpoofFEDType spoofType = null;
+               if(scla == SpoofCellwise.class)
+                       spoofType = new SpoofFEDCellwise(_op, _output);
+               else if(scla == SpoofRowwise.class)
+                       spoofType = new SpoofFEDRowwise(_op, _output);
+               else if(scla == SpoofMultiAggregate.class)
+                       spoofType = new SpoofFEDMultiAgg(_op, _output);
+               else if(scla == SpoofOuterProduct.class)
+                       spoofType = new SpoofFEDOuterProduct(_op, _output);
+               else
+                       throw new DMLRuntimeException("Federated code 
generation only supported" +
+                               " for cellwise, rowwise, multiaggregate, and 
outerproduct templates.");
+
                ArrayList<CPOperand> inCpoMat = new ArrayList<>();
                ArrayList<CPOperand> inCpoScal = new ArrayList<>();
                ArrayList<MatrixObject> inMo = new ArrayList<>();
@@ -112,8 +125,8 @@ public class SpoofFEDInstruction extends FEDInstruction
                int index = 0;
                frIds[index++] = fedMap.getID(); // insert federation map id at 
the beginning
                for(MatrixObject mo : inMo) {
-                       if(needsBroadcastSliced(fedMap, mo.getNumRows(), 
mo.getNumColumns())) {
-                               FederatedRequest[] tmpFr = 
fedMap.broadcastSliced(mo, false);
+                       if(spoofType.needsBroadcastSliced(fedMap, 
mo.getNumRows(), mo.getNumColumns(), index)) {
+                               FederatedRequest[] tmpFr = 
spoofType.broadcastSliced(mo, fedMap);
                                frIds[index++] = tmpFr[0].getID();
                                frBroadcastSliced.add(tmpFr);
                        }
@@ -132,7 +145,8 @@ public class SpoofFEDInstruction extends FEDInstruction
                // change the is_literal flag from true to false because when 
broadcasted it is not a literal anymore
                instString = instString.replace("true", "false");
 
-               CPOperand[] inCpo = ArrayUtils.addAll(inCpoMat.toArray(new 
CPOperand[0]), inCpoScal.toArray(new CPOperand[0]));
+               CPOperand[] inCpo = ArrayUtils.addAll(inCpoMat.toArray(new 
CPOperand[0]),
+                       inCpoScal.toArray(new CPOperand[0]));
                FederatedRequest frCompute = 
FederationUtils.callInstruction(instString, _output, inCpo, frIds);
 
                // get partial results from federated workers
@@ -151,163 +165,271 @@ public class SpoofFEDInstruction extends FEDInstruction
                Future<FederatedResponse>[] response = 
fedMap.executeMultipleSlices(
                        getTID(), true, frBroadcastSliced.toArray(new 
FederatedRequest[0][]), frAll);
 
-               if(_op.getClass().getSuperclass() == SpoofCellwise.class)
-                       setOutputCellwise(ec, response, fedMap);
-               else if(_op.getClass().getSuperclass() == SpoofRowwise.class)
-                       setOutputRowwise(ec, response, fedMap);
-
-               else if(_op.getClass().getSuperclass() == 
SpoofMultiAggregate.class)
-                       setOutputMultiAgg(ec, response, fedMap);
-               else
-                       throw new DMLRuntimeException("Federated code 
generation only supported for cellwise, rowwise, and multiaggregate 
templates.");
+               // setting the output with respect to the different aggregation 
types
+               // of the different spoof templates
+               spoofType.setOutput(ec, response, fedMap);
        }
 
-       private static boolean needsBroadcastSliced(FederationMap fedMap, long 
rowNum, long colNum) {
-               if(rowNum == fedMap.getMaxIndexInRange(0) && colNum == 
fedMap.getMaxIndexInRange(1))
-                       return true;
 
-               if(fedMap.getType() == FType.ROW) {
-                       return (rowNum == fedMap.getMaxIndexInRange(0) && 
(colNum == 1 || colNum == fedMap.getSize()))
-                               || (colNum > 1 && rowNum == fedMap.getSize());
+       private static abstract class SpoofFEDType {
+               CPOperand _output;
+
+               protected SpoofFEDType(CPOperand out) {
+                       _output = out;
                }
-               else if(fedMap.getType() == FType.COL) {
-                       return ((rowNum == 1 || rowNum == fedMap.getSize()) && 
colNum == fedMap.getMaxIndexInRange(1))
-                               || (rowNum > 1 && colNum == fedMap.getSize());
+               
+               protected FederatedRequest[] broadcastSliced(MatrixObject mo, 
FederationMap fedMap) {
+                       return fedMap.broadcastSliced(mo, false);
+               }
+
+               protected boolean needsBroadcastSliced(FederationMap fedMap, 
long rowNum, long colNum, int inputIndex) {
+                       FType fedType = fedMap.getType();
+                       boolean retVal = (rowNum == 
fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1));
+                       if(fedType == FType.ROW)
+                               retVal |= (rowNum == 
fedMap.getMaxIndexInRange(0) && (colNum == 1 || colNum == fedMap.getSize()));
+                       else if(fedType == FType.COL)
+                               retVal |= ((rowNum == 1 || rowNum == 
fedMap.getSize()) && colNum == fedMap.getMaxIndexInRange(1));
+                       else
+                               throw new DMLRuntimeException("Only row 
partitioned or column" +
+                                       " partitioned federated input supported 
yet.");
+                       return retVal;
                }
-               throw new DMLRuntimeException("Only row partitioned or column 
partitioned federated input supported yet.");
+
+               protected abstract void setOutput(ExecutionContext ec,
+                       Future<FederatedResponse>[] response, FederationMap 
fedMap);
        }
 
-       private void setOutputCellwise(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap)
-       {
-               FType fedType = fedMap.getType();
-               AggOp aggOp = ((SpoofCellwise)_op).getAggOp();
-               CellType cellType = ((SpoofCellwise)_op).getCellType();
-               if(cellType == CellType.FULL_AGG) { // full aggregation
-                       AggregateUnaryOperator aop = null;
-                       if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ) {
-                               // aggregate partial results from federated 
responses as sum
-                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-                       }
-                       else if(aggOp == AggOp.MIN) {
-                               // aggregate partial results from federated 
responses as min
-                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
-                       }
-                       else if(aggOp == AggOp.MAX) {
-                               // aggregate partial results from federated 
responses as max
-                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
-                       }
-                       else {
-                               throw new DMLRuntimeException("Aggregation 
operation not supported yet.");
-                       }
-                       ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
+       private static class SpoofFEDCellwise extends SpoofFEDType {
+               private final SpoofCellwise _op;
+
+               SpoofFEDCellwise(SpoofOperator op, CPOperand out) {
+                       super(out);
+                       _op = (SpoofCellwise)op;
                }
-               else if(cellType == CellType.ROW_AGG) { // row aggregation
-                       if(fedType == FType.ROW) {
-                               // bind partial results from federated responses
-                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
-                       }
-                       else if(fedType == FType.COL) {
+
+               protected void setOutput(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap) {
+                       FType fedType = fedMap.getType();
+                       AggOp aggOp = ((SpoofCellwise)_op).getAggOp();
+                       CellType cellType = ((SpoofCellwise)_op).getCellType();
+                       if(cellType == CellType.FULL_AGG) { // full aggregation
                                AggregateUnaryOperator aop = null;
                                if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
-                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
+                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                                else if(aggOp == AggOp.MIN)
-                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
+                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
                                else if(aggOp == AggOp.MAX)
-                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
+                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
                                else
                                        throw new 
DMLRuntimeException("Aggregation operation not supported yet.");
-                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+                               ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
                        }
-                       else {
-                               throw new DMLRuntimeException("Aggregation type 
for federated spoof instructions not supported yet.");
+                       else if(cellType == CellType.ROW_AGG) { // row 
aggregation
+                               if(fedType == FType.ROW) {
+                                       // bind partial results from federated 
responses
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
+                               }
+                               else if(fedType == FType.COL) {
+                                       AggregateUnaryOperator aop = null;
+                                       if(aggOp == AggOp.SUM || aggOp == 
AggOp.SUM_SQ)
+                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
+                                       else if(aggOp == AggOp.MIN)
+                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
+                                       else if(aggOp == AggOp.MAX)
+                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
+                                       else
+                                               throw new 
DMLRuntimeException("Aggregation operation not supported yet.");
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+                               }
+                               else {
+                                       throw new 
DMLRuntimeException("Aggregation type for federated spoof instructions not 
supported yet.");
+                               }
                        }
-               }
-               else if(cellType == CellType.COL_AGG) { // col aggregation
-                       if(fedType == FType.ROW) {
-                               AggregateUnaryOperator aop = null;
-                               if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
-                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
-                               else if(aggOp == AggOp.MIN)
-                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
-                               else if(aggOp == AggOp.MAX)
-                                       aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
-                               else
-                                       throw new 
DMLRuntimeException("Aggregation operation not supported yet.");
-                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+                       else if(cellType == CellType.COL_AGG) { // col 
aggregation
+                               if(fedType == FType.ROW) {
+                                       AggregateUnaryOperator aop = null;
+                                       if(aggOp == AggOp.SUM || aggOp == 
AggOp.SUM_SQ)
+                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+                                       else if(aggOp == AggOp.MIN)
+                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
+                                       else if(aggOp == AggOp.MAX)
+                                               aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
+                                       else
+                                               throw new 
DMLRuntimeException("Aggregation operation not supported yet.");
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+                               }
+                               else if(fedType == FType.COL) {
+                                       // cbind partial results from federated 
responses
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, true));
+                               }
+                               else {
+                                       throw new 
DMLRuntimeException("Aggregation type for federated spoof instructions not 
supported yet.");
+                               }
                        }
-                       else if(fedType == FType.COL) {
-                               // bind partial results from federated responses
-                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, true));
+                       else if(cellType == CellType.NO_AGG) { // no aggregation
+                               if(fedType == FType.ROW) //rbind
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
+                               else if(fedType == FType.COL) //cbind
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, true));
+                               else
+                                       throw new DMLRuntimeException("Only row 
partitioned or column" +
+                                               " partitioned federated 
matrices supported yet.");
                        }
                        else {
-                               throw new DMLRuntimeException("Aggregation type 
for federated spoof instructions not supported yet.");
+                               throw new DMLRuntimeException("Aggregation type 
not supported yet.");
                        }
                }
-               else if(cellType == CellType.NO_AGG) { // no aggregation
-                       if(fedType == FType.ROW) {
-                               // bind partial results from federated responses
-                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
+       }
+
+       private static class SpoofFEDRowwise extends SpoofFEDType {
+               private final SpoofRowwise _op;
+
+               SpoofFEDRowwise(SpoofOperator op, CPOperand out) {
+                       super(out);
+                       _op = (SpoofRowwise)op;
+               }
+
+               protected void setOutput(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap) {
+                       RowType rowType = ((SpoofRowwise)_op).getRowType();
+                       if(rowType == RowType.FULL_AGG) { // full aggregation
+                               // aggregate partial results from federated 
responses as sum
+                               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+                               ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
+                       }
+                       else if(rowType == RowType.ROW_AGG) { // row aggregation
+                               // aggregate partial results from federated 
responses as rowSum
+                               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
+                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
                        }
-                       else if(fedType == FType.COL) {
-                               // bind partial results from federated responses
-                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, true));
+                       else if(rowType == RowType.COL_AGG
+                               || rowType == RowType.COL_AGG_T
+                               || rowType == RowType.COL_AGG_B1
+                               || rowType == RowType.COL_AGG_B1_T
+                               || rowType == RowType.COL_AGG_B1R
+                               || rowType == RowType.COL_AGG_CONST) { // col 
aggregation
+                               // aggregate partial results from federated 
responses as colSum
+                               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+                       }
+                       else if(rowType == RowType.NO_AGG
+                               || rowType == RowType.NO_AGG_B1
+                               || rowType == RowType.NO_AGG_CONST) { // no 
aggregation
+                               if(fedMap.getType() == FType.ROW) {
+                                       // bind partial results from federated 
responses
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
+                               }
+                               else {
+                                       throw new DMLRuntimeException("Only row 
partitioned federated matrices supported yet.");
+                               }
                        }
                        else {
-                               throw new DMLRuntimeException("Only row 
partitioned or column partitioned federated matrices supported yet.");
+                               throw new DMLRuntimeException("AggregationType 
not supported yet.");
                        }
                }
-               else {
-                       throw new DMLRuntimeException("Aggregation type not 
supported yet.");
+       }
+
+       private static class SpoofFEDMultiAgg extends SpoofFEDType {
+               private final SpoofMultiAggregate _op;
+
+               SpoofFEDMultiAgg(SpoofOperator op, CPOperand out) {
+                       super(out);
+                       _op = (SpoofMultiAggregate)op;
+               }
+
+               protected void setOutput(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap) {
+                       MatrixBlock[] partRes = 
FederationUtils.getResults(response);
+                       SpoofCellwise.AggOp[] aggOps = 
((SpoofMultiAggregate)_op).getAggOps();
+                       for(int counter = 1; counter < partRes.length; 
counter++) {
+                               
SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], 
partRes[counter]);
+                       }
+                       ec.setMatrixOutput(_output.getName(), partRes[0]);
                }
        }
 
-       private void setOutputRowwise(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap)
-       {
-               RowType rowType = ((SpoofRowwise)_op).getRowType();
-               if(rowType == RowType.FULL_AGG) { // full aggregation
-                       // aggregate partial results from federated responses 
as sum
-                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
-                       ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
+
+       private static class SpoofFEDOuterProduct extends SpoofFEDType {
+               private final SpoofOuterProduct _op;
+
+               SpoofFEDOuterProduct(SpoofOperator op, CPOperand out) {
+                       super(out);
+                       _op = (SpoofOuterProduct)op;
                }
-               else if(rowType == RowType.ROW_AGG) { // row aggregation
-                       // aggregate partial results from federated responses 
as rowSum
-                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
-                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+
+               protected FederatedRequest[] broadcastSliced(MatrixObject mo, 
FederationMap fedMap) {
+                       return fedMap.broadcastSliced(mo, (fedMap.getType() == 
FType.COL));
                }
-               else if(rowType == RowType.COL_AGG
-                       || rowType == RowType.COL_AGG_T
-                       || rowType == RowType.COL_AGG_B1
-                       || rowType == RowType.COL_AGG_B1_T
-                       || rowType == RowType.COL_AGG_B1R
-                       || rowType == RowType.COL_AGG_CONST) { // col 
aggregation
-                       // aggregate partial results from federated responses 
as colSum
-                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
-                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+
+               protected boolean needsBroadcastSliced(FederationMap fedMap, 
long rowNum, long colNum, int inputIndex) {
+                       boolean retVal = false;
+                       FType fedType = fedMap.getType();
+                       
+                       retVal |= (rowNum == fedMap.getMaxIndexInRange(0) && 
colNum == fedMap.getMaxIndexInRange(1));
+                       
+                       if(fedType == FType.ROW)
+                               retVal |= (rowNum == 
fedMap.getMaxIndexInRange(0)) && (inputIndex != 2); // input at index 2 is V
+                       else if(fedType == FType.COL)
+                               retVal |= (rowNum == 
fedMap.getMaxIndexInRange(1)) && (inputIndex != 1); // input at index 1 is U
+                       else
+                               throw new DMLRuntimeException("Only row 
partitioned or column" +
+                                       " partitioned federated input supported 
yet.");
+                       
+                       return retVal;
                }
-               else if(rowType == RowType.NO_AGG
-                       || rowType == RowType.NO_AGG_B1
-                       || rowType == RowType.NO_AGG_CONST) { // no aggregation
-                       if(fedMap.getType() == FType.ROW) {
-                               // bind partial results from federated responses
-                               ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
+
+               protected void setOutput(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap) {
+                       FType fedType = fedMap.getType();
+                       OutProdType outProdType = 
((SpoofOuterProduct)_op).getOuterProdType();
+                       if(outProdType == OutProdType.LEFT_OUTER_PRODUCT) {
+                               if(fedType == FType.ROW) {
+                                       // aggregate partial results from 
federated responses as elementwise sum
+                                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+                               }
+                               else if(fedType == FType.COL) {
+                                       // bind partial results from federated 
responses
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
+                               }
+                               else {
+                                       throw new DMLRuntimeException("Only row 
partitioned or column" +
+                                               " partitioned federated 
matrices supported yet.");
+                               }
+                       }
+                       else if(outProdType == OutProdType.RIGHT_OUTER_PRODUCT) 
{
+                               if(fedType == FType.ROW) {
+                                       // bind partial results from federated 
responses
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
+                               }
+                               else if(fedType == FType.COL) {
+                                       // aggregate partial results from 
federated responses as elementwise sum
+                                       AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
+                               }
+                               else {
+                                       throw new DMLRuntimeException("Only row 
partitioned or column" +
+                                               " partitioned federated 
matrices supported yet.");
+                               }
+                       }
+                       else if(outProdType == 
OutProdType.CELLWISE_OUTER_PRODUCT) {
+                               if(fedType == FType.ROW) {
+                                       // rbind partial results from federated 
responses
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, false));
+                               }
+                               else if(fedType == FType.COL) {
+                                       // cbind partial results from federated 
responses
+                                       ec.setMatrixOutput(_output.getName(), 
FederationUtils.bind(response, true));
+                               }
+                               else {
+                                       throw new DMLRuntimeException("Only row 
partitioned or column" +
+                                               " partitioned federated 
matrices supported yet.");
+                               }
+                       }
+                       else if(outProdType == OutProdType.AGG_OUTER_PRODUCT) {
+                               // aggregate partial results from federated 
responses as sum
+                               AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+                               ec.setVariable(_output.getName(), 
FederationUtils.aggScalar(aop, response));
                        }
                        else {
-                               throw new DMLRuntimeException("Only row 
partitioned federated matrices supported yet.");
+                               throw new DMLRuntimeException("Outer Product 
Type " + outProdType + " not supported yet.");
                        }
                }
-               else {
-                       throw new DMLRuntimeException("AggregationType not 
supported yet.");
-               }
        }
-       
-       private void setOutputMultiAgg(ExecutionContext ec, 
Future<FederatedResponse>[] response, FederationMap fedMap)
-       {
-               MatrixBlock[] partRes = FederationUtils.getResults(response);
-               SpoofCellwise.AggOp[] aggOps = 
((SpoofMultiAggregate)_op).getAggOps();
-               for(int counter = 1; counter < partRes.length; counter++) {
-                       SpoofMultiAggregate.aggregatePartialResults(aggOps, 
partRes[0], partRes[counter]);
-               }
-               ec.setMatrixOutput(_output.getName(), partRes[0]);
-       }
-
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
new file mode 100644
index 0000000..d3460da
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.codegen;
+
+import java.io.File;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedOuterProductTmplTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME = "FederatedOuterProductTmplTest";
+
+       private final static String TEST_DIR = "functions/federated/codegen/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedOuterProductTmplTest.class.getSimpleName() + "/";
+
+       private final static String TEST_CONF = "SystemDS-config-codegen.xml";
+
+       private final static String OUTPUT_NAME = "Z";
+       private final static double TOLERANCE = 1e-7;
+       private final static int BLOCKSIZE = 1024;
+
+       @Parameterized.Parameter()
+       public int test_num;
+       @Parameterized.Parameter(1)
+       public int rows;
+       @Parameterized.Parameter(2)
+       public int cols;
+       @Parameterized.Parameter(3)
+       public boolean row_partitioned;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{OUTPUT_NAME}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows must be even for row partitioned X
+               // cols must be even for col partitioned X
+               return Arrays.asList(new Object[][] {
+                       // {test_num, rows, cols, row_partitioned}
+
+                       // row partitioned
+                       {1, 2000, 2000, true},
+                       {2, 4000, 2000, true},
+                       {3, 1000, 1000, true},
+                       {4, 4000, 2000, true},
+                       // {5, 4000, 2000, true},
+                       {6, 4000, 2000, true},
+                       // {7, 2000, 2000, true},
+                       // {8, 1000, 2000, true},
+                       {9, 1000, 2000, true},
+
+                       // column partitioned
+                       {1, 2000, 2000, false},
+                       // {2, 4000, 2000, false},
+                       // {3, 1000, 1000, false},
+                       {4, 4000, 2000, false},
+                       {5, 4000, 2000, false},
+                       // {6, 4000, 2000, false},
+                       {7, 2000, 2000, false},
+                       {8, 1000, 2000, false},
+                       // {9, 1000, 2000, false},
+               });
+       }
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       @Test
+       @Ignore
+       public void federatedCodegenOuterProductSingleNode() {
+               testFederatedCodegenOuterProduct(ExecMode.SINGLE_NODE);
+       }
+       
+       @Test
+       @Ignore
+       public void federatedCodegenOuterProductSpark() {
+               testFederatedCodegenOuterProduct(ExecMode.SPARK);
+       }
+       
+       @Test
+       public void federatedCodegenOuterProductHybrid() {
+               testFederatedCodegenOuterProduct(ExecMode.HYBRID);
+       }
+       
+       private void testFederatedCodegenOuterProduct(ExecMode exec_mode) {
+               // store the previous platform config to restore it after the 
test
+               ExecMode platform_old = setExecMode(exec_mode);
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               int fed_rows = rows;
+               int fed_cols = cols;
+               if(row_partitioned)
+                       fed_rows /= 2;
+               else
+                       fed_cols /= 2;
+
+               // generate dataset
+               // matrix handled by two federated workers
+               double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1, 
3);
+               double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1, 
7);
+
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread thread1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+               Thread thread2 = startLocalFedWorkerThread(port2);
+
+               getAndLoadTestConfiguration(TEST_NAME);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X1=" + input("X1"), "in_X2=" + input("X2"),
+                       "in_rp=" + 
Boolean.toString(row_partitioned).toUpperCase(),
+                       "in_test_num=" + Integer.toString(test_num),
+                       "out_Z=" + expected(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_rp=" + 
Boolean.toString(row_partitioned).toUpperCase(),
+                       "in_test_num=" + Integer.toString(test_num),
+                       "rows=" + rows, "cols=" + cols,
+                       "out_Z=" + output(OUTPUT_NAME)};
+               runTest(true, false, null, -1);
+
+               // compare the results via files
+               HashMap<CellIndex, Double> refResults = 
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+               HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+               TestUtils.compareMatrices(fedResults, refResults, TOLERANCE, 
"Fed", "Ref");
+
+               TestUtils.shutdownThreads(thread1, thread2);
+
+               // check for federated operations
+               Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               resetExecMode(platform_old);
+       }
+
+       /**
+        * Override default configuration with custom test configuration to 
ensure
+        * scratch space and local temporary directory locations are also 
updated.
+        */
+       @Override
+       protected File getConfigTemplateFile() {
+               // Instrumentation in this test's output log to show custom 
configuration file used for template.
+               File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, 
TEST_CONF);
+               return TEST_CONF_FILE;
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTest.dml
 
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTest.dml
new file mode 100644
index 0000000..1f724ca
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTest.dml
@@ -0,0 +1,108 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+  X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)));
+}
+else {
+  X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows, $cols / 2), list(0, $cols / 2), 
list($rows, $cols)));
+}
+
+if(test_num == 1) { # wcemm
+  # X ... 2000x2000 matrix
+  
+  U = matrix(seq(1, 20000), rows=2000, cols=10);
+  V = matrix(seq(20001, 40000), rows=2000, cols=10);
+  eps = 0.1;
+  Z = as.matrix(sum(X * log(U %*% t(V) + eps)));
+}
+else if(test_num == 2) { # wdivmm
+  # X ... 4000x2000 matrix
+  
+  U = matrix(seq(1, 40000), rows=4000, cols=10);
+  V = matrix(seq(51, 20050), rows=2000, cols=10);
+  eps = 0.1;
+  Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+else if(test_num == 3) { # wdivmmbasic
+  # X 1000x1000 matrix
+
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  V = matrix(seq(-7499, 2500), rows=1000, cols=10);
+  eps = 0.1;
+  Z = X / ((U %*% t(V)) + eps);
+}
+else if(test_num == 4) { # wdivmmNeq
+  # X ...4000x2000 matrix
+  
+  U = matrix(seq(1, 40000), rows=4000, cols=10) / 1000;
+  V = matrix(seq(501, 20500), rows=2000, cols=10) / 1000;
+  eps = 0.1;
+  Z = ((X!=0) * (U %*% t(V) + eps)) %*% V;
+}
+else if(test_num == 5) { # wdivmmRight
+  # X ... 4000x2000 matrix
+
+  U = matrix(seq(1, 40000), rows=4000, cols=10);
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  eps = 0.1;
+  Z = (X / (U %*% t(V))) %*% V;
+}
+else if(test_num == 6) { # wdivmmRightNotranspose
+  # X ... 4000x2000 matrix
+  
+  U = matrix(seq(1, 40000), rows=4000, cols=10);
+  V = matrix(seq(-1, 19998), rows=10, cols=2000);
+  eps = 0.1;
+  Z = (X / ((U %*% V) + eps)) %*% t(V);
+}
+else if(test_num == 7) { # wdivmmTransposeOut
+  # X ... 2000x2000 matrix
+
+  U = matrix(seq(600, 20599), rows=2000, cols=10);
+  V = matrix(seq(0, 19999), rows=10, cols=2000);
+  eps = 0.1;
+  Z = (t(U) %*% (X / ((U %*% V) + eps)));
+}
+else if(test_num == 8) { # wsigmoid
+  # X ... 1000x2000 matrix
+
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  eps = 0.1;
+  Z = X * (1 / (1 + exp(-(U %*% t(V)))));
+}
+else if(test_num == 9) { #wdivmmLeftEps
+  # X ... 1000x2000 matrix
+
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  eps = 0.4;
+
+  Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+
+write(Z, $out_Z);
diff --git 
a/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
 
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
new file mode 100644
index 0000000..e592dcb
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
@@ -0,0 +1,106 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+  X = rbind(read($in_X1), read($in_X2));
+}
+else {
+  X = cbind(read($in_X1), read($in_X2));
+}
+
+if(test_num == 1) { # wcemm
+  # X ... 2000x2000 matrix
+  
+  U = matrix(seq(1, 20000), rows=2000, cols=10);
+  V = matrix(seq(20001, 40000), rows=2000, cols=10);
+  eps = 0.1;
+  Z = as.matrix(sum(X * log(U %*% t(V) + eps)));
+}
+else if(test_num == 2) { # wdivmm
+  # X ... 4000x2000 matrix
+  
+  U = matrix(seq(1, 40000), rows=4000, cols=10);
+  V = matrix(seq(51, 20050), rows=2000, cols=10);
+  eps = 0.1;
+  Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+else if(test_num == 3) { # wdivmmbasic
+  # X 1000x1000 matrix
+
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  V = matrix(seq(-7499, 2500), rows=1000, cols=10);
+  eps = 0.1;
+  Z = X / ((U %*% t(V)) + eps);
+}
+else if(test_num == 4) { # wdivmmNeq
+  # X ...4000x2000 matrix
+  
+  U = matrix(seq(1, 40000), rows=4000, cols=10) / 1000;
+  V = matrix(seq(501, 20500), rows=2000, cols=10) / 1000;
+  eps = 0.1;
+  Z = ((X!=0) * (U %*% t(V) + eps)) %*% V;
+}
+else if(test_num == 5) { # wdivmmRight
+  # X ... 4000x2000 matrix
+
+  U = matrix(seq(1, 40000), rows=4000, cols=10);
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  eps = 0.1;
+  Z = (X / (U %*% t(V))) %*% V;
+}
+else if(test_num == 6) { # wdivmmRightNotranspose
+  # X ... 4000x2000 matrix
+  
+  U = matrix(seq(1, 40000), rows=4000, cols=10);
+  V = matrix(seq(-1, 19998), rows=10, cols=2000);
+  eps = 0.1;
+  Z = (X / ((U %*% V) + eps)) %*% t(V);
+}
+else if(test_num == 7) { # wdivmmTransposeOut
+  # X ... 2000x2000 matrix
+
+  U = matrix(seq(600, 20599), rows=2000, cols=10);
+  V = matrix(seq(0, 19999), rows=10, cols=2000);
+  eps = 0.1;
+  Z = (t(U) %*% (X / ((U %*% V) + eps)));
+}
+else if(test_num == 8) { # wsigmoid
+  # X ... 1000x2000 matrix
+
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  eps = 0.1;
+  Z = X * (1 / (1 + exp(-(U %*% t(V)))));
+}
+else if(test_num == 9) { #wdivmmLeftEps
+  # X ... 1000x2000 matrix
+
+  U = matrix(seq(1, 10000), rows=1000, cols=10);
+  V = matrix(seq(1, 20000), rows=2000, cols=10);
+  eps = 0.4;
+
+  Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+
+write(Z, $out_Z);

Reply via email to