This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 42f3b15 [SYSTEMDS-2863] Federated ctable instruction
42f3b15 is described below
commit 42f3b1594bba4f4f49316c9866912bb978e6dd8d
Author: Olga <[email protected]>
AuthorDate: Sat Jan 30 16:00:47 2021 +0100
[SYSTEMDS-2863] Federated ctable instruction
Closes #1184
---
.../instructions/cp/CtableCPInstruction.java | 4 +-
.../instructions/fed/CtableFEDInstruction.java | 340 +++++++++++++++++++++
.../runtime/instructions/fed/FEDInstruction.java | 1 +
.../instructions/fed/FEDInstructionUtils.java | 9 +
.../federated/algorithms/FederatedCorTest.java | 4 -
.../federated/primitives/FederatedCtableTest.java | 207 +++++++++++++
.../federated/FederatedCtableFedOutput.dml | 47 +++
.../FederatedCtableFedOutputReference.dml | 46 +++
.../functions/federated/FederatedCtableTest.dml | 40 +++
.../federated/FederatedCtableTestReference.dml | 37 +++
10 files changed, 729 insertions(+), 6 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
index 4869e3e..f44aeb4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
@@ -74,7 +74,7 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
CPOperand out = new CPOperand(parts[6]);
boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
-
+
// ctable does not require any operator, so we simply pass-in a
dummy operator with null functionobject
return new CtableCPInstruction(in1, in2, in3, out,
dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0],
Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst);
}
@@ -174,7 +174,7 @@ public class CtableCPInstruction extends
ComputationCPInstruction {
if( checkGuardedRepresentationChange(matBlock1, matBlock2,
resultBlock) ) {
resultBlock.examSparsity();
}
-
+
ec.setMatrixOutput(output.getName(), resultBlock);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
new file mode 100644
index 0000000..2681759
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.concurrent.Future;
+import java.util.stream.IntStream;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.functionobjects.And;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+
+public class CtableFEDInstruction extends ComputationFEDInstruction {
+ private final CPOperand _outDim1;
+ private final CPOperand _outDim2;
+ private final boolean _isExpand;
+ private final boolean _ignoreZeros;
+
+ private CtableFEDInstruction(CPOperand in1, CPOperand in2, CPOperand
in3, CPOperand out, String outputDim1, boolean dim1Literal, String outputDim2,
boolean dim2Literal, boolean isExpand,
+ boolean ignoreZeros, String opcode, String istr) {
+ super(FEDType.Ctable, null, in1, in2, in3, out, opcode, istr);
+ _outDim1 = new CPOperand(outputDim1, ValueType.FP64,
DataType.SCALAR, dim1Literal);
+ _outDim2 = new CPOperand(outputDim2, ValueType.FP64,
DataType.SCALAR, dim2Literal);
+ _isExpand = isExpand;
+ _ignoreZeros = ignoreZeros;
+ }
+
+ public static CtableFEDInstruction parseInstruction(String inst) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(inst);
+ InstructionUtils.checkNumFields(parts, 7);
+
+ String opcode = parts[0];
+
+ //handle opcode
+ if(!(opcode.equalsIgnoreCase("ctable"))) {
+ throw new DMLRuntimeException("Unexpected opcode in
CtableFEDInstruction: " + inst);
+ }
+
+ //handle operands
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+
+ //handle known dimension information
+ String[] dim1Fields =
parts[4].split(Instruction.LITERAL_PREFIX);
+ String[] dim2Fields =
parts[5].split(Instruction.LITERAL_PREFIX);
+
+ CPOperand out = new CPOperand(parts[6]);
+ boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
+
+ // ctable does not require any operator, so we simply pass-in a
dummy operator with null functionobject
+ return new CtableFEDInstruction(in1,
+ in2, in3, out, dim1Fields[0],
Boolean.parseBoolean(dim1Fields[1]),
+ dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]),
false, ignoreZeros, opcode, inst);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ MatrixObject mo2 = ec.getMatrixObject(input2);
+
+ boolean reversed = false;
+ if(!mo1.isFederated() && mo2.isFederated()) {
+ mo1 = ec.getMatrixObject(input2);
+ mo2 = ec.getMatrixObject(input1);
+ reversed = true;
+ }
+
+ // get new output dims
+ Long[] dims1 = getOutputDimension(mo1, input1, _outDim1,
mo1.getFedMapping().getFederatedRanges());
+ Long[] dims2 = getOutputDimension(mo2, input2, _outDim2,
mo1.getFedMapping().getFederatedRanges());
+
+ MatrixObject mo3 = input3 != null && input3.isMatrix() ?
ec.getMatrixObject(input3) : null;
+
+ boolean reversedWeights = mo3 != null && mo3.isFederated() &&
!(mo1.isFederated() || mo2.isFederated());
+ if(reversedWeights) {
+ mo3 = mo1;
+ mo1 = ec.getMatrixObject(input3);
+ }
+
+ long dim1 = Collections.max(Arrays.asList(dims1),
Long::compare);
+ boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0
&& dims1.length == Arrays.stream(dims1).distinct().count();
+
+ processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights,
fedOutput, dims1, dims2);
+ }
+
+ private void processRequest(ExecutionContext ec, MatrixObject mo1,
MatrixObject mo2, MatrixObject mo3,
+ boolean reversed, boolean reversedWeights, boolean fedOutput,
Long[] dims1, Long[] dims2) {
+ Future<FederatedResponse>[] ffr;
+
+ FederatedRequest[] fr1 =
mo1.getFedMapping().broadcastSliced(mo2, false);
+ FederatedRequest fr2, fr3;
+ if(mo3 == null) {
+ if(!reversed)
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1,
input2},
+ new long[]
{mo1.getFedMapping().getID(), fr1[0].getID()});
+ else
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1,
input2},
+ new long[] {fr1[0].getID(),
mo1.getFedMapping().getID()});
+
+ fr3 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr4 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+ ffr = mo1.getFedMapping().execute(getTID(), true, fr1,
fr2, fr3, fr4);
+
+ } else {
+ FederatedRequest[] fr4 =
mo1.getFedMapping().broadcastSliced(mo3, false);
+ if(!reversed && !reversedWeights)
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1,
input2, input3},
+ new long[]
{mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+ else if(reversed && !reversedWeights)
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1,
input2, input3},
+ new long[] {fr1[0].getID(),
mo1.getFedMapping().getID(), fr4[0].getID()});
+ else
+ fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1,
input2, input3},
+ new long[] {fr1[0].getID(),
fr4[0].getID(), mo1.getFedMapping().getID()});
+
+ fr3 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+ FederatedRequest fr5 =
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr4[0].getID());
+ ffr = mo1.getFedMapping().execute(getTID(), true, fr1,
fr4, fr2, fr3, fr5);
+ }
+
+ if(fedOutput && isFedOutput(ffr, dims1)) {
+ MatrixObject out = ec.getMatrixObject(output);
+ FederationMap newFedMap =
modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
+ setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+ } else {
+ ec.setMatrixOutput(output.getName(), aggResult(ffr));
+ }
+ }
+
+ boolean isFedOutput(Future<FederatedResponse>[] ffr, Long[] dims1) {
+ boolean fedOutput = true;
+
+ long fedSize = Collections.max(Arrays.asList(dims1),
Long::compare) / ffr.length;
+ try {
+ MatrixBlock curr;
+ MatrixBlock prev =(MatrixBlock)
ffr[0].get().getData()[0];
+ for(int i = 1; i < ffr.length && fedOutput; i++) {
+ curr = (MatrixBlock) ffr[i].get().getData()[0];
+ MatrixBlock sliced = curr.slice((int)
(curr.getNumRows() - fedSize), curr.getNumRows() - 1);
+
+ // no intersection
+ if(curr.getNumRows() == (i+1) *
prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
+ && (curr.getNumRows() -
sliced.getNumRows()) == i * prev.getNumRows()
+ && curr.getNonZeros() -
sliced.getNonZeros() == 0)
+ continue;
+
+ // check intersect with AND and compare number
of nnz
+ MatrixBlock prevExtend = new
MatrixBlock(curr.getNumRows(), curr.getNumColumns(), 0.0);
+ prevExtend.copy(0, prev.getNumRows()-1, 0,
prev.getNumColumns()-1, prev, true);
+
+ MatrixBlock intersect =
curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()),
prevExtend);
+ if(intersect.getNonZeros() != 0)
+ fedOutput = false;
+ prev = sliced;
+ }
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ }
+ return fedOutput;
+ }
+
+
+ private void setFedOutput(MatrixObject mo1, MatrixObject out,
FederationMap fedMap, Long[] dims1, long outId) {
+ long fedSize = Collections.max(Arrays.asList(dims1),
Long::compare) / dims1.length;
+
+ long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
+ long d2 = Collections.max(Arrays.asList(dims1), Long::compare);
+
+ // set output
+ out.getDataCharacteristics().set(d1, d2, (int)
mo1.getBlocksize(), mo1.getNnz());
+ out.setFedMapping(fedMap.copyWithNewID(outId));
+
+ long varID = FederationUtils.getNextFedDataID();
+ out.getFedMapping().mapParallel(varID, (range, data) -> {
+ try {
+ FederatedResponse response =
data.executeFederatedOperation(new FederatedRequest(
+ FederatedRequest.RequestType.EXEC_UDF,
-1,
+ new SliceOutput(data.getVarID(),
fedSize))).get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+ }
+
+ private MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
+ MatrixBlock resultBlock = new MatrixBlock(1, 1, 0);
+ int dim1 = 0, dim2 = 0;
+ for(int i = 0; i < ffr.length; i++) {
+ try {
+ MatrixBlock mb = ((MatrixBlock)
ffr[i].get().getData()[0]);
+ dim1 = mb.getNumRows() > dim1 ?
mb.getNumRows() : dim1;
+ dim2 = mb.getNumColumns() > dim2 ?
mb.getNumColumns() : dim2;
+
+ // set next and prev to same output dimensions
+ MatrixBlock prev = new MatrixBlock(dim1, dim2,
0.0);
+ prev.copy(0, resultBlock.getNumRows()-1, 0,
resultBlock.getNumColumns()-1, resultBlock, true);
+
+ MatrixBlock next = new MatrixBlock(dim1, dim2,
0.0);
+ next.copy(0, mb.getNumRows()-1, 0,
mb.getNumColumns()-1, mb, true);
+
+ // add worker results
+ BinaryOperator plus =
InstructionUtils.parseBinaryOperator("+");
+ resultBlock =
prev.binaryOperationsInPlace(plus, next);
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ }
+ }
+ return resultBlock;
+ }
+
+ private FederationMap modifyFedRanges(FederationMap fedMap, Long[]
dims1, Long[] dims2) {
+ IntStream.range(0,
fedMap.getFederatedRanges().length).forEach(i -> {
+ fedMap.getFederatedRanges()[i]
+ .setBeginDim(0, i == 0 ? 0 :
fedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
+ fedMap.getFederatedRanges()[i].setEndDim(0, dims1[i]);
+ fedMap.getFederatedRanges()[i]
+ .setBeginDim(1, i == 0 ? 0 :
fedMap.getFederatedRanges()[i - 1].getBeginDims()[1]);
+ fedMap.getFederatedRanges()[i].setEndDim(1, dims2[i]);
+ });
+ return fedMap;
+ }
+
+ private Long[] getOutputDimension(MatrixObject in, CPOperand inOp,
CPOperand outOp, FederatedRange[] federatedRanges) {
+ Long[] fedDims = new Long[federatedRanges.length];
+
+ if(!in.isFederated()) {
+ //slice
+ MatrixBlock mb = in.acquireReadAndRelease();
+ IntStream.range(0, federatedRanges.length).forEach(i ->
{
+ MatrixBlock sliced = mb
+
.slice(federatedRanges[i].getBeginDimsInt()[0],
federatedRanges[i].getEndDimsInt()[0] - 1);
+ fedDims[i] = (long) sliced.max();
+ });
+ return fedDims;
+ }
+
+ String maxInstString = constructMaxInstString(inOp.getName(),
outOp.getName());
+
+ // get max per worker
+ FederationMap map = in.getFedMapping();
+ FederatedRequest fr1 =
FederationUtils.callInstruction(maxInstString, outOp,
+ new CPOperand[]{inOp}, new
long[]{in.getFedMapping().getID()});
+ FederatedRequest fr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
+ Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1,
fr2, fr3);
+
+ return computeOutputDims(tmp);
+ }
+
+ private Long[] computeOutputDims(Future<FederatedResponse>[] tmp) {
+ Long[] fedDims = new Long[tmp.length];
+ for(int i = 0; i < tmp.length; i ++)
+ try {
+ fedDims[i] = ((ScalarObject)
tmp[i].get().getData()[0]).getLongValue();
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ }
+ return fedDims;
+ }
+
+ private String constructMaxInstString(String in, String out) {
+ String maxInstrString = instString.replace("ctable", "uamax");
+ String[] instParts =
maxInstrString.split(Lop.OPERAND_DELIMITOR);
+ String[] maxInstParts = new String[] {instParts[0],
instParts[1],
+ InstructionUtils.concatOperandParts(in,
DataType.MATRIX.name(), (ValueType.FP64).name()),
+ InstructionUtils.concatOperandParts(out,
DataType.SCALAR.name(), (ValueType.FP64).name()), "16"};
+ return String.join(Lop.OPERAND_DELIMITOR, maxInstParts);
+ }
+
+ private static class SliceOutput extends FederatedUDF {
+
+ private static final long serialVersionUID =
-2808597461054603816L;
+ private final long _fedSize;
+
+ protected SliceOutput(long input, long fedSize) {
+ super(new long[] {input});
+ _fedSize = fedSize;
+ }
+
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixObject mo = (MatrixObject) data[0];
+ MatrixBlock mb = mo.acquireReadAndRelease();
+
+ MatrixBlock sliced = mb.slice((int)
(mb.getNumRows()-_fedSize), mb.getNumRows()-1);
+ mo.acquireModify(sliced);
+ mo.release();
+
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {});
+ }
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+}
\ No newline at end of file
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index a91cb0c..871bdb1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -32,6 +32,7 @@ public abstract class FEDInstruction extends Instruction {
AggregateTernary,
Append,
Binary,
+ Ctable,
CumulativeAggregate,
Init,
MultiReturnParameterizedBuiltin,
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 7439b6f..3708d0a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -31,6 +31,7 @@ import
org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
@@ -223,6 +224,14 @@ public class FEDInstructionUtils {
if(instruction.getOperatorClass().getSuperclass() ==
SpoofCellwise.class && instruction.isFederated(ec))
fedinst =
SpoofFEDInstruction.parseInstruction(instruction.getInstructionString());
}
+ else if(inst instanceof CtableCPInstruction) {
+ CtableCPInstruction cinst = (CtableCPInstruction) inst;
+ if(inst.getOpcode().equalsIgnoreCase("ctable")
+ && (
ec.getCacheableData(cinst.input1).isFederated(FType.ROW)
+ || (cinst.input2.isMatrix() &&
ec.getCacheableData(cinst.input2).isFederated(FType.ROW))
+ || (cinst.input3.isMatrix() &&
ec.getCacheableData(cinst.input3).isFederated(FType.ROW))))
+ fedinst =
CtableFEDInstruction.parseInstruction(cinst.getInstructionString());
+ }
//set thread id for federated context management
if( fedinst != null ) {
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
index 20b0147..73bf8e9 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
@@ -107,10 +107,6 @@ public class FederatedCorTest extends AutomatedTestBase {
Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
Thread t4 = startLocalFedWorkerThread(port4);
- rtplatform = execMode;
- if(rtplatform == ExecMode.SPARK)
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
new file mode 100644
index 0000000..a5793b5
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedCtableTest extends AutomatedTestBase {
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME1 = "FederatedCtableTest";
+ private final static String TEST_NAME2 = "FederatedCtableFedOutput";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedCtableTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public int maxVal1;
+ @Parameterized.Parameter(3)
+ public int maxVal2;
+ @Parameterized.Parameter(4)
+ public boolean reversedInputs;
+ @Parameterized.Parameter(5)
+ public boolean weighted;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"F"}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"F"}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {12, 4, 4, 7, true, true}, {12, 4, 4, 7, true, false},
+ {12, 4, 4, 7, false, true}, {12, 4, 4, 7, false, false},
+
+ {100, 14, 4, 7, true, true}, {100, 14, 4, 7, true,
false},
+ {100, 14, 4, 7, false, true}, {100, 14, 4, 7, false,
false},
+
+ // {1000, 14, 4, 7, true, true}, {1000, 14, 4, 7, true,
false},
+ // {1000, 14, 4, 7, false, true}, {1000, 14, 4, 7,
false, false}
+ });
+ }
+
+ @Test
+ public void federatedCtableSinglenode() {
runCtable(Types.ExecMode.SINGLE_NODE, false, false); }
+
+ @Test
+ public void federatedCtableFedOutputSinglenode() {
runCtable(Types.ExecMode.SINGLE_NODE, true, false); }
+
+ @Test
+ public void federatedCtableMatrixInputSinglenode() {
runCtable(Types.ExecMode.SINGLE_NODE, false, true); }
+
+
+ public void runCtable(Types.ExecMode execMode, boolean fedOutput,
boolean matrixInput) {
+ String TEST_NAME = fedOutput ? TEST_NAME2 : TEST_NAME1;
+ Types.ExecMode platformOld = setExecMode(execMode);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+ Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+ Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ if(fedOutput)
+ runFedCtable(HOME, TEST_NAME, port1, port2, port3,
port4);
+ else
+ runNonFedCtable(HOME, TEST_NAME, matrixInput, port1,
port2, port3, port4);
+ checkResults();
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+ resetExecMode(platformOld);
+ }
+
+ private void runNonFedCtable(String HOME, String TEST_NAME, boolean
matrixInput, int port1, int port2, int port3, int port4) {
+ int r = rows / 4;
+ cols = matrixInput ? cols : 1;
+ double[][] X1 = TestUtils.floor(getRandomMatrix(r, cols, 1,
maxVal1, 1, 3));
+ double[][] X2 = TestUtils.floor(getRandomMatrix(r, cols, 1,
maxVal1, 1, 7));
+ double[][] X3 = TestUtils.floor(getRandomMatrix(r, cols, 1,
maxVal1, 1, 8));
+ double[][] X4 = TestUtils.floor(getRandomMatrix(r, cols, 1,
maxVal1, 1, 9));
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, cols,
blocksize, r);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ double[][] Y = TestUtils.floor(getRandomMatrix(rows, cols, 1,
maxVal2, 1, 9));
+ writeInputMatrixWithMTD("Y", Y, false, new
MatrixCharacteristics(rows, cols, blocksize, r));
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
+ input("Y"),
Boolean.toString(reversedInputs).toUpperCase(),
Boolean.toString(weighted).toUpperCase(),expected("F")};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "in_Y=" + input("Y"),
+ "rows=" + rows, "cols=" + cols, "revIn=" +
Boolean.toString(reversedInputs).toUpperCase(),
+ "weighted=" + Boolean.toString(weighted).toUpperCase(),
"out=" + output("F")};
+ runTest(true, false, null, -1);
+ }
+
+ private void runFedCtable(String HOME, String TEST_NAME, int port1, int
port2, int port3, int port4) {
+ int r = rows / 4;
+ int c = cols;
+
+ double[][] X1 = getRandomMatrix(r, c, 1, 3, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 3, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 3, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 3, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ //execute main test
+ fullDMLScriptName = HOME + TEST_NAME2 + "Reference.dml";
+ programArgs = new String[]{"-stats", "100", "-args",
+ input("X1"), input("X2"), input("X3"), input("X4"),
Boolean.toString(reversedInputs).toUpperCase(),
+ Boolean.toString(weighted).toUpperCase(),
expected("F")};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME2 + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")),
+ "rows=" + rows, "cols=" + cols, "revIn=" +
Boolean.toString(reversedInputs).toUpperCase(),
+ "weighted=" + Boolean.toString(weighted).toUpperCase(),
"out=" + output("F")
+ };
+ runTest(true, false, null, -1);
+ }
+
+ void checkResults() {
+ // compare via files
+ compareResults(1e-9);
+
+ // check for federated operations
+ Assert.assertTrue(heavyHittersContainsString("fed_ctable"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+ }
+
+}
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
new file mode 100644
index 0000000..9c21ed5
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
@@ -0,0 +1,47 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+
+m = nrow(X);
+n = ncol(X);
+
+# prepare offset vectors and one-hot encoded X
+maxs = colMaxs(X);
+rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
+cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+
+W = rix + cix;
+
+if($revIn)
+ if($weighted)
+ X2 = table(cix, rix, W);
+ else
+ X2 = table(cix, rix);
+else
+ if($weighted)
+ X2 = table(rix, cix, W);
+ else
+ X2 = table(rix, cix);
+
+write(X2, $out);
diff --git
a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
new file mode 100644
index 0000000..e0721df
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
@@ -0,0 +1,46 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4));
+
+m = nrow(X);
+n = ncol(X);
+
+# prepare offset vectors and one-hot encoded X
+maxs = colMaxs(X);
+
+rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1)
+cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+
+W = rix + cix;
+
+if($5)
+ if($6)
+ X2 = table(cix, rix, W);
+ else
+ X2 = table(cix, rix);
+else
+ if($6)
+ X2 = table(rix, cix, W);
+ else
+ X2 = table(rix, cix);
+
+write(X2, $7);
diff --git a/src/test/scripts/functions/federated/FederatedCtableTest.dml
b/src/test/scripts/functions/federated/FederatedCtableTest.dml
new file mode 100644
index 0000000..48ed734
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCtableTest.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+
+Y = read($in_Y);
+W = Y * 2;
+
+if($revIn)
+ if($weighted)
+ F = table(X, Y, W);
+ else
+ F = table(X, Y);
+else
+ if($weighted)
+ F = table(X, Y, W);
+ else
+ F = table(X, Y);
+
+write(F, $out);
diff --git
a/src/test/scripts/functions/federated/FederatedCtableTestReference.dml
b/src/test/scripts/functions/federated/FederatedCtableTestReference.dml
new file mode 100644
index 0000000..749a99a
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCtableTestReference.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4));
+Y = read($5);
+W = Y * 2;
+
+if($6)
+ if($7)
+ F = table(X, Y, W);
+ else
+ F = table(X, Y);
+else
+ if($7)
+ F = table(X, Y, W);
+ else
+ F = table(X, Y);
+
+write(F, $8);