This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c5ab81c6cf [SYSTEMDS-3790] Extraction of optimal FedPlans and conflict
handling
c5ab81c6cf is described below
commit c5ab81c6cf8d4288762b2355f060755433e7e720
Author: min-guk <[email protected]>
AuthorDate: Sat Jan 25 10:22:27 2025 +0100
[SYSTEMDS-3790] Extraction of optimal FedPlans and conflict handling
Closes #2175.
---
.../sysds/hops/fedplanner/FederatedMemoTable.java | 214 +++++++--------------
.../hops/fedplanner/FederatedMemoTablePrinter.java | 139 +++++++++++++
.../fedplanner/FederatedPlanCostEnumerator.java | 112 ++++++++++-
.../fedplanner/FederatedPlanCostEstimator.java | 130 ++++++++++++-
.../federated/FederatedPlanCostEnumeratorTest.java | 18 +-
.../FederatedPlanCostEnumeratorTest1.dml} | 0
.../FederatedPlanCostEnumeratorTest2.dml} | 5 +-
.../FederatedPlanCostEnumeratorTest3.dml} | 5 +-
8 files changed, 460 insertions(+), 163 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
index 16240f0281..b2b58871f6 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java
@@ -20,7 +20,6 @@
package org.apache.sysds.hops.fedplanner;
import org.apache.sysds.hops.Hop;
-import org.apache.sysds.hops.OptimizerUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.ImmutablePair;
import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
@@ -29,8 +28,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
-import java.util.HashSet;
-import java.util.Set;
/**
* A Memoization Table for managing federated plans (FedPlan) based on
combinations of Hops and fedOutTypes.
@@ -71,12 +68,11 @@ public class FederatedMemoTable {
* Retrieves the minimum cost child plan considering the parent's
output type.
* The cost is calculated using getParentViewCost to account for
potential type mismatches.
*
- * @param childHopID ?
- * @param childFedOutType ?
- * @return ?
+ * @param fedPlanPair ???
+ * @return min cost fed plan
*/
- public FedPlan getMinCostChildFedPlan(long childHopID, FederatedOutput
childFedOutType) {
- FedPlanVariants fedPlanVariantList = hopMemoTable.get(new
ImmutablePair<>(childHopID, childFedOutType));
+ public FedPlan getMinCostFedPlan(Pair<Long, FederatedOutput>
fedPlanPair) {
+ FedPlanVariants fedPlanVariantList =
hopMemoTable.get(fedPlanPair);
return fedPlanVariantList._fedPlanVariants.stream()
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
.orElse(null);
@@ -86,6 +82,22 @@ public class FederatedMemoTable {
return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType));
}
+ public FedPlanVariants getFedPlanVariants(Pair<Long, FederatedOutput>
fedPlanPair) {
+ return hopMemoTable.get(fedPlanPair);
+ }
+
+ public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput
fedOutType) {
+ // Todo: Consider whether to verify if pruning has been
performed
+ FedPlanVariants fedPlanVariantList = hopMemoTable.get(new
ImmutablePair<>(hopID, fedOutType));
+ return fedPlanVariantList._fedPlanVariants.get(0);
+ }
+
+ public FedPlan getFedPlanAfterPrune(Pair<Long, FederatedOutput>
fedPlanPair) {
+ // Todo: Consider whether to verify if pruning has been
performed
+ FedPlanVariants fedPlanVariantList =
hopMemoTable.get(fedPlanPair);
+ return fedPlanVariantList._fedPlanVariants.get(0);
+ }
+
/**
* Checks if the memo table contains an entry for a given Hop and
fedOutType.
*
@@ -98,162 +110,77 @@ public class FederatedMemoTable {
}
/**
- * Prunes all entries in the memo table, retaining only the minimum-cost
- * FedPlan for each entry.
- */
- public void pruneMemoTable() {
- for (Map.Entry<Pair<Long, FederatedOutput>, FedPlanVariants>
entry : hopMemoTable.entrySet()) {
- List<FedPlan> fedPlanList =
entry.getValue().getFedPlanVariants();
- if (fedPlanList.size() > 1) {
- // Find the FedPlan with the minimum cost
- FedPlan minCostPlan = fedPlanList.stream()
-
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
- .orElse(null);
-
- // Retain only the minimum cost plan
- fedPlanList.clear();
- fedPlanList.add(minCostPlan);
- }
- }
- }
-
- /**
- * Recursively prints a tree representation of the DAG starting from
the given root FedPlan.
- * Includes information about hopID, fedOutType, TotalCost, SelfCost,
and NetCost for each node.
+ * Prunes the specified entry in the memo table, retaining only the
minimum-cost
+ * FedPlan for the given Hop ID and federated output type.
*
- * @param rootFedPlan The starting point FedPlan to print
+ * @param hopID The ID of the Hop to prune
+ * @param federatedOutput The federated output type associated with the
Hop
*/
- public void printFedPlanTree(FedPlan rootFedPlan) {
- Set<FedPlan> visited = new HashSet<>();
- printFedPlanTreeRecursive(rootFedPlan, visited, 0, true);
+ public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) {
+ hopMemoTable.get(new ImmutablePair<>(hopID,
federatedOutput)).prune();
}
/**
- * Helper method to recursively print the FedPlan tree.
- *
- * @param plan The current FedPlan to print
- * @param visited Set to keep track of visited FedPlans (prevents
cycles)
- * @param depth The current depth level for indentation
- * @param isLast Whether this node is the last child of its parent
+ * Represents common properties and costs associated with a Hop.
+ * This class holds a reference to the Hop and tracks its execution and
network transfer costs.
*/
- private void printFedPlanTreeRecursive(FedPlan plan, Set<FedPlan>
visited, int depth, boolean isLast) {
- if (plan == null || visited.contains(plan)) {
- return;
- }
-
- visited.add(plan);
-
- Hop hop = plan.getHopRef();
- StringBuilder sb = new StringBuilder();
-
- // Add FedPlan information
- sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
- .append(plan.getHopRef().getOpString())
- .append(" [")
- .append(plan.getFedOutType())
- .append("]");
-
- StringBuilder childs = new StringBuilder();
- childs.append(" (");
- boolean childAdded = false;
- for( Hop input : hop.getInput()){
- childs.append(childAdded?",":"");
- childs.append(input.getHopID());
- childAdded = true;
- }
- childs.append(")");
- if( childAdded )
- sb.append(childs.toString());
-
-
- sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
- plan.getTotalCost(),
- plan.getSelfCost(),
- plan.getNetTransferCost()));
-
- // Add matrix characteristics
- sb.append(" [")
- .append(hop.getDim1()).append(", ")
- .append(hop.getDim2()).append(", ")
- .append(hop.getBlocksize()).append(", ")
- .append(hop.getNnz());
-
- if (hop.getUpdateType().isInPlace()) {
- sb.append(",
").append(hop.getUpdateType().toString().toLowerCase());
- }
- sb.append("]");
-
- // Add memory estimates
- sb.append(" [")
-
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
-
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
-
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
-
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
-
- // Add reblock and checkpoint requirements
- if (hop.requiresReblock() && hop.requiresCheckpoint()) {
- sb.append(" [rblk, chkpt]");
- } else if (hop.requiresReblock()) {
- sb.append(" [rblk]");
- } else if (hop.requiresCheckpoint()) {
- sb.append(" [chkpt]");
- }
-
- // Add execution type
- if (hop.getExecType() != null) {
- sb.append(", ").append(hop.getExecType());
- }
-
- System.out.println(sb);
-
- // Process child nodes
- List<Pair<Long, FederatedOutput>> childRefs =
plan.getChildFedPlans();
- for (int i = 0; i < childRefs.size(); i++) {
- Pair<Long, FederatedOutput> childRef = childRefs.get(i);
- FedPlanVariants childVariants =
getFedPlanVariants(childRef.getLeft(), childRef.getRight());
- if (childVariants == null ||
childVariants.getFedPlanVariants().isEmpty())
- continue;
+ public static class HopCommon {
+ protected final Hop hopRef; // Reference to the
associated Hop
+ protected double selfCost; // Current execution cost
(compute + memory access)
+ protected double netTransferCost; // Network transfer cost
- boolean isLastChild = (i == childRefs.size() - 1);
- for (FedPlan childPlan :
childVariants.getFedPlanVariants()) {
- printFedPlanTreeRecursive(childPlan, visited,
depth + 1, isLastChild);
- }
+ protected HopCommon(Hop hopRef) {
+ this.hopRef = hopRef;
+ this.selfCost = 0;
+ this.netTransferCost = 0;
}
}
/**
- * Represents a collection of federated execution plan variants for a
specific Hop.
- * Contains cost information and references to the associated plans.
+ * Represents a collection of federated execution plan variants for a
specific Hop and FederatedOutput.
+ * This class contains cost information and references to the
associated plans.
+ * It uses HopCommon to store common properties and costs related to
the Hop.
*/
public static class FedPlanVariants {
- protected final Hop hopRef; // Reference to the
associated Hop
- protected double selfCost; // Current execution cost
(compute + memory access)
- protected double netTransferCost; // Network transfer cost
- private final FederatedOutput fedOutType; // Output
type (FOUT/LOUT)
- protected List<FedPlan> _fedPlanVariants; // List of plan
variants
+ protected HopCommon hopCommon; // Common properties and
costs for the Hop
+ private final FederatedOutput fedOutType; // Output type
(FOUT/LOUT)
+ protected List<FedPlan> _fedPlanVariants; // List of plan
variants
public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) {
- this.hopRef = hopRef;
+ this.hopCommon = new HopCommon(hopRef);
this.fedOutType = fedOutType;
- this.selfCost = 0;
- this.netTransferCost = 0;
this._fedPlanVariants = new ArrayList<>();
}
- public int size() {return _fedPlanVariants.size();}
public void addFedPlan(FedPlan fedPlan)
{_fedPlanVariants.add(fedPlan);}
public List<FedPlan> getFedPlanVariants() {return
_fedPlanVariants;}
+ public boolean isEmpty() {return _fedPlanVariants.isEmpty();}
+
+ public void prune() {
+ if (_fedPlanVariants.size() > 1) {
+ // Find the FedPlan with the minimum cost
+ FedPlan minCostPlan = _fedPlanVariants.stream()
+
.min(Comparator.comparingDouble(FedPlan::getTotalCost))
+ .orElse(null);
+
+ // Retain only the minimum cost plan
+ _fedPlanVariants.clear();
+ _fedPlanVariants.add(minCostPlan);
+ }
+ }
}
/**
* Represents a single federated execution plan with its associated
costs and dependencies.
- * Contains:
+ * This class contains:
* 1. selfCost: Cost of current hop (compute + input/output memory
access)
* 2. totalCost: Cumulative cost including this plan and all child plans
* 3. netTransferCost: Network transfer cost for this plan to parent
plan.
+ *
+ * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon
to manage common properties and costs.
*/
public static class FedPlan {
- private double totalCost; //
Total cost including child plans
+ private double totalCost; // Total cost
including child plans
private final FedPlanVariants fedPlanVariants; // Reference to
variant list
private final List<Pair<Long, FederatedOutput>> childFedPlans;
// Child plan references
@@ -264,25 +191,26 @@ public class FederatedMemoTable {
}
public void setTotalCost(double totalCost) {this.totalCost =
totalCost;}
- public void setSelfCost(double selfCost)
{fedPlanVariants.selfCost = selfCost;}
- public void setNetTransferCost(double netTransferCost)
{fedPlanVariants.netTransferCost = netTransferCost;}
-
- public Hop getHopRef() {return fedPlanVariants.hopRef;}
+ public void setSelfCost(double selfCost)
{fedPlanVariants.hopCommon.selfCost = selfCost;}
+ public void setNetTransferCost(double netTransferCost)
{fedPlanVariants.hopCommon.netTransferCost = netTransferCost;}
+
+ public Hop getHopRef() {return
fedPlanVariants.hopCommon.hopRef;}
+ public long getHopID() {return
fedPlanVariants.hopCommon.hopRef.getHopID();}
public FederatedOutput getFedOutType() {return
fedPlanVariants.fedOutType;}
public double getTotalCost() {return totalCost;}
- public double getSelfCost() {return fedPlanVariants.selfCost;}
- private double getNetTransferCost() {return
fedPlanVariants.netTransferCost;}
+ public double getSelfCost() {return
fedPlanVariants.hopCommon.selfCost;}
+ public double getNetTransferCost() {return
fedPlanVariants.hopCommon.netTransferCost;}
public List<Pair<Long, FederatedOutput>> getChildFedPlans()
{return childFedPlans;}
/**
* Calculates the conditional network transfer cost based on
output type compatibility.
* Returns 0 if output types match, otherwise returns the
network transfer cost.
- * @param parentFedOutType ?
- * @return ?
+ * @param parentFedOutType The federated output type of the
parent plan.
+ * @return The conditional network transfer cost.
*/
public double getCondNetTransferCost(FederatedOutput
parentFedOutType) {
if (parentFedOutType == getFedOutType()) return 0;
- return fedPlanVariants.netTransferCost;
+ return fedPlanVariants.hopCommon.netTransferCost;
}
}
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
new file mode 100644
index 0000000000..f7b3343a98
--- /dev/null
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.fedplanner;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class FederatedMemoTablePrinter {
+ /**
+ * Recursively prints a tree representation of the DAG starting from
the given root FedPlan.
+ * Includes information about hopID, fedOutType, TotalCost, SelfCost,
and NetCost for each node.
+ * Additionally, prints the additional total cost once at the beginning.
+ *
+ * @param rootFedPlan The starting point FedPlan to print
+ * @param memoTable The memoization table containing FedPlan variants
+ * @param additionalTotalCost The additional cost to be printed once
+ */
+ public static void printFedPlanTree(FederatedMemoTable.FedPlan
rootFedPlan, FederatedMemoTable memoTable,
+
double additionalTotalCost) {
+ System.out.println("Additional Cost: " + additionalTotalCost);
+ Set<FederatedMemoTable.FedPlan> visited = new HashSet<>();
+ printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0);
+ }
+
+ /**
+ * Helper method to recursively print the FedPlan tree.
+ *
+ * @param plan The current FedPlan to print
+ * @param visited Set to keep track of visited FedPlans (prevents
cycles)
+ * @param depth The current depth level for indentation
+ */
+ private static void
printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable
memoTable,
+
Set<FederatedMemoTable.FedPlan> visited, int depth) {
+ if (plan == null || visited.contains(plan)) {
+ return;
+ }
+
+ visited.add(plan);
+
+ Hop hop = plan.getHopRef();
+ StringBuilder sb = new StringBuilder();
+
+ // Add FedPlan information
+ sb.append(String.format("(%d) ", plan.getHopRef().getHopID()))
+ .append(plan.getHopRef().getOpString())
+ .append(" [")
+ .append(plan.getFedOutType())
+ .append("]");
+
+ StringBuilder childs = new StringBuilder();
+ childs.append(" (");
+ boolean childAdded = false;
+ for( Hop input : hop.getInput()){
+ childs.append(childAdded?",":"");
+ childs.append(input.getHopID());
+ childAdded = true;
+ }
+ childs.append(")");
+ if( childAdded )
+ sb.append(childs.toString());
+
+
+ sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}",
+ plan.getTotalCost(),
+ plan.getSelfCost(),
+ plan.getNetTransferCost()));
+
+ // Add matrix characteristics
+ sb.append(" [")
+ .append(hop.getDim1()).append(", ")
+ .append(hop.getDim2()).append(", ")
+ .append(hop.getBlocksize()).append(", ")
+ .append(hop.getNnz());
+
+ if (hop.getUpdateType().isInPlace()) {
+ sb.append(",
").append(hop.getUpdateType().toString().toLowerCase());
+ }
+ sb.append("]");
+
+ // Add memory estimates
+ sb.append(" [")
+
.append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ")
+
.append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ")
+
.append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ")
+
.append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]");
+
+ // Add reblock and checkpoint requirements
+ if (hop.requiresReblock() && hop.requiresCheckpoint()) {
+ sb.append(" [rblk, chkpt]");
+ } else if (hop.requiresReblock()) {
+ sb.append(" [rblk]");
+ } else if (hop.requiresCheckpoint()) {
+ sb.append(" [chkpt]");
+ }
+
+ // Add execution type
+ if (hop.getExecType() != null) {
+ sb.append(", ").append(hop.getExecType());
+ }
+
+ System.out.println(sb);
+
+ // Process child nodes
+ List<Pair<Long, FEDInstruction.FederatedOutput>>
childFedPlanPairs = plan.getChildFedPlans();
+ for (int i = 0; i < childFedPlanPairs.size(); i++) {
+ Pair<Long, FEDInstruction.FederatedOutput>
childFedPlanPair = childFedPlanPairs.get(i);
+ FederatedMemoTable.FedPlanVariants childVariants =
memoTable.getFedPlanVariants(childFedPlanPair);
+ if (childVariants == null || childVariants.isEmpty())
+ continue;
+
+ for (FederatedMemoTable.FedPlan childPlan :
childVariants.getFedPlanVariants()) {
+ printFedPlanTreeRecursive(childPlan, memoTable,
visited, depth + 1);
+ }
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
index 73e8d5d693..be1cfa7cdf 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java
@@ -20,10 +20,14 @@
package org.apache.sysds.hops.fedplanner;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
import java.util.Comparator;
+import java.util.HashMap;
import java.util.Objects;
+import java.util.LinkedHashMap;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants;
@@ -36,12 +40,13 @@ import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
*/
public class FederatedPlanCostEnumerator {
/**
- * Entry point for federated plan enumeration. Creates a memo table and
returns
- * the minimum cost plan for the entire DAG.
+ * Entry point for federated plan enumeration. This method creates a
memo table
+ * and returns the minimum cost plan for the entire Directed Acyclic
Graph (DAG).
+ * It also resolves conflicts where FedPlans have different
FederatedOutput types.
*
- * @param rootHop ?
- * @param printTree ?
- * @return ?
+ * @param rootHop The root Hop node from which to start the plan
enumeration.
+ * @param printTree A boolean flag indicating whether to print the
federated plan tree.
+ * @return The optimal FedPlan with the minimum cost for the entire DAG.
*/
public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean
printTree) {
// Create new memo table to store all plan variants
@@ -52,8 +57,12 @@ public class FederatedPlanCostEnumerator {
// Return the minimum cost plan for the root node
FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(),
memoTable);
- memoTable.pruneMemoTable();
- if (printTree) memoTable.printFedPlanTree(optimalPlan);
+
+ // Detect conflicts in the federated plans where different
FedPlans have different FederatedOutput types
+ double additionalTotalCost =
detectAndResolveConflictFedPlan(optimalPlan, memoTable);
+
+ // Optionally print the federated plan tree if requested
+ if (printTree)
FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable,
additionalTotalCost);
return optimalPlan;
}
@@ -106,6 +115,10 @@ public class FederatedPlanCostEnumerator {
FedPlan lOutPlan = memoTable.addFedPlan(hop,
FederatedOutput.LOUT, planChilds);
FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable);
}
+
+ // Prune MemoTable for hop.
+ memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.LOUT);
+ memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.FOUT);
}
/**
@@ -133,4 +146,89 @@ public class FederatedPlanCostEnumerator {
}
return minlOutFedPlan;
}
+
+ /**
+ * Detects and resolves conflicts in federated plans starting from the
root plan.
+ * This function performs a breadth-first search (BFS) to traverse the
federated plan tree.
+ * It identifies conflicts where the same plan ID has different
federated output types.
+ * For each conflict, it records the plan ID and its conflicting parent
plans.
+ * The function ensures that each plan ID is associated with a
consistent federated output type
+ * by resolving these conflicts iteratively.
+ *
+ * The process involves:
+ * - Using a map to track conflicts, associating each plan ID with its
federated output type
+ * and a list of parent plans.
+ * - Storing detected conflicts in a linked map, each entry containing
a plan ID and its
+ * conflicting parent plans.
+ * - Performing BFS traversal starting from the root plan, checking
each child plan for conflicts.
+ * - If a conflict is detected (i.e., a plan ID has different output
types), the conflicting plan
+ * is removed from the BFS queue and added to the conflict map to
prevent duplicate calculations.
+ * - Resolving conflicts by ensuring a consistent federated output type
across the plan.
+ * - Re-running BFS with resolved conflicts to ensure all
inconsistencies are addressed.
+ *
+ * @param rootPlan The root federated plan from which to start the
conflict detection.
+ * @param memoTable The memoization table used to retrieve pruned
federated plans.
+ * @return The cumulative additional cost for resolving conflicts.
+ */
+ private static double detectAndResolveConflictFedPlan(FedPlan rootPlan,
FederatedMemoTable memoTable) {
+ // Map to track conflicts: maps a plan ID to its federated
output type and list of parent plans
+ Map<Long, Pair<FederatedOutput, List<FedPlan>>>
conflictCheckMap = new HashMap<>();
+
+ // LinkedMap to store detected conflicts, each with a plan ID
and its conflicting parent plans
+ LinkedHashMap<Long, List<FedPlan>> conflictLinkedMap = new
LinkedHashMap<>();
+
+ // LinkedMap for BFS traversal starting from the root plan (Do
not use value (boolean))
+ LinkedHashMap<FedPlan, Boolean> bfsLinkedMap = new
LinkedHashMap<>();
+ bfsLinkedMap.put(rootPlan, true);
+
+ // Array to store cumulative additional cost for resolving
conflicts
+ double[] cumulativeAdditionalCost = new double[]{0.0};
+
+ while (!bfsLinkedMap.isEmpty()) {
+ // Perform BFS to detect conflicts in federated plans
+ while (!bfsLinkedMap.isEmpty()) {
+ FedPlan currentPlan =
bfsLinkedMap.keySet().iterator().next();
+ bfsLinkedMap.remove(currentPlan);
+
+ // Iterate over each child plan of the current
plan
+ for (Pair<Long, FederatedOutput> childPlanPair
: currentPlan.getChildFedPlans()) {
+ FedPlan childFedPlan =
memoTable.getFedPlanAfterPrune(childPlanPair);
+
+ // Check if the child plan ID is
already visited
+ if
(conflictCheckMap.containsKey(childPlanPair.getLeft())) {
+ // Retrieve the existing
conflict pair for the child plan
+ Pair<FederatedOutput,
List<FedPlan>> conflictChildPlanPair =
conflictCheckMap.get(childPlanPair.getLeft());
+ // Add the current plan to the
list of parent plans
+
conflictChildPlanPair.getRight().add(currentPlan);
+
+ // If the federated output type
differs, a conflict is detected
+ if
(conflictChildPlanPair.getLeft() != childPlanPair.getRight()) {
+ // If this is the first
detection, remove conflictChildFedPlan from the BFS queue and add it to the
conflict linked map (queue)
+ // If the existing
FedPlan is not removed from the bfsqueue or both actions are performed,
duplicate calculations for the same FedPlan and its children occur
+ if
(!conflictLinkedMap.containsKey(childPlanPair.getLeft())) {
+
conflictLinkedMap.put(childPlanPair.getLeft(),
conflictChildPlanPair.getRight());
+
bfsLinkedMap.remove(childFedPlan);
+ }
+ }
+ } else {
+ // If no conflict exists,
create a new entry in the conflict check map
+ List<FedPlan> parentFedPlanList
= new ArrayList<>();
+
parentFedPlanList.add(currentPlan);
+
+ // Map the child plan ID to its
output type and list of parent plans
+
conflictCheckMap.put(childPlanPair.getLeft(), new
ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList));
+ // Add the child plan to the
BFS queue
+ bfsLinkedMap.put(childFedPlan,
true);
+ }
+ }
+ }
+ // Resolve these conflicts to ensure a consistent
federated output type across the plan
+ // Re-run BFS with resolved conflicts
+ bfsLinkedMap =
FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap,
cumulativeAdditionalCost);
+ conflictLinkedMap.clear();
+ }
+
+ // Return the cumulative additional cost for resolving conflicts
+ return cumulativeAdditionalCost[0];
+ }
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
index a716c3321d..7bc7339563 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java
@@ -24,6 +24,11 @@ import org.apache.sysds.hops.cost.ComputeCost;
import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan;
import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+import java.util.LinkedHashMap;
+import java.util.NoSuchElementException;
+import java.util.List;
+import java.util.Map;
+
/**
* Cost estimator for federated execution plans.
* Calculates computation, memory access, and network transfer costs for
federated operations.
@@ -47,7 +52,7 @@ public class FederatedPlanCostEstimator {
* @param memoTable Table containing all plan variants
*/
public static void computeFederatedPlanCost(FedPlan currentPlan,
FederatedMemoTable memoTable) {
- double totalCost = 0;
+ double totalCost;
Hop currentHop = currentPlan.getHopRef();
// Step 1: Calculate current node costs if not already computed
@@ -62,11 +67,11 @@ public class FederatedPlanCostEstimator {
}
// Step 2: Process each child plan and add their costs
- for (Pair<Long, FederatedOutput> planRefMeta :
currentPlan.getChildFedPlans()) {
+ for (Pair<Long, FederatedOutput> childPlanPair :
currentPlan.getChildFedPlans()) {
// Find minimum cost child plan considering federation
type compatibility
// Note: This approach might lead to suboptimal or
wrong solutions when a child has multiple parents
// because we're selecting child plans independently
for each parent
- FedPlan planRef =
memoTable.getMinCostChildFedPlan(planRefMeta.getLeft(), planRefMeta.getRight());
+ FedPlan planRef =
memoTable.getMinCostFedPlan(childPlanPair);
// Add child plan cost (includes network transfer cost
if federation types differ)
totalCost += planRef.getTotalCost() +
planRef.getCondNetTransferCost(currentPlan.getFedOutType());
@@ -76,6 +81,125 @@ public class FederatedPlanCostEstimator {
currentPlan.setTotalCost(totalCost);
}
+ /**
+ * Resolves conflicts in federated plans where different plans have
different FederatedOutput types.
+ * This function traverses the list of conflicting plans in reverse
order to ensure that conflicts
+ * are resolved from the bottom-up, allowing for consistent federated
output types across the plan.
+ * It calculates additional costs for each potential resolution and
updates the cumulative additional cost.
+ *
+ * @param memoTable The FederatedMemoTable containing all federated
plan variants.
+ * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent
plans with conflicting federated outputs.
+ * @param cumulativeAdditionalCost An array to store the cumulative
additional cost incurred by resolving conflicts.
+ * @return A LinkedHashMap of resolved federated plans, marked with a
boolean indicating resolution status.
+ */
+ public static LinkedHashMap<FedPlan, Boolean>
resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap<Long,
List<FedPlan>> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) {
+ // LinkedHashMap to store resolved federated plans for BFS
traversal.
+ LinkedHashMap<FedPlan, Boolean> resolvedFedPlanLinkedMap = new
LinkedHashMap<>();
+
+ // Traverse the conflictFedPlanList in reverse order after BFS
to resolve conflicts
+ for (Map.Entry<Long, List<FedPlan>> conflictFedPlanPair :
conflictFedPlanLinkedMap.entrySet()) {
+ long conflictHopID = conflictFedPlanPair.getKey();
+ List<FedPlan> conflictParentFedPlans =
conflictFedPlanPair.getValue();
+
+ // Retrieve the conflicting federated plans for LOUT
and FOUT types
+ FedPlan confilctLOutFedPlan =
memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT);
+ FedPlan confilctFOutFedPlan =
memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT);
+
+ // Variables to store additional costs for LOUT and
FOUT types
+ double lOutAdditionalCost = 0;
+ double fOutAdditionalCost = 0;
+
+ // Flags to check if the plan involves network transfer
+ // Network transfer cost is calculated only once, even
if it occurs multiple times
+ boolean isLOutNetTransfer = false;
+ boolean isFOutNetTransfer = false;
+
+ // Determine the optimal federated output type based on
the calculated costs
+ FederatedOutput optimalFedOutType;
+
+ // Iterate over each parent federated plan in the
current conflict pair
+ for (FedPlan conflictParentFedPlan :
conflictParentFedPlans) {
+ // Find the calculated FedOutType of the child
plan
+ Pair<Long, FederatedOutput>
cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream()
+ .filter(pair ->
pair.getLeft().equals(conflictHopID))
+ .findFirst()
+ .orElseThrow(() -> new
NoSuchElementException("No matching pair found for ID: " + conflictHopID));
+
+ // CASE 1. Calculated LOUT / Parent LOUT /
Current LOUT: Total cost remains unchanged.
+ // CASE 2. Calculated LOUT / Parent FOUT /
Current LOUT: Total cost remains unchanged, subtract net cost, add net cost
later.
+ // CASE 3. Calculated FOUT / Parent LOUT /
Current LOUT: Change total cost, subtract net cost.
+ // CASE 4. Calculated FOUT / Parent FOUT /
Current LOUT: Change total cost, add net cost later.
+ // CASE 5. Calculated LOUT / Parent LOUT /
Current FOUT: Change total cost, add net cost later.
+ // CASE 6. Calculated LOUT / Parent FOUT /
Current FOUT: Change total cost, subtract net cost.
+ // CASE 7. Calculated FOUT / Parent LOUT /
Current FOUT: Total cost remains unchanged, subtract net cost, add net cost
later.
+ // CASE 8. Calculated FOUT / Parent FOUT /
Current FOUT: Total cost remains unchanged.
+
+ // Adjust LOUT, FOUT costs based on the
calculated plan's output type
+ if (cacluatedConflictPlanPair.getRight() ==
FederatedOutput.LOUT) {
+ // When changing from calculated LOUT
to current FOUT, subtract the existing LOUT total cost and add the FOUT total
cost
+ // When maintaining calculated LOUT to
current LOUT, the total cost remains unchanged.
+ fOutAdditionalCost +=
confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost();
+
+ if
(conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) {
+ // (CASE 1) Previously,
calculated was LOUT and parent was LOUT, so no network transfer cost occurred
+ // (CASE 5) If changing from
calculated LOUT to current FOUT, network transfer cost occurs, but calculated
later
+ isFOutNetTransfer = true;
+ } else {
+ // Previously, calculated was
LOUT and parent was FOUT, so network transfer cost occurred
+ // (CASE 2) If maintaining calculated LOUT to current
LOUT, subtract existing network transfer cost and calculate later
+ isLOutNetTransfer = true;
+ lOutAdditionalCost -=
confilctLOutFedPlan.getNetTransferCost();
+
+ // (CASE 6) If changing from
calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it
+ fOutAdditionalCost -=
confilctLOutFedPlan.getNetTransferCost();
+ }
+ } else {
+ lOutAdditionalCost +=
confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost();
+
+ if
(conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) {
+ isLOutNetTransfer = true;
+ } else {
+ isFOutNetTransfer = true;
+ lOutAdditionalCost -=
confilctLOutFedPlan.getNetTransferCost();
+ fOutAdditionalCost -=
confilctLOutFedPlan.getNetTransferCost();
+ }
+ }
+ }
+
+ // Add network transfer costs if applicable
+ if (isLOutNetTransfer) {
+ lOutAdditionalCost +=
confilctLOutFedPlan.getNetTransferCost();
+ }
+ if (isFOutNetTransfer) {
+ fOutAdditionalCost +=
confilctFOutFedPlan.getNetTransferCost();
+ }
+
+ // Determine the optimal federated output type based on
the calculated costs
+ if (lOutAdditionalCost <= fOutAdditionalCost) {
+ optimalFedOutType = FederatedOutput.LOUT;
+ cumulativeAdditionalCost[0] +=
lOutAdditionalCost;
+
resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true);
+ } else {
+ optimalFedOutType = FederatedOutput.FOUT;
+ cumulativeAdditionalCost[0] +=
fOutAdditionalCost;
+
resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true);
+ }
+
+ // Update only the optimal federated output type, not
the cost itself or recursively
+ for (FedPlan conflictParentFedPlan :
conflictParentFedPlans) {
+ for (Pair<Long, FederatedOutput> childPlanPair
: conflictParentFedPlan.getChildFedPlans()) {
+ if (childPlanPair.getLeft() ==
conflictHopID && childPlanPair.getRight() != optimalFedOutType) {
+ int index =
conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair);
+
conflictParentFedPlan.getChildFedPlans().set(index,
+
Pair.of(childPlanPair.getLeft(), optimalFedOutType));
+ break;
+ }
+ }
+ }
+ }
+ return resolvedFedPlanLinkedMap;
+ }
+
/**
* Computes the cost for the current Hop node.
*
diff --git
a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
index 56de8cf3c4..20485588d3 100644
---
a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
+++
b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java
@@ -39,7 +39,7 @@ import
org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator;
public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase
{
- private static final String TEST_DIR = "functions/federated/";
+ private static final String TEST_DIR = "functions/federated/privacy/";
private static final String HOME = SCRIPT_DIR + TEST_DIR;
private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/";
@@ -47,8 +47,15 @@ public class FederatedPlanCostEnumeratorTest extends
AutomatedTestBase
public void setUp() {}
@Test
- public void testDependencyAnalysis1() { runTest("cost.dml"); }
-
+ public void testFederatedPlanCostEnumerator1() {
runTest("FederatedPlanCostEnumeratorTest1.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator2() {
runTest("FederatedPlanCostEnumeratorTest2.dml"); }
+
+ @Test
+ public void testFederatedPlanCostEnumerator3() {
runTest("FederatedPlanCostEnumeratorTest3.dml"); }
+
+ // Todo: Need to write test scripts for the federated version
private void runTest( String scriptFilename ) {
int index = scriptFilename.lastIndexOf(".dml");
String testName = scriptFilename.substring(0, index > 0 ? index
: scriptFilename.length());
@@ -72,10 +79,7 @@ public class FederatedPlanCostEnumeratorTest extends
AutomatedTestBase
dmlt.constructHops(prog);
dmlt.rewriteHopsDAG(prog);
dmlt.constructLops(prog);
- /* TODO) In the current DAG, Hop's _outputMemEstimate
is not initialized
- // This leads to incorrect fedplan generation, so test
code needs to be modified
- // If needed, modify costEstimator to handle cases
where _outputMemEstimate is not initialized
- */
+
Hop hops =
prog.getStatementBlocks().get(0).getHops().get(0);
FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true);
}
diff --git a/src/test/scripts/functions/federated/cost.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest1.dml
similarity index 100%
copy from src/test/scripts/functions/federated/cost.dml
copy to
src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest1.dml
diff --git a/src/test/scripts/functions/federated/cost.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest2.dml
similarity index 95%
copy from src/test/scripts/functions/federated/cost.dml
copy to
src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest2.dml
index ec34d45bb6..3cc07eeb01 100644
--- a/src/test/scripts/functions/federated/cost.dml
+++
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest2.dml
@@ -21,5 +21,6 @@
a = matrix(7,10,10);
b = a + a^2;
-c = sqrt(b);
-print(sum(c));
\ No newline at end of file
+c = a * b;
+d = b + sqrt(c);
+print(sum(d));
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/cost.dml
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest3.dml
similarity index 93%
rename from src/test/scripts/functions/federated/cost.dml
rename to
src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest3.dml
index ec34d45bb6..7fe002df75 100644
--- a/src/test/scripts/functions/federated/cost.dml
+++
b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest3.dml
@@ -20,6 +20,9 @@
#-------------------------------------------------------------
a = matrix(7,10,10);
-b = a + a^2;
+if (sum(a) > 0.5)
+ b = a + a^2;
+else
+ b = a * a;
c = sqrt(b);
print(sum(c));
\ No newline at end of file