This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 371f59df20 [SYSTEMDS-3018] Add Federated KMeans Planning Test
371f59df20 is described below

commit 371f59df20cc6e8cd888152e136ab88824ccba15
Author: sebwrede <[email protected]>
AuthorDate: Thu Jul 7 10:35:01 2022 +0200

    [SYSTEMDS-3018] Add Federated KMeans Planning Test
    
    Closes #1659.
---
 .../hops/fedplanner/FederatedPlannerCostbased.java |  50 +-----
 .../hops/fedplanner/FederatedPlannerUtils.java     |  40 ++++-
 .../hops/fedplanner/PrivacyConstraintLoader.java   |  20 ++-
 .../runtime/instructions/FEDInstructionParser.java |   1 +
 .../privacy/propagation/PrivacyPropagator.java     |   6 +-
 .../fedplanning/FederatedKMeansPlanningTest.java   | 169 +++++++++++++++++++++
 .../fedplanning/FederatedKMeansPlanningTest.dml    |  26 ++++
 .../FederatedKMeansPlanningTestReference.dml       |  25 +++
 8 files changed, 285 insertions(+), 52 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index 4a2cab2854..d809544f6b 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -203,34 +203,13 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                                        rewriteStatementBlock(prog, 
sbFuncBlock, paramMap);
 
                                        FunctionStatement funcStatement = 
(FunctionStatement) sbFuncBlock.getStatement(0);
-                                       mapFunctionOutputs((FunctionOp) sbHop, 
funcStatement);
+                                       
FederatedPlannerUtils.mapFunctionOutputs((FunctionOp) sbHop, funcStatement, 
transientWrites);
                                }
                        }
                }
                return new ArrayList<>(Collections.singletonList(sb));
        }
 
-       /**
-        * Saves the HOPs (TWrite) of the function return values for
-        * the variable name used when calling the function.
-        *
-        * Example:
-        * <code>
-        *     f = function() return (matrix[double] model) {a = rand(1, 1);}
-        *     b = f();
-        * </code>
-        * This function saves the HOP writing to <code>a</code> for identifier 
<code>b</code>.
-        *
-        * @param sbHop The <code>FunctionOp</code> for the call
-        * @param funcStatement The <code>FunctionStatement</code> of the 
called function
-        */
-       private void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement 
funcStatement) {
-               for (int i = 0; i < sbHop.getOutputVariableNames().length; ++i) 
{
-                       Hop outputWrite = 
transientWrites.get(funcStatement.getOutputParams().get(i).getName());
-                       transientWrites.put(sbHop.getOutputVariableNames()[i], 
outputWrite);
-               }
-       }
-
        /**
         * Set final fedouts of all hops starting from terminal hops.
         */
@@ -368,7 +347,7 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
 
        private ArrayList<Hop> getHopInputs(Hop currentHop, Map<String, Hop> 
paramMap){
                if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTREAD) )
-                       return getTransientInputs(currentHop, paramMap);
+                       return 
FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites, 
localVariableMap);
                else
                        return currentHop.getInput();
        }
@@ -392,7 +371,7 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                ArrayList<HopRel> hopRels = new ArrayList<>();
                ArrayList<Hop> inputHops = currentHop.getInput();
                if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTREAD) ) {
-                       inputHops = getTransientInputs(currentHop, paramMap);
+                       inputHops = 
FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites, 
localVariableMap);
                        if (inputHops == null) {
                                // check if transient read on a runtime 
variable (only when planning during dynamic recompilation)
                                return createHopRelsFromRuntimeVars(currentHop, 
hopRels);
@@ -427,29 +406,6 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                return hopRels;
        }
 
-       /**
-        * Get transient inputs from either paramMap or transientWrites.
-        * Inputs from paramMap has higher priority than inputs from 
transientWrites.
-        * @param currentHop hop for which inputs are read from maps
-        * @param paramMap of local parameters
-        * @return inputs of currentHop
-        */
-       private ArrayList<Hop> getTransientInputs(Hop currentHop, Map<String, 
Hop> paramMap){
-               Hop tWriteHop = null;
-               if ( paramMap != null)
-                       tWriteHop = paramMap.get(currentHop.getName());
-               if ( tWriteHop == null )
-                       tWriteHop = transientWrites.get(currentHop.getName());
-               if ( tWriteHop == null ) {
-                       if(localVariableMap.get(currentHop.getName()) != null)
-                               return null;
-                       else
-                               throw new DMLRuntimeException("Transient write 
not found for " + currentHop);
-               }
-               else
-                       return new 
ArrayList<>(Collections.singletonList(tWriteHop));
-       }
-
        /**
         * Generate a collection of FOUT HopRels representing the different 
possible FType outputs.
         * For each FType output, only the minimum cost input combination is 
chosen.
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
index 45b711a41d..42c5f648f1 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java
@@ -21,30 +21,42 @@ package org.apache.sysds.hops.fedplanner;
 
 import org.apache.sysds.hops.FunctionOp;
 import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.FunctionStatement;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
 
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
+/**
+ * Utility class for federated planners.
+ */
 public class FederatedPlannerUtils {
+
        /**
         * Get transient inputs from either paramMap or transientWrites.
         * Inputs from paramMap has higher priority than inputs from 
transientWrites.
         * @param currentHop hop for which inputs are read from maps
         * @param paramMap of local parameters
         * @param transientWrites map of transient writes
+        * @param localVariableMap map of local variables
         * @return inputs of currentHop
         */
-       public static ArrayList<Hop> getTransientInputs(Hop currentHop, 
Map<String, Hop> paramMap, Map<String,Hop> transientWrites){
+       public static ArrayList<Hop> getTransientInputs(Hop currentHop, 
Map<String, Hop> paramMap,
+               Map<String,Hop> transientWrites, LocalVariableMap 
localVariableMap){
                Hop tWriteHop = null;
                if ( paramMap != null)
                        tWriteHop = paramMap.get(currentHop.getName());
                if ( tWriteHop == null )
                        tWriteHop = transientWrites.get(currentHop.getName());
-               if ( tWriteHop == null )
-                       throw new DMLRuntimeException("Transient write not 
found for " + currentHop);
+               if ( tWriteHop == null ) {
+                       if(localVariableMap.get(currentHop.getName()) != null)
+                               return null;
+                       else
+                               throw new DMLRuntimeException("Transient write 
not found for " + currentHop);
+               }
                else
                        return new 
ArrayList<>(Collections.singletonList(tWriteHop));
        }
@@ -64,4 +76,26 @@ public class FederatedPlannerUtils {
                }
                return paramMap;
        }
+
+       /**
+        * Saves the HOPs (TWrite) of the function return values for
+        * the variable name used when calling the function.
+        *
+        * Example:
+        * <code>
+        *     f = function() return (matrix[double] model) {a = rand(1, 1);}
+        *     b = f();
+        * </code>
+        * This function saves the HOP writing to <code>a</code> for identifier 
<code>b</code>.
+        *
+        * @param sbHop The <code>FunctionOp</code> for the call
+        * @param funcStatement The <code>FunctionStatement</code> of the 
called function
+        * @param transientWrites map of transient writes
+        */
+       public static void mapFunctionOutputs(FunctionOp sbHop, 
FunctionStatement funcStatement, Map<String,Hop> transientWrites) {
+               for (int i = 0; i < sbHop.getOutputVariableNames().length; ++i) 
{
+                       Hop outputWrite = 
transientWrites.get(funcStatement.getOutputParams().get(i).getName());
+                       transientWrites.put(sbHop.getOutputVariableNames()[i], 
outputWrite);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
index 82e4316988..69a2b638ed 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java
@@ -30,6 +30,7 @@ import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.ForStatement;
 import org.apache.sysds.parser.ForStatementBlock;
 import org.apache.sysds.parser.FunctionStatement;
@@ -41,6 +42,7 @@ import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.parser.WhileStatement;
 import org.apache.sysds.parser.WhileStatementBlock;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
@@ -48,6 +50,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.IntObject;
 import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.lineage.LineageItem;
@@ -62,6 +65,7 @@ import java.io.InputStreamReader;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -71,6 +75,7 @@ public class PrivacyConstraintLoader {
 
        private final Map<Long, Hop> memo = new HashMap<>();
        private final Map<String, Hop> transientWrites = new HashMap<>();
+       private LocalVariableMap localVariableMap = new LocalVariableMap();
 
        public void loadConstraints(DMLProgram prog){
                rewriteStatementBlocks(prog, prog.getStatementBlocks(), null);
@@ -119,10 +124,17 @@ public class PrivacyConstraintLoader {
                loadPrivacyConstraint(forSB.getFromHops(), paramMap);
                loadPrivacyConstraint(forSB.getToHops(), paramMap);
                loadPrivacyConstraint(forSB.getIncrementHops(), paramMap);
+
+               // add iter variable to local variable map allowing us to 
reason over transient reads in the HOP DAG
+               DataIdentifier iterVar = ((ForStatement) 
forSB.getStatement(0)).getIterablePredicate().getIterVar();
+               LocalVariableMap tmpLocalVariableMap = localVariableMap;
+               localVariableMap = (LocalVariableMap) localVariableMap.clone();
+               localVariableMap.put(iterVar.getName(), new IntObject(-1));
                for(Statement statement : forSB.getStatements()) {
                        ForStatement forStatement = ((ForStatement) statement);
                        rewriteStatementBlocks(prog, forStatement.getBody(), 
paramMap);
                }
+               localVariableMap = tmpLocalVariableMap;
        }
 
        private void rewriteFunctionStatementBlock(DMLProgram prog, 
FunctionStatementBlock funcSB, Map<String, Hop> paramMap) {
@@ -144,6 +156,9 @@ public class PrivacyConstraintLoader {
                                        paramMap = funcParamMap;
                                        FunctionStatementBlock sbFuncBlock = 
prog.getBuiltinFunctionDictionary().getFunction(funcName);
                                        rewriteStatementBlock(prog, 
sbFuncBlock, paramMap);
+
+                                       FunctionStatement funcStatement = 
(FunctionStatement) sbFuncBlock.getStatement(0);
+                                       
FederatedPlannerUtils.mapFunctionOutputs((FunctionOp) sbHop, funcStatement, 
transientWrites);
                                }
                        }
                }
@@ -167,7 +182,10 @@ public class PrivacyConstraintLoader {
                        transientWrites.put(currentHop.getName(), currentHop);
                }
                else if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTREAD) ){
-                       
currentHop.setPrivacy(FederatedPlannerUtils.getTransientInputs(currentHop, 
paramMap, transientWrites).get(0).getPrivacy());
+                       ArrayList<Hop> tInputs = 
FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites, 
localVariableMap);
+                       if ( tInputs != null && tInputs.get(0) != null ){
+                               
currentHop.setPrivacy(tInputs.get(0).getPrivacy());
+                       }
                } else {
                        PrivacyPropagator.hopPropagation(currentHop);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 17f448588e..81d2983da1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -58,6 +58,7 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "uamax"   , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uacmax"  , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uamin"   , 
FEDType.AggregateUnary );
+               String2FEDInstructionType.put( "uarmin"  , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uasqk+"  , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uarsqk+" , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uacsqk+" , 
FEDType.AggregateUnary );
diff --git 
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
 
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
index 94834ebc6e..673bc9d9cd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java
@@ -33,8 +33,11 @@ import org.apache.sysds.hops.DataGenOp;
 import org.apache.sysds.hops.DataOp;
 import org.apache.sysds.hops.FunctionOp;
 import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.IndexingOp;
+import org.apache.sysds.hops.LeftIndexingOp;
 import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.NaryOp;
+import org.apache.sysds.hops.ParameterizedBuiltinOp;
 import org.apache.sysds.hops.ReorgOp;
 import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.hops.UnaryOp;
@@ -202,7 +205,8 @@ public class PrivacyPropagator
        private static OperatorType getOpType(Hop hop){
                if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop 
instanceof ReorgOp
                        || hop instanceof DataOp || hop instanceof LiteralOp || 
hop instanceof NaryOp
-                       || hop instanceof DataGenOp || hop instanceof 
FunctionOp )
+                       || hop instanceof DataGenOp || hop instanceof 
FunctionOp || hop instanceof IndexingOp
+                       || hop instanceof ParameterizedBuiltinOp || hop 
instanceof LeftIndexingOp )
                        return OperatorType.NonAggregate;
                else if ( hop instanceof AggBinaryOp || hop instanceof 
AggUnaryOp  || hop instanceof UnaryOp )
                        return OperatorType.Aggregate;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedKMeansPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedKMeansPlanningTest.java
new file mode 100644
index 0000000000..3a437d3249
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedKMeansPlanningTest.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy.fedplanning;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+
+public class FederatedKMeansPlanningTest extends AutomatedTestBase {
+       private static final Log LOG = 
LogFactory.getLog(FederatedKMeansPlanningTest.class.getName());
+
+       private final static String TEST_DIR = "functions/privacy/fedplanning/";
+       private final static String TEST_NAME = "FederatedKMeansPlanningTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedKMeansPlanningTest.class.getSimpleName() + "/";
+       private static File TEST_CONF_FILE;
+
+       private final static int blocksize = 1024;
+       public final int rows = 1000;
+       public final int cols = 100;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+       }
+
+       @Test
+       @Ignore
+       public void runKMeansFOUTTest(){
+               String[] expectedHeavyHitters = new String[]{};
+               setTestConf("SystemDS-config-fout.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       @Test
+       @Ignore
+       public void runKMeansHeuristicTest(){
+               String[] expectedHeavyHitters = new String[]{};
+               setTestConf("SystemDS-config-heuristic.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       @Test
+       public void runKMeansCostBasedTest(){
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_*", "fed_uack+", "fed_bcumoffk+"};
+               setTestConf("SystemDS-config-cost-based.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       @Test
+       public void runRuntimeTest(){
+               String[] expectedHeavyHitters = new String[]{};
+               TEST_CONF_FILE = new 
File("src/test/config/SystemDS-config.xml");
+               loadAndRunTest(expectedHeavyHitters, TEST_NAME);
+       }
+
+       private void setTestConf(String test_conf){
+               TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, test_conf);
+       }
+
+       /**
+        * Override default configuration with custom test configuration to 
ensure
+        * scratch space and local temporary directory locations are also 
updated.
+        */
+       @Override
+       protected File getConfigTemplateFile() {
+               // Instrumentation in this test's output log to show custom 
configuration file used for template.
+               LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+               return TEST_CONF_FILE;
+       }
+
+       private void writeInputMatrices(){
+               writeStandardRowFedMatrix("X1", 65, null);
+               writeStandardRowFedMatrix("X2", 75, null);
+       }
+
+       private void writeStandardMatrix(String matrixName, long seed, int 
numRows, PrivacyConstraint privacyConstraint){
+               double[][] matrix = getRandomMatrix(numRows, cols, 0, 1, 1, 
seed);
+               writeStandardMatrix(matrixName, numRows, privacyConstraint, 
matrix);
+       }
+
+       private void writeStandardMatrix(String matrixName, int numRows, 
PrivacyConstraint privacyConstraint, double[][] matrix){
+               MatrixCharacteristics mc = new MatrixCharacteristics(numRows, 
cols, blocksize, (long) numRows * cols);
+               writeInputMatrixWithMTD(matrixName, matrix, false, mc, 
privacyConstraint);
+       }
+
+       private void writeStandardRowFedMatrix(String matrixName, long seed, 
PrivacyConstraint privacyConstraint){
+               int halfRows = rows/2;
+               writeStandardMatrix(matrixName, seed, halfRows, 
privacyConstraint);
+       }
+
+       private void loadAndRunTest(String[] expectedHeavyHitters, String 
testName){
+
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = Types.ExecMode.SINGLE_NODE;
+
+               Thread t1 = null, t2 = null;
+
+               try {
+                       getAndLoadTestConfiguration(testName);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       writeInputMatrices();
+
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       t2 = startLocalFedWorkerThread(port2);
+
+                       // Run actual dml script with federated matrix
+                       fullDMLScriptName = HOME + testName + ".dml";
+                       programArgs = new String[] { "-stats", "-explain", 
"hops", "-nvargs",
+                               "X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "Y=" + input("Y"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
+                       runTest(true, false, null, -1);
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + testName + "Reference.dml";
+                       programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"),
+                               "Y=" + input("Y"), "Z=" + expected("Z")};
+                       runTest(true, false, null, -1);
+
+                       // compare via files
+                       compareResults(1e-9);
+                       if 
(!heavyHittersContainsAllString(expectedHeavyHitters))
+                               fail("The following expected heavy hitters are 
missing: "
+                                       + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
+               }
+               finally {
+                       TestUtils.shutdownThreads(t1, t2);
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+
+
+}
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTest.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTest.dml
new file mode 100644
index 0000000000..8b259dbb99
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+        ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, 
$c)))
+
+[C, Y] = kmeans(X=X,k=4, runs=1, max_iter=120, seed=93)
+write(C, $Z);
diff --git 
a/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTestReference.dml
 
b/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTestReference.dml
new file mode 100644
index 0000000000..b2510c560e
--- /dev/null
+++ 
b/src/test/scripts/functions/privacy/fedplanning/FederatedKMeansPlanningTestReference.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+  X = rbind(read($X1), read($X2))
+  [C, Y] = kmeans(X=X,k=4, runs=1, max_iter=120, seed=93)
+  write(C, $Z);

Reply via email to