This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 371f59df20 [SYSTEMDS-3018] Add Federated KMeans Planning Test
371f59df20 is described below
commit 371f59df20cc6e8cd888152e136ab88824ccba15
Author: sebwrede <[email protected]>
AuthorDate: Thu Jul 7 10:35:01 2022 +0200
[SYSTEMDS-3018] Add Federated KMeans Planning Test
Closes #1659.
---
.../hops/fedplanner/FederatedPlannerCostbased.java | 50 +-----
.../hops/fedplanner/FederatedPlannerUtils.java | 40 ++++-
.../hops/fedplanner/PrivacyConstraintLoader.java | 20 ++-
.../runtime/instructions/FEDInstructionParser.java | 1 +
.../privacy/propagation/PrivacyPropagator.java | 6 +-
.../fedplanning/FederatedKMeansPlanningTest.java | 169 +++++++++++++++++++++
.../fedplanning/FederatedKMeansPlanningTest.dml | 26 ++++
.../FederatedKMeansPlanningTestReference.dml | 25 +++
8 files changed, 285 insertions(+), 52 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index 4a2cab2854..d809544f6b 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -203,34 +203,13 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
rewriteStatementBlock(prog,
sbFuncBlock, paramMap);
FunctionStatement funcStatement =
(FunctionStatement) sbFuncBlock.getStatement(0);
- mapFunctionOutputs((FunctionOp) sbHop,
funcStatement);
+
FederatedPlannerUtils.mapFunctionOutputs((FunctionOp) sbHop, funcStatement,
transientWrites);
}
}
}
return new ArrayList<>(Collections.singletonList(sb));
}
- /**
- * Saves the HOPs (TWrite) of the function return values for
- * the variable name used when calling the function.
- *
- * Example:
- * <code>
- * f = function() return (matrix[double] model) {a = rand(1, 1);}
- * b = f();
- * </code>
- * This function saves the HOP writing to <code>a</code> for identifier
<code>b</code>.
- *
- * @param sbHop The <code>FunctionOp</code> for the call
- * @param funcStatement The <code>FunctionStatement</code> of the
called function
- */
- private void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement
funcStatement) {
- for (int i = 0; i < sbHop.getOutputVariableNames().length; ++i)
{
- Hop outputWrite =
transientWrites.get(funcStatement.getOutputParams().get(i).getName());
- transientWrites.put(sbHop.getOutputVariableNames()[i],
outputWrite);
- }
- }
-
/**
* Set final fedouts of all hops starting from terminal hops.
*/
@@ -368,7 +347,7 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
private ArrayList<Hop> getHopInputs(Hop currentHop, Map<String, Hop>
paramMap){
if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.TRANSIENTREAD) )
- return getTransientInputs(currentHop, paramMap);
+ return
FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites,
localVariableMap);
else
return currentHop.getInput();
}
@@ -392,7 +371,7 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
ArrayList<HopRel> hopRels = new ArrayList<>();
ArrayList<Hop> inputHops = currentHop.getInput();
if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.TRANSIENTREAD) ) {
- inputHops = getTransientInputs(currentHop, paramMap);
+ inputHops =
FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites,
localVariableMap);
if (inputHops == null) {
// check if transient read on a runtime
variable (only when planning during dynamic recompilation)
return createHopRelsFromRuntimeVars(currentHop,
hopRels);
@@ -427,29 +406,6 @@ public class FederatedPlannerCostbased extends
AFederatedPlanner {
return hopRels;
}
- /**
- * Get transient inputs from either paramMap or transientWrites.
- * Inputs from paramMap has higher priority than inputs from
transientWrites.
- * @param currentHop hop for which inputs are read from maps
- * @param paramMap of local parameters
- * @return inputs of currentHop
- */
- private ArrayList<Hop> getTransientInputs(Hop currentHop, Map<String,
Hop> paramMap){
- Hop tWriteHop = null;
- if ( paramMap != null)
- tWriteHop = paramMap.get(currentHop.getName());
- if ( tWriteHop == null )
- tWriteHop = transientWrites.get(currentHop.getName());
- if ( tWriteHop == null ) {
- if(localVariableMap.get(currentHop.getName()) != null)
- return null;
- else
- throw new DMLRuntimeException("Transient write
not found for " + currentHop);
- }
- else
- return new
ArrayList<>(Collections.singletonList(tWriteHop));
- }
-
/**
* Generate a collection of FOUT HopRels representing the different
possible FType outputs.
* For each FType output, only the minimum cost input combination is
chosen.
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
index 45b711a41d..42c5f648f1 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
@@ -21,30 +21,42 @@ package org.apache.sysds.hops.fedplanner;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+/**
+ * Utility class for federated planners.
+ */
public class FederatedPlannerUtils {
+
/**
* Get transient inputs from either paramMap or transientWrites.
* Inputs from paramMap has higher priority than inputs from
transientWrites.
* @param currentHop hop for which inputs are read from maps
* @param paramMap of local parameters
* @param transientWrites map of transient writes
+ * @param localVariableMap map of local variables
* @return inputs of currentHop
*/
- public static ArrayList<Hop> getTransientInputs(Hop currentHop,
Map<String, Hop> paramMap, Map<String,Hop> transientWrites){
+ public static ArrayList<Hop> getTransientInputs(Hop currentHop,
Map<String, Hop> paramMap,
+ Map<String,Hop> transientWrites, LocalVariableMap
localVariableMap){
Hop tWriteHop = null;
if ( paramMap != null)
tWriteHop = paramMap.get(currentHop.getName());
if ( tWriteHop == null )
tWriteHop = transientWrites.get(currentHop.getName());
- if ( tWriteHop == null )
- throw new DMLRuntimeException("Transient write not
found for " + currentHop);
+ if ( tWriteHop == null ) {
+ if(localVariableMap.get(currentHop.getName()) != null)
+ return null;
+ else
+ throw new DMLRuntimeException("Transient write
not found for " + currentHop);
+ }
else
return new
ArrayList<>(Collections.singletonList(tWriteHop));
}
@@ -64,4 +76,26 @@ public class FederatedPlannerUtils {
}
return paramMap;
}
+
+ /**
+ * Saves the HOPs (TWrite) of the function return values for
+ * the variable name used when calling the function.
+ *
+ * Example:
+ * <code>
+ * f = function() return (matrix[double] model) {a = rand(1, 1);}
+ * b = f();
+ * </code>
+ * This function saves the HOP writing to <code>a</code> for identifier
<code>b</code>.
+ *
+ * @param sbHop The <code>FunctionOp</code> for the call
+ * @param funcStatement The <code>FunctionStatement</code> of the
called function
+ * @param transientWrites map of transient writes
+ */
+ public static void mapFunctionOutputs(FunctionOp sbHop,
FunctionStatement funcStatement, Map<String,Hop> transientWrites) {
+ for (int i = 0; i < sbHop.getOutputVariableNames().length; ++i)
{
+ Hop outputWrite =
transientWrites.get(funcStatement.getOutputParams().get(i).getName());
+ transientWrites.put(sbHop.getOutputVariableNames()[i],
outputWrite);
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
index 82e4316988..69a2b638ed 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
@@ -30,6 +30,7 @@ import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
@@ -41,6 +42,7 @@ import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
@@ -48,6 +50,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageItem;
@@ -62,6 +65,7 @@ import java.io.InputStreamReader;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
+import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -71,6 +75,7 @@ public class PrivacyConstraintLoader {
private final Map<Long, Hop> memo = new HashMap<>();
private final Map<String, Hop> transientWrites = new HashMap<>();
+ private LocalVariableMap localVariableMap = new LocalVariableMap();
public void loadConstraints(DMLProgram prog){
rewriteStatementBlocks(prog, prog.getStatementBlocks(), null);
@@ -119,10 +124,17 @@ public class PrivacyConstraintLoader {
loadPrivacyConstraint(forSB.getFromHops(), paramMap);
loadPrivacyConstraint(forSB.getToHops(), paramMap);
loadPrivacyConstraint(forSB.getIncrementHops(), paramMap);
+
+ // add iter variable to local variable map allowing us to
reason over transient reads in the HOP DAG
+ DataIdentifier iterVar = ((ForStatement)
forSB.getStatement(0)).getIterablePredicate().getIterVar();
+ LocalVariableMap tmpLocalVariableMap = localVariableMap;
+ localVariableMap = (LocalVariableMap) localVariableMap.clone();
+ localVariableMap.put(iterVar.getName(), new IntObject(-1));
for(Statement statement : forSB.getStatements()) {
ForStatement forStatement = ((ForStatement) statement);
rewriteStatementBlocks(prog, forStatement.getBody(),
paramMap);
}
+ localVariableMap = tmpLocalVariableMap;
}
private void rewriteFunctionStatementBlock(DMLProgram prog,
FunctionStatementBlock funcSB, Map<String, Hop> paramMap) {
@@ -144,6 +156,9 @@ public class PrivacyConstraintLoader {
paramMap = funcParamMap;
FunctionStatementBlock sbFuncBlock =
prog.getBuiltinFunctionDictionary().getFunction(funcName);
rewriteStatementBlock(prog,
sbFuncBlock, paramMap);
+
+ FunctionStatement funcStatement =
(FunctionStatement) sbFuncBlock.getStatement(0);
+
FederatedPlannerUtils.mapFunctionOutputs((FunctionOp) sbHop, funcStatement,
transientWrites);
}
}
}
@@ -167,7 +182,10 @@ public class PrivacyConstraintLoader {
transientWrites.put(currentHop.getName(), currentHop);
}
else if ( HopRewriteUtils.isData(currentHop,
Types.OpOpData.TRANSIENTREAD) ){
-
currentHop.setPrivacy(FederatedPlannerUtils.getTransientInputs(currentHop,
paramMap, transientWrites).get(0).getPrivacy());
+ ArrayList<Hop> tInputs =
FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites,
localVariableMap);
+ if ( tInputs != null && tInputs.get(0) != null ){
+
currentHop.setPrivacy(tInputs.get(0).getPrivacy());
+ }
} else {
PrivacyPropagator.hopPropagation(currentHop);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 17f448588e..81d2983da1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -58,6 +58,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "uamax" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uacmax" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uamin" ,
FEDType.AggregateUnary );
+ String2FEDInstructionType.put( "uarmin" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uasqk+" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uarsqk+" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uacsqk+" ,
FEDType.AggregateUnary );
diff --git
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
index 94834ebc6e..673bc9d9cd 100644
---
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
+++
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
@@ -33,8 +33,11 @@ import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.IndexingOp;
+import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
@@ -202,7 +205,8 @@ public class PrivacyPropagator
private static OperatorType getOpType(Hop hop){
if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop
instanceof ReorgOp
|| hop instanceof DataOp || hop instanceof LiteralOp ||
hop instanceof NaryOp
- || hop instanceof DataGenOp || hop instanceof
FunctionOp )
+ || hop instanceof DataGenOp || hop instanceof
FunctionOp || hop instanceof IndexingOp
+ || hop instanceof ParameterizedBuiltinOp || hop
instanceof LeftIndexingOp )
return OperatorType.NonAggregate;
else if ( hop instanceof AggBinaryOp || hop instanceof
AggUnaryOp || hop instanceof UnaryOp )
return OperatorType.Aggregate;
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedKMeansPlanningTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedKMeansPlanningTest.java
new file mode 100644
index 0000000000..3a437d3249
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedKMeansPlanningTest.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
+public class FederatedKMeansPlanningTest extends AutomatedTestBase {
+ private static final Log LOG =
LogFactory.getLog(FederatedKMeansPlanningTest.class.getName());
+
+ private final static String TEST_DIR = "functions/privacy/fedplanning/";
+ private final static String TEST_NAME = "FederatedKMeansPlanningTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedKMeansPlanningTest.class.getSimpleName() + "/";
+ private static File TEST_CONF_FILE;
+
+ private final static int blocksize = 1024;
+ public final int rows = 1000;
+ public final int cols = 100;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Test
+ @Ignore
+ public void runKMeansFOUTTest(){
+ String[] expectedHeavyHitters = new String[]{};
+ setTestConf("SystemDS-config-fout.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ @Ignore
+ public void runKMeansHeuristicTest(){
+ String[] expectedHeavyHitters = new String[]{};
+ setTestConf("SystemDS-config-heuristic.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ public void runKMeansCostBasedTest(){
+ String[] expectedHeavyHitters = new String[]{ "fed_fedinit",
"fed_ba+*", "fed_*", "fed_uack+", "fed_bcumoffk+"};
+ setTestConf("SystemDS-config-cost-based.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ @Test
+ public void runRuntimeTest(){
+ String[] expectedHeavyHitters = new String[]{};
+ TEST_CONF_FILE = new
File("src/test/config/SystemDS-config.xml");
+ loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+ }
+
+ private void setTestConf(String test_conf){
+ TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+ }
+
+ /**
+ * Override default configuration with custom test configuration to
ensure
+ * scratch space and local temporary directory locations are also
updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom
configuration file used for template.
+ LOG.info("This test case overrides default configuration with "
+ TEST_CONF_FILE.getPath());
+ return TEST_CONF_FILE;
+ }
+
+ private void writeInputMatrices(){
+ writeStandardRowFedMatrix("X1", 65, null);
+ writeStandardRowFedMatrix("X2", 75, null);
+ }
+
+ private void writeStandardMatrix(String matrixName, long seed, int
numRows, PrivacyConstraint privacyConstraint){
+ double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1,
seed);
+ writeStandardMatrix(matrixName, numRows, privacyConstraint,
matrix);
+ }
+
+ private void writeStandardMatrix(String matrixName, int numRows,
PrivacyConstraint privacyConstraint, double[][] matrix){
+ MatrixCharacteristics mc = new MatrixCharacteristics(numRows,
cols, blocksize, (long) numRows * cols);
+ writeInputMatrixWithMTD(matrixName, matrix, false, mc,
privacyConstraint);
+ }
+
+ private void writeStandardRowFedMatrix(String matrixName, long seed,
PrivacyConstraint privacyConstraint){
+ int halfRows = rows/2;
+ writeStandardMatrix(matrixName, seed, halfRows,
privacyConstraint);
+ }
+
+ private void loadAndRunTest(String[] expectedHeavyHitters, String
testName){
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+
+ Thread t1 = null, t2 = null;
+
+ try {
+ getAndLoadTestConfiguration(testName);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ writeInputMatrices();
+
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[] { "-stats", "-explain",
"hops", "-nvargs",
+ "X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "Y=" + input("Y"), "r=" + rows, "c=" + cols,
"Z=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + testName + "Reference.dml";
+ programArgs = new String[] {"-nvargs", "X1=" +
input("X1"), "X2=" + input("X2"),
+ "Y=" + input("Y"), "Z=" + expected("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ if
(!heavyHittersContainsAllString(expectedHeavyHitters))
+ fail("The following expected heavy hitters are
missing: "
+ +
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+ }
+ finally {
+ TestUtils.shutdownThreads(t1, t2);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+
+}
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTest.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTest.dml
new file mode 100644
index 0000000000..8b259dbb99
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r,
$c)))
+
+[C, Y] = kmeans(X=X,k=4, runs=1, max_iter=120, seed=93)
+write(C, $Z);
diff --git
a/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTestReference.dml
b/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTestReference.dml
new file mode 100644
index 0000000000..b2510c560e
--- /dev/null
+++
b/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTestReference.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+ X = rbind(read($X1), read($X2))
+ [C, Y] = kmeans(X=X,k=4, runs=1, max_iter=120, seed=93)
+ write(C, $Z);