This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 42f3b15  [SYSTEMDS-2863] Federated ctable instruction
42f3b15 is described below

commit 42f3b1594bba4f4f49316c9866912bb978e6dd8d
Author: Olga <[email protected]>
AuthorDate: Sat Jan 30 16:00:47 2021 +0100

    [SYSTEMDS-2863] Federated ctable instruction
    
    Closes #1184
---
 .../instructions/cp/CtableCPInstruction.java       |   4 +-
 .../instructions/fed/CtableFEDInstruction.java     | 340 +++++++++++++++++++++
 .../runtime/instructions/fed/FEDInstruction.java   |   1 +
 .../instructions/fed/FEDInstructionUtils.java      |   9 +
 .../federated/algorithms/FederatedCorTest.java     |   4 -
 .../federated/primitives/FederatedCtableTest.java  | 207 +++++++++++++
 .../federated/FederatedCtableFedOutput.dml         |  47 +++
 .../FederatedCtableFedOutputReference.dml          |  46 +++
 .../functions/federated/FederatedCtableTest.dml    |  40 +++
 .../federated/FederatedCtableTestReference.dml     |  37 +++
 10 files changed, 729 insertions(+), 6 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
index 4869e3e..f44aeb4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CtableCPInstruction.java
@@ -74,7 +74,7 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
 
                CPOperand out = new CPOperand(parts[6]);
                boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
-               
+
                // ctable does not require any operator, so we simply pass-in a 
dummy operator with null functionobject
                return new CtableCPInstruction(in1, in2, in3, out, 
dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], 
Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst);
        }
@@ -174,7 +174,7 @@ public class CtableCPInstruction extends 
ComputationCPInstruction {
                if( checkGuardedRepresentationChange(matBlock1, matBlock2, 
resultBlock) ) {
                        resultBlock.examSparsity();
                }
-               
+
                ec.setMatrixOutput(output.getName(), resultBlock);
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
new file mode 100644
index 0000000..2681759
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.concurrent.Future;
+import java.util.stream.IntStream;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.functionobjects.And;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+
+public class CtableFEDInstruction extends ComputationFEDInstruction {
+       private final CPOperand _outDim1;
+       private final CPOperand _outDim2;
+       private final boolean _isExpand;
+       private final boolean _ignoreZeros;
+
+       private CtableFEDInstruction(CPOperand in1, CPOperand in2, CPOperand 
in3, CPOperand out, String outputDim1, boolean dim1Literal, String outputDim2, 
boolean dim2Literal, boolean isExpand,
+               boolean ignoreZeros, String opcode, String istr) {
+               super(FEDType.Ctable, null, in1, in2, in3, out, opcode, istr);
+               _outDim1 = new CPOperand(outputDim1, ValueType.FP64, 
DataType.SCALAR, dim1Literal);
+               _outDim2 = new CPOperand(outputDim2, ValueType.FP64, 
DataType.SCALAR, dim2Literal);
+               _isExpand = isExpand;
+               _ignoreZeros = ignoreZeros;
+       }
+
+       public static CtableFEDInstruction parseInstruction(String inst) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(inst);
+               InstructionUtils.checkNumFields(parts, 7);
+
+               String opcode = parts[0];
+
+               //handle opcode
+               if(!(opcode.equalsIgnoreCase("ctable"))) {
+                       throw new DMLRuntimeException("Unexpected opcode in 
CtableFEDInstruction: " + inst);
+               }
+
+               //handle operands
+               CPOperand in1 = new CPOperand(parts[1]);
+               CPOperand in2 = new CPOperand(parts[2]);
+               CPOperand in3 = new CPOperand(parts[3]);
+
+               //handle known dimension information
+               String[] dim1Fields = 
parts[4].split(Instruction.LITERAL_PREFIX);
+               String[] dim2Fields = 
parts[5].split(Instruction.LITERAL_PREFIX);
+
+               CPOperand out = new CPOperand(parts[6]);
+               boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
+
+               // ctable does not require any operator, so we simply pass-in a 
dummy operator with null functionobject
+               return new CtableFEDInstruction(in1,
+                       in2, in3, out, dim1Fields[0], 
Boolean.parseBoolean(dim1Fields[1]),
+                       dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), 
false, ignoreZeros, opcode, inst);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject mo1 = ec.getMatrixObject(input1);
+               MatrixObject mo2 = ec.getMatrixObject(input2);
+
+               boolean reversed = false;
+               if(!mo1.isFederated() && mo2.isFederated()) {
+                       mo1 = ec.getMatrixObject(input2);
+                       mo2 = ec.getMatrixObject(input1);
+                       reversed = true;
+               }
+
+               // get new output dims
+               Long[] dims1 = getOutputDimension(mo1, input1, _outDim1, 
mo1.getFedMapping().getFederatedRanges());
+               Long[] dims2 = getOutputDimension(mo2, input2, _outDim2, 
mo1.getFedMapping().getFederatedRanges());
+
+               MatrixObject mo3 = input3 != null && input3.isMatrix() ? 
ec.getMatrixObject(input3) : null;
+
+               boolean reversedWeights = mo3 != null && mo3.isFederated() && 
!(mo1.isFederated() || mo2.isFederated());
+               if(reversedWeights) {
+                       mo3 = mo1;
+                       mo1 = ec.getMatrixObject(input3);
+               }
+
+               long dim1 = Collections.max(Arrays.asList(dims1), 
Long::compare);
+               boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 
&& dims1.length == Arrays.stream(dims1).distinct().count();
+
+               processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, 
fedOutput, dims1, dims2);
+       }
+
+       private void processRequest(ExecutionContext ec, MatrixObject mo1, 
MatrixObject mo2, MatrixObject mo3,
+               boolean reversed, boolean reversedWeights, boolean fedOutput, 
Long[] dims1, Long[] dims2) {
+               Future<FederatedResponse>[] ffr;
+
+               FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
+               FederatedRequest fr2, fr3;
+               if(mo3 == null) {
+                       if(!reversed)
+                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
+                                       new long[] 
{mo1.getFedMapping().getID(), fr1[0].getID()});
+                       else
+                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
+                                       new long[] {fr1[0].getID(), 
mo1.getFedMapping().getID()});
+
+                       fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+                       FederatedRequest fr4 = 
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+                       ffr = mo1.getFedMapping().execute(getTID(), true, fr1, 
fr2, fr3, fr4);
+
+               } else {
+                       FederatedRequest[] fr4 = 
mo1.getFedMapping().broadcastSliced(mo3, false);
+                       if(!reversed && !reversedWeights)
+                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] 
{mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+                       else if(reversed && !reversedWeights)
+                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fr1[0].getID(), 
mo1.getFedMapping().getID(), fr4[0].getID()});
+                       else
+                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fr1[0].getID(), 
fr4[0].getID(), mo1.getFedMapping().getID()});
+
+                       fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+                       FederatedRequest fr5 = 
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr4[0].getID());
+                       ffr = mo1.getFedMapping().execute(getTID(), true, fr1, 
fr4, fr2, fr3, fr5);
+               }
+
+               if(fedOutput && isFedOutput(ffr, dims1)) {
+                       MatrixObject out = ec.getMatrixObject(output);
+                       FederationMap newFedMap = 
modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
+                       setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+               } else {
+                       ec.setMatrixOutput(output.getName(), aggResult(ffr));
+               }
+       }
+
+       boolean isFedOutput(Future<FederatedResponse>[] ffr,  Long[] dims1) {
+               boolean fedOutput = true;
+
+               long fedSize = Collections.max(Arrays.asList(dims1), 
Long::compare) / ffr.length;
+               try {
+                       MatrixBlock curr;
+                       MatrixBlock prev =(MatrixBlock) 
ffr[0].get().getData()[0];
+                       for(int i = 1; i < ffr.length && fedOutput; i++) {
+                               curr = (MatrixBlock) ffr[i].get().getData()[0];
+                               MatrixBlock sliced = curr.slice((int) 
(curr.getNumRows() - fedSize), curr.getNumRows() - 1);
+
+                               // no intersection
+                               if(curr.getNumRows() == (i+1) * 
prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
+                                       && (curr.getNumRows() - 
sliced.getNumRows()) == i * prev.getNumRows()
+                                       && curr.getNonZeros() - 
sliced.getNonZeros() == 0)
+                                       continue;
+
+                               // check intersect with AND and compare number 
of nnz
+                               MatrixBlock prevExtend = new 
MatrixBlock(curr.getNumRows(), curr.getNumColumns(), 0.0);
+                               prevExtend.copy(0, prev.getNumRows()-1, 0, 
prev.getNumColumns()-1, prev, true);
+
+                               MatrixBlock  intersect = 
curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), 
prevExtend);
+                               if(intersect.getNonZeros() != 0)
+                                       fedOutput = false;
+                               prev = sliced;
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+               }
+               return fedOutput;
+       }
+
+
+       private void setFedOutput(MatrixObject mo1, MatrixObject out, 
FederationMap fedMap, Long[] dims1, long outId) {
+               long fedSize = Collections.max(Arrays.asList(dims1), 
Long::compare) / dims1.length;
+
+               long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
+               long d2 = Collections.max(Arrays.asList(dims1), Long::compare);
+
+               // set output
+               out.getDataCharacteristics().set(d1, d2, (int) 
mo1.getBlocksize(), mo1.getNnz());
+               out.setFedMapping(fedMap.copyWithNewID(outId));
+
+               long varID = FederationUtils.getNextFedDataID();
+               out.getFedMapping().mapParallel(varID, (range, data) -> {
+                       try {
+                               FederatedResponse response = 
data.executeFederatedOperation(new FederatedRequest(
+                                       FederatedRequest.RequestType.EXEC_UDF, 
-1,
+                                       new SliceOutput(data.getVarID(), 
fedSize))).get();
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+                       return null;
+               });
+       }
+
+       private MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
+               MatrixBlock resultBlock = new MatrixBlock(1, 1, 0);
+               int dim1 = 0, dim2 = 0;
+               for(int i = 0; i < ffr.length; i++) {
+                       try {
+                               MatrixBlock mb = ((MatrixBlock) 
ffr[i].get().getData()[0]);
+                               dim1 = mb.getNumRows()  > dim1 ? 
mb.getNumRows() : dim1;
+                               dim2 = mb.getNumColumns()  > dim2 ? 
mb.getNumColumns() : dim2;
+
+                               // set next and prev to same output dimensions
+                               MatrixBlock prev = new MatrixBlock(dim1, dim2, 
0.0);
+                               prev.copy(0, resultBlock.getNumRows()-1, 0, 
resultBlock.getNumColumns()-1, resultBlock, true);
+
+                               MatrixBlock next = new MatrixBlock(dim1, dim2, 
0.0);
+                               next.copy(0, mb.getNumRows()-1, 0, 
mb.getNumColumns()-1, mb, true);
+
+                               // add worker results
+                               BinaryOperator plus = 
InstructionUtils.parseBinaryOperator("+");
+                               resultBlock = 
prev.binaryOperationsInPlace(plus, next);
+                       }
+                       catch(Exception e) {
+                               e.printStackTrace();
+                       }
+               }
+               return resultBlock;
+       }
+
+       private FederationMap modifyFedRanges(FederationMap fedMap, Long[] 
dims1, Long[] dims2) {
+               IntStream.range(0, 
fedMap.getFederatedRanges().length).forEach(i -> {
+                       fedMap.getFederatedRanges()[i]
+                               .setBeginDim(0, i == 0 ? 0 : 
fedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
+                       fedMap.getFederatedRanges()[i].setEndDim(0, dims1[i]);
+                       fedMap.getFederatedRanges()[i]
+                               .setBeginDim(1, i == 0 ? 0 : 
fedMap.getFederatedRanges()[i - 1].getBeginDims()[1]);
+                       fedMap.getFederatedRanges()[i].setEndDim(1, dims2[i]);
+               });
+               return fedMap;
+       }
+
+       private Long[] getOutputDimension(MatrixObject in, CPOperand inOp, 
CPOperand outOp, FederatedRange[] federatedRanges) {
+               Long[] fedDims = new Long[federatedRanges.length];
+
+               if(!in.isFederated()) {
+                       //slice
+                       MatrixBlock mb = in.acquireReadAndRelease();
+                       IntStream.range(0, federatedRanges.length).forEach(i -> 
{
+                               MatrixBlock sliced = mb
+                                       
.slice(federatedRanges[i].getBeginDimsInt()[0], 
federatedRanges[i].getEndDimsInt()[0] - 1);
+                               fedDims[i] = (long) sliced.max();
+                       });
+                       return fedDims;
+               }
+
+               String maxInstString = constructMaxInstString(inOp.getName(), 
outOp.getName());
+
+               // get max per worker
+               FederationMap map = in.getFedMapping();
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(maxInstString, outOp,
+                       new CPOperand[]{inOp}, new 
long[]{in.getFedMapping().getID()});
+               FederatedRequest fr2 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+               FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
+               Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2, fr3);
+
+               return computeOutputDims(tmp);
+       }
+
+       private Long[] computeOutputDims(Future<FederatedResponse>[] tmp) {
+               Long[] fedDims = new Long[tmp.length];
+               for(int i = 0; i < tmp.length; i ++)
+                       try {
+                               fedDims[i] = ((ScalarObject) 
tmp[i].get().getData()[0]).getLongValue();
+                       }
+                       catch(Exception e) {
+                               e.printStackTrace();
+                       }
+               return fedDims;
+       }
+
+       private String constructMaxInstString(String in, String out) {
+               String maxInstrString = instString.replace("ctable", "uamax");
+               String[] instParts = 
maxInstrString.split(Lop.OPERAND_DELIMITOR);
+               String[] maxInstParts = new String[] {instParts[0], 
instParts[1],
+                       InstructionUtils.concatOperandParts(in, 
DataType.MATRIX.name(), (ValueType.FP64).name()),
+                       InstructionUtils.concatOperandParts(out, 
DataType.SCALAR.name(), (ValueType.FP64).name()), "16"};
+               return String.join(Lop.OPERAND_DELIMITOR, maxInstParts);
+       }
+
+       private static class SliceOutput extends FederatedUDF {
+
+               private static final long serialVersionUID = 
-2808597461054603816L;
+               private final long _fedSize;
+
+               protected SliceOutput(long input, long fedSize) {
+                       super(new long[] {input});
+                       _fedSize = fedSize;
+               }
+
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixObject mo = (MatrixObject) data[0];
+                       MatrixBlock mb = mo.acquireReadAndRelease();
+
+                       MatrixBlock sliced = mb.slice((int) 
(mb.getNumRows()-_fedSize), mb.getNumRows()-1);
+                       mo.acquireModify(sliced);
+                       mo.release();
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {});
+               }
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+}
\ No newline at end of file
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index a91cb0c..871bdb1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -32,6 +32,7 @@ public abstract class FEDInstruction extends Instruction {
                AggregateTernary,
                Append,
                Binary,
+               Ctable,
                CumulativeAggregate,
                Init,
                MultiReturnParameterizedBuiltin,
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 7439b6f..3708d0a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -31,6 +31,7 @@ import 
org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
@@ -223,6 +224,14 @@ public class FEDInstructionUtils {
                        if(instruction.getOperatorClass().getSuperclass() == 
SpoofCellwise.class && instruction.isFederated(ec))
                                fedinst = 
SpoofFEDInstruction.parseInstruction(instruction.getInstructionString());
                }
+               else if(inst instanceof CtableCPInstruction) {
+                       CtableCPInstruction cinst = (CtableCPInstruction) inst;
+                       if(inst.getOpcode().equalsIgnoreCase("ctable")
+                               && ( 
ec.getCacheableData(cinst.input1).isFederated(FType.ROW)
+                               || (cinst.input2.isMatrix() && 
ec.getCacheableData(cinst.input2).isFederated(FType.ROW))
+                               || (cinst.input3.isMatrix() && 
ec.getCacheableData(cinst.input3).isFederated(FType.ROW))))
+                               fedinst = 
CtableFEDInstruction.parseInstruction(cinst.getInstructionString());
+               }
 
                //set thread id for federated context management
                if( fedinst != null ) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
index 20b0147..73bf8e9 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
@@ -107,10 +107,6 @@ public class FederatedCorTest extends AutomatedTestBase {
                Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
                Thread t4 = startLocalFedWorkerThread(port4);
 
-               rtplatform = execMode;
-               if(rtplatform == ExecMode.SPARK)
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
                TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                loadTestConfiguration(config);
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
new file mode 100644
index 0000000..a5793b5
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedCtableTest extends AutomatedTestBase {
+       private final static String TEST_DIR = "functions/federated/";
+       private final static String TEST_NAME1 = "FederatedCtableTest";
+       private final static String TEST_NAME2 = "FederatedCtableFedOutput";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedCtableTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public int maxVal1;
+       @Parameterized.Parameter(3)
+       public int maxVal2;
+       @Parameterized.Parameter(4)
+       public boolean reversedInputs;
+       @Parameterized.Parameter(5)
+       public boolean weighted;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"F"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"F"}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {12, 4, 4, 7, true, true}, {12, 4, 4, 7, true, false},
+                       {12, 4, 4, 7, false, true}, {12, 4, 4, 7, false, false},
+
+                       {100, 14, 4, 7, true, true}, {100, 14, 4, 7, true, 
false},
+                       {100, 14, 4, 7, false, true}, {100, 14, 4, 7, false, 
false},
+
+                       // {1000, 14, 4, 7, true, true}, {1000, 14, 4, 7, true, 
false},
+                       // {1000, 14, 4, 7, false, true}, {1000, 14, 4, 7, 
false, false}
+               });
+       }
+
+       @Test
+       public void federatedCtableSinglenode() { 
runCtable(Types.ExecMode.SINGLE_NODE, false, false); }
+
+       @Test
+       public void federatedCtableFedOutputSinglenode() { 
runCtable(Types.ExecMode.SINGLE_NODE, true, false); }
+
+       @Test
+       public void federatedCtableMatrixInputSinglenode() { 
runCtable(Types.ExecMode.SINGLE_NODE, false, true); }
+
+
+       public void runCtable(Types.ExecMode execMode, boolean fedOutput, 
boolean matrixInput) {
+               String TEST_NAME = fedOutput ? TEST_NAME2 : TEST_NAME1;
+               Types.ExecMode platformOld = setExecMode(execMode);
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               if(fedOutput)
+                       runFedCtable(HOME, TEST_NAME, port1, port2, port3, 
port4);
+               else
+                       runNonFedCtable(HOME, TEST_NAME, matrixInput, port1, 
port2, port3, port4);
+               checkResults();
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+               resetExecMode(platformOld);
+       }
+
+       private void runNonFedCtable(String HOME, String TEST_NAME, boolean 
matrixInput, int port1, int port2, int port3, int port4) {
+               int r = rows / 4;
+               cols  = matrixInput ? cols : 1;
+               double[][] X1 = TestUtils.floor(getRandomMatrix(r, cols, 1, 
maxVal1, 1, 3));
+               double[][] X2 = TestUtils.floor(getRandomMatrix(r, cols, 1, 
maxVal1, 1, 7));
+               double[][] X3 = TestUtils.floor(getRandomMatrix(r, cols, 1, 
maxVal1, 1, 8));
+               double[][] X4 = TestUtils.floor(getRandomMatrix(r, cols, 1, 
maxVal1, 1, 9));
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, cols, 
blocksize, r);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               double[][] Y = TestUtils.floor(getRandomMatrix(rows, cols, 1, 
maxVal2, 1, 9));
+               writeInputMatrixWithMTD("Y", Y, false, new 
MatrixCharacteristics(rows, cols, blocksize, r));
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
input("X1"), input("X2"), input("X3"), input("X4"),
+                       input("Y"), 
Boolean.toString(reversedInputs).toUpperCase(), 
Boolean.toString(weighted).toUpperCase(),expected("F")};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "in_Y=" + input("Y"),
+                       "rows=" + rows, "cols=" + cols, "revIn=" + 
Boolean.toString(reversedInputs).toUpperCase(),
+                       "weighted=" + Boolean.toString(weighted).toUpperCase(), 
"out=" + output("F")};
+               runTest(true, false, null, -1);
+       }
+
+       private void runFedCtable(String HOME, String TEST_NAME, int port1, int 
port2, int port3, int port4) {
+               int r = rows / 4;
+               int c = cols;
+
+               double[][] X1 = getRandomMatrix(r, c, 1, 3, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 1, 3, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 1, 3, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 1, 3, 1, 9);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
+               writeInputMatrixWithMTD("X1", X1, false, mc);
+               writeInputMatrixWithMTD("X2", X2, false, mc);
+               writeInputMatrixWithMTD("X3", X3, false, mc);
+               writeInputMatrixWithMTD("X4", X4, false, mc);
+
+               //execute main test
+               fullDMLScriptName = HOME + TEST_NAME2 + "Reference.dml";
+               programArgs = new String[]{"-stats", "100", "-args",
+                       input("X1"), input("X2"), input("X3"), input("X4"), 
Boolean.toString(reversedInputs).toUpperCase(),
+                       Boolean.toString(weighted).toUpperCase(), 
expected("F")};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + TEST_NAME2 + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
+                       "rows=" + rows, "cols=" + cols, "revIn=" + 
Boolean.toString(reversedInputs).toUpperCase(),
+                       "weighted=" + Boolean.toString(weighted).toUpperCase(), 
"out=" + output("F")
+               };
+               runTest(true, false, null, -1);
+       }
+
+       void checkResults() {
+               // compare via files
+               compareResults(1e-9);
+
+               // check for federated operations
+               Assert.assertTrue(heavyHittersContainsString("fed_ctable"));
+
+               // check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+       }
+
+}
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml 
b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
new file mode 100644
index 0000000..9c21ed5
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
@@ -0,0 +1,47 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+
+m = nrow(X);
+n = ncol(X);
+
+# prepare offset vectors and one-hot encoded X
+maxs = colMaxs(X);
+rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
+cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+
+W = rix + cix;
+
+if($revIn)
+  if($weighted)
+    X2 = table(cix, rix, W);
+  else
+    X2 = table(cix, rix);
+else
+  if($weighted)
+      X2 = table(rix, cix, W);
+    else
+      X2 = table(rix, cix);
+
+write(X2, $out);
diff --git 
a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml 
b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
new file mode 100644
index 0000000..e0721df
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
@@ -0,0 +1,46 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4));
+
+m = nrow(X);
+n = ncol(X);
+
+# prepare offset vectors and one-hot encoded X
+maxs = colMaxs(X);
+
+rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1)
+cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+
+W = rix + cix;
+
+if($5)
+  if($6)
+    X2 = table(cix, rix, W);
+  else
+    X2 = table(cix, rix);
+else
+  if($6)
+    X2 = table(rix, cix, W);
+  else
+    X2 = table(rix, cix);
+
+write(X2, $7);
diff --git a/src/test/scripts/functions/federated/FederatedCtableTest.dml 
b/src/test/scripts/functions/federated/FederatedCtableTest.dml
new file mode 100644
index 0000000..48ed734
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCtableTest.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+
+Y = read($in_Y);
+W = Y * 2;
+
+if($revIn)
+  if($weighted)
+    F = table(X, Y, W);
+  else
+    F = table(X, Y);
+else
+  if($weighted)
+    F = table(X, Y, W);
+  else
+    F = table(X, Y);
+
+write(F, $out);
diff --git 
a/src/test/scripts/functions/federated/FederatedCtableTestReference.dml 
b/src/test/scripts/functions/federated/FederatedCtableTestReference.dml
new file mode 100644
index 0000000..749a99a
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCtableTestReference.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2), read($3), read($4));
+Y = read($5);
+W = Y * 2;
+
+if($6)
+  if($7)
+    F = table(X, Y, W);
+  else
+    F = table(X, Y);
+else
+  if($7)
+    F = table(X, Y, W);
+  else
+    F = table(X, Y);
+
+write(F, $8);

Reply via email to