Baunsgaard commented on a change in pull request #1395:
URL: https://github.com/apache/systemds/pull/1395#discussion_r722214525



##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
                }
        }
 
+       /**
+        * Return cost estimate in bytes of Hop DAG starting from given root 
HopRel.
+        * @param root HopRel of Hop DAG for which cost is estimated
+        * @param hopRelMemo memo table of HopRels for calculating input costs
+        * @return cost estimation of Hop DAG starting from given root HopRel
+        */
+       public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> 
hopRelMemo){
+               // Check if root is in memo table.
+               if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+                       && 
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == 
root.fedOut) ){
+                       return root.getCostObject();
+               }
+               else {
+                       // If no input has FOUT, the root will be processed by 
the coordinator
+                       boolean hasFederatedInput = 
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+                       //the input cost is included the first time the input 
hop is used
+                       //for additional usage, the additional cost is zero 
(disregarding potential read cost)

Review comment:
       consistently add a space after the // start.
   if you use our code formatter it does it for you.

##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
                }
        }
 
+       /**
+        * Return cost estimate in bytes of Hop DAG starting from given root 
HopRel.
+        * @param root HopRel of Hop DAG for which cost is estimated
+        * @param hopRelMemo memo table of HopRels for calculating input costs
+        * @return cost estimation of Hop DAG starting from given root HopRel
+        */
+       public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> 
hopRelMemo){
+               // Check if root is in memo table.
+               if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+                       && 
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == 
root.fedOut) ){
+                       return root.getCostObject();
+               }
+               else {
+                       // If no input has FOUT, the root will be processed by 
the coordinator
+                       boolean hasFederatedInput = 
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+                       //the input cost is included the first time the input 
hop is used
+                       //for additional usage, the additional cost is zero 
(disregarding potential read cost)
+                       double inputCosts = root.inputDependency.stream()
+                               .mapToDouble( in -> {
+                                       double inCost = 
in.existingCostPointer(root.hopRef.getHopID()) ?
+                                               0 : costEstimate(in, 
hopRelMemo).getTotal();
+                                       
in.addCostPointer(root.hopRef.getHopID());
+                                       return inCost;
+                               } )
+                               .sum();
+                       double inputTransferCost = 
inputTransferCostEstimate(hasFederatedInput, root);
+                       double computingCost = 
ComputeCost.getHOPComputeCost(root.hopRef);
+                       if ( hasFederatedInput ){
+                               //Find the number of inputs that has FOUT set.
+                               int numWorkers = 
(int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+                               //divide memory usage by the number of workers 
the computation would be split to multiplied by
+                               //the number of parallel processes at each 
worker multiplied by the FLOPS of each process
+                               //This assumes uniform workload among the 
workers with FOUT data involved in the operation
+                               //and assumes that the degree of parallelism 
and compute bandwidth are equal for all workers
+                               computingCost = computingCost / 
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       } else computingCost = computingCost / 
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       //Calculate output transfer cost if the operation is 
computed at federated workers and the output is forced to the coordinator
+                       //If the root is a federated DataOp, the data is forced 
to the coordinator even if no input is LOUT
+                       double outputTransferCost = ( root.hasLocalOutput() && 
(hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+                               
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+                       double readCost = 
root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+                       return new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts);
+               }
+       }
+
+       /**
+        * Returns input transfer cost estimate.
+        * The input transfer cost estimate is based on the memory estimate of 
LOUT when some input is FOUT
+        * except if root is a federated DataOp, since all input for this has 
to be at the coordinator.
+        * When no input is FOUT, the input transfer cost is always 0.
+        * @param hasFederatedInput true if root has any FOUT input
+        * @param root hopRel for which cost is estimated
+        * @return input transfer cost estimate
+        */
+       private double inputTransferCostEstimate(boolean hasFederatedInput, 
HopRel root){
+               if ( hasFederatedInput )
+                       return root.inputDependency.stream()
+                               .filter(input -> 
(root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() : 
input.hasLocalOutput() )
+                               .mapToDouble(in -> 
in.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+                               .map(inMem -> inMem/ 
WORKER_NETWORK_BANDWIDTH_BYTES_PS)
+                               .sum();
+               else return 0;
+       }
+
+       /**
+        * Returns input transfer cost estimate.
+        * The input transfer cost estimate is based on the memory estimate of 
LOUT when some input is FOUT
+        * except if root is a federated DataOp, since all input for this has 
to be at the coordinator.
+        * When no input is FOUT, the input transfer cost is always 0.
+        * @param hasFederatedInput true if root has any FOUT input
+        * @param root hop for which cost is estimated
+        * @return input transfer cost estimate
+        */
+       private double inputTransferCostEstimate(boolean hasFederatedInput, Hop 
root){
+               if ( hasFederatedInput )
+                       return root.getInput().stream()
+                               .filter(input -> (root.isFederatedDataOp()) ? 
input.hasFederatedOutput() : input.hasLocalOutput() )
+                               .mapToDouble(in -> 
in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+                               .map(inMem -> inMem/ 
WORKER_NETWORK_BANDWIDTH_BYTES_PS)

Review comment:
       if you divide each element with a value it is the same as dividing the 
sum with the value. therefore you don't need to map here.

##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
                }
        }
 
+       /**
+        * Return cost estimate in bytes of Hop DAG starting from given root 
HopRel.
+        * @param root HopRel of Hop DAG for which cost is estimated
+        * @param hopRelMemo memo table of HopRels for calculating input costs
+        * @return cost estimation of Hop DAG starting from given root HopRel
+        */
+       public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> 
hopRelMemo){
+               // Check if root is in memo table.
+               if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+                       && 
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == 
root.fedOut) ){
+                       return root.getCostObject();
+               }
+               else {
+                       // If no input has FOUT, the root will be processed by 
the coordinator
+                       boolean hasFederatedInput = 
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+                       //the input cost is included the first time the input 
hop is used
+                       //for additional usage, the additional cost is zero 
(disregarding potential read cost)
+                       double inputCosts = root.inputDependency.stream()
+                               .mapToDouble( in -> {
+                                       double inCost = 
in.existingCostPointer(root.hopRef.getHopID()) ?
+                                               0 : costEstimate(in, 
hopRelMemo).getTotal();
+                                       
in.addCostPointer(root.hopRef.getHopID());
+                                       return inCost;
+                               } )
+                               .sum();
+                       double inputTransferCost = 
inputTransferCostEstimate(hasFederatedInput, root);
+                       double computingCost = 
ComputeCost.getHOPComputeCost(root.hopRef);
+                       if ( hasFederatedInput ){
+                               //Find the number of inputs that has FOUT set.
+                               int numWorkers = 
(int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+                               //divide memory usage by the number of workers 
the computation would be split to multiplied by
+                               //the number of parallel processes at each 
worker multiplied by the FLOPS of each process
+                               //This assumes uniform workload among the 
workers with FOUT data involved in the operation
+                               //and assumes that the degree of parallelism 
and compute bandwidth are equal for all workers
+                               computingCost = computingCost / 
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       } else computingCost = computingCost / 
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       //Calculate output transfer cost if the operation is 
computed at federated workers and the output is forced to the coordinator
+                       //If the root is a federated DataOp, the data is forced 
to the coordinator even if no input is LOUT
+                       double outputTransferCost = ( root.hasLocalOutput() && 
(hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+                               
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+                       double readCost = 
root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+                       return new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts);
+               }
+       }
+
+       /**
+        * Returns input transfer cost estimate.
+        * The input transfer cost estimate is based on the memory estimate of 
LOUT when some input is FOUT
+        * except if root is a federated DataOp, since all input for this has 
to be at the coordinator.
+        * When no input is FOUT, the input transfer cost is always 0.
+        * @param hasFederatedInput true if root has any FOUT input
+        * @param root hopRel for which cost is estimated
+        * @return input transfer cost estimate
+        */
+       private double inputTransferCostEstimate(boolean hasFederatedInput, 
HopRel root){
+               if ( hasFederatedInput )
+                       return root.inputDependency.stream()

Review comment:
       very fond of the steam API here :)

##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
                }
        }
 
+       /**
+        * Return cost estimate in bytes of Hop DAG starting from given root 
HopRel.
+        * @param root HopRel of Hop DAG for which cost is estimated
+        * @param hopRelMemo memo table of HopRels for calculating input costs
+        * @return cost estimation of Hop DAG starting from given root HopRel
+        */
+       public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> 
hopRelMemo){
+               // Check if root is in memo table.
+               if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+                       && 
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == 
root.fedOut) ){
+                       return root.getCostObject();
+               }
+               else {
+                       // If no input has FOUT, the root will be processed by 
the coordinator
+                       boolean hasFederatedInput = 
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+                       //the input cost is included the first time the input 
hop is used
+                       //for additional usage, the additional cost is zero 
(disregarding potential read cost)
+                       double inputCosts = root.inputDependency.stream()
+                               .mapToDouble( in -> {
+                                       double inCost = 
in.existingCostPointer(root.hopRef.getHopID()) ?
+                                               0 : costEstimate(in, 
hopRelMemo).getTotal();
+                                       
in.addCostPointer(root.hopRef.getHopID());
+                                       return inCost;
+                               } )
+                               .sum();

Review comment:
       seems overkill with a stream sum here, i would suggest making method, 
where you use a for loop instead.

##########
File path: src/main/java/org/apache/sysds/hops/cost/HopRel.java
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of 
setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated 
execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal 
FederatedOutput values for the inputs.
+ */
+public class HopRel {
+       protected Hop hopRef;
+       protected FEDInstruction.FederatedOutput fedOut;
+       protected FederatedCost cost;
+       protected Set<Long> costPointerSet = new HashSet<>();
+       protected List<HopRel> inputDependency = new ArrayList<>();
+
+       /**
+        * Constructs a HopRel with input dependency and cost estimate based on 
entries in hopRelMemo.

Review comment:
       If you use our formatter it will add a line after the description and 
before the (add)param. 

##########
File path: src/main/java/org/apache/sysds/hops/cost/HopRel.java
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of 
setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated 
execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal 
FederatedOutput values for the inputs.
+ */
+public class HopRel {
+       protected Hop hopRef;
+       protected FEDInstruction.FederatedOutput fedOut;
+       protected FederatedCost cost;
+       protected Set<Long> costPointerSet = new HashSet<>();
+       protected List<HopRel> inputDependency = new ArrayList<>();
+
+       /**
+        * Constructs a HopRel with input dependency and cost estimate based on 
entries in hopRelMemo.
+        * @param associatedHop hop associated with this HopRel
+        * @param fedOut FederatedOutput value assigned to this HopRel
+        * @param hopRelMemo memo table storing other HopRels including the 
inputs of associatedHop
+        */
+       public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, 
Map<Long, List<HopRel>> hopRelMemo){
+               hopRef = associatedHop;
+               this.fedOut = fedOut;
+               setInputDependency(hopRelMemo);
+               cost = new FederatedCostEstimator().costEstimate(this, 
hopRelMemo);
+       }
+
+       /**
+        * Adds hopID to set of hops pointing to this HopRel.
+        * By storing the hopID it can later be determined if the cost
+        * stored in this HopRel is already used as input cost in another 
HopRel.
+        * @param hopID added to set of stored cost pointers
+        */
+       public void addCostPointer(long hopID){
+               costPointerSet.add(hopID);
+       }
+
+       /**
+        * Checks if another Hop is refering to this HopRel in memo table.
+        * A reference from a HopRel with same Hop ID is allowed, so this
+        * ID is ignored when checking references.
+        * @param currentHopID to ignore when checking references
+        * @return true if another Hop refers to this HopRel in memo table
+        */
+       public boolean existingCostPointer(long currentHopID){
+               if ( costPointerSet.contains(currentHopID) )
+                       return costPointerSet.size() > 1;
+               else return costPointerSet.size() > 0;
+       }
+
+       public boolean hasLocalOutput(){
+               return fedOut == FederatedOutput.LOUT;
+       }
+
+       public boolean hasFederatedOutput(){
+               return fedOut == FederatedOutput.FOUT;
+       }
+
+       public FederatedOutput getFederatedOutput(){
+               return fedOut;
+       }
+
+       public List<HopRel> getInputDependency(){
+               return inputDependency;
+       }
+
+       public Hop getHopRef(){
+               return hopRef;
+       }
+
+       /**
+        * Returns FOUT HopRel for given hop found in hopRelMemo or returns 
null if HopRel not found.
+        * @param hop to look for in hopRelMemo
+        * @param hopRelMemo memo table storing HopRels
+        * @return FOUT HopRel found in hopRelMemo
+        */
+       private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>> 
hopRelMemo){
+               return 
hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);
+       }
+
+       /**
+        * Get the HopRel with minimum cost for given hop
+        * @param hopRelMemo memo table storing HopRels
+        * @param input hop for which minimum cost HopRel is found
+        * @return HopRel with minimum cost for given hop
+        */
+       private HopRel getMinOfInput(Map<Long, List<HopRel>> hopRelMemo, Hop 
input){
+               return hopRelMemo.get(input.getHopID()).stream()
+                       .min(Comparator.comparingDouble(a -> a.cost.getTotal()))
+                       .orElseThrow(() -> new DMLException("No element in Memo 
Table found for input"));
+       }
+
+       /**
+        * Set valid and optimal input dependency for this HopRel as a field.
+        * @param hopRelMemo memo table storing input HopRels
+        */
+       private void setInputDependency(Map<Long, List<HopRel>> hopRelMemo){
+               if (hopRef.getInput() != null && hopRef.getInput().size() > 0) {
+                       if ( fedOut == FederatedOutput.FOUT && 
!hopRef.isFederatedDataOp() ) {
+                               int lowestFOUTIndex = 0;
+                               HopRel lowestFOUTHopRel = 
getFOUTHopRel(hopRef.getInput().get(0), hopRelMemo);
+                               for(int i = 1; i < hopRef.getInput().size(); 
i++) {
+                                       Hop input = hopRef.getInput(i);
+                                       HopRel foutHopRel = 
getFOUTHopRel(input, hopRelMemo);
+                                       if(lowestFOUTHopRel == null) {
+                                               lowestFOUTHopRel = foutHopRel;
+                                               lowestFOUTIndex = i;
+                                       }
+                                       else if(foutHopRel != null) {
+                                               if(foutHopRel.getCost() < 
lowestFOUTHopRel.getCost()) {
+                                                       lowestFOUTHopRel = 
foutHopRel;
+                                                       lowestFOUTIndex = i;
+                                               }
+                                       }
+                               }
+
+                               HopRel[] inputHopRels = new 
HopRel[hopRef.getInput().size()];
+                               for(int i = 0; i < hopRef.getInput().size(); 
i++) {
+                                       if(i != lowestFOUTIndex) {
+                                               Hop input = hopRef.getInput(i);
+                                               inputHopRels[i] = 
getMinOfInput(hopRelMemo, input);
+                                       }
+                                       else {
+                                               inputHopRels[i] = 
lowestFOUTHopRel;
+                                       }
+                               }
+                               
inputDependency.addAll(Arrays.asList(inputHopRels));
+                       } else {
+                               inputDependency.addAll(
+                                       hopRef.getInput().stream()
+                                               .map(input -> 
getMinOfInput(hopRelMemo, input))
+                                               .collect(Collectors.toList()));
+                       }
+               }
+               validateInputDependency();
+       }
+
+       /**
+        * Throws exception if any input dependency is null.
+        * If any of the input dependencies are null, it is not possible to 
build a federated execution plan.
+        * If this null-state is not found here, an exception will be thrown at 
another difficult-to-debug place.
+        */
+       private void validateInputDependency(){
+               for ( int i = 0; i < inputDependency.size(); i++){
+                       if ( inputDependency.get(i) == null)
+                               throw new DMLException("HopRel input number " + 
i + " (" + hopRef.getInput(i) + ")"
+                                       + " is null for root: \n" + this);
+               }

Review comment:
       you can remove the last outer "}"

##########
File path: src/main/java/org/apache/sysds/hops/cost/HopRel.java
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of 
setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated 
execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal 
FederatedOutput values for the inputs.
+ */
+public class HopRel {
+       protected Hop hopRef;
+       protected FEDInstruction.FederatedOutput fedOut;
+       protected FederatedCost cost;
+       protected Set<Long> costPointerSet = new HashSet<>();
+       protected List<HopRel> inputDependency = new ArrayList<>();

Review comment:
       are some of these fields final?

##########
File path: 
src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
##########
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.cost.HopRel;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
+import org.apache.sysds.hops.ipa.IPAPass;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This rewrite generates a federated execution plan by estimating and setting 
costs and the FederatedOutput values of
+ * all relevant hops in the DML program.
+ * The rewrite is only applied if federated compilation is activated in 
OptimizerUtils.
+ */
+public class IPAPassRewriteFederatedPlan extends IPAPass {
+
+       private final static Map<Long, List<HopRel>> hopRelMemo = new 
HashMap<>();
+
+       /**
+        * Indicates if an IPA pass is applicable for the current configuration.
+        * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
+        *
+        * @param fgraph function call graph
+        * @return true if federated compilation is activated.
+        */
+       @Override
+       public boolean isApplicable(FunctionCallGraph fgraph) {
+               return OptimizerUtils.FEDERATED_COMPILATION;
+       }
+
+       /**
+        * Estimates cost and selects a federated execution plan
+        * by setting the federated output value of each hop in the program.
+        *
+        * @param prog       dml program
+        * @param fgraph     function call graph
+        * @param fcallSizes function call size infos
+        * @return false since the function call graph never has to be rebuilt
+        */
+       @Override
+       public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes) {
+               rewriteStatementBlocks(prog.getStatementBlocks());
+               return false;
+       }
+
+       /**
+        * Estimates cost and selects a federated execution plan
+        * by setting the federated output value of each hop in the statement 
blocks.
+        * The method calls the contained statement blocks recursively.
+        *
+        * @param sbs   list of statement blocks
+        * @return list of statement blocks with the federated output value 
updated for each hop
+        */
+       public ArrayList<StatementBlock> 
rewriteStatementBlocks(List<StatementBlock> sbs) {
+               ArrayList<StatementBlock> rewrittenStmBlocks = new 
ArrayList<>();
+               for ( StatementBlock stmBlock : sbs )
+                       
rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
+               return rewrittenStmBlocks;
+       }
+
+       /**
+        * Estimates cost and selects a federated execution plan
+        * by setting the federated output value of each hop in the statement 
blocks.
+        * The method calls the contained statement blocks recursively.
+        *
+        * @param sb    statement block
+        * @return list of statement blocks with the federated output value 
updated for each hop
+        */
+       public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock 
sb) {
+               if ( sb instanceof WhileStatementBlock)
+                       return rewriteWhileStatementBlock((WhileStatementBlock) 
sb);
+               else if ( sb instanceof IfStatementBlock)
+                       return rewriteIfStatementBlock((IfStatementBlock) sb);
+               else if ( sb instanceof ForStatementBlock){
+                       // This also includes ParForStatementBlocks
+                       return rewriteForStatementBlock((ForStatementBlock) sb);
+               }
+               else if ( sb instanceof FunctionStatementBlock)
+                       return 
rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
+               else {
+                       // StatementBlock type (no subclass)
+                       selectFederatedExecutionPlan(sb.getHops());
+               }
+               return new ArrayList<>(Collections.singletonList(sb));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteWhileStatementBlock(WhileStatementBlock whileSB){
+               Hop whilePredicateHop = whileSB.getPredicateHops();
+               selectFederatedExecutionPlan(whilePredicateHop);
+               for ( Statement stm : whileSB.getStatements() ){
+                       WhileStatement whileStm = (WhileStatement) stm;
+                       
whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(whileSB));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteIfStatementBlock(IfStatementBlock ifSB){
+               selectFederatedExecutionPlan(ifSB.getPredicateHops());
+               for ( Statement statement : ifSB.getStatements() ){
+                       IfStatement ifStatement = (IfStatement) statement;
+                       
ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
+                       
ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(ifSB));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteForStatementBlock(ForStatementBlock forSB){
+               selectFederatedExecutionPlan(forSB.getFromHops());
+               selectFederatedExecutionPlan(forSB.getToHops());
+               selectFederatedExecutionPlan(forSB.getIncrementHops());
+               for ( Statement statement : forSB.getStatements() ){
+                       ForStatement forStatement = ((ForStatement)statement);
+                       
forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(forSB));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
+               for ( Statement statement : funcSB.getStatements() ){
+                       FunctionStatement funcStm = (FunctionStatement) 
statement;
+                       
funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(funcSB));
+       }
+
+       /**
+        * Sets FederatedOutput field of all hops in DAG starting from given 
root.
+        * The FederatedOutput chosen for root is the minimum cost HopRel found 
in memo table for the given root.
+        * The FederatedOutput values chosen for the inputs to the root are 
chosen based on the input dependencies.
+        * @param root hop for which FederatedOutput needs to be set
+        */
+       private void setFinalFedout(Hop root){
+               HopRel optimalRootHopRel = 
hopRelMemo.get(root.getHopID()).stream().min(Comparator.comparingDouble(HopRel::getCost))
+                       .orElseThrow(() -> new DMLException("Hop root " + root 
+ " has no feasible federated output alternatives"));
+               setFinalFedout(root, optimalRootHopRel);
+       }
+
+       /**
+        * Update the FederatedOutput value and cost based on information 
stored in given rootHopRel.
+        * @param root hop for which FederatedOutput is set
+        * @param rootHopRel from which FederatedOutput value and cost is 
retrieved
+        */
+       private void setFinalFedout(Hop root, HopRel rootHopRel){
+               updateFederatedOutput(root, rootHopRel);
+               visitInputDependency(rootHopRel);
+       }
+
+       /**
+        * Sets FederatedOutput value for each of the inputs of rootHopRel
+        * @param rootHopRel which has its input values updated
+        */
+       private void visitInputDependency(HopRel rootHopRel){
+               List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
+               for ( HopRel input : hopRelInputs )
+                       setFinalFedout(input.getHopRef(), input);
+       }
+
+       /**
+        * Updates FederatedOutput value and cost estimate based on 
updateHopRel values.
+        * @param root which has its values updated
+        * @param updateHopRel from which the values are retrieved
+        */
+       private void updateFederatedOutput(Hop root, HopRel updateHopRel){
+               root.setFederatedOutput(updateHopRel.getFederatedOutput());
+               root.setFederatedCost(updateHopRel.getCostObject());
+       }
+
+       /**
+        * Select federated execution plan for every Hop in the DAG starting 
from given roots.
+        * The cost estimates of the hops are also updated when FederatedOutput 
is updated in the hops.
+        * @param roots starting point for going through the Hop DAG to update 
the FederatedOutput fields.
+        */
+       private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
+               for ( Hop root : roots ){
+                       selectFederatedExecutionPlan(root);
+               }
+       }
+
+       /**
+        * Select federated execution plan for every Hop in the DAG starting 
from given root.
+        * @param root starting point for going through the Hop DAG to update 
the federatedOutput fields
+        */
+       private void selectFederatedExecutionPlan(Hop root){
+               visitFedPlanHop(root);
+               setFinalFedout(root);
+       }
+
+       /**
+        * Go through the Hop DAG and set the FederatedOutput field and cost 
estimate for each Hop from leaf to given currentHop.
+        * @param currentHop the Hop from which the DAG is visited
+        */
+       private void visitFedPlanHop(Hop currentHop){
+               // If the currentHop is in the hopRelMemo table, it means that 
it has been visited
+               if ( hopRelMemo.containsKey(currentHop.getHopID()) )
+                       return;
+               // If the currentHop has input, then the input should be 
visited depth-first
+               if ( currentHop.getInput() != null && 
currentHop.getInput().size() > 0 ){
+                       for ( Hop input : currentHop.getInput() )
+                               visitFedPlanHop(input);
+               }
+               // Put FOUT, LOUT, and None HopRels into the memo table
+               ArrayList<HopRel> hopRels = new ArrayList<>();
+               if ( isFedInstSupportedHop(currentHop) ){
+                       for ( FEDInstruction.FederatedOutput fedoutValue : 
FEDInstruction.FederatedOutput.values() )
+                               if ( isFedOutSupported(currentHop, fedoutValue) 
)
+                                       hopRels.add(new 
HopRel(currentHop,fedoutValue, hopRelMemo));
+               }
+               if ( hopRels.isEmpty() )
+                       hopRels.add(new HopRel(currentHop, 
FEDInstruction.FederatedOutput.NONE, hopRelMemo));
+               hopRelMemo.put(currentHop.getHopID(), hopRels);
+               currentHop.setVisited();
+       }
+
+       /**
+        * Checks if the instructions related to the given hop supports 
FOUT/LOUT processing.
+        * @param hop to check for federated support
+        * @return true if federated instructions related to hop supports 
FOUT/LOUT processing
+        */
+       private boolean isFedInstSupportedHop(Hop hop){
+               // The following operations are supported given that the above 
conditions have not returned already
+               return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp 
|| hop instanceof ReorgOp
+                       || hop instanceof AggUnaryOp || hop instanceof 
TernaryOp || hop instanceof DataOp);
+       }
+
+       /**
+        * Checks if the associatedHop supports the given federated output 
value.
+        * @param associatedHop to check support of
+        * @param fedOut federated output value
+        * @return true if associatedHop supports fedOut
+        */
+       private boolean isFedOutSupported(Hop associatedHop, 
FEDInstruction.FederatedOutput fedOut){
+               switch(fedOut){
+                       case FOUT:
+                               return isFOUTSupported(associatedHop);
+                       case LOUT:
+                               return isLOUTSupported(associatedHop);
+                       case NONE:
+                               return false;
+                       default:
+                               return true;
+               }
+       }
+
+       /**
+        * Checks to see if the associatedHop supports FOUT.
+        * @param associatedHop for which FOUT support is checked
+        * @return true if FOUT is supported by the associatedHop
+        */
+       private boolean isFOUTSupported(Hop associatedHop){
+               // If the output of AggUnaryOp is a scalar, the operation 
cannot be FOUT
+               if ( associatedHop instanceof AggUnaryOp && 
associatedHop.isScalar() )
+                       return false;
+               // It can only be FOUT if at least one of the inputs are FOUT, 
except if it is a federated DataOp
+               if ( associatedHop.getInput().stream().noneMatch(
+                       input -> 
hopRelMemo.get(input.getHopID()).stream().anyMatch(HopRel::hasFederatedOutput) )
+                       && !associatedHop.isFederatedDataOp() )
+                       return false;
+               return true;
+       }
+
+       /**
+        * Checks to see if the associatedHop supports LOUT.
+        * It supports LOUT if the output has no privacy constraints.
+        * @param associatedHop for which LOUT support is checked.
+        * @return true if LOUT is supported by the associatedHop
+        */
+       private boolean isLOUTSupported(Hop associatedHop){
+               return associatedHop.getPrivacy() == null

Review comment:
       i would remove the first part of the parenthesis, because it is 
implicitly applied from the first part of the or.

##########
File path: 
src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
##########
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.cost.HopRel;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
+import org.apache.sysds.hops.ipa.IPAPass;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This rewrite generates a federated execution plan by estimating and setting 
costs and the FederatedOutput values of
+ * all relevant hops in the DML program.
+ * The rewrite is only applied if federated compilation is activated in 
OptimizerUtils.
+ */
+public class IPAPassRewriteFederatedPlan extends IPAPass {
+
+       private final static Map<Long, List<HopRel>> hopRelMemo = new 
HashMap<>();
+
+       /**
+        * Indicates if an IPA pass is applicable for the current configuration.
+        * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
+        *
+        * @param fgraph function call graph
+        * @return true if federated compilation is activated.
+        */
+       @Override
+       public boolean isApplicable(FunctionCallGraph fgraph) {
+               return OptimizerUtils.FEDERATED_COMPILATION;
+       }
+
+       /**
+        * Estimates cost and selects a federated execution plan
+        * by setting the federated output value of each hop in the program.
+        *
+        * @param prog       dml program
+        * @param fgraph     function call graph
+        * @param fcallSizes function call size infos
+        * @return false since the function call graph never has to be rebuilt
+        */
+       @Override
+       public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes) {
+               rewriteStatementBlocks(prog.getStatementBlocks());
+               return false;
+       }
+
+       /**
+        * Estimates cost and selects a federated execution plan
+        * by setting the federated output value of each hop in the statement 
blocks.
+        * The method calls the contained statement blocks recursively.
+        *
+        * @param sbs   list of statement blocks
+        * @return list of statement blocks with the federated output value 
updated for each hop
+        */
+       public ArrayList<StatementBlock> 
rewriteStatementBlocks(List<StatementBlock> sbs) {
+               ArrayList<StatementBlock> rewrittenStmBlocks = new 
ArrayList<>();
+               for ( StatementBlock stmBlock : sbs )
+                       
rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
+               return rewrittenStmBlocks;
+       }
+
+       /**
+        * Estimates cost and selects a federated execution plan
+        * by setting the federated output value of each hop in the statement 
blocks.
+        * The method calls the contained statement blocks recursively.
+        *
+        * @param sb    statement block
+        * @return list of statement blocks with the federated output value 
updated for each hop
+        */
+       public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock 
sb) {
+               if ( sb instanceof WhileStatementBlock)
+                       return rewriteWhileStatementBlock((WhileStatementBlock) 
sb);
+               else if ( sb instanceof IfStatementBlock)
+                       return rewriteIfStatementBlock((IfStatementBlock) sb);
+               else if ( sb instanceof ForStatementBlock){
+                       // This also includes ParForStatementBlocks
+                       return rewriteForStatementBlock((ForStatementBlock) sb);
+               }
+               else if ( sb instanceof FunctionStatementBlock)
+                       return 
rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
+               else {
+                       // StatementBlock type (no subclass)
+                       selectFederatedExecutionPlan(sb.getHops());
+               }
+               return new ArrayList<>(Collections.singletonList(sb));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteWhileStatementBlock(WhileStatementBlock whileSB){
+               Hop whilePredicateHop = whileSB.getPredicateHops();
+               selectFederatedExecutionPlan(whilePredicateHop);
+               for ( Statement stm : whileSB.getStatements() ){
+                       WhileStatement whileStm = (WhileStatement) stm;
+                       
whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(whileSB));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteIfStatementBlock(IfStatementBlock ifSB){
+               selectFederatedExecutionPlan(ifSB.getPredicateHops());
+               for ( Statement statement : ifSB.getStatements() ){
+                       IfStatement ifStatement = (IfStatement) statement;
+                       
ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
+                       
ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(ifSB));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteForStatementBlock(ForStatementBlock forSB){
+               selectFederatedExecutionPlan(forSB.getFromHops());
+               selectFederatedExecutionPlan(forSB.getToHops());
+               selectFederatedExecutionPlan(forSB.getIncrementHops());
+               for ( Statement statement : forSB.getStatements() ){
+                       ForStatement forStatement = ((ForStatement)statement);
+                       
forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(forSB));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
+               for ( Statement statement : funcSB.getStatements() ){
+                       FunctionStatement funcStm = (FunctionStatement) 
statement;
+                       
funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(funcSB));
+       }
+
+       /**
+        * Sets FederatedOutput field of all hops in DAG starting from given 
root.
+        * The FederatedOutput chosen for root is the minimum cost HopRel found 
in memo table for the given root.
+        * The FederatedOutput values chosen for the inputs to the root are 
chosen based on the input dependencies.
+        * @param root hop for which FederatedOutput needs to be set
+        */
+       private void setFinalFedout(Hop root){
+               HopRel optimalRootHopRel = 
hopRelMemo.get(root.getHopID()).stream().min(Comparator.comparingDouble(HopRel::getCost))
+                       .orElseThrow(() -> new DMLException("Hop root " + root 
+ " has no feasible federated output alternatives"));
+               setFinalFedout(root, optimalRootHopRel);
+       }
+
+       /**
+        * Update the FederatedOutput value and cost based on information 
stored in given rootHopRel.
+        * @param root hop for which FederatedOutput is set
+        * @param rootHopRel from which FederatedOutput value and cost is 
retrieved
+        */
+       private void setFinalFedout(Hop root, HopRel rootHopRel){
+               updateFederatedOutput(root, rootHopRel);
+               visitInputDependency(rootHopRel);
+       }
+
+       /**
+        * Sets FederatedOutput value for each of the inputs of rootHopRel
+        * @param rootHopRel which has its input values updated
+        */
+       private void visitInputDependency(HopRel rootHopRel){
+               List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
+               for ( HopRel input : hopRelInputs )
+                       setFinalFedout(input.getHopRef(), input);
+       }
+
+       /**
+        * Updates FederatedOutput value and cost estimate based on 
updateHopRel values.
+        * @param root which has its values updated
+        * @param updateHopRel from which the values are retrieved
+        */
+       private void updateFederatedOutput(Hop root, HopRel updateHopRel){
+               root.setFederatedOutput(updateHopRel.getFederatedOutput());
+               root.setFederatedCost(updateHopRel.getCostObject());
+       }
+
+       /**
+        * Select federated execution plan for every Hop in the DAG starting 
from given roots.
+        * The cost estimates of the hops are also updated when FederatedOutput 
is updated in the hops.
+        * @param roots starting point for going through the Hop DAG to update 
the FederatedOutput fields.
+        */
+       private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
+               for ( Hop root : roots ){

Review comment:
       formatting and remove {

##########
File path: src/main/java/org/apache/sysds/hops/cost/HopRel.java
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of 
setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated 
execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal 
FederatedOutput values for the inputs.
+ */
+public class HopRel {
+       protected Hop hopRef;
+       protected FEDInstruction.FederatedOutput fedOut;
+       protected FederatedCost cost;
+       protected Set<Long> costPointerSet = new HashSet<>();
+       protected List<HopRel> inputDependency = new ArrayList<>();
+
+       /**
+        * Constructs a HopRel with input dependency and cost estimate based on 
entries in hopRelMemo.
+        * @param associatedHop hop associated with this HopRel
+        * @param fedOut FederatedOutput value assigned to this HopRel
+        * @param hopRelMemo memo table storing other HopRels including the 
inputs of associatedHop
+        */
+       public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, 
Map<Long, List<HopRel>> hopRelMemo){
+               hopRef = associatedHop;
+               this.fedOut = fedOut;
+               setInputDependency(hopRelMemo);
+               cost = new FederatedCostEstimator().costEstimate(this, 
hopRelMemo);
+       }
+
+       /**
+        * Adds hopID to set of hops pointing to this HopRel.
+        * By storing the hopID it can later be determined if the cost
+        * stored in this HopRel is already used as input cost in another 
HopRel.
+        * @param hopID added to set of stored cost pointers
+        */
+       public void addCostPointer(long hopID){
+               costPointerSet.add(hopID);
+       }
+
+       /**
+        * Checks if another Hop is refering to this HopRel in memo table.
+        * A reference from a HopRel with same Hop ID is allowed, so this
+        * ID is ignored when checking references.
+        * @param currentHopID to ignore when checking references
+        * @return true if another Hop refers to this HopRel in memo table
+        */
+       public boolean existingCostPointer(long currentHopID){
+               if ( costPointerSet.contains(currentHopID) )
+                       return costPointerSet.size() > 1;
+               else return costPointerSet.size() > 0;
+       }
+
+       public boolean hasLocalOutput(){
+               return fedOut == FederatedOutput.LOUT;
+       }
+
+       public boolean hasFederatedOutput(){
+               return fedOut == FederatedOutput.FOUT;
+       }
+
+       public FederatedOutput getFederatedOutput(){
+               return fedOut;
+       }
+
+       public List<HopRel> getInputDependency(){
+               return inputDependency;
+       }
+
+       public Hop getHopRef(){
+               return hopRef;
+       }
+
+       /**
+        * Returns FOUT HopRel for given hop found in hopRelMemo or returns 
null if HopRel not found.
+        * @param hop to look for in hopRelMemo
+        * @param hopRelMemo memo table storing HopRels
+        * @return FOUT HopRel found in hopRelMemo
+        */
+       private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>> 
hopRelMemo){
+               return 
hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);

Review comment:
       fancy one liner.

##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
                }
        }
 
+       /**
+        * Return cost estimate in bytes of Hop DAG starting from given root 
HopRel.
+        * @param root HopRel of Hop DAG for which cost is estimated
+        * @param hopRelMemo memo table of HopRels for calculating input costs
+        * @return cost estimation of Hop DAG starting from given root HopRel
+        */
+       public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> 
hopRelMemo){
+               // Check if root is in memo table.
+               if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+                       && 
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == 
root.fedOut) ){
+                       return root.getCostObject();
+               }
+               else {
+                       // If no input has FOUT, the root will be processed by 
the coordinator
+                       boolean hasFederatedInput = 
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+                       //the input cost is included the first time the input 
hop is used
+                       //for additional usage, the additional cost is zero 
(disregarding potential read cost)
+                       double inputCosts = root.inputDependency.stream()
+                               .mapToDouble( in -> {
+                                       double inCost = 
in.existingCostPointer(root.hopRef.getHopID()) ?
+                                               0 : costEstimate(in, 
hopRelMemo).getTotal();
+                                       
in.addCostPointer(root.hopRef.getHopID());
+                                       return inCost;
+                               } )
+                               .sum();
+                       double inputTransferCost = 
inputTransferCostEstimate(hasFederatedInput, root);
+                       double computingCost = 
ComputeCost.getHOPComputeCost(root.hopRef);
+                       if ( hasFederatedInput ){
+                               //Find the number of inputs that has FOUT set.
+                               int numWorkers = 
(int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+                               //divide memory usage by the number of workers 
the computation would be split to multiplied by
+                               //the number of parallel processes at each 
worker multiplied by the FLOPS of each process
+                               //This assumes uniform workload among the 
workers with FOUT data involved in the operation
+                               //and assumes that the degree of parallelism 
and compute bandwidth are equal for all workers
+                               computingCost = computingCost / 
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       } else computingCost = computingCost / 
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
+                       //Calculate output transfer cost if the operation is 
computed at federated workers and the output is forced to the coordinator
+                       //If the root is a federated DataOp, the data is forced 
to the coordinator even if no input is LOUT
+                       double outputTransferCost = ( root.hasLocalOutput() && 
(hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+                               
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+                       double readCost = 
root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+                       return new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts);
+               }
+       }
+
+       /**
+        * Returns input transfer cost estimate.
+        * The input transfer cost estimate is based on the memory estimate of 
LOUT when some input is FOUT
+        * except if root is a federated DataOp, since all input for this has 
to be at the coordinator.
+        * When no input is FOUT, the input transfer cost is always 0.
+        * @param hasFederatedInput true if root has any FOUT input
+        * @param root hopRel for which cost is estimated
+        * @return input transfer cost estimate
+        */
+       private double inputTransferCostEstimate(boolean hasFederatedInput, 
HopRel root){
+               if ( hasFederatedInput )
+                       return root.inputDependency.stream()
+                               .filter(input -> 
(root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() : 
input.hasLocalOutput() )
+                               .mapToDouble(in -> 
in.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+                               .map(inMem -> inMem/ 
WORKER_NETWORK_BANDWIDTH_BYTES_PS)

Review comment:
       same argument here, don't divide each element, but divide the sum

##########
File path: src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
##########
@@ -196,6 +194,90 @@ private FederatedCost costEstimate(Hop root){
                }
        }
 
+       /**
+        * Return cost estimate in bytes of Hop DAG starting from given root 
HopRel.
+        * @param root HopRel of Hop DAG for which cost is estimated
+        * @param hopRelMemo memo table of HopRels for calculating input costs
+        * @return cost estimation of Hop DAG starting from given root HopRel
+        */
+       public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> 
hopRelMemo){
+               // Check if root is in memo table.
+               if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+                       && 
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == 
root.fedOut) ){
+                       return root.getCostObject();
+               }
+               else {
+                       // If no input has FOUT, the root will be processed by 
the coordinator
+                       boolean hasFederatedInput = 
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+                       //the input cost is included the first time the input 
hop is used
+                       //for additional usage, the additional cost is zero 
(disregarding potential read cost)
+                       double inputCosts = root.inputDependency.stream()
+                               .mapToDouble( in -> {
+                                       double inCost = 
in.existingCostPointer(root.hopRef.getHopID()) ?
+                                               0 : costEstimate(in, 
hopRelMemo).getTotal();
+                                       
in.addCostPointer(root.hopRef.getHopID());
+                                       return inCost;
+                               } )
+                               .sum();

Review comment:
       if you don't do that i would suggest that you define a method that you 
use int the map to double instead of the lambda function defined here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to