This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 8565b1d [SYSTEMDS-2766] Federated covariance
8565b1d is described below
commit 8565b1db948fc7222e3227a84622c6714fa6e425
Author: Olga <[email protected]>
AuthorDate: Mon Jan 11 01:11:47 2021 +0100
[SYSTEMDS-2766] Federated covariance
Closes #1150
---
.../instructions/fed/BinaryFEDInstruction.java | 28 ++
.../instructions/fed/CovarianceFEDInstruction.java | 312 +++++++++++++++++++++
.../instructions/fed/FEDInstructionUtils.java | 3 +
...EmptyTest.java => FederatedCovarianceTest.java} | 125 +++++----
.../primitives/FederatedRemoveEmptyTest.java | 1 -
.../federated/FederatedCovarianceAlignedTest.dml | 31 ++
.../FederatedCovarianceAlignedTestReference.dml | 27 ++
.../federated/FederatedCovarianceTest.dml | 28 ++
.../federated/FederatedCovarianceTestReference.dml | 27 ++
9 files changed, 529 insertions(+), 53 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 9f0c91a..1adaf09 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -35,6 +35,11 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
super(type, op, in1, in2, out, opcode, istr);
}
+ public BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
+ CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
String opcode, String istr) {
+ super(type, op, in1, in2, in3, out, opcode, istr);
+ }
+
public static BinaryFEDInstruction parseInstruction(String str) {
if(str.startsWith(ExecType.SPARK.name())) {
// rewrite the spark instruction to a cp instruction
@@ -67,6 +72,29 @@ public abstract class BinaryFEDInstruction extends
ComputationFEDInstruction {
throw new DMLRuntimeException("Federated binary
operations not yet supported:" + opcode);
}
+ protected static String parseBinaryInstruction(String instr, CPOperand
in1, CPOperand in2, CPOperand out) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(instr);
+ InstructionUtils.checkNumFields ( parts, 3, 4 );
+ String opcode = parts[0];
+ in1.split(parts[1]);
+ in2.split(parts[2]);
+ out.split(parts[3]);
+ return opcode;
+ }
+
+ protected static String parseBinaryInstruction(String instr, CPOperand
in1, CPOperand in2, CPOperand in3, CPOperand out) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(instr);
+ InstructionUtils.checkNumFields ( parts, 4 );
+
+ String opcode = parts[0];
+ in1.split(parts[1]);
+ in2.split(parts[2]);
+ in3.split(parts[3]);
+ out.split(parts[4]);
+
+ return opcode;
+ }
+
protected static void checkOutputDataType(CPOperand in1, CPOperand in2,
CPOperand out) {
// check for valid data type of output
if( (in1.getDataType() == DataType.MATRIX || in2.getDataType()
== DataType.MATRIX) && out.getDataType() != DataType.MATRIX )
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
new file mode 100644
index 0000000..dd38a2f
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CovarianceFEDInstruction.java
@@ -0,0 +1,312 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.Future;
+import java.util.stream.IntStream;
+
+import org.apache.commons.lang3.tuple.ImmutableTriple;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.functionobjects.COV;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.COVOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class CovarianceFEDInstruction extends BinaryFEDInstruction {
+ private CovarianceFEDInstruction(Operator op, CPOperand in1, CPOperand
in2, CPOperand out, String opcode,
+ String istr) {
+ super(FEDInstruction.FEDType.AggregateBinary, op, in1, in2,
out, opcode, istr);
+ }
+
+ private CovarianceFEDInstruction(Operator op, CPOperand in1, CPOperand
in2, CPOperand in3, CPOperand out,
+ String opcode, String istr) {
+ super(FEDInstruction.FEDType.AggregateBinary, op, in1, in2,
in3, out, opcode, istr);
+ }
+
+
+ public static CovarianceFEDInstruction parseInstruction(String str) {
+ CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+ CPOperand in2 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+ CPOperand in3 = null;
+ CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ if( !opcode.equalsIgnoreCase("cov") ) {
+ throw new
DMLRuntimeException("CovarianceCPInstruction.parseInstruction():: Unknown
opcode " + opcode);
+ }
+
+ COVOperator cov = new COVOperator(COV.getCOMFnObject());
+ if ( parts.length == 4 ) {
+ parseBinaryInstruction(str, in1, in2, out);
+ return new CovarianceFEDInstruction(cov, in1, in2, out,
opcode, str);
+ } else if ( parts.length == 5 ) {
+ in3 = new CPOperand("", Types.ValueType.UNKNOWN,
Types.DataType.UNKNOWN);
+ parseBinaryInstruction(str, in1, in2, in3, out);
+ return new CovarianceFEDInstruction(cov, in1, in2, in3,
out, opcode, str);
+ }
+ else {
+ throw new DMLRuntimeException("Invalid number of
arguments in Instruction: " + str);
+ }
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ MatrixObject mo2 = ec.getMatrixObject(input2);
+ MatrixObject weights = input3 != null ?
ec.getMatrixObject(input3) : null;
+
+ if(mo1.isFederated() && mo2.isFederated() &&
!mo1.getFedMapping().isAligned(mo2.getFedMapping(), false))
+ throw new DMLRuntimeException("Not supported
matrix-matrix binary operation: covariance.");
+
+ boolean moAligned = mo1.isFederated() && mo2.isFederated() &&
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false);
+ boolean weightsAligned = weights == null ||
(weights.isFederated() && mo2.isFederated() && weights.getFedMapping()
+ .isAligned(mo2.getFedMapping(), false));
+
+ // all aligned
+ if(moAligned && weightsAligned)
+ processAlignedFedCov(ec, mo1, mo2, weights);
+ // weights are not aligned, broadcast
+ else if(moAligned)
+ processFedCovWeights(ec, mo1, mo2, weights);
+ else
+ processCov(ec, mo1, mo2);
+ }
+
+ private void processAlignedFedCov(ExecutionContext ec, MatrixObject
mo1, MatrixObject mo2, MatrixObject mo3) {
+ FederatedRequest fr1;
+ if(mo3 == null)
+ fr1 = FederationUtils.callInstruction(instString,
output,
+ new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
+ else
+ fr1 = FederationUtils.callInstruction(instString,
output,
+ new CPOperand[]{input1, input2, input3}, new
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID(),
mo3.getFedMapping().getID()});
+
+ FederatedRequest fr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(),
fr1.getID());
+ Future<FederatedResponse>[] covTmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+
+ //means
+ Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0);
+ Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1);
+
+ ImmutableTriple<Double[], Double[], Double[]> res =
getResponses(covTmp, meanTmp1, meanTmp2);
+
+ double result = aggCov(res.left, res.middle, res.right,
mo1.getFedMapping().getFederatedRanges());
+ ec.setVariable(output.getName(), new DoubleObject(result));
+ }
+
+ private void processFedCovWeights(ExecutionContext ec, MatrixObject
mo1, MatrixObject mo2, MatrixObject mo3) {
+
+ FederatedRequest[] fr2 =
mo1.getFedMapping().broadcastSliced(mo3, false);
+ FederatedRequest fr1 =
FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2}, new
long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
+ FederatedRequest fr3 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr4 = mo1.getFedMapping().cleanup(getTID(),
fr1.getID(), fr2[0].getID());
+ Future<FederatedResponse>[] covTmp =
mo1.getFedMapping().execute(getTID(), fr1, fr2[0], fr3, fr4);
+
+ //means
+ Future<FederatedResponse>[] meanTmp1 = processMean(mo1, 0);
+ Future<FederatedResponse>[] meanTmp2 = processMean(mo2, 1);
+
+ ImmutableTriple<Double[], Double[], Double[]> res =
getResponses(covTmp, meanTmp1, meanTmp2);
+
+ double result = aggCov(res.left, res.middle, res.right,
mo1.getFedMapping().getFederatedRanges());
+ ec.setVariable(output.getName(), new DoubleObject(result));
+ }
+
+ private void processCov(ExecutionContext ec, MatrixObject mo1,
MatrixObject mo2) {
+ MatrixBlock mb;
+ MatrixObject mo;
+ COVOperator cop = ((COVOperator)_optr);
+
+ if(!mo1.isFederated() && mo2.isFederated()) {
+ mo = mo2;
+ mb = ec.getMatrixInput(input1.getName());
+ }
+ else {
+ mo = mo1;
+ mb = ec.getMatrixInput(input2.getName());
+ }
+
+ FederationMap fedMapping = mo.getFedMapping();
+ List<CM_COV_Object> globalCmobj = new ArrayList<>();
+ long varID = FederationUtils.getNextFedDataID();
+ fedMapping.mapParallel(varID, (range, data) -> {
+
+ FederatedResponse response;
+ try {
+ if(input3 == null) {
+ response =
data.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new
CovarianceFEDInstruction.COVFunction(data.getVarID(),
+
mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1),
+ cop))).get();
+ }
+ // with weights
+ else {
+ MatrixBlock wtBlock =
ec.getMatrixInput(input2.getName());
+ response =
data.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new
CovarianceFEDInstruction.COVWeightsFunction(data.getVarID(),
+
mb.slice(range.getBeginDimsInt()[0], range.getEndDimsInt()[0] - 1),
+ cop, wtBlock))).get();
+ }
+
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ synchronized(globalCmobj) {
+ globalCmobj.add((CM_COV_Object)
response.getData()[0]);
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ Optional<CM_COV_Object> res =
globalCmobj.stream().reduce((arg0, arg1) -> (CM_COV_Object)
cop.fn.execute(arg0, arg1));
+ try {
+ ec.setScalarOutput(output.getName(), new
DoubleObject(res.get().getRequiredResult(cop)));
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ }
+
+ private static ImmutableTriple<Double[], Double[], Double[]>
getResponses(Future<FederatedResponse>[] covFfr, Future<FederatedResponse>[]
mean1Ffr, Future<FederatedResponse>[] mean2Ffr) {
+ Double[] cov = new Double[covFfr.length];
+ Double[] mean1 = new Double[mean1Ffr.length];
+ Double[] mean2 = new Double[mean2Ffr.length];
+ IntStream.range(0, covFfr.length).forEach(i -> {
+ try {
+ cov[i] = ((ScalarObject)
covFfr[i].get().getData()[0]).getDoubleValue();
+ mean1[i] = ((ScalarObject)
mean1Ffr[1].get().getData()[0]).getDoubleValue();
+ mean2[i] = ((ScalarObject)
mean2Ffr[2].get().getData()[0]).getDoubleValue();
+ }
+ catch(Exception e) {
+ throw new
DMLRuntimeException("CovarianceFEDInstruction: incorrect means or cov.");
+ }
+ });
+
+ return new ImmutableTriple<>(cov, mean1, mean2);
+ }
+
+ private static double aggCov(Double[] covValues, Double[] mean1,
Double[] mean2, FederatedRange[] ranges) {
+ double cov = covValues[0];
+ long size1 = ranges[0].getSize();
+ double mean = (mean1[0] + mean2[0]) / 2;
+
+ for(int i = 0; i < covValues.length - 1; i++) {
+ long size2 = ranges[i+1].getSize();
+ double nextMean = (mean1[i+1] + mean2[i+1]) / 2;
+ double newMean = (size1 * mean + size2 * nextMean) /
(size1 + size2);
+
+ cov = (size1 * cov + size2 * covValues[i+1] + size1 *
(mean - newMean) * (mean - newMean)
+ + size2 * (nextMean - newMean) * (nextMean -
newMean)) / (size1 + size2);
+
+ mean = newMean;
+ size1 = size1 + size2;
+ }
+ return cov;
+ }
+
+ private Future<FederatedResponse>[] processMean(MatrixObject mo1, int
var){
+ String[] parts = instString.split("°");
+ String meanInstr = instString.replace(getOpcode(),
getOpcode().replace("cov", "uamean"));
+ meanInstr = meanInstr.replace((var == 0 ? parts[2] : parts[3])
+ "°", "");
+ meanInstr = meanInstr.replace(parts[4],
parts[4].replace("FP64", "STRING°16"));
+ Future<FederatedResponse>[] meanTmp = null;
+
+ //create federated commands for aggregation
+ FederatedRequest meanFr1 =
FederationUtils.callInstruction(meanInstr, output,
+ new CPOperand[]{var == 0 ? input2 : input1}, new
long[]{mo1.getFedMapping().getID()});
+ FederatedRequest meanFr2 = new
FederatedRequest(FederatedRequest.RequestType.GET_VAR, meanFr1.getID());
+ FederatedRequest meanFr3 =
mo1.getFedMapping().cleanup(getTID(), meanFr1.getID());
+ meanTmp = mo1.getFedMapping().execute(getTID(), meanFr1,
meanFr2, meanFr3);
+ return meanTmp;
+ }
+
+ private static class COVFunction extends FederatedUDF {
+
+ private static final long serialVersionUID =
-501036588060113499L;
+ private final MatrixBlock _mo2;
+ private final COVOperator _op;
+
+ public COVFunction (long input, MatrixBlock mo2, COVOperator
op) {
+ super(new long[] {input});
+ _op = op;
+ _mo2 = mo2;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op,
_mo2));
+ }
+
+ @Override public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
+ private static class COVWeightsFunction extends FederatedUDF {
+ private static final long serialVersionUID =
-1768739786192949573L;
+ private final COVOperator _op;
+ private final MatrixBlock _mo2;
+ private final MatrixBlock _weights;
+
+ protected COVWeightsFunction(long input, MatrixBlock mo2,
COVOperator op, MatrixBlock weights) {
+ super(new long[] {input});
+ _mo2 = mo2;
+ _op = op;
+ _weights = weights;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.covOperations(_op,
_mo2, _weights));
+ }
+
+ @Override public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 613ff31..1417c90 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -136,6 +136,9 @@ public class FEDInstructionUtils {
fedinst =
AppendFEDInstruction.parseInstruction(inst.getInstructionString());
else if(instruction.getOpcode().equals("qpick"))
fedinst =
QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
+ else if("cov".equals(instruction.getOpcode())
&& (ec.getMatrixObject(instruction.input1).isFederated(FType.ROW) ||
+
ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
+ fedinst =
CovarianceFEDInstruction.parseInstruction(inst.getInstructionString());
else
fedinst =
BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCovarianceTest.java
similarity index 51%
copy from
src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
copy to
src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCovarianceTest.java
index 10a6711..557341a 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCovarianceTest.java
@@ -25,7 +25,6 @@ import java.util.Collection;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -36,69 +35,64 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedRemoveEmptyTest extends AutomatedTestBase {
- // private static final Log LOG =
LogFactory.getLog(FederatedRightIndexTest.class.getName());
-
- private final static String TEST_NAME = "FederatedRemoveEmptyTest";
+public class FederatedCovarianceTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "FederatedCovarianceTest";
+ private final static String TEST_NAME2 =
"FederatedCovarianceAlignedTest";
private final static String TEST_DIR = "functions/federated/";
- private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedRemoveEmptyTest.class.getSimpleName() + "/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedCovarianceTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
- @Parameterized.Parameter()
+ @Parameterized.Parameter
public int rows;
@Parameterized.Parameter(1)
public int cols;
- @Parameterized.Parameter(2)
- public boolean rowPartitioned;
-
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {20, 12, true},
- {20, 12, false}
+ {20, 1},
+// {100, 1}, {1000, 1}
});
}
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
}
@Test
- public void testRemoveEmptyCP() {
- runAggregateOperationTest(ExecMode.SINGLE_NODE);
- }
+ public void testCovCP() { runCovTest(ExecMode.SINGLE_NODE, false); }
+
+ @Test
+ public void testAlignedCovCP() { runCovTest(ExecMode.SINGLE_NODE,
true); }
- private void runAggregateOperationTest(ExecMode execMode) {
+ private void runCovTest(ExecMode execMode, boolean alignedFedInput) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
if(rtplatform == ExecMode.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ String TEST_NAME = alignedFedInput ? TEST_NAME2 : TEST_NAME1;
+
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
// write input matrices
- int r = rows;
- int c = cols / 4;
- if(rowPartitioned) {
- r = rows / 4;
- c = cols;
- }
+ int r = r = rows / 4;
+ int c = cols;
+
+ // empty script name because we don't execute any script, just
start the worker
+ fullDMLScriptName = "";
double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
- for(int k : new int[] {1, 2, 3}) {
- Arrays.fill(X3[k], 0);
- }
-
MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
writeInputMatrixWithMTD("X1", X1, false, mc);
writeInputMatrixWithMTD("X2", X2, false, mc);
@@ -124,36 +118,63 @@ public class FederatedRemoveEmptyTest extends
AutomatedTestBase {
TestConfiguration config =
availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
- // Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
- Boolean.toString(rowPartitioned).toUpperCase(),
expected("S")};
-
- runTest(null);
-
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
- "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
- "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
- "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + cols,
- "rP=" + Boolean.toString(rowPartitioned).toUpperCase(),
"out_S=" + output("S")};
-
- runTest(null);
+ if(alignedFedInput) {
+ double[][] Y1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] Y2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] Y3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] Y4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+ writeInputMatrixWithMTD("Y1", Y1, false, mc);
+ writeInputMatrixWithMTD("Y2", Y2, false, mc);
+ writeInputMatrixWithMTD("Y3", Y3, false, mc);
+ writeInputMatrixWithMTD("Y4", Y4, false, mc);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
+ input("Y1"), input("Y2"), input("Y3"),
input("Y4"), expected("S")};
+ runTest(null);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")),
+ "in_Y1=" + TestUtils.federatedAddress(port1,
input("Y1")),
+ "in_Y2=" + TestUtils.federatedAddress(port2,
input("Y2")),
+ "in_Y3=" + TestUtils.federatedAddress(port3,
input("Y3")),
+ "in_Y4=" + TestUtils.federatedAddress(port4,
input("Y4")),
+ "rows=" + rows, "cols=" + cols, "out_S=" +
output("S")};
+ runTest(null);
+
+ } else {
+ double[][] Y = getRandomMatrix(rows, c, 1, 5, 1, 3);
+ writeInputMatrixWithMTD("Y", Y, false, new
MatrixCharacteristics(rows, cols, blocksize, r*c));
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats",
"100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ input("Y"), expected("S"),};
+ runTest(null);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")),
+ "Y=" + input("Y"), "rows=" + rows, "cols=" +
cols, "out_S=" + output("S")};
+ runTest(null);
+ }
// compare via files
- compareResults(1e-9);
-
- // check that federated input files are still existing
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+ compareResults(1e-2);
+ Assert.assertTrue(heavyHittersContainsString("fed_cov"));
TestUtils.shutdownThreads(t1, t2, t3, t4);
-
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-
}
}
+
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
index 10a6711..89f67b2 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
@@ -49,7 +49,6 @@ public class FederatedRemoveEmptyTest extends
AutomatedTestBase {
public int rows;
@Parameterized.Parameter(1)
public int cols;
-
@Parameterized.Parameter(2)
public boolean rowPartitioned;
diff --git
a/src/test/scripts/functions/federated/FederatedCovarianceAlignedTest.dml
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedTest.dml
new file mode 100644
index 0000000..9f64ad0
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCovarianceAlignedTest.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows,
$cols)));
+
+B = federated(addresses=list($in_Y1, $in_Y2, $in_Y3, $in_Y4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows,
$cols)));
+
+s = cov(A, B);
+write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedCovarianceAlignedTestReference.dml
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedTestReference.dml
new file mode 100644
index 0000000..9039286
--- /dev/null
+++
b/src/test/scripts/functions/federated/FederatedCovarianceAlignedTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rbind(read($1), read($2), read($3), read($4));
+B = rbind(read($5), read($6), read($7), read($8));
+
+s = cov(A, B);
+write(s, $9);
diff --git a/src/test/scripts/functions/federated/FederatedCovarianceTest.dml
b/src/test/scripts/functions/federated/FederatedCovarianceTest.dml
new file mode 100644
index 0000000..ee1315a
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCovarianceTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows,
$cols)));
+B = read($Y);
+
+s = cov(A, B);
+write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/FederatedCovarianceTestReference.dml
b/src/test/scripts/functions/federated/FederatedCovarianceTestReference.dml
new file mode 100644
index 0000000..f3c3a3a
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedCovarianceTestReference.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rbind(read($1), read($2), read($3), read($4));
+B = read($5);
+
+s = cov(A, B);
+write(s, $6);