min-guk opened a new pull request, #2147:
URL: https://github.com/apache/systemds/pull/2147

   This implementation is based on the newly implemented 
FederatedPlanCostEstimator and FederatedMemoTable, following the direction we 
previously discussed.
   
   ## 1. FederatedMemoTable (MemoTable)
   ```java
   public class MemoTable {
           private final Map<Pair<Long, FTypes.FType>, List<FedPlan>> 
hopMemoTable = new HashMap<>();
   
        public static class FedPlan {
                @SuppressWarnings("unused")
                private final Hop hopRef;                       // The 
associated Hop object
                private final double cost;                      // Cost of this 
federated plan
                @SuppressWarnings("unused")
                private final List<Pair<Long, FType>> planRefs; // References 
to dependent plans
        }
   }
   ```
   The previous FedPlan class structure had several issues:
   - A single <HopID, FederatedOutput> pair stored multiple FedPlans as a list 
in the MemoTable, redundantly storing the hopRef.
   - A single <HopID, FederatedOutput> pair had to calculate its computeCost 
and accessCost 2^(planRefs+1) times redundantly.
   - FedPlan did not store its own FederatedOutput
   
   ```java
   public class FederatedMemoTable {
       private final Map<Pair<Long, FederatedOutput>, FedPlanVariants> 
hopMemoTable = new HashMap<>();
   
       public static class FedPlanVariants {
           protected final Hop hopRef;         // Reference to the associated 
Hop
           protected double currentCost;       // Current execution cost 
(compute + memory access)
           protected double netTransferCost;   // Network transfer cost
           protected List<FedPlan> _fedPlanVariants;  
       }
       public static class FedPlan {
           private double cumulativeCost;                  // Total cost 
including child plans
           private final FederatedOutput fedOutType;       // Output type 
(FOUT/LOUT)
           private final FedPlanVariants fedPlanVariantList;  // Reference to 
variant list
           private List<Pair<Long, FederatedOutput>> metaChildFedPlans;  // 
Child plan references
           private List<FedPlan> selectedFedPlans;           // Selected child 
plans
       }
   ```
   The key points of the redesigned FederatedMemoTable are as follows:
   - A single <HopID, FederatedOutput> pair has one FedPlanVariants, which 
stores and shares the redundant hopRef, currentCost, and netTransferCost with 
FedPlans stored in fedPlanVariants.
   - A single <HopID, FederatedOutput> pair calculates its computeCost and 
accessCost only once.
   - FedPlan stores its own FederatedOutput.
   
   ## 2. CostEstimator
   ```java
       // Do not create and allocate any new FedPlan.
       // just calculate the cost for given fed plans.
       // cost of dependent fedplans in planRefs is already calculated.
       public static void computeFederatedPlanCost(FedPlan currentPlan, 
FederatedMemoTable memoTable){
           double cost = computeFederatedPlanCost(currentPlan.getHopRef());
   
           for (Pair<Long, FederatedOutput> planRefMeta: 
currentPlan.getPlanRefs()){
               FedPlan planRef = memoTable.getFedPlan(planRefMeta.getLeft(), 
planRefMeta.getRight());
               cost += planRef.getCost();
   
               if (currentPlan.getFedOutType() != planRef.getFedOutType()){
                   cost += 
computeHopNetworkAccessCost(planRef.getHopRef().getOutputMemEstimate());
               }
           }
           currentPlan.setCost(cost);
       }
   ```
   The previous CostEstimator also had several issues:
   - It calculates the currentHop's cost every time.
   - The Optimal FedPlan should minimize the total cost of compute, memory 
access, and network access.
   - However, the previous CostEstimator selects the ref plan with minimum cost 
excluding network cost, and then adds network cost afterward, so it cannot 
guarantee the minimum cost FedPlan.
   
   ```java
       public static void computeFederatedPlanCost(FedPlan currentPlan, 
FederatedMemoTable memoTable) {
           double cumulativeCost = 0;
           Hop currentHop = currentPlan.getHopRef();
   
           // Step 1: Calculate current node costs if not already computed
           if (currentPlan.getCurrentCost() == 0) {
               // Compute cost for current node (computation + memory access)
               cumulativeCost = computeCurrentCost(currentHop);
               currentPlan.setCurrentCost(cumulativeCost);
               // Calculate potential network transfer cost if federation type 
changes
               
currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate()));
           } else {
               cumulativeCost = currentPlan.getCurrentCost();
           }
           
           // Step 2: Process each child plan and add their costs
           for (Pair<Long, FederatedOutput> planRefMeta : 
currentPlan.getMetaChildFedPlans()) {
               // Find minimum cost child plan considering federation type 
compatibility
               // Note: This approach might lead to suboptimal or wrong 
solutions when a child has multiple parents
               // because we're selecting child plans independently for each 
parent
               FedPlan planRef = memoTable.getMinCostChildFedPlan(
                       planRefMeta.getLeft(), planRefMeta.getRight(), 
currentPlan.getFedOutType());
   
               // Add child plan cost (includes network transfer cost if 
federation types differ)
               cumulativeCost += 
planRef.getParentViewCost(currentPlan.getFedOutType());
               
               // Store selected child plan
               // Note: Selected plan has minimum parent view cost, not minimum 
cumulative cost,
               // which means it highly unlikely to be found through simple 
pruning after enumeration
               currentPlan.putChildFedPlan(planRef);
           }
           
           // Step 3: Set final cumulative cost including current node
           currentPlan.setCumulativeCost(cumulativeCost);
       }
   ```
   The key points of the redesigned CostEstimator are as follows:
   - It calculates the compute cost and access cost of currentHop only once per 
HopID.
   - When selecting the minimum cost ref plan, it selects the ref plan 
including network cost, ensuring minimum total cost.
   - It stores selected child plans in a list as pointers.
     - This is because when pruning all at once in the memotable later, we 
cannot calculate network cost without knowing the fOutType of each fedplan's 
parent fedplan, so we cannot identify the optimal cost plan. Therefore, pruning 
in the current MemoTable has been removed.
   
   However, the current CostEstimator may cause two problems because it selects 
child plans based only on the cost of a single current plan and child plan:
   1. A child plan can have multiple parent plans, and different parent plans 
can select different child plans. Therefore, a child plan could form a 
non-existent fed plan with different fOutTypes.
   2. Since a child plan can have multiple parent plans, it should select the 
fOutType that minimizes the sum of costs of all parent plans referencing it. 
Otherwise, it may select a suboptimal plan.
     - We need to devise a new algorithm to solve these two problems.
   
   ## 3. FederatedPlanCostEnumerator 
   ```java
   public class FederatedPlanCostEnumerator {
       public static FedPlan enumerateFederatedPlanCost(Hop rootHop) {          
    
           FederatedMemoTable memoTable = new FederatedMemoTable();
           enumerateFederatedPlanCost(rootHop, memoTable);
           return getMinCostRootFedPlan(rootHop.getHopID(), memoTable);
       }
   
       /**
        * Recursively enumerates all possible federated execution plans for a 
Hop DAG.
        * For each node:
        * 1. First processes all input nodes recursively if not already 
processed
        * 2. Generates all possible combinations of federation types 
(FOUT/LOUT) for inputs
        * 3. Creates and evaluates both FOUT and LOUT variants for current node 
with each input combination
        * 
        * The enumeration uses a bottom-up approach where:
        * - Each input combination is represented by a binary number (i)
        * - Bit j in i determines whether input j is FOUT (1) or LOUT (0)
        * - Total number of combinations is 2^numInputs
        */
       private static void enumerateFederatedPlanCost(Hop hop, 
FederatedMemoTable memoTable) {              
           int numInputs = hop.getInput().size();
   
           // Process all input nodes first if not already in memo table
           for (Hop inputHop : hop.getInput()) {
               if (!memoTable.contains(inputHop.getHopID(), 
FederatedOutput.FOUT) 
                   && !memoTable.contains(inputHop.getHopID(), 
FederatedOutput.LOUT)) {
                       enumerateFederatedPlanCost(inputHop, memoTable);
               }
           }
   
           // Generate all possible input combinations using binary 
representation
           // i represents a specific combination of FOUT/LOUT for inputs
           for (int i = 0; i < (1 << numInputs); i++) {
               List<Pair<Long, FederatedOutput>> planChilds = new 
ArrayList<>(); 
   
               // For each input, determine if it should be FOUT or LOUT based 
on bit j in i
               for (int j = 0; j < numInputs; j++) {
                   Hop inputHop = hop.getInput().get(j);
                   // If bit j is set (1), use FOUT; otherwise use LOUT
                   FederatedOutput childType = ((i & (1 << j)) != 0) ?
                       FederatedOutput.FOUT : FederatedOutput.LOUT;
                   planChilds.add(Pair.of(inputHop.getHopID(), childType));
               }
               
               // Create and evaluate FOUT variant for current input combination
               FedPlan fOutPlan = memoTable.addFedPlan(hop, 
FederatedOutput.FOUT, planChilds);
               FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, 
memoTable);
   
               // Create and evaluate LOUT variant for current input combination
               FedPlan lOutPlan = memoTable.addFedPlan(hop, 
FederatedOutput.LOUT, planChilds);
               FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, 
memoTable);
           }
       }
   }
   ```
   - This implementation is based on the newly implemented 
FederatedPlanCostEstimator and FederatedMemoTable, following the direction we 
previously discussed.
   - I'm not sure how to create complex Hop DAGs similar to real scenarios in 
the test code. Could you please provide some reference test code that I can 
refer to?
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to