This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new e9ba92b  [SYSTEMDS-3009] Additional federated MSVM algorithm tests
e9ba92b is described below

commit e9ba92b36a8776b12aff95abb294fc7b6054d914
Author: Olga <[email protected]>
AuthorDate: Sat Jun 5 15:40:05 2021 +0200

    [SYSTEMDS-3009] Additional federated MSVM algorithm tests
    
    Closes #1294.
---
 .../controlprogram/caching/CacheableData.java      |   3 +-
 .../federated/algorithms/FederatedMSVMTest.java    | 127 +++++++++++++++++++++
 .../functions/federated/FederatedMSVMTest.dml      |  35 ++++++
 .../federated/FederatedMSVMTestReference.dml       |  32 ++++++
 4 files changed, 196 insertions(+), 1 deletion(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index f549d15..5d42e48 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -1033,7 +1033,8 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
 
        // Federated read
        protected T readBlobFromFederated(FederationMap fedMap) throws 
IOException {
-               LOG.info("Pulling data from federated sites");
+               if( LOG.isDebugEnabled() ) //common if instructions keep 
federated outputs
+                       LOG.debug("Pulling data from federated sites");
                MetaDataFormat iimd = (MetaDataFormat) _metaData;
                DataCharacteristics dc = iimd.getDataCharacteristics();
                return readBlobFromFederated(fedMap, dc.getDims());
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
new file mode 100644
index 0000000..c5344f9
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedMSVMTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.algorithms;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedMSVMTest extends AutomatedTestBase {
+
+       private final static String TEST_DIR = "functions/federated/";
+       private final static String TEST_NAME = "FederatedMSVMTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedMSVMTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows have to be even and > 1
+               return Arrays.asList(new Object[][] {
+                       // {2, 1000}, {10, 100}, {100, 10}, {1000, 1}, {10, 
2000},
+                       {2000, 10}});
+       }
+
+       @Test
+       public void federatedMSVM2CP() {
+               federatedMSVM(Types.ExecMode.SINGLE_NODE, false);
+       }
+       
+       @Test
+       public void federatedMSVM1CP() {
+               federatedMSVM(Types.ExecMode.SINGLE_NODE, true);
+       }
+
+       public void federatedMSVM(Types.ExecMode execMode, boolean 
singleWorker) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = execMode;
+               if(rtplatform == Types.ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int halfRows = rows / 2;
+               // We have two matrices handled by a single federated worker
+               double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+               double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+               double[][] Y = getRandomMatrix(rows, 1, 0, 10, 1, 3);
+               Y = TestUtils.round(Y);
+
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("Y", Y, false, new 
MatrixCharacteristics(rows, 1, blocksize, rows));
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2);
+
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-args", input("X1"), input("X2"), 
input("Y"),
+                       String.valueOf(singleWorker).toUpperCase(), 
expected("Z")};
+               runTest(true, false, null, -1);
+
+               // Run actual dml script with federated matrix
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats",  "30", "-nvargs", 
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
+                       "in_Y=" + input("Y"), "single=" + 
String.valueOf(singleWorker).toUpperCase(), "out=" + output("Z")};
+               runTest(true, false, null, -1);
+
+               // compare via files
+               compareResults(1e-9);
+
+               TestUtils.shutdownThreads(t1, t2);
+
+               rtplatform = platformOld;
+               DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+       }
+}
diff --git a/src/test/scripts/functions/federated/FederatedMSVMTest.dml 
b/src/test/scripts/functions/federated/FederatedMSVMTest.dml
new file mode 100644
index 0000000..3d9cc8c
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedMSVMTest.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+Y = read($in_Y)
+
+if( $single ) {
+  X = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows/2, 
$cols)))
+  Y = Y[1:nrow(X),]
+}
+else {
+  X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
+}
+
+model = msvm(X=X,  Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, 
maxIterations = 100, verbose = FALSE)
+
+write(model, $out)
diff --git 
a/src/test/scripts/functions/federated/FederatedMSVMTestReference.dml 
b/src/test/scripts/functions/federated/FederatedMSVMTestReference.dml
new file mode 100644
index 0000000..19fad3a
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedMSVMTestReference.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+Y = read($3)
+if( $4 ) {
+  X = read($1)
+  Y = Y[1:nrow(X),]
+}
+else
+  X = rbind(read($1), read($2))
+
+model = msvm(X=X,  Y=Y, intercept = FALSE, epsilon = 1e-12, lambda = 1, 
maxIterations = 100, verbose = FALSE)
+
+write(model, $5)

Reply via email to