This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 11965c8 [SYSTEMDS-2922] Federated codegen outer-product operations
11965c8 is described below
commit 11965c802b9f563056a2b4c2bafa7896bce09235
Author: ywcb00 <[email protected]>
AuthorDate: Sat May 22 21:34:34 2021 +0200
[SYSTEMDS-2922] Federated codegen outer-product operations
Closes #1283.
---
.../instructions/fed/FEDInstructionUtils.java | 9 +-
.../instructions/fed/SpoofFEDInstruction.java | 388 ++++++++++++++-------
.../codegen/FederatedOuterProductTmplTest.java | 200 +++++++++++
.../codegen/FederatedOuterProductTmplTest.dml | 108 ++++++
.../FederatedOuterProductTmplTestReference.dml | 106 ++++++
5 files changed, 674 insertions(+), 137 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 1ff7fbf..721143f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
+import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
@@ -229,8 +230,8 @@ public class FEDInstructionUtils {
else if(inst instanceof SpoofCPInstruction) {
SpoofCPInstruction instruction = (SpoofCPInstruction)
inst;
Class<?> scla =
instruction.getOperatorClass().getSuperclass();
- if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class)
- && instruction.isFederated(ec))
+ if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class
+ || scla ==
SpoofOuterProduct.class) && instruction.isFederated(ec))
|| (scla == SpoofRowwise.class &&
instruction.isFederated(ec, FType.ROW))) {
fedinst =
SpoofFEDInstruction.parseInstruction(instruction.getInstructionString());
}
@@ -339,8 +340,8 @@ public class FEDInstructionUtils {
else if(inst instanceof SpoofSPInstruction) {
SpoofSPInstruction instruction = (SpoofSPInstruction)
inst;
Class<?> scla =
instruction.getOperatorClass().getSuperclass();
- if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class)
- && instruction.isFederated(ec))
+ if(((scla == SpoofCellwise.class || scla ==
SpoofMultiAggregate.class
+ || scla ==
SpoofOuterProduct.class) && instruction.isFederated(ec))
|| (scla == SpoofRowwise.class &&
instruction.isFederated(ec, FType.ROW))) {
fedinst =
SpoofFEDInstruction.parseInstruction(inst.getInstructionString());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index 01d6051..13b1785 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -24,10 +24,12 @@ import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;
import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;
-import org.apache.sysds.runtime.codegen.SpoofRowwise;
-import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofOperator;
+import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
+import org.apache.sysds.runtime.codegen.SpoofOuterProduct.OutProdType;
+import org.apache.sysds.runtime.codegen.SpoofRowwise;
+import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
@@ -54,16 +56,14 @@ public class SpoofFEDInstruction extends FEDInstruction
private final CPOperand _output;
private SpoofFEDInstruction(SpoofOperator op, CPOperand[] in,
- CPOperand out, String opcode, String instStr)
- {
+ CPOperand out, String opcode, String instStr) {
super(FEDInstruction.FEDType.SpoofFused, opcode, instStr);
_op = op;
_inputs = in;
_output = out;
}
- public static SpoofFEDInstruction parseInstruction(String str)
- {
+ public static SpoofFEDInstruction parseInstruction(String str) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
CPOperand[] inputCpo = new CPOperand[parts.length - 3 - 2];
@@ -79,8 +79,21 @@ public class SpoofFEDInstruction extends FEDInstruction
}
@Override
- public void processInstruction(ExecutionContext ec)
- {
+ public void processInstruction(ExecutionContext ec) {
+ Class<?> scla = _op.getClass().getSuperclass();
+ SpoofFEDType spoofType = null;
+ if(scla == SpoofCellwise.class)
+ spoofType = new SpoofFEDCellwise(_op, _output);
+ else if(scla == SpoofRowwise.class)
+ spoofType = new SpoofFEDRowwise(_op, _output);
+ else if(scla == SpoofMultiAggregate.class)
+ spoofType = new SpoofFEDMultiAgg(_op, _output);
+ else if(scla == SpoofOuterProduct.class)
+ spoofType = new SpoofFEDOuterProduct(_op, _output);
+ else
+ throw new DMLRuntimeException("Federated code
generation only supported" +
+ " for cellwise, rowwise, multiaggregate, and
outerproduct templates.");
+
ArrayList<CPOperand> inCpoMat = new ArrayList<>();
ArrayList<CPOperand> inCpoScal = new ArrayList<>();
ArrayList<MatrixObject> inMo = new ArrayList<>();
@@ -112,8 +125,8 @@ public class SpoofFEDInstruction extends FEDInstruction
int index = 0;
frIds[index++] = fedMap.getID(); // insert federation map id at
the beginning
for(MatrixObject mo : inMo) {
- if(needsBroadcastSliced(fedMap, mo.getNumRows(),
mo.getNumColumns())) {
- FederatedRequest[] tmpFr =
fedMap.broadcastSliced(mo, false);
+ if(spoofType.needsBroadcastSliced(fedMap,
mo.getNumRows(), mo.getNumColumns(), index)) {
+ FederatedRequest[] tmpFr =
spoofType.broadcastSliced(mo, fedMap);
frIds[index++] = tmpFr[0].getID();
frBroadcastSliced.add(tmpFr);
}
@@ -132,7 +145,8 @@ public class SpoofFEDInstruction extends FEDInstruction
// change the is_literal flag from true to false because when
broadcasted it is not a literal anymore
instString = instString.replace("true", "false");
- CPOperand[] inCpo = ArrayUtils.addAll(inCpoMat.toArray(new
CPOperand[0]), inCpoScal.toArray(new CPOperand[0]));
+ CPOperand[] inCpo = ArrayUtils.addAll(inCpoMat.toArray(new
CPOperand[0]),
+ inCpoScal.toArray(new CPOperand[0]));
FederatedRequest frCompute =
FederationUtils.callInstruction(instString, _output, inCpo, frIds);
// get partial results from federated workers
@@ -151,163 +165,271 @@ public class SpoofFEDInstruction extends FEDInstruction
Future<FederatedResponse>[] response =
fedMap.executeMultipleSlices(
getTID(), true, frBroadcastSliced.toArray(new
FederatedRequest[0][]), frAll);
- if(_op.getClass().getSuperclass() == SpoofCellwise.class)
- setOutputCellwise(ec, response, fedMap);
- else if(_op.getClass().getSuperclass() == SpoofRowwise.class)
- setOutputRowwise(ec, response, fedMap);
-
- else if(_op.getClass().getSuperclass() ==
SpoofMultiAggregate.class)
- setOutputMultiAgg(ec, response, fedMap);
- else
- throw new DMLRuntimeException("Federated code
generation only supported for cellwise, rowwise, and multiaggregate
templates.");
+ // setting the output with respect to the different aggregation
types
+ // of the different spoof templates
+ spoofType.setOutput(ec, response, fedMap);
}
- private static boolean needsBroadcastSliced(FederationMap fedMap, long
rowNum, long colNum) {
- if(rowNum == fedMap.getMaxIndexInRange(0) && colNum ==
fedMap.getMaxIndexInRange(1))
- return true;
- if(fedMap.getType() == FType.ROW) {
- return (rowNum == fedMap.getMaxIndexInRange(0) &&
(colNum == 1 || colNum == fedMap.getSize()))
- || (colNum > 1 && rowNum == fedMap.getSize());
+ private static abstract class SpoofFEDType {
+ CPOperand _output;
+
+ protected SpoofFEDType(CPOperand out) {
+ _output = out;
}
- else if(fedMap.getType() == FType.COL) {
- return ((rowNum == 1 || rowNum == fedMap.getSize()) &&
colNum == fedMap.getMaxIndexInRange(1))
- || (rowNum > 1 && colNum == fedMap.getSize());
+
+ protected FederatedRequest[] broadcastSliced(MatrixObject mo,
FederationMap fedMap) {
+ return fedMap.broadcastSliced(mo, false);
+ }
+
+ protected boolean needsBroadcastSliced(FederationMap fedMap,
long rowNum, long colNum, int inputIndex) {
+ FType fedType = fedMap.getType();
+ boolean retVal = (rowNum ==
fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1));
+ if(fedType == FType.ROW)
+ retVal |= (rowNum ==
fedMap.getMaxIndexInRange(0) && (colNum == 1 || colNum == fedMap.getSize()));
+ else if(fedType == FType.COL)
+ retVal |= ((rowNum == 1 || rowNum ==
fedMap.getSize()) && colNum == fedMap.getMaxIndexInRange(1));
+ else
+ throw new DMLRuntimeException("Only row
partitioned or column" +
+ " partitioned federated input supported
yet.");
+ return retVal;
}
- throw new DMLRuntimeException("Only row partitioned or column
partitioned federated input supported yet.");
+
+ protected abstract void setOutput(ExecutionContext ec,
+ Future<FederatedResponse>[] response, FederationMap
fedMap);
}
- private void setOutputCellwise(ExecutionContext ec,
Future<FederatedResponse>[] response, FederationMap fedMap)
- {
- FType fedType = fedMap.getType();
- AggOp aggOp = ((SpoofCellwise)_op).getAggOp();
- CellType cellType = ((SpoofCellwise)_op).getCellType();
- if(cellType == CellType.FULL_AGG) { // full aggregation
- AggregateUnaryOperator aop = null;
- if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ) {
- // aggregate partial results from federated
responses as sum
- aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- }
- else if(aggOp == AggOp.MIN) {
- // aggregate partial results from federated
responses as min
- aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
- }
- else if(aggOp == AggOp.MAX) {
- // aggregate partial results from federated
responses as max
- aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
- }
- else {
- throw new DMLRuntimeException("Aggregation
operation not supported yet.");
- }
- ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+ private static class SpoofFEDCellwise extends SpoofFEDType {
+ private final SpoofCellwise _op;
+
+ SpoofFEDCellwise(SpoofOperator op, CPOperand out) {
+ super(out);
+ _op = (SpoofCellwise)op;
}
- else if(cellType == CellType.ROW_AGG) { // row aggregation
- if(fedType == FType.ROW) {
- // bind partial results from federated responses
- ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
- }
- else if(fedType == FType.COL) {
+
+ protected void setOutput(ExecutionContext ec,
Future<FederatedResponse>[] response, FederationMap fedMap) {
+ FType fedType = fedMap.getType();
+ AggOp aggOp = ((SpoofCellwise)_op).getAggOp();
+ CellType cellType = ((SpoofCellwise)_op).getCellType();
+ if(cellType == CellType.FULL_AGG) { // full aggregation
AggregateUnaryOperator aop = null;
if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
- aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
else if(aggOp == AggOp.MIN)
- aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
else if(aggOp == AggOp.MAX)
- aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
else
throw new
DMLRuntimeException("Aggregation operation not supported yet.");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
}
- else {
- throw new DMLRuntimeException("Aggregation type
for federated spoof instructions not supported yet.");
+ else if(cellType == CellType.ROW_AGG) { // row
aggregation
+ if(fedType == FType.ROW) {
+ // bind partial results from federated
responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ }
+ else if(fedType == FType.COL) {
+ AggregateUnaryOperator aop = null;
+ if(aggOp == AggOp.SUM || aggOp ==
AggOp.SUM_SQ)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
+ else if(aggOp == AggOp.MIN)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
+ else if(aggOp == AggOp.MAX)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
+ else
+ throw new
DMLRuntimeException("Aggregation operation not supported yet.");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else {
+ throw new
DMLRuntimeException("Aggregation type for federated spoof instructions not
supported yet.");
+ }
}
- }
- else if(cellType == CellType.COL_AGG) { // col aggregation
- if(fedType == FType.ROW) {
- AggregateUnaryOperator aop = null;
- if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
- aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
- else if(aggOp == AggOp.MIN)
- aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
- else if(aggOp == AggOp.MAX)
- aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
- else
- throw new
DMLRuntimeException("Aggregation operation not supported yet.");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ else if(cellType == CellType.COL_AGG) { // col
aggregation
+ if(fedType == FType.ROW) {
+ AggregateUnaryOperator aop = null;
+ if(aggOp == AggOp.SUM || aggOp ==
AggOp.SUM_SQ)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+ else if(aggOp == AggOp.MIN)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
+ else if(aggOp == AggOp.MAX)
+ aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
+ else
+ throw new
DMLRuntimeException("Aggregation operation not supported yet.");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(fedType == FType.COL) {
+ // cbind partial results from federated
responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, true));
+ }
+ else {
+ throw new
DMLRuntimeException("Aggregation type for federated spoof instructions not
supported yet.");
+ }
}
- else if(fedType == FType.COL) {
- // bind partial results from federated responses
- ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, true));
+ else if(cellType == CellType.NO_AGG) { // no aggregation
+ if(fedType == FType.ROW) //rbind
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ else if(fedType == FType.COL) //cbind
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, true));
+ else
+ throw new DMLRuntimeException("Only row
partitioned or column" +
+ " partitioned federated
matrices supported yet.");
}
else {
- throw new DMLRuntimeException("Aggregation type
for federated spoof instructions not supported yet.");
+ throw new DMLRuntimeException("Aggregation type
not supported yet.");
}
}
- else if(cellType == CellType.NO_AGG) { // no aggregation
- if(fedType == FType.ROW) {
- // bind partial results from federated responses
- ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ }
+
+ private static class SpoofFEDRowwise extends SpoofFEDType {
+ private final SpoofRowwise _op;
+
+ SpoofFEDRowwise(SpoofOperator op, CPOperand out) {
+ super(out);
+ _op = (SpoofRowwise)op;
+ }
+
+ protected void setOutput(ExecutionContext ec,
Future<FederatedResponse>[] response, FederationMap fedMap) {
+ RowType rowType = ((SpoofRowwise)_op).getRowType();
+ if(rowType == RowType.FULL_AGG) { // full aggregation
+ // aggregate partial results from federated
responses as sum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+ }
+ else if(rowType == RowType.ROW_AGG) { // row aggregation
+ // aggregate partial results from federated
responses as rowSum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
}
- else if(fedType == FType.COL) {
- // bind partial results from federated responses
- ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, true));
+ else if(rowType == RowType.COL_AGG
+ || rowType == RowType.COL_AGG_T
+ || rowType == RowType.COL_AGG_B1
+ || rowType == RowType.COL_AGG_B1_T
+ || rowType == RowType.COL_AGG_B1R
+ || rowType == RowType.COL_AGG_CONST) { // col
aggregation
+ // aggregate partial results from federated
responses as colSum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(rowType == RowType.NO_AGG
+ || rowType == RowType.NO_AGG_B1
+ || rowType == RowType.NO_AGG_CONST) { // no
aggregation
+ if(fedMap.getType() == FType.ROW) {
+ // bind partial results from federated
responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ }
+ else {
+ throw new DMLRuntimeException("Only row
partitioned federated matrices supported yet.");
+ }
}
else {
- throw new DMLRuntimeException("Only row
partitioned or column partitioned federated matrices supported yet.");
+ throw new DMLRuntimeException("AggregationType
not supported yet.");
}
}
- else {
- throw new DMLRuntimeException("Aggregation type not
supported yet.");
+ }
+
+ private static class SpoofFEDMultiAgg extends SpoofFEDType {
+ private final SpoofMultiAggregate _op;
+
+ SpoofFEDMultiAgg(SpoofOperator op, CPOperand out) {
+ super(out);
+ _op = (SpoofMultiAggregate)op;
+ }
+
+ protected void setOutput(ExecutionContext ec,
Future<FederatedResponse>[] response, FederationMap fedMap) {
+ MatrixBlock[] partRes =
FederationUtils.getResults(response);
+ SpoofCellwise.AggOp[] aggOps =
((SpoofMultiAggregate)_op).getAggOps();
+ for(int counter = 1; counter < partRes.length;
counter++) {
+
SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0],
partRes[counter]);
+ }
+ ec.setMatrixOutput(_output.getName(), partRes[0]);
}
}
- private void setOutputRowwise(ExecutionContext ec,
Future<FederatedResponse>[] response, FederationMap fedMap)
- {
- RowType rowType = ((SpoofRowwise)_op).getRowType();
- if(rowType == RowType.FULL_AGG) { // full aggregation
- // aggregate partial results from federated responses
as sum
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
+
+ private static class SpoofFEDOuterProduct extends SpoofFEDType {
+ private final SpoofOuterProduct _op;
+
+ SpoofFEDOuterProduct(SpoofOperator op, CPOperand out) {
+ super(out);
+ _op = (SpoofOuterProduct)op;
}
- else if(rowType == RowType.ROW_AGG) { // row aggregation
- // aggregate partial results from federated responses
as rowSum
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+
+ protected FederatedRequest[] broadcastSliced(MatrixObject mo,
FederationMap fedMap) {
+ return fedMap.broadcastSliced(mo, (fedMap.getType() ==
FType.COL));
}
- else if(rowType == RowType.COL_AGG
- || rowType == RowType.COL_AGG_T
- || rowType == RowType.COL_AGG_B1
- || rowType == RowType.COL_AGG_B1_T
- || rowType == RowType.COL_AGG_B1R
- || rowType == RowType.COL_AGG_CONST) { // col
aggregation
- // aggregate partial results from federated responses
as colSum
- AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
- ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+
+ protected boolean needsBroadcastSliced(FederationMap fedMap,
long rowNum, long colNum, int inputIndex) {
+ boolean retVal = false;
+ FType fedType = fedMap.getType();
+
+ retVal |= (rowNum == fedMap.getMaxIndexInRange(0) &&
colNum == fedMap.getMaxIndexInRange(1));
+
+ if(fedType == FType.ROW)
+ retVal |= (rowNum ==
fedMap.getMaxIndexInRange(0)) && (inputIndex != 2); // input at index 2 is V
+ else if(fedType == FType.COL)
+ retVal |= (rowNum ==
fedMap.getMaxIndexInRange(1)) && (inputIndex != 1); // input at index 1 is U
+ else
+ throw new DMLRuntimeException("Only row
partitioned or column" +
+ " partitioned federated input supported
yet.");
+
+ return retVal;
}
- else if(rowType == RowType.NO_AGG
- || rowType == RowType.NO_AGG_B1
- || rowType == RowType.NO_AGG_CONST) { // no aggregation
- if(fedMap.getType() == FType.ROW) {
- // bind partial results from federated responses
- ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+
+ protected void setOutput(ExecutionContext ec,
Future<FederatedResponse>[] response, FederationMap fedMap) {
+ FType fedType = fedMap.getType();
+ OutProdType outProdType =
((SpoofOuterProduct)_op).getOuterProdType();
+ if(outProdType == OutProdType.LEFT_OUTER_PRODUCT) {
+ if(fedType == FType.ROW) {
+ // aggregate partial results from
federated responses as elementwise sum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else if(fedType == FType.COL) {
+ // bind partial results from federated
responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ }
+ else {
+ throw new DMLRuntimeException("Only row
partitioned or column" +
+ " partitioned federated
matrices supported yet.");
+ }
+ }
+ else if(outProdType == OutProdType.RIGHT_OUTER_PRODUCT)
{
+ if(fedType == FType.ROW) {
+ // bind partial results from federated
responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ }
+ else if(fedType == FType.COL) {
+ // aggregate partial results from
federated responses as elementwise sum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.aggMatrix(aop, response, fedMap));
+ }
+ else {
+ throw new DMLRuntimeException("Only row
partitioned or column" +
+ " partitioned federated
matrices supported yet.");
+ }
+ }
+ else if(outProdType ==
OutProdType.CELLWISE_OUTER_PRODUCT) {
+ if(fedType == FType.ROW) {
+ // rbind partial results from federated
responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, false));
+ }
+ else if(fedType == FType.COL) {
+ // cbind partial results from federated
responses
+ ec.setMatrixOutput(_output.getName(),
FederationUtils.bind(response, true));
+ }
+ else {
+ throw new DMLRuntimeException("Only row
partitioned or column" +
+ " partitioned federated
matrices supported yet.");
+ }
+ }
+ else if(outProdType == OutProdType.AGG_OUTER_PRODUCT) {
+ // aggregate partial results from federated
responses as sum
+ AggregateUnaryOperator aop =
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ ec.setVariable(_output.getName(),
FederationUtils.aggScalar(aop, response));
}
else {
- throw new DMLRuntimeException("Only row
partitioned federated matrices supported yet.");
+ throw new DMLRuntimeException("Outer Product
Type " + outProdType + " not supported yet.");
}
}
- else {
- throw new DMLRuntimeException("AggregationType not
supported yet.");
- }
}
-
- private void setOutputMultiAgg(ExecutionContext ec,
Future<FederatedResponse>[] response, FederationMap fedMap)
- {
- MatrixBlock[] partRes = FederationUtils.getResults(response);
- SpoofCellwise.AggOp[] aggOps =
((SpoofMultiAggregate)_op).getAggOps();
- for(int counter = 1; counter < partRes.length; counter++) {
- SpoofMultiAggregate.aggregatePartialResults(aggOps,
partRes[0], partRes[counter]);
- }
- ec.setMatrixOutput(_output.getName(), partRes[0]);
- }
-
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
new file mode 100644
index 0000000..d3460da
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.codegen;
+
+import java.io.File;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Ignore;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedOuterProductTmplTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME = "FederatedOuterProductTmplTest";
+
+ private final static String TEST_DIR = "functions/federated/codegen/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedOuterProductTmplTest.class.getSimpleName() + "/";
+
+ private final static String TEST_CONF = "SystemDS-config-codegen.xml";
+
+ private final static String OUTPUT_NAME = "Z";
+ private final static double TOLERANCE = 1e-7;
+ private final static int BLOCKSIZE = 1024;
+
+ @Parameterized.Parameter()
+ public int test_num;
+ @Parameterized.Parameter(1)
+ public int rows;
+ @Parameterized.Parameter(2)
+ public int cols;
+ @Parameterized.Parameter(3)
+ public boolean row_partitioned;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{OUTPUT_NAME}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows must be even for row partitioned X
+ // cols must be even for col partitioned X
+ return Arrays.asList(new Object[][] {
+ // {test_num, rows, cols, row_partitioned}
+
+ // row partitioned
+ {1, 2000, 2000, true},
+ {2, 4000, 2000, true},
+ {3, 1000, 1000, true},
+ {4, 4000, 2000, true},
+ // {5, 4000, 2000, true},
+ {6, 4000, 2000, true},
+ // {7, 2000, 2000, true},
+ // {8, 1000, 2000, true},
+ {9, 1000, 2000, true},
+
+ // column partitioned
+ {1, 2000, 2000, false},
+ // {2, 4000, 2000, false},
+ // {3, 1000, 1000, false},
+ {4, 4000, 2000, false},
+ {5, 4000, 2000, false},
+ // {6, 4000, 2000, false},
+ {7, 2000, 2000, false},
+ {8, 1000, 2000, false},
+ // {9, 1000, 2000, false},
+ });
+ }
+
+ @BeforeClass
+ public static void init() {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @Test
+ @Ignore
+ public void federatedCodegenOuterProductSingleNode() {
+ testFederatedCodegenOuterProduct(ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void federatedCodegenOuterProductSpark() {
+ testFederatedCodegenOuterProduct(ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedCodegenOuterProductHybrid() {
+ testFederatedCodegenOuterProduct(ExecMode.HYBRID);
+ }
+
+ private void testFederatedCodegenOuterProduct(ExecMode exec_mode) {
+ // store the previous platform config to restore it after the
test
+ ExecMode platform_old = setExecMode(exec_mode);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int fed_rows = rows;
+ int fed_cols = cols;
+ if(row_partitioned)
+ fed_rows /= 2;
+ else
+ fed_cols /= 2;
+
+ // generate dataset
+ // matrix handled by two federated workers
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1,
3);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1,
7);
+
+ writeInputMatrixWithMTD("X1", X1, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+ writeInputMatrixWithMTD("X2", X2, false, new
MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread thread1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ Thread thread2 = startLocalFedWorkerThread(port2);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + input("X1"), "in_X2=" + input("X2"),
+ "in_rp=" +
Boolean.toString(row_partitioned).toUpperCase(),
+ "in_test_num=" + Integer.toString(test_num),
+ "out_Z=" + expected(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_rp=" +
Boolean.toString(row_partitioned).toUpperCase(),
+ "in_test_num=" + Integer.toString(test_num),
+ "rows=" + rows, "cols=" + cols,
+ "out_Z=" + output(OUTPUT_NAME)};
+ runTest(true, false, null, -1);
+
+ // compare the results via files
+ HashMap<CellIndex, Double> refResults =
readDMLMatrixFromExpectedDir(OUTPUT_NAME);
+ HashMap<CellIndex, Double> fedResults =
readDMLMatrixFromOutputDir(OUTPUT_NAME);
+ TestUtils.compareMatrices(fedResults, refResults, TOLERANCE,
"Fed", "Ref");
+
+ TestUtils.shutdownThreads(thread1, thread2);
+
+ // check for federated operations
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofOP"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ resetExecMode(platform_old);
+ }
+
+ /**
+ * Override default configuration with custom test configuration to
ensure
+ * scratch space and local temporary directory locations are also
updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom
configuration file used for template.
+ File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR,
TEST_CONF);
+ return TEST_CONF_FILE;
+ }
+}
diff --git
a/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTest.dml
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTest.dml
new file mode 100644
index 0000000..1f724ca
--- /dev/null
+++
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTest.dml
@@ -0,0 +1,108 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0),
list($rows, $cols)));
+}
+else {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols / 2), list(0, $cols / 2),
list($rows, $cols)));
+}
+
+if(test_num == 1) { # wcemm
+ # X ... 2000x2000 matrix
+
+ U = matrix(seq(1, 20000), rows=2000, cols=10);
+ V = matrix(seq(20001, 40000), rows=2000, cols=10);
+ eps = 0.1;
+ Z = as.matrix(sum(X * log(U %*% t(V) + eps)));
+}
+else if(test_num == 2) { # wdivmm
+ # X ... 4000x2000 matrix
+
+ U = matrix(seq(1, 40000), rows=4000, cols=10);
+ V = matrix(seq(51, 20050), rows=2000, cols=10);
+ eps = 0.1;
+ Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+else if(test_num == 3) { # wdivmmbasic
+ # X 1000x1000 matrix
+
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ V = matrix(seq(-7499, 2500), rows=1000, cols=10);
+ eps = 0.1;
+ Z = X / ((U %*% t(V)) + eps);
+}
+else if(test_num == 4) { # wdivmmNeq
+ # X ...4000x2000 matrix
+
+ U = matrix(seq(1, 40000), rows=4000, cols=10) / 1000;
+ V = matrix(seq(501, 20500), rows=2000, cols=10) / 1000;
+ eps = 0.1;
+ Z = ((X!=0) * (U %*% t(V) + eps)) %*% V;
+}
+else if(test_num == 5) { # wdivmmRight
+ # X ... 4000x2000 matrix
+
+ U = matrix(seq(1, 40000), rows=4000, cols=10);
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ eps = 0.1;
+ Z = (X / (U %*% t(V))) %*% V;
+}
+else if(test_num == 6) { # wdivmmRightNotranspose
+ # X ... 4000x2000 matrix
+
+ U = matrix(seq(1, 40000), rows=4000, cols=10);
+ V = matrix(seq(-1, 19998), rows=10, cols=2000);
+ eps = 0.1;
+ Z = (X / ((U %*% V) + eps)) %*% t(V);
+}
+else if(test_num == 7) { # wdivmmTransposeOut
+ # X ... 2000x2000 matrix
+
+ U = matrix(seq(600, 20599), rows=2000, cols=10);
+ V = matrix(seq(0, 19999), rows=10, cols=2000);
+ eps = 0.1;
+ Z = (t(U) %*% (X / ((U %*% V) + eps)));
+}
+else if(test_num == 8) { # wsigmoid
+ # X ... 1000x2000 matrix
+
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ eps = 0.1;
+ Z = X * (1 / (1 + exp(-(U %*% t(V)))));
+}
+else if(test_num == 9) { #wdivmmLeftEps
+ # X ... 1000x2000 matrix
+
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ eps = 0.4;
+
+ Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+
+write(Z, $out_Z);
diff --git
a/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
new file mode 100644
index 0000000..e592dcb
--- /dev/null
+++
b/src/test/scripts/functions/federated/codegen/FederatedOuterProductTmplTestReference.dml
@@ -0,0 +1,106 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = rbind(read($in_X1), read($in_X2));
+}
+else {
+ X = cbind(read($in_X1), read($in_X2));
+}
+
+if(test_num == 1) { # wcemm
+ # X ... 2000x2000 matrix
+
+ U = matrix(seq(1, 20000), rows=2000, cols=10);
+ V = matrix(seq(20001, 40000), rows=2000, cols=10);
+ eps = 0.1;
+ Z = as.matrix(sum(X * log(U %*% t(V) + eps)));
+}
+else if(test_num == 2) { # wdivmm
+ # X ... 4000x2000 matrix
+
+ U = matrix(seq(1, 40000), rows=4000, cols=10);
+ V = matrix(seq(51, 20050), rows=2000, cols=10);
+ eps = 0.1;
+ Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+else if(test_num == 3) { # wdivmmbasic
+ # X 1000x1000 matrix
+
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ V = matrix(seq(-7499, 2500), rows=1000, cols=10);
+ eps = 0.1;
+ Z = X / ((U %*% t(V)) + eps);
+}
+else if(test_num == 4) { # wdivmmNeq
+ # X ...4000x2000 matrix
+
+ U = matrix(seq(1, 40000), rows=4000, cols=10) / 1000;
+ V = matrix(seq(501, 20500), rows=2000, cols=10) / 1000;
+ eps = 0.1;
+ Z = ((X!=0) * (U %*% t(V) + eps)) %*% V;
+}
+else if(test_num == 5) { # wdivmmRight
+ # X ... 4000x2000 matrix
+
+ U = matrix(seq(1, 40000), rows=4000, cols=10);
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ eps = 0.1;
+ Z = (X / (U %*% t(V))) %*% V;
+}
+else if(test_num == 6) { # wdivmmRightNotranspose
+ # X ... 4000x2000 matrix
+
+ U = matrix(seq(1, 40000), rows=4000, cols=10);
+ V = matrix(seq(-1, 19998), rows=10, cols=2000);
+ eps = 0.1;
+ Z = (X / ((U %*% V) + eps)) %*% t(V);
+}
+else if(test_num == 7) { # wdivmmTransposeOut
+ # X ... 2000x2000 matrix
+
+ U = matrix(seq(600, 20599), rows=2000, cols=10);
+ V = matrix(seq(0, 19999), rows=10, cols=2000);
+ eps = 0.1;
+ Z = (t(U) %*% (X / ((U %*% V) + eps)));
+}
+else if(test_num == 8) { # wsigmoid
+ # X ... 1000x2000 matrix
+
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ eps = 0.1;
+ Z = X * (1 / (1 + exp(-(U %*% t(V)))));
+}
+else if(test_num == 9) { #wdivmmLeftEps
+ # X ... 1000x2000 matrix
+
+ U = matrix(seq(1, 10000), rows=1000, cols=10);
+ V = matrix(seq(1, 20000), rows=2000, cols=10);
+ eps = 0.4;
+
+ Z = t(t(U) %*% (X / (U %*% t(V) + eps)));
+}
+
+write(Z, $out_Z);