This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 235a165  [SYSTEMDS-3018] Federated Planner with Memo Table
235a165 is described below

commit 235a16530d3e5047d7a45662a1f70fa938ead814
Author: sebwrede <[email protected]>
AuthorDate: Tue Aug 24 12:50:04 2021 +0200

    [SYSTEMDS-3018] Federated Planner with Memo Table
    
    This commit:
    (1) Change federated plan rewriter to take an entire DML program.
    (2) Add Basic HopRel.
    (3) Add HopRel cost estimator.
    
    Closes #1395.
---
 src/main/java/org/apache/sysds/hops/DataOp.java    |  16 +-
 src/main/java/org/apache/sysds/hops/Hop.java       |  20 +-
 .../sysds/hops/cost/FederatedCostEstimator.java    | 116 ++++++--
 .../java/org/apache/sysds/hops/cost/HopRel.java    | 218 ++++++++++++++
 .../sysds/hops/ipa/InterProceduralAnalysis.java    |   4 +-
 .../hops/rewrite/IPAPassRewriteFederatedPlan.java  | 314 +++++++++++++++++++++
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |   1 -
 .../hops/rewrite/RewriteFederatedExecution.java    | 123 +-------
 .../rewrite/RewriteFederatedStatementBlocks.java   |  66 -----
 .../fedplanning/FederatedCostEstimatorTest.java    |  26 +-
 10 files changed, 671 insertions(+), 233 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java 
b/src/main/java/org/apache/sysds/hops/DataOp.java
index 9bc4607..548417d 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -364,7 +364,8 @@ public class DataOp extends Hop {
                return( _op == OpOpData.PERSISTENTREAD || _op == 
OpOpData.PERSISTENTWRITE );
        }
 
-       public boolean isFederatedData(){
+       @Override
+       public boolean isFederatedDataOp(){
                return _op == OpOpData.FEDERATED;
        }
 
@@ -496,17 +497,10 @@ public class DataOp extends Hop {
                        
                        _etype = letype;
                }
-               
-               return _etype;
-       }
 
-       /**
-        * True if execution is federated, if output is federated, or if 
OpOpData is federated.
-        * @return true if federated
-        */
-       @Override
-       public boolean isFederated() {
-               return super.isFederated() || getOp() == OpOpData.FEDERATED;
+               updateETFed();
+
+               return _etype;
        }
 
        @Override
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index 45fc3af..a25cf10 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -865,15 +865,19 @@ public abstract class Hop implements ParseInfo {
        }
 
        /**
-        * Update the execution type if input is federated and federated 
compilation is activated.
-        * Federated compilation is activated in OptimizerUtils.
+        * Update the execution type if input is federated.
         * This method only has an effect if FEDERATED_COMPILATION is activated.
+        * Federated compilation is activated in OptimizerUtils.
         */
        protected void updateETFed(){
-               if ( _federatedOutput.isForced() )
+               if ( someInputFederated() || isFederatedDataOp() )
                        _etype = ExecType.FED;
        }
-       
+
+       /**
+        * Checks if ExecType is federated.
+        * @return true if ExecType is federated
+        */
        public boolean isFederated(){
                return getExecType() == ExecType.FED;
        }
@@ -882,6 +886,14 @@ public abstract class Hop implements ParseInfo {
                return getInput().stream().anyMatch(Hop::hasFederatedOutput);
        }
 
+       /**
+        * Checks if the hop is a DataOp with federated data.
+        * @return true if hop is a federated DataOp
+        */
+       public boolean isFederatedDataOp(){
+               return false;
+       }
+
        public ArrayList<Hop> getParent() {
                return _parent;
        }
diff --git 
a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java 
b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
index 3e2f994..7089ed8 100644
--- a/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
+++ b/src/main/java/org/apache/sysds/hops/cost/FederatedCostEstimator.java
@@ -33,6 +33,8 @@ import org.apache.sysds.parser.WhileStatement;
 import org.apache.sysds.parser.WhileStatementBlock;
 
 import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
 
 /**
  * Cost estimator for federated executions with methods and constants for 
going through DML programs to estimate costs.
@@ -41,7 +43,7 @@ public class FederatedCostEstimator {
        public int DEFAULT_MEMORY_ESTIMATE = 8;
        public int DEFAULT_ITERATION_NUMBER = 15;
        public double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1024*1024*1024; 
//Default network bandwidth in bytes per second
-       public double WORKER_COMPUTE_BANDWITH_FLOPS = 2.5*1024*1024*1024; 
//Default compute bandwidth in FLOPS
+       public double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.5*1024*1024*1024; 
//Default compute bandwidth in FLOPS
        public double WORKER_DEGREE_OF_PARALLELISM = 8; //Default number of 
parallel processes for workers
        public double WORKER_READ_BANDWIDTH_BYTES_PS = 3.5*1024*1024*1024; 
//Default read bandwidth in bytes per second
 
@@ -154,34 +156,30 @@ public class FederatedCostEstimator {
         * @param root of Hop DAG for which cost is estimated
         * @return cost estimation of Hop DAG starting from given root
         */
-       private FederatedCost costEstimate(Hop root){
+       public FederatedCost costEstimate(Hop root){
                if ( root.federatedCostInitialized() )
                        return root.getFederatedCost();
                else {
                        // If no input has FOUT, the root will be processed by 
the coordinator
                        boolean hasFederatedInput = root.someInputFederated();
-                       //the input cost is included the first time the input 
hop is used
-                       //for additional usage, the additional cost is zero 
(disregarding potential read cost)
+                       // The input cost is included the first time the input 
hop is used.
+                       // For additional usage, the additional cost is zero 
(disregarding potential read cost).
                        double inputCosts = root.getInput().stream()
                                .mapToDouble( in -> 
in.federatedCostInitialized() ? 0 : costEstimate(in).getTotal() )
                                .sum();
-                       double inputTransferCost = hasFederatedInput ? 
root.getInput().stream()
-                               .filter(Hop::hasLocalOutput)
-                               .mapToDouble(in -> 
in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
-                               .map(inMem -> inMem/ 
WORKER_NETWORK_BANDWIDTH_BYTES_PS)
-                               .sum() : 0;
+                       double inputTransferCost = 
inputTransferCostEstimate(hasFederatedInput, root);
                        double computingCost = 
ComputeCost.getHOPComputeCost(root);
                        if ( hasFederatedInput ){
-                               //Find the number of inputs that has FOUT set.
+                               // Find the number of inputs that has FOUT set.
                                int numWorkers = 
(int)root.getInput().stream().filter(Hop::hasFederatedOutput).count();
-                               //divide memory usage by the number of workers 
the computation would be split to multiplied by
-                               //the number of parallel processes at each 
worker multiplied by the FLOPS of each process
-                               //This assumes uniform workload among the 
workers with FOUT data involved in the operation
-                               //and assumes that the degree of parallelism 
and compute bandwidth are equal for all workers
-                               computingCost = computingCost / 
(numWorkers*WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
-                       } else computingCost = computingCost / 
(WORKER_DEGREE_OF_PARALLELISM*WORKER_COMPUTE_BANDWITH_FLOPS);
-                       //Calculate output transfer cost if the operation is 
computed at federated workers and the output is forced to the coordinator
-                       double outputTransferCost = ( root.hasLocalOutput() && 
hasFederatedInput ) ?
+                               // Divide memory usage by the number of workers 
the computation would be split to multiplied by
+                               // the number of parallel processes at each 
worker multiplied by the FLOPS of each process.
+                               // This assumes uniform workload among the 
workers with FOUT data involved in the operation
+                               // and assumes that the degree of parallelism 
and compute bandwidth are equal for all workers
+                               computingCost = computingCost / 
(numWorkers*WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+                       } else computingCost = computingCost / 
(WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+                       // Calculate output transfer cost if the operation is 
computed at federated workers and the output is forced to the coordinator
+                       double outputTransferCost = ( root.hasLocalOutput() && 
(hasFederatedInput || root.isFederatedDataOp()) ) ?
                                
root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
                        double readCost = 
root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_READ_BANDWIDTH_BYTES_PS;
 
@@ -197,6 +195,88 @@ public class FederatedCostEstimator {
        }
 
        /**
+        * Return cost estimate in bytes of Hop DAG starting from given root 
HopRel.
+        * @param root HopRel of Hop DAG for which cost is estimated
+        * @param hopRelMemo memo table of HopRels for calculating input costs
+        * @return cost estimation of Hop DAG starting from given root HopRel
+        */
+       public FederatedCost costEstimate(HopRel root, Map<Long, List<HopRel>> 
hopRelMemo){
+               // Check if root is in memo table.
+               if ( hopRelMemo.containsKey(root.hopRef.getHopID())
+                       && 
hopRelMemo.get(root.hopRef.getHopID()).stream().anyMatch(h -> h.fedOut == 
root.fedOut) ){
+                       return root.getCostObject();
+               }
+               else {
+                       // If no input has FOUT, the root will be processed by 
the coordinator
+                       boolean hasFederatedInput = 
root.inputDependency.stream().anyMatch(in -> in.hopRef.hasFederatedOutput());
+                       // The input cost is included the first time the input 
hop is used.
+                       // For additional usage, the additional cost is zero 
(disregarding potential read cost).
+                       double inputCosts = root.inputDependency.stream()
+                               .mapToDouble( in -> {
+                                       double inCost = 
in.existingCostPointer(root.hopRef.getHopID()) ?
+                                               0 : costEstimate(in, 
hopRelMemo).getTotal();
+                                       
in.addCostPointer(root.hopRef.getHopID());
+                                       return inCost;
+                               } )
+                               .sum();
+                       double inputTransferCost = 
inputTransferCostEstimate(hasFederatedInput, root);
+                       double computingCost = 
ComputeCost.getHOPComputeCost(root.hopRef);
+                       if ( hasFederatedInput ){
+                               // Find the number of inputs that has FOUT set.
+                               int numWorkers = 
(int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
+                               // Divide memory usage by the number of workers 
the computation would be split to multiplied by
+                               // the number of parallel processes at each 
worker multiplied by the FLOPS of each process
+                               // This assumes uniform workload among the 
workers with FOUT data involved in the operation
+                               // and assumes that the degree of parallelism 
and compute bandwidth are equal for all workers
+                               computingCost = computingCost / 
(numWorkers*WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+                       } else computingCost = computingCost / 
(WORKER_DEGREE_OF_PARALLELISM* WORKER_COMPUTE_BANDWIDTH_FLOPS);
+                       // Calculate output transfer cost if the operation is 
computed at federated workers and the output is forced to the coordinator
+                       // If the root is a federated DataOp, the data is 
forced to the coordinator even if no input is LOUT
+                       double outputTransferCost = ( root.hasLocalOutput() && 
(hasFederatedInput || root.hopRef.isFederatedDataOp()) ) ?
+                               
root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0;
+                       double readCost = 
root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / 
WORKER_READ_BANDWIDTH_BYTES_PS;
+
+                       return new FederatedCost(readCost, inputTransferCost, 
outputTransferCost, computingCost, inputCosts);
+               }
+       }
+
+       /**
+        * Returns input transfer cost estimate.
+        * The input transfer cost estimate is based on the memory estimate of 
LOUT when some input is FOUT
+        * except if root is a federated DataOp, since all input for this has 
to be at the coordinator.
+        * When no input is FOUT, the input transfer cost is always 0.
+        * @param hasFederatedInput true if root has any FOUT input
+        * @param root hopRel for which cost is estimated
+        * @return input transfer cost estimate
+        */
+       private double inputTransferCostEstimate(boolean hasFederatedInput, 
HopRel root){
+               if ( hasFederatedInput )
+                       return root.inputDependency.stream()
+                               .filter(input -> 
(root.hopRef.isFederatedDataOp()) ? input.hasFederatedOutput() : 
input.hasLocalOutput() )
+                               .mapToDouble(in -> 
in.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+                               .sum() / WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+               else return 0;
+       }
+
+       /**
+        * Returns input transfer cost estimate.
+        * The input transfer cost estimate is based on the memory estimate of 
LOUT when some input is FOUT
+        * except if root is a federated DataOp, since all input for this has 
to be at the coordinator.
+        * When no input is FOUT, the input transfer cost is always 0.
+        * @param hasFederatedInput true if root has any FOUT input
+        * @param root hop for which cost is estimated
+        * @return input transfer cost estimate
+        */
+       private double inputTransferCostEstimate(boolean hasFederatedInput, Hop 
root){
+               if ( hasFederatedInput )
+                       return root.getInput().stream()
+                               .filter(input -> (root.isFederatedDataOp()) ? 
input.hasFederatedOutput() : input.hasLocalOutput() )
+                               .mapToDouble(in -> 
in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE))
+                               .sum() / WORKER_NETWORK_BANDWIDTH_BYTES_PS;
+               else return 0;
+       }
+
+       /**
         * Prints costs and information about root for debugging purposes
         * @param root hop for which information is printed
         */
diff --git a/src/main/java/org/apache/sysds/hops/cost/HopRel.java 
b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
new file mode 100644
index 0000000..6191a6c
--- /dev/null
+++ b/src/main/java/org/apache/sysds/hops/cost/HopRel.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.cost;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * HopRel provides a representation of the relation between a hop, the cost of 
setting a given FederatedOutput value,
+ * and the input dependency with the given FederatedOutput value.
+ * The HopRel class is used when building and selecting an optimal federated 
execution plan in IPAPassRewriteFederatedPlan.
+ * The input dependency is needed to hold the valid and optimal 
FederatedOutput values for the inputs.
+ */
+public class HopRel {
+       protected final Hop hopRef;
+       protected final FEDInstruction.FederatedOutput fedOut;
+       protected final FederatedCost cost;
+       protected final Set<Long> costPointerSet = new HashSet<>();
+       protected final List<HopRel> inputDependency = new ArrayList<>();
+
+       /**
+        * Constructs a HopRel with input dependency and cost estimate based on 
entries in hopRelMemo.
+        * @param associatedHop hop associated with this HopRel
+        * @param fedOut FederatedOutput value assigned to this HopRel
+        * @param hopRelMemo memo table storing other HopRels including the 
inputs of associatedHop
+        */
+       public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, 
Map<Long, List<HopRel>> hopRelMemo){
+               hopRef = associatedHop;
+               this.fedOut = fedOut;
+               setInputDependency(hopRelMemo);
+               cost = new FederatedCostEstimator().costEstimate(this, 
hopRelMemo);
+       }
+
+       /**
+        * Adds hopID to set of hops pointing to this HopRel.
+        * By storing the hopID it can later be determined if the cost
+        * stored in this HopRel is already used as input cost in another 
HopRel.
+        * @param hopID added to set of stored cost pointers
+        */
+       public void addCostPointer(long hopID){
+               costPointerSet.add(hopID);
+       }
+
+       /**
+        * Checks if another Hop is refering to this HopRel in memo table.
+        * A reference from a HopRel with same Hop ID is allowed, so this
+        * ID is ignored when checking references.
+        * @param currentHopID to ignore when checking references
+        * @return true if another Hop refers to this HopRel in memo table
+        */
+       public boolean existingCostPointer(long currentHopID){
+               if ( costPointerSet.contains(currentHopID) )
+                       return costPointerSet.size() > 1;
+               else return costPointerSet.size() > 0;
+       }
+
+       public boolean hasLocalOutput(){
+               return fedOut == FederatedOutput.LOUT;
+       }
+
+       public boolean hasFederatedOutput(){
+               return fedOut == FederatedOutput.FOUT;
+       }
+
+       public FederatedOutput getFederatedOutput(){
+               return fedOut;
+       }
+
+       public List<HopRel> getInputDependency(){
+               return inputDependency;
+       }
+
+       public Hop getHopRef(){
+               return hopRef;
+       }
+
+       /**
+        * Returns FOUT HopRel for given hop found in hopRelMemo or returns 
null if HopRel not found.
+        * @param hop to look for in hopRelMemo
+        * @param hopRelMemo memo table storing HopRels
+        * @return FOUT HopRel found in hopRelMemo
+        */
+       private HopRel getFOUTHopRel(Hop hop, Map<Long, List<HopRel>> 
hopRelMemo){
+               return 
hopRelMemo.get(hop.getHopID()).stream().filter(in->in.fedOut==FederatedOutput.FOUT).findFirst().orElse(null);
+       }
+
+       /**
+        * Get the HopRel with minimum cost for given hop
+        * @param hopRelMemo memo table storing HopRels
+        * @param input hop for which minimum cost HopRel is found
+        * @return HopRel with minimum cost for given hop
+        */
+       private HopRel getMinOfInput(Map<Long, List<HopRel>> hopRelMemo, Hop 
input){
+               return hopRelMemo.get(input.getHopID()).stream()
+                       .min(Comparator.comparingDouble(a -> a.cost.getTotal()))
+                       .orElseThrow(() -> new DMLException("No element in Memo 
Table found for input"));
+       }
+
+       /**
+        * Set valid and optimal input dependency for this HopRel as a field.
+        * @param hopRelMemo memo table storing input HopRels
+        */
+       private void setInputDependency(Map<Long, List<HopRel>> hopRelMemo){
+               if (hopRef.getInput() != null && hopRef.getInput().size() > 0) {
+                       if ( fedOut == FederatedOutput.FOUT && 
!hopRef.isFederatedDataOp() ) {
+                               int lowestFOUTIndex = 0;
+                               HopRel lowestFOUTHopRel = 
getFOUTHopRel(hopRef.getInput().get(0), hopRelMemo);
+                               for(int i = 1; i < hopRef.getInput().size(); 
i++) {
+                                       Hop input = hopRef.getInput(i);
+                                       HopRel foutHopRel = 
getFOUTHopRel(input, hopRelMemo);
+                                       if(lowestFOUTHopRel == null) {
+                                               lowestFOUTHopRel = foutHopRel;
+                                               lowestFOUTIndex = i;
+                                       }
+                                       else if(foutHopRel != null) {
+                                               if(foutHopRel.getCost() < 
lowestFOUTHopRel.getCost()) {
+                                                       lowestFOUTHopRel = 
foutHopRel;
+                                                       lowestFOUTIndex = i;
+                                               }
+                                       }
+                               }
+
+                               HopRel[] inputHopRels = new 
HopRel[hopRef.getInput().size()];
+                               for(int i = 0; i < hopRef.getInput().size(); 
i++) {
+                                       if(i != lowestFOUTIndex) {
+                                               Hop input = hopRef.getInput(i);
+                                               inputHopRels[i] = 
getMinOfInput(hopRelMemo, input);
+                                       }
+                                       else {
+                                               inputHopRels[i] = 
lowestFOUTHopRel;
+                                       }
+                               }
+                               
inputDependency.addAll(Arrays.asList(inputHopRels));
+                       } else {
+                               inputDependency.addAll(
+                                       hopRef.getInput().stream()
+                                               .map(input -> 
getMinOfInput(hopRelMemo, input))
+                                               .collect(Collectors.toList()));
+                       }
+               }
+               validateInputDependency();
+       }
+
+       /**
+        * Throws exception if any input dependency is null.
+        * If any of the input dependencies are null, it is not possible to 
build a federated execution plan.
+        * If this null-state is not found here, an exception will be thrown at 
another difficult-to-debug place.
+        */
+       private void validateInputDependency(){
+               for ( int i = 0; i < inputDependency.size(); i++){
+                       if ( inputDependency.get(i) == null)
+                               throw new DMLException("HopRel input number " + 
i + " (" + hopRef.getInput(i) + ")"
+                                       + " is null for root: \n" + this);
+               }
+       }
+
+       /**
+        * Get total cost as double
+        * @return cost as double
+        */
+       public double getCost(){
+               return cost.getTotal();
+       }
+
+       /**
+        * Get cost object
+        * @return cost object
+        */
+       public FederatedCost getCostObject(){
+               return cost;
+       }
+
+       @Override
+       public String toString(){
+               StringBuilder strB = new StringBuilder();
+               strB.append(this.getClass().getSimpleName());
+               strB.append(" {HopID: ");
+               strB.append(hopRef.getHopID());
+               strB.append(", Opcode: ");
+               strB.append(hopRef.getOpString());
+               strB.append(", FedOut: ");
+               strB.append(fedOut);
+               strB.append(", Cost: ");
+               strB.append(cost);
+               strB.append(", Number of inputs: ");
+               strB.append(inputDependency.size());
+               strB.append("}");
+               return strB.toString();
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java 
b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index 8224192..b0597eb 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -34,6 +34,7 @@ import org.apache.sysds.hops.HopsException;
 import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.hops.rewrite.IPAPassRewriteFederatedPlan;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
 import org.apache.sysds.parser.DataIdentifier;
@@ -243,7 +244,8 @@ public class InterProceduralAnalysis
                List<IPAPass> fpasses = Arrays.asList(
                        new IPAPassRemoveUnusedFunctions(),
                        new IPAPassCompressionWorkloadAnalysis(), // 
workload-aware compression
-                       new IPAPassApplyStaticAndDynamicHopRewrites());  
//split after compress
+                       new IPAPassApplyStaticAndDynamicHopRewrites(),  //split 
after compress
+                       new IPAPassRewriteFederatedPlan());
                for(IPAPass pass : fpasses)
                        if( pass.isApplicable(graph2) )
                                pass.rewriteProgram(_prog, graph2, null);
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java 
b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
new file mode 100644
index 0000000..cbc21cf
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.AggUnaryOp;
+import org.apache.sysds.hops.BinaryOp;
+import org.apache.sysds.hops.DataOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.ReorgOp;
+import org.apache.sysds.hops.TernaryOp;
+import org.apache.sysds.hops.cost.HopRel;
+import org.apache.sysds.hops.ipa.FunctionCallGraph;
+import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
+import org.apache.sysds.hops.ipa.IPAPass;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This rewrite generates a federated execution plan by estimating and setting 
costs and the FederatedOutput values of
+ * all relevant hops in the DML program.
+ * The rewrite is only applied if federated compilation is activated in 
OptimizerUtils.
+ */
+public class IPAPassRewriteFederatedPlan extends IPAPass {
+
+       private final static Map<Long, List<HopRel>> hopRelMemo = new 
HashMap<>();
+
+       /**
+        * Indicates if an IPA pass is applicable for the current configuration.
+        * The configuration depends on OptimizerUtils.FEDERATED_COMPILATION.
+        *
+        * @param fgraph function call graph
+        * @return true if federated compilation is activated.
+        */
+       @Override
+       public boolean isApplicable(FunctionCallGraph fgraph) {
+               return OptimizerUtils.FEDERATED_COMPILATION;
+       }
+
+       /**
+        * Estimates cost and selects a federated execution plan
+        * by setting the federated output value of each hop in the program.
+        *
+        * @param prog       dml program
+        * @param fgraph     function call graph
+        * @param fcallSizes function call size infos
+        * @return false since the function call graph never has to be rebuilt
+        */
+       @Override
+       public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph 
fgraph, FunctionCallSizeInfo fcallSizes) {
+               rewriteStatementBlocks(prog.getStatementBlocks());
+               return false;
+       }
+
+       /**
+        * Estimates cost and selects a federated execution plan
+        * by setting the federated output value of each hop in the statement 
blocks.
+        * The method calls the contained statement blocks recursively.
+        *
+        * @param sbs   list of statement blocks
+        * @return list of statement blocks with the federated output value 
updated for each hop
+        */
+       public ArrayList<StatementBlock> 
rewriteStatementBlocks(List<StatementBlock> sbs) {
+               ArrayList<StatementBlock> rewrittenStmBlocks = new 
ArrayList<>();
+               for ( StatementBlock stmBlock : sbs )
+                       
rewrittenStmBlocks.addAll(rewriteStatementBlock(stmBlock));
+               return rewrittenStmBlocks;
+       }
+
+       /**
+        * Estimates cost and selects a federated execution plan
+        * by setting the federated output value of each hop in the statement 
blocks.
+        * The method calls the contained statement blocks recursively.
+        *
+        * @param sb    statement block
+        * @return list of statement blocks with the federated output value 
updated for each hop
+        */
+       public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock 
sb) {
+               if ( sb instanceof WhileStatementBlock)
+                       return rewriteWhileStatementBlock((WhileStatementBlock) 
sb);
+               else if ( sb instanceof IfStatementBlock)
+                       return rewriteIfStatementBlock((IfStatementBlock) sb);
+               else if ( sb instanceof ForStatementBlock){
+                       // This also includes ParForStatementBlocks
+                       return rewriteForStatementBlock((ForStatementBlock) sb);
+               }
+               else if ( sb instanceof FunctionStatementBlock)
+                       return 
rewriteFunctionStatementBlock((FunctionStatementBlock) sb);
+               else {
+                       // StatementBlock type (no subclass)
+                       selectFederatedExecutionPlan(sb.getHops());
+               }
+               return new ArrayList<>(Collections.singletonList(sb));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteWhileStatementBlock(WhileStatementBlock whileSB){
+               Hop whilePredicateHop = whileSB.getPredicateHops();
+               selectFederatedExecutionPlan(whilePredicateHop);
+               for ( Statement stm : whileSB.getStatements() ){
+                       WhileStatement whileStm = (WhileStatement) stm;
+                       
whileStm.setBody(rewriteStatementBlocks(whileStm.getBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(whileSB));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteIfStatementBlock(IfStatementBlock ifSB){
+               selectFederatedExecutionPlan(ifSB.getPredicateHops());
+               for ( Statement statement : ifSB.getStatements() ){
+                       IfStatement ifStatement = (IfStatement) statement;
+                       
ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
+                       
ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(ifSB));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteForStatementBlock(ForStatementBlock forSB){
+               selectFederatedExecutionPlan(forSB.getFromHops());
+               selectFederatedExecutionPlan(forSB.getToHops());
+               selectFederatedExecutionPlan(forSB.getIncrementHops());
+               for ( Statement statement : forSB.getStatements() ){
+                       ForStatement forStatement = ((ForStatement)statement);
+                       
forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(forSB));
+       }
+
+       private ArrayList<StatementBlock> 
rewriteFunctionStatementBlock(FunctionStatementBlock funcSB){
+               for ( Statement statement : funcSB.getStatements() ){
+                       FunctionStatement funcStm = (FunctionStatement) 
statement;
+                       
funcStm.setBody(rewriteStatementBlocks(funcStm.getBody()));
+               }
+               return new ArrayList<>(Collections.singletonList(funcSB));
+       }
+
+       /**
+        * Sets FederatedOutput field of all hops in DAG starting from given 
root.
+        * The FederatedOutput chosen for root is the minimum cost HopRel found 
in memo table for the given root.
+        * The FederatedOutput values chosen for the inputs to the root are 
chosen based on the input dependencies.
+        * @param root hop for which FederatedOutput needs to be set
+        */
+       private void setFinalFedout(Hop root){
+               HopRel optimalRootHopRel = 
hopRelMemo.get(root.getHopID()).stream().min(Comparator.comparingDouble(HopRel::getCost))
+                       .orElseThrow(() -> new DMLException("Hop root " + root 
+ " has no feasible federated output alternatives"));
+               setFinalFedout(root, optimalRootHopRel);
+       }
+
+       /**
+        * Update the FederatedOutput value and cost based on information 
stored in given rootHopRel.
+        * @param root hop for which FederatedOutput is set
+        * @param rootHopRel from which FederatedOutput value and cost is 
retrieved
+        */
+       private void setFinalFedout(Hop root, HopRel rootHopRel){
+               updateFederatedOutput(root, rootHopRel);
+               visitInputDependency(rootHopRel);
+       }
+
+       /**
+        * Sets FederatedOutput value for each of the inputs of rootHopRel
+        * @param rootHopRel which has its input values updated
+        */
+       private void visitInputDependency(HopRel rootHopRel){
+               List<HopRel> hopRelInputs = rootHopRel.getInputDependency();
+               for ( HopRel input : hopRelInputs )
+                       setFinalFedout(input.getHopRef(), input);
+       }
+
+       /**
+        * Updates FederatedOutput value and cost estimate based on 
updateHopRel values.
+        * @param root which has its values updated
+        * @param updateHopRel from which the values are retrieved
+        */
+       private void updateFederatedOutput(Hop root, HopRel updateHopRel){
+               root.setFederatedOutput(updateHopRel.getFederatedOutput());
+               root.setFederatedCost(updateHopRel.getCostObject());
+       }
+
+       /**
+        * Select federated execution plan for every Hop in the DAG starting 
from given roots.
+        * The cost estimates of the hops are also updated when FederatedOutput 
is updated in the hops.
+        * @param roots starting point for going through the Hop DAG to update 
the FederatedOutput fields.
+        */
+       private void selectFederatedExecutionPlan(ArrayList<Hop> roots){
+               for ( Hop root : roots )
+                       selectFederatedExecutionPlan(root);
+       }
+
+       /**
+        * Select federated execution plan for every Hop in the DAG starting 
from given root.
+        * @param root starting point for going through the Hop DAG to update 
the federatedOutput fields
+        */
+       private void selectFederatedExecutionPlan(Hop root){
+               visitFedPlanHop(root);
+               setFinalFedout(root);
+       }
+
+       /**
+        * Go through the Hop DAG and set the FederatedOutput field and cost 
estimate for each Hop from leaf to given currentHop.
+        * @param currentHop the Hop from which the DAG is visited
+        */
+       private void visitFedPlanHop(Hop currentHop){
+               // If the currentHop is in the hopRelMemo table, it means that 
it has been visited
+               if ( hopRelMemo.containsKey(currentHop.getHopID()) )
+                       return;
+               // If the currentHop has input, then the input should be 
visited depth-first
+               if ( currentHop.getInput() != null && 
currentHop.getInput().size() > 0 ){
+                       for ( Hop input : currentHop.getInput() )
+                               visitFedPlanHop(input);
+               }
+               // Put FOUT, LOUT, and None HopRels into the memo table
+               ArrayList<HopRel> hopRels = new ArrayList<>();
+               if ( isFedInstSupportedHop(currentHop) ){
+                       for ( FEDInstruction.FederatedOutput fedoutValue : 
FEDInstruction.FederatedOutput.values() )
+                               if ( isFedOutSupported(currentHop, fedoutValue) 
)
+                                       hopRels.add(new 
HopRel(currentHop,fedoutValue, hopRelMemo));
+               }
+               if ( hopRels.isEmpty() )
+                       hopRels.add(new HopRel(currentHop, 
FEDInstruction.FederatedOutput.NONE, hopRelMemo));
+               hopRelMemo.put(currentHop.getHopID(), hopRels);
+               currentHop.setVisited();
+       }
+
+       /**
+        * Checks if the instructions related to the given hop supports 
FOUT/LOUT processing.
+        * @param hop to check for federated support
+        * @return true if federated instructions related to hop supports 
FOUT/LOUT processing
+        */
+       private boolean isFedInstSupportedHop(Hop hop){
+               // The following operations are supported given that the above 
conditions have not returned already
+               return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp 
|| hop instanceof ReorgOp
+                       || hop instanceof AggUnaryOp || hop instanceof 
TernaryOp || hop instanceof DataOp);
+       }
+
+       /**
+        * Checks if the associatedHop supports the given federated output 
value.
+        * @param associatedHop to check support of
+        * @param fedOut federated output value
+        * @return true if associatedHop supports fedOut
+        */
+       private boolean isFedOutSupported(Hop associatedHop, 
FEDInstruction.FederatedOutput fedOut){
+               switch(fedOut){
+                       case FOUT:
+                               return isFOUTSupported(associatedHop);
+                       case LOUT:
+                               return isLOUTSupported(associatedHop);
+                       case NONE:
+                               return false;
+                       default:
+                               return true;
+               }
+       }
+
+       /**
+        * Checks to see if the associatedHop supports FOUT.
+        * @param associatedHop for which FOUT support is checked
+        * @return true if FOUT is supported by the associatedHop
+        */
+       private boolean isFOUTSupported(Hop associatedHop){
+               // If the output of AggUnaryOp is a scalar, the operation 
cannot be FOUT
+               if ( associatedHop instanceof AggUnaryOp && 
associatedHop.isScalar() )
+                       return false;
+               // It can only be FOUT if at least one of the inputs are FOUT, 
except if it is a federated DataOp
+               if ( associatedHop.getInput().stream().noneMatch(
+                       input -> 
hopRelMemo.get(input.getHopID()).stream().anyMatch(HopRel::hasFederatedOutput) )
+                       && !associatedHop.isFederatedDataOp() )
+                       return false;
+               return true;
+       }
+
+       /**
+        * Checks to see if the associatedHop supports LOUT.
+        * It supports LOUT if the output has no privacy constraints.
+        * @param associatedHop for which LOUT support is checked.
+        * @return true if LOUT is supported by the associatedHop
+        */
+       private boolean isLOUTSupported(Hop associatedHop){
+               return associatedHop.getPrivacy() == null || 
!associatedHop.getPrivacy().hasConstraints();
+       }
+}
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 04cdf32..2e3edb0 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -140,7 +140,6 @@ public class ProgramRewriter
                        }
                        if ( OptimizerUtils.FEDERATED_COMPILATION ) {
                                _dagRuleSet.add( new 
RewriteFederatedExecution() );
-                               _sbRuleSet.add( new 
RewriteFederatedStatementBlocks() );
                        }
                }
                
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
index e6a92ce..75fa735 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java
@@ -23,14 +23,8 @@ import org.apache.commons.lang3.tuple.Pair;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.sysds.api.DMLException;
-import org.apache.sysds.hops.AggBinaryOp;
-import org.apache.sysds.hops.AggUnaryOp;
-import org.apache.sysds.hops.BinaryOp;
-import org.apache.sysds.hops.DataOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.LiteralOp;
-import org.apache.sysds.hops.ReorgOp;
-import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -40,8 +34,6 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
 import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
-import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.lineage.LineageItem;
@@ -58,127 +50,24 @@ import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
-import java.util.Collections;
-import java.util.EnumMap;
-import java.util.Map;
 import java.util.concurrent.Future;
 
 public class RewriteFederatedExecution extends HopRewriteRule {
+
        @Override
        public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
                if ( roots == null )
                        return null;
                for ( Hop root : roots )
                        visitHop(root);
-
-               return selectFederatedExecutionPlan(roots);
+               return roots;
        }
 
        @Override public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus 
state) {
                return null;
        }
 
-       /**
-        * Select federated execution plan for every Hop in the DAG starting 
from given roots.
-        * @param roots starting point for going through the Hop DAG to update 
the FederatedOutput fields.
-        * @return the list of roots with updated FederatedOutput fields.
-        */
-       private static ArrayList<Hop> 
selectFederatedExecutionPlan(ArrayList<Hop> roots){
-               for (Hop root : roots){
-                       root.resetVisitStatus();
-               }
-               for ( Hop root : roots ){
-                       visitFedPlanHop(root);
-               }
-               return roots;
-       }
-
-       /**
-        * Go through the Hop DAG and set the FederatedOutput field for each 
Hop from leaf to given currentHop.
-        * @param currentHop the Hop from which the DAG is visited
-        */
-       private static void visitFedPlanHop(Hop currentHop){
-               if ( currentHop.isVisited() )
-                       return;
-               if ( currentHop.getInput() != null && 
currentHop.getInput().size() > 0 && !isFederatedDataOp(currentHop) ){
-                       // Depth first to get to the input
-                       for ( Hop input : currentHop.getInput() )
-                               visitFedPlanHop(input);
-               } else if ( isFederatedDataOp(currentHop) ) {
-                       // leaf federated node
-                       //TODO: This will block the cases where the federated 
DataOp is based on input that are also federated.
-                       // This means that the actual federated leaf nodes will 
never be reached.
-                       currentHop.setFederatedOutput(FederatedOutput.FOUT);
-               }
-               if ( ( isFedInstSupportedHop(currentHop) ) ){
-                       // The Hop can be FOUT or LOUT or None. Check utility 
of FOUT vs LOUT vs None.
-                       
currentHop.setFederatedOutput(getHighestUtilFedOut(currentHop));
-               }
-               else
-                       
currentHop.setFederatedOutput(FEDInstruction.FederatedOutput.NONE);
-               currentHop.setVisited();
-       }
-
-       /**
-        * Returns the FederatedOutput with the highest utility out of the 
valid FederatedOutput values.
-        * @param hop for which the utility is found
-        * @return the FederatedOutput value with highest utility for the given 
Hop
-        */
-       private static FederatedOutput getHighestUtilFedOut(Hop hop){
-               Map<FederatedOutput,Long> fedOutUtilMap = new 
EnumMap<>(FederatedOutput.class);
-               if ( isFOUTSupported(hop) )
-                       fedOutUtilMap.put(FederatedOutput.FOUT, getUtilFout());
-               if ( hop.getPrivacy() == null || (hop.getPrivacy() != null && 
!hop.getPrivacy().hasConstraints()) )
-                       fedOutUtilMap.put(FederatedOutput.LOUT, 
getUtilLout(hop));
-               fedOutUtilMap.put(FederatedOutput.NONE, 0L);
-
-               Map.Entry<FederatedOutput, Long> fedOutMax = 
Collections.max(fedOutUtilMap.entrySet(), Map.Entry.comparingByValue());
-               return fedOutMax.getKey();
-       }
-
-       /**
-        * Utility if hop is FOUT. This is a simple version where it always 
returns 1.
-        * @return utility if hop is FOUT
-        */
-       private static long getUtilFout(){
-               //TODO: Make better utility estimation
-               return 1;
-       }
-
-       /**
-        * Utility if hop is LOUT. This is a simple version only based on 
dimensions.
-        * @param hop for which utility is calculated
-        * @return utility if hop is LOUT
-        */
-       private static long getUtilLout(Hop hop){
-               //TODO: Make better utility estimation
-               return -(long)hop.getMemEstimate();
-       }
-
-       private static boolean isFedInstSupportedHop(Hop hop){
-
-               // Check that some input is FOUT, otherwise none of the fed 
instructions will run unless it is fedinit
-               if ( (!isFederatedDataOp(hop)) && 
hop.getInput().stream().noneMatch(Hop::hasFederatedOutput) )
-                       return false;
-
-               // The following operations are supported given that the above 
conditions have not returned already
-               return ( hop instanceof AggBinaryOp || hop instanceof BinaryOp 
|| hop instanceof ReorgOp
-                       || hop instanceof AggUnaryOp || hop instanceof 
TernaryOp || hop instanceof DataOp );
-       }
-
-       /**
-        * Checks to see if the associatedHop supports FOUT.
-        * @param associatedHop for which FOUT support is checked
-        * @return true if FOUT is supported by the associatedHop
-        */
-       private static boolean isFOUTSupported(Hop associatedHop){
-               // If the output of AggUnaryOp is a scalar, the operation 
cannot be FOUT
-               if ( associatedHop instanceof AggUnaryOp )
-                       return !associatedHop.isScalar();
-               return true;
-       }
-
-       private static void visitHop(Hop hop){
+       private void visitHop(Hop hop){
                if (hop.isVisited())
                        return;
 
@@ -206,7 +95,7 @@ public class RewriteFederatedExecution extends 
HopRewriteRule {
         * @hop hop for which privacy constraints are loaded
         */
        private static void loadFederatedPrivacyConstraints(Hop hop){
-               if ( isFederatedDataOp(hop) && hop.getPrivacy() == null){
+               if ( hop.isFederatedDataOp() && hop.getPrivacy() == null){
                        try {
                                PrivacyConstraint privConstraint = 
unwrapPrivConstraint(sendPrivConstraintRequest(hop));
                                hop.setPrivacy(privConstraint);
@@ -238,10 +127,6 @@ public class RewriteFederatedExecution extends 
HopRewriteRule {
                return (PrivacyConstraint) privConstraintResponse.getData()[0];
        }
 
-       private static boolean isFederatedDataOp(Hop hop){
-               return hop instanceof DataOp && ((DataOp) 
hop).isFederatedData();
-       }
-
        /**
         * FederatedUDF for retrieving privacy constraint of data stored in 
file name.
         */
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
deleted file mode 100644
index 18b36d5..0000000
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedStatementBlocks.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.hops.rewrite;
-
-import org.apache.sysds.parser.StatementBlock;
-
-import java.util.Arrays;
-import java.util.List;
-
-public class RewriteFederatedStatementBlocks extends StatementBlockRewriteRule 
{
-
-       /**
-        * Indicates if the rewrite potentially splits dags, which is used
-        * for phase ordering of rewrites.
-        *
-        * @return true if dag splits are possible.
-        */
-       @Override public boolean createsSplitDag() {
-               return false;
-       }
-
-       /**
-        * Handle an arbitrary statement block. Specific type constraints have 
to be ensured
-        * within the individual rewrites. If a rewrite does not apply to 
individual blocks, it
-        * should simply return the input block.
-        *
-        * @param sb    statement block
-        * @param state program rewrite status
-        * @return list of statement blocks
-        */
-       @Override
-       public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, 
ProgramRewriteStatus state) {
-               return Arrays.asList(sb);
-       }
-
-       /**
-        * Handle a list of statement blocks. Specific type constraints have to 
be ensured
-        * within the individual rewrites. If a rewrite does not require 
sequence access, it
-        * should simply return the input list of statement blocks.
-        *
-        * @param sbs   list of statement blocks
-        * @param state program rewrite status
-        * @return list of statement blocks
-        */
-       @Override
-       public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> 
sbs, ProgramRewriteStatus state) {
-               return sbs;
-       }
-}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
index 0092a3a..906ed1f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedCostEstimatorTest.java
@@ -63,7 +63,7 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
 
        @Test
        public void simpleBinary() {
-               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
                fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
 
                /*
@@ -75,7 +75,7 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
                 * TOSTRING             1                               1       
                        800                     0.0625                          
80
                 * UnaryOp              1                               1       
                        8                       0.0625                          
0.8
                 */
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
                double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
 
                double expectedCost = computeCost + readCost;
@@ -84,9 +84,9 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
 
        @Test
        public void ifElseTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
                fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
                double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
                double expectedCost = ((computeCost + readCost + 0.8 + 0.0625 + 
0.0625) / 2) + 0.0625 + 0.8 + 0.0625;
                runTest("IfElseCostEstimatorTest.dml", false, expectedCost);
@@ -94,9 +94,9 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
 
        @Test
        public void whileTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
                fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
                double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
                double expectedCost = (computeCost + readCost + 0.0625) * 
fedCostEstimator.DEFAULT_ITERATION_NUMBER + 0.0625 + 0.8;
                runTest("WhileCostEstimatorTest.dml", false, expectedCost);
@@ -104,9 +104,9 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
 
        @Test
        public void forLoopTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
                fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
                double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
                double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 
0.0625 + 0.0625 + 0.8 + 0.0625;
                double expectedCost = (computeCost + readCost + predicateCost) 
* 5;
@@ -115,9 +115,9 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
 
        @Test
        public void parForLoopTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
                fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
                double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
                double predicateCost = 0.0625 + 0.8 + 0.0625 + 0.0625 + 0.8 + 
0.0625 + 0.0625 + 0.8 + 0.0625;
                double expectedCost = (computeCost + readCost + predicateCost) 
* 5;
@@ -126,9 +126,9 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
 
        @Test
        public void functionTest(){
-               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
                fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
-               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
+               double computeCost = (16+2*100+100+1+1) / 
(fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS 
*fedCostEstimator.WORKER_DEGREE_OF_PARALLELISM);
                double readCost = (2*64+1600+800+8) / 
(fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS);
                double expectedCost = (computeCost + readCost);
                runTest("FunctionCostEstimatorTest.dml", false, expectedCost);
@@ -136,7 +136,7 @@ public class FederatedCostEstimatorTest extends 
AutomatedTestBase {
 
        @Test
        public void federatedMultiply() {
-               fedCostEstimator.WORKER_COMPUTE_BANDWITH_FLOPS = 2;
+               fedCostEstimator.WORKER_COMPUTE_BANDWIDTH_FLOPS = 2;
                fedCostEstimator.WORKER_READ_BANDWIDTH_BYTES_PS = 10;
                fedCostEstimator.WORKER_NETWORK_BANDWIDTH_BYTES_PS = 5;
 

Reply via email to