min-guk opened a new pull request, #2147: URL: https://github.com/apache/systemds/pull/2147
This implementation is based on the newly implemented FederatedPlanCostEstimator and FederatedMemoTable, following the direction we previously discussed. ## 1. FederatedMemoTable (MemoTable) ```java public class MemoTable { private final Map<Pair<Long, FTypes.FType>, List<FedPlan>> hopMemoTable = new HashMap<>(); public static class FedPlan { @SuppressWarnings("unused") private final Hop hopRef; // The associated Hop object private final double cost; // Cost of this federated plan @SuppressWarnings("unused") private final List<Pair<Long, FType>> planRefs; // References to dependent plans } } ``` The previous FedPlan class structure had several issues: - A single <HopID, FederatedOutput> pair stored multiple FedPlans as a list in the MemoTable, redundantly storing the hopRef. - A single <HopID, FederatedOutput> pair had to calculate its computeCost and accessCost 2^(planRefs+1) times redundantly. - FedPlan did not store its own FederatedOutput ```java public class FederatedMemoTable { private final Map<Pair<Long, FederatedOutput>, FedPlanVariants> hopMemoTable = new HashMap<>(); public static class FedPlanVariants { protected final Hop hopRef; // Reference to the associated Hop protected double currentCost; // Current execution cost (compute + memory access) protected double netTransferCost; // Network transfer cost protected List<FedPlan> _fedPlanVariants; } public static class FedPlan { private double cumulativeCost; // Total cost including child plans private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) private final FedPlanVariants fedPlanVariantList; // Reference to variant list private List<Pair<Long, FederatedOutput>> metaChildFedPlans; // Child plan references private List<FedPlan> selectedFedPlans; // Selected child plans } ``` The key points of the redesigned FederatedMemoTable are as follows: - A single <HopID, FederatedOutput> pair has one FedPlanVariants, which stores and shares the redundant hopRef, currentCost, and netTransferCost with FedPlans stored in fedPlanVariants. - A single <HopID, FederatedOutput> pair calculates its computeCost and accessCost only once. - FedPlan stores its own FederatedOutput. ## 2. CostEstimator ```java // Do not create and allocate any new FedPlan. // just calculate the cost for given fed plans. // cost of dependent fedplans in planRefs is already calculated. public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable){ double cost = computeFederatedPlanCost(currentPlan.getHopRef()); for (Pair<Long, FederatedOutput> planRefMeta: currentPlan.getPlanRefs()){ FedPlan planRef = memoTable.getFedPlan(planRefMeta.getLeft(), planRefMeta.getRight()); cost += planRef.getCost(); if (currentPlan.getFedOutType() != planRef.getFedOutType()){ cost += computeHopNetworkAccessCost(planRef.getHopRef().getOutputMemEstimate()); } } currentPlan.setCost(cost); } ``` The previous CostEstimator also had several issues: - It calculates the currentHop's cost every time. - The Optimal FedPlan should minimize the total cost of compute, memory access, and network access. - However, the previous CostEstimator selects the ref plan with minimum cost excluding network cost, and then adds network cost afterward, so it cannot guarantee the minimum cost FedPlan. ```java public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { double cumulativeCost = 0; Hop currentHop = currentPlan.getHopRef(); // Step 1: Calculate current node costs if not already computed if (currentPlan.getCurrentCost() == 0) { // Compute cost for current node (computation + memory access) cumulativeCost = computeCurrentCost(currentHop); currentPlan.setCurrentCost(cumulativeCost); // Calculate potential network transfer cost if federation type changes currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); } else { cumulativeCost = currentPlan.getCurrentCost(); } // Step 2: Process each child plan and add their costs for (Pair<Long, FederatedOutput> planRefMeta : currentPlan.getMetaChildFedPlans()) { // Find minimum cost child plan considering federation type compatibility // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents // because we're selecting child plans independently for each parent FedPlan planRef = memoTable.getMinCostChildFedPlan( planRefMeta.getLeft(), planRefMeta.getRight(), currentPlan.getFedOutType()); // Add child plan cost (includes network transfer cost if federation types differ) cumulativeCost += planRef.getParentViewCost(currentPlan.getFedOutType()); // Store selected child plan // Note: Selected plan has minimum parent view cost, not minimum cumulative cost, // which means it highly unlikely to be found through simple pruning after enumeration currentPlan.putChildFedPlan(planRef); } // Step 3: Set final cumulative cost including current node currentPlan.setCumulativeCost(cumulativeCost); } ``` The key points of the redesigned CostEstimator are as follows: - It calculates the compute cost and access cost of currentHop only once per HopID. - When selecting the minimum cost ref plan, it selects the ref plan including network cost, ensuring minimum total cost. - It stores selected child plans in a list as pointers. - This is because when pruning all at once in the memotable later, we cannot calculate network cost without knowing the fOutType of each fedplan's parent fedplan, so we cannot identify the optimal cost plan. Therefore, pruning in the current MemoTable has been removed. However, the current CostEstimator may cause two problems because it selects child plans based only on the cost of a single current plan and child plan: 1. A child plan can have multiple parent plans, and different parent plans can select different child plans. Therefore, a child plan could form a non-existent fed plan with different fOutTypes. 2. Since a child plan can have multiple parent plans, it should select the fOutType that minimizes the sum of costs of all parent plans referencing it. Otherwise, it may select a suboptimal plan. - We need to devise a new algorithm to solve these two problems. ## 3. FederatedPlanCostEnumerator ```java public class FederatedPlanCostEnumerator { public static FedPlan enumerateFederatedPlanCost(Hop rootHop) { FederatedMemoTable memoTable = new FederatedMemoTable(); enumerateFederatedPlanCost(rootHop, memoTable); return getMinCostRootFedPlan(rootHop.getHopID(), memoTable); } /** * Recursively enumerates all possible federated execution plans for a Hop DAG. * For each node: * 1. First processes all input nodes recursively if not already processed * 2. Generates all possible combinations of federation types (FOUT/LOUT) for inputs * 3. Creates and evaluates both FOUT and LOUT variants for current node with each input combination * * The enumeration uses a bottom-up approach where: * - Each input combination is represented by a binary number (i) * - Bit j in i determines whether input j is FOUT (1) or LOUT (0) * - Total number of combinations is 2^numInputs */ private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) { int numInputs = hop.getInput().size(); // Process all input nodes first if not already in memo table for (Hop inputHop : hop.getInput()) { if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { enumerateFederatedPlanCost(inputHop, memoTable); } } // Generate all possible input combinations using binary representation // i represents a specific combination of FOUT/LOUT for inputs for (int i = 0; i < (1 << numInputs); i++) { List<Pair<Long, FederatedOutput>> planChilds = new ArrayList<>(); // For each input, determine if it should be FOUT or LOUT based on bit j in i for (int j = 0; j < numInputs; j++) { Hop inputHop = hop.getInput().get(j); // If bit j is set (1), use FOUT; otherwise use LOUT FederatedOutput childType = ((i & (1 << j)) != 0) ? FederatedOutput.FOUT : FederatedOutput.LOUT; planChilds.add(Pair.of(inputHop.getHopID(), childType)); } // Create and evaluate FOUT variant for current input combination FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds); FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable); // Create and evaluate LOUT variant for current input combination FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); } } } ``` - This implementation is based on the newly implemented FederatedPlanCostEstimator and FederatedMemoTable, following the direction we previously discussed. - I'm not sure how to create complex Hop DAGs similar to real scenarios in the test code. Could you please provide some reference test code that I can refer to? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org