This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 693ef52fa8 [SYSTEMDS-3790] Extended optimizer for federated execution plans 693ef52fa8 is described below commit 693ef52fa8b137aaf91bf45cee57cc895bce5aef Author: min-guk <koreacho...@gmail.com> AuthorDate: Fri Apr 18 11:52:42 2025 +0200 [SYSTEMDS-3790] Extended optimizer for federated execution plans Closes #2238. --- scripts/staging/fedplanner/graph.py | 268 ++++++++ .../sysds/hops/fedplanner/FederatedMemoTable.java | 173 ++--- .../hops/fedplanner/FederatedMemoTablePrinter.java | 153 +++-- .../fedplanner/FederatedPlanCostEnumerator.java | 757 +++++++++++++++------ .../fedplanner/FederatedPlanCostEstimator.java | 466 +++++++------ .../federated/FederatedPlanCostEnumeratorTest.java | 157 +++-- .../component/federated/FederatedPlanVisualizer.py | 268 ++++++++ .../privacy/FederatedPlanCostEnumeratorTest10.dml | 33 + .../privacy/FederatedPlanCostEnumeratorTest4.dml | 28 + .../privacy/FederatedPlanCostEnumeratorTest5.dml | 26 + .../privacy/FederatedPlanCostEnumeratorTest6.dml | 34 + .../privacy/FederatedPlanCostEnumeratorTest7.dml | 28 + .../privacy/FederatedPlanCostEnumeratorTest8.dml | 49 ++ .../privacy/FederatedPlanCostEnumeratorTest9.dml | 58 ++ 14 files changed, 1842 insertions(+), 656 deletions(-) diff --git a/scripts/staging/fedplanner/graph.py b/scripts/staging/fedplanner/graph.py new file mode 100644 index 0000000000..b083c77913 --- /dev/null +++ b/scripts/staging/fedplanner/graph.py @@ -0,0 +1,268 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +import sys +import re +import networkx as nx +import matplotlib.pyplot as plt + +try: + import pygraphviz + from networkx.drawing.nx_agraph import graphviz_layout + HAS_PYGRAPHVIZ = True +except ImportError: + HAS_PYGRAPHVIZ = False + print("[WARNING] pygraphviz not found. Please install via 'pip install pygraphviz'.\n" + "If not installed, we will use an alternative layout (spring_layout).") + + +def parse_line(line: str): + """ + Parse a single line from the trace file to extract: + - Node ID + - Operation (hop name) + - Kind (e.g., FOUT, LOUT, NREF) + - Total cost + - Weight + - Refs (list of IDs that this node depends on) + """ + + # 1) Match a node ID in the form of "(R)" or "(<number>)" + match_id = re.match(r'^\((R|\d+)\)', line) + if not match_id: + return None + node_id = match_id.group(1) + + # 2) The remaining string after the node ID + after_id = line[match_id.end():].strip() + + # Extract operation (hop name) before the first "[" + match_label = re.search(r'^(.*?)\s*\[', after_id) + if match_label: + operation = match_label.group(1).strip() + else: + operation = after_id.strip() + + # 3) Extract the kind (content inside the first pair of brackets "[]") + match_bracket = re.search(r'\[([^\]]+)\]', after_id) + if match_bracket: + kind = match_bracket.group(1).strip() + else: + kind = "" + + # 4) Extract total and weight from the content inside curly braces "{}" + total = "" + weight = "" + match_curly = re.search(r'\{([^}]+)\}', line) + if match_curly: + curly_content = match_curly.group(1) + m_total = re.search(r'Total:\s*([\d\.]+)', curly_content) + m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content) + if m_total: + total = m_total.group(1) + if m_weight: + weight = m_weight.group(1) + + # 5) Extract reference nodes: look for the first parenthesis containing numbers after the hop name + match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id) + if match_refs: + ref_str = match_refs.group(1) + refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()] + else: + refs = [] + + return { + 'node_id': node_id, + 'operation': operation, + 'kind': kind, + 'total': total, + 'weight': weight, + 'refs': refs + } + + +def build_dag_from_file(filename: str): + """ + Read a trace file line by line and build a directed acyclic graph (DAG) using NetworkX. + """ + G = nx.DiGraph() + with open(filename, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + + info = parse_line(line) + if not info: + continue + + node_id = info['node_id'] + operation = info['operation'] + kind = info['kind'] + total = info['total'] + weight = info['weight'] + refs = info['refs'] + + # Add node with attributes + G.add_node(node_id, label=operation, kind=kind, total=total, weight=weight) + + # Add edges from references to this node + for r in refs: + if r not in G: + G.add_node(r, label=r, kind="", total="", weight="") + G.add_edge(r, node_id) + return G + + +def main(): + """ + Main function that: + - Reads a filename from command-line arguments + - Builds a DAG from the file + - Draws and displays the DAG using matplotlib + """ + + # Get filename from command-line argument + if len(sys.argv) < 2: + print("[ERROR] No filename provided.\nUsage: python plot_federated_dag.py <filename>") + sys.exit(1) + filename = sys.argv[1] + + print(f"[INFO] Running with filename '{filename}'") + + # Build the DAG + G = build_dag_from_file(filename) + + # Print debug info: nodes and edges + print("Nodes:", G.nodes(data=True)) + print("Edges:", list(G.edges())) + + # Decide on layout + if HAS_PYGRAPHVIZ: + # graphviz_layout with rankdir=BT (bottom to top), etc. + pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5 -Granksep=0.8') + else: + # Fallback layout if pygraphviz is not installed + pos = nx.spring_layout(G, seed=42) + + # Dynamically adjust figure size based on number of nodes + node_count = len(G.nodes()) + fig_width = 10 + node_count / 10.0 + fig_height = 6 + node_count / 10.0 + plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300) + ax = plt.gca() + ax.set_facecolor('white') + + # Generate labels for each node in the format: + # node_id: operation_name + # C<total> (W<weight>) + labels = { + n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total', '')} (W{G.nodes[n].get('weight', '')})" + for n in G.nodes() + } + + # Function to determine color based on 'kind' + def get_color(n): + k = G.nodes[n].get('kind', '').lower() + if k == 'fout': + return 'tomato' + elif k == 'lout': + return 'dodgerblue' + elif k == 'nref': + return 'mediumpurple' + else: + return 'mediumseagreen' + + # Determine node shapes based on operation name: + # - '^' (triangle) if the label contains "twrite" + # - 's' (square) if the label contains "tread" + # - 'o' (circle) otherwise + triangle_nodes = [n for n in G.nodes() if 'twrite' in G.nodes[n].get('label', '').lower()] + square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label', '').lower()] + other_nodes = [ + n for n in G.nodes() + if 'twrite' not in G.nodes[n].get('label', '').lower() and + 'tread' not in G.nodes[n].get('label', '').lower() + ] + + # Colors for each group + triangle_colors = [get_color(n) for n in triangle_nodes] + square_colors = [get_color(n) for n in square_nodes] + other_colors = [get_color(n) for n in other_nodes] + + # Draw nodes group-wise + node_collection_triangle = nx.draw_networkx_nodes( + G, pos, nodelist=triangle_nodes, node_size=800, + node_color=triangle_colors, node_shape='^', ax=ax + ) + node_collection_square = nx.draw_networkx_nodes( + G, pos, nodelist=square_nodes, node_size=800, + node_color=square_colors, node_shape='s', ax=ax + ) + node_collection_other = nx.draw_networkx_nodes( + G, pos, nodelist=other_nodes, node_size=800, + node_color=other_colors, node_shape='o', ax=ax + ) + + # Set z-order for nodes, edges, and labels + node_collection_triangle.set_zorder(1) + node_collection_square.set_zorder(1) + node_collection_other.set_zorder(1) + + edge_collection = nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle='->', ax=ax) + if isinstance(edge_collection, list): + for ec in edge_collection: + ec.set_zorder(2) + else: + edge_collection.set_zorder(2) + + label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, ax=ax) + for text in label_dict.values(): + text.set_zorder(3) + + # Set the title + plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold") + + # Provide a small legend on the top-right or top-left + plt.text(1, 1, + "[LABEL]\n hopID: hopName\n C(Total) (W(Weight))", + fontsize=12, ha='right', va='top', transform=ax.transAxes) + + # Example mini-legend for different 'kind' values + plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes) + plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes) + plt.scatter(0.31, 0.95, color='mediumpurple', s=200, transform=ax.transAxes) + + plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center', transform=ax.transAxes) + plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center', transform=ax.transAxes) + plt.text(0.34, 0.95, "NREF", fontsize=12, va='center', transform=ax.transAxes) + + plt.axis("off") + + # Save the plot to a file with the same name as the input file, but with a .png extension + output_filename = f"{filename.rsplit('.', 1)[0]}.png" + plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight') + + plt.show() + + +if __name__ == '__main__': + main() diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index b2b58871f6..b35723b817 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -19,15 +19,15 @@ package org.apache.sysds.hops.fedplanner; -import org.apache.sysds.hops.Hop; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.ArrayList; import java.util.Map; +import org.apache.sysds.hops.Hop; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; /** * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. @@ -38,48 +38,8 @@ public class FederatedMemoTable { // Maps Hop ID and fedOutType pairs to their plan variants private final Map<Pair<Long, FederatedOutput>, FedPlanVariants> hopMemoTable = new HashMap<>(); - /** - * Adds a new federated plan to the memo table. - * Creates a new variant list if none exists for the given Hop and fedOutType. - * - * @param hop The Hop node - * @param fedOutType The federated output type - * @param planChilds List of child plan references - * @return The newly created FedPlan - */ - public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List<Pair<Long, FederatedOutput>> planChilds) { - long hopID = hop.getHopID(); - FedPlanVariants fedPlanVariantList; - - if (contains(hopID, fedOutType)) { - fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); - } else { - fedPlanVariantList = new FedPlanVariants(hop, fedOutType); - hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList); - } - - FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList); - fedPlanVariantList.addFedPlan(newPlan); - - return newPlan; - } - - /** - * Retrieves the minimum cost child plan considering the parent's output type. - * The cost is calculated using getParentViewCost to account for potential type mismatches. - * - * @param fedPlanPair ??? - * @return min cost fed plan - */ - public FedPlan getMinCostFedPlan(Pair<Long, FederatedOutput> fedPlanPair) { - FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); - return fedPlanVariantList._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - } - - public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) { - return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + public void addFedPlanVariants(long hopID, FederatedOutput fedOutType, FedPlanVariants fedPlanVariants) { + hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariants); } public FedPlanVariants getFedPlanVariants(Pair<Long, FederatedOutput> fedPlanPair) { @@ -87,53 +47,47 @@ public class FederatedMemoTable { } public FedPlan getFedPlanAfterPrune(long hopID, FederatedOutput fedOutType) { - // Todo: Consider whether to verify if pruning has been performed FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); return fedPlanVariantList._fedPlanVariants.get(0); } public FedPlan getFedPlanAfterPrune(Pair<Long, FederatedOutput> fedPlanPair) { - // Todo: Consider whether to verify if pruning has been performed FedPlanVariants fedPlanVariantList = hopMemoTable.get(fedPlanPair); return fedPlanVariantList._fedPlanVariants.get(0); } - /** - * Checks if the memo table contains an entry for a given Hop and fedOutType. - * - * @param hopID The Hop ID. - * @param fedOutType The associated fedOutType. - * @return True if the entry exists, false otherwise. - */ public boolean contains(long hopID, FederatedOutput fedOutType) { return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); } /** - * Prunes the specified entry in the memo table, retaining only the minimum-cost - * FedPlan for the given Hop ID and federated output type. - * - * @param hopID The ID of the Hop to prune - * @param federatedOutput The federated output type associated with the Hop - */ - public void pruneFedPlan(long hopID, FederatedOutput federatedOutput) { - hopMemoTable.get(new ImmutablePair<>(hopID, federatedOutput)).prune(); - } - - /** - * Represents common properties and costs associated with a Hop. - * This class holds a reference to the Hop and tracks its execution and network transfer costs. + * Represents a single federated execution plan with its associated costs and dependencies. + * This class contains: + * 1. selfCost: Cost of the current hop (computation + input/output memory access). + * 2. cumulativeCost: Total cost including this plan's selfCost and all child plans' cumulativeCost. + * 3. forwardingCost: Network transfer cost for this plan to the parent plan. + * + * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. */ - public static class HopCommon { - protected final Hop hopRef; // Reference to the associated Hop - protected double selfCost; // Current execution cost (compute + memory access) - protected double netTransferCost; // Network transfer cost + public static class FedPlan { + private double cumulativeCost; // Total cost = sum of selfCost + cumulativeCost of child plans + private final FedPlanVariants fedPlanVariants; // Reference to variant list + private final List<Pair<Long, FederatedOutput>> childFedPlans; // Child plan references - protected HopCommon(Hop hopRef) { - this.hopRef = hopRef; - this.selfCost = 0; - this.netTransferCost = 0; + public FedPlan(double cumulativeCost, FedPlanVariants fedPlanVariants, List<Pair<Long, FederatedOutput>> childFedPlans) { + this.cumulativeCost = cumulativeCost; + this.fedPlanVariants = fedPlanVariants; + this.childFedPlans = childFedPlans; } + + public Hop getHopRef() {return fedPlanVariants.hopCommon.getHopRef();} + public long getHopID() {return fedPlanVariants.hopCommon.getHopRef().getHopID();} + public FederatedOutput getFedOutType() {return fedPlanVariants.getFedOutType();} + public double getCumulativeCost() {return cumulativeCost;} + public double getSelfCost() {return fedPlanVariants.hopCommon.getSelfCost();} + public double getForwardingCost() {return fedPlanVariants.hopCommon.getForwardingCost();} + public double getWeight() {return fedPlanVariants.hopCommon.getWeight();} + public List<Pair<Long, FederatedOutput>> getChildFedPlans() {return childFedPlans;} } /** @@ -146,21 +100,22 @@ public class FederatedMemoTable { private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) protected List<FedPlan> _fedPlanVariants; // List of plan variants - public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) { - this.hopCommon = new HopCommon(hopRef); + public FedPlanVariants(HopCommon hopCommon, FederatedOutput fedOutType) { + this.hopCommon = hopCommon; this.fedOutType = fedOutType; this._fedPlanVariants = new ArrayList<>(); } + public boolean isEmpty() {return _fedPlanVariants.isEmpty();} public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} public List<FedPlan> getFedPlanVariants() {return _fedPlanVariants;} - public boolean isEmpty() {return _fedPlanVariants.isEmpty();} + public FederatedOutput getFedOutType() {return fedOutType;} - public void prune() { + public void pruneFedPlans() { if (_fedPlanVariants.size() > 1) { - // Find the FedPlan with the minimum cost + // Find the FedPlan with the minimum cumulative cost FedPlan minCostPlan = _fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) .orElse(null); // Retain only the minimum cost plan @@ -171,46 +126,28 @@ public class FederatedMemoTable { } /** - * Represents a single federated execution plan with its associated costs and dependencies. - * This class contains: - * 1. selfCost: Cost of current hop (compute + input/output memory access) - * 2. totalCost: Cumulative cost including this plan and all child plans - * 3. netTransferCost: Network transfer cost for this plan to parent plan. - * - * FedPlan is linked to FedPlanVariants, which in turn uses HopCommon to manage common properties and costs. + * Represents common properties and costs associated with a Hop. + * This class holds a reference to the Hop and tracks its execution and network forwarding (transfer) costs. */ - public static class FedPlan { - private double totalCost; // Total cost including child plans - private final FedPlanVariants fedPlanVariants; // Reference to variant list - private final List<Pair<Long, FederatedOutput>> childFedPlans; // Child plan references + public static class HopCommon { + protected final Hop hopRef; // Reference to the associated Hop + protected double selfCost; // Cost of the hop's computation and memory access + protected double forwardingCost; // Cost of forwarding the hop's output to its parent + protected double weight; // Weight used to calculate cost based on hop execution frequency - public FedPlan(List<Pair<Long, FederatedOutput>> childFedPlans, FedPlanVariants fedPlanVariants) { - this.totalCost = 0; - this.childFedPlans = childFedPlans; - this.fedPlanVariants = fedPlanVariants; + public HopCommon(Hop hopRef, double weight) { + this.hopRef = hopRef; + this.selfCost = 0; + this.forwardingCost = 0; + this.weight = weight; } - public void setTotalCost(double totalCost) {this.totalCost = totalCost;} - public void setSelfCost(double selfCost) {fedPlanVariants.hopCommon.selfCost = selfCost;} - public void setNetTransferCost(double netTransferCost) {fedPlanVariants.hopCommon.netTransferCost = netTransferCost;} - - public Hop getHopRef() {return fedPlanVariants.hopCommon.hopRef;} - public long getHopID() {return fedPlanVariants.hopCommon.hopRef.getHopID();} - public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;} - public double getTotalCost() {return totalCost;} - public double getSelfCost() {return fedPlanVariants.hopCommon.selfCost;} - public double getNetTransferCost() {return fedPlanVariants.hopCommon.netTransferCost;} - public List<Pair<Long, FederatedOutput>> getChildFedPlans() {return childFedPlans;} + public Hop getHopRef() {return hopRef;} + public double getSelfCost() {return selfCost;} + public double getForwardingCost() {return forwardingCost;} + public double getWeight() {return weight;} - /** - * Calculates the conditional network transfer cost based on output type compatibility. - * Returns 0 if output types match, otherwise returns the network transfer cost. - * @param parentFedOutType The federated output type of the parent plan. - * @return The conditional network transfer cost. - */ - public double getCondNetTransferCost(FederatedOutput parentFedOutType) { - if (parentFedOutType == getFedOutType()) return 0; - return fedPlanVariants.hopCommon.netTransferCost; - } + protected void setSelfCost(double selfCost) {this.selfCost = selfCost;} + protected void setForwardingCost(double forwardingCost) {this.forwardingCost = forwardingCost;} } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java index f7b3343a98..2841256607 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java @@ -1,28 +1,11 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - package org.apache.sysds.hops.fedplanner; import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; import org.apache.sysds.runtime.instructions.fed.FEDInstruction; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import java.util.HashSet; import java.util.List; @@ -35,14 +18,52 @@ public class FederatedMemoTablePrinter { * Additionally, prints the additional total cost once at the beginning. * * @param rootFedPlan The starting point FedPlan to print + * @param rootHopStatSet ??? * @param memoTable The memoization table containing FedPlan variants * @param additionalTotalCost The additional cost to be printed once */ - public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, FederatedMemoTable memoTable, - double additionalTotalCost) { + public static void printFedPlanTree(FederatedMemoTable.FedPlan rootFedPlan, Set<Hop> rootHopStatSet, + FederatedMemoTable memoTable, double additionalTotalCost) { System.out.println("Additional Cost: " + additionalTotalCost); - Set<FederatedMemoTable.FedPlan> visited = new HashSet<>(); + Set<Long> visited = new HashSet<>(); printFedPlanTreeRecursive(rootFedPlan, memoTable, visited, 0); + + for (Hop hop : rootHopStatSet) { + FedPlan plan = memoTable.getFedPlanAfterPrune(hop.getHopID(), FederatedOutput.LOUT); + printNotReferencedFedPlanRecursive(plan, memoTable, visited, 1); + } + } + + /** + * Helper method to recursively print the FedPlan tree. + * + * @param plan The current FedPlan to print + * @param visited Set to keep track of visited FedPlans (prevents cycles) + * @param depth The current depth level for indentation + */ + private static void printNotReferencedFedPlanRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, + Set<Long> visited, int depth) { + long hopID = plan.getHopRef().getHopID(); + + if (visited.contains(hopID)) { + return; + } + + visited.add(hopID); + printFedPlan(plan, depth, true); + + // Process child nodes + List<Pair<Long, FEDInstruction.FederatedOutput>> childFedPlanPairs = plan.getChildFedPlans(); + for (int i = 0; i < childFedPlanPairs.size(); i++) { + Pair<Long, FEDInstruction.FederatedOutput> childFedPlanPair = childFedPlanPairs.get(i); + FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); + if (childVariants == null || childVariants.isEmpty()) + continue; + + for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { + printNotReferencedFedPlanRecursive(childPlan, memoTable, visited, depth + 1); + } + } } /** @@ -53,40 +74,83 @@ public class FederatedMemoTablePrinter { * @param depth The current depth level for indentation */ private static void printFedPlanTreeRecursive(FederatedMemoTable.FedPlan plan, FederatedMemoTable memoTable, - Set<FederatedMemoTable.FedPlan> visited, int depth) { - if (plan == null || visited.contains(plan)) { + Set<Long> visited, int depth) { + long hopID = 0; + + if (depth == 0) { + hopID = -1; + } else { + hopID = plan.getHopRef().getHopID(); + } + + if (visited.contains(hopID)) { return; } - visited.add(plan); + visited.add(hopID); + printFedPlan(plan, depth, false); - Hop hop = plan.getHopRef(); - StringBuilder sb = new StringBuilder(); + // Process child nodes + List<Pair<Long, FEDInstruction.FederatedOutput>> childFedPlanPairs = plan.getChildFedPlans(); + for (int i = 0; i < childFedPlanPairs.size(); i++) { + Pair<Long, FEDInstruction.FederatedOutput> childFedPlanPair = childFedPlanPairs.get(i); + FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); + if (childVariants == null || childVariants.isEmpty()) + continue; + + for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { + printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1); + } + } + } - // Add FedPlan information - sb.append(String.format("(%d) ", plan.getHopRef().getHopID())) - .append(plan.getHopRef().getOpString()) - .append(" [") - .append(plan.getFedOutType()) - .append("]"); + private static void printFedPlan(FederatedMemoTable.FedPlan plan, int depth, boolean isNotReferenced) { + StringBuilder sb = new StringBuilder(); + Hop hop = null; + + if (depth == 0){ + sb.append("(R) ROOT [Root]"); + } else { + hop = plan.getHopRef(); + // Add FedPlan information + sb.append(String.format("(%d) ", hop.getHopID())) + .append(hop.getOpString()) + .append(" ["); + + if (isNotReferenced) { + sb.append("NRef"); + } else{ + sb.append(plan.getFedOutType()); + } + sb.append("]"); + } StringBuilder childs = new StringBuilder(); childs.append(" ("); + boolean childAdded = false; - for( Hop input : hop.getInput()){ + for (Pair<Long, FederatedOutput> childPair : plan.getChildFedPlans()){ childs.append(childAdded?",":""); - childs.append(input.getHopID()); + childs.append(childPair.getLeft()); childAdded = true; } + childs.append(")"); + if( childAdded ) sb.append(childs.toString()); + if (depth == 0){ + sb.append(String.format(" {Total: %.1f}", plan.getCumulativeCost())); + System.out.println(sb); + return; + } - sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", - plan.getTotalCost(), + sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f, Weight: %.1f}", + plan.getCumulativeCost(), plan.getSelfCost(), - plan.getNetTransferCost())); + plan.getForwardingCost(), + plan.getWeight())); // Add matrix characteristics sb.append(" [") @@ -122,18 +186,5 @@ public class FederatedMemoTablePrinter { } System.out.println(sb); - - // Process child nodes - List<Pair<Long, FEDInstruction.FederatedOutput>> childFedPlanPairs = plan.getChildFedPlans(); - for (int i = 0; i < childFedPlanPairs.size(); i++) { - Pair<Long, FEDInstruction.FederatedOutput> childFedPlanPair = childFedPlanPairs.get(i); - FederatedMemoTable.FedPlanVariants childVariants = memoTable.getFedPlanVariants(childFedPlanPair); - if (childVariants == null || childVariants.isEmpty()) - continue; - - for (FederatedMemoTable.FedPlan childPlan : childVariants.getFedPlanVariants()) { - printFedPlanTreeRecursive(childPlan, memoTable, visited, depth + 1); - } - } } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index be1cfa7cdf..f3e8cc286d 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -17,218 +17,581 @@ * under the License. */ -package org.apache.sysds.hops.fedplanner; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Comparator; -import java.util.HashMap; -import java.util.Objects; -import java.util.LinkedHashMap; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - -/** - * Enumerates and evaluates all possible federated execution plans for a given Hop DAG. - * Works with FederatedMemoTable to store plan variants and FederatedPlanCostEstimator - * to compute their costs. - */ -public class FederatedPlanCostEnumerator { - /** - * Entry point for federated plan enumeration. This method creates a memo table - * and returns the minimum cost plan for the entire Directed Acyclic Graph (DAG). - * It also resolves conflicts where FedPlans have different FederatedOutput types. - * - * @param rootHop The root Hop node from which to start the plan enumeration. - * @param printTree A boolean flag indicating whether to print the federated plan tree. - * @return The optimal FedPlan with the minimum cost for the entire DAG. - */ - public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) { - // Create new memo table to store all plan variants - FederatedMemoTable memoTable = new FederatedMemoTable(); - - // Recursively enumerate all possible plans - enumerateFederatedPlanCost(rootHop, memoTable); - - // Return the minimum cost plan for the root node - FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); - - // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types - double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); - - // Optionally print the federated plan tree if requested - if (printTree) FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost); - - return optimalPlan; - } - - /** - * Recursively enumerates all possible federated execution plans for a Hop DAG. - * For each node: - * 1. First processes all input nodes recursively if not already processed - * 2. Generates all possible combinations of federation types (FOUT/LOUT) for inputs - * 3. Creates and evaluates both FOUT and LOUT variants for current node with each input combination - * - * The enumeration uses a bottom-up approach where: - * - Each input combination is represented by a binary number (i) - * - Bit j in i determines whether input j is FOUT (1) or LOUT (0) - * - Total number of combinations is 2^numInputs - * - * @param hop ? - * @param memoTable ? - */ - private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) { - int numInputs = hop.getInput().size(); - + package org.apache.sysds.hops.fedplanner; + import java.util.ArrayList; + import java.util.List; + import java.util.Map; + import java.util.HashMap; + import java.util.LinkedHashMap; + import java.util.Optional; + import java.util.Set; + import java.util.HashSet; + + import org.apache.commons.lang3.tuple.Pair; + + import org.apache.commons.lang3.tuple.ImmutablePair; + import org.apache.sysds.common.Types; + import org.apache.sysds.hops.DataOp; + import org.apache.sysds.hops.Hop; + import org.apache.sysds.hops.LiteralOp; + import org.apache.sysds.hops.UnaryOp; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; + import org.apache.sysds.hops.rewrite.HopRewriteUtils; + import org.apache.sysds.parser.DMLProgram; + import org.apache.sysds.parser.ForStatement; + import org.apache.sysds.parser.ForStatementBlock; + import org.apache.sysds.parser.FunctionStatement; + import org.apache.sysds.parser.FunctionStatementBlock; + import org.apache.sysds.parser.IfStatement; + import org.apache.sysds.parser.IfStatementBlock; + import org.apache.sysds.parser.StatementBlock; + import org.apache.sysds.parser.WhileStatement; + import org.apache.sysds.parser.WhileStatementBlock; + import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + import org.apache.sysds.runtime.util.UtilFunctions; + + public class FederatedPlanCostEnumerator { + private static final double DEFAULT_LOOP_WEIGHT = 10.0; + private static final double DEFAULT_IF_ELSE_WEIGHT = 0.5; + + /** + * Enumerates the entire DML program to generate federated execution plans. + * It processes each statement block, computes the optimal federated plan, + * detects and resolves conflicts, and optionally prints the plan tree. + * + * @param prog The DML program to enumerate. + * @param isPrint A boolean indicating whether to print the federated plan tree. + */ + public static void enumerateProgram(DMLProgram prog, boolean isPrint) { + FederatedMemoTable memoTable = new FederatedMemoTable(); + + Map<String, List<Hop>> outerTransTable = new HashMap<>(); + Map<String, List<Hop>> formerInnerTransTable = new HashMap<>(); + Set<Hop> progRootHopSet = new HashSet<>(); // Set of hops for the root dummy node + // TODO: Just for debug, remove later + Set<Hop> statRootHopSet = new HashSet<>(); // Set of hops that have no parent but are not referenced + + for (StatementBlock sb : prog.getStatementBlocks()) { + Optional.ofNullable(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, 1, false)) + .ifPresent(outerTransTable::putAll); + } + + FedPlan optimalPlan = getMinCostRootFedPlan(progRootHopSet, memoTable); + + // Detect conflicts in the federated plans where different FedPlans have different FederatedOutput types + double additionalTotalCost = detectAndResolveConflictFedPlan(optimalPlan, memoTable); + + // Print the federated plan tree if requested + if (isPrint) { + FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, statRootHopSet, memoTable, additionalTotalCost); + } + } + + + /** + * Enumerates the statement block and updates the transient and memoization tables. + * This method processes different types of statement blocks such as If, For, While, and Function blocks. + * It recursively enumerates the Hop DAGs within these blocks and updates the corresponding tables. + * The method also calculates weights recursively for if-else/loops and handles inner and outer block distinctions. + * + * @param sb The statement block to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track immutable outer transient writes. + * @param formerInnerTransTable The table to track immutable former inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). + * @param weight The weight associated with the current Hop. + * @param isInnerBlock A boolean indicating if the current block is an inner block. + * @return A map of inner transient writes. + */ + public static Map<String, List<Hop>> enumerateStatementBlock(StatementBlock sb, FederatedMemoTable memoTable, Map<String, List<Hop>> outerTransTable, + Map<String, List<Hop>> formerInnerTransTable, Set<Hop> progRootHopSet, Set<Hop> statRootHopSet, double weight, boolean isInnerBlock) { + Map<String, List<Hop>> innerTransTable = new HashMap<>(); + + if (sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + + enumerateHopDAG(isb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + + // Treat outerTransTable as immutable in inner blocks + // Write TWrite of sb sequentially in innerTransTable, and update formerInnerTransTable after the sb ends + // In case of if-else, create separate formerInnerTransTables for if and else, merge them after completion, and update formerInnerTransTable + Map<String, List<Hop>> ifFormerInnerTransTable = new HashMap<>(formerInnerTransTable); + Map<String, List<Hop>> elseFormerInnerTransTable = new HashMap<>(formerInnerTransTable); + + for (StatementBlock csb : istmt.getIfBody()){ + ifFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, ifFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); + } + + for (StatementBlock csb : istmt.getElseBody()){ + elseFormerInnerTransTable.putAll(enumerateStatementBlock(csb, memoTable, outerTransTable, elseFormerInnerTransTable, progRootHopSet, statRootHopSet, DEFAULT_IF_ELSE_WEIGHT * weight, true)); + } + + // If there are common keys: merge elseValue list into ifValue list + elseFormerInnerTransTable.forEach((key, elseValue) -> { + ifFormerInnerTransTable.merge(key, elseValue, (ifValue, newValue) -> { + ifValue.addAll(newValue); + return ifValue; + }); + }); + // Update innerTransTable + innerTransTable.putAll(ifFormerInnerTransTable); + } + else if (sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement)fsb.getStatement(0); + + // Calculate for-loop iteration count if possible + double loopWeight = DEFAULT_LOOP_WEIGHT; + Hop from = fsb.getFromHops().getInput().get(0); + Hop to = fsb.getToHops().getInput().get(0); + Hop incr = (fsb.getIncrementHops() != null) ? + fsb.getIncrementHops().getInput().get(0) : new LiteralOp(1); + + // Calculate for-loop iteration count (weight) if from, to, and incr are literal ops (constant values) + if( from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp ) { + double dfrom = HopRewriteUtils.getDoubleValue((LiteralOp) from); + double dto = HopRewriteUtils.getDoubleValue((LiteralOp) to); + double dincr = HopRewriteUtils.getDoubleValue((LiteralOp) incr); + if( dfrom > dto && dincr == 1 ) + dincr = -1; + loopWeight = UtilFunctions.getSeqLength(dfrom, dto, dincr, false); + } + weight *= loopWeight; + + enumerateHopDAG(fsb.getFromHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateHopDAG(fsb.getToHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateHopDAG(fsb.getIncrementHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + + enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } + else if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); + weight *= DEFAULT_LOOP_WEIGHT; + + enumerateHopDAG(wsb.getPredicateHops(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + enumerateStatementBlockBody(wstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } + else if (sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + + // TODO: Do not descend for visited functions (use a hash set for functions using their names) + enumerateStatementBlockBody(fstmt.getBody(), memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight); + } + else { //generic (last-level) + if( sb.getHops() != null ){ + for(Hop c : sb.getHops()) + // In the statement block, if isInner, write hopDAG in innerTransTable, if not, write directly in outerTransTable + enumerateHopDAG(c, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, progRootHopSet, statRootHopSet, weight, isInnerBlock); + } + } + return innerTransTable; + } + + /** + * Enumerates the statement blocks within a body and updates the transient and memoization tables. + * + * @param sbList The list of statement blocks to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track immutable outer transient writes. + * @param formerInnerTransTable The table to track immutable former inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of statement root hops for debugging purposes (check if not referenced). + * @param weight The weight associated with the current Hop. + */ + public static void enumerateStatementBlockBody(List<StatementBlock> sbList, FederatedMemoTable memoTable, Map<String, List<Hop>> outerTransTable, + Map<String, List<Hop>> formerInnerTransTable, Map<String, List<Hop>> innerTransTable, Set<Hop> progRootHopSet, Set<Hop> statRootHopSet, double weight) { + // The statement blocks within the body reference outerTransTable and formerInnerTransTable as immutable read-only, + // and record TWrite in the innerTransTable of the statement block within the body. + // Update the formerInnerTransTable with the contents of the returned innerTransTable. + for (StatementBlock sb : sbList) + formerInnerTransTable.putAll(enumerateStatementBlock(sb, memoTable, outerTransTable, formerInnerTransTable, progRootHopSet, statRootHopSet, weight, true)); + + // Then update and return the innerTransTable of the statement block containing the body. + innerTransTable.putAll(formerInnerTransTable); + } + + /** + * Enumerates the statement hop DAG within a statement block. + * This method recursively enumerates all possible federated execution plans + * and identifies hops to connect to the root dummy node. + * + * @param rootHop The root Hop of the DAG to enumerate. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track transient writes. + * @param formerInnerTransTable The table to track immutable inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param progRootHopSet The set of hops to connect to the root dummy node. + * @param statRootHopSet The set of root hops for debugging purposes. + * @param weight The weight associated with the current Hop. + * @param isInnerBlock A boolean indicating if the current block is an inner block. + */ + public static void enumerateHopDAG(Hop rootHop, FederatedMemoTable memoTable, Map<String, List<Hop>> outerTransTable, + Map<String, List<Hop>> formerInnerTransTable, Map<String,List<Hop>> innerTransTable, Set<Hop> progRootHopSet, Set<Hop> statRootHopSet, double weight, boolean isInnerBlock) { + // Recursively enumerate all possible plans + rewireAndEnumerateFedPlan(rootHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInnerBlock); + + // Identify hops to connect to the root dummy node + + if ((rootHop instanceof DataOp && (rootHop.getName().equals("__pred"))) // TWrite "__pred" + || (rootHop instanceof UnaryOp && ((UnaryOp)rootHop).getOp() == Types.OpOp1.PRINT)){ // u(print) + // Connect TWrite pred and u(print) to the root dummy node + // TODO: Should we check all statement-level root hops to see if they are not referenced? + progRootHopSet.add(rootHop); + } else { + // TODO: Just for debug, remove later + // For identifying TWrites that are not referenced later + statRootHopSet.add(rootHop); + } + } + + /** + * Rewires and enumerates federated execution plans for a given Hop. + * This method processes all input nodes, rewires TWrite and TRead operations, + * and generates federated plan variants for both inner and outer code blocks. + * + * @param hop The Hop for which to rewire and enumerate federated plans. + * @param memoTable The memoization table to store plan variants. + * @param outerTransTable The table to track transient writes. + * @param formerInnerTransTable The table to track immutable inner transient writes. + * @param innerTransTable The table to track inner transient writes. + * @param weight The weight associated with the current Hop. + * @param isInner A boolean indicating if the current block is an inner block. + */ + private static void rewireAndEnumerateFedPlan(Hop hop, FederatedMemoTable memoTable, Map<String, List<Hop>> outerTransTable, + Map<String, List<Hop>> formerInnerTransTable, Map<String, List<Hop>> innerTransTable, + double weight, boolean isInner) { // Process all input nodes first if not already in memo table for (Hop inputHop : hop.getInput()) { - if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) - && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { - enumerateFederatedPlanCost(inputHop, memoTable); + long inputHopID = inputHop.getHopID(); + if (!memoTable.contains(inputHopID, FederatedOutput.FOUT) + && !memoTable.contains(inputHopID, FederatedOutput.LOUT)) { + rewireAndEnumerateFedPlan(inputHop, memoTable, outerTransTable, formerInnerTransTable, innerTransTable, weight, isInner); } } - // Generate all possible input combinations using binary representation - // i represents a specific combination of FOUT/LOUT for inputs - for (int i = 0; i < (1 << numInputs); i++) { - List<Pair<Long, FederatedOutput>> planChilds = new ArrayList<>(); - - // For each input, determine if it should be FOUT or LOUT based on bit j in i - for (int j = 0; j < numInputs; j++) { - Hop inputHop = hop.getInput().get(j); - // If bit j is set (1), use FOUT; otherwise use LOUT - FederatedOutput childType = ((i & (1 << j)) != 0) ? - FederatedOutput.FOUT : FederatedOutput.LOUT; - planChilds.add(Pair.of(inputHop.getHopID(), childType)); - } - - // Create and evaluate FOUT variant for current input combination - FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds); - FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable); - - // Create and evaluate LOUT variant for current input combination - FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); - FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); - } + // Determine modified child hops based on DataOp type and transient operations + List<Hop> childHops = rewireTransReadWrite(hop, outerTransTable, formerInnerTransTable, innerTransTable, isInner); - // Prune MemoTable for hop. - memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.LOUT); - memoTable.pruneFedPlan(hop.getHopID(), FederatedOutput.FOUT); + // Enumerate the federated plan for the current Hop + enumerateFedPlan(hop, memoTable, childHops, weight); } - /** - * Returns the minimum cost plan for the root Hop, comparing both FOUT and LOUT variants. - * Used to select the final execution plan after enumeration. - * - * @param HopID ? - * @param memoTable ? - * @return ? - */ - private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) { - FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT); - FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT); - - FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getTotalCost)) - .orElse(null); - - if (Objects.requireNonNull(minFOutFedPlan).getTotalCost() - < Objects.requireNonNull(minlOutFedPlan).getTotalCost()) { - return minFOutFedPlan; + private static List<Hop> rewireTransReadWrite(Hop hop, Map<String, List<Hop>> outerTransTable, + Map<String, List<Hop>> formerInnerTransTable, + Map<String, List<Hop>> innerTransTable, boolean isInner) { + List<Hop> childHops = hop.getInput(); + + if (!(hop instanceof DataOp) || hop.getName().equals("__pred")) { + return childHops; // Early exit for non-DataOp or __pred } - return minlOutFedPlan; - } - - /** - * Detects and resolves conflicts in federated plans starting from the root plan. - * This function performs a breadth-first search (BFS) to traverse the federated plan tree. - * It identifies conflicts where the same plan ID has different federated output types. - * For each conflict, it records the plan ID and its conflicting parent plans. - * The function ensures that each plan ID is associated with a consistent federated output type - * by resolving these conflicts iteratively. - * - * The process involves: - * - Using a map to track conflicts, associating each plan ID with its federated output type - * and a list of parent plans. - * - Storing detected conflicts in a linked map, each entry containing a plan ID and its - * conflicting parent plans. - * - Performing BFS traversal starting from the root plan, checking each child plan for conflicts. - * - If a conflict is detected (i.e., a plan ID has different output types), the conflicting plan - * is removed from the BFS queue and added to the conflict map to prevent duplicate calculations. - * - Resolving conflicts by ensuring a consistent federated output type across the plan. - * - Re-running BFS with resolved conflicts to ensure all inconsistencies are addressed. - * - * @param rootPlan The root federated plan from which to start the conflict detection. - * @param memoTable The memoization table used to retrieve pruned federated plans. - * @return The cumulative additional cost for resolving conflicts. - */ - private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { - // Map to track conflicts: maps a plan ID to its federated output type and list of parent plans - Map<Long, Pair<FederatedOutput, List<FedPlan>>> conflictCheckMap = new HashMap<>(); - - // LinkedMap to store detected conflicts, each with a plan ID and its conflicting parent plans - LinkedHashMap<Long, List<FedPlan>> conflictLinkedMap = new LinkedHashMap<>(); - // LinkedMap for BFS traversal starting from the root plan (Do not use value (boolean)) - LinkedHashMap<FedPlan, Boolean> bfsLinkedMap = new LinkedHashMap<>(); - bfsLinkedMap.put(rootPlan, true); + DataOp dataOp = (DataOp) hop; + Types.OpOpData opType = dataOp.getOp(); + String hopName = dataOp.getName(); - // Array to store cumulative additional cost for resolving conflicts - double[] cumulativeAdditionalCost = new double[]{0.0}; - - while (!bfsLinkedMap.isEmpty()) { - // Perform BFS to detect conflicts in federated plans - while (!bfsLinkedMap.isEmpty()) { - FedPlan currentPlan = bfsLinkedMap.keySet().iterator().next(); - bfsLinkedMap.remove(currentPlan); + if (isInner && opType == Types.OpOpData.TRANSIENTWRITE) { + innerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); + } + else if (isInner && opType == Types.OpOpData.TRANSIENTREAD) { + childHops = rewireInnerTransRead(childHops, hopName, + innerTransTable, formerInnerTransTable, outerTransTable); + } + else if (!isInner && opType == Types.OpOpData.TRANSIENTWRITE) { + outerTransTable.computeIfAbsent(hopName, k -> new ArrayList<>()).add(hop); + } + else if (!isInner && opType == Types.OpOpData.TRANSIENTREAD) { + childHops = rewireOuterTransRead(childHops, hopName, outerTransTable); + } - // Iterate over each child plan of the current plan - for (Pair<Long, FederatedOutput> childPlanPair : currentPlan.getChildFedPlans()) { - FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair); + return childHops; + } - // Check if the child plan ID is already visited - if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { - // Retrieve the existing conflict pair for the child plan - Pair<FederatedOutput, List<FedPlan>> conflictChildPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); - // Add the current plan to the list of parent plans - conflictChildPlanPair.getRight().add(currentPlan); + private static List<Hop> rewireInnerTransRead(List<Hop> childHops, String hopName, Map<String, List<Hop>> innerTransTable, + Map<String, List<Hop>> formerInnerTransTable, Map<String, List<Hop>> outerTransTable) { + List<Hop> newChildHops = new ArrayList<>(childHops); - // If the federated output type differs, a conflict is detected - if (conflictChildPlanPair.getLeft() != childPlanPair.getRight()) { - // If this is the first detection, remove conflictChildFedPlan from the BFS queue and add it to the conflict linked map (queue) - // If the existing FedPlan is not removed from the bfsqueue or both actions are performed, duplicate calculations for the same FedPlan and its children occur - if (!conflictLinkedMap.containsKey(childPlanPair.getLeft())) { - conflictLinkedMap.put(childPlanPair.getLeft(), conflictChildPlanPair.getRight()); - bfsLinkedMap.remove(childFedPlan); - } - } - } else { - // If no conflict exists, create a new entry in the conflict check map - List<FedPlan> parentFedPlanList = new ArrayList<>(); - parentFedPlanList.add(currentPlan); + // Read according to priority: inner -> formerInner -> outer + List<Hop> additionalChildHops = innerTransTable.get(hopName); + if (additionalChildHops == null) { + additionalChildHops = formerInnerTransTable.get(hopName); + } + if (additionalChildHops == null) { + additionalChildHops = outerTransTable.get(hopName); + } - // Map the child plan ID to its output type and list of parent plans - conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); - // Add the child plan to the BFS queue - bfsLinkedMap.put(childFedPlan, true); - } - } - } - // Resolve these conflicts to ensure a consistent federated output type across the plan - // Re-run BFS with resolved conflicts - bfsLinkedMap = FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, cumulativeAdditionalCost); - conflictLinkedMap.clear(); + if (additionalChildHops != null) { + newChildHops.addAll(additionalChildHops); } + return newChildHops; + } - // Return the cumulative additional cost for resolving conflicts - return cumulativeAdditionalCost[0]; + private static List<Hop> rewireOuterTransRead(List<Hop> childHops, String hopName, Map<String, List<Hop>> outerTransTable) { + List<Hop> newChildHops = new ArrayList<>(childHops); + List<Hop> additionalChildHops = outerTransTable.get(hopName); + if (additionalChildHops != null) { + newChildHops.addAll(additionalChildHops); + } + return newChildHops; } -} + + /** + * Enumerates federated execution plans for a given Hop. + * This method calculates the self cost and child costs for the Hop, + * generates federated plan variants for both LOUT and FOUT output types, + * and prunes redundant plans before adding them to the memo table. + * + * @param hop The Hop for which to enumerate federated plans. + * @param memoTable The memoization table to store plan variants. + * @param childHops The list of child hops. + * @param weight The weight associated with the current Hop. + */ + private static void enumerateFedPlan(Hop hop, FederatedMemoTable memoTable, List<Hop> childHops, double weight){ + long hopID = hop.getHopID(); + HopCommon hopCommon = new HopCommon(hop, weight); + double selfCost = FederatedPlanCostEstimator.computeHopCost(hopCommon); + + FedPlanVariants lOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.LOUT); + FedPlanVariants fOutFedPlanVariants = new FedPlanVariants(hopCommon, FederatedOutput.FOUT); + + int numInputs = childHops.size(); + int numInitInputs = hop.getInput().size(); + + double[][] childCumulativeCost = new double[numInputs][2]; // # of child, LOUT/FOUT of child + double[] childForwardingCost = new double[numInputs]; // # of child + + // The self cost follows its own weight, while the forwarding cost follows the parent's weight. + FederatedPlanCostEstimator.getChildCosts(hopCommon, memoTable, childHops, childCumulativeCost, childForwardingCost); + + if (numInitInputs == numInputs){ + enumerateOnlyInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); + } else { + enumerateTReadInitChildFedPlan(lOutFedPlanVariants, fOutFedPlanVariants, numInitInputs, numInputs, childHops, childCumulativeCost, childForwardingCost, selfCost); + } + + // Prune the FedPlans to remove redundant plans + lOutFedPlanVariants.pruneFedPlans(); + fOutFedPlanVariants.pruneFedPlans(); + + // Add the FedPlanVariants to the memo table + memoTable.addFedPlanVariants(hopID, FederatedOutput.LOUT, lOutFedPlanVariants); + memoTable.addFedPlanVariants(hopID, FederatedOutput.FOUT, fOutFedPlanVariants); + } + + /** + * Enumerates federated execution plans for initial child hops only. + * This method generates all possible combinations of federated output types (LOUT and FOUT) + * for the initial child hops and calculates their cumulative costs. + * + * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. + * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. + * @param numInitInputs The number of initial input hops. + * @param childHops The list of child hops. + * @param childCumulativeCost The cumulative costs for each child hop. + * @param childForwardingCost The forwarding costs for each child hop. + * @param selfCost The self cost of the current hop. + */ + private static void enumerateOnlyInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, int numInitInputs, List<Hop> childHops, + double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ + // Iterate 2^n times, generating two FedPlans (LOUT, FOUT) each time. + for (int i = 0; i < (1 << numInitInputs); i++) { + double[] cumulativeCost = new double[]{selfCost, selfCost}; + List<Pair<Long, FederatedOutput>> planChilds = new ArrayList<>(); + // LOUT and FOUT share the same planChilds in each iteration (only forwarding cost differs). + enumerateInitChildFedPlan(numInitInputs, childHops, planChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); + + lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, planChilds)); + fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, planChilds)); + } + } + + /** + * Enumerates federated execution plans for a TRead hop. + * This method calculates the cumulative costs for both LOUT and FOUT federated output types + * by considering the additional child hops, which are TWrite hops. + * It generates all possible combinations of federated output types for the initial child hops + * and adds the pre-calculated costs of the TWrite child hops to these combinations. + * + * @param lOutFedPlanVariants The FedPlanVariants object for LOUT output type. + * @param fOutFedPlanVariants The FedPlanVariants object for FOUT output type. + * @param numInitInputs The number of initial input hops. + * @param numInputs The total number of input hops, including additional TWrite hops. + * @param childHops The list of child hops. + * @param childCumulativeCost The cumulative costs for each child hop. + * @param childForwardingCost The forwarding costs for each child hop. + * @param selfCost The self cost of the current hop. + */ + private static void enumerateTReadInitChildFedPlan(FedPlanVariants lOutFedPlanVariants, FedPlanVariants fOutFedPlanVariants, + int numInitInputs, int numInputs, List<Hop> childHops, + double[][] childCumulativeCost, double[] childForwardingCost, double selfCost){ + double lOutTReadCumulativeCost = selfCost; + double fOutTReadCumulativeCost = selfCost; + + List<Pair<Long, FederatedOutput>> lOutTReadPlanChilds = new ArrayList<>(); + List<Pair<Long, FederatedOutput>> fOutTReadPlanChilds = new ArrayList<>(); + + // Pre-calculate the cost for the additional child hop, which is a TWrite hop, of the TRead hop. + // Constraint: TWrite must have the same FedOutType as TRead. + for (int j = numInitInputs; j < numInputs; j++) { + Hop inputHop = childHops.get(j); + lOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.LOUT)); + fOutTReadPlanChilds.add(Pair.of(inputHop.getHopID(), FederatedOutput.FOUT)); + + lOutTReadCumulativeCost += childCumulativeCost[j][0]; + fOutTReadCumulativeCost += childCumulativeCost[j][1]; + // Skip TWrite -> TRead as they have the same FedOutType. + } + + for (int i = 0; i < (1 << numInitInputs); i++) { + double[] cumulativeCost = new double[]{selfCost, selfCost}; + List<Pair<Long, FederatedOutput>> lOutPlanChilds = new ArrayList<>(); + enumerateInitChildFedPlan(numInitInputs, childHops, lOutPlanChilds, childCumulativeCost, childForwardingCost, cumulativeCost, i); + + // Copy lOutPlanChilds to create fOutPlanChilds and add the pre-calculated cost of the TWrite child hop. + List<Pair<Long, FederatedOutput>> fOutPlanChilds = new ArrayList<>(lOutPlanChilds); + + lOutPlanChilds.addAll(lOutTReadPlanChilds); + fOutPlanChilds.addAll(fOutTReadPlanChilds); + + cumulativeCost[0] += lOutTReadCumulativeCost; + cumulativeCost[1] += fOutTReadCumulativeCost; + + lOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[0], lOutFedPlanVariants, lOutPlanChilds)); + fOutFedPlanVariants.addFedPlan(new FedPlan(cumulativeCost[1], fOutFedPlanVariants, fOutPlanChilds)); + } + } + + // Calculates costs for initial child hops, determining FOUT or LOUT based on `i`. + private static void enumerateInitChildFedPlan(int numInitInputs, List<Hop> childHops, List<Pair<Long, FederatedOutput>> planChilds, + double[][] childCumulativeCost, double[] childForwardingCost, double[] cumulativeCost, int i){ + // For each input, determine if it should be FOUT or LOUT based on bit j in i + for (int j = 0; j < numInitInputs; j++) { + Hop inputHop = childHops.get(j); + // Calculate the bit value to decide between FOUT and LOUT for the current input + final int bit = (i & (1 << j)) != 0 ? 1 : 0; // Determine the bit value (decides FOUT/LOUT) + final FederatedOutput childType = (bit == 1) ? FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); + + // Update the cumulative cost for LOUT, FOUT + cumulativeCost[0] += childCumulativeCost[j][bit] + childForwardingCost[j] * bit; + cumulativeCost[1] += childCumulativeCost[j][bit] + childForwardingCost[j] * (1 - bit); + } + } + + // Creates a dummy root node (fedplan) and selects the FedPlan with the minimum cost to return. + // The dummy root node does not have LOUT or FOUT. + private static FedPlan getMinCostRootFedPlan(Set<Hop> progRootHopSet, FederatedMemoTable memoTable) { + double cumulativeCost = 0; + List<Pair<Long, FederatedOutput>> rootFedPlanChilds = new ArrayList<>(); + + // Iterate over each Hop in the progRootHopSet + for (Hop endHop : progRootHopSet){ + // Retrieve the pruned FedPlan for LOUT and FOUT from the memo table + FedPlan lOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.LOUT); + FedPlan fOutFedPlan = memoTable.getFedPlanAfterPrune(endHop.getHopID(), FederatedOutput.FOUT); + + // Compare the cumulative costs of LOUT and FOUT FedPlans + if (lOutFedPlan.getCumulativeCost() <= fOutFedPlan.getCumulativeCost()){ + cumulativeCost += lOutFedPlan.getCumulativeCost(); + rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.LOUT)); + } else{ + cumulativeCost += fOutFedPlan.getCumulativeCost(); + rootFedPlanChilds.add(Pair.of(endHop.getHopID(), FederatedOutput.FOUT)); + } + } + + return new FedPlan(cumulativeCost, null, rootFedPlanChilds); + } + + /** + * Detects and resolves conflicts in federated plans starting from the root plan. + * This function performs a breadth-first search (BFS) to traverse the federated plan tree. + * It identifies conflicts where the same plan ID has different federated output types. + * For each conflict, it records the plan ID and its conflicting parent plans. + * The function ensures that each plan ID is associated with a consistent federated output type + * by resolving these conflicts iteratively. + * + * The process involves: + * - Using a map to track conflicts, associating each plan ID with its federated output type + * and a list of parent plans. + * - Storing detected conflicts in a linked map, each entry containing a plan ID and its + * conflicting parent plans. + * - Performing BFS traversal starting from the root plan, checking each child plan for conflicts. + * - If a conflict is detected (i.e., a plan ID has different output types), the conflicting plan + * is removed from the BFS queue and added to the conflict map to prevent duplicate calculations. + * - Resolving conflicts by ensuring a consistent federated output type across the plan. + * - Re-running BFS with resolved conflicts to ensure all inconsistencies are addressed. + * + * @param rootPlan The root federated plan from which to start the conflict detection. + * @param memoTable The memoization table used to retrieve pruned federated plans. + * @return The cumulative additional cost for resolving conflicts. + */ + private static double detectAndResolveConflictFedPlan(FedPlan rootPlan, FederatedMemoTable memoTable) { + // Map to track conflicts: maps a plan ID to its federated output type and list of parent plans + Map<Long, Pair<FederatedOutput, List<FedPlan>>> conflictCheckMap = new HashMap<>(); + + // LinkedMap to store detected conflicts, each with a plan ID and its conflicting parent plans + LinkedHashMap<Long, List<FedPlan>> conflictLinkedMap = new LinkedHashMap<>(); + + // LinkedMap for BFS traversal starting from the root plan (Do not use value (boolean)) + LinkedHashMap<FedPlan, Boolean> bfsLinkedMap = new LinkedHashMap<>(); + bfsLinkedMap.put(rootPlan, true); + + // Array to store cumulative additional cost for resolving conflicts + double[] cumulativeAdditionalCost = new double[]{0.0}; + + while (!bfsLinkedMap.isEmpty()) { + // Perform BFS to detect conflicts in federated plans + while (!bfsLinkedMap.isEmpty()) { + FedPlan currentPlan = bfsLinkedMap.keySet().iterator().next(); + bfsLinkedMap.remove(currentPlan); + + // Iterate over each child plan of the current plan + for (Pair<Long, FederatedOutput> childPlanPair : currentPlan.getChildFedPlans()) { + FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair); + + // Check if the child plan ID is already visited + if (conflictCheckMap.containsKey(childPlanPair.getLeft())) { + // Retrieve the existing conflict pair for the child plan + Pair<FederatedOutput, List<FedPlan>> conflictChildPlanPair = conflictCheckMap.get(childPlanPair.getLeft()); + // Add the current plan to the list of parent plans + conflictChildPlanPair.getRight().add(currentPlan); + + // If the federated output type differs, a conflict is detected + if (conflictChildPlanPair.getLeft() != childPlanPair.getRight()) { + // If this is the first detection, remove conflictChildFedPlan from the BFS queue and add it to the conflict linked map (queue) + // If the existing FedPlan is not removed from the bfsqueue or both actions are performed, duplicate calculations for the same FedPlan and its children occur + if (!conflictLinkedMap.containsKey(childPlanPair.getLeft())) { + conflictLinkedMap.put(childPlanPair.getLeft(), conflictChildPlanPair.getRight()); + bfsLinkedMap.remove(childFedPlan); + } + } + } else { + // If no conflict exists, create a new entry in the conflict check map + List<FedPlan> parentFedPlanList = new ArrayList<>(); + parentFedPlanList.add(currentPlan); + + // Map the child plan ID to its output type and list of parent plans + conflictCheckMap.put(childPlanPair.getLeft(), new ImmutablePair<>(childPlanPair.getRight(), parentFedPlanList)); + // Add the child plan to the BFS queue + bfsLinkedMap.put(childFedPlan, true); + } + } + } + // Resolve these conflicts to ensure a consistent federated output type across the plan + // Re-run BFS with resolved conflicts + bfsLinkedMap = FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, cumulativeAdditionalCost); + conflictLinkedMap.clear(); + } + + // Return the cumulative additional cost for resolving conflicts + return cumulativeAdditionalCost[0]; + } + } + \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java index 7bc7339563..9ff405ab28 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -17,224 +17,248 @@ * under the License. */ -package org.apache.sysds.hops.fedplanner; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.cost.ComputeCost; -import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; -import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - -import java.util.LinkedHashMap; -import java.util.NoSuchElementException; -import java.util.List; -import java.util.Map; - -/** - * Cost estimator for federated execution plans. - * Calculates computation, memory access, and network transfer costs for federated operations. - * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. - */ -public class FederatedPlanCostEstimator { - // Default value is used as a reasonable estimate since we only need - // to compare relative costs between different federated plans - // Memory bandwidth for local computations (25 GB/s) - private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; - // Network bandwidth for data transfers between federated sites (1 Gbps) - private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; - - /** - * Computes total cost of federated plan by: - * 1. Computing current node cost (if not cached) - * 2. Adding minimum-cost child plans - * 3. Including network transfer costs when needed - * - * @param currentPlan Plan to compute cost for - * @param memoTable Table containing all plan variants - */ - public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { - double totalCost; - Hop currentHop = currentPlan.getHopRef(); - - // Step 1: Calculate current node costs if not already computed - if (currentPlan.getSelfCost() == 0) { - // Compute cost for current node (computation + memory access) - totalCost = computeCurrentCost(currentHop); - currentPlan.setSelfCost(totalCost); - // Calculate potential network transfer cost if federation type changes - currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); - } else { - totalCost = currentPlan.getSelfCost(); - } - - // Step 2: Process each child plan and add their costs - for (Pair<Long, FederatedOutput> childPlanPair : currentPlan.getChildFedPlans()) { - // Find minimum cost child plan considering federation type compatibility - // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents - // because we're selecting child plans independently for each parent - FedPlan planRef = memoTable.getMinCostFedPlan(childPlanPair); - - // Add child plan cost (includes network transfer cost if federation types differ) - totalCost += planRef.getTotalCost() + planRef.getCondNetTransferCost(currentPlan.getFedOutType()); - } - - // Step 3: Set final cumulative cost including current node - currentPlan.setTotalCost(totalCost); - } - - /** - * Resolves conflicts in federated plans where different plans have different FederatedOutput types. - * This function traverses the list of conflicting plans in reverse order to ensure that conflicts - * are resolved from the bottom-up, allowing for consistent federated output types across the plan. - * It calculates additional costs for each potential resolution and updates the cumulative additional cost. - * - * @param memoTable The FederatedMemoTable containing all federated plan variants. - * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent plans with conflicting federated outputs. - * @param cumulativeAdditionalCost An array to store the cumulative additional cost incurred by resolving conflicts. - * @return A LinkedHashMap of resolved federated plans, marked with a boolean indicating resolution status. - */ - public static LinkedHashMap<FedPlan, Boolean> resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap<Long, List<FedPlan>> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) { - // LinkedHashMap to store resolved federated plans for BFS traversal. - LinkedHashMap<FedPlan, Boolean> resolvedFedPlanLinkedMap = new LinkedHashMap<>(); - - // Traverse the conflictFedPlanList in reverse order after BFS to resolve conflicts - for (Map.Entry<Long, List<FedPlan>> conflictFedPlanPair : conflictFedPlanLinkedMap.entrySet()) { - long conflictHopID = conflictFedPlanPair.getKey(); - List<FedPlan> conflictParentFedPlans = conflictFedPlanPair.getValue(); - - // Retrieve the conflicting federated plans for LOUT and FOUT types - FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT); - FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT); - - // Variables to store additional costs for LOUT and FOUT types - double lOutAdditionalCost = 0; - double fOutAdditionalCost = 0; - - // Flags to check if the plan involves network transfer - // Network transfer cost is calculated only once, even if it occurs multiple times - boolean isLOutNetTransfer = false; - boolean isFOutNetTransfer = false; - - // Determine the optimal federated output type based on the calculated costs - FederatedOutput optimalFedOutType; - - // Iterate over each parent federated plan in the current conflict pair - for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { - // Find the calculated FedOutType of the child plan - Pair<Long, FederatedOutput> cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream() - .filter(pair -> pair.getLeft().equals(conflictHopID)) - .findFirst() - .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + conflictHopID)); - - // CASE 1. Calculated LOUT / Parent LOUT / Current LOUT: Total cost remains unchanged. - // CASE 2. Calculated LOUT / Parent FOUT / Current LOUT: Total cost remains unchanged, subtract net cost, add net cost later. - // CASE 3. Calculated FOUT / Parent LOUT / Current LOUT: Change total cost, subtract net cost. - // CASE 4. Calculated FOUT / Parent FOUT / Current LOUT: Change total cost, add net cost later. - // CASE 5. Calculated LOUT / Parent LOUT / Current FOUT: Change total cost, add net cost later. - // CASE 6. Calculated LOUT / Parent FOUT / Current FOUT: Change total cost, subtract net cost. - // CASE 7. Calculated FOUT / Parent LOUT / Current FOUT: Total cost remains unchanged, subtract net cost, add net cost later. - // CASE 8. Calculated FOUT / Parent FOUT / Current FOUT: Total cost remains unchanged. - - // Adjust LOUT, FOUT costs based on the calculated plan's output type - if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { - // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost - // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. - fOutAdditionalCost += confilctFOutFedPlan.getTotalCost() - confilctLOutFedPlan.getTotalCost(); - - if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { - // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred - // (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later - isFOutNetTransfer = true; - } else { - // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred - // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later - isLOutNetTransfer = true; - lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); - - // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it - fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); - } - } else { - lOutAdditionalCost += confilctLOutFedPlan.getTotalCost() - confilctFOutFedPlan.getTotalCost(); - - if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { - isLOutNetTransfer = true; - } else { - isFOutNetTransfer = true; - lOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); - fOutAdditionalCost -= confilctLOutFedPlan.getNetTransferCost(); - } - } - } - - // Add network transfer costs if applicable - if (isLOutNetTransfer) { - lOutAdditionalCost += confilctLOutFedPlan.getNetTransferCost(); - } - if (isFOutNetTransfer) { - fOutAdditionalCost += confilctFOutFedPlan.getNetTransferCost(); - } - - // Determine the optimal federated output type based on the calculated costs - if (lOutAdditionalCost <= fOutAdditionalCost) { - optimalFedOutType = FederatedOutput.LOUT; - cumulativeAdditionalCost[0] += lOutAdditionalCost; - resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true); - } else { - optimalFedOutType = FederatedOutput.FOUT; - cumulativeAdditionalCost[0] += fOutAdditionalCost; - resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true); - } - - // Update only the optimal federated output type, not the cost itself or recursively - for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { - for (Pair<Long, FederatedOutput> childPlanPair : conflictParentFedPlan.getChildFedPlans()) { - if (childPlanPair.getLeft() == conflictHopID && childPlanPair.getRight() != optimalFedOutType) { - int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair); - conflictParentFedPlan.getChildFedPlans().set(index, - Pair.of(childPlanPair.getLeft(), optimalFedOutType)); - break; - } - } - } - } - return resolvedFedPlanLinkedMap; - } - - /** - * Computes the cost for the current Hop node. - * - * @param currentHop The Hop node whose cost needs to be computed - * @return The total cost for the current node's operation - */ - private static double computeCurrentCost(Hop currentHop){ - double computeCost = ComputeCost.getHOPComputeCost(currentHop); - double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); - double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); - - // Compute total cost assuming: - // 1. Computation and input access can be overlapped (hence taking max) - // 2. Output access must wait for both to complete (hence adding) - return Math.max(computeCost, inputAccessCost) + ouputAccessCost; - } - - /** - * Calculates the memory access cost based on data size and memory bandwidth. - * - * @param memSize Size of data to be accessed (in bytes) - * @return Time cost for memory access (in seconds) - */ - private static double computeHopMemoryAccessCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; - } - - /** - * Calculates the network transfer cost based on data size and network bandwidth. - * Used when federation status changes between parent and child plans. - * - * @param memSize Size of data to be transferred (in bytes) - * @return Time cost for network transfer (in seconds) - */ - private static double computeHopNetworkAccessCost(double memSize) { - return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; - } -} + package org.apache.sysds.hops.fedplanner; + import org.apache.commons.lang3.tuple.Pair; + import org.apache.sysds.common.Types; + import org.apache.sysds.hops.DataOp; + import org.apache.sysds.hops.Hop; + import org.apache.sysds.hops.cost.ComputeCost; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; + import org.apache.sysds.hops.fedplanner.FederatedMemoTable.HopCommon; + import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + + import java.util.LinkedHashMap; + import java.util.NoSuchElementException; + import java.util.List; + import java.util.Map; + + /** + * Cost estimator for federated execution plans. + * Calculates computation, memory access, and network transfer costs for federated operations. + * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. + */ + public class FederatedPlanCostEstimator { + // Default value is used as a reasonable estimate since we only need + // to compare relative costs between different federated plans + // Memory bandwidth for local computations (25 GB/s) + private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; + // Network bandwidth for data transfers between federated sites (1 Gbps) + private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; + + // Retrieves the cumulative and forwarding costs of the child hops and stores them in arrays + public static void getChildCosts(HopCommon hopCommon, FederatedMemoTable memoTable, List<Hop> inputHops, + double[][] childCumulativeCost, double[] childForwardingCost) { + for (int i = 0; i < inputHops.size(); i++) { + long childHopID = inputHops.get(i).getHopID(); + + FedPlan childLOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.LOUT); + FedPlan childFOutFedPlan = memoTable.getFedPlanAfterPrune(childHopID, FederatedOutput.FOUT); + + // The cumulative cost of the child already includes the weight + childCumulativeCost[i][0] = childLOutFedPlan.getCumulativeCost(); + childCumulativeCost[i][1] = childFOutFedPlan.getCumulativeCost(); + + // TODO: Q. Shouldn't the child's forwarding cost follow the parent's weight, regardless of loops or if-else statements? + childForwardingCost[i] = hopCommon.weight * childLOutFedPlan.getForwardingCost(); + } + } + + /** + * Computes the cost associated with a given Hop node. + * This method calculates both the self cost and the forwarding cost for the Hop, + * taking into account its type and the number of parent nodes. + * + * @param hopCommon The HopCommon object containing the Hop and its properties. + * @return The self cost of the Hop. + */ + public static double computeHopCost(HopCommon hopCommon){ + // TWrite and TRead are meta-data operations, hence selfCost is zero + if (hopCommon.hopRef instanceof DataOp){ + if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTWRITE ){ + hopCommon.setSelfCost(0); + // Since TWrite and TRead have the same FedOutType, forwarding cost is zero + hopCommon.setForwardingCost(0); + return 0; + } else if (((DataOp)hopCommon.hopRef).getOp() == Types.OpOpData.TRANSIENTREAD) { + hopCommon.setSelfCost(0); + // TRead may have a different FedOutType from its parent, so calculate forwarding cost + hopCommon.setForwardingCost(computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate())); + return 0; + } + } + + // In loops, selfCost is repeated, but forwarding may not be + // Therefore, the weight for forwarding follows the parent's weight (TODO: Q. Is the parent also receiving forwarding once?) + double selfCost = hopCommon.weight * computeSelfCost(hopCommon.hopRef); + double forwardingCost = computeHopForwardingCost(hopCommon.hopRef.getOutputMemEstimate()); + + int numParents = hopCommon.hopRef.getParent().size(); + if (numParents >= 2) { + selfCost /= numParents; + forwardingCost /= numParents; + } + + hopCommon.setSelfCost(selfCost); + hopCommon.setForwardingCost(forwardingCost); + + return selfCost; + } + + /** + * Computes the cost for the current Hop node. + * + * @param currentHop The Hop node whose cost needs to be computed + * @return The total cost for the current node's operation + */ + private static double computeSelfCost(Hop currentHop){ + double computeCost = ComputeCost.getHOPComputeCost(currentHop); + double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); + double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); + + // Compute total cost assuming: + // 1. Computation and input access can be overlapped (hence taking max) + // 2. Output access must wait for both to complete (hence adding) + return Math.max(computeCost, inputAccessCost) + ouputAccessCost; + } + + /** + * Calculates the memory access cost based on data size and memory bandwidth. + * + * @param memSize Size of data to be accessed (in bytes) + * @return Time cost for memory access (in seconds) + */ + private static double computeHopMemoryAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; + } + + /** + * Calculates the network transfer cost based on data size and network bandwidth. + * Used when federation status changes between parent and child plans. + * + * @param memSize Size of data to be transferred (in bytes) + * @return Time cost for network transfer (in seconds) + */ + private static double computeHopForwardingCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; + } + + /** + * Resolves conflicts in federated plans where different plans have different FederatedOutput types. + * This function traverses the list of conflicting plans in reverse order to ensure that conflicts + * are resolved from the bottom-up, allowing for consistent federated output types across the plan. + * It calculates additional costs for each potential resolution and updates the cumulative additional cost. + * + * @param memoTable The FederatedMemoTable containing all federated plan variants. + * @param conflictFedPlanLinkedMap A map of plan IDs to lists of parent plans with conflicting federated outputs. + * @param cumulativeAdditionalCost An array to store the cumulative additional cost incurred by resolving conflicts. + * @return A LinkedHashMap of resolved federated plans, marked with a boolean indicating resolution status. + */ + public static LinkedHashMap<FedPlan, Boolean> resolveConflictFedPlan(FederatedMemoTable memoTable, LinkedHashMap<Long, List<FedPlan>> conflictFedPlanLinkedMap, double[] cumulativeAdditionalCost) { + // LinkedHashMap to store resolved federated plans for BFS traversal. + LinkedHashMap<FedPlan, Boolean> resolvedFedPlanLinkedMap = new LinkedHashMap<>(); + + // Traverse the conflictFedPlanList in reverse order after BFS to resolve conflicts + for (Map.Entry<Long, List<FedPlan>> conflictFedPlanPair : conflictFedPlanLinkedMap.entrySet()) { + long conflictHopID = conflictFedPlanPair.getKey(); + List<FedPlan> conflictParentFedPlans = conflictFedPlanPair.getValue(); + + // Retrieve the conflicting federated plans for LOUT and FOUT types + FedPlan confilctLOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.LOUT); + FedPlan confilctFOutFedPlan = memoTable.getFedPlanAfterPrune(conflictHopID, FederatedOutput.FOUT); + + // Variables to store additional costs for LOUT and FOUT types + double lOutAdditionalCost = 0; + double fOutAdditionalCost = 0; + + // Flags to check if the plan involves network transfer + // Network transfer cost is calculated only once, even if it occurs multiple times + boolean isLOutForwarding = false; + boolean isFOutForwarding = false; + + // Determine the optimal federated output type based on the calculated costs + FederatedOutput optimalFedOutType; + + // Iterate over each parent federated plan in the current conflict pair + for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { + // Find the calculated FedOutType of the child plan + Pair<Long, FederatedOutput> cacluatedConflictPlanPair = conflictParentFedPlan.getChildFedPlans().stream() + .filter(pair -> pair.getLeft().equals(conflictHopID)) + .findFirst() + .orElseThrow(() -> new NoSuchElementException("No matching pair found for ID: " + conflictHopID)); + + // CASE 1. Calculated LOUT / Parent LOUT / Current LOUT: Total cost remains unchanged. + // CASE 2. Calculated LOUT / Parent FOUT / Current LOUT: Total cost remains unchanged, subtract net cost, add net cost later. + // CASE 3. Calculated FOUT / Parent LOUT / Current LOUT: Change total cost, subtract net cost. + // CASE 4. Calculated FOUT / Parent FOUT / Current LOUT: Change total cost, add net cost later. + // CASE 5. Calculated LOUT / Parent LOUT / Current FOUT: Change total cost, add net cost later. + // CASE 6. Calculated LOUT / Parent FOUT / Current FOUT: Change total cost, subtract net cost. + // CASE 7. Calculated FOUT / Parent LOUT / Current FOUT: Total cost remains unchanged, subtract net cost, add net cost later. + // CASE 8. Calculated FOUT / Parent FOUT / Current FOUT: Total cost remains unchanged. + + // Adjust LOUT, FOUT costs based on the calculated plan's output type + if (cacluatedConflictPlanPair.getRight() == FederatedOutput.LOUT) { + // When changing from calculated LOUT to current FOUT, subtract the existing LOUT total cost and add the FOUT total cost + // When maintaining calculated LOUT to current LOUT, the total cost remains unchanged. + fOutAdditionalCost += confilctFOutFedPlan.getCumulativeCost() - confilctLOutFedPlan.getCumulativeCost(); + + if (conflictParentFedPlan.getFedOutType() == FederatedOutput.LOUT) { + // (CASE 1) Previously, calculated was LOUT and parent was LOUT, so no network transfer cost occurred + // (CASE 5) If changing from calculated LOUT to current FOUT, network transfer cost occurs, but calculated later + isFOutForwarding = true; + } else { + // Previously, calculated was LOUT and parent was FOUT, so network transfer cost occurred + // (CASE 2) If maintaining calculated LOUT to current LOUT, subtract existing network transfer cost and calculate later + isLOutForwarding = true; + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + + // (CASE 6) If changing from calculated LOUT to current FOUT, no network transfer cost occurs, so subtract it + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + } + } else { + lOutAdditionalCost += confilctLOutFedPlan.getCumulativeCost() - confilctFOutFedPlan.getCumulativeCost(); + + if (conflictParentFedPlan.getFedOutType() == FederatedOutput.FOUT) { + isLOutForwarding = true; + } else { + isFOutForwarding = true; + lOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + fOutAdditionalCost -= confilctLOutFedPlan.getForwardingCost(); + } + } + } + + // Add network transfer costs if applicable + if (isLOutForwarding) { + lOutAdditionalCost += confilctLOutFedPlan.getForwardingCost(); + } + if (isFOutForwarding) { + fOutAdditionalCost += confilctFOutFedPlan.getForwardingCost(); + } + + // Determine the optimal federated output type based on the calculated costs + if (lOutAdditionalCost <= fOutAdditionalCost) { + optimalFedOutType = FederatedOutput.LOUT; + cumulativeAdditionalCost[0] += lOutAdditionalCost; + resolvedFedPlanLinkedMap.put(confilctLOutFedPlan, true); + } else { + optimalFedOutType = FederatedOutput.FOUT; + cumulativeAdditionalCost[0] += fOutAdditionalCost; + resolvedFedPlanLinkedMap.put(confilctFOutFedPlan, true); + } + + // Update only the optimal federated output type, not the cost itself or recursively + for (FedPlan conflictParentFedPlan : conflictParentFedPlans) { + for (Pair<Long, FederatedOutput> childPlanPair : conflictParentFedPlan.getChildFedPlans()) { + if (childPlanPair.getLeft() == conflictHopID && childPlanPair.getRight() != optimalFedOutType) { + int index = conflictParentFedPlan.getChildFedPlans().indexOf(childPlanPair); + conflictParentFedPlan.getChildFedPlans().set(index, + Pair.of(childPlanPair.getLeft(), optimalFedOutType)); + break; + } + } + } + } + return resolvedFedPlanLinkedMap; + } + } + \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java index 20485588d3..0bc7d9f84f 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanCostEnumeratorTest.java @@ -17,75 +17,94 @@ * under the License. */ -package org.apache.sysds.test.component.federated; + package org.apache.sysds.test.component.federated; -import java.io.IOException; -import java.util.HashMap; + import java.io.IOException; + import java.util.HashMap; + import org.junit.Assert; + import org.junit.Test; + import org.apache.sysds.api.DMLScript; + import org.apache.sysds.conf.ConfigurationManager; + import org.apache.sysds.conf.DMLConfig; + import org.apache.sysds.parser.DMLProgram; + import org.apache.sysds.parser.DMLTranslator; + import org.apache.sysds.parser.ParserFactory; + import org.apache.sysds.parser.ParserWrapper; + import org.apache.sysds.test.AutomatedTestBase; + import org.apache.sysds.test.TestConfiguration; + import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; + + public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase + { + private static final String TEST_DIR = "functions/federated/privacy/"; + private static final String HOME = SCRIPT_DIR + TEST_DIR; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; + + @Override + public void setUp() {} + + @Test + public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); } + + @Test + public void testFederatedPlanCostEnumerator2() { runTest("FederatedPlanCostEnumeratorTest2.dml"); } + + @Test + public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); } + + @Test + public void testFederatedPlanCostEnumerator4() { runTest("FederatedPlanCostEnumeratorTest4.dml"); } + + @Test + public void testFederatedPlanCostEnumerator5() { runTest("FederatedPlanCostEnumeratorTest5.dml"); } + + @Test + public void testFederatedPlanCostEnumerator6() { runTest("FederatedPlanCostEnumeratorTest6.dml"); } + + @Test + public void testFederatedPlanCostEnumerator7() { runTest("FederatedPlanCostEnumeratorTest7.dml"); } + + @Test + public void testFederatedPlanCostEnumerator8() { runTest("FederatedPlanCostEnumeratorTest8.dml"); } + + @Test + public void testFederatedPlanCostEnumerator9() { runTest("FederatedPlanCostEnumeratorTest9.dml"); } -import org.apache.sysds.hops.Hop; -import org.junit.Assert; -import org.junit.Test; -import org.apache.sysds.api.DMLScript; -import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.conf.DMLConfig; -import org.apache.sysds.parser.DMLProgram; -import org.apache.sysds.parser.DMLTranslator; -import org.apache.sysds.parser.ParserFactory; -import org.apache.sysds.parser.ParserWrapper; -import org.apache.sysds.test.AutomatedTestBase; -import org.apache.sysds.test.TestConfiguration; -import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; + @Test + public void testFederatedPlanCostEnumerator10() { runTest("FederatedPlanCostEnumeratorTest10.dml"); } - -public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase -{ - private static final String TEST_DIR = "functions/federated/privacy/"; - private static final String HOME = SCRIPT_DIR + TEST_DIR; - private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; - - @Override - public void setUp() {} - - @Test - public void testFederatedPlanCostEnumerator1() { runTest("FederatedPlanCostEnumeratorTest1.dml"); } - - @Test - public void testFederatedPlanCostEnumerator2() { runTest("FederatedPlanCostEnumeratorTest2.dml"); } - - @Test - public void testFederatedPlanCostEnumerator3() { runTest("FederatedPlanCostEnumeratorTest3.dml"); } - - // Todo: Need to write test scripts for the federated version - private void runTest( String scriptFilename ) { - int index = scriptFilename.lastIndexOf(".dml"); - String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); - TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); - addTestConfiguration(testName, testConfig); - loadTestConfiguration(testConfig); - - try { - DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); - ConfigurationManager.setLocalConfig(conf); - - //read script - String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); - - //parsing and dependency analysis - ParserWrapper parser = ParserFactory.createParser(); - DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); - DMLTranslator dmlt = new DMLTranslator(prog); - dmlt.liveVariableAnalysis(prog); - dmlt.validateParseTree(prog); - dmlt.constructHops(prog); - dmlt.rewriteHopsDAG(prog); - dmlt.constructLops(prog); - - Hop hops = prog.getStatementBlocks().get(0).getHops().get(0); - FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true); - } - catch (IOException e) { - e.printStackTrace(); - Assert.fail(); - } - } -} + // Todo: Need to write test scripts for the federated version + private void runTest( String scriptFilename ) { + int index = scriptFilename.lastIndexOf(".dml"); + String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); + TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); + addTestConfiguration(testName, testConfig); + loadTestConfiguration(testConfig); + + try { + DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); + ConfigurationManager.setLocalConfig(conf); + + //read script + String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); + + //parsing and dependency analysis + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + dmlt.rewriteHopsDAG(prog); + dmlt.constructLops(prog); + dmlt.rewriteLopDAG(prog); + + FederatedPlanCostEnumerator.enumerateProgram(prog, true); + } + catch (IOException e) { + e.printStackTrace(); + Assert.fail(); + } + } + } + \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py new file mode 100644 index 0000000000..b083c77913 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/federated/FederatedPlanVisualizer.py @@ -0,0 +1,268 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +import sys +import re +import networkx as nx +import matplotlib.pyplot as plt + +try: + import pygraphviz + from networkx.drawing.nx_agraph import graphviz_layout + HAS_PYGRAPHVIZ = True +except ImportError: + HAS_PYGRAPHVIZ = False + print("[WARNING] pygraphviz not found. Please install via 'pip install pygraphviz'.\n" + "If not installed, we will use an alternative layout (spring_layout).") + + +def parse_line(line: str): + """ + Parse a single line from the trace file to extract: + - Node ID + - Operation (hop name) + - Kind (e.g., FOUT, LOUT, NREF) + - Total cost + - Weight + - Refs (list of IDs that this node depends on) + """ + + # 1) Match a node ID in the form of "(R)" or "(<number>)" + match_id = re.match(r'^\((R|\d+)\)', line) + if not match_id: + return None + node_id = match_id.group(1) + + # 2) The remaining string after the node ID + after_id = line[match_id.end():].strip() + + # Extract operation (hop name) before the first "[" + match_label = re.search(r'^(.*?)\s*\[', after_id) + if match_label: + operation = match_label.group(1).strip() + else: + operation = after_id.strip() + + # 3) Extract the kind (content inside the first pair of brackets "[]") + match_bracket = re.search(r'\[([^\]]+)\]', after_id) + if match_bracket: + kind = match_bracket.group(1).strip() + else: + kind = "" + + # 4) Extract total and weight from the content inside curly braces "{}" + total = "" + weight = "" + match_curly = re.search(r'\{([^}]+)\}', line) + if match_curly: + curly_content = match_curly.group(1) + m_total = re.search(r'Total:\s*([\d\.]+)', curly_content) + m_weight = re.search(r'Weight:\s*([\d\.]+)', curly_content) + if m_total: + total = m_total.group(1) + if m_weight: + weight = m_weight.group(1) + + # 5) Extract reference nodes: look for the first parenthesis containing numbers after the hop name + match_refs = re.search(r'\(\s*(\d+(?:,\d+)*)\s*\)', after_id) + if match_refs: + ref_str = match_refs.group(1) + refs = [r.strip() for r in ref_str.split(',') if r.strip().isdigit()] + else: + refs = [] + + return { + 'node_id': node_id, + 'operation': operation, + 'kind': kind, + 'total': total, + 'weight': weight, + 'refs': refs + } + + +def build_dag_from_file(filename: str): + """ + Read a trace file line by line and build a directed acyclic graph (DAG) using NetworkX. + """ + G = nx.DiGraph() + with open(filename, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + + info = parse_line(line) + if not info: + continue + + node_id = info['node_id'] + operation = info['operation'] + kind = info['kind'] + total = info['total'] + weight = info['weight'] + refs = info['refs'] + + # Add node with attributes + G.add_node(node_id, label=operation, kind=kind, total=total, weight=weight) + + # Add edges from references to this node + for r in refs: + if r not in G: + G.add_node(r, label=r, kind="", total="", weight="") + G.add_edge(r, node_id) + return G + + +def main(): + """ + Main function that: + - Reads a filename from command-line arguments + - Builds a DAG from the file + - Draws and displays the DAG using matplotlib + """ + + # Get filename from command-line argument + if len(sys.argv) < 2: + print("[ERROR] No filename provided.\nUsage: python plot_federated_dag.py <filename>") + sys.exit(1) + filename = sys.argv[1] + + print(f"[INFO] Running with filename '{filename}'") + + # Build the DAG + G = build_dag_from_file(filename) + + # Print debug info: nodes and edges + print("Nodes:", G.nodes(data=True)) + print("Edges:", list(G.edges())) + + # Decide on layout + if HAS_PYGRAPHVIZ: + # graphviz_layout with rankdir=BT (bottom to top), etc. + pos = graphviz_layout(G, prog='dot', args='-Grankdir=BT -Gnodesep=0.5 -Granksep=0.8') + else: + # Fallback layout if pygraphviz is not installed + pos = nx.spring_layout(G, seed=42) + + # Dynamically adjust figure size based on number of nodes + node_count = len(G.nodes()) + fig_width = 10 + node_count / 10.0 + fig_height = 6 + node_count / 10.0 + plt.figure(figsize=(fig_width, fig_height), facecolor='white', dpi=300) + ax = plt.gca() + ax.set_facecolor('white') + + # Generate labels for each node in the format: + # node_id: operation_name + # C<total> (W<weight>) + labels = { + n: f"{n}: {G.nodes[n].get('label', n)}\n C{G.nodes[n].get('total', '')} (W{G.nodes[n].get('weight', '')})" + for n in G.nodes() + } + + # Function to determine color based on 'kind' + def get_color(n): + k = G.nodes[n].get('kind', '').lower() + if k == 'fout': + return 'tomato' + elif k == 'lout': + return 'dodgerblue' + elif k == 'nref': + return 'mediumpurple' + else: + return 'mediumseagreen' + + # Determine node shapes based on operation name: + # - '^' (triangle) if the label contains "twrite" + # - 's' (square) if the label contains "tread" + # - 'o' (circle) otherwise + triangle_nodes = [n for n in G.nodes() if 'twrite' in G.nodes[n].get('label', '').lower()] + square_nodes = [n for n in G.nodes() if 'tread' in G.nodes[n].get('label', '').lower()] + other_nodes = [ + n for n in G.nodes() + if 'twrite' not in G.nodes[n].get('label', '').lower() and + 'tread' not in G.nodes[n].get('label', '').lower() + ] + + # Colors for each group + triangle_colors = [get_color(n) for n in triangle_nodes] + square_colors = [get_color(n) for n in square_nodes] + other_colors = [get_color(n) for n in other_nodes] + + # Draw nodes group-wise + node_collection_triangle = nx.draw_networkx_nodes( + G, pos, nodelist=triangle_nodes, node_size=800, + node_color=triangle_colors, node_shape='^', ax=ax + ) + node_collection_square = nx.draw_networkx_nodes( + G, pos, nodelist=square_nodes, node_size=800, + node_color=square_colors, node_shape='s', ax=ax + ) + node_collection_other = nx.draw_networkx_nodes( + G, pos, nodelist=other_nodes, node_size=800, + node_color=other_colors, node_shape='o', ax=ax + ) + + # Set z-order for nodes, edges, and labels + node_collection_triangle.set_zorder(1) + node_collection_square.set_zorder(1) + node_collection_other.set_zorder(1) + + edge_collection = nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle='->', ax=ax) + if isinstance(edge_collection, list): + for ec in edge_collection: + ec.set_zorder(2) + else: + edge_collection.set_zorder(2) + + label_dict = nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, ax=ax) + for text in label_dict.values(): + text.set_zorder(3) + + # Set the title + plt.title("Program Level Federated Plan", fontsize=14, fontweight="bold") + + # Provide a small legend on the top-right or top-left + plt.text(1, 1, + "[LABEL]\n hopID: hopName\n C(Total) (W(Weight))", + fontsize=12, ha='right', va='top', transform=ax.transAxes) + + # Example mini-legend for different 'kind' values + plt.scatter(0.05, 0.95, color='dodgerblue', s=200, transform=ax.transAxes) + plt.scatter(0.18, 0.95, color='tomato', s=200, transform=ax.transAxes) + plt.scatter(0.31, 0.95, color='mediumpurple', s=200, transform=ax.transAxes) + + plt.text(0.08, 0.95, "LOUT", fontsize=12, va='center', transform=ax.transAxes) + plt.text(0.21, 0.95, "FOUT", fontsize=12, va='center', transform=ax.transAxes) + plt.text(0.34, 0.95, "NREF", fontsize=12, va='center', transform=ax.transAxes) + + plt.axis("off") + + # Save the plot to a file with the same name as the input file, but with a .png extension + output_filename = f"{filename.rsplit('.', 1)[0]}.png" + plt.savefig(output_filename, format='png', dpi=300, bbox_inches='tight') + + plt.show() + + +if __name__ == '__main__': + main() diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml new file mode 100644 index 0000000000..276de7bde9 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest10.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Recursive function: Calculate factorial +factorialUser = function(int n) return (int result) { + if (n <= 1) { + result = 1; # base case + } else { + result = n * factorialUser(n - 1); # recursive call + } +} + +number = 5; +fact_result = factorialUser(number); +print("Factorial of " + number + ": " + fact_result); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml new file mode 100644 index 0000000000..06533df144 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest4.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +a = matrix(7,10,10); +if (sum(a) > 0.5) + b = a * 2; +else + b = a * 3; +c = sqrt(b); +print(sum(c)); \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml new file mode 100644 index 0000000000..2721bbcbaf --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest5.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +for( i in 1:100 ) +{ + b = i + 1; + print(b); +} \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml new file mode 100644 index 0000000000..b95ae1b5bb --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest6.dml @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = matrix(7, rows=10, cols=10) +b = rand(rows = 1, cols = ncol(A), min = 1, max = 2); +i = 0 + +while (sum(b) < i) { + i = i + 1 + b = b + i + A = A * A + s = b %*% A + print(mean(s)) +} +c = sqrt(A) +print(sum(c)) \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml new file mode 100644 index 0000000000..e3efaa2851 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest7.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +a = 1; + +parfor( i in 1:10 ) +{ + b = i + a; + #print(b); +} diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml new file mode 100644 index 0000000000..1587ff613b --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest8.dml @@ -0,0 +1,49 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +a = rand(); +b= rand(); +c= rand(); +d= rand(); +e= rand(); +f= rand(); +h= rand(); +i= rand(); + +if (a < 30){ + a = a + b; + + if (a < 20) { + a = a * c; + } else { + a = a + d; + + if (a < 10) { + a = a + e; + } else { + a = a + f; + } + } +} else { + a = a + h; +} +c = a + i; +print(mean(c)) \ No newline at end of file diff --git a/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml new file mode 100644 index 0000000000..b5713374f2 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/FederatedPlanCostEnumeratorTest9.dml @@ -0,0 +1,58 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +# Define UDFs +meanUser = function (matrix[double] A) return (double m) { + m = sum(A)/nrow(A) +} + +minMaxUser = function( matrix[double] M) return (double minVal, double maxVal) { + minVal = min(M); + maxVal = max(M); +} + +# Recursive function: Calculate factorial +factorialUser = function(int n) return (int result) { + if (n <= 1) { + result = 1; # base case + } else { + result = n * factorialUser(n - 1); # recursive call + } +} + +# Main script +# 1. Create matrix and calculate statistics +M = rand(rows=4, cols=4, min=1, max=5); # 4x4 random matrix +avg = meanUser(M); +[min_val, max_val] = minMaxUser(M); + +# 2. Call recursive function (factorial) +number = 5; +fact_result = factorialUser(number); + +# 3. Print results +print("=== Matrix Statistics ==="); +print("Average: " + avg); +print("Min: " + min_val + ", Max: " + max_val); + +print("\n=== Recursive Function ==="); +print("Factorial of " + number + ": " + fact_result); \ No newline at end of file