This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 46a30eaef2 [SYSTEMDS-3018] Federated Planner Extended 3
46a30eaef2 is described below

commit 46a30eaef2fb9e25f41c1b46405e60228783b230
Author: sebwrede <[email protected]>
AuthorDate: Tue Apr 19 17:34:28 2022 +0200

    [SYSTEMDS-3018] Federated Planner Extended 3
    
    This commit adds DataOps to allowsFederated and getFederatedOut methods to 
ensure that transient reads and writes are allowed to be FOUT.
    It also changes tests to load configuration files and remove OptimizerUtils 
calls.
    
    Closes #1586.
---
 .../sysds/hops/fedplanner/AFederatedPlanner.java   |  7 +++++++
 .../hops/fedplanner/FederatedPlannerCostbased.java | 14 ++++++-------
 .../fedplanning/FederatedL2SVMPlanningTest.java    |  9 ++-------
 .../fedplanning/FederatedMultiplyPlanningTest.java | 23 ++++++++++++++++------
 4 files changed, 32 insertions(+), 21 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
index b5adb09780..3403cc4bbe 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
@@ -78,6 +78,10 @@ public abstract class AFederatedPlanner {
                else if ( HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS) ){
                        return ft[0] == FType.COL || ft[0] == FType.ROW;
                }
+               else if (HopRewriteUtils.isData(hop, Types.OpOpData.FEDERATED)
+                       || HopRewriteUtils.isData(hop, 
Types.OpOpData.TRANSIENTWRITE)
+                       || HopRewriteUtils.isData(hop, 
Types.OpOpData.TRANSIENTREAD))
+                       return true;
                else if(ft.length==1 && ft[0] != null) {
                        return HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS)
                                || HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM, 
AggOp.MIN, AggOp.MAX);
@@ -135,6 +139,9 @@ public abstract class AFederatedPlanner {
                }
                else if ( HopRewriteUtils.isData(hop, Types.OpOpData.FEDERATED) 
)
                        return deriveFType((DataOp)hop);
+               else if ( HopRewriteUtils.isData(hop, 
Types.OpOpData.TRANSIENTWRITE)
+                       || HopRewriteUtils.isData(hop, 
Types.OpOpData.TRANSIENTREAD) )
+                       return ft[0];
                return null;
        }
        
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
index a4c0bb8760..ee39e468bd 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java
@@ -327,14 +327,12 @@ public class FederatedPlannerCostbased extends 
AFederatedPlanner {
                }
                if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.TRANSIENTWRITE) )
                        transientWrites.put(currentHop.getName(), currentHop);
-               else {
-                       if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.FEDERATED) )
-                               hopRels.add(new HopRel(currentHop, 
FederatedOutput.FOUT, deriveFType((DataOp)currentHop), hopRelMemo, inputHops));
-                       else
-                               hopRels.addAll(generateHopRels(currentHop, 
inputHops));
-                       if ( isLOUTSupported(currentHop) )
-                               hopRels.add(new HopRel(currentHop, 
FederatedOutput.LOUT, hopRelMemo, inputHops));
-               }
+               if ( HopRewriteUtils.isData(currentHop, 
Types.OpOpData.FEDERATED) )
+                       hopRels.add(new HopRel(currentHop, 
FederatedOutput.FOUT, deriveFType((DataOp)currentHop), hopRelMemo, inputHops));
+               else
+                       hopRels.addAll(generateHopRels(currentHop, inputHops));
+               if ( isLOUTSupported(currentHop) )
+                       hopRels.add(new HopRel(currentHop, 
FederatedOutput.LOUT, hopRelMemo, inputHops));
                return hopRels;
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
index 3b0ab91f49..e9ab6b6ad0 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java
@@ -23,7 +23,6 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
-import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -74,7 +73,8 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
        public void runL2SVMCostBasedTest(){
                //String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
                //      "fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
-               String[] expectedHeavyHitters = new String[]{ "fed_fedinit"};
+               String[] expectedHeavyHitters = new String[]{ "fed_fedinit", 
"fed_ba+*", "fed_tak+*", "fed_+*",
+                       "fed_max", "fed_1-*", "fed_>"};
                setTestConf("SystemDS-config-cost-based.xml");
                loadAndRunTest(expectedHeavyHitters);
        }
@@ -126,8 +126,6 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
                Thread t1 = null, t2 = null;
 
                try {
-                       OptimizerUtils.FEDERATED_COMPILATION = true;
-
                        getAndLoadTestConfiguration(TEST_NAME);
                        String HOME = SCRIPT_DIR + TEST_DIR;
 
@@ -145,8 +143,6 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
                                "Y=" + input("Y"), "r=" + rows, "c=" + cols, 
"Z=" + output("Z")};
                        runTest(true, false, null, -1);
 
-                       OptimizerUtils.FEDERATED_COMPILATION = false;
-
                        // Run reference dml script with normal matrix
                        fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
                        programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"),
@@ -160,7 +156,6 @@ public class FederatedL2SVMPlanningTest extends 
AutomatedTestBase {
                                        + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
                }
                finally {
-                       OptimizerUtils.FEDERATED_COMPILATION = false;
                        TestUtils.shutdownThreads(t1, t2);
                        rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 56a7dae1f6..6ec10232dd 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -19,7 +19,8 @@
 
 package org.apache.sysds.test.functions.privacy.fedplanning;
 
-import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
 import org.junit.Ignore;
@@ -33,6 +34,7 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
+import java.io.File;
 import java.util.Arrays;
 import java.util.Collection;
 
@@ -41,6 +43,8 @@ import static org.junit.Assert.fail;
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
+       private static final Log LOG = 
LogFactory.getLog(FederatedMultiplyPlanningTest.class.getName());
+
        private final static String TEST_DIR = "functions/privacy/fedplanning/";
        private final static String TEST_NAME = "FederatedMultiplyPlanningTest";
        private final static String TEST_NAME_2 = 
"FederatedMultiplyPlanningTest2";
@@ -52,6 +56,7 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
        private final static String TEST_NAME_8 = 
"FederatedMultiplyPlanningTest8";
        private final static String TEST_NAME_9 = 
"FederatedMultiplyPlanningTest9";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
+       private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + 
TEST_DIR, "SystemDS-config-cost-based.xml");
 
        private final static int blocksize = 1024;
        @Parameterized.Parameter()
@@ -223,8 +228,6 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                Thread t1 = null, t2 = null;
 
                try{
-                       OptimizerUtils.FEDERATED_COMPILATION = true;
-
                        getAndLoadTestConfiguration(testName);
                        String HOME = SCRIPT_DIR + TEST_DIR;
 
@@ -244,8 +247,6 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                        rewriteRealProgramArgs(testName, port1, port2);
                        runTest(true, false, null, -1);
 
-                       OptimizerUtils.FEDERATED_COMPILATION = false;
-
                        // Run reference dml script with normal matrix
                        fullDMLScriptName = HOME + testName + "Reference.dml";
                        programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
@@ -259,7 +260,6 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                                fail("The following expected heavy hitters are 
missing: "
                                        + 
Arrays.toString(missingHeavyHitters(expectedHeavyHitters)));
                } finally {
-                       OptimizerUtils.FEDERATED_COMPILATION = false;
                        TestUtils.shutdownThreads(t1, t2);
                        rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
@@ -289,5 +289,16 @@ public class FederatedMultiplyPlanningTest extends 
AutomatedTestBase {
                                "Y2=" + input("Y2"), "W1=" + input("W1"), "W2=" 
+ input("W2"), "Z=" + expected("Z")};
                }
        }
+
+       /**
+        * Override default configuration with custom test configuration to 
ensure
+        * scratch space and local temporary directory locations are also 
updated.
+        */
+       @Override
+       protected File getConfigTemplateFile() {
+               // Instrumentation in this test's output log to show custom 
configuration file used for template.
+               LOG.info("This test case overrides default configuration with " 
+ TEST_CONF_FILE.getPath());
+               return TEST_CONF_FILE;
+       }
 }
 

Reply via email to