sebwrede commented on a change in pull request #1395:
URL: https://github.com/apache/systemds/pull/1395#discussion_r723272733



##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
                }
        }
 
+       /**
+        * Return cost estimate in bytes of Hop DAG starting from given root 
HopRel.
+        * @param root HopRel of Hop DAG for which cost is estimated
+        * @param hopRelMemo memo table of HopRels for calculating input costs
+        * @return cost estimation of Hop DAG starting from given root HopRel
+        */
+       public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> 
hopRelMemo){
+               // Check if root is in memo table.
+               if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+                       && 
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == 
root.fedOut) ){
+                       return root.getCostObject();
+               }
+               else {
+                       // If no input has FOUT, the root will be processed by 
the coordinator
+                       boolean hasFederatedInput = 
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+                       //the input cost is included the first time the input 
hop is used
+                       //for additional usage, the additional cost is zero 
(disregarding potential read cost)
+                       double inputCosts = root.inputDependency.stream()
+                               .mapToDouble( in -> {
+                                       double inCost = 
in.existingCostPointer(root.hopRef.getHopID()) ?
+                                               0 : costEstimate(in, 
hopRelMemo).getTotal();
+                                       
in.addCostPointer(root.hopRef.getHopID());
+                                       return inCost;
+                               } )
+                               .sum();
+                       double inputTransferCost = 
inputTransferCostEstimate(hasFederatedInput, root);
+                       double computingCost = 
ComputeCost.getHOPComputeCost(root.hopRef);
+                       if ( hasFederatedInput ){
+                               //Find the number of inputs that has FOUT set.
+                               int numWorkers = 
(int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+                               //divide memory usage by the number of workers 
the computation would be split to multiplied by
+                               //the number of parallel processes at each 
worker multiplied by the FLOPS of each process
+                               //This assumes uniform workload among the 
workers with FOUT data involved in the operation
+                               //and assumes that the degree of parallelism 
and compute bandwidth are equal for all workers
+                               computingCost = computingCost / 
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       } else computingCost = computingCost / 
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       //Calculate output transfer cost if the operation is 
computed at federated workers and the output is forced to the coordinator
+                       //If the root is a federated DataOp, the data is forced 
to the coordinator even if no input is LOUT
+                       double outputTransferCost = ( root.hasLocalOutput() && 
(hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+                               
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+                       double readCost = 
root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+                       return new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts);
+               }
+       }
+
+       /**
+        * Returns input transfer cost estimate.
+        * The input transfer cost estimate is based on the memory estimate of 
LOUT when some input is FOUT
+        * except if root is a federated DataOp, since all input for this has 
to be at the coordinator.
+        * When no input is FOUT, the input transfer cost is always 0.
+        * @param hasFederatedInput true if root has any FOUT input
+        * @param root hopRel for which cost is estimated
+        * @return input transfer cost estimate
+        */
+       private double inputTransferCostEstimate(boolean hasFederatedInput, 
HopRel root){
+               if ( hasFederatedInput )
+                       return root.inputDependency.stream()

Review comment:
       Yes :smiley: 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to