This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 0968c3b1fb2e88d279e71a3cde6154a163d2baf0
Author: sebwrede <[email protected]>
AuthorDate: Thu Dec 17 11:37:17 2020 +0100

    [SYSTEMDS-2759] Add federated lmCG test for rewrite debugging
    
    Closes #1126.
---
 .../org/apache/sysds/test/AutomatedTestBase.java   |   2 +-
 .../test/functions/builtin/BuiltinLmTest.java      |   2 +-
 .../test/functions/privacy/FederatedLmCGTest.java  | 143 +++++++++++++++++++++
 .../scripts/functions/privacy/FederatedLmCG.dml    |  27 ++++
 .../scripts/functions/privacy/FederatedLmCG2.dml   |  28 ++++
 5 files changed, 200 insertions(+), 2 deletions(-)

diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 3ae6eea..4f08e88 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -106,7 +106,7 @@ public abstract class AutomatedTestBase {
        public static final double GPU_TOLERANCE = 1e-9;
 
        public static final int FED_WORKER_WAIT = 1000; // in ms
-       public static final int FED_WORKER_WAIT_S = 30; // in ms
+       public static final int FED_WORKER_WAIT_S = 40; // in ms
        
 
        // With OpenJDK 8u242 on Windows, the new changes in JDK are not 
allowing
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLmTest.java 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLmTest.java
index 9eeee44..df87f61 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLmTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinLmTest.java
@@ -130,7 +130,7 @@ public class BuiltinLmTest extends AutomatedTestBase
 
 
                        fullDMLScriptName = HOME + dml_test_name + ".dml";
-                       programArgs = new String[]{"-args", input("A"), 
input("B"), output("C") };
+                       programArgs = new String[]{"-explain", "-args", 
input("A"), input("B"), output("C") };
                        fullRScriptName = HOME + TEST_NAME + ".R";
                        rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " "  + expectedDir();
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java 
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
new file mode 100644
index 0000000..d019f32
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class FederatedLmCGTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME = "lmCGFederated";
+       private final static String TEST_DIR = "functions/privacy/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedLmCGTest.class.getSimpleName() + "/";
+
+       private final static int rows = 10;
+       private final static int cols = 3;
+       private final static double spSparse = 0.3;
+       private final static double spDense = 0.7;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"C"}));
+       }
+
+       @Test
+       public void testLmMatrixDenseCPlmCG1() {
+               runLmTest(false, ExecType.CP, false);
+       }
+
+       @Test
+       public void testLmMatrixSparseCPlmCG1() {
+               runLmTest(true, ExecType.CP, false);
+       }
+
+       @Test
+       public void testLmMatrixDenseCPlmCG2() {
+               runLmTest(false, ExecType.CP, true);
+       }
+
+       @Test
+       public void testLmMatrixSparseCPlmCG2() {
+               runLmTest(true, ExecType.CP, true);
+       }
+
+       @Test
+       public void testLmMatrixDenseSPlmCG() {
+               runLmTest(false, ExecType.SPARK, true);
+       }
+
+       @Test
+       public void testLmMatrixSparseSPlmCG() {
+               runLmTest(true, ExecType.SPARK, true);
+       }
+
+       private void runLmTest(boolean sparse, ExecType instType, boolean 
doubleFederated)
+       {
+               ExecMode platformOld = setExecMode(instType);
+
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       double sparsity = sparse ? spSparse : spDense;
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       Thread t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       Thread t2 = startLocalFedWorkerThread(port2);
+
+                       fullDMLScriptName = HOME + "FederatedLmCG" + 
(doubleFederated?"2":"") + ".dml";
+                       
+                       if (doubleFederated){
+                               programArgs = new String[]{
+                                       "-explain", "-nvargs",
+                                       "X1="+TestUtils.federatedAddress(port1, 
input("X1")),
+                                       "X2="+TestUtils.federatedAddress(port2, 
input("X2")),
+                                       "y1=" + 
TestUtils.federatedAddress(port1, input("y1")),
+                                       "y2=" + 
TestUtils.federatedAddress(port2, input("y2")),
+                                       "C="+output("C"),
+                                       "r=" + rows, "c=" + cols};
+                       } else {
+                               programArgs = new String[]{
+                                       "-explain", "-nvargs",
+                                       "X1="+TestUtils.federatedAddress(port1, 
input("X1")),
+                                       "X2="+TestUtils.federatedAddress(port2, 
input("X2")),
+                                       "y=" + input("y"),
+                                       "C="+output("C"),
+                                       "r=" + rows, "c=" + cols};
+                       }
+
+                       //generate actual dataset
+                       int halfRows = rows / 2;
+                       double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 
sparsity, 7);
+                       writeInputMatrixWithMTD("X1", X1, false);
+                       double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 
sparsity, 8);
+                       writeInputMatrixWithMTD("X2", X2, false);
+
+                       if ( doubleFederated ){
+                               double[][] y1 = getRandomMatrix(halfRows, 1, 0, 
10, 1.0, 3);
+                               double[][] y2 = getRandomMatrix(halfRows, 1, 0, 
10, 1.0, 4);
+                               writeInputMatrixWithMTD("y1", y1, false);
+                               writeInputMatrixWithMTD("y2", y2, false);
+                       } else {
+                               double[][] y = getRandomMatrix(rows, 1, 0, 10, 
1.0, 3);
+                               writeInputMatrixWithMTD("y", y, false);
+                       }
+
+                       runTest(true, false, null, -1);
+
+                       //check expected operations
+                       
Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
+                       
+                       TestUtils.shutdownThreads(t1, t2);
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/privacy/FederatedLmCG.dml 
b/src/test/scripts/functions/privacy/FederatedLmCG.dml
new file mode 100644
index 0000000..5eb103e
--- /dev/null
+++ b/src/test/scripts/functions/privacy/FederatedLmCG.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+        ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, 
$c)))
+y = read($y)
+C = lmCG(X = X, y = y, reg = 1e-12, verbose=FALSE)
+write(C, $C)
+
diff --git a/src/test/scripts/functions/privacy/FederatedLmCG2.dml 
b/src/test/scripts/functions/privacy/FederatedLmCG2.dml
new file mode 100644
index 0000000..707a370
--- /dev/null
+++ b/src/test/scripts/functions/privacy/FederatedLmCG2.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+              ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), 
list($r, $c)))
+y = federated(addresses=list($y1, $y2),
+              ranges=list(list(0, 0), list($r / 2, 0), list($r / 2, 0), 
list($r, 0)))
+C = lmCG(X = X, y = y, reg = 1e-12, maxi = 2, verbose=FALSE)
+write(C, $C)
+

Reply via email to